diff --git a/.config/hakari.toml b/.config/hakari.toml index 982542ca39..2050065cc2 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -23,6 +23,8 @@ workspace-members = [ ] third-party = [ { name = "reqwest", version = "0.11.27" }, + # build of remote_server should not include scap / its x11 dependency + { name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" }, ] [final-excludes] diff --git a/.github/actionlint.yml b/.github/actionlint.yml new file mode 100644 index 0000000000..d93ec5b15e --- /dev/null +++ b/.github/actionlint.yml @@ -0,0 +1,30 @@ +# Configuration related to self-hosted runner. +self-hosted-runner: + # Labels of self-hosted runner in array of strings. + labels: + # GitHub-hosted Runners + - github-8vcpu-ubuntu-2404 + - github-16vcpu-ubuntu-2404 + - windows-2025-16 + - windows-2025-32 + - windows-2025-64 + # Buildjet Ubuntu 20.04 - AMD x86_64 + - buildjet-2vcpu-ubuntu-2004 + - buildjet-4vcpu-ubuntu-2004 + - buildjet-8vcpu-ubuntu-2004 + - buildjet-16vcpu-ubuntu-2004 + - buildjet-32vcpu-ubuntu-2004 + # Buildjet Ubuntu 22.04 - AMD x86_64 + - buildjet-2vcpu-ubuntu-2204 + - buildjet-4vcpu-ubuntu-2204 + - buildjet-8vcpu-ubuntu-2204 + - buildjet-16vcpu-ubuntu-2204 + - buildjet-32vcpu-ubuntu-2204 + # Buildjet Ubuntu 22.04 - Graviton aarch64 + - buildjet-8vcpu-ubuntu-2204-arm + - buildjet-16vcpu-ubuntu-2204-arm + - buildjet-32vcpu-ubuntu-2204-arm + - buildjet-64vcpu-ubuntu-2204-arm + # Self Hosted Runners + - self-mini-macos + - self-32vcpu-windows-2022 diff --git a/.github/actions/install_trusted_signing/action.yml b/.github/actions/install_trusted_signing/action.yml deleted file mode 100644 index a99ff08eb1..0000000000 --- a/.github/actions/install_trusted_signing/action.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: "Trusted Signing on Windows" -description: "Install trusted signing on Windows." - -# Modified from https://github.com/Azure/trusted-signing-action -runs: - using: "composite" - steps: - - name: Set variables - id: set-variables - shell: "pwsh" - run: | - $defaultPath = $env:PSModulePath -split ';' | Select-Object -First 1 - "PSMODULEPATH=$defaultPath" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - - "TRUSTED_SIGNING_MODULE_VERSION=0.5.3" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "BUILD_TOOLS_NUGET_VERSION=10.0.22621.3233" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "TRUSTED_SIGNING_NUGET_VERSION=1.0.53" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - "DOTNET_SIGNCLI_NUGET_VERSION=0.9.1-beta.24469.1" | Out-File -FilePath $env:GITHUB_OUTPUT -Append - - - name: Cache TrustedSigning PowerShell module - id: cache-module - uses: actions/cache@v4 - env: - cache-name: cache-module - with: - path: ${{ steps.set-variables.outputs.PSMODULEPATH }}\TrustedSigning\${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} - key: TrustedSigning-${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache Microsoft.Windows.SDK.BuildTools NuGet package - id: cache-buildtools - uses: actions/cache@v4 - env: - cache-name: cache-buildtools - with: - path: ~\AppData\Local\TrustedSigning\Microsoft.Windows.SDK.BuildTools\Microsoft.Windows.SDK.BuildTools.${{ steps.set-variables.outputs.BUILD_TOOLS_NUGET_VERSION }} - key: Microsoft.Windows.SDK.BuildTools-${{ steps.set-variables.outputs.BUILD_TOOLS_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache Microsoft.Trusted.Signing.Client NuGet package - id: cache-tsclient - uses: actions/cache@v4 - env: - cache-name: cache-tsclient - with: - path: ~\AppData\Local\TrustedSigning\Microsoft.Trusted.Signing.Client\Microsoft.Trusted.Signing.Client.${{ steps.set-variables.outputs.TRUSTED_SIGNING_NUGET_VERSION }} - key: Microsoft.Trusted.Signing.Client-${{ steps.set-variables.outputs.TRUSTED_SIGNING_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Cache SignCli NuGet package - id: cache-signcli - uses: actions/cache@v4 - env: - cache-name: cache-signcli - with: - path: ~\AppData\Local\TrustedSigning\sign\sign.${{ steps.set-variables.outputs.DOTNET_SIGNCLI_NUGET_VERSION }} - key: SignCli-${{ steps.set-variables.outputs.DOTNET_SIGNCLI_NUGET_VERSION }} - if: ${{ inputs.cache-dependencies == 'true' }} - - - name: Install Trusted Signing module - shell: "pwsh" - run: | - Install-Module -Name TrustedSigning -RequiredVersion ${{ steps.set-variables.outputs.TRUSTED_SIGNING_MODULE_VERSION }} -Force -Repository PSGallery - if: ${{ inputs.cache-dependencies != 'true' || steps.cache-module.outputs.cache-hit != 'true' }} diff --git a/.github/workflows/bump_patch_version.yml b/.github/workflows/bump_patch_version.yml index 02857a9151..8a48ff96f1 100644 --- a/.github/workflows/bump_patch_version.yml +++ b/.github/workflows/bump_patch_version.yml @@ -28,7 +28,7 @@ jobs: run: | set -eux - channel=$(cat crates/zed/RELEASE_CHANNEL) + channel="$(cat crates/zed/RELEASE_CHANNEL)" tag_suffix="" case $channel in @@ -43,9 +43,9 @@ jobs: ;; esac which cargo-set-version > /dev/null || cargo install cargo-edit - output=$(cargo set-version -p zed --bump patch 2>&1 | sed 's/.* //') + output="$(cargo set-version -p zed --bump patch 2>&1 | sed 's/.* //')" export GIT_COMMITTER_NAME="Zed Bot" export GIT_COMMITTER_EMAIL="hi@zed.dev" git commit -am "Bump to $output for @$GITHUB_ACTOR" --author "Zed Bot " - git tag v${output}${tag_suffix} - git push origin HEAD v${output}${tag_suffix} + git tag "v${output}${tag_suffix}" + git push origin HEAD "v${output}${tag_suffix}" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25a1ed8670..a4da5e99ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,9 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 + DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} + DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} + ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} jobs: job_spec: @@ -31,6 +34,7 @@ jobs: run_license: ${{ steps.filter.outputs.run_license }} run_docs: ${{ steps.filter.outputs.run_docs }} run_nix: ${{ steps.filter.outputs.run_nix }} + run_actionlint: ${{ steps.filter.outputs.run_actionlint }} runs-on: - ubuntu-latest steps: @@ -44,38 +48,40 @@ jobs: run: | if [ -z "$GITHUB_BASE_REF" ]; then echo "Not in a PR context (i.e., push to main/stable/preview)" - COMPARE_REV=$(git rev-parse HEAD~1) + COMPARE_REV="$(git rev-parse HEAD~1)" else echo "In a PR context comparing to pull_request.base.ref" git fetch origin "$GITHUB_BASE_REF" --depth=350 - COMPARE_REV=$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD) + COMPARE_REV="$(git merge-base "origin/${GITHUB_BASE_REF}" HEAD)" fi - # Specify anything which should skip full CI in this regex: + CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" ${{ github.sha }})" + + # Specify anything which should potentially skip full test suite in this regex: # - docs/ + # - script/update_top_ranking_issues/ # - .github/ISSUE_TEMPLATE/ # - .github/workflows/ (except .github/workflows/ci.yml) - SKIP_REGEX='^(docs/|\.github/(ISSUE_TEMPLATE|workflows/(?!ci)))' - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -vP "$SKIP_REGEX") ]]; then - echo "run_tests=true" >> $GITHUB_OUTPUT - else - echo "run_tests=false" >> $GITHUB_OUTPUT - fi - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep '^docs/') ]]; then - echo "run_docs=true" >> $GITHUB_OUTPUT - else - echo "run_docs=false" >> $GITHUB_OUTPUT - fi - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -P '^(Cargo.lock|script/.*licenses)') ]]; then - echo "run_license=true" >> $GITHUB_OUTPUT - else - echo "run_license=false" >> $GITHUB_OUTPUT - fi - NIX_REGEX='^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)' - if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep "$NIX_REGEX") ]]; then - echo "run_nix=true" >> $GITHUB_OUTPUT - else - echo "run_nix=false" >> $GITHUB_OUTPUT - fi + SKIP_REGEX='^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!ci)))' + + echo "$CHANGED_FILES" | grep -qvP "$SKIP_REGEX" && \ + echo "run_tests=true" >> "$GITHUB_OUTPUT" || \ + echo "run_tests=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^docs/' && \ + echo "run_docs=true" >> "$GITHUB_OUTPUT" || \ + echo "run_docs=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^\.github/(workflows/|actions/|actionlint.yml)' && \ + echo "run_actionlint=true" >> "$GITHUB_OUTPUT" || \ + echo "run_actionlint=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^(Cargo.lock|script/.*licenses)' && \ + echo "run_license=true" >> "$GITHUB_OUTPUT" || \ + echo "run_license=false" >> "$GITHUB_OUTPUT" + + echo "$CHANGED_FILES" | grep -qP '^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)' && \ + echo "run_nix=true" >> "$GITHUB_OUTPUT" || \ + echo "run_nix=false" >> "$GITHUB_OUTPUT" migration_checks: name: Check Postgres and Protobuf migrations, mergability @@ -85,8 +91,7 @@ jobs: needs.job_spec.outputs.run_tests == 'true' timeout-minutes: 60 runs-on: - - self-hosted - - macOS + - self-mini-macos steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -108,11 +113,11 @@ jobs: run: | if [ -z "$GITHUB_BASE_REF" ]; then - echo "BUF_BASE_BRANCH=$(git merge-base origin/main HEAD)" >> $GITHUB_ENV + echo "BUF_BASE_BRANCH=$(git merge-base origin/main HEAD)" >> "$GITHUB_ENV" else git checkout -B temp - git merge -q origin/$GITHUB_BASE_REF -m "merge main into temp" - echo "BUF_BASE_BRANCH=$GITHUB_BASE_REF" >> $GITHUB_ENV + git merge -q "origin/$GITHUB_BASE_REF" -m "merge main into temp" + echo "BUF_BASE_BRANCH=$GITHUB_BASE_REF" >> "$GITHUB_ENV" fi - uses: bufbuild/buf-setup-action@v1 @@ -136,7 +141,7 @@ jobs: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Install cargo-hakari uses: clechasseur/rs-cargo@8435b10f6e71c2e3d4d3b7573003a8ce4bfc6386 # v2 with: @@ -174,7 +179,7 @@ jobs: - name: Prettier Check on /docs working-directory: ./docs run: | - pnpm dlx prettier@${PRETTIER_VERSION} . --check || { + pnpm dlx "prettier@${PRETTIER_VERSION}" . --check || { echo "To fix, run from the root of the Zed repo:" echo " cd docs && pnpm dlx prettier@${PRETTIER_VERSION} . --write && cd .." false @@ -184,7 +189,7 @@ jobs: - name: Prettier Check on default.json run: | - pnpm dlx prettier@${PRETTIER_VERSION} assets/settings/default.json --check || { + pnpm dlx "prettier@${PRETTIER_VERSION}" assets/settings/default.json --check || { echo "To fix, run from the root of the Zed repo:" echo " pnpm dlx prettier@${PRETTIER_VERSION} assets/settings/default.json --write" false @@ -230,6 +235,20 @@ jobs: - name: Build docs uses: ./.github/actions/build_docs + actionlint: + runs-on: ubuntu-latest + if: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_actionlint == 'true' + needs: [job_spec] + steps: + - uses: actions/checkout@v4 + - name: Download actionlint + id: get_actionlint + run: bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) + shell: bash + - name: Check workflow files + run: ${{ steps.get_actionlint.outputs.executable }} -color + shell: bash + macos_tests: timeout-minutes: 60 name: (macOS) Run Clippy and tests @@ -238,8 +257,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - self-hosted - - macOS + - self-mini-macos steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -308,7 +326,7 @@ jobs: - buildjet-16vcpu-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -360,7 +378,7 @@ jobs: - buildjet-8vcpu-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -390,7 +408,7 @@ jobs: windows_tests: timeout-minutes: 60 - name: (Windows) Run Tests + name: (Windows) Run Clippy and tests needs: [job_spec] if: | github.repository_owner == 'zed-industries' && @@ -440,6 +458,7 @@ jobs: - job_spec - style - check_docs + - actionlint - migration_checks # run_tests: If adding required tests, add them here and to script below. - workspace_hack @@ -461,6 +480,11 @@ jobs: if [[ "${{ needs.job_spec.outputs.run_docs }}" == "true" ]]; then [[ "${{ needs.check_docs.result }}" != 'success' ]] && { RET_CODE=1; echo "docs checks failed"; } fi + + if [[ "${{ needs.job_spec.outputs.run_actionlint }}" == "true" ]]; then + [[ "${{ needs.actionlint.result }}" != 'success' ]] && { RET_CODE=1; echo "actionlint checks failed"; } + fi + # Only check test jobs if they were supposed to run if [[ "${{ needs.job_spec.outputs.run_tests }}" == "true" ]]; then [[ "${{ needs.workspace_hack.result }}" != 'success' ]] && { RET_CODE=1; echo "Workspace Hack failed"; } @@ -480,8 +504,7 @@ jobs: timeout-minutes: 120 name: Create a macOS bundle runs-on: - - self-hosted - - bundle + - self-mini-macos if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -492,9 +515,6 @@ jobs: APPLE_NOTARIZATION_KEY: ${{ secrets.APPLE_NOTARIZATION_KEY }} APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }} APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 @@ -577,10 +597,6 @@ jobs: startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') needs: [linux_tests] - env: - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -634,10 +650,6 @@ jobs: startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') needs: [linux_tests] - env: - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -686,20 +698,18 @@ jobs: timeout-minutes: 60 runs-on: github-8vcpu-ubuntu-2404 if: | + false && ( startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') + ) needs: [linux_tests] name: Build Zed on FreeBSD - # env: - # MYTOKEN : ${{ secrets.MYTOKEN }} - # MYTOKEN2: "value2" steps: - uses: actions/checkout@v4 - name: Build FreeBSD remote-server id: freebsd-build uses: vmactions/freebsd-vm@c3ae29a132c8ef1924775414107a97cac042aad5 # v1.2.0 with: - # envs: "MYTOKEN MYTOKEN2" usesh: true release: 13.5 copyback: true @@ -757,7 +767,7 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }} + if: false && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} @@ -766,8 +776,6 @@ jobs: ACCOUNT_NAME: ${{ vars.AZURE_SIGNING_ACCOUNT_NAME }} CERT_PROFILE_NAME: ${{ vars.AZURE_SIGNING_CERT_PROFILE_NAME }} ENDPOINT: ${{ vars.AZURE_SIGNING_ENDPOINT }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} FILE_DIGEST: SHA256 TIMESTAMP_DIGEST: SHA256 TIMESTAMP_SERVER: "http://timestamp.acs.microsoft.com" @@ -784,9 +792,6 @@ jobs: # This exports RELEASE_CHANNEL into env (GITHUB_ENV) script/determine-release-channel.ps1 - - name: Install trusted signing - uses: ./.github/actions/install_trusted_signing - - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 @@ -800,6 +805,7 @@ jobs: - name: Upload Artifacts to release uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1 + # Re-enable when we are ready to publish windows preview releases if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview with: draft: true @@ -813,12 +819,11 @@ jobs: if: | startsWith(github.ref, 'refs/tags/v') && endsWith(github.ref, '-pre') && !endsWith(github.ref, '.0-pre') - needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64, freebsd] + needs: [bundle-mac, bundle-linux-x86_x64, bundle-linux-aarch64, bundle-windows-x64] runs-on: - - self-hosted - - bundle + - self-mini-macos steps: - name: gh release - run: gh release edit $GITHUB_REF_NAME --draft=false + run: gh release edit "$GITHUB_REF_NAME" --draft=false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index 3e253978b7..31dda1fa6d 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -18,7 +18,7 @@ jobs: URL="https://zed.dev/releases/stable/latest" fi - echo "URL=$URL" >> $GITHUB_OUTPUT + echo "URL=$URL" >> "$GITHUB_OUTPUT" - name: Get content uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 id: get-content @@ -50,9 +50,9 @@ jobs: PREVIEW_TAG="${VERSION}-pre" if git rev-parse "$PREVIEW_TAG" > /dev/null 2>&1; then - echo "was_promoted_from_preview=true" >> $GITHUB_OUTPUT + echo "was_promoted_from_preview=true" >> "$GITHUB_OUTPUT" else - echo "was_promoted_from_preview=false" >> $GITHUB_OUTPUT + echo "was_promoted_from_preview=false" >> "$GITHUB_OUTPUT" fi - name: Send release notes email diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index cfd455f920..f7348a1069 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -79,12 +79,12 @@ jobs: - name: Build docker image run: | docker build -f Dockerfile-collab \ - --build-arg GITHUB_SHA=$GITHUB_SHA \ - --tag registry.digitalocean.com/zed/collab:$GITHUB_SHA \ + --build-arg "GITHUB_SHA=$GITHUB_SHA" \ + --tag "registry.digitalocean.com/zed/collab:$GITHUB_SHA" \ . - name: Publish docker image - run: docker push registry.digitalocean.com/zed/collab:${GITHUB_SHA} + run: docker push "registry.digitalocean.com/zed/collab:${GITHUB_SHA}" - name: Prune Docker system run: docker system prune --filter 'until=72h' -f @@ -131,7 +131,8 @@ jobs: source script/lib/deploy-helpers.sh export_vars_for_environment $ZED_KUBE_NAMESPACE - export ZED_DO_CERTIFICATE_ID=$(doctl compute certificate list --format ID --no-header) + ZED_DO_CERTIFICATE_ID="$(doctl compute certificate list --format ID --no-header)" + export ZED_DO_CERTIFICATE_ID export ZED_IMAGE_ID="registry.digitalocean.com/zed/collab:${GITHUB_SHA}" export ZED_SERVICE_NAME=collab diff --git a/.github/workflows/eval.yml b/.github/workflows/eval.yml index 6eefdfea95..2ad302a602 100644 --- a/.github/workflows/eval.yml +++ b/.github/workflows/eval.yml @@ -35,7 +35,7 @@ jobs: - buildjet-16vcpu-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 155fc484f5..beacd27774 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -43,8 +43,8 @@ jobs: - name: Set path if: ${{ ! matrix.system.install_nix }} run: | - echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH - echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH + echo "/nix/var/nix/profiles/default/bin" >> "$GITHUB_PATH" + echo "/Users/administrator/.nix-profile/bin" >> "$GITHUB_PATH" - uses: cachix/install-nix-action@02a151ada4993995686f9ed4f1be7cfbb229e56f # v31 if: ${{ matrix.system.install_nix }} @@ -56,11 +56,13 @@ jobs: name: zed authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" pushFilter: "${{ inputs.cachix-filter }}" - cachixArgs: '-v' + cachixArgs: "-v" - run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config - name: Limit /nix/store to 50GB on macs if: ${{ ! matrix.system.install_nix }} run: | - [ $(du -sm /nix/store | cut -f1) -gt 50000 ] && nix-collect-garbage -d || : + if [ "$(du -sm /nix/store | cut -f1)" -gt 50000 ]; then + nix-collect-garbage -d || true + fi diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index df9f6ef40f..f799133ea7 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -12,6 +12,9 @@ env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 RUST_BACKTRACE: 1 + ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} + DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} + DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} jobs: style: @@ -82,8 +85,7 @@ jobs: name: Create a macOS bundle if: github.repository_owner == 'zed-industries' runs-on: - - self-hosted - - bundle + - self-mini-macos needs: tests env: MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }} @@ -91,9 +93,6 @@ jobs: APPLE_NOTARIZATION_KEY: ${{ secrets.APPLE_NOTARIZATION_KEY }} APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }} APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 @@ -125,10 +124,6 @@ jobs: runs-on: - buildjet-16vcpu-ubuntu-2004 needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -136,7 +131,7 @@ jobs: clean: false - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Install Linux dependencies run: ./script/linux && ./script/install-mold 2.34.0 @@ -164,10 +159,6 @@ jobs: runs-on: - buildjet-16vcpu-ubuntu-2204-arm needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} - ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -195,12 +186,9 @@ jobs: freebsd: timeout-minutes: 60 - if: github.repository_owner == 'zed-industries' + if: false && github.repository_owner == 'zed-industries' runs-on: github-8vcpu-ubuntu-2404 needs: tests - env: - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} name: Build Zed on FreeBSD # env: # MYTOKEN : ${{ secrets.MYTOKEN }} @@ -257,8 +245,6 @@ jobs: ACCOUNT_NAME: ${{ vars.AZURE_SIGNING_ACCOUNT_NAME }} CERT_PROFILE_NAME: ${{ vars.AZURE_SIGNING_CERT_PROFILE_NAME }} ENDPOINT: ${{ vars.AZURE_SIGNING_ENDPOINT }} - DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} - DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} FILE_DIGEST: SHA256 TIMESTAMP_DIGEST: SHA256 TIMESTAMP_SERVER: "http://timestamp.acs.microsoft.com" @@ -276,9 +262,6 @@ jobs: Write-Host "Publishing version: $version on release channel nightly" "nightly" | Set-Content -Path "crates/zed/RELEASE_CHANNEL" - - name: Install trusted signing - uses: ./.github/actions/install_trusted_signing - - name: Build Zed installer working-directory: ${{ env.ZED_WORKSPACE }} run: script/bundle-windows.ps1 diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index 705caff37a..cb4e39d151 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -26,7 +26,7 @@ jobs: - buildjet-16vcpu-ubuntu-2204 steps: - name: Add Rust to the PATH - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.zed/settings.json b/.zed/settings.json index 1ef6bc28f7..68e05a426f 100644 --- a/.zed/settings.json +++ b/.zed/settings.json @@ -40,7 +40,7 @@ }, "file_types": { "Dockerfile": ["Dockerfile*[!dockerignore]"], - "JSONC": ["assets/**/*.json", "renovate.json"], + "JSONC": ["**/assets/**/*.json", "renovate.json"], "Git Ignore": ["dockerignore"] }, "hard_tabs": false, diff --git a/Cargo.lock b/Cargo.lock index 38bb7819ca..2c65131db0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,35 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "acp_thread" +version = "0.1.0" +dependencies = [ + "agent-client-protocol", + "agentic-coding-protocol", + "anyhow", + "assistant_tool", + "async-pipe", + "buffer_diff", + "editor", + "env_logger 0.11.8", + "futures 0.3.31", + "gpui", + "indoc", + "itertools 0.14.0", + "language", + "markdown", + "project", + "serde", + "serde_json", + "settings", + "smol", + "tempfile", + "ui", + "util", + "workspace-hack", +] + [[package]] name = "activity_indicator" version = "0.1.0" @@ -107,6 +136,53 @@ dependencies = [ "zstd", ] +[[package]] +name = "agent-client-protocol" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fb7f39671e02f8a1aeb625652feae40b6fc2597baaa97e028a98863477aecbd" +dependencies = [ + "schemars", + "serde", + "serde_json", +] + +[[package]] +name = "agent_servers" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent-client-protocol", + "agentic-coding-protocol", + "anyhow", + "collections", + "context_server", + "env_logger 0.11.8", + "futures 0.3.31", + "gpui", + "indoc", + "itertools 0.14.0", + "language", + "libc", + "log", + "nix 0.29.0", + "paths", + "project", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "strum 0.27.1", + "tempfile", + "ui", + "util", + "uuid", + "watch", + "which 6.0.3", + "workspace-hack", +] + [[package]] name = "agent_settings" version = "0.1.0" @@ -130,8 +206,12 @@ dependencies = [ name = "agent_ui" version = "0.1.0" dependencies = [ + "acp_thread", "agent", + "agent-client-protocol", + "agent_servers", "agent_settings", + "ai_onboarding", "anyhow", "assistant_context", "assistant_slash_command", @@ -143,6 +223,7 @@ dependencies = [ "chrono", "client", "collections", + "command_palette_hooks", "component", "context_server", "db", @@ -164,6 +245,7 @@ dependencies = [ "jsonschema", "language", "language_model", + "language_models", "languages", "log", "lsp", @@ -191,6 +273,7 @@ dependencies = [ "settings", "smol", "streaming_diff", + "task", "telemetry", "telemetry_events", "terminal", @@ -201,6 +284,7 @@ dependencies = [ "time_format", "tree-sitter-md", "ui", + "ui_input", "unindent", "urlencoding", "util", @@ -212,6 +296,24 @@ dependencies = [ "zed_llm_client", ] +[[package]] +name = "agentic-coding-protocol" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e6ae951b36fa2f8d9dd6e1af6da2fcaba13d7c866cf6a9e65deda9dc6c5fe4" +dependencies = [ + "anyhow", + "chrono", + "derive_more 2.0.1", + "futures 0.3.31", + "log", + "parking_lot", + "schemars", + "semver", + "serde", + "serde_json", +] + [[package]] name = "ahash" version = "0.7.8" @@ -247,6 +349,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai_onboarding" +version = "0.1.0" +dependencies = [ + "client", + "component", + "gpui", + "language_model", + "proto", + "serde", + "smallvec", + "telemetry", + "ui", + "workspace-hack", + "zed_actions", +] + [[package]] name = "alacritty_terminal" version = "0.25.1-dev" @@ -610,7 +729,7 @@ dependencies = [ "anyhow", "async-trait", "collections", - "derive_more", + "derive_more 0.99.19", "extension", "futures 0.3.31", "gpui", @@ -673,10 +792,11 @@ dependencies = [ "clock", "collections", "ctor", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "gpui", "icons", + "indoc", "language", "language_model", "log", @@ -709,7 +829,8 @@ dependencies = [ "clock", "collections", "component", - "derive_more", + "derive_more 0.99.19", + "diffy", "editor", "feature_flags", "fs", @@ -1166,7 +1287,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more", + "derive_more 0.99.19", "gpui", "parking_lot", "rodio", @@ -1765,9 +1886,7 @@ version = "0.1.0" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", - "futures 0.3.31", "http_client", - "tokio", "workspace-hack", ] @@ -2078,7 +2197,7 @@ dependencies = [ [[package]] name = "blade-graphics" version = "0.6.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "ash", "ash-window", @@ -2111,7 +2230,7 @@ dependencies = [ [[package]] name = "blade-macros" version = "0.3.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "proc-macro2", "quote", @@ -2121,7 +2240,7 @@ dependencies = [ [[package]] name = "blade-util" version = "0.2.0" -source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad" +source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5" dependencies = [ "blade-graphics", "bytemuck", @@ -2859,7 +2978,7 @@ dependencies = [ "cocoa 0.26.0", "collections", "credentials_provider", - "derive_more", + "derive_more 0.99.19", "feature_flags", "fs", "futures 0.3.31", @@ -3043,10 +3162,11 @@ dependencies = [ "context_server", "ctor", "dap", + "dap-types", "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more", + "derive_more 0.99.19", "editor", "envy", "extension", @@ -3099,6 +3219,7 @@ dependencies = [ "session", "settings", "sha2", + "smol", "sqlx", "strum 0.27.1", "subtle", @@ -3251,7 +3372,7 @@ name = "command_palette_hooks" version = "0.1.0" dependencies = [ "collections", - "derive_more", + "derive_more 0.99.19", "gpui", "workspace-hack", ] @@ -3339,12 +3460,14 @@ dependencies = [ "futures 0.3.31", "gpui", "log", + "net", "parking_lot", "postage", "schemars", "serde", "serde_json", "smol", + "tempfile", "url", "util", "workspace-hack", @@ -4160,6 +4283,7 @@ dependencies = [ "serde", "serde_json", "shlex", + "smol", "task", "util", "workspace-hack", @@ -4326,17 +4450,21 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", + "hex", "indoc", "itertools 0.14.0", "language", "log", "menu", + "notifications", "parking_lot", + "parse_int", "paths", "picker", "pretty_assertions", "project", "rpc", + "schemars", "serde", "serde_json", "serde_json_lenient", @@ -4455,6 +4583,27 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", + "unicode-xid", +] + [[package]] name = "derive_refineable" version = "0.1.0" @@ -6153,7 +6302,7 @@ dependencies = [ "askpass", "async-trait", "collections", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "git2", "gpui", @@ -6224,6 +6373,7 @@ dependencies = [ "buffer_diff", "call", "chrono", + "client", "collections", "command_palette_hooks", "component", @@ -7170,7 +7320,7 @@ dependencies = [ "core-video", "cosmic-text", "ctor", - "derive_more", + "derive_more 0.99.19", "embed-resource", "env_logger 0.11.8", "etagere", @@ -7265,9 +7415,9 @@ dependencies = [ [[package]] name = "grid" -version = "0.13.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d196ffc1627db18a531359249b2bf8416178d84b729f3cebeb278f285fb9b58c" +checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa" [[package]] name = "group" @@ -7716,9 +7866,10 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes 1.10.1", - "derive_more", + "derive_more 0.99.19", "futures 0.3.31", "http 1.3.1", + "http-body 1.0.1", "log", "serde", "serde_json", @@ -8154,7 +8305,7 @@ dependencies = [ "async-trait", "cargo_metadata", "collections", - "derive_more", + "derive_more 0.99.19", "extension", "fs", "futures 0.3.31", @@ -8881,6 +9032,7 @@ dependencies = [ "task", "text", "theme", + "toml 0.8.20", "tree-sitter", "tree-sitter-elixir", "tree-sitter-embedded-template", @@ -8913,6 +9065,7 @@ dependencies = [ "gpui", "language", "lsp", + "project", "serde", "serde_json", "util", @@ -8951,6 +9104,7 @@ dependencies = [ name = "language_models" version = "0.1.0" dependencies = [ + "ai_onboarding", "anthropic", "anyhow", "aws-config", @@ -8961,12 +9115,11 @@ dependencies = [ "client", "collections", "component", + "convert_case 0.8.0", "copilot", "credentials_provider", "deepseek", "editor", - "feature_flags", - "fs", "futures 0.3.31", "google_ai", "gpui", @@ -9000,6 +9153,7 @@ dependencies = [ "util", "vercel", "workspace-hack", + "x_ai", "zed_llm_client", ] @@ -9032,7 +9186,6 @@ dependencies = [ "collections", "copilot", "editor", - "feature_flags", "futures 0.3.31", "gpui", "itertools 0.14.0", @@ -9592,12 +9745,11 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.95.1" -source = "git+https://github.com/zed-industries/lsp-types?rev=c9c189f1c5dd53c624a419ce35bc77ad6a908d18#c9c189f1c5dd53c624a419ce35bc77ad6a908d18" +source = "git+https://github.com/zed-industries/lsp-types?rev=39f629bdd03d59abd786ed9fc27e8bca02c0c0ec#39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" dependencies = [ "bitflags 1.3.2", "serde", "serde_json", - "serde_repr", "url", ] @@ -10192,6 +10344,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "nc" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "net", + "smol", + "workspace-hack", +] + [[package]] name = "ndk" version = "0.8.0" @@ -10840,6 +11003,23 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "onboarding" +version = "0.1.0" +dependencies = [ + "anyhow", + "command_palette_hooks", + "db", + "feature_flags", + "fs", + "gpui", + "settings", + "theme", + "ui", + "workspace", + "workspace-hack", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -11209,6 +11389,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parse_int" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c464266693329dd5a8715098c7f86e6c5fd5d985018b8318f53d9c6c2b21a31" +dependencies = [ + "num-traits", +] + [[package]] name = "partial-json-fixer" version = "0.5.3" @@ -12252,6 +12441,7 @@ dependencies = [ "anyhow", "askpass", "async-trait", + "base64 0.22.1", "buffer_diff", "circular-buffer", "client", @@ -12297,6 +12487,7 @@ dependencies = [ "sha2", "shellexpand 2.1.2", "shlex", + "smallvec", "smol", "snippet", "snippet_provider", @@ -14031,7 +14222,7 @@ dependencies = [ [[package]] name = "scap" version = "0.0.8" -source = "git+https://github.com/zed-industries/scap?rev=08f0a01417505cc0990b9931a37e5120db92e0d0#08f0a01417505cc0990b9931a37e5120db92e0d0" +source = "git+https://github.com/zed-industries/scap?rev=808aa5c45b41e8f44729d02e38fd00a2fe2722e7#808aa5c45b41e8f44729d02e38fd00a2fe2722e7" dependencies = [ "anyhow", "cocoa 0.25.0", @@ -14078,10 +14269,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984" dependencies = [ + "chrono", "dyn-clone", "indexmap", "ref-cast", "schemars_derive", + "semver", "serde", "serde_json", ] @@ -14600,19 +14793,25 @@ dependencies = [ "fs", "fuzzy", "gpui", + "itertools 0.14.0", "language", "log", "menu", + "notifications", "paths", "project", "schemars", "search", "serde", + "serde_json", "settings", + "telemetry", + "tempfile", "theme", "tree-sitter-json", "tree-sitter-rust", "ui", + "ui_input", "util", "workspace", "workspace-hack", @@ -15775,13 +15974,12 @@ dependencies = [ [[package]] name = "taffy" -version = "0.4.4" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ec17858c2d465b2f734b798b920818a974faf0babb15d7fef81818a4b2d16f1" +checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c" dependencies = [ "arrayvec", "grid", - "num-traits", "serde", "slotmap", ] @@ -16040,7 +16238,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more", + "derive_more 0.99.19", "fs", "futures 0.3.31", "gpui", @@ -16323,6 +16521,7 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" name = "title_bar" version = "0.1.0" dependencies = [ + "anyhow", "auto_update", "call", "chrono", @@ -16339,6 +16538,7 @@ dependencies = [ "schemars", "serde", "settings", + "settings_ui", "smallvec", "story", "telemetry", @@ -18307,7 +18507,6 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "feature_flags", "futures 0.3.31", "gpui", "http_client", @@ -18568,8 +18767,7 @@ dependencies = [ [[package]] name = "windows-capture" version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59d10b4be8b907c7055bc7270dd68d2b920978ffacc1599dcb563a79f0e68d16" +source = "git+https://github.com/zed-industries/windows-capture.git?rev=f0d6c1b6691db75461b732f6d5ff56eed002eeb9#f0d6c1b6691db75461b732f6d5ff56eed002eeb9" dependencies = [ "clap", "ctrlc", @@ -19579,6 +19777,7 @@ dependencies = [ "rustix 1.0.7", "rustls 0.23.26", "rustls-webpki 0.103.1", + "schemars", "scopeguard", "sea-orm", "sea-query-binder", @@ -19731,6 +19930,17 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" +[[package]] +name = "x_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "schemars", + "serde", + "strum 0.27.1", + "workspace-hack", +] + [[package]] name = "xattr" version = "0.2.3" @@ -19972,10 +20182,11 @@ dependencies = [ [[package]] name = "zed" -version = "0.195.0" +version = "0.198.0" dependencies = [ "activity_indicator", "agent", + "agent_servers", "agent_settings", "agent_ui", "anyhow", @@ -20012,6 +20223,7 @@ dependencies = [ "extension", "extension_host", "extensions_ui", + "feature_flags", "feedback", "file_finder", "fs", @@ -20045,9 +20257,11 @@ dependencies = [ "menu", "migrator", "mimalloc", + "nc", "nix 0.29.0", "node_runtime", "notifications", + "onboarding", "outline", "outline_panel", "parking_lot", @@ -20354,6 +20568,7 @@ dependencies = [ name = "zeta" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", "arrayvec", "call", @@ -20361,6 +20576,7 @@ dependencies = [ "clock", "collections", "command_palette_hooks", + "copilot", "ctor", "db", "editor", @@ -20375,8 +20591,6 @@ dependencies = [ "language_model", "log", "menu", - "migrator", - "paths", "postage", "project", "proto", diff --git a/Cargo.toml b/Cargo.toml index a4d8b3cb95..ea01003f36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,12 @@ resolver = "2" members = [ "crates/activity_indicator", + "crates/acp_thread", "crates/agent_ui", "crates/agent", "crates/agent_settings", + "crates/ai_onboarding", + "crates/agent_servers", "crates/anthropic", "crates/askpass", "crates/assets", @@ -100,10 +103,12 @@ members = [ "crates/migrator", "crates/mistral", "crates/multi_buffer", + "crates/nc", "crates/net", "crates/node_runtime", "crates/notifications", "crates/ollama", + "crates/onboarding", "crates/open_ai", "crates/open_router", "crates/outline", @@ -177,6 +182,7 @@ members = [ "crates/welcome", "crates/workspace", "crates/worktree", + "crates/x_ai", "crates/zed", "crates/zed_actions", "crates/zeta", @@ -216,11 +222,14 @@ edition = "2024" # Workspace member crates # -activity_indicator = { path = "crates/activity_indicator" } +acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } +agent_servers = { path = "crates/agent_servers" } ai = { path = "crates/ai" } +ai_onboarding = { path = "crates/ai_onboarding" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } assets = { path = "crates/assets" } @@ -312,10 +321,12 @@ menu = { path = "crates/menu" } migrator = { path = "crates/migrator" } mistral = { path = "crates/mistral" } multi_buffer = { path = "crates/multi_buffer" } +nc = { path = "crates/nc" } net = { path = "crates/net" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } ollama = { path = "crates/ollama" } +onboarding = { path = "crates/onboarding" } open_ai = { path = "crates/open_ai" } open_router = { path = "crates/open_router", features = ["schemars"] } outline = { path = "crates/outline" } @@ -390,6 +401,7 @@ web_search_providers = { path = "crates/web_search_providers" } welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } +x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zeta = { path = "crates/zeta" } @@ -400,6 +412,8 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # +agentic-coding-protocol = "0.0.10" +agent-client-protocol = "0.0.10" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -427,9 +441,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } base64 = "0.22" bitflags = "2.6.0" -blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } -blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" } +blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } +blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" } blake3 = "1.5.3" bytes = "1.0" cargo_metadata = "0.19" @@ -469,6 +483,7 @@ heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" html5ever = "0.27.0" http = "1.1" +http-body = "1.0" hyper = "0.14" ignore = "0.4.22" image = "0.25.1" @@ -482,18 +497,18 @@ json_dotpath = "1.1" jsonschema = "0.30.0" jsonwebtoken = "9.3" jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } -jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" } libc = "0.2" libsqlite3-sys = { version = "0.30.1", features = ["bundled"] } linkify = "0.10.0" log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } -lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" } +lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "39f629bdd03d59abd786ed9fc27e8bca02c0c0ec" } markup5ever_rcdom = "0.3.0" metal = "0.29" moka = { version = "0.12.10", features = ["sync"] } naga = { version = "25.0", features = ["wgsl-in"] } nanoid = "0.4" -nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } +nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" } nix = "0.29" num-format = "0.4.4" objc = "0.2" @@ -502,6 +517,7 @@ ordered-float = "2.1.1" palette = { version = "0.7.5", default-features = false, features = ["std"] } parking_lot = "0.12.1" partial-json-fixer = "0.5.3" +parse_int = "0.9" pathdiff = "0.2" pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } @@ -533,7 +549,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77 "stream", ] } rsa = "0.9.6" -runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ +runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [ "async-dispatcher-runtime", ] } rust-embed = { version = "8.4", features = ["include-exclude"] } @@ -541,7 +557,7 @@ rustc-demangle = "0.1.23" rustc-hash = "2.1.0" rustls = { version = "0.23.26" } rustls-platform-verifier = "0.5.0" -scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false } +scap = { git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7", default-features = false } schemars = { version = "1.0", features = ["indexmap2"] } semver = "1.0" serde = { version = "1.0", features = ["derive", "rc"] } @@ -695,6 +711,7 @@ features = [ [patch.crates-io] notify = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" } notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" } +windows-capture = { git = "https://github.com/zed-industries/windows-capture.git", rev = "f0d6c1b6691db75461b732f6d5ff56eed002eeb9" } # Makes the workspace hack crate refer to the local one, but only when you're building locally workspace-hack = { path = "tooling/workspace-hack" } @@ -703,6 +720,11 @@ workspace-hack = { path = "tooling/workspace-hack" } split-debuginfo = "unpacked" codegen-units = 16 +# mirror configuration for crates compiled for the build platform +# (without this cargo will compile ~400 crates twice) +[profile.dev.build-override] +codegen-units = 16 + [profile.dev.package] taffy = { opt-level = 3 } cranelift-codegen = { opt-level = 3 } diff --git a/assets/icons/ai_claude.svg b/assets/icons/ai_claude.svg new file mode 100644 index 0000000000..a3e3e1f4cd --- /dev/null +++ b/assets/icons/ai_claude.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_gemini.svg b/assets/icons/ai_gemini.svg new file mode 100644 index 0000000000..bdde44ed24 --- /dev/null +++ b/assets/icons/ai_gemini.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/ai_open_ai_compat.svg b/assets/icons/ai_open_ai_compat.svg new file mode 100644 index 0000000000..f6557caac3 --- /dev/null +++ b/assets/icons/ai_open_ai_compat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg index cc8597729a..94f2849146 100644 --- a/assets/icons/ai_open_router.svg +++ b/assets/icons/ai_open_router.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg new file mode 100644 index 0000000000..289525c8ef --- /dev/null +++ b/assets/icons/ai_x_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/clipboard.svg b/assets/icons/clipboard.svg deleted file mode 100644 index 5c8842f3b7..0000000000 --- a/assets/icons/clipboard.svg +++ /dev/null @@ -1 +0,0 @@ - diff --git a/assets/icons/debug.svg b/assets/icons/debug.svg index 8cea0c4604..ff51e42b1a 100644 --- a/assets/icons/debug.svg +++ b/assets/icons/debug.svg @@ -1 +1,12 @@ - + + + + + + + + + + + + diff --git a/assets/icons/search_code.svg b/assets/icons/equal.svg similarity index 56% rename from assets/icons/search_code.svg rename to assets/icons/equal.svg index 1cc9affeb8..9b3a151a12 100644 --- a/assets/icons/search_code.svg +++ b/assets/icons/equal.svg @@ -1 +1 @@ - + diff --git a/assets/icons/file_delete.svg b/assets/icons/file_delete.svg deleted file mode 100644 index b84f79958f..0000000000 --- a/assets/icons/file_delete.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/assets/icons/file_tree.svg b/assets/icons/file_tree.svg index 4c921b1351..a140cd70b1 100644 --- a/assets/icons/file_tree.svg +++ b/assets/icons/file_tree.svg @@ -1,5 +1,5 @@ - - - + + + diff --git a/assets/icons/git_branch_small.svg b/assets/icons/git_branch_small.svg index d23fc176ac..22832d6fed 100644 --- a/assets/icons/git_branch_small.svg +++ b/assets/icons/git_branch_small.svg @@ -1,6 +1,7 @@ - - - - - + + + + + + diff --git a/assets/icons/list_tree.svg b/assets/icons/list_tree.svg index 8cf157ec13..09872a60f7 100644 --- a/assets/icons/list_tree.svg +++ b/assets/icons/list_tree.svg @@ -1 +1,7 @@ - \ No newline at end of file + + + + + + + diff --git a/assets/icons/location_edit.svg b/assets/icons/location_edit.svg new file mode 100644 index 0000000000..de82e8db4e --- /dev/null +++ b/assets/icons/location_edit.svg @@ -0,0 +1 @@ + diff --git a/assets/icons/new_from_summary.svg b/assets/icons/new_from_summary.svg new file mode 100644 index 0000000000..3b61ca51a0 --- /dev/null +++ b/assets/icons/new_from_summary.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/icons/new_text_thread.svg b/assets/icons/new_text_thread.svg new file mode 100644 index 0000000000..75afa934a0 --- /dev/null +++ b/assets/icons/new_text_thread.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/icons/new_thread.svg b/assets/icons/new_thread.svg new file mode 100644 index 0000000000..8c2596a4c9 --- /dev/null +++ b/assets/icons/new_thread.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/play_filled.svg b/assets/icons/play_filled.svg new file mode 100644 index 0000000000..387304ef04 --- /dev/null +++ b/assets/icons/play_filled.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/terminal_alt.svg b/assets/icons/terminal_alt.svg new file mode 100644 index 0000000000..7afb89db21 --- /dev/null +++ b/assets/icons/terminal_alt.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/todo_complete.svg b/assets/icons/todo_complete.svg new file mode 100644 index 0000000000..9fa2e818bb --- /dev/null +++ b/assets/icons/todo_complete.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/todo_pending.svg b/assets/icons/todo_pending.svg new file mode 100644 index 0000000000..dfb013b52b --- /dev/null +++ b/assets/icons/todo_pending.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/todo_progress.svg b/assets/icons/todo_progress.svg new file mode 100644 index 0000000000..9b2ed7375d --- /dev/null +++ b/assets/icons/todo_progress.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/icons/tool_bulb.svg b/assets/icons/tool_bulb.svg new file mode 100644 index 0000000000..54d5ac5fd7 --- /dev/null +++ b/assets/icons/tool_bulb.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/tool_copy.svg b/assets/icons/tool_copy.svg new file mode 100644 index 0000000000..e722d8a022 --- /dev/null +++ b/assets/icons/tool_copy.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/tool_delete_file.svg b/assets/icons/tool_delete_file.svg new file mode 100644 index 0000000000..3276f3d78e --- /dev/null +++ b/assets/icons/tool_delete_file.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/tool_diagnostics.svg b/assets/icons/tool_diagnostics.svg new file mode 100644 index 0000000000..c659d96781 --- /dev/null +++ b/assets/icons/tool_diagnostics.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/tool_folder.svg b/assets/icons/tool_folder.svg new file mode 100644 index 0000000000..9d3ac299d2 --- /dev/null +++ b/assets/icons/tool_folder.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/tool_hammer.svg b/assets/icons/tool_hammer.svg new file mode 100644 index 0000000000..e66173ce70 --- /dev/null +++ b/assets/icons/tool_hammer.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/tool_notification.svg b/assets/icons/tool_notification.svg new file mode 100644 index 0000000000..7510b32040 --- /dev/null +++ b/assets/icons/tool_notification.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/tool_pencil.svg b/assets/icons/tool_pencil.svg new file mode 100644 index 0000000000..b913015c08 --- /dev/null +++ b/assets/icons/tool_pencil.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/tool_read.svg b/assets/icons/tool_read.svg new file mode 100644 index 0000000000..458cbb3660 --- /dev/null +++ b/assets/icons/tool_read.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/assets/icons/tool_regex.svg b/assets/icons/tool_regex.svg new file mode 100644 index 0000000000..0432cd570f --- /dev/null +++ b/assets/icons/tool_regex.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/tool_search.svg b/assets/icons/tool_search.svg new file mode 100644 index 0000000000..4f2750cfa2 --- /dev/null +++ b/assets/icons/tool_search.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/tool_terminal.svg b/assets/icons/tool_terminal.svg new file mode 100644 index 0000000000..5154fa8e70 --- /dev/null +++ b/assets/icons/tool_terminal.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/tool_web.svg b/assets/icons/tool_web.svg new file mode 100644 index 0000000000..6250a9f05a --- /dev/null +++ b/assets/icons/tool_web.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/user_group.svg b/assets/icons/user_group.svg index aa99277646..ac1f7bdc63 100644 --- a/assets/icons/user_group.svg +++ b/assets/icons/user_group.svg @@ -1,3 +1,5 @@ - + + + diff --git a/assets/icons/zed_assistant.svg b/assets/icons/zed_assistant.svg index 693d86f929..d21252de8c 100644 --- a/assets/icons/zed_assistant.svg +++ b/assets/icons/zed_assistant.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 6f50945828..31adef8cd5 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -269,7 +269,15 @@ } }, { - "context": "MessageEditor > Editor", + "context": "AgentPanel && external_agent_thread", + "use_key_equivalents": true, + "bindings": { + "ctrl-n": "agent::NewExternalAgentThread", + "ctrl-alt-t": "agent::NewThread" + } + }, + { + "context": "MessageEditor && !Picker > Editor && !use_modifier_to_send", "bindings": { "enter": "agent::Chat", "ctrl-enter": "agent::ChatWithFollow", @@ -279,6 +287,17 @@ "ctrl-shift-n": "agent::RejectAll" } }, + { + "context": "MessageEditor && !Picker > Editor && use_modifier_to_send", + "bindings": { + "ctrl-enter": "agent::Chat", + "enter": "editor::Newline", + "ctrl-i": "agent::ToggleProfileSelector", + "shift-ctrl-r": "agent::OpenAgentDiff", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" + } + }, { "context": "EditMessageEditor > Editor", "bindings": { @@ -306,6 +325,16 @@ "enter": "agent::AcceptSuggestedContext" } }, + { + "context": "AcpThread > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "agent::Chat", + "up": "agent::PreviousHistoryMessage", + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff" + } + }, { "context": "ThreadHistory", "bindings": { @@ -401,7 +430,7 @@ "ctrl-shift-pagedown": "pane::SwapItemRight", "ctrl-f4": ["pane::CloseActiveItem", { "close_pinned": false }], "ctrl-w": ["pane::CloseActiveItem", { "close_pinned": false }], - "alt-ctrl-t": ["pane::CloseInactiveItems", { "close_pinned": false }], + "alt-ctrl-t": ["pane::CloseOtherItems", { "close_pinned": false }], "alt-ctrl-shift-w": "workspace::CloseInactiveTabsAndPanes", "ctrl-k e": ["pane::CloseItemsToTheLeft", { "close_pinned": false }], "ctrl-k t": ["pane::CloseItemsToTheRight", { "close_pinned": false }], @@ -454,11 +483,10 @@ "ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip "ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "ctrl-k ctrl-i": "editor::Hover", + "ctrl-k ctrl-b": "editor::BlameHover", "ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }], - "ctrl-u": "editor::UndoSelection", - "ctrl-shift-u": "editor::RedoSelection", - "f8": "editor::GoToDiagnostic", - "shift-f8": "editor::GoToPreviousDiagnostic", + "f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }], "f2": "editor::Rename", "f12": "editor::GoToDefinition", "alt-f12": "editor::GoToDefinitionSplit", @@ -568,7 +596,7 @@ "ctrl-shift-f": "pane::DeploySearch", "ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }], "ctrl-shift-t": "pane::ReopenClosedItem", - "ctrl-k ctrl-s": "zed::OpenKeymap", + "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", @@ -634,6 +662,8 @@ { "context": "Editor", "bindings": { + "ctrl-u": "editor::UndoSelection", + "ctrl-shift-u": "editor::RedoSelection", "ctrl-shift-j": "editor::JoinLines", "ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart", "ctrl-alt-h": "editor::DeleteToPreviousSubwordStart", @@ -838,6 +868,7 @@ "alt-shift-y": "git::UnstageFile", "ctrl-alt-y": "git::ToggleStaged", "space": "git::ToggleStaged", + "shift-space": "git::StageRange", "tab": "git_panel::FocusEditor", "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", @@ -898,7 +929,7 @@ } }, { - "context": "GitPanel > Editor", + "context": "CommitEditor > Editor", "bindings": { "escape": "git_panel::FocusChanges", "tab": "git_panel::FocusChanges", @@ -944,9 +975,14 @@ "context": "CollabPanel && not_editing", "bindings": { "ctrl-backspace": "collab_panel::Remove", - "space": "menu::Confirm", - "ctrl-up": "collab_panel::MoveChannelUp", - "ctrl-down": "collab_panel::MoveChannelDown" + "space": "menu::Confirm" + } + }, + { + "context": "CollabPanel", + "bindings": { + "alt-up": "collab_panel::MoveChannelUp", + "alt-down": "collab_panel::MoveChannelDown" } }, { @@ -980,6 +1016,7 @@ { "context": "FileFinder || (FileFinder > Picker > Editor)", "bindings": { + "ctrl-p": "file_finder::Toggle", "ctrl-shift-a": "file_finder::ToggleSplitMenu", "ctrl-shift-i": "file_finder::ToggleFilterMenu" } @@ -1095,7 +1132,40 @@ "context": "KeymapEditor", "use_key_equivalents": true, "bindings": { - "ctrl-f": "search::FocusSearch" + "ctrl-f": "search::FocusSearch", + "alt-find": "keymap_editor::ToggleKeystrokeSearch", + "alt-ctrl-f": "keymap_editor::ToggleKeystrokeSearch", + "alt-c": "keymap_editor::ToggleConflictFilter", + "enter": "keymap_editor::EditBinding", + "alt-enter": "keymap_editor::CreateBinding", + "ctrl-c": "keymap_editor::CopyAction", + "ctrl-shift-c": "keymap_editor::CopyContext", + "ctrl-t": "keymap_editor::ShowMatchingKeybinds" + } + }, + { + "context": "KeystrokeInput", + "use_key_equivalents": true, + "bindings": { + "enter": "keystroke_input::StartRecording", + "escape escape escape": "keystroke_input::StopRecording", + "delete": "keystroke_input::ClearKeystrokes" + } + }, + { + "context": "KeybindEditorModal", + "use_key_equivalents": true, + "bindings": { + "ctrl-enter": "menu::Confirm", + "escape": "menu::Cancel" + } + }, + { + "context": "KeybindEditorModal > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index cbc90c05e6..f942c6f8ae 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -310,7 +310,15 @@ } }, { - "context": "MessageEditor > Editor", + "context": "AgentPanel && external_agent_thread", + "use_key_equivalents": true, + "bindings": { + "cmd-n": "agent::NewExternalAgentThread", + "cmd-alt-t": "agent::NewThread" + } + }, + { + "context": "MessageEditor && !Picker > Editor && !use_modifier_to_send", "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", @@ -321,6 +329,18 @@ "cmd-shift-n": "agent::RejectAll" } }, + { + "context": "MessageEditor && !Picker > Editor && use_modifier_to_send", + "use_key_equivalents": true, + "bindings": { + "cmd-enter": "agent::Chat", + "enter": "editor::Newline", + "cmd-i": "agent::ToggleProfileSelector", + "shift-ctrl-r": "agent::OpenAgentDiff", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll" + } + }, { "context": "EditMessageEditor > Editor", "use_key_equivalents": true, @@ -357,6 +377,16 @@ "ctrl--": "pane::GoBack" } }, + { + "context": "AcpThread > Editor", + "use_key_equivalents": true, + "bindings": { + "enter": "agent::Chat", + "up": "agent::PreviousHistoryMessage", + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff" + } + }, { "context": "ThreadHistory", "bindings": { @@ -459,7 +489,7 @@ "ctrl-shift-pageup": "pane::SwapItemLeft", "ctrl-shift-pagedown": "pane::SwapItemRight", "cmd-w": ["pane::CloseActiveItem", { "close_pinned": false }], - "alt-cmd-t": ["pane::CloseInactiveItems", { "close_pinned": false }], + "alt-cmd-t": ["pane::CloseOtherItems", { "close_pinned": false }], "ctrl-alt-cmd-w": "workspace::CloseInactiveTabsAndPanes", "cmd-k e": ["pane::CloseItemsToTheLeft", { "close_pinned": false }], "cmd-k t": ["pane::CloseItemsToTheRight", { "close_pinned": false }], @@ -507,11 +537,10 @@ "ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch "cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "cmd-k cmd-i": "editor::Hover", + "cmd-k cmd-b": "editor::BlameHover", "cmd-/": ["editor::ToggleComments", { "advance_downwards": false }], - "cmd-u": "editor::UndoSelection", - "cmd-shift-u": "editor::RedoSelection", - "f8": "editor::GoToDiagnostic", - "shift-f8": "editor::GoToPreviousDiagnostic", + "f8": ["editor::GoToDiagnostic", { "severity": { "min": "hint", "max": "error" } }], + "shift-f8": ["editor::GoToPreviousDiagnostic", { "severity": { "min": "hint", "max": "error" } }], "f2": "editor::Rename", "f12": "editor::GoToDefinition", "alt-f12": "editor::GoToDefinitionSplit", @@ -634,7 +663,7 @@ "cmd-shift-f": "pane::DeploySearch", "cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }], "cmd-shift-t": "pane::ReopenClosedItem", - "cmd-k cmd-s": "zed::OpenKeymap", + "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", @@ -696,6 +725,8 @@ "context": "Editor", "use_key_equivalents": true, "bindings": { + "cmd-u": "editor::UndoSelection", + "cmd-shift-u": "editor::RedoSelection", "ctrl-j": "editor::JoinLines", "ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart", "ctrl-alt-h": "editor::DeleteToPreviousSubwordStart", @@ -912,6 +943,7 @@ "enter": "menu::Confirm", "cmd-alt-y": "git::ToggleStaged", "space": "git::ToggleStaged", + "shift-space": "git::StageRange", "cmd-y": "git::StageFile", "cmd-shift-y": "git::UnstageFile", "alt-down": "git_panel::FocusEditor", @@ -944,7 +976,7 @@ } }, { - "context": "GitPanel > Editor", + "context": "CommitEditor > Editor", "use_key_equivalents": true, "bindings": { "enter": "editor::Newline", @@ -1005,9 +1037,15 @@ "use_key_equivalents": true, "bindings": { "ctrl-backspace": "collab_panel::Remove", - "space": "menu::Confirm", - "cmd-up": "collab_panel::MoveChannelUp", - "cmd-down": "collab_panel::MoveChannelDown" + "space": "menu::Confirm" + } + }, + { + "context": "CollabPanel", + "use_key_equivalents": true, + "bindings": { + "alt-up": "collab_panel::MoveChannelUp", + "alt-down": "collab_panel::MoveChannelDown" } }, { @@ -1079,13 +1117,16 @@ "ctrl-cmd-space": "terminal::ShowCharacterPalette", "cmd-c": "terminal::Copy", "cmd-v": "terminal::Paste", + "cmd-f": "buffer_search::Deploy", "cmd-a": "editor::SelectAll", "cmd-k": "terminal::Clear", "cmd-n": "workspace::NewTerminal", "ctrl-enter": "assistant::InlineAssist", "ctrl-_": null, // emacs undo // Some nice conveniences - "cmd-backspace": ["terminal::SendText", "\u0015"], + "cmd-backspace": ["terminal::SendText", "\u0015"], // ctrl-u: clear line + "alt-delete": ["terminal::SendText", "\u001bd"], // alt-d: delete word forward + "cmd-delete": ["terminal::SendText", "\u000b"], // ctrl-k: delete to end of line "cmd-right": ["terminal::SendText", "\u0005"], "cmd-left": ["terminal::SendText", "\u0001"], // Terminal.app compatibility @@ -1194,7 +1235,39 @@ "context": "KeymapEditor", "use_key_equivalents": true, "bindings": { - "cmd-f": "search::FocusSearch" + "cmd-f": "search::FocusSearch", + "cmd-alt-f": "keymap_editor::ToggleKeystrokeSearch", + "cmd-alt-c": "keymap_editor::ToggleConflictFilter", + "enter": "keymap_editor::EditBinding", + "alt-enter": "keymap_editor::CreateBinding", + "cmd-c": "keymap_editor::CopyAction", + "cmd-shift-c": "keymap_editor::CopyContext", + "cmd-t": "keymap_editor::ShowMatchingKeybinds" + } + }, + { + "context": "KeystrokeInput", + "use_key_equivalents": true, + "bindings": { + "enter": "keystroke_input::StartRecording", + "escape escape escape": "keystroke_input::StopRecording", + "delete": "keystroke_input::ClearKeystrokes" + } + }, + { + "context": "KeybindEditorModal", + "use_key_equivalents": true, + "bindings": { + "cmd-enter": "menu::Confirm", + "escape": "menu::Cancel" + } + }, + { + "context": "KeybindEditorModal > Editor", + "use_key_equivalents": true, + "bindings": { + "up": "menu::SelectPrevious", + "down": "menu::SelectNext" } } ] diff --git a/assets/keymaps/initial.json b/assets/keymaps/initial.json index 0cfd28f0e5..8e4fe59f44 100644 --- a/assets/keymaps/initial.json +++ b/assets/keymaps/initial.json @@ -13,9 +13,9 @@ } }, { - "context": "Editor && vim_mode == insert && !menu", + "context": "Editor && vim_mode == insert", "bindings": { - // "j k": "vim::SwitchToNormalMode" + // "j k": "vim::NormalBefore" } } ] diff --git a/assets/keymaps/linux/emacs.json b/assets/keymaps/linux/emacs.json index 0c633efabe..0ff3796f03 100755 --- a/assets/keymaps/linux/emacs.json +++ b/assets/keymaps/linux/emacs.json @@ -114,7 +114,7 @@ "ctrl-x o": "workspace::ActivateNextPane", // other-window "ctrl-x k": "pane::CloseActiveItem", // kill-buffer "ctrl-x 0": "pane::CloseActiveItem", // delete-window - "ctrl-x 1": "pane::CloseInactiveItems", // delete-other-windows + "ctrl-x 1": "pane::CloseOtherItems", // delete-other-windows "ctrl-x 2": "pane::SplitDown", // split-window-below "ctrl-x 3": "pane::SplitRight", // split-window-right "ctrl-x ctrl-f": "file_finder::Toggle", // find-file diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index dbf50b0fce..629333663d 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -66,22 +66,51 @@ "context": "Editor && mode == full", "bindings": { "ctrl-f12": "outline::Toggle", - "alt-7": "outline::Toggle", + "ctrl-r": ["buffer_search::Deploy", { "replace_enabled": true }], "ctrl-shift-n": "file_finder::Toggle", "ctrl-g": "go_to_line::Toggle", "alt-enter": "editor::ToggleCodeActions" } }, + { + "context": "BufferSearchBar", + "bindings": { + "shift-enter": "search::SelectPreviousMatch" + } + }, + { + "context": "BufferSearchBar || ProjectSearchBar", + "bindings": { + "alt-c": "search::ToggleCaseSensitive", + "alt-e": "search::ToggleSelection", + "alt-x": "search::ToggleRegex", + "alt-w": "search::ToggleWholeWord" + } + }, { "context": "Workspace", "bindings": { + "ctrl-shift-f12": "workspace::CloseAllDocks", + "ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], + "alt-shift-f10": "task::Spawn", + "ctrl-e": "file_finder::Toggle", + "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "ctrl-shift-n": "file_finder::Toggle", "ctrl-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "ctrl-alt-shift-n": "project_symbols::Toggle", + "alt-0": "git_panel::ToggleFocus", "alt-1": "workspace::ToggleLeftDock", - "ctrl-e": "tab_switcher::Toggle", - "alt-6": "diagnostics::Deploy" + "alt-5": "debug_panel::ToggleFocus", + "alt-6": "diagnostics::Deploy", + "alt-7": "outline_panel::ToggleFocus" + } + }, + { + "context": "Workspace || Editor", + "bindings": { + "alt-f12": "terminal_panel::ToggleFocus", + "ctrl-shift-k": "git::Push" } }, { @@ -95,10 +124,33 @@ "context": "ProjectPanel", "bindings": { "enter": "project_panel::Open", + "ctrl-shift-f": "project_panel::NewSearchInDirectory", "backspace": ["project_panel::Trash", { "skip_prompt": false }], "delete": ["project_panel::Trash", { "skip_prompt": false }], "shift-delete": ["project_panel::Delete", { "skip_prompt": false }], "shift-f6": "project_panel::Rename" } + }, + { + "context": "Terminal", + "bindings": { + "ctrl-shift-t": "workspace::NewTerminal", + "alt-f12": "workspace::CloseActiveDock", + "alt-left": "pane::ActivatePreviousItem", + "alt-right": "pane::ActivateNextItem", + "ctrl-up": "terminal::ScrollLineUp", + "ctrl-down": "terminal::ScrollLineDown", + "shift-pageup": "terminal::ScrollPageUp", + "shift-pagedown": "terminal::ScrollPageDown" + } + }, + { "context": "GitPanel", "bindings": { "alt-0": "workspace::CloseActiveDock" } }, + { "context": "ProjectPanel", "bindings": { "alt-1": "workspace::CloseActiveDock" } }, + { "context": "DebugPanel", "bindings": { "alt-5": "workspace::CloseActiveDock" } }, + { "context": "Diagnostics > Editor", "bindings": { "alt-6": "pane::CloseActiveItem" } }, + { "context": "OutlinePanel", "bindings": { "alt-7": "workspace::CloseActiveDock" } }, + { + "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { "escape": "editor::ToggleFocus" } } ] diff --git a/assets/keymaps/macos/emacs.json b/assets/keymaps/macos/emacs.json index 0c633efabe..0ff3796f03 100755 --- a/assets/keymaps/macos/emacs.json +++ b/assets/keymaps/macos/emacs.json @@ -114,7 +114,7 @@ "ctrl-x o": "workspace::ActivateNextPane", // other-window "ctrl-x k": "pane::CloseActiveItem", // kill-buffer "ctrl-x 0": "pane::CloseActiveItem", // delete-window - "ctrl-x 1": "pane::CloseInactiveItems", // delete-other-windows + "ctrl-x 1": "pane::CloseOtherItems", // delete-other-windows "ctrl-x 2": "pane::SplitDown", // split-window-below "ctrl-x 3": "pane::SplitRight", // split-window-right "ctrl-x ctrl-f": "file_finder::Toggle", // find-file diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index 22c6f18383..e8b796f534 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -3,6 +3,7 @@ "bindings": { "cmd-{": "pane::ActivatePreviousItem", "cmd-}": "pane::ActivateNextItem", + "cmd-0": "git_panel::ToggleFocus", // overrides `cmd-0` zoom reset "ctrl-f2": "debugger::Stop", "f6": "debugger::Pause", "f7": "debugger::StepInto", @@ -63,28 +64,55 @@ "context": "Editor && mode == full", "bindings": { "cmd-f12": "outline::Toggle", - "cmd-7": "outline::Toggle", + "cmd-r": ["buffer_search::Deploy", { "replace_enabled": true }], "cmd-shift-o": "file_finder::Toggle", "cmd-l": "go_to_line::Toggle", "alt-enter": "editor::ToggleCodeActions" } }, { - "context": "BufferSearchBar > Editor", + "context": "BufferSearchBar", "bindings": { "shift-enter": "search::SelectPreviousMatch" } }, + { + "context": "BufferSearchBar || ProjectSearchBar", + "bindings": { + "alt-c": "search::ToggleCaseSensitive", + "alt-e": "search::ToggleSelection", + "alt-x": "search::ToggleRegex", + "alt-w": "search::ToggleWholeWord", + "ctrl-alt-c": "search::ToggleCaseSensitive", + "ctrl-alt-e": "search::ToggleSelection", + "ctrl-alt-w": "search::ToggleWholeWord", + "ctrl-alt-x": "search::ToggleRegex" + } + }, { "context": "Workspace", "bindings": { + "cmd-shift-f12": "workspace::CloseAllDocks", + "cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], + "ctrl-alt-r": "task::Spawn", + "cmd-e": "file_finder::Toggle", + "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "cmd-shift-o": "file_finder::Toggle", "cmd-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", "cmd-alt-o": "project_symbols::Toggle", // JetBrains: Go to Symbol "cmd-o": "project_symbols::Toggle", // JetBrains: Go to Class - "cmd-1": "workspace::ToggleLeftDock", - "cmd-6": "diagnostics::Deploy" + "cmd-1": "project_panel::ToggleFocus", + "cmd-5": "debug_panel::ToggleFocus", + "cmd-6": "diagnostics::Deploy", + "cmd-7": "outline_panel::ToggleFocus" + } + }, + { + "context": "Workspace || Editor", + "bindings": { + "alt-f12": "terminal_panel::ToggleFocus", + "cmd-shift-k": "git::Push" } }, { @@ -98,11 +126,31 @@ "context": "ProjectPanel", "bindings": { "enter": "project_panel::Open", + "cmd-shift-f": "project_panel::NewSearchInDirectory", "cmd-backspace": ["project_panel::Trash", { "skip_prompt": false }], "backspace": ["project_panel::Trash", { "skip_prompt": false }], "delete": ["project_panel::Trash", { "skip_prompt": false }], "shift-delete": ["project_panel::Delete", { "skip_prompt": false }], "shift-f6": "project_panel::Rename" } + }, + { + "context": "Terminal", + "bindings": { + "cmd-t": "workspace::NewTerminal", + "alt-f12": "workspace::CloseActiveDock", + "cmd-up": "terminal::ScrollLineUp", + "cmd-down": "terminal::ScrollLineDown", + "shift-pageup": "terminal::ScrollPageUp", + "shift-pagedown": "terminal::ScrollPageDown" + } + }, + { "context": "GitPanel", "bindings": { "cmd-0": "workspace::CloseActiveDock" } }, + { "context": "DebugPanel", "bindings": { "cmd-5": "workspace::CloseActiveDock" } }, + { "context": "Diagnostics > Editor", "bindings": { "cmd-6": "pane::CloseActiveItem" } }, + { "context": "OutlinePanel", "bindings": { "cmd-7": "workspace::CloseActiveDock" } }, + { + "context": "Dock || Workspace || Terminal || OutlinePanel || ProjectPanel || CollabPanel || (Editor && mode == auto_height)", + "bindings": { "escape": "editor::ToggleFocus" } } ] diff --git a/assets/keymaps/macos/textmate.json b/assets/keymaps/macos/textmate.json index dccb675f6c..0bd8873b17 100644 --- a/assets/keymaps/macos/textmate.json +++ b/assets/keymaps/macos/textmate.json @@ -6,7 +6,7 @@ } }, { - "context": "Editor", + "context": "Editor && mode == full", "bindings": { "cmd-l": "go_to_line::Toggle", "ctrl-shift-d": "editor::DuplicateLineDown", @@ -15,7 +15,12 @@ "cmd-enter": "editor::NewlineBelow", "cmd-alt-enter": "editor::NewlineAbove", "cmd-shift-l": "editor::SelectLine", - "cmd-shift-t": "outline::Toggle", + "cmd-shift-t": "outline::Toggle" + } + }, + { + "context": "Editor", + "bindings": { "alt-backspace": "editor::DeleteToPreviousWordStart", "alt-shift-backspace": "editor::DeleteToNextWordEnd", "alt-delete": "editor::DeleteToNextWordEnd", @@ -39,10 +44,6 @@ "ctrl-_": "editor::ConvertToSnakeCase" } }, - { - "context": "Editor && mode == full", - "bindings": {} - }, { "context": "BufferSearchBar", "bindings": { diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 571192a479..6458ac1510 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -124,6 +124,7 @@ "g r a": "editor::ToggleCodeActions", "g g": "vim::StartOfDocument", "g h": "editor::Hover", + "g B": "editor::BlameHover", "g t": "pane::ActivateNextItem", "g shift-t": "pane::ActivatePreviousItem", "g d": "editor::GoToDefinition", @@ -219,6 +220,8 @@ { "context": "vim_mode == normal", "bindings": { + "i": "vim::InsertBefore", + "a": "vim::InsertAfter", "ctrl-[": "editor::Cancel", ":": "command_palette::Toggle", "c": "vim::PushChange", @@ -352,9 +355,7 @@ "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", "shift-y": "vim::YankLine", - "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", - "a": "vim::InsertAfter", "shift-a": "vim::InsertEndOfLine", "o": "vim::InsertLineBelow", "shift-o": "vim::InsertLineAbove", @@ -376,7 +377,10 @@ { "context": "vim_mode == helix_normal && !menu", "bindings": { + "i": "vim::HelixInsert", + "a": "vim::HelixAppend", "ctrl-[": "editor::Cancel", + ";": "vim::HelixCollapseSelection", ":": "command_palette::Toggle", "left": "vim::WrappingLeft", "right": "vim::WrappingRight", @@ -466,7 +470,7 @@ } }, { - "context": "vim_mode == insert && showing_signature_help && !showing_completions", + "context": "(vim_mode == insert || vim_mode == normal) && showing_signature_help && !showing_completions", "bindings": { "ctrl-p": "editor::SignatureHelpPrevious", "ctrl-n": "editor::SignatureHelpNext" @@ -723,7 +727,7 @@ } }, { - "context": "AgentPanel || GitPanel || ProjectPanel || CollabPanel || OutlinePanel || ChatPanel || VimControl || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView || DebugPanel", + "context": "VimControl || !Editor && !Terminal", "bindings": { // window related commands (ctrl-w X) "ctrl-w": null, @@ -781,7 +785,7 @@ } }, { - "context": "ChangesList || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView || Welcome", + "context": "!Editor && !Terminal", "bindings": { ":": "command_palette::Toggle", "g /": "pane::DeploySearch" @@ -841,6 +845,7 @@ "i": "git_panel::FocusEditor", "x": "git::ToggleStaged", "shift-x": "git::StageAll", + "g x": "git::StageRange", "shift-u": "git::UnstageAll" } }, @@ -856,6 +861,14 @@ "shift-n": null } }, + { + "context": "Picker > Editor", + "bindings": { + "ctrl-h": "editor::Backspace", + "ctrl-u": "editor::DeleteToBeginningOfLine", + "ctrl-w": "editor::DeleteToPreviousWordStart" + } + }, { "context": "GitCommit > Editor && VimControl && vim_mode == normal", "bindings": { diff --git a/assets/settings/default.json b/assets/settings/default.json index 8c105b2c1e..3a7a48efc2 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -84,7 +84,7 @@ "bottom_dock_layout": "contained", // The direction that you want to split panes horizontally. Defaults to "up" "pane_split_direction_horizontal": "up", - // The direction that you want to split panes horizontally. Defaults to "left" + // The direction that you want to split panes vertically. Defaults to "left" "pane_split_direction_vertical": "left", // Centered layout related settings. "centered_layout": { @@ -197,6 +197,8 @@ // "inline" // 3. Place snippets at the bottom of the completion list: // "bottom" + // 4. Do not show snippets in the completion list: + // "none" "snippet_sort_order": "inline", // How to highlight the current line in the editor. // @@ -362,7 +364,9 @@ // Whether to show user picture in the titlebar. "show_user_picture": true, // Whether to show the sign in button in the titlebar. - "show_sign_in": true + "show_sign_in": true, + // Whether to show the menus in the titlebar. + "show_menus": false }, // Scrollbar related settings "scrollbar": { @@ -687,7 +691,10 @@ // 5. Never show the scrollbar: // "never" "show": null - } + }, + // Default depth to expand outline items in the current file. + // Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper. + "expand_outlines_with_depth": 100 }, "collaboration_panel": { // Whether to show the collaboration panel button in the status bar. @@ -815,7 +822,7 @@ "edit_file": true, "fetch": true, "list_directory": true, - "project_notifications": true, + "project_notifications": false, "move_path": true, "now": true, "find_path": true, @@ -835,7 +842,7 @@ "diagnostics": true, "fetch": true, "list_directory": true, - "project_notifications": true, + "project_notifications": false, "now": true, "find_path": true, "read_file": true, @@ -1072,6 +1079,10 @@ // Send anonymized usage data like what languages you're using Zed with. "metrics": true }, + // Whether to disable all AI features in Zed. + // + // Default: false + "disable_ai": false, // Automatically update Zed. This setting may be ignored on Linux if // installed through a package manager. "auto_update": true, @@ -1133,6 +1144,7 @@ "**/.svn", "**/.hg", "**/.jj", + "**/.repo", "**/CVS", "**/.DS_Store", "**/Thumbs.db", @@ -1155,16 +1167,14 @@ // Control whether the git blame information is shown inline, // in the currently focused line. "inline_blame": { - "enabled": true + "enabled": true, // Sets a delay after which the inline blame information is shown. // Delay is restarted with every cursor movement. - // "delay_ms": 600 - // + "delay_ms": 0, // Whether or not to display the git commit summary on the same line. - // "show_commit_summary": false - // + "show_commit_summary": false, // The minimum column number to show the inline blame information at - // "min_column": 0 + "min_column": 0 }, // How git hunks are displayed visually in the editor. // This setting can take two values: @@ -1377,11 +1387,11 @@ // This will be merged with the platform's default font fallbacks // "font_fallbacks": ["FiraCode Nerd Fonts"], // The weight of the editor font in standard CSS units from 100 to 900. - // "font_weight": 400 + "font_weight": 400, // Sets the maximum number of lines in the terminal's scrollback buffer. // Default: 10_000, maximum: 100_000 (all bigger values set will be treated as 100_000), 0 disables the scrolling. // Existing terminals will not pick up this change until they are recreated. - // "max_scroll_history_lines": 10000, + "max_scroll_history_lines": 10000, // The minimum APCA perceptual contrast between foreground and background colors. // APCA (Accessible Perceptual Contrast Algorithm) is more accurate than WCAG 2.x, // especially for dark mode. Values range from 0 to 106. @@ -1670,6 +1680,10 @@ "allowed": true } }, + "SystemVerilog": { + "format_on_save": "off", + "use_on_type_format": false + }, "Vue.js": { "language_servers": ["vue-language-server", "..."], "prettier": { @@ -1705,6 +1719,7 @@ "openai": { "api_url": "https://api.openai.com/v1" }, + "openai_compatible": {}, "open_router": { "api_url": "https://openrouter.ai/api/v1" }, @@ -1855,6 +1870,8 @@ "read_ssh_config": true, // Configures context servers for use by the agent. "context_servers": {}, + // Configures agent servers available in the agent panel. + "agent_servers": {}, "debugger": { "stepping_granularity": "line", "save_breakpoints": true, diff --git a/assets/settings/initial_debug_tasks.json b/assets/settings/initial_debug_tasks.json index 78fc1fc5f0..af4512bd51 100644 --- a/assets/settings/initial_debug_tasks.json +++ b/assets/settings/initial_debug_tasks.json @@ -15,13 +15,15 @@ "adapter": "JavaScript", "program": "$ZED_FILE", "request": "launch", - "cwd": "$ZED_WORKTREE_ROOT" + "cwd": "$ZED_WORKTREE_ROOT", + "type": "pwa-node" }, { "label": "JavaScript debug terminal", "adapter": "JavaScript", "request": "launch", "cwd": "$ZED_WORKTREE_ROOT", - "console": "integratedTerminal" + "console": "integratedTerminal", + "type": "pwa-node" } ] diff --git a/assets/settings/initial_user_settings.json b/assets/settings/initial_user_settings.json index 71f3beb1d6..5ac2063bdb 100644 --- a/assets/settings/initial_user_settings.json +++ b/assets/settings/initial_user_settings.json @@ -8,7 +8,7 @@ // command palette (cmd-shift-p / ctrl-shift-p) { "ui_font_size": 16, - "buffer_font_size": 16, + "buffer_font_size": 15, "theme": { "mode": "system", "light": "One Light", diff --git a/compose.yml b/compose.yml index 4cd4c86df6..d0d9bac425 100644 --- a/compose.yml +++ b/compose.yml @@ -59,5 +59,11 @@ services: depends_on: - postgres + stripe-mock: + image: stripe/stripe-mock:v0.178.0 + ports: + - 12111:12111 + - 12112:12112 + volumes: postgres_data: diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml new file mode 100644 index 0000000000..011f26f364 --- /dev/null +++ b/crates/acp_thread/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "acp_thread" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/acp_thread.rs" +doctest = false + +[features] +test-support = ["gpui/test-support", "project/test-support"] + +[dependencies] +agent-client-protocol.workspace = true +agentic-coding-protocol.workspace = true +anyhow.workspace = true +assistant_tool.workspace = true +buffer_diff.workspace = true +editor.workspace = true +futures.workspace = true +gpui.workspace = true +itertools.workspace = true +language.workspace = true +markdown.workspace = true +project.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +async-pipe.workspace = true +env_logger.workspace = true +gpui = { workspace = true, "features" = ["test-support"] } +indoc.workspace = true +project = { workspace = true, "features" = ["test-support"] } +tempfile.workspace = true +util.workspace = true +settings.workspace = true diff --git a/crates/acp_thread/LICENSE-GPL b/crates/acp_thread/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/acp_thread/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs new file mode 100644 index 0000000000..3c6c21205f --- /dev/null +++ b/crates/acp_thread/src/acp_thread.rs @@ -0,0 +1,1571 @@ +mod connection; +mod old_acp_support; +pub use connection::*; +pub use old_acp_support::*; + +use agent_client_protocol as acp; +use anyhow::{Context as _, Result}; +use assistant_tool::ActionLog; +use buffer_diff::BufferDiff; +use editor::{Bias, MultiBuffer, PathKey}; +use futures::{FutureExt, channel::oneshot, future::BoxFuture}; +use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use itertools::Itertools; +use language::{ + Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, + text_diff, +}; +use markdown::Markdown; +use project::{AgentLocation, Project}; +use std::collections::HashMap; +use std::error::Error; +use std::fmt::Formatter; +use std::rc::Rc; +use std::{ + fmt::Display, + mem, + path::{Path, PathBuf}, + sync::Arc, +}; +use ui::App; +use util::ResultExt; + +#[derive(Debug)] +pub struct UserMessage { + pub content: ContentBlock, +} + +impl UserMessage { + pub fn from_acp( + message: impl IntoIterator, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let mut content = ContentBlock::Empty; + for chunk in message { + content.append(chunk, &language_registry, cx) + } + Self { content: content } + } + + fn to_markdown(&self, cx: &App) -> String { + format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) + } +} + +#[derive(Debug)] +pub struct MentionPath<'a>(&'a Path); + +impl<'a> MentionPath<'a> { + const PREFIX: &'static str = "@file:"; + + pub fn new(path: &'a Path) -> Self { + MentionPath(path) + } + + pub fn try_parse(url: &'a str) -> Option { + let path = url.strip_prefix(Self::PREFIX)?; + Some(MentionPath(Path::new(path))) + } + + pub fn path(&self) -> &Path { + self.0 + } +} + +impl Display for MentionPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[@{}]({}{})", + self.0.file_name().unwrap_or_default().display(), + Self::PREFIX, + self.0.display() + ) + } +} + +#[derive(Debug, PartialEq)] +pub struct AssistantMessage { + pub chunks: Vec, +} + +impl AssistantMessage { + pub fn to_markdown(&self, cx: &App) -> String { + format!( + "## Assistant\n\n{}\n\n", + self.chunks + .iter() + .map(|chunk| chunk.to_markdown(cx)) + .join("\n\n") + ) + } +} + +#[derive(Debug, PartialEq)] +pub enum AssistantMessageChunk { + Message { block: ContentBlock }, + Thought { block: ContentBlock }, +} + +impl AssistantMessageChunk { + pub fn from_str(chunk: &str, language_registry: &Arc, cx: &mut App) -> Self { + Self::Message { + block: ContentBlock::new(chunk.into(), language_registry, cx), + } + } + + fn to_markdown(&self, cx: &App) -> String { + match self { + Self::Message { block } => block.to_markdown(cx).to_string(), + Self::Thought { block } => { + format!("\n{}\n", block.to_markdown(cx)) + } + } + } +} + +#[derive(Debug)] +pub enum AgentThreadEntry { + UserMessage(UserMessage), + AssistantMessage(AssistantMessage), + ToolCall(ToolCall), +} + +impl AgentThreadEntry { + fn to_markdown(&self, cx: &App) -> String { + match self { + Self::UserMessage(message) => message.to_markdown(cx), + Self::AssistantMessage(message) => message.to_markdown(cx), + Self::ToolCall(tool_call) => tool_call.to_markdown(cx), + } + } + + pub fn diffs(&self) -> impl Iterator { + if let AgentThreadEntry::ToolCall(call) = self { + itertools::Either::Left(call.diffs()) + } else { + itertools::Either::Right(std::iter::empty()) + } + } + + pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> { + if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self { + Some(locations) + } else { + None + } + } +} + +#[derive(Debug)] +pub struct ToolCall { + pub id: acp::ToolCallId, + pub label: Entity, + pub kind: acp::ToolKind, + pub content: Vec, + pub status: ToolCallStatus, + pub locations: Vec, +} + +impl ToolCall { + fn from_acp( + tool_call: acp::ToolCall, + status: ToolCallStatus, + language_registry: Arc, + cx: &mut App, + ) -> Self { + Self { + id: tool_call.id, + label: cx.new(|cx| { + Markdown::new( + tool_call.label.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + kind: tool_call.kind, + content: tool_call + .content + .into_iter() + .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) + .collect(), + locations: tool_call.locations, + status, + } + } + + pub fn diffs(&self) -> impl Iterator { + self.content.iter().filter_map(|content| match content { + ToolCallContent::ContentBlock { .. } => None, + ToolCallContent::Diff { diff } => Some(diff), + }) + } + + fn to_markdown(&self, cx: &App) -> String { + let mut markdown = format!( + "**Tool Call: {}**\nStatus: {}\n\n", + self.label.read(cx).source(), + self.status + ); + for content in &self.content { + markdown.push_str(content.to_markdown(cx).as_str()); + markdown.push_str("\n\n"); + } + markdown + } +} + +#[derive(Debug)] +pub enum ToolCallStatus { + WaitingForConfirmation { + options: Vec, + respond_tx: oneshot::Sender, + }, + Allowed { + status: acp::ToolCallStatus, + }, + Rejected, + Canceled, +} + +impl Display for ToolCallStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", + ToolCallStatus::Allowed { status } => match status { + acp::ToolCallStatus::InProgress => "In Progress", + acp::ToolCallStatus::Completed => "Completed", + acp::ToolCallStatus::Failed => "Failed", + }, + ToolCallStatus::Rejected => "Rejected", + ToolCallStatus::Canceled => "Canceled", + } + ) + } +} + +#[derive(Debug, PartialEq, Clone)] +pub enum ContentBlock { + Empty, + Markdown { markdown: Entity }, +} + +impl ContentBlock { + pub fn new( + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + this.append(block, language_registry, cx); + this + } + + pub fn new_combined( + blocks: impl IntoIterator, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + for block in blocks { + this.append(block, &language_registry, cx); + } + this + } + + pub fn append( + &mut self, + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) { + let new_content = match block { + acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::ResourceLink(resource_link) => { + if let Some(path) = resource_link.uri.strip_prefix("file://") { + format!("{}", MentionPath(path.as_ref())) + } else { + resource_link.uri.clone() + } + } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => String::new(), + }; + + match self { + ContentBlock::Empty => { + *self = ContentBlock::Markdown { + markdown: cx.new(|cx| { + Markdown::new( + new_content.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + }; + } + ContentBlock::Markdown { markdown } => { + markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); + } + } + } + + fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { + match self { + ContentBlock::Empty => "", + ContentBlock::Markdown { markdown } => markdown.read(cx).source(), + } + } + + pub fn markdown(&self) -> Option<&Entity> { + match self { + ContentBlock::Empty => None, + ContentBlock::Markdown { markdown } => Some(markdown), + } + } +} + +#[derive(Debug)] +pub enum ToolCallContent { + ContentBlock { content: ContentBlock }, + Diff { diff: Diff }, +} + +impl ToolCallContent { + pub fn from_acp( + content: acp::ToolCallContent, + language_registry: Arc, + cx: &mut App, + ) -> Self { + match content { + acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock { + content: ContentBlock::new(content, &language_registry, cx), + }, + acp::ToolCallContent::Diff { diff } => Self::Diff { + diff: Diff::from_acp(diff, language_registry, cx), + }, + } + } + + pub fn to_markdown(&self, cx: &App) -> String { + match self { + Self::ContentBlock { content } => content.to_markdown(cx).to_string(), + Self::Diff { diff } => diff.to_markdown(cx), + } + } +} + +#[derive(Debug)] +pub struct Diff { + pub multibuffer: Entity, + pub path: PathBuf, + pub new_buffer: Entity, + pub old_buffer: Entity, + _task: Task>, +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut App, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); + + let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); + let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); + let old_buffer_snapshot = old_buffer.read(cx).snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); + let diff_task = buffer_diff.update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry.clone()), + new_buffer_snapshot, + cx, + ) + }); + + let task = cx.spawn({ + let multibuffer = multibuffer.clone(); + let path = path.clone(); + let new_buffer = new_buffer.clone(); + async move |cx| { + diff_task.await?; + + multibuffer + .update(cx, |multibuffer, cx| { + let hunk_ranges = { + let buffer = new_buffer.read(cx); + let diff = buffer_diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>() + }; + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&new_buffer, cx), + new_buffer.clone(), + hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff.clone(), cx); + }) + .log_err(); + + if let Some(language) = language_registry + .language_for_file_path(&path) + .await + .log_err() + { + new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?; + } + + anyhow::Ok(()) + } + }); + + Self { + multibuffer, + path, + new_buffer, + old_buffer, + _task: task, + } + } + + fn to_markdown(&self, cx: &App) -> String { + let buffer_text = self + .multibuffer + .read(cx) + .all_buffers() + .iter() + .map(|buffer| buffer.read(cx).text()) + .join("\n"); + format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) + } +} + +#[derive(Debug, Default)] +pub struct Plan { + pub entries: Vec, +} + +#[derive(Debug)] +pub struct PlanStats<'a> { + pub in_progress_entry: Option<&'a PlanEntry>, + pub pending: u32, + pub completed: u32, +} + +impl Plan { + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn stats(&self) -> PlanStats<'_> { + let mut stats = PlanStats { + in_progress_entry: None, + pending: 0, + completed: 0, + }; + + for entry in &self.entries { + match &entry.status { + acp::PlanEntryStatus::Pending => { + stats.pending += 1; + } + acp::PlanEntryStatus::InProgress => { + stats.in_progress_entry = stats.in_progress_entry.or(Some(entry)); + } + acp::PlanEntryStatus::Completed => { + stats.completed += 1; + } + } + } + + stats + } +} + +#[derive(Debug)] +pub struct PlanEntry { + pub content: Entity, + pub priority: acp::PlanEntryPriority, + pub status: acp::PlanEntryStatus, +} + +impl PlanEntry { + pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self { + Self { + content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)), + priority: entry.priority, + status: entry.status, + } + } +} + +pub struct AcpThread { + title: SharedString, + entries: Vec, + plan: Plan, + project: Entity, + action_log: Entity, + shared_buffers: HashMap, BufferSnapshot>, + send_task: Option>, + connection: Rc, + session_id: acp::SessionId, +} + +pub enum AcpThreadEvent { + NewEntry, + EntryUpdated(usize), +} + +impl EventEmitter for AcpThread {} + +#[derive(PartialEq, Eq)] +pub enum ThreadStatus { + Idle, + WaitingForToolConfirmation, + Generating, +} + +#[derive(Debug, Clone)] +pub enum LoadError { + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, + Exited(i32), + Other(SharedString), +} + +impl Display for LoadError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message), + LoadError::Exited(status) => write!(f, "Server exited with status {}", status), + LoadError::Other(msg) => write!(f, "{}", msg), + } + } +} + +impl Error for LoadError {} + +impl AcpThread { + pub fn new( + connection: Rc, + project: Entity, + session_id: acp::SessionId, + cx: &mut Context, + ) -> Self { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + Self { + action_log, + shared_buffers: Default::default(), + entries: Default::default(), + plan: Default::default(), + title: connection.name().into(), + project, + send_task: None, + connection, + session_id, + } + } + + pub fn action_log(&self) -> &Entity { + &self.action_log + } + + pub fn project(&self) -> &Entity { + &self.project + } + + pub fn title(&self) -> SharedString { + self.title.clone() + } + + pub fn entries(&self) -> &[AgentThreadEntry] { + &self.entries + } + + pub fn status(&self) -> ThreadStatus { + if self.send_task.is_some() { + if self.waiting_for_tool_confirmation() { + ThreadStatus::WaitingForToolConfirmation + } else { + ThreadStatus::Generating + } + } else { + ThreadStatus::Idle + } + } + + pub fn has_pending_edit_tool_calls(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(_) => return false, + AgentThreadEntry::ToolCall(call) if call.diffs().next().is_some() => return true, + AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} + } + } + + false + } + + pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { + self.entries.push(entry); + cx.emit(AcpThreadEvent::NewEntry); + } + + pub fn push_assistant_chunk( + &mut self, + chunk: acp::ContentBlock, + is_thought: bool, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + let entries_len = self.entries.len(); + if let Some(last_entry) = self.entries.last_mut() + && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry + { + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + match (chunks.last_mut(), is_thought) { + (Some(AssistantMessageChunk::Message { block }), false) + | (Some(AssistantMessageChunk::Thought { block }), true) => { + block.append(chunk, &language_registry, cx) + } + _ => { + let block = ContentBlock::new(chunk, &language_registry, cx); + if is_thought { + chunks.push(AssistantMessageChunk::Thought { block }) + } else { + chunks.push(AssistantMessageChunk::Message { block }) + } + } + } + } else { + let block = ContentBlock::new(chunk, &language_registry, cx); + let chunk = if is_thought { + AssistantMessageChunk::Thought { block } + } else { + AssistantMessageChunk::Message { block } + }; + + self.push_entry( + AgentThreadEntry::AssistantMessage(AssistantMessage { + chunks: vec![chunk], + }), + cx, + ); + } + } + + pub fn update_tool_call( + &mut self, + id: acp::ToolCallId, + status: acp::ToolCallStatus, + content: Option>, + cx: &mut Context, + ) -> Result<()> { + let languages = self.project.read(cx).languages().clone(); + let (ix, current_call) = self.tool_call_mut(&id).context("Tool call not found")?; + + if let Some(content) = content { + current_call.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, languages.clone(), cx)) + .collect(); + } + current_call.status = ToolCallStatus::Allowed { status }; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + + Ok(()) + } + + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. + pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { + let status = ToolCallStatus::Allowed { + status: tool_call.status, + }; + self.upsert_tool_call_inner(tool_call, status, cx) + } + + pub fn upsert_tool_call_inner( + &mut self, + tool_call: acp::ToolCall, + status: ToolCallStatus, + cx: &mut Context, + ) { + let language_registry = self.project.read(cx).languages().clone(); + let call = ToolCall::from_acp(tool_call, status, language_registry, cx); + + let location = call.locations.last().cloned(); + + if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { + *current_call = call; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } else { + self.push_entry(AgentThreadEntry::ToolCall(call), cx); + } + + if let Some(location) = location { + self.set_project_location(location, cx) + } + } + + fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { + // The tool call we are looking for is typically the last one, or very close to the end. + // At the moment, it doesn't seem like a hashmap would be a good fit for this use case. + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) + } + + pub fn request_tool_call_permission( + &mut self, + tool_call: acp::ToolCall, + options: Vec, + cx: &mut Context, + ) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + let status = ToolCallStatus::WaitingForConfirmation { + options, + respond_tx: tx, + }; + + self.upsert_tool_call_inner(tool_call, status, cx); + rx + } + + pub fn authorize_tool_call( + &mut self, + id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, + cx: &mut Context, + ) { + let Some((ix, call)) = self.tool_call_mut(&id) else { + return; + }; + + let new_status = match option_kind { + acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { + ToolCallStatus::Rejected + } + acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + } + } + }; + + let curr_status = mem::replace(&mut call.status, new_status); + + if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { + respond_tx.send(option_id).log_err(); + } else if cfg!(debug_assertions) { + panic!("tried to authorize an already authorized tool call"); + } + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } + + pub fn plan(&self) -> &Plan { + &self.plan + } + + pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { + self.plan = Plan { + entries: request + .entries + .into_iter() + .map(|entry| PlanEntry::from_acp(entry, cx)) + .collect(), + }; + + cx.notify(); + } + + fn clear_completed_plan_entries(&mut self, cx: &mut Context) { + self.plan + .entries + .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); + cx.notify(); + } + + pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { + self.project.update(cx, |project, cx| { + let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { + return; + }; + let buffer = project.open_buffer(path, cx); + cx.spawn(async move |project, cx| { + let buffer = buffer.await?; + + project.update(cx, |project, cx| { + let position = if let Some(line) = location.line { + let snapshot = buffer.read(cx).snapshot(); + let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); + snapshot.anchor_before(point) + } else { + Anchor::MIN + }; + + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + }) + }) + .detach_and_log_err(cx); + }); + } + + /// Returns true if the last turn is awaiting tool authorization + pub fn waiting_for_tool_confirmation(&self) -> bool { + for entry in self.entries.iter().rev() { + match &entry { + AgentThreadEntry::ToolCall(call) => match call.status { + ToolCallStatus::WaitingForConfirmation { .. } => return true, + ToolCallStatus::Allowed { .. } + | ToolCallStatus::Rejected + | ToolCallStatus::Canceled => continue, + }, + AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { + // Reached the beginning of the turn + return false; + } + } + } + false + } + + pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { + self.connection.authenticate(cx) + } + + #[cfg(any(test, feature = "test-support"))] + pub fn send_raw( + &mut self, + message: &str, + cx: &mut Context, + ) -> BoxFuture<'static, Result<()>> { + self.send( + vec![acp::ContentBlock::Text(acp::TextContent { + text: message.to_string(), + annotations: None, + })], + cx, + ) + } + + pub fn send( + &mut self, + message: Vec, + cx: &mut Context, + ) -> BoxFuture<'static, Result<()>> { + let block = ContentBlock::new_combined( + message.clone(), + self.project.read(cx).languages().clone(), + cx, + ); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { content: block }), + cx, + ); + self.clear_completed_plan_entries(cx); + + let (tx, rx) = oneshot::channel(); + let cancel_task = self.cancel(cx); + + self.send_task = Some(cx.spawn(async move |this, cx| { + async { + cancel_task.await; + + let result = this + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptToolArguments { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) + })? + .await; + tx.send(result).log_err(); + this.update(cx, |this, _cx| this.send_task.take())?; + anyhow::Ok(()) + } + .await + .log_err(); + })); + + async move { + match rx.await { + Ok(Err(e)) => Err(e)?, + _ => Ok(()), + } + } + .boxed() + } + + pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + let Some(send_task) = self.send_task.take() else { + return Task::ready(()); + }; + + for entry in self.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress + } + ); + + if cancel { + call.status = ToolCallStatus::Canceled; + } + } + } + + self.connection.cancel(&self.session_id, cx); + + // Wait for the send task to complete + cx.foreground_executor().spawn(send_task) + } + + pub fn read_text_file( + &self, + path: PathBuf, + line: Option, + limit: Option, + reuse_shared_snapshot: bool, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + + let snapshot = if reuse_shared_snapshot { + this.read_with(cx, |this, _| { + this.shared_buffers.get(&buffer.clone()).cloned() + }) + .log_err() + .flatten() + } else { + None + }; + + let snapshot = if let Some(snapshot) = snapshot { + snapshot + } else { + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + })?; + project.update(cx, |project, cx| { + let position = buffer + .read(cx) + .snapshot() + .anchor_before(Point::new(line.unwrap_or_default(), 0)); + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + })?; + + buffer.update(cx, |buffer, _| buffer.snapshot())? + }; + + this.update(cx, |this, _| { + let text = snapshot.text(); + this.shared_buffers.insert(buffer.clone(), snapshot); + if line.is_none() && limit.is_none() { + return Ok(text); + } + let limit = limit.unwrap_or(u32::MAX) as usize; + let Some(line) = line else { + return Ok(text.lines().take(limit).collect::()); + }; + + let count = text.lines().count(); + if count < line as usize { + anyhow::bail!("There are only {} lines", count); + } + Ok(text + .lines() + .skip(line as usize + 1) + .take(limit) + .collect::()) + })? + }) + } + + pub fn write_text_file( + &self, + path: PathBuf, + content: String, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + let snapshot = this.update(cx, |this, cx| { + this.shared_buffers + .get(&buffer) + .cloned() + .unwrap_or_else(|| buffer.read(cx).snapshot()) + })?; + let edits = cx + .background_executor() + .spawn(async move { + let old_text = snapshot.text(); + text_diff(old_text.as_str(), &content) + .into_iter() + .map(|(range, replacement)| { + ( + snapshot.anchor_after(range.start) + ..snapshot.anchor_before(range.end), + replacement, + ) + }) + .collect::>() + }) + .await; + cx.update(|cx| { + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: edits + .last() + .map(|(range, _)| range.end) + .unwrap_or(Anchor::MIN), + }), + cx, + ); + }); + + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + }); + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + }); + action_log.update(cx, |action_log, cx| { + action_log.buffer_edited(buffer.clone(), cx); + }); + })?; + project + .update(cx, |project, cx| project.save_buffer(buffer, cx))? + .await + }) + } + + pub fn to_markdown(&self, cx: &App) -> String { + self.entries.iter().map(|e| e.to_markdown(cx)).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agentic_coding_protocol as acp_old; + use anyhow::anyhow; + use async_pipe::{PipeReader, PipeWriter}; + use futures::{channel::mpsc, future::LocalBoxFuture, select}; + use gpui::{AsyncApp, TestAppContext}; + use indoc::indoc; + use project::FakeFs; + use serde_json::json; + use settings::SettingsStore; + use smol::{future::BoxedLocal, stream::StreamExt as _}; + use std::{cell::RefCell, rc::Rc, time::Duration}; + + use util::path; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } + + #[gpui::test] + async fn test_thinking_concatenation(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (thread, fake_server) = fake_acp_thread(project, cx); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |_, server, mut cx| async move { + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { + chunk: acp_old::AssistantMessageChunk::Thought { + thought: "Thinking ".into(), + }, + }) + })? + .await + .unwrap(); + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { + chunk: acp_old::AssistantMessageChunk::Thought { + thought: "hard!".into(), + }, + }) + })? + .await + .unwrap(); + + Ok(()) + }) + }); + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert_eq!( + output, + indoc! {r#" + ## User + + Hello from Zed! + + ## Assistant + + + Thinking hard! + + + "#} + ); + } + + #[gpui::test] + async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let (thread, fake_server) = fake_acp_thread(project.clone(), cx); + let (worktree, pathbuf) = project + .update(cx, |project, cx| { + project.find_or_create_worktree(path!("/tmp/foo"), true, cx) + }) + .await + .unwrap(); + let buffer = project + .update(cx, |project, cx| { + project.open_buffer((worktree.read(cx).id(), pathbuf), cx) + }) + .await + .unwrap(); + + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |_, server, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = server + .update(&mut cx, |server, _| { + server.send_to_zed(acp_old::ReadTextFileParams { + path: path!("/tmp/foo").into(), + line: None, + limit: None, + }) + })? + .await + .unwrap(); + assert_eq!(content.content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp_old::WriteTextFileParams { + path: path!("/tmp/foo").into(), + content: "one\ntwo\nthree\nfour\nfive\n".to_string(), + }) + })? + .await + .unwrap(); + Ok(()) + } + }) + }); + + let request = thread.update(cx, |thread, cx| { + thread.send_raw("Extend the count in /tmp/foo", cx) + }); + read_file_rx.await.ok(); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "zero\n".to_string())], None, cx); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + assert_eq!( + String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + request.await.unwrap(); + } + + #[gpui::test] + async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (thread, fake_server) = fake_acp_thread(project, cx); + + let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>(); + + let tool_call_id = Rc::new(RefCell::new(None)); + let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx))); + fake_server.update(cx, |fake_server, _| { + let tool_call_id = tool_call_id.clone(); + fake_server.on_user_message(move |_, server, mut cx| { + let end_turn_rx = end_turn_rx.clone(); + let tool_call_id = tool_call_id.clone(); + async move { + let tool_call_result = server + .update(&mut cx, |server, _| { + server.send_to_zed(acp_old::PushToolCallParams { + label: "Fetch".to_string(), + icon: acp_old::Icon::Globe, + content: None, + locations: vec![], + }) + })? + .await + .unwrap(); + *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id); + end_turn_rx.take().unwrap().await.ok(); + + Ok(()) + } + }) + }); + + let request = thread.update(cx, |thread, cx| { + thread.send_raw("Fetch https://example.com", cx) + }); + + run_until_first_tool_call(&thread, cx).await; + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + .. + }, + .. + }) + )); + }); + + cx.run_until_parked(); + + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) + )); + }); + + fake_server + .update(cx, |fake_server, _| { + fake_server.send_to_zed(acp_old::UpdateToolCallParams { + tool_call_id: tool_call_id.borrow().unwrap(), + status: acp_old::ToolCallStatus::Finished, + content: None, + }) + }) + .await + .unwrap(); + + drop(end_turn_tx); + assert!(request.await.unwrap_err().to_string().contains("canceled")); + + thread.read_with(cx, |thread, _| { + assert!(matches!( + thread.entries[1], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Completed, + .. + }, + .. + }) + )); + }); + } + + async fn run_until_first_tool_call( + thread: &Entity, + cx: &mut TestAppContext, + ) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries.iter().enumerate() { + if matches!(entry, AgentThreadEntry::ToolCall(_)) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } + } + + pub fn fake_acp_thread( + project: Entity, + cx: &mut TestAppContext, + ) -> (Entity, Entity) { + let (stdin_tx, stdin_rx) = async_pipe::pipe(); + let (stdout_tx, stdout_rx) = async_pipe::pipe(); + + let thread = cx.new(|cx| { + let foreground_executor = cx.foreground_executor().clone(); + let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()), + stdin_tx, + stdout_rx, + move |fut| { + foreground_executor.spawn(fut).detach(); + }, + ); + + let io_task = cx.background_spawn({ + async move { + io_fut.await.log_err(); + Ok(()) + } + }); + let connection = OldAcpAgentConnection { + name: "test", + connection, + child_status: io_task, + }; + + AcpThread::new( + Rc::new(connection), + project, + acp::SessionId("test".into()), + cx, + ) + }); + let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); + (thread, agent) + } + + pub struct FakeAcpServer { + connection: acp_old::ClientConnection, + + _io_task: Task<()>, + on_user_message: Option< + Rc< + dyn Fn( + acp_old::SendUserMessageParams, + Entity, + AsyncApp, + ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>, + >, + >, + } + + #[derive(Clone)] + struct FakeAgent { + server: Entity, + cx: AsyncApp, + cancel_tx: Rc>>>, + } + + impl acp_old::Agent for FakeAgent { + async fn initialize( + &self, + params: acp_old::InitializeParams, + ) -> Result { + Ok(acp_old::InitializeResponse { + protocol_version: params.protocol_version, + is_authenticated: true, + }) + } + + async fn authenticate(&self) -> Result<(), acp_old::Error> { + Ok(()) + } + + async fn cancel_send_message(&self) -> Result<(), acp_old::Error> { + if let Some(cancel_tx) = self.cancel_tx.take() { + cancel_tx.send(()).log_err(); + } + Ok(()) + } + + async fn send_user_message( + &self, + request: acp_old::SendUserMessageParams, + ) -> Result<(), acp_old::Error> { + let (cancel_tx, cancel_rx) = oneshot::channel(); + self.cancel_tx.replace(Some(cancel_tx)); + + let mut cx = self.cx.clone(); + let handler = self + .server + .update(&mut cx, |server, _| server.on_user_message.clone()) + .ok() + .flatten(); + if let Some(handler) = handler { + select! { + _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()), + _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()), + } + } else { + Err(anyhow::anyhow!("No handler for on_user_message").into()) + } + } + } + + impl FakeAcpServer { + fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context) -> Self { + let agent = FakeAgent { + server: cx.entity(), + cx: cx.to_async(), + cancel_tx: Default::default(), + }; + let foreground_executor = cx.foreground_executor().clone(); + + let (connection, io_fut) = acp_old::ClientConnection::connect_to_client( + agent.clone(), + stdout, + stdin, + move |fut| { + foreground_executor.spawn(fut).detach(); + }, + ); + FakeAcpServer { + connection: connection, + on_user_message: None, + _io_task: cx.background_spawn(async move { + io_fut.await.log_err(); + }), + } + } + + fn on_user_message( + &mut self, + handler: impl for<'a> Fn( + acp_old::SendUserMessageParams, + Entity, + AsyncApp, + ) -> F + + 'static, + ) where + F: Future> + 'static, + { + self.on_user_message + .replace(Rc::new(move |request, server, cx| { + handler(request, server, cx).boxed_local() + })); + } + + fn send_to_zed( + &self, + message: T, + ) -> BoxedLocal> { + self.connection + .request(message) + .map(|f| f.map_err(|err| anyhow!(err))) + .boxed_local() + } + } +} diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs new file mode 100644 index 0000000000..fde167da5f --- /dev/null +++ b/crates/acp_thread/src/connection.rs @@ -0,0 +1,26 @@ +use std::{path::Path, rc::Rc}; + +use agent_client_protocol as acp; +use anyhow::Result; +use gpui::{AsyncApp, Entity, Task}; +use project::Project; +use ui::App; + +use crate::AcpThread; + +pub trait AgentConnection { + fn name(&self) -> &'static str; + + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>>; + + fn authenticate(&self, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task>; + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); +} diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs new file mode 100644 index 0000000000..316a5bcf25 --- /dev/null +++ b/crates/acp_thread/src/old_acp_support.rs @@ -0,0 +1,461 @@ +// Translates old acp agents into the new schema +use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; +use anyhow::{Context as _, Result}; +use futures::channel::oneshot; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use ui::App; + +use crate::{AcpThread, AcpThreadEvent, AgentConnection, ToolCallContent, ToolCallStatus}; + +#[derive(Clone)] +pub struct OldAcpClientDelegate { + thread: Rc>>, + cx: AsyncApp, + next_tool_call_id: Rc>, + // sent_buffer_versions: HashMap, HashMap>, +} + +impl OldAcpClientDelegate { + pub fn new(thread: Rc>>, cx: AsyncApp) -> Self { + Self { + thread, + cx, + next_tool_call_id: Rc::new(RefCell::new(0)), + } + } +} + +impl acp_old::Client for OldAcpClientDelegate { + async fn stream_assistant_message_chunk( + &self, + params: acp_old::StreamAssistantMessageChunkParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| match params.chunk { + acp_old::AssistantMessageChunk::Text { text } => { + thread.push_assistant_chunk(text.into(), false, cx) + } + acp_old::AssistantMessageChunk::Thought { thought } => { + thread.push_assistant_chunk(thought.into(), true, cx) + } + }) + .ok(); + })?; + + Ok(()) + } + + async fn request_tool_call_confirmation( + &self, + request: acp_old::RequestToolCallConfirmationParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + let tool_call = into_new_tool_call( + acp::ToolCallId(old_acp_id.to_string().into()), + request.tool_call, + ); + + let mut options = match request.confirmation { + acp_old::ToolCallConfirmation::Edit { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow Edits".to_string(), + )], + acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", root_command), + )], + acp_old::ToolCallConfirmation::Mcp { + server_name, + tool_name, + .. + } => vec![ + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", server_name), + ), + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", tool_name), + ), + ], + acp_old::ToolCallConfirmation::Fetch { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + acp_old::ToolCallConfirmation::Other { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + }; + + options.extend([ + ( + acp_old::ToolCallConfirmationOutcome::Allow, + acp::PermissionOptionKind::AllowOnce, + "Allow".to_string(), + ), + ( + acp_old::ToolCallConfirmationOutcome::Reject, + acp::PermissionOptionKind::RejectOnce, + "Reject".to_string(), + ), + ]); + + let mut outcomes = Vec::with_capacity(options.len()); + let mut acp_options = Vec::with_capacity(options.len()); + + for (index, (outcome, kind, label)) in options.into_iter().enumerate() { + outcomes.push(outcome); + acp_options.push(acp::PermissionOption { + id: acp::PermissionOptionId(index.to_string().into()), + label, + kind, + }) + } + + let response = cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.request_tool_call_permission(tool_call, acp_options, cx) + }) + })? + .context("Failed to update thread")? + .await; + + let outcome = match response { + Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], + Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, + }; + + Ok(acp_old::RequestToolCallConfirmationResponse { + id: acp_old::ToolCallId(old_acp_id), + outcome: outcome, + }) + } + + async fn push_tool_call( + &self, + request: acp_old::PushToolCallParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.upsert_tool_call( + into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(acp_old::PushToolCallResponse { + id: acp_old::ToolCallId(old_acp_id), + }) + } + + async fn update_tool_call( + &self, + request: acp_old::UpdateToolCallParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + let languages = thread.project.read(cx).languages().clone(); + + if let Some((ix, tool_call)) = thread + .tool_call_mut(&acp::ToolCallId(request.tool_call_id.0.to_string().into())) + { + tool_call.status = ToolCallStatus::Allowed { + status: into_new_tool_call_status(request.status), + }; + tool_call.content = request + .content + .into_iter() + .map(|content| { + ToolCallContent::from_acp( + into_new_tool_call_content(content), + languages.clone(), + cx, + ) + }) + .collect(); + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + anyhow::Ok(()) + } else { + anyhow::bail!("Tool call not found") + } + }) + })? + .context("Failed to update thread")??; + + Ok(()) + } + + async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_plan( + acp::Plan { + entries: request + .entries + .into_iter() + .map(into_new_plan_entry) + .collect(), + }, + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + async fn read_text_file( + &self, + acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.read_text_file(path, line, limit, false, cx) + }) + })? + .context("Failed to update thread")? + .await?; + Ok(acp_old::ReadTextFileResponse { content }) + } + + async fn write_text_file( + &self, + acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, + ) -> Result<(), acp_old::Error> { + self.cx + .update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } +} + +fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { + acp::ToolCall { + id: id, + label: request.label, + kind: acp_kind_from_old_icon(request.icon), + status: acp::ToolCallStatus::InProgress, + content: request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect(), + locations: request + .locations + .into_iter() + .map(into_new_tool_call_location) + .collect(), + } +} + +fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { + match icon { + acp_old::Icon::FileSearch => acp::ToolKind::Search, + acp_old::Icon::Folder => acp::ToolKind::Search, + acp_old::Icon::Globe => acp::ToolKind::Search, + acp_old::Icon::Hammer => acp::ToolKind::Other, + acp_old::Icon::LightBulb => acp::ToolKind::Think, + acp_old::Icon::Pencil => acp::ToolKind::Edit, + acp_old::Icon::Regex => acp::ToolKind::Search, + acp_old::Icon::Terminal => acp::ToolKind::Execute, + } +} + +fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { + match status { + acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, + acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, + acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, + } +} + +fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { + match content { + acp_old::ToolCallContent::Markdown { markdown } => acp::ToolCallContent::ContentBlock { + content: acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: markdown, + }), + }, + acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { + diff: into_new_diff(diff), + }, + } +} + +fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { + acp::Diff { + path: diff.path, + old_text: diff.old_text, + new_text: diff.new_text, + } +} + +fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { + acp::ToolCallLocation { + path: location.path, + line: location.line, + } +} + +fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { + acp::PlanEntry { + content: entry.content, + priority: into_new_plan_priority(entry.priority), + status: into_new_plan_status(entry.status), + } +} + +fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { + match priority { + acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, + acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, + acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, + } +} + +fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { + match status { + acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, + acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, + acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, + } +} + +#[derive(Debug)] +pub struct Unauthenticated; + +impl Error for Unauthenticated {} +impl fmt::Display for Unauthenticated { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Unauthenticated") + } +} + +pub struct OldAcpAgentConnection { + pub name: &'static str, + pub connection: acp_old::AgentConnection, + pub child_status: Task>, +} + +impl AgentConnection for OldAcpAgentConnection { + fn name(&self) -> &'static str { + self.name + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.connection.request_any( + acp_old::InitializeParams { + protocol_version: acp_old::ProtocolVersion::latest(), + } + .into_any(), + ); + cx.spawn(async move |cx| { + let result = task.await?; + let result = acp_old::InitializeParams::response_from_any(result)?; + + if !result.is_authenticated { + anyhow::bail!(Unauthenticated) + } + + cx.update(|cx| { + let thread = cx.new(|cx| { + let session_id = acp::SessionId("acp-old-no-id".into()); + AcpThread::new(self.clone(), project, session_id, cx) + }); + thread + }) + }) + } + + fn authenticate(&self, cx: &mut App) -> Task> { + let task = self + .connection + .request_any(acp_old::AuthenticateParams.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + Ok(()) + }) + } + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let chunks = params + .prompt + .into_iter() + .filter_map(|block| match block { + acp::ContentBlock::Text(text) => { + Some(acp_old::UserMessageChunk::Text { text: text.text }) + } + acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { + path: link.uri.into(), + }), + _ => None, + }) + .collect(); + + let task = self + .connection + .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + anyhow::Ok(()) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { + let task = self + .connection + .request_any(acp_old::CancelSendMessageParams.into_any()); + cx.foreground_executor() + .spawn(async move { + task.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx) + } +} diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index b07c541821..f8ea7173d8 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -231,7 +231,6 @@ impl ActivityIndicator { status, } => { let create_buffer = project.update(cx, |project, cx| project.create_buffer(cx)); - let project = project.clone(); let status = status.clone(); let server_name = server_name.clone(); cx.spawn_in(window, async move |workspace, cx| { @@ -247,8 +246,7 @@ impl ActivityIndicator { workspace.update_in(cx, |workspace, window, cx| { workspace.add_item_to_active_pane( Box::new(cx.new(|cx| { - let mut editor = - Editor::for_buffer(buffer, Some(project.clone()), window, cx); + let mut editor = Editor::for_buffer(buffer, None, window, cx); editor.set_read_only(true); editor })), @@ -448,7 +446,7 @@ impl ActivityIndicator { .into_any_element(), ), message: format!("Debug: {}", session.read(cx).adapter()), - tooltip_message: Some(session.read(cx).label().to_string()), + tooltip_message: session.read(cx).label().map(|label| label.to_string()), on_click: None, }); } diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index da7de1e312..4c6d2b2b0b 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -38,7 +38,7 @@ impl Tool for ContextServerTool { } fn icon(&self) -> IconName { - IconName::Cog + IconName::ToolHammer } fn source(&self) -> ToolSource { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1f2654dac5..1af27ca8a7 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -21,6 +21,7 @@ use gpui::{ AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, Window, }; +use http_client::StatusCode; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, @@ -46,12 +47,24 @@ use std::{ time::{Duration, Instant}, }; use thiserror::Error; -use util::{ResultExt as _, debug_panic, post_inc}; +use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; -const MAX_RETRY_ATTEMPTS: u8 = 3; -const BASE_RETRY_DELAY_SECS: u64 = 5; +const MAX_RETRY_ATTEMPTS: u8 = 4; +const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, +} #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, @@ -383,6 +396,7 @@ pub struct Thread { remaining_turns: u32, configured_model: Option, profile: AgentProfile, + last_error_context: Option<(Arc, CompletionIntent)>, } #[derive(Clone, Debug)] @@ -476,10 +490,11 @@ impl Thread { retry_state: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, - configured_model, + configured_model: configured_model.clone(), profile: AgentProfile::new(profile_id, tools), } } @@ -600,6 +615,7 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + last_error_context: None, last_received_chunk_at: None, request_callback: None, remaining_turns: u32::MAX, @@ -1251,9 +1267,58 @@ impl Thread { self.flush_notifications(model.clone(), intent, cx); - let request = self.to_completion_request(model.clone(), intent, cx); + let _checkpoint = self.finalize_pending_checkpoint(cx); + self.stream_completion( + self.to_completion_request(model.clone(), intent, cx), + model, + intent, + window, + cx, + ); + } - self.stream_completion(request, model, intent, window, cx); + pub fn retry_last_completion( + &mut self, + window: Option, + cx: &mut Context, + ) { + // Clear any existing error state + self.retry_state = None; + + // Use the last error context if available, otherwise fall back to configured model + let (model, intent) = if let Some((model, intent)) = self.last_error_context.take() { + (model, intent) + } else if let Some(configured_model) = self.configured_model.as_ref() { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else if let Some(configured_model) = self.get_or_init_configured_model(cx) { + let model = configured_model.model.clone(); + let intent = if self.has_pending_tool_uses() { + CompletionIntent::ToolResults + } else { + CompletionIntent::UserPrompt + }; + (model, intent) + } else { + return; + }; + + self.send_to_model(model, intent, window, cx); + } + + pub fn enable_burn_mode_and_retry( + &mut self, + window: Option, + cx: &mut Context, + ) { + self.completion_mode = CompletionMode::Burn; + cx.emit(ThreadEvent::ProfileChanged); + self.retry_last_completion(window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { @@ -1284,6 +1349,7 @@ impl Thread { tool_choice: None, stop: Vec::new(), temperature: AgentSettings::temperature_for_model(&model, cx), + thinking_allowed: true, }; let available_tools = self.available_tools(cx, model.clone()); @@ -1449,6 +1515,7 @@ impl Thread { tool_choice: None, stop: Vec::new(), temperature: AgentSettings::temperature_for_model(model, cx), + thinking_allowed: false, }; for message in &self.messages { @@ -1515,21 +1582,21 @@ impl Thread { model: Arc, cx: &mut App, ) -> Option { - let action_log = self.action_log.read(cx); - - action_log.unnotified_stale_buffers(cx).next()?; - // Represent notification as a simulated `project_notifications` tool call let tool_name = Arc::from("project_notifications"); - let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else { - debug_panic!("`project_notifications` tool not found"); - return None; - }; + let tool = self.tools.read(cx).tool(&tool_name, cx)?; if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { return None; } + if self + .action_log + .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none()) + { + return None; + } + let input = serde_json::json!({}); let request = Arc::new(LanguageModelRequest::default()); // unused let window = None; @@ -1931,18 +1998,6 @@ impl Thread { project.set_agent_location(None, cx); }); - fn emit_generic_error(error: &anyhow::Error, cx: &mut Context) { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error interacting with language model".into(), - message: SharedString::from(error_message.clone()), - })); - } - if error.is::() { cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); } else if let Some(error) = @@ -1954,9 +2009,10 @@ impl Thread { } else if let Some(completion_error) = error.downcast_ref::() { - use LanguageModelCompletionError::*; match &completion_error { - PromptTooLarge { tokens, .. } => { + LanguageModelCompletionError::PromptTooLarge { + tokens, .. + } => { let tokens = tokens.unwrap_or_else(|| { // We didn't get an exact token count from the API, so fall back on our estimate. thread @@ -1977,63 +2033,28 @@ impl Thread { }); cx.notify(); } - RateLimitExceeded { - retry_after: Some(retry_after), - .. - } - | ServerOverloaded { - retry_after: Some(retry_after), - .. - } => { - thread.handle_rate_limit_error( - &completion_error, - *retry_after, - model.clone(), - intent, - window, - cx, - ); - retry_scheduled = true; - } - RateLimitExceeded { .. } | ServerOverloaded { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); + _ => { + if let Some(retry_strategy) = + Thread::get_retry_strategy(completion_error) + { + log::info!( + "Retrying with {:?} for language model completion error {:?}", + retry_strategy, + completion_error + ); + + retry_scheduled = thread + .handle_retryable_error_with_delay( + &completion_error, + Some(retry_strategy), + model.clone(), + intent, + window, + cx, + ); } } - ApiInternalServerError { .. } - | ApiReadResponseError { .. } - | HttpSend { .. } => { - retry_scheduled = thread.handle_retryable_error( - &completion_error, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); - } - } - NoApiKey { .. } - | HttpResponseError { .. } - | BadRequestFormat { .. } - | AuthenticationError { .. } - | PermissionError { .. } - | ApiEndpointNotFound { .. } - | SerializeRequest { .. } - | BuildRequestBody { .. } - | DeserializeResponse { .. } - | Other { .. } => emit_generic_error(error, cx), } - } else { - emit_generic_error(error, cx); } if !retry_scheduled { @@ -2160,73 +2181,141 @@ impl Thread { }); } - fn handle_rate_limit_error( - &mut self, - error: &LanguageModelCompletionError, - retry_after: Duration, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - // For rate limit errors, we only retry once with the specified duration - let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs()); - log::warn!( - "Retrying completion request in {} seconds: {error:?}", - retry_after.as_secs(), - ); + fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option { + use LanguageModelCompletionError::*; - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - // Schedule the retry - let thread_handle = cx.entity().downgrade(); - - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(retry_after).await; - - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); + // General strategy here: + // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. + // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff. + // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times. + match error { + HttpResponseError { + status_code: StatusCode::TOO_MANY_REQUESTS, + .. + } => Some(RetryStrategy::ExponentialBackoff { + initial_delay: BASE_RETRY_DELAY, + max_attempts: MAX_RETRY_ATTEMPTS, + }), + ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, }) - .log_err(); - }) - .detach(); - } - - fn handle_retryable_error( - &mut self, - error: &LanguageModelCompletionError, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - self.handle_retryable_error_with_delay(error, None, model, intent, window, cx) + } + UpstreamProviderError { + status, + retry_after, + .. + } => match *status { + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + // Internal Server Error could be anything, retry up to 3 times. + max_attempts: 3, + }), + status => { + // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), + // but we frequently get them in practice. See https://http.dev/529 + if status.as_u16() == 529 { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } else { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: 2, + }) + } + } + }, + ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + ApiReadResponseError { .. } + | HttpSend { .. } + | DeserializeResponse { .. } + | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + // Retrying these errors definitely shouldn't help. + HttpResponseError { + status_code: + StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, + .. + } + | AuthenticationError { .. } + | PermissionError { .. } + | NoApiKey { .. } + | ApiEndpointNotFound { .. } + | PromptTooLarge { .. } => None, + // These errors might be transient, so retry them + SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + // Retry all other 4xx and 5xx errors once. + HttpResponseError { status_code, .. } + if status_code.is_client_error() || status_code.is_server_error() => + { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }) + } + // Conservatively assume that any other errors are non-retryable + HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 2, + }), + } } fn handle_retryable_error_with_delay( &mut self, error: &LanguageModelCompletionError, - custom_delay: Option, + strategy: Option, model: Arc, intent: CompletionIntent, window: Option, cx: &mut Context, ) -> bool { + // Store context for the Retry button + self.last_error_context = Some((model.clone(), intent)); + + // Only auto-retry if Burn Mode is enabled + if self.completion_mode != CompletionMode::Burn { + // Show error with retry options + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!( + "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.", + error + ) + .into(), + can_enable_burn_mode: true, + })); + return false; + } + + let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else { + return false; + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + let retry_state = self.retry_state.get_or_insert(RetryState { attempt: 0, - max_attempts: MAX_RETRY_ATTEMPTS, + max_attempts, intent, }); @@ -2236,20 +2325,24 @@ impl Thread { let intent = retry_state.intent; if attempt <= max_attempts { - // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff - let delay = if let Some(custom_delay) = custom_delay { - custom_delay - } else { - let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, }; // Add a transient message to inform the user let delay_secs = delay.as_secs(); - let retry_message = format!( - "{error}. Retrying (attempt {attempt} of {max_attempts}) \ - in {delay_secs} seconds..." - ); + let retry_message = if max_attempts == 1 { + format!("{error}. Retrying in {delay_secs} seconds...") + } else { + format!( + "{error}. Retrying (attempt {attempt} of {max_attempts}) \ + in {delay_secs} seconds..." + ) + }; log::warn!( "Retrying completion request (attempt {attempt} of {max_attempts}) \ in {delay_secs} seconds: {error:?}", @@ -2288,18 +2381,15 @@ impl Thread { // Max retries exceeded self.retry_state = None; - let notification_text = if max_attempts == 1 { - "Failed after retrying.".into() - } else { - format!("Failed after retrying {} times.", max_attempts).into() - }; - // Stop generating since we're giving up on retrying. self.pending_completions.clear(); - cx.emit(ThreadEvent::RetriesFailed { - message: notification_text, - }); + // Show error alongside a Retry button, but no + // Enable Burn Mode button (since it's already enabled) + cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { + message: format!("Failed after retrying: {}", error).into(), + can_enable_burn_mode: false, + })); false } @@ -3211,6 +3301,11 @@ pub enum ThreadError { header: SharedString, message: SharedString, }, + #[error("Retryable error: {message}")] + RetryableError { + message: SharedString, + can_enable_burn_mode: bool, + }, } #[derive(Debug, Clone)] @@ -3256,9 +3351,6 @@ pub enum ThreadEvent { CancelEditing, CompletionCanceled, ProfileChanged, - RetriesFailed { - message: SharedString, - }, } impl EventEmitter for Thread {} @@ -3286,7 +3378,6 @@ mod tests { use futures::stream::BoxStream; use gpui::TestAppContext; use http_client; - use indoc::indoc; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use language_model::{ LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, @@ -3615,6 +3706,7 @@ fn main() {{ } #[gpui::test] + #[ignore] // turn this test on when project_notifications tool is re-enabled async fn test_stale_buffer_notification(cx: &mut TestAppContext) { init_test_settings(cx); @@ -3647,6 +3739,7 @@ fn main() {{ cx, ); }); + cx.run_until_parked(); // We shouldn't have a stale buffer notification yet let notifications = thread.read_with(cx, |thread, _| { @@ -3676,11 +3769,13 @@ fn main() {{ cx, ) }); + cx.run_until_parked(); // Check for the stale buffer warning thread.update(cx, |thread, cx| { thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) }); + cx.run_until_parked(); let notifications = thread.read_with(cx, |thread, _cx| { find_tool_uses(thread, "project_notifications") @@ -3694,12 +3789,8 @@ fn main() {{ panic!("`project_notifications` should return text"); }; - let expected_content = indoc! {"[The following is an auto-generated notification; do not reply] - - These files have changed since the last read: - - code.rs - "}; - assert_eq!(notification_content, expected_content); + assert!(notification_content.contains("These files have changed since the last read:")); + assert!(notification_content.contains("code.rs")); // Insert another user message and flush notifications again thread.update(cx, |thread, cx| { @@ -3715,6 +3806,7 @@ fn main() {{ thread.update(cx, |thread, cx| { thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) }); + cx.run_until_parked(); // There should be no new notifications (we already flushed one) let notifications = thread.read_with(cx, |thread, _cx| { @@ -4169,6 +4261,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4190,7 +4287,7 @@ fn main() {{ assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have default max attempts" + "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors" ); }); @@ -4242,6 +4339,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns internal server error let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); @@ -4263,7 +4365,7 @@ fn main() {{ let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + retry_state.max_attempts, 3, "Should have correct max attempts" ); }); @@ -4279,8 +4381,9 @@ fn main() {{ if let MessageSegment::Text(text) = seg { text.contains("internal") && text.contains("Fake") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + && text.contains("Retrying") + && text.contains("attempt 1 of 3") + && text.contains("seconds") } else { false } @@ -4318,8 +4421,13 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + + // Create model that returns internal server error + let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); // Insert a user message thread.update(cx, |thread, cx| { @@ -4369,50 +4477,25 @@ fn main() {{ assert!(thread.retry_state.is_some(), "Should have retry state"); let retry_state = thread.retry_state.as_ref().unwrap(); assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, 3, + "Internal server errors should retry up to 3 times" + ); }); // Advance clock for first retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); - // Should have scheduled second retry - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 2, "Should have scheduled second retry"); - - // Check retry state updated - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for second retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); + // Advance clock for second retry + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); - // Should have scheduled third retry - // Count all retry messages now + // Advance clock for third retry + cx.executor().advance_clock(BASE_RETRY_DELAY); + cx.run_until_parked(); + + // Should have completed all retries - count retry messages let retry_count = thread.update(cx, |thread, _| { thread .messages @@ -4430,56 +4513,24 @@ fn main() {{ .count() }); assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have scheduled third retry" + retry_count, 3, + "Should have 3 retries for internal server errors" ); - // Check retry state updated + // For internal server errors, we retry 3 times and then give up + // Check that retry_state is cleared after all retries thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!( - retry_state.attempt, MAX_RETRY_ATTEMPTS, - "Should be at max retry attempt" - ); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after all retries" ); }); - // Advance clock for third retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); - cx.run_until_parked(); - - // No more retries should be scheduled after clock was advanced. - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should not exceed max retries" - ); - - // Final completion count should be initial + max retries + // Verify total attempts (1 initial + 3 retries) assert_eq!( *completion_count.lock(), - (MAX_RETRY_ATTEMPTS + 1) as usize, - "Should have made initial + max retry attempts" + 4, + "Should have attempted once plus 3 retries" ); } @@ -4490,6 +4541,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -4499,13 +4555,13 @@ fn main() {{ }); // Track events - let retries_failed = Arc::new(Mutex::new(false)); - let retries_failed_clone = retries_failed.clone(); + let stopped_with_error = Arc::new(Mutex::new(false)); + let stopped_with_error_clone = stopped_with_error.clone(); let _subscription = thread.update(cx, |_, cx| { cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::RetriesFailed { .. } = event { - *retries_failed_clone.lock() = true; + if let ThreadEvent::Stopped(Err(_)) = event { + *stopped_with_error_clone.lock() = true; } }) }); @@ -4517,23 +4573,11 @@ fn main() {{ cx.run_until_parked(); // Advance through all retries - for i in 0..MAX_RETRY_ATTEMPTS { - let delay = if i == 0 { - BASE_RETRY_DELAY_SECS - } else { - BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) - }; - cx.executor().advance_clock(Duration::from_secs(delay)); + for _ in 0..MAX_RETRY_ATTEMPTS { + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); } - // After the 3rd retry is scheduled, we need to wait for it to execute and fail - // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) - let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); - cx.executor() - .advance_clock(Duration::from_secs(final_delay)); - cx.run_until_parked(); - let retry_count = thread.update(cx, |thread, _| { thread .messages @@ -4551,14 +4595,14 @@ fn main() {{ .count() }); - // After max retries, should emit RetriesFailed event + // After max retries, should emit Stopped(Err(...)) event assert_eq!( retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted max retries" + "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors" ); assert!( - *retries_failed.lock(), - "Should emit RetriesFailed event after max retries exceeded" + *stopped_with_error.lock(), + "Should emit Stopped(Err(...)) event after max retries exceeded" ); // Retry state should be cleared @@ -4576,7 +4620,7 @@ fn main() {{ .count(); assert_eq!( retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have one retry message per attempt" + "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors" ); }); } @@ -4588,6 +4632,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // We'll use a wrapper to switch behavior after first failure struct RetryTestModel { inner: Arc, @@ -4714,8 +4763,7 @@ fn main() {{ }); // Wait for retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // Stream some successful content @@ -4757,6 +4805,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that fails once then succeeds struct FailOnceModel { inner: Arc, @@ -4877,8 +4930,7 @@ fn main() {{ }); // Wait for retry delay - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.executor().advance_clock(BASE_RETRY_DELAY); cx.run_until_parked(); // The retry should now use our FailOnceModel which should succeed @@ -4919,6 +4971,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create a model that returns rate limit error with retry_after struct RateLimitModel { inner: Arc, @@ -5037,9 +5094,15 @@ fn main() {{ thread.read_with(cx, |thread, _| { assert!( - thread.retry_state.is_none(), - "Rate limit errors should not set retry_state" + thread.retry_state.is_some(), + "Rate limit errors should set retry_state" ); + if let Some(retry_state) = &thread.retry_state { + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Rate limit errors should use MAX_RETRY_ATTEMPTS" + ); + } }); // Verify we have one retry message @@ -5072,18 +5135,15 @@ fn main() {{ .find(|msg| msg.role == Role::System && msg.ui_only) .expect("Should have a retry message"); - // Check that the message doesn't contain attempt count + // Check that the message contains attempt count since we use retry_state if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { assert!( - !text.contains("attempt"), - "Rate limit retry message should not contain attempt count" + text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)), + "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS" ); assert!( - text.contains(&format!( - "Retrying in {} seconds", - TEST_RATE_LIMIT_RETRY_SECS - )), - "Rate limit retry message should contain retry delay" + text.contains("Retrying"), + "Rate limit retry message should contain retry text" ); } }); @@ -5189,6 +5249,79 @@ fn main() {{ ); } + #[gpui::test] + async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Ensure we're in Normal mode (not Burn mode) + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Normal); + }); + + // Track error events + let error_events = Arc::new(Mutex::new(Vec::new())); + let error_events_clone = error_events.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::ShowError(error) = event { + error_events_clone.lock().push(error.clone()); + } + }) + }); + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Verify no retry state was created + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Should not have retry state in Normal mode" + ); + }); + + // Check that a retryable error was reported + let errors = error_events.lock(); + assert!(!errors.is_empty(), "Should have received an error event"); + + if let ThreadError::RetryableError { + message: _, + can_enable_burn_mode, + } = &errors[0] + { + assert!( + *can_enable_burn_mode, + "Error should indicate burn mode can be enabled" + ); + } else { + panic!("Expected RetryableError, got {:?}", errors[0]); + } + + // Verify the thread is no longer generating + thread.read_with(cx, |thread, _| { + assert!( + !thread.is_generating(), + "Should not be generating after error without retry" + ); + }); + } + #[gpui::test] async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { init_test_settings(cx); @@ -5196,6 +5329,11 @@ fn main() {{ let project = create_test_project(cx, json!({})).await; let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + // Enable Burn Mode to allow retries + thread.update(cx, |thread, _| { + thread.set_completion_mode(CompletionMode::Burn); + }); + // Create model that returns overloaded error let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); @@ -5357,7 +5495,7 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - let provider = Arc::new(FakeLanguageModelProvider); + let provider = Arc::new(FakeLanguageModelProvider::default()); let model = provider.test_model(); let model: Arc = Arc::new(model); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 0347156cd4..cc7cb50c91 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -41,6 +41,9 @@ use std::{ }; use util::ResultExt as _; +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum DataType { #[serde(rename = "json")] @@ -874,7 +877,11 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = Connection::open_file(&sqlite_path.to_string_lossy()); + let connection = if *ZED_STATELESS { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + Connection::open_file(&sqlite_path.to_string_lossy()) + }; connection.exec(indoc! {" CREATE TABLE IF NOT EXISTS threads ( diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml new file mode 100644 index 0000000000..4371f7684d --- /dev/null +++ b/crates/agent_servers/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "agent_servers" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[features] +test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"] +e2e = [] + +[lints] +workspace = true + +[lib] +path = "src/agent_servers.rs" +doctest = false + +[dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true +agentic-coding-protocol.workspace = true +anyhow.workspace = true +collections.workspace = true +context_server.workspace = true +futures.workspace = true +gpui.workspace = true +itertools.workspace = true +log.workspace = true +paths.workspace = true +project.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strum.workspace = true +tempfile.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +watch.workspace = true +which.workspace = true +workspace-hack.workspace = true + +[target.'cfg(unix)'.dependencies] +libc.workspace = true +nix.workspace = true + +[dev-dependencies] +env_logger.workspace = true +language.workspace = true +indoc.workspace = true +acp_thread = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/LICENSE-GPL b/crates/agent_servers/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/agent_servers/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs new file mode 100644 index 0000000000..660f61f907 --- /dev/null +++ b/crates/agent_servers/src/agent_servers.rs @@ -0,0 +1,164 @@ +mod claude; +mod gemini; +mod settings; + +#[cfg(test)] +mod e2e_tests; + +pub use claude::*; +pub use gemini::*; +pub use settings::*; + +use acp_thread::AgentConnection; +use anyhow::Result; +use collections::HashMap; +use gpui::{App, AsyncApp, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{ + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use util::ResultExt as _; + +pub fn init(cx: &mut App) { + settings::init(cx); +} + +pub trait AgentServer: Send { + fn logo(&self) -> ui::IconName; + fn name(&self) -> &'static str; + fn empty_state_headline(&self) -> &'static str; + fn empty_state_message(&self) -> &'static str; + + fn connect( + &self, + // these will go away when old_acp is fully removed + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>>; +} + +impl std::fmt::Debug for AgentServerCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let filtered_env = self.env.as_ref().map(|env| { + env.iter() + .map(|(k, v)| { + ( + k, + if util::redact::should_redact(k) { + "[REDACTED]" + } else { + v + }, + ) + }) + .collect::>() + }); + + f.debug_struct("AgentServerCommand") + .field("path", &self.path) + .field("args", &self.args) + .field("env", &filtered_env) + .finish() + } +} + +pub enum AgentServerVersion { + Supported, + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, +} + +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] +pub struct AgentServerCommand { + #[serde(rename = "command")] + pub path: PathBuf, + #[serde(default)] + pub args: Vec, + pub env: Option>, +} + +impl AgentServerCommand { + pub(crate) async fn resolve( + path_bin_name: &'static str, + extra_args: &[&'static str], + settings: Option, + project: &Entity, + cx: &mut AsyncApp, + ) -> Option { + if let Some(agent_settings) = settings { + return Some(Self { + path: agent_settings.command.path, + args: agent_settings + .command + .args + .into_iter() + .chain(extra_args.iter().map(|arg| arg.to_string())) + .collect(), + env: agent_settings.command.env, + }); + } else { + find_bin_in_path(path_bin_name, project, cx) + .await + .map(|path| Self { + path, + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }) + } + } +} + +async fn find_bin_in_path( + bin_name: &'static str, + project: &Entity, + cx: &mut AsyncApp, +) -> Option { + let (env_task, root_dir) = project + .update(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next(); + match worktree { + Some(worktree) => { + let env_task = project.environment().update(cx, |env, cx| { + env.get_worktree_environment(worktree.clone(), cx) + }); + + let path = worktree.read(cx).abs_path(); + (env_task, path) + } + None => { + let path: Arc = paths::home_dir().as_path().into(); + let env_task = project.environment().update(cx, |env, cx| { + env.get_directory_environment(path.clone(), cx) + }); + (env_task, path) + } + } + }) + .log_err()?; + + cx.background_executor() + .spawn(async move { + let which_result = if cfg!(windows) { + which::which(bin_name) + } else { + let env = env_task.await.unwrap_or_default(); + let shell_path = env.get("PATH").cloned(); + which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) + }; + + if let Err(which::Error::CannotFindBinaryPath) = which_result { + return None; + } + + which_result.log_err() + }) + .await +} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs new file mode 100644 index 0000000000..d63d8c43cf --- /dev/null +++ b/crates/agent_servers/src/claude.rs @@ -0,0 +1,839 @@ +mod mcp_server; +pub mod tools; + +use collections::HashMap; +use context_server::listener::McpServerTool; +use project::Project; +use settings::SettingsStore; +use smol::process::Child; +use std::cell::RefCell; +use std::fmt::Display; +use std::path::Path; +use std::pin::pin; +use std::rc::Rc; +use uuid::Uuid; + +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use futures::channel::oneshot; +use futures::{AsyncBufReadExt, AsyncWriteExt}; +use futures::{ + AsyncRead, AsyncWrite, FutureExt, StreamExt, + channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; +use serde::{Deserialize, Serialize}; +use util::ResultExt; + +use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; +use crate::claude::tools::ClaudeTool; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; +use acp_thread::{AcpThread, AgentConnection}; + +#[derive(Clone)] +pub struct ClaudeCode; + +impl AgentServer for ClaudeCode { + fn name(&self) -> &'static str { + "Claude Code" + } + + fn empty_state_headline(&self) -> &'static str { + self.name() + } + + fn empty_state_message(&self) -> &'static str { + "" + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiClaude + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + let connection = ClaudeAgentConnection { + sessions: Default::default(), + }; + + Task::ready(Ok(Rc::new(connection) as _)) + } +} + +#[cfg(unix)] +fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { + let pid = nix::unistd::Pid::from_raw(pid); + + nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) + .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) +} + +#[cfg(windows)] +fn send_interrupt(_pid: i32) -> anyhow::Result<()> { + panic!("Cancel not implemented on Windows") +} + +struct ClaudeAgentConnection { + sessions: Rc>>, +} + +impl AgentConnection for ClaudeAgentConnection { + fn name(&self) -> &'static str { + ClaudeCode.name() + } + + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let cwd = cwd.to_owned(); + cx.spawn(async move |cx| { + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; + + let mut mcp_servers = HashMap::default(); + mcp_servers.insert( + mcp_server::SERVER_NAME.to_string(), + permission_mcp_server.server_config()?, + ); + let mcp_config = McpConfig { mcp_servers }; + + let mcp_config_file = tempfile::NamedTempFile::new()?; + let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts(); + + let mut mcp_config_file = smol::fs::File::from(mcp_config_file); + mcp_config_file + .write_all(serde_json::to_string(&mcp_config)?.as_bytes()) + .await?; + mcp_config_file.flush().await?; + + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + })?; + + let Some(command) = + AgentServerCommand::resolve("claude", &[], settings, &project, cx).await + else { + anyhow::bail!("Failed to find claude binary"); + }; + + let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); + + let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); + + log::trace!("Starting session with id: {}", session_id); + + cx.background_spawn({ + let session_id = session_id.clone(); + async move { + let mut outgoing_rx = Some(outgoing_rx); + let mut mode = ClaudeSessionMode::Start; + + loop { + let mut child = spawn_claude( + &command, + mode, + session_id.clone(), + &mcp_config_path, + &cwd, + ) + .await?; + mode = ClaudeSessionMode::Resume; + + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); + + let mut io_fut = pin!( + ClaudeAgentSession::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + child.stdin.take().unwrap(), + child.stdout.take().unwrap(), + ) + .fuse() + ); + + select_biased! { + done_tx = cancel_rx.next() => { + if let Some(done_tx) = done_tx { + log::trace!("Interrupted (pid: {})", pid); + let result = send_interrupt(pid as i32); + outgoing_rx.replace(io_fut.await?); + done_tx.send(result).log_err(); + continue; + } + } + result = io_fut => { + result?; + } + } + + log::trace!("Stopped (pid: {})", pid); + break; + } + + drop(mcp_config_path); + anyhow::Ok(()) + } + }) + .detach(); + + let end_turn_tx = Rc::new(RefCell::new(None)); + let handler_task = cx.spawn({ + let end_turn_tx = end_turn_tx.clone(); + let thread_rx = thread_rx.clone(); + async move |cx| { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentSession::handle_message( + thread_rx.clone(), + message, + end_turn_tx.clone(), + cx, + ) + .await + } + } + }); + + let thread = + cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; + + thread_tx.send(thread.downgrade())?; + + let session = ClaudeAgentSession { + outgoing_tx, + end_turn_tx, + cancel_tx, + _handler_task: handler_task, + _mcp_server: Some(permission_mcp_server), + }; + + self.sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(¶ms.session_id) else { + return Task::ready(Err(anyhow!( + "Attempted to send message to nonexistent session {}", + params.session_id + ))); + }; + + let (tx, rx) = oneshot::channel(); + session.end_turn_tx.borrow_mut().replace(tx); + + let mut content = String::new(); + for chunk in params.prompt { + match chunk { + acp::ContentBlock::Text(text_content) => { + content.push_str(&text_content.text); + } + acp::ContentBlock::ResourceLink(resource_link) => { + content.push_str(&format!("@{}", resource_link.uri)); + } + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Image(_) + | acp::ContentBlock::Resource(_) => { + // TODO + } + } + } + + if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(params.session_id.to_string()), + }) { + return Task::ready(Err(anyhow!(err))); + } + + cx.foreground_executor().spawn(async move { + rx.await??; + Ok(()) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(&session_id) else { + log::warn!("Attempted to cancel nonexistent session {}", session_id); + return; + }; + + let (done_tx, done_rx) = oneshot::channel(); + if session + .cancel_tx + .unbounded_send(done_tx) + .log_err() + .is_some() + { + let end_turn_tx = session.end_turn_tx.clone(); + cx.foreground_executor() + .spawn(async move { + done_rx.await??; + if let Some(end_turn_tx) = end_turn_tx.take() { + end_turn_tx.send(Ok(())).ok(); + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + } +} + +#[derive(Clone, Copy)] +enum ClaudeSessionMode { + Start, + Resume, +} + +async fn spawn_claude( + command: &AgentServerCommand, + mode: ClaudeSessionMode, + session_id: acp::SessionId, + mcp_config_path: &Path, + root_dir: &Path, +) -> Result { + let child = util::command::new_smol_command(&command.path) + .args([ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--print", + "--verbose", + "--mcp-config", + mcp_config_path.to_string_lossy().as_ref(), + "--permission-prompt-tool", + &format!( + "mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::PermissionTool::NAME, + ), + "--allowedTools", + &format!( + "mcp__{}__{},mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::EditTool::NAME, + mcp_server::SERVER_NAME, + mcp_server::ReadTool::NAME + ), + "--disallowedTools", + "Read,Edit", + ]) + .args(match mode { + ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], + ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], + }) + .args(command.args.iter().map(|arg| arg.as_str())) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + Ok(child) +} + +struct ClaudeAgentSession { + outgoing_tx: UnboundedSender, + end_turn_tx: Rc>>>>, + cancel_tx: UnboundedSender>>, + _mcp_server: Option, + _handler_task: Task<()>, +} + +impl ClaudeAgentSession { + async fn handle_message( + mut thread_rx: watch::Receiver>, + message: SdkMessage, + end_turn_tx: Rc>>>>, + cx: &mut AsyncApp, + ) { + match message { + SdkMessage::Assistant { + message, + session_id: _, + } + | SdkMessage::User { + message, + session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk(text.into(), false, cx) + }) + .log_err(); + } + ContentChunk::ToolUse { id, name, input } => { + let claude_tool = ClaudeTool::infer(&name, input); + + thread + .update(cx, |thread, cx| { + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + thread.update_plan( + acp::Plan { + entries: params + .todos + .into_iter() + .map(Into::into) + .collect(), + }, + cx, + ) + } else { + thread.upsert_tool_call( + claude_tool.as_acp(acp::ToolCallId(id.into())), + cx, + ); + } + }) + .log_err(); + } + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallId(tool_use_id.into()), + acp::ToolCallStatus::Completed, + (!content.is_empty()).then(|| vec![content.into()]), + cx, + ) + }) + .log_err(); + } + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::Thinking + | ContentChunk::RedactedThinking + | ContentChunk::WebSearchToolResult => { + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) + .log_err(); + } + } + } + } + SdkMessage::Result { + is_error, subtype, .. + } => { + if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { + if is_error { + end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok(); + } else { + end_turn_tx.send(Ok(())).ok(); + } + } + } + SdkMessage::System { .. } => {} + } + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut outgoing_line = Vec::new(); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next() => { + if let Some(message) = message { + outgoing_line.clear(); + serde_json::to_writer(&mut outgoing_line, &message)?; + log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); + outgoing_line.push(b'\n'); + outgoing_bytes.write_all(&outgoing_line).await.ok(); + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str::(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + + Ok(outgoing_rx) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Message { + role: Role, + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_sequence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +enum Content { + UntaggedText(String), + Chunks(Vec), +} + +impl Content { + pub fn chunks(self) -> impl Iterator { + match self { + Self::Chunks(chunks) => chunks.into_iter(), + Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(), + } + } +} + +impl Display for Content { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Content::UntaggedText(txt) => write!(f, "{}", txt), + Content::Chunks(chunks) => { + for chunk in chunks { + write!(f, "{}", chunk)?; + } + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ContentChunk { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + content: Content, + tool_use_id: String, + }, + // TODO + Image, + Document, + Thinking, + RedactedThinking, + WebSearchToolResult, + #[serde(untagged)] + UntaggedText(String), +} + +impl Display for ContentChunk { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentChunk::Text { text } => write!(f, "{}", text), + ContentChunk::UntaggedText(text) => write!(f, "{}", text), + ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::Thinking + | ContentChunk::RedactedThinking + | ContentChunk::ToolUse { .. } + | ContentChunk::WebSearchToolResult => { + write!(f, "\n{:?}\n", &self) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Usage { + input_tokens: u32, + cache_creation_input_tokens: u32, + cache_read_input_tokens: u32, + output_tokens: u32, + service_tier: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum Role { + System, + Assistant, + User, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MessageParam { + role: Role, + content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SdkMessage { + // An assistant message + Assistant { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + + // A user message + User { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + + // Emitted as the last message in a conversation + Result { + subtype: ResultErrorType, + duration_ms: f64, + duration_api_ms: f64, + is_error: bool, + num_turns: i32, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + session_id: String, + total_cost_usd: f64, + }, + // Emitted as the first message at the start of a conversation + System { + cwd: String, + session_id: String, + tools: Vec, + model: String, + mcp_servers: Vec, + #[serde(rename = "apiKeySource")] + api_key_source: String, + #[serde(rename = "permissionMode")] + permission_mode: PermissionMode, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ResultErrorType { + Success, + ErrorMaxTurns, + ErrorDuringExecution, +} + +impl Display for ResultErrorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResultErrorType::Success => write!(f, "success"), + ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"), + ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct McpServer { + name: String, + status: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +enum PermissionMode { + Default, + AcceptEdits, + BypassPermissions, + Plan, +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use serde_json::json; + + crate::common_e2e_tests!(ClaudeCode); + + pub fn local_command() -> AgentServerCommand { + AgentServerCommand { + path: "claude".into(), + args: vec![], + env: None, + } + } + + #[test] + fn test_deserialize_content_untagged_text() { + let json = json!("Hello, world!"); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"), + _ => panic!("Expected UntaggedText variant"), + } + } + + #[test] + fn test_deserialize_content_chunks() { + let json = json!([ + { + "type": "text", + "text": "Hello" + }, + { + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": {"operation": "add", "a": 1, "b": 2} + } + ]); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Hello"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::ToolUse { id, name, input } => { + assert_eq!(id, "tool_123"); + assert_eq!(name, "calculator"); + assert_eq!(input["operation"], "add"); + assert_eq!(input["a"], 1); + assert_eq!(input["b"], 2); + } + _ => panic!("Expected ToolUse chunk"), + } + } + _ => panic!("Expected Chunks variant"), + } + } + + #[test] + fn test_deserialize_tool_result_untagged_text() { + let json = json!({ + "type": "tool_result", + "content": "Result content", + "tool_use_id": "tool_456" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::UntaggedText(text) => assert_eq!(text, "Result content"), + _ => panic!("Expected UntaggedText content"), + } + assert_eq!(tool_use_id, "tool_456"); + } + _ => panic!("Expected ToolResult variant"), + } + } + + #[test] + fn test_deserialize_tool_result_chunks() { + let json = json!({ + "type": "tool_result", + "content": [ + { + "type": "text", + "text": "Processing complete" + }, + { + "type": "text", + "text": "Result: 42" + } + ], + "tool_use_id": "tool_789" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Processing complete"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::Text { text } => assert_eq!(text, "Result: 42"), + _ => panic!("Expected Text chunk"), + } + } + _ => panic!("Expected Chunks content"), + } + assert_eq!(tool_use_id, "tool_789"); + } + _ => panic!("Expected ToolResult variant"), + } + } +} diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs new file mode 100644 index 0000000000..a320a6d37f --- /dev/null +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -0,0 +1,297 @@ +use std::path::PathBuf; + +use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; +use acp_thread::AcpThread; +use agent_client_protocol as acp; +use anyhow::{Context, Result}; +use collections::HashMap; +use context_server::listener::{McpServerTool, ToolResponse}; +use context_server::types::{ + Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, + ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, +}; +use gpui::{App, AsyncApp, Task, WeakEntity}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +pub struct ClaudeZedMcpServer { + server: context_server::listener::McpServer, +} + +pub const SERVER_NAME: &str = "zed"; + +impl ClaudeZedMcpServer { + pub async fn new( + thread_rx: watch::Receiver>, + cx: &AsyncApp, + ) -> Result { + let mut mcp_server = context_server::listener::McpServer::new(cx).await?; + mcp_server.handle_request::(Self::handle_initialize); + + mcp_server.add_tool(PermissionTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(ReadTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(EditTool { + thread_rx: thread_rx.clone(), + }); + + Ok(Self { server: mcp_server }) + } + + pub fn server_config(&self) -> Result { + let zed_path = std::env::current_exe() + .context("finding current executable path for use in mcp_server")?; + + Ok(McpServerConfig { + command: zed_path, + args: vec![ + "--nc".into(), + self.server.socket_path().display().to_string(), + ], + env: None, + }) + } + + fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(InitializeResponse { + protocol_version: ProtocolVersion("2025-06-18".into()), + capabilities: ServerCapabilities { + experimental: None, + logging: None, + completions: None, + prompts: None, + resources: None, + tools: Some(ToolsCapabilities { + list_changed: Some(false), + }), + }, + server_info: Implementation { + name: SERVER_NAME.into(), + version: "0.1.0".into(), + }, + meta: None, + }) + }) + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfig { + pub mcp_servers: HashMap, +} + +#[derive(Serialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + pub command: PathBuf, + pub args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, +} + +// Tools + +#[derive(Clone)] +pub struct PermissionTool { + thread_rx: watch::Receiver>, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl McpServerTool for PermissionTool { + type Input = PermissionToolParams; + type Output = (); + + const NAME: &'static str = "Confirmation"; + + fn description(&self) -> &'static str { + "Request permission for tool calls" + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); + let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); + let allow_option_id = acp::PermissionOptionId("allow".into()); + let reject_option_id = acp::PermissionOptionId("reject".into()); + + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission( + claude_tool.as_acp(tool_call_id), + vec![ + acp::PermissionOption { + id: allow_option_id.clone(), + label: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: reject_option_id.clone(), + label: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, + ) + })? + .await?; + + let response = if chosen_option == allow_option_id { + PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + } + } else { + PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + } + }; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct ReadTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for ReadTool { + type Input = ReadToolParams; + type Output = (); + + const NAME: &'static str = "Read"; + + fn description(&self) -> &'static str { + "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: None, + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { text: content }], + structured_content: (), + }) + } +} + +#[derive(Clone)] +pub struct EditTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for EditTool { + type Input = EditToolParams; + type Output = (); + + const NAME: &'static str = "Edit"; + + fn description(&self) -> &'static str { + "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path.clone(), None, None, true, cx) + })? + .await?; + + let new_content = content.replace(&input.old_text, &input.new_text); + if new_content == content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, new_content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: (), + }) + } +} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs new file mode 100644 index 0000000000..ed25f9af7f --- /dev/null +++ b/crates/agent_servers/src/claude/tools.rs @@ -0,0 +1,674 @@ +use std::path::PathBuf; + +use agent_client_protocol as acp; +use itertools::Itertools; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use util::ResultExt; + +pub enum ClaudeTool { + Task(Option), + NotebookRead(Option), + NotebookEdit(Option), + Edit(Option), + MultiEdit(Option), + ReadFile(Option), + Write(Option), + Ls(Option), + Glob(Option), + Grep(Option), + Terminal(Option), + WebFetch(Option), + WebSearch(Option), + TodoWrite(Option), + ExitPlanMode(Option), + Other { + name: String, + input: serde_json::Value, + }, +} + +impl ClaudeTool { + pub fn infer(tool_name: &str, input: serde_json::Value) -> Self { + match tool_name { + // Known tools + "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), + "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), + "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), + "Write" => Self::Write(serde_json::from_value(input).log_err()), + "LS" => Self::Ls(serde_json::from_value(input).log_err()), + "Glob" => Self::Glob(serde_json::from_value(input).log_err()), + "Grep" => Self::Grep(serde_json::from_value(input).log_err()), + "Bash" => Self::Terminal(serde_json::from_value(input).log_err()), + "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()), + "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()), + "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()), + "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()), + "Task" => Self::Task(serde_json::from_value(input).log_err()), + "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()), + "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()), + // Inferred from name + _ => { + let tool_name = tool_name.to_lowercase(); + + if tool_name.contains("edit") || tool_name.contains("write") { + Self::Edit(None) + } else if tool_name.contains("terminal") { + Self::Terminal(None) + } else { + Self::Other { + name: tool_name.to_string(), + input, + } + } + } + } + } + + pub fn label(&self) -> String { + match &self { + Self::Task(Some(params)) => params.description.clone(), + Self::Task(None) => "Task".into(), + Self::NotebookRead(Some(params)) => { + format!("Read Notebook {}", params.notebook_path.display()) + } + Self::NotebookRead(None) => "Read Notebook".into(), + Self::NotebookEdit(Some(params)) => { + format!("Edit Notebook {}", params.notebook_path.display()) + } + Self::NotebookEdit(None) => "Edit Notebook".into(), + Self::Terminal(Some(params)) => format!("`{}`", params.command), + Self::Terminal(None) => "Terminal".into(), + Self::ReadFile(_) => "Read File".into(), + Self::Ls(Some(params)) => { + format!("List Directory {}", params.path.display()) + } + Self::Ls(None) => "List Directory".into(), + Self::Edit(Some(params)) => { + format!("Edit {}", params.abs_path.display()) + } + Self::Edit(None) => "Edit".into(), + Self::MultiEdit(Some(params)) => { + format!("Multi Edit {}", params.file_path.display()) + } + Self::MultiEdit(None) => "Multi Edit".into(), + Self::Write(Some(params)) => { + format!("Write {}", params.file_path.display()) + } + Self::Write(None) => "Write".into(), + Self::Glob(Some(params)) => { + format!("Glob `{params}`") + } + Self::Glob(None) => "Glob".into(), + Self::Grep(Some(params)) => format!("`{params}`"), + Self::Grep(None) => "Grep".into(), + Self::WebFetch(Some(params)) => format!("Fetch {}", params.url), + Self::WebFetch(None) => "Fetch".into(), + Self::WebSearch(Some(params)) => format!("Web Search: {}", params), + Self::WebSearch(None) => "Web Search".into(), + Self::TodoWrite(Some(params)) => format!( + "Update TODOs: {}", + params.todos.iter().map(|todo| &todo.content).join(", ") + ), + Self::TodoWrite(None) => "Update TODOs".into(), + Self::ExitPlanMode(_) => "Exit Plan Mode".into(), + Self::Other { name, .. } => name.clone(), + } + } + pub fn content(&self) -> Vec { + match &self { + Self::Other { input, .. } => vec![ + format!( + "```json\n{}```", + serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) + ) + .into(), + ], + Self::Task(Some(params)) => vec![params.prompt.clone().into()], + Self::NotebookRead(Some(params)) => { + vec![params.notebook_path.display().to_string().into()] + } + Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], + Self::Terminal(Some(params)) => vec![ + format!( + "`{}`\n\n{}", + params.command, + params.description.as_deref().unwrap_or_default() + ) + .into(), + ], + Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], + Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], + Self::Glob(Some(params)) => vec![params.to_string().into()], + Self::Grep(Some(params)) => vec![format!("`{params}`").into()], + Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], + Self::WebSearch(Some(params)) => vec![params.to_string().into()], + Self::TodoWrite(Some(params)) => vec![ + params + .todos + .iter() + .map(|todo| { + format!( + "- {} {}: {}", + match todo.status { + TodoStatus::Completed => "βœ…", + TodoStatus::InProgress => "🚧", + TodoStatus::Pending => "⬜", + }, + todo.priority, + todo.content + ) + }) + .join("\n") + .into(), + ], + Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], + Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.abs_path.clone(), + old_text: Some(params.old_text.clone()), + new_text: params.new_text.clone(), + }, + }], + Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: None, + new_text: params.content.clone(), + }, + }], + Self::MultiEdit(Some(params)) => { + // todo: show multiple edits in a multibuffer? + params + .edits + .first() + .map(|edit| { + vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }] + }) + .unwrap_or_default() + } + Self::Task(None) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Terminal(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(None) + | Self::Grep(None) + | Self::WebFetch(None) + | Self::WebSearch(None) + | Self::TodoWrite(None) + | Self::ExitPlanMode(None) + | Self::Edit(None) + | Self::Write(None) + | Self::MultiEdit(None) => vec![], + } + } + + pub fn kind(&self) -> acp::ToolKind { + match self { + Self::Task(_) => acp::ToolKind::Think, + Self::NotebookRead(_) => acp::ToolKind::Read, + Self::NotebookEdit(_) => acp::ToolKind::Edit, + Self::Edit(_) => acp::ToolKind::Edit, + Self::MultiEdit(_) => acp::ToolKind::Edit, + Self::Write(_) => acp::ToolKind::Edit, + Self::ReadFile(_) => acp::ToolKind::Read, + Self::Ls(_) => acp::ToolKind::Search, + Self::Glob(_) => acp::ToolKind::Search, + Self::Grep(_) => acp::ToolKind::Search, + Self::Terminal(_) => acp::ToolKind::Execute, + Self::WebSearch(_) => acp::ToolKind::Search, + Self::WebFetch(_) => acp::ToolKind::Fetch, + Self::TodoWrite(_) => acp::ToolKind::Think, + Self::ExitPlanMode(_) => acp::ToolKind::Think, + Self::Other { .. } => acp::ToolKind::Other, + } + } + + pub fn locations(&self) -> Vec { + match &self { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { + path: abs_path.clone(), + line: None, + }], + Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::Write(Some(WriteToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::ReadFile(Some(ReadToolParams { + abs_path, offset, .. + })) => vec![acp::ToolCallLocation { + path: abs_path.clone(), + line: *offset, + }], + Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { + vec![acp::ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { + vec![acp::ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::Glob(Some(GlobToolParams { + path: Some(path), .. + })) => vec![acp::ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Grep(Some(GrepToolParams { + path: Some(path), .. + })) => vec![acp::ToolCallLocation { + path: PathBuf::from(path), + line: None, + }], + Self::Task(_) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Edit(None) + | Self::MultiEdit(None) + | Self::Write(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(_) + | Self::Grep(_) + | Self::Terminal(_) + | Self::WebFetch(_) + | Self::WebSearch(_) + | Self::TodoWrite(_) + | Self::ExitPlanMode(_) + | Self::Other { .. } => vec![], + } + } + + pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { + acp::ToolCall { + id, + kind: self.kind(), + status: acp::ToolCallStatus::InProgress, + label: self.label(), + content: self.content(), + locations: self.locations(), + } + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct EditToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// The old text to replace (must be unique in the file) + pub old_text: String, + /// The new text. + pub new_text: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ReadToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// Which line to start reading from. Omit to start from the beginning. + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option, + /// How many lines to read. Omit for the whole file. + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WriteToolParams { + /// Absolute path for new file + pub file_path: PathBuf, + /// File content + pub content: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct BashToolParams { + /// Shell command to execute + pub command: String, + /// 5-10 word description of what command does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Timeout in ms (max 600000ms/10min, default 120000ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GlobToolParams { + /// Glob pattern like **/*.js or src/**/*.ts + pub pattern: String, + /// Directory to search in (omit for current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +impl std::fmt::Display for GlobToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(path) = &self.path { + write!(f, "{}", path.display())?; + } + write!(f, "{}", self.pattern) + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct LsToolParams { + /// Absolute path to directory + pub path: PathBuf, + /// Array of glob patterns to ignore + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub ignore: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GrepToolParams { + /// Regex pattern to search for + pub pattern: String, + /// File/directory to search (defaults to current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// "content" (shows lines), "files_with_matches" (default), "count" + #[serde(skip_serializing_if = "Option::is_none")] + pub output_mode: Option, + /// Filter files with glob pattern like "*.js" + #[serde(skip_serializing_if = "Option::is_none")] + pub glob: Option, + /// File type filter like "js", "py", "rust" + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub file_type: Option, + /// Case insensitive search + #[serde(rename = "-i", default, skip_serializing_if = "is_false")] + pub case_insensitive: bool, + /// Show line numbers (content mode only) + #[serde(rename = "-n", default, skip_serializing_if = "is_false")] + pub line_numbers: bool, + /// Lines after match (content mode only) + #[serde(rename = "-A", skip_serializing_if = "Option::is_none")] + pub after_context: Option, + /// Lines before match (content mode only) + #[serde(rename = "-B", skip_serializing_if = "Option::is_none")] + pub before_context: Option, + /// Lines before and after match (content mode only) + #[serde(rename = "-C", skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Enable multiline/cross-line matching + #[serde(default, skip_serializing_if = "is_false")] + pub multiline: bool, + /// Limit output to first N results + #[serde(skip_serializing_if = "Option::is_none")] + pub head_limit: Option, +} + +impl std::fmt::Display for GrepToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "grep")?; + + // Boolean flags + if self.case_insensitive { + write!(f, " -i")?; + } + if self.line_numbers { + write!(f, " -n")?; + } + + // Context options + if let Some(after) = self.after_context { + write!(f, " -A {}", after)?; + } + if let Some(before) = self.before_context { + write!(f, " -B {}", before)?; + } + if let Some(context) = self.context { + write!(f, " -C {}", context)?; + } + + // Output mode + if let Some(mode) = &self.output_mode { + match mode { + GrepOutputMode::FilesWithMatches => write!(f, " -l")?, + GrepOutputMode::Count => write!(f, " -c")?, + GrepOutputMode::Content => {} // Default mode + } + } + + // Head limit + if let Some(limit) = self.head_limit { + write!(f, " | head -{}", limit)?; + } + + // Glob pattern + if let Some(glob) = &self.glob { + write!(f, " --include=\"{}\"", glob)?; + } + + // File type + if let Some(file_type) = &self.file_type { + write!(f, " --type={}", file_type)?; + } + + // Multiline + if self.multiline { + write!(f, " -P")?; // Perl-compatible regex for multiline + } + + // Pattern (escaped if contains special characters) + write!(f, " \"{}\"", self.pattern)?; + + // Path + if let Some(path) = &self.path { + write!(f, " {}", path)?; + } + + Ok(()) + } +} + +#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoPriority { + High, + Medium, + Low, +} + +impl Into for TodoPriority { + fn into(self) -> acp::PlanEntryPriority { + match self { + TodoPriority::High => acp::PlanEntryPriority::High, + TodoPriority::Medium => acp::PlanEntryPriority::Medium, + TodoPriority::Low => acp::PlanEntryPriority::Low, + } + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoStatus { + Pending, + InProgress, + Completed, +} + +impl Into for TodoStatus { + fn into(self) -> acp::PlanEntryStatus { + match self { + TodoStatus::Pending => acp::PlanEntryStatus::Pending, + TodoStatus::InProgress => acp::PlanEntryStatus::InProgress, + TodoStatus::Completed => acp::PlanEntryStatus::Completed, + } + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct Todo { + /// Unique identifier + pub id: String, + /// Task description + pub content: String, + /// Priority level of the todo + pub priority: TodoPriority, + /// Current status of the todo + pub status: TodoStatus, +} + +impl Into for Todo { + fn into(self) -> acp::PlanEntry { + acp::PlanEntry { + content: self.content, + priority: self.priority.into(), + status: self.status.into(), + } + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TodoWriteToolParams { + pub todos: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ExitPlanModeToolParams { + /// Implementation plan in markdown format + pub plan: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TaskToolParams { + /// Short 3-5 word description of task + pub description: String, + /// Detailed task for agent to perform + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookReadToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// Specific cell ID to read + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum CellType { + Code, + Markdown, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum EditMode { + Replace, + Insert, + Delete, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookEditToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// New cell content + pub new_source: String, + /// Cell ID to edit + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, + /// Type of cell (code or markdown) + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_type: Option, + /// Edit operation mode + #[serde(skip_serializing_if = "Option::is_none")] + pub edit_mode: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct MultiEditItem { + /// The text to search for and replace + pub old_string: String, + /// The replacement text + pub new_string: String, + /// Whether to replace all occurrences or just the first + #[serde(default, skip_serializing_if = "is_false")] + pub replace_all: bool, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct MultiEditToolParams { + /// Absolute path to file + pub file_path: PathBuf, + /// List of edits to apply + pub edits: Vec, +} + +fn is_false(v: &bool) -> bool { + !*v +} + +#[derive(Deserialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum GrepOutputMode { + Content, + FilesWithMatches, + Count, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebFetchToolParams { + /// Valid URL to fetch + #[serde(rename = "url")] + pub url: String, + /// What to extract from content + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebSearchToolParams { + /// Search query (min 2 chars) + pub query: String, + /// Only include these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_domains: Vec, + /// Exclude these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub blocked_domains: Vec, +} + +impl std::fmt::Display for WebSearchToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "\"{}\"", self.query)?; + + if !self.allowed_domains.is_empty() { + write!(f, " (allowed: {})", self.allowed_domains.join(", "))?; + } + + if !self.blocked_domains.is_empty() { + write!(f, " (blocked: {})", self.blocked_domains.join(", "))?; + } + + Ok(()) + } +} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs new file mode 100644 index 0000000000..9bc6fd60fe --- /dev/null +++ b/crates/agent_servers/src/e2e_tests.rs @@ -0,0 +1,411 @@ +use std::{path::Path, sync::Arc, time::Duration}; + +use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use agent_client_protocol as acp; + +use futures::{FutureExt, StreamExt, channel::mpsc, select}; +use gpui::{Entity, TestAppContext}; +use indoc::indoc; +use project::{FakeFs, Project}; +use serde_json::json; +use settings::{Settings, SettingsStore}; +use util::path; + +pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!( + thread.entries()[1], + AgentThreadEntry::AssistantMessage(_) + )); + }); +} + +pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + std::fs::write( + tempdir.path().join("foo.rs"), + indoc! {" + fn main() { + println!(\"Hello, world!\"); + } + "}, + ) + .expect("failed to write file"); + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await; + thread + .update(cx, |thread, cx| { + thread.send( + vec![ + acp::ContentBlock::Text(acp::TextContent { + text: "Read the file ".into(), + annotations: None, + }), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: "foo.rs".into(), + name: "foo.rs".into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + acp::ContentBlock::Text(acp::TextContent { + text: " and tell me what the content of the println! is".into(), + annotations: None, + }), + ], + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, cx| { + assert_eq!(thread.entries().len(), 3); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); + let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { + panic!("Expected AssistantMessage") + }; + assert!( + assistant_message.to_markdown(cx).contains("Hello, world!"), + "unexpected assistant message: {:?}", + assistant_message.to_markdown(cx) + ); + }); +} + +pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + fs.insert_tree( + path!("/private/tmp"), + json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), + ) + .await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Read the '/private/tmp/foo' file and tell me what you see.", + cx, + ) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _cx| { + assert!(thread.entries().iter().any(|entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + ) + })); + assert!( + thread + .entries() + .iter() + .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) + ); + }); +} + +pub async fn test_tool_call_with_confirmation( + server: impl AgentServer + 'static, + cx: &mut TestAppContext, +) { + let fs = init_test(cx).await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw( + r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#, + cx, + ) + }); + + run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; + + let tool_call_id = thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + content, + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) = &thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() + else { + panic!(); + }; + + assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); + + id.clone() + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id, + acp::PermissionOptionId("0".into()), + acp::PermissionOptionKind::AllowOnce, + cx, + ); + + assert!(thread.entries().iter().any(|entry| matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + ))); + }); + + full_turn.await.unwrap(); + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + content, + status: ToolCallStatus::Allowed { .. }, + .. + }) = thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() + else { + panic!(); + }; + + assert!( + content.iter().any(|c| c.to_markdown(cx).contains("Hello")), + "Expected content to contain 'Hello'" + ); + }); +} + +pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw( + r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#, + cx, + ) + }); + + let first_tool_call_ix = run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; + + thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + content, + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!("{:?}", thread.entries()[1]); + }; + + assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); + + id.clone() + }); + + let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); + full_turn.await.unwrap(); + thread.read_with(cx, |thread, _| { + let AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!(); + }; + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw(r#"Stop running and say goodbye to me."#, cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries().last().unwrap(), + AgentThreadEntry::AssistantMessage(..), + )) + }); +} + +#[macro_export] +macro_rules! common_e2e_tests { + ($server:expr) => { + mod common_e2e { + use super::*; + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn basic(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_basic($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn path_mentions(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_path_mentions($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn cancel(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_cancel($server, cx).await; + } + } + }; +} + +// Helpers + +pub async fn init_test(cx: &mut TestAppContext) -> Arc { + env_logger::try_init().ok(); + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + crate::settings::init(cx); + + crate::AllAgentServersSettings::override_global( + AllAgentServersSettings { + claude: Some(AgentServerSettings { + command: crate::claude::tests::local_command(), + }), + gemini: Some(AgentServerSettings { + command: crate::gemini::tests::local_command(), + }), + }, + cx, + ); + }); + + cx.executor().allow_parking(); + + FakeFs::new(cx.executor()) +} + +pub async fn new_test_thread( + server: impl AgentServer + 'static, + project: Entity, + current_dir: impl AsRef, + cx: &mut TestAppContext, +) -> Entity { + let connection = cx + .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) + .await + .unwrap(); + + let thread = connection + .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + .await + .unwrap(); + + thread +} + +pub async fn run_until_first_tool_call( + thread: &Entity, + wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static, + cx: &mut TestAppContext, +) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries().iter().enumerate() { + if wait_until(entry) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + // We have to use a smol timer here because + // cx.background_executor().timer isn't real in the test context + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs new file mode 100644 index 0000000000..47b965cdad --- /dev/null +++ b/crates/agent_servers/src/gemini.rs @@ -0,0 +1,205 @@ +use anyhow::anyhow; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; +use util::ResultExt as _; + +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; +use agentic_coding_protocol as acp_old; +use anyhow::{Context as _, Result}; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use settings::SettingsStore; +use ui::App; + +use crate::AllAgentServersSettings; + +#[derive(Clone)] +pub struct Gemini; + +const ACP_ARG: &str = "--experimental-acp"; + +impl AgentServer for Gemini { + fn name(&self) -> &'static str { + "Gemini" + } + + fn empty_state_headline(&self) -> &'static str { + "Welcome to Gemini" + } + + fn empty_state_message(&self) -> &'static str { + "Ask questions, edit files, run commands.\nBe specific for the best results." + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiGemini + } + + fn connect( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let root_dir = root_dir.to_path_buf(); + let project = project.clone(); + let this = self.clone(); + let name = self.name(); + + cx.spawn(async move |cx| { + let command = this.command(&project, cx).await?; + + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + let connection: Rc = Rc::new(OldAcpAgentConnection { + name, + connection, + child_status, + }); + + Ok(connection) + }) + } +} + +impl Gemini { + async fn command( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; + + if let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + { + return Ok(command); + }; + + let (fs, node_runtime) = project.update(cx, |project, _| { + (project.fs().clone(), project.node_runtime().cloned()) + })?; + let node_runtime = node_runtime.context("gemini not found on path")?; + + let directory = ::paths::agent_servers_dir().join("gemini"); + fs.create_dir(&directory).await?; + node_runtime + .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) + .await?; + let path = directory.join("node_modules/.bin/gemini"); + + Ok(AgentServerCommand { + path, + args: vec![ACP_ARG.into()], + env: None, + }) + } + + async fn version(&self, command: &AgentServerCommand) -> Result { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if supported { + Ok(AgentServerVersion::Supported) + } else { + Ok(AgentServerVersion::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }) + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::AgentServerCommand; + use std::path::Path; + + crate::common_e2e_tests!(Gemini); + + pub fn local_command() -> AgentServerCommand { + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../gemini-cli/packages/cli") + .to_string_lossy() + .to_string(); + + AgentServerCommand { + path: "node".into(), + args: vec![cli_path, ACP_ARG.into()], + env: None, + } + } +} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs new file mode 100644 index 0000000000..645674b5f1 --- /dev/null +++ b/crates/agent_servers/src/settings.rs @@ -0,0 +1,45 @@ +use crate::AgentServerCommand; +use anyhow::Result; +use gpui::App; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources}; + +pub fn init(cx: &mut App) { + AllAgentServersSettings::register(cx); +} + +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AllAgentServersSettings { + pub gemini: Option, + pub claude: Option, +} + +#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AgentServerSettings { + #[serde(flatten)] + pub command: AgentServerCommand, +} + +impl settings::Settings for AllAgentServersSettings { + const KEY: Option<&'static str> = Some("agent_servers"); + + type FileContent = Self; + + fn load(sources: SettingsSources, _: &mut App) -> Result { + let mut settings = AllAgentServersSettings::default(); + + for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { + if gemini.is_some() { + settings.gemini = gemini.clone(); + } + if claude.is_some() { + settings.claude = claude.clone(); + } + } + + Ok(settings) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 131cd2dc3f..13b966608c 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -69,6 +69,7 @@ pub struct AgentSettings { pub enable_feedback: bool, pub expand_edit_card: bool, pub expand_terminal_card: bool, + pub use_modifier_to_send: bool, } impl AgentSettings { @@ -174,6 +175,10 @@ impl AgentSettingsContent { self.single_file_review = Some(allow); } + pub fn set_use_modifier_to_send(&mut self, always_use: bool) { + self.use_modifier_to_send = Some(always_use); + } + pub fn set_profile(&mut self, profile_id: AgentProfileId) { self.default_profile = Some(profile_id); } @@ -301,6 +306,10 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, + /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// + /// Default: false + use_modifier_to_send: Option, } #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] @@ -456,6 +465,10 @@ impl Settings for AgentSettings { &mut settings.expand_terminal_card, value.expand_terminal_card, ); + merge( + &mut settings.use_modifier_to_send, + value.use_modifier_to_send, + ); settings .model_parameters diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 070e8eb585..fbd53e8d09 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -13,14 +13,15 @@ path = "src/agent_ui.rs" doctest = false [features] -test-support = [ - "gpui/test-support", - "language/test-support", -] +test-support = ["gpui/test-support", "language/test-support"] [dependencies] +acp_thread.workspace = true +agent-client-protocol.workspace = true agent.workspace = true +agent_servers.workspace = true agent_settings.workspace = true +ai_onboarding.workspace = true anyhow.workspace = true assistant_context.workspace = true assistant_slash_command.workspace = true @@ -31,6 +32,7 @@ buffer_diff.workspace = true chrono.workspace = true client.workspace = true collections.workspace = true +command_palette_hooks.workspace = true component.workspace = true context_server.workspace = true db.workspace = true @@ -52,6 +54,7 @@ itertools.workspace = true jsonschema.workspace = true language.workspace = true language_model.workspace = true +language_models.workspace = true log.workspace = true lsp.workspace = true markdown.workspace = true @@ -76,6 +79,7 @@ serde_json_lenient.workspace = true settings.workspace = true smol.workspace = true streaming_diff.workspace = true +task.workspace = true telemetry.workspace = true telemetry_events.workspace = true terminal.workspace = true @@ -85,6 +89,7 @@ theme.workspace = true time.workspace = true time_format.workspace = true ui.workspace = true +ui_input.workspace = true urlencoding.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs new file mode 100644 index 0000000000..cc476b1a86 --- /dev/null +++ b/crates/agent_ui/src/acp.rs @@ -0,0 +1,6 @@ +mod completion_provider; +mod message_history; +mod thread_view; + +pub use message_history::MessageHistory; +pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs new file mode 100644 index 0000000000..fca4ae0300 --- /dev/null +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -0,0 +1,574 @@ +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; + +use anyhow::Result; +use collections::HashMap; +use editor::display_map::CreaseId; +use editor::{CompletionProvider, Editor, ExcerptId}; +use file_icons::FileIcons; +use gpui::{App, Entity, Task, WeakEntity}; +use language::{Buffer, CodeLabel, HighlightId}; +use lsp::CompletionContext; +use parking_lot::Mutex; +use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId}; +use rope::Point; +use text::{Anchor, ToPoint}; +use ui::prelude::*; +use workspace::Workspace; + +use crate::context_picker::MentionLink; +use crate::context_picker::file_context_picker::{extract_file_name_and_directory, search_files}; + +#[derive(Default)] +pub struct MentionSet { + paths_by_crease_id: HashMap, +} + +impl MentionSet { + pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) { + self.paths_by_crease_id.insert(crease_id, path); + } + + pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option { + self.paths_by_crease_id.get(&crease_id).cloned() + } + + pub fn drain(&mut self) -> impl Iterator { + self.paths_by_crease_id.drain().map(|(id, _)| id) + } +} + +pub struct ContextPickerCompletionProvider { + workspace: WeakEntity, + editor: WeakEntity, + mention_set: Arc>, +} + +impl ContextPickerCompletionProvider { + pub fn new( + mention_set: Arc>, + workspace: WeakEntity, + editor: WeakEntity, + ) -> Self { + Self { + mention_set, + workspace, + editor, + } + } + + fn completion_for_path( + project_path: ProjectPath, + path_prefix: &str, + is_recent: bool, + is_directory: bool, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + mention_set: Arc>, + cx: &App, + ) -> Completion { + let (file_name, directory) = + extract_file_name_and_directory(&project_path.path, path_prefix); + + let label = + build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx); + let full_path = if let Some(directory) = directory { + format!("{}{}", directory, file_name) + } else { + file_name.to_string() + }; + + let crease_icon_path = if is_directory { + FileIcons::get_folder_icon(false, cx).unwrap_or_else(|| IconName::Folder.path().into()) + } else { + FileIcons::get_icon(Path::new(&full_path), cx) + .unwrap_or_else(|| IconName::File.path().into()) + }; + let completion_icon_path = if is_recent { + IconName::HistoryRerun.path().into() + } else { + crease_icon_path.clone() + }; + + let new_text = format!("{} ", MentionLink::for_file(&file_name, &full_path)); + let new_text_len = new_text.len(); + Completion { + replace_range: source_range.clone(), + new_text, + label, + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(completion_icon_path), + insert_text_mode: None, + confirm: Some(confirm_completion_callback( + crease_icon_path, + file_name, + project_path, + excerpt_id, + source_range.start, + new_text_len - 1, + editor, + mention_set, + )), + } + } +} + +fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel { + let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); + let mut label = CodeLabel::default(); + + label.push_str(&file_name, None); + label.push_str(" ", None); + + if let Some(directory) = directory { + label.push_str(&directory, comment_id); + } + + label.filter_range = 0..label.text().len(); + + label +} + +impl CompletionProvider for ContextPickerCompletionProvider { + fn completions( + &self, + excerpt_id: ExcerptId, + buffer: &Entity, + buffer_position: Anchor, + _trigger: CompletionContext, + _window: &mut Window, + cx: &mut Context, + ) -> Task>> { + let state = buffer.update(cx, |buffer, _cx| { + let position = buffer_position.to_point(buffer); + let line_start = Point::new(position.row, 0); + let offset_to_line = buffer.point_to_offset(line_start); + let mut lines = buffer.text_for_range(line_start..position).lines(); + let line = lines.next()?; + MentionCompletion::try_parse(line, offset_to_line) + }); + let Some(state) = state else { + return Task::ready(Ok(Vec::new())); + }; + + let Some(workspace) = self.workspace.upgrade() else { + return Task::ready(Ok(Vec::new())); + }; + + let snapshot = buffer.read(cx).snapshot(); + let source_range = snapshot.anchor_before(state.source_range.start) + ..snapshot.anchor_after(state.source_range.end); + + let editor = self.editor.clone(); + let mention_set = self.mention_set.clone(); + let MentionCompletion { argument, .. } = state; + let query = argument.unwrap_or_else(|| "".to_string()); + + let search_task = search_files(query.clone(), Arc::::default(), &workspace, cx); + + cx.spawn(async move |_, cx| { + let matches = search_task.await; + let Some(editor) = editor.upgrade() else { + return Ok(Vec::new()); + }; + + let completions = cx.update(|cx| { + matches + .into_iter() + .map(|mat| { + let path_match = &mat.mat; + let project_path = ProjectPath { + worktree_id: WorktreeId::from_usize(path_match.worktree_id), + path: path_match.path.clone(), + }; + + Self::completion_for_path( + project_path, + &path_match.path_prefix, + mat.is_recent, + path_match.is_dir, + excerpt_id, + source_range.clone(), + editor.clone(), + mention_set.clone(), + cx, + ) + }) + .collect() + })?; + + Ok(vec![CompletionResponse { + completions, + // Since this does its own filtering (see `filter_completions()` returns false), + // there is no benefit to computing whether this set of completions is incomplete. + is_incomplete: true, + }]) + }) + } + + fn is_completion_trigger( + &self, + buffer: &Entity, + position: language::Anchor, + _text: &str, + _trigger_in_words: bool, + _menu_is_open: bool, + cx: &mut Context, + ) -> bool { + let buffer = buffer.read(cx); + let position = position.to_point(buffer); + let line_start = Point::new(position.row, 0); + let offset_to_line = buffer.point_to_offset(line_start); + let mut lines = buffer.text_for_range(line_start..position).lines(); + if let Some(line) = lines.next() { + MentionCompletion::try_parse(line, offset_to_line) + .map(|completion| { + completion.source_range.start <= offset_to_line + position.column as usize + && completion.source_range.end >= offset_to_line + position.column as usize + }) + .unwrap_or(false) + } else { + false + } + } + + fn sort_completions(&self) -> bool { + false + } + + fn filter_completions(&self) -> bool { + false + } +} + +fn confirm_completion_callback( + crease_icon_path: SharedString, + crease_text: SharedString, + project_path: ProjectPath, + excerpt_id: ExcerptId, + start: Anchor, + content_len: usize, + editor: Entity, + mention_set: Arc>, +) -> Arc bool + Send + Sync> { + Arc::new(move |_, window, cx| { + let crease_text = crease_text.clone(); + let crease_icon_path = crease_icon_path.clone(); + let editor = editor.clone(); + let project_path = project_path.clone(); + let mention_set = mention_set.clone(); + window.defer(cx, move |window, cx| { + let crease_id = crate::context_picker::insert_crease_for_mention( + excerpt_id, + start, + content_len, + crease_text.clone(), + crease_icon_path, + editor.clone(), + window, + cx, + ); + if let Some(crease_id) = crease_id { + mention_set.lock().insert(crease_id, project_path); + } + }); + false + }) +} + +#[derive(Debug, Default, PartialEq)] +struct MentionCompletion { + source_range: Range, + argument: Option, +} + +impl MentionCompletion { + fn try_parse(line: &str, offset_to_line: usize) -> Option { + let last_mention_start = line.rfind('@')?; + if last_mention_start >= line.len() { + return Some(Self::default()); + } + if last_mention_start > 0 + && line + .chars() + .nth(last_mention_start - 1) + .map_or(false, |c| !c.is_whitespace()) + { + return None; + } + + let rest_of_line = &line[last_mention_start + 1..]; + let mut argument = None; + + let mut parts = rest_of_line.split_whitespace(); + let mut end = last_mention_start + 1; + if let Some(argument_text) = parts.next() { + end += argument_text.len(); + argument = Some(argument_text.to_string()); + } + + Some(Self { + source_range: last_mention_start + offset_to_line..end + offset_to_line, + argument, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; + use project::{Project, ProjectPath}; + use serde_json::json; + use settings::SettingsStore; + use std::{ops::Deref, rc::Rc}; + use util::path; + use workspace::{AppState, Item}; + + #[test] + fn test_mention_completion_parse() { + assert_eq!(MentionCompletion::try_parse("Lorem Ipsum", 0), None); + + assert_eq!( + MentionCompletion::try_parse("Lorem @", 0), + Some(MentionCompletion { + source_range: 6..7, + argument: None, + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @main", 0), + Some(MentionCompletion { + source_range: 6..11, + argument: Some("main".to_string()), + }) + ); + + assert_eq!(MentionCompletion::try_parse("test@", 0), None); + } + + struct AtMentionEditor(Entity); + + impl Item for AtMentionEditor { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for AtMentionEditor {} + + impl Focusable for AtMentionEditor { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx).clone() + } + } + + impl Render for AtMentionEditor { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + + #[gpui::test] + async fn test_context_completion_provider(cx: &mut TestAppContext) { + init_test(cx); + + let app_state = cx.update(AppState::test); + + cx.update(|cx| { + language::init(cx); + editor::init(cx); + workspace::init(app_state.clone(), cx); + Project::init_settings(cx); + }); + + app_state + .fs + .as_fake() + .insert_tree( + path!("/dir"), + json!({ + "editor": "", + "a": { + "one.txt": "", + "two.txt": "", + "three.txt": "", + "four.txt": "" + }, + "b": { + "five.txt": "", + "six.txt": "", + "seven.txt": "", + "eight.txt": "", + } + }), + ) + .await; + + let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let workspace = window.root(cx).unwrap(); + + let worktree = project.update(cx, |project, cx| { + let mut worktrees = project.worktrees(cx).collect::>(); + assert_eq!(worktrees.len(), 1); + worktrees.pop().unwrap() + }); + let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); + + let mut cx = VisualTestContext::from_window(*window.deref(), cx); + + let paths = vec![ + path!("a/one.txt"), + path!("a/two.txt"), + path!("a/three.txt"), + path!("a/four.txt"), + path!("b/five.txt"), + path!("b/six.txt"), + path!("b/seven.txt"), + path!("b/eight.txt"), + ]; + + let mut opened_editors = Vec::new(); + for path in paths { + let buffer = workspace + .update_in(&mut cx, |workspace, window, cx| { + workspace.open_path( + ProjectPath { + worktree_id, + path: Path::new(path).into(), + }, + None, + false, + window, + cx, + ) + }) + .await + .unwrap(); + opened_editors.push(buffer); + } + + let editor = workspace.update_in(&mut cx, |workspace, window, cx| { + let editor = cx.new(|cx| { + Editor::new( + editor::EditorMode::full(), + multi_buffer::MultiBuffer::build_simple("", cx), + None, + window, + cx, + ) + }); + workspace.active_pane().update(cx, |pane, cx| { + pane.add_item( + Box::new(cx.new(|_| AtMentionEditor(editor.clone()))), + true, + true, + None, + window, + cx, + ); + }); + editor + }); + + let mention_set = Arc::new(Mutex::new(MentionSet::default())); + + let editor_entity = editor.downgrade(); + editor.update_in(&mut cx, |editor, window, cx| { + window.focus(&editor.focus_handle(cx)); + editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( + mention_set.clone(), + workspace.downgrade(), + editor_entity, + )))); + }); + + cx.simulate_input("Lorem "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem "); + assert!(!editor.has_visible_completions_menu()); + }); + + cx.simulate_input("@"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @"); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + current_completion_labels(editor), + &[ + "eight.txt dir/b/", + "seven.txt dir/b/", + "six.txt dir/b/", + "five.txt dir/b/", + "four.txt dir/a/", + "three.txt dir/a/", + "two.txt dir/a/", + "one.txt dir/a/", + "dir ", + "a dir/", + "four.txt dir/a/", + "one.txt dir/a/", + "three.txt dir/a/", + "two.txt dir/a/", + "b dir/", + "eight.txt dir/b/", + "five.txt dir/b/", + "seven.txt dir/b/", + "six.txt dir/b/", + "editor dir/" + ] + ); + }); + + // Select and confirm "File" + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem [@four.txt](@file:dir/a/four.txt) "); + }); + } + + fn current_completion_labels(editor: &Editor) -> Vec { + let completions = editor.current_completions().expect("Missing completions"); + completions + .into_iter() + .map(|completion| completion.label.text.to_string()) + .collect::>() + } + + pub(crate) fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + theme::init(theme::LoadThemes::JustBase, cx); + client::init_settings(cx); + language::init(cx); + Project::init_settings(cx); + workspace::init_settings(cx); + editor::init_settings(cx); + }); + } +} diff --git a/crates/agent_ui/src/acp/message_history.rs b/crates/agent_ui/src/acp/message_history.rs new file mode 100644 index 0000000000..d0fb1f0990 --- /dev/null +++ b/crates/agent_ui/src/acp/message_history.rs @@ -0,0 +1,87 @@ +pub struct MessageHistory { + items: Vec, + current: Option, +} + +impl Default for MessageHistory { + fn default() -> Self { + MessageHistory { + items: Vec::new(), + current: None, + } + } +} + +impl MessageHistory { + pub fn push(&mut self, message: T) { + self.current.take(); + self.items.push(message); + } + + pub fn reset_position(&mut self) { + self.current.take(); + } + + pub fn prev(&mut self) -> Option<&T> { + if self.items.is_empty() { + return None; + } + + let new_ix = self + .current + .get_or_insert(self.items.len()) + .saturating_sub(1); + + self.current = Some(new_ix); + self.items.get(new_ix) + } + + pub fn next(&mut self) -> Option<&T> { + let current = self.current.as_mut()?; + *current += 1; + + self.items.get(*current).or_else(|| { + self.current.take(); + None + }) + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prev_next() { + let mut history = MessageHistory::default(); + + // Test empty history + assert_eq!(history.prev(), None); + assert_eq!(history.next(), None); + + // Add some messages + history.push("first"); + history.push("second"); + history.push("third"); + + // Test prev navigation + assert_eq!(history.prev(), Some(&"third")); + assert_eq!(history.prev(), Some(&"second")); + assert_eq!(history.prev(), Some(&"first")); + assert_eq!(history.prev(), Some(&"first")); + + assert_eq!(history.next(), Some(&"second")); + + // Test mixed navigation + history.push("fourth"); + assert_eq!(history.prev(), Some(&"fourth")); + assert_eq!(history.prev(), Some(&"third")); + assert_eq!(history.next(), Some(&"fourth")); + assert_eq!(history.next(), None); + + // Test that push resets navigation + history.prev(); + history.prev(); + history.push("fifth"); + assert_eq!(history.prev(), Some(&"fifth")); + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs new file mode 100644 index 0000000000..7f5de9db5f --- /dev/null +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -0,0 +1,2434 @@ +use acp_thread::{AgentConnection, Plan}; +use agent_servers::AgentServer; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; + +use agent_client_protocol as acp; +use assistant_tool::ActionLog; +use buffer_diff::BufferDiff; +use collections::{HashMap, HashSet}; +use editor::{ + AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, + EditorStyle, MinimapVisibility, MultiBuffer, PathKey, +}; +use file_icons::FileIcons; +use gpui::{ + Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, + Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, + Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, + pulsating_between, +}; +use language::language_settings::SoftWrap; +use language::{Buffer, Language}; +use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; +use parking_lot::Mutex; +use project::Project; +use settings::Settings as _; +use text::Anchor; +use theme::ThemeSettings; +use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; +use util::ResultExt; +use workspace::{CollaboratorId, Workspace}; +use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; + +use ::acp_thread::{ + AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, + LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, +}; + +use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; +use crate::acp::message_history::MessageHistory; +use crate::agent_diff::AgentDiff; +use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; +use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll}; + +const RESPONSE_PADDING_X: Pixels = px(19.); + +pub struct AcpThreadView { + agent: Rc, + workspace: WeakEntity, + project: Entity, + thread_state: ThreadState, + diff_editors: HashMap>, + message_editor: Entity, + message_set_from_history: bool, + _message_editor_subscription: Subscription, + mention_set: Arc>, + last_error: Option>, + list_state: ListState, + auth_task: Option>, + expanded_tool_calls: HashSet, + expanded_thinking_blocks: HashSet<(usize, usize)>, + edits_expanded: bool, + plan_expanded: bool, + editor_expanded: bool, + message_history: Rc>>>, + _cancel_task: Option>, +} + +enum ThreadState { + Loading { + _task: Task<()>, + }, + Ready { + thread: Entity, + _subscription: [Subscription; 2], + }, + LoadError(LoadError), + Unauthenticated { + connection: Rc, + }, +} + +impl AcpThreadView { + pub fn new( + agent: Rc, + workspace: WeakEntity, + project: Entity, + message_history: Rc>>>, + min_lines: usize, + max_lines: Option, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let language = Language::new( + language::LanguageConfig { + completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), + ..Default::default() + }, + None, + ); + + let mention_set = Arc::new(Mutex::new(MentionSet::default())); + + let message_editor = cx.new(|cx| { + let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + + let mut editor = Editor::new( + editor::EditorMode::AutoHeight { + min_lines, + max_lines: max_lines, + }, + buffer, + None, + window, + cx, + ); + editor.set_placeholder_text("Message the agent - @ to include files", cx); + editor.set_show_indent_guides(false, cx); + editor.set_soft_wrap(); + editor.set_use_modal_editing(true); + editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( + mention_set.clone(), + workspace.clone(), + cx.weak_entity(), + )))); + editor.set_context_menu_options(ContextMenuOptions { + min_entries_visible: 12, + max_entries_visible: 12, + placement: Some(ContextMenuPlacement::Above), + }); + editor + }); + + let message_editor_subscription = cx.subscribe(&message_editor, |this, _, event, _| { + if let editor::EditorEvent::BufferEdited = &event { + if !this.message_set_from_history { + this.message_history.borrow_mut().reset_position(); + } + this.message_set_from_history = false; + } + }); + + let mention_set = mention_set.clone(); + + let list_state = ListState::new( + 0, + gpui::ListAlignment::Bottom, + px(2048.0), + cx.processor({ + move |this: &mut Self, index: usize, window, cx| { + let Some((entry, len)) = this.thread().and_then(|thread| { + let entries = &thread.read(cx).entries(); + Some((entries.get(index)?, entries.len())) + }) else { + return Empty.into_any(); + }; + this.render_entry(index, len, entry, window, cx) + } + }), + ); + + Self { + agent: agent.clone(), + workspace: workspace.clone(), + project: project.clone(), + thread_state: Self::initial_state(agent, workspace, project, window, cx), + message_editor, + message_set_from_history: false, + _message_editor_subscription: message_editor_subscription, + mention_set, + diff_editors: Default::default(), + list_state: list_state, + last_error: None, + auth_task: None, + expanded_tool_calls: HashSet::default(), + expanded_thinking_blocks: HashSet::default(), + edits_expanded: false, + plan_expanded: false, + editor_expanded: false, + message_history, + _cancel_task: None, + } + } + + fn initial_state( + agent: Rc, + workspace: WeakEntity, + project: Entity, + window: &mut Window, + cx: &mut Context, + ) -> ThreadState { + let root_dir = project + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()) + .unwrap_or_else(|| paths::home_dir().as_path().into()); + + let connect_task = agent.connect(&root_dir, &project, cx); + let load_task = cx.spawn_in(window, async move |this, cx| { + let connection = match connect_task.await { + Ok(thread) => thread, + Err(err) => { + this.update(cx, |this, cx| { + this.handle_load_error(err, cx); + cx.notify(); + }) + .log_err(); + return; + } + }; + + let result = match connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + .await + { + Err(e) => { + let mut cx = cx.clone(); + if e.downcast_ref::().is_some() { + this.update(&mut cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); + }) + .ok(); + return; + } else { + Err(e) + } + } + Ok(session_id) => Ok(session_id), + }; + + this.update_in(cx, |this, window, cx| { + match result { + Ok(thread) => { + let thread_subscription = + cx.subscribe_in(&thread, window, Self::handle_thread_event); + + let action_log = thread.read(cx).action_log().clone(); + let action_log_subscription = + cx.observe(&action_log, |_, _, cx| cx.notify()); + + this.list_state + .splice(0..0, thread.read(cx).entries().len()); + + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + + this.thread_state = ThreadState::Ready { + thread, + _subscription: [thread_subscription, action_log_subscription], + }; + + cx.notify(); + } + Err(err) => { + this.handle_load_error(err, cx); + } + }; + }) + .log_err(); + }); + + ThreadState::Loading { _task: load_task } + } + + fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context) { + if let Some(load_err) = err.downcast_ref::() { + self.thread_state = ThreadState::LoadError(load_err.clone()); + } else { + self.thread_state = ThreadState::LoadError(LoadError::Other(err.to_string().into())) + } + cx.notify(); + } + + pub fn thread(&self) -> Option<&Entity> { + match &self.thread_state { + ThreadState::Ready { thread, .. } => Some(thread), + ThreadState::Unauthenticated { .. } + | ThreadState::Loading { .. } + | ThreadState::LoadError(..) => None, + } + } + + pub fn title(&self, cx: &App) -> SharedString { + match &self.thread_state { + ThreadState::Ready { thread, .. } => thread.read(cx).title(), + ThreadState::Loading { .. } => "Loading…".into(), + ThreadState::LoadError(_) => "Failed to load".into(), + ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + } + } + + pub fn cancel(&mut self, cx: &mut Context) { + self.last_error.take(); + + if let Some(thread) = self.thread() { + self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); + } + } + + pub fn expand_message_editor( + &mut self, + _: &ExpandMessageEditor, + _window: &mut Window, + cx: &mut Context, + ) { + self.set_editor_is_expanded(!self.editor_expanded, cx); + cx.notify(); + } + + fn set_editor_is_expanded(&mut self, is_expanded: bool, cx: &mut Context) { + self.editor_expanded = is_expanded; + self.message_editor.update(cx, |editor, _| { + if self.editor_expanded { + editor.set_mode(EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: false, + }) + } else { + editor.set_mode(EditorMode::AutoHeight { + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), + }) + } + }); + cx.notify(); + } + + fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context) { + self.last_error.take(); + + let mut ix = 0; + let mut chunks: Vec = Vec::new(); + let project = self.project.clone(); + self.message_editor.update(cx, |editor, cx| { + let text = editor.text(cx); + editor.display_map.update(cx, |map, cx| { + let snapshot = map.snapshot(cx); + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + if let Some(project_path) = + self.mention_set.lock().path_for_crease_id(crease_id) + { + let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); + if crease_range.start > ix { + chunks.push(text[ix..crease_range.start].into()); + } + if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { + let path_str = abs_path.display().to_string(); + chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: path_str.clone(), + name: path_str, + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + })); + } + ix = crease_range.end; + } + } + + if ix < text.len() { + let last_chunk = text[ix..].trim(); + if !last_chunk.is_empty() { + chunks.push(last_chunk.into()); + } + } + }) + }); + + if chunks.is_empty() { + return; + } + + let Some(thread) = self.thread() else { return }; + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + + cx.spawn(async move |this, cx| { + let result = task.await; + + this.update(cx, |this, cx| { + if let Err(err) = result { + this.last_error = + Some(cx.new(|cx| Markdown::new(err.to_string().into(), None, None, cx))) + } + }) + }) + .detach(); + + let mention_set = self.mention_set.clone(); + + self.set_editor_is_expanded(false, cx); + self.message_editor.update(cx, |editor, cx| { + editor.clear(window, cx); + editor.remove_creases(mention_set.lock().drain(), cx) + }); + + self.message_history.borrow_mut().push(chunks); + } + + fn previous_history_message( + &mut self, + _: &PreviousHistoryMessage, + window: &mut Window, + cx: &mut Context, + ) { + self.message_set_from_history = Self::set_draft_message( + self.message_editor.clone(), + self.mention_set.clone(), + self.project.clone(), + self.message_history.borrow_mut().prev(), + window, + cx, + ); + } + + fn next_history_message( + &mut self, + _: &NextHistoryMessage, + window: &mut Window, + cx: &mut Context, + ) { + self.message_set_from_history = Self::set_draft_message( + self.message_editor.clone(), + self.mention_set.clone(), + self.project.clone(), + self.message_history.borrow_mut().next(), + window, + cx, + ); + } + + fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { + if let Some(thread) = self.thread() { + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err(); + } + } + + fn open_edited_buffer( + &mut self, + buffer: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { + return; + }; + + let Some(diff) = + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err() + else { + return; + }; + + diff.update(cx, |diff, cx| { + diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) + }) + } + + fn set_draft_message( + message_editor: Entity, + mention_set: Arc>, + project: Entity, + message: Option<&Vec>, + window: &mut Window, + cx: &mut Context, + ) -> bool { + cx.notify(); + + let Some(message) = message else { + return false; + }; + + let mut text = String::new(); + let mut mentions = Vec::new(); + + for chunk in message { + match chunk { + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); + } + acp::ContentBlock::ResourceLink(resource_link) => { + let path = Path::new(&resource_link.uri); + let start = text.len(); + let content = MentionPath::new(&path).to_string(); + text.push_str(&content); + let end = text.len(); + if let Some(project_path) = + project.read(cx).project_path_for_absolute_path(&path, cx) + { + let filename: SharedString = path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string() + .into(); + mentions.push((start..end, project_path, filename)); + } + } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => {} + } + } + + let snapshot = message_editor.update(cx, |editor, cx| { + editor.set_text(text, window, cx); + editor.buffer().read(cx).snapshot(cx) + }); + + for (range, project_path, filename) in mentions { + let crease_icon_path = if project_path.path.is_dir() { + FileIcons::get_folder_icon(false, cx) + .unwrap_or_else(|| IconName::Folder.path().into()) + } else { + FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx) + .unwrap_or_else(|| IconName::File.path().into()) + }; + + let anchor = snapshot.anchor_before(range.start); + let crease_id = crate::context_picker::insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + filename, + crease_icon_path, + message_editor.clone(), + window, + cx, + ); + if let Some(crease_id) = crease_id { + mention_set.lock().insert(crease_id, project_path); + } + } + + true + } + + fn handle_thread_event( + &mut self, + thread: &Entity, + event: &AcpThreadEvent, + window: &mut Window, + cx: &mut Context, + ) { + let count = self.list_state.item_count(); + match event { + AcpThreadEvent::NewEntry => { + let index = thread.read(cx).entries().len() - 1; + self.sync_thread_entry_view(index, window, cx); + self.list_state.splice(count..count, 1); + } + AcpThreadEvent::EntryUpdated(index) => { + let index = *index; + self.sync_thread_entry_view(index, window, cx); + self.list_state.splice(index..index + 1, 1); + } + } + cx.notify(); + } + + fn sync_thread_entry_view( + &mut self, + entry_ix: usize, + window: &mut Window, + cx: &mut Context, + ) { + let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { + return; + }; + + let multibuffers = multibuffers.collect::>(); + + for multibuffer in multibuffers { + if self.diff_editors.contains_key(&multibuffer.entity_id()) { + return; + } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + }); + editor + }); + let entity_id = multibuffer.entity_id(); + cx.observe_release(&multibuffer, move |this, _, _| { + this.diff_editors.remove(&entity_id); + }) + .detach(); + + self.diff_editors.insert(entity_id, editor); + } + } + + fn entry_diff_multibuffers( + &self, + entry_ix: usize, + cx: &App, + ) -> Option>> { + let entry = self.thread()?.read(cx).entries().get(entry_ix)?; + Some(entry.diffs().map(|diff| diff.multibuffer.clone())) + } + + fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + let ThreadState::Unauthenticated { ref connection } = self.thread_state else { + return; + }; + + self.last_error.take(); + let authenticate = connection.authenticate(cx); + self.auth_task = Some(cx.spawn_in(window, { + let project = self.project.clone(); + let agent = self.agent.clone(); + async move |this, cx| { + let result = authenticate.await; + + this.update_in(cx, |this, window, cx| { + if let Err(err) = result { + this.last_error = Some(cx.new(|cx| { + Markdown::new(format!("Error: {err}").into(), None, None, cx) + })) + } else { + this.thread_state = Self::initial_state( + agent, + this.workspace.clone(), + project.clone(), + window, + cx, + ) + } + this.auth_task.take() + }) + .ok(); + } + })); + } + + fn authorize_tool_call( + &mut self, + tool_call_id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { + return; + }; + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); + }); + cx.notify(); + } + + fn render_entry( + &self, + index: usize, + total_entries: usize, + entry: &AgentThreadEntry, + window: &mut Window, + cx: &Context, + ) -> AnyElement { + match &entry { + AgentThreadEntry::UserMessage(message) => div() + .py_4() + .px_2() + .child( + v_flex() + .p_3() + .gap_1p5() + .rounded_lg() + .shadow_md() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(cx.theme().colors().border) + .text_xs() + .children(message.content.markdown().map(|md| { + self.render_markdown( + md.clone(), + user_message_markdown_style(window, cx), + ) + })), + ) + .into_any(), + AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { + let style = default_markdown_style(false, window, cx); + let message_body = v_flex() + .w_full() + .gap_2p5() + .children(chunks.iter().enumerate().filter_map( + |(chunk_ix, chunk)| match chunk { + AssistantMessageChunk::Message { block } => { + block.markdown().map(|md| { + self.render_markdown(md.clone(), style.clone()) + .into_any_element() + }) + } + AssistantMessageChunk::Thought { block } => { + block.markdown().map(|md| { + self.render_thinking_block( + index, + chunk_ix, + md.clone(), + window, + cx, + ) + .into_any_element() + }) + } + }, + )) + .into_any(); + + v_flex() + .px_5() + .py_1() + .when(index + 1 == total_entries, |this| this.pb_4()) + .w_full() + .text_ui(cx) + .child(message_body) + .into_any() + } + AgentThreadEntry::ToolCall(tool_call) => div() + .py_1p5() + .px_5() + .child(self.render_tool_call(index, tool_call, window, cx)) + .into_any(), + } + } + + fn tool_card_header_bg(&self, cx: &Context) -> Hsla { + cx.theme() + .colors() + .element_background + .blend(cx.theme().colors().editor_foreground.opacity(0.025)) + } + + fn tool_card_border_color(&self, cx: &Context) -> Hsla { + cx.theme().colors().border.opacity(0.6) + } + + fn tool_name_font_size(&self) -> Rems { + rems_from_px(13.) + } + + fn render_thinking_block( + &self, + entry_ix: usize, + chunk_ix: usize, + chunk: Entity, + window: &Window, + cx: &Context, + ) -> AnyElement { + let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix)); + let key = (entry_ix, chunk_ix); + let is_open = self.expanded_thinking_blocks.contains(&key); + + v_flex() + .child( + h_flex() + .id(header_id) + .group("disclosure-header") + .w_full() + .justify_between() + .opacity(0.8) + .hover(|style| style.opacity(1.)) + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::ToolBulb) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child( + div() + .text_size(self.tool_name_font_size()) + .child("Thinking"), + ), + ) + .child( + div().visible_on_hover("disclosure-header").child( + Disclosure::new("thinking-disclosure", is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronDown) + .on_click(cx.listener({ + move |this, _event, _window, cx| { + if is_open { + this.expanded_thinking_blocks.remove(&key); + } else { + this.expanded_thinking_blocks.insert(key); + } + cx.notify(); + } + })), + ), + ) + .on_click(cx.listener({ + move |this, _event, _window, cx| { + if is_open { + this.expanded_thinking_blocks.remove(&key); + } else { + this.expanded_thinking_blocks.insert(key); + } + cx.notify(); + } + })), + ) + .when(is_open, |this| { + this.child( + div() + .relative() + .mt_1p5() + .ml(px(7.)) + .pl_4() + .border_l_1() + .border_color(self.tool_card_border_color(cx)) + .text_ui_sm(cx) + .child( + self.render_markdown(chunk, default_markdown_style(false, window, cx)), + ), + ) + }) + .into_any_element() + } + + fn render_tool_call( + &self, + entry_ix: usize, + tool_call: &ToolCall, + window: &Window, + cx: &Context, + ) -> Div { + let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); + + let status_icon = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { .. } => None, + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + .. + } => Some( + Icon::new(IconName::ArrowCircle) + .color(Color::Accent) + .size(IconSize::Small) + .with_animation( + "running", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), + ) + .into_any(), + ), + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Completed, + .. + } => None, + ToolCallStatus::Rejected + | ToolCallStatus::Canceled + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Failed, + .. + } => Some( + Icon::new(IconName::X) + .color(Color::Error) + .size(IconSize::Small) + .into_any_element(), + ), + }; + + let needs_confirmation = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { .. } => true, + _ => tool_call + .content + .iter() + .any(|content| matches!(content, ToolCallContent::Diff { .. })), + }; + + let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; + let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); + + v_flex() + .when(needs_confirmation, |this| { + this.rounded_lg() + .border_1() + .border_color(self.tool_card_border_color(cx)) + .bg(cx.theme().colors().editor_background) + .overflow_hidden() + }) + .child( + h_flex() + .id(header_id) + .w_full() + .gap_1() + .justify_between() + .map(|this| { + if needs_confirmation { + this.px_2() + .py_1() + .rounded_t_md() + .bg(self.tool_card_header_bg(cx)) + .border_b_1() + .border_color(self.tool_card_border_color(cx)) + } else { + this.opacity(0.8).hover(|style| style.opacity(1.)) + } + }) + .child( + h_flex() + .id("tool-call-header") + .overflow_x_scroll() + .map(|this| { + if needs_confirmation { + this.text_xs() + } else { + this.text_size(self.tool_name_font_size()) + } + }) + .gap_1p5() + .child( + Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolRead, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolBulb, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::Other => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child(if tool_call.locations.len() == 1 { + let name = tool_call.locations[0] + .path + .file_name() + .unwrap_or_default() + .display() + .to_string(); + + h_flex() + .id(("open-tool-call-location", entry_ix)) + .child(name) + .w_full() + .max_w_full() + .pr_1() + .gap_0p5() + .cursor_pointer() + .rounded_sm() + .opacity(0.8) + .hover(|label| { + label.opacity(1.).bg(cx + .theme() + .colors() + .element_hover + .opacity(0.5)) + }) + .tooltip(Tooltip::text("Jump to File")) + .on_click(cx.listener(move |this, _, window, cx| { + this.open_tool_call_location(entry_ix, 0, window, cx); + })) + .into_any_element() + } else { + self.render_markdown( + tool_call.label.clone(), + default_markdown_style(needs_confirmation, window, cx), + ) + .into_any() + }), + ) + .child( + h_flex() + .gap_0p5() + .when(is_collapsible, |this| { + this.child( + Disclosure::new(("expand", entry_ix), is_open) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronDown) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })), + ) + }) + .children(status_icon), + ) + .on_click(cx.listener({ + let id = tool_call.id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + if is_open { + this.expanded_tool_calls.remove(&id); + } else { + this.expanded_tool_calls.insert(id.clone()); + } + cx.notify(); + } + })), + ) + .when(is_open, |this| { + this.child( + v_flex() + .text_xs() + .when(is_collapsible, |this| { + this.mt_1() + .border_1() + .border_color(self.tool_card_border_color(cx)) + .bg(cx.theme().colors().editor_background) + .rounded_lg() + }) + .map(|this| { + if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => this + .children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + .child(self.render_permission_buttons( + options, + entry_ix, + tool_call.id.clone(), + cx, + )), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + this.children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + } + ToolCallStatus::Rejected => this, + } + } else { + this + } + }), + ) + }) + } + + fn render_tool_call_content( + &self, + content: &ToolCallContent, + window: &Window, + cx: &Context, + ) -> AnyElement { + match content { + ToolCallContent::ContentBlock { content } => { + if let Some(md) = content.markdown() { + div() + .p_2() + .child( + self.render_markdown( + md.clone(), + default_markdown_style(false, window, cx), + ), + ) + .into_any_element() + } else { + Empty.into_any_element() + } + } + ToolCallContent::Diff { + diff: Diff { multibuffer, .. }, + .. + } => self.render_diff_editor(multibuffer), + } + } + + fn render_permission_buttons( + &self, + options: &[acp::PermissionOption], + entry_ix: usize, + tool_call_id: acp::ToolCallId, + cx: &Context, + ) -> Div { + h_flex() + .py_1p5() + .px_1p5() + .gap_1() + .justify_end() + .border_t_1() + .border_color(self.tool_card_border_color(cx)) + .children(options.iter().map(|option| { + let option_id = SharedString::from(option.id.0.clone()); + Button::new((option_id, entry_ix), option.label.clone()) + .map(|this| match option.kind { + acp::PermissionOptionKind::AllowOnce => { + this.icon(IconName::Check).icon_color(Color::Success) + } + acp::PermissionOptionKind::AllowAlways => { + this.icon(IconName::CheckDouble).icon_color(Color::Success) + } + acp::PermissionOptionKind::RejectOnce => { + this.icon(IconName::X).icon_color(Color::Error) + } + acp::PermissionOptionKind::RejectAlways => { + this.icon(IconName::X).icon_color(Color::Error) + } + }) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .on_click(cx.listener({ + let tool_call_id = tool_call_id.clone(); + let option_id = option.id.clone(); + let option_kind = option.kind; + move |this, _, _, cx| { + this.authorize_tool_call( + tool_call_id.clone(), + option_id.clone(), + option_kind, + cx, + ); + } + })) + })) + } + + fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { + v_flex() + .h_full() + .child( + if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) { + editor.clone().into_any_element() + } else { + Empty.into_any() + }, + ) + .into_any() + } + + fn render_agent_logo(&self) -> AnyElement { + Icon::new(self.agent.logo()) + .color(Color::Muted) + .size(IconSize::XLarge) + .into_any_element() + } + + fn render_error_agent_logo(&self) -> AnyElement { + let logo = Icon::new(self.agent.logo()) + .color(Color::Muted) + .size(IconSize::XLarge) + .into_any_element(); + + h_flex() + .relative() + .justify_center() + .child(div().opacity(0.3).child(logo)) + .child( + h_flex().absolute().right_1().bottom_0().child( + Icon::new(IconName::XCircle) + .color(Color::Error) + .size(IconSize::Small), + ), + ) + .into_any_element() + } + + fn render_empty_state(&self, cx: &App) -> AnyElement { + let loading = matches!(&self.thread_state, ThreadState::Loading { .. }); + + v_flex() + .size_full() + .items_center() + .justify_center() + .child(if loading { + h_flex() + .justify_center() + .child(self.render_agent_logo()) + .with_animation( + "pulsating_icon", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 1.0)), + |icon, delta| icon.opacity(delta), + ) + .into_any() + } else { + self.render_agent_logo().into_any_element() + }) + .child(h_flex().mt_4().mb_1().justify_center().child(if loading { + div() + .child(LoadingLabel::new("").size(LabelSize::Large)) + .into_any_element() + } else { + Headline::new(self.agent.empty_state_headline()) + .size(HeadlineSize::Medium) + .into_any_element() + })) + .child( + div() + .max_w_1_2() + .text_sm() + .text_center() + .map(|this| { + if loading { + this.invisible() + } else { + this.text_color(cx.theme().colors().text_muted) + } + }) + .child(self.agent.empty_state_message()), + ) + .into_any() + } + + fn render_pending_auth_state(&self) -> AnyElement { + v_flex() + .items_center() + .justify_center() + .child(self.render_error_agent_logo()) + .child( + h_flex() + .mt_4() + .mb_1() + .justify_center() + .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)), + ) + .into_any() + } + + fn render_error_state(&self, e: &LoadError, cx: &Context) -> AnyElement { + let mut container = v_flex() + .items_center() + .justify_center() + .child(self.render_error_agent_logo()) + .child( + v_flex() + .mt_4() + .mb_2() + .gap_0p5() + .text_center() + .items_center() + .child(Headline::new("Failed to launch").size(HeadlineSize::Medium)) + .child( + Label::new(e.to_string()) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ); + + if let LoadError::Unsupported { + upgrade_message, + upgrade_command, + .. + } = &e + { + let upgrade_message = upgrade_message.clone(); + let upgrade_command = upgrade_command.clone(); + container = container.child(Button::new("upgrade", upgrade_message).on_click( + cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + let project = workspace.project().read(cx); + let cwd = project.first_project_directory(cx); + let shell = project.terminal_settings(&cwd, cx).shell.clone(); + let spawn_in_terminal = task::SpawnInTerminal { + id: task::TaskId("install".to_string()), + full_label: upgrade_command.clone(), + label: upgrade_command.clone(), + command: Some(upgrade_command.clone()), + args: Vec::new(), + command_label: upgrade_command.clone(), + cwd, + env: Default::default(), + use_new_terminal: true, + allow_concurrent_runs: true, + reveal: Default::default(), + reveal_target: Default::default(), + hide: Default::default(), + shell, + show_summary: true, + show_command: true, + show_rerun: false, + }; + workspace + .spawn_in_terminal(spawn_in_terminal, window, cx) + .detach(); + }) + .ok(); + }), + )); + } + + container.into_any() + } + + fn render_activity_bar( + &self, + thread_entity: &Entity, + window: &mut Window, + cx: &Context, + ) -> Option { + let thread = thread_entity.read(cx); + let action_log = thread.action_log(); + let changed_buffers = action_log.read(cx).changed_buffers(cx); + let plan = thread.plan(); + + if changed_buffers.is_empty() && plan.is_empty() { + return None; + } + + let editor_bg_color = cx.theme().colors().editor_background; + let active_color = cx.theme().colors().element_selected; + let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); + + let pending_edits = thread.has_pending_edit_tool_calls(); + + v_flex() + .mt_1() + .mx_2() + .bg(bg_edit_files_disclosure) + .border_1() + .border_b_0() + .border_color(cx.theme().colors().border) + .rounded_t_md() + .shadow(vec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .when(!plan.is_empty(), |this| { + this.child(self.render_plan_summary(plan, window, cx)) + .when(self.plan_expanded, |parent| { + parent.child(self.render_plan_entries(plan, window, cx)) + }) + }) + .when(!changed_buffers.is_empty(), |this| { + this.child(Divider::horizontal()) + .child(self.render_edits_summary( + action_log, + &changed_buffers, + self.edits_expanded, + pending_edits, + window, + cx, + )) + .when(self.edits_expanded, |parent| { + parent.child(self.render_edited_files( + action_log, + &changed_buffers, + pending_edits, + cx, + )) + }) + }) + .into_any() + .into() + } + + fn render_plan_summary(&self, plan: &Plan, window: &mut Window, cx: &Context) -> Div { + let stats = plan.stats(); + + let title = if let Some(entry) = stats.in_progress_entry + && !self.plan_expanded + { + h_flex() + .w_full() + .gap_1() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Label::new("Current:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(MarkdownElement::new( + entry.content.clone(), + plan_label_markdown_style(&entry.status, window, cx), + )), + ) + .when(stats.pending > 0, |this| { + this.child( + Label::new(format!("{} left", stats.pending)) + .size(LabelSize::Small) + .color(Color::Muted) + .mr_1(), + ) + }) + } else { + let status_label = if stats.pending == 0 { + "All Done".to_string() + } else if stats.completed == 0 { + format!("{}", plan.entries.len()) + } else { + format!("{}/{}", stats.completed, plan.entries.len()) + }; + + h_flex() + .w_full() + .gap_1() + .justify_between() + .child( + Label::new("Plan") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new(status_label) + .size(LabelSize::Small) + .color(Color::Muted) + .mr_1(), + ) + }; + + h_flex() + .p_1() + .justify_between() + .when(self.plan_expanded, |this| { + this.border_b_1().border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .id("plan_summary") + .w_full() + .gap_1() + .child(Disclosure::new("plan_disclosure", self.plan_expanded)) + .child(title) + .on_click(cx.listener(|this, _, _, cx| { + this.plan_expanded = !this.plan_expanded; + cx.notify(); + })), + ) + } + + fn render_plan_entries(&self, plan: &Plan, window: &mut Window, cx: &Context) -> Div { + v_flex().children(plan.entries.iter().enumerate().flat_map(|(index, entry)| { + let element = h_flex() + .py_1() + .px_2() + .gap_2() + .justify_between() + .bg(cx.theme().colors().editor_background) + .when(index < plan.entries.len() - 1, |parent| { + parent.border_color(cx.theme().colors().border).border_b_1() + }) + .child( + h_flex() + .id(("plan_entry", index)) + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(match entry.status { + acp::PlanEntryStatus::Pending => Icon::new(IconName::TodoPending) + .size(IconSize::Small) + .color(Color::Muted) + .into_any_element(), + acp::PlanEntryStatus::InProgress => Icon::new(IconName::TodoProgress) + .size(IconSize::Small) + .color(Color::Accent) + .with_animation( + "running", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage(delta))) + }, + ) + .into_any_element(), + acp::PlanEntryStatus::Completed => Icon::new(IconName::TodoComplete) + .size(IconSize::Small) + .color(Color::Success) + .into_any_element(), + }) + .child(MarkdownElement::new( + entry.content.clone(), + plan_label_markdown_style(&entry.status, window, cx), + )), + ); + + Some(element) + })) + } + + fn render_edits_summary( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + expanded: bool, + pending_edits: bool, + window: &mut Window, + cx: &Context, + ) -> Div { + const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete."; + + let focus_handle = self.focus_handle(cx); + + h_flex() + .p_1() + .justify_between() + .when(expanded, |this| { + this.border_b_1().border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .id("edits-container") + .cursor_pointer() + .w_full() + .gap_1() + .child(Disclosure::new("edits-disclosure", expanded)) + .map(|this| { + if pending_edits { + this.child( + Label::new(format!( + "Editing {} {}…", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .color(Color::Muted) + .size(LabelSize::Small) + .with_animation( + "edit-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.3, 0.7)), + |label, delta| label.alpha(delta), + ), + ) + } else { + this.child( + Label::new("Edits") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Label::new("β€’").size(LabelSize::XSmall).color(Color::Muted)) + .child( + Label::new(format!( + "{} {}", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) + .on_click(cx.listener(|this, _, _, cx| { + this.edits_expanded = !this.edits_expanded; + cx.notify(); + })), + ) + .child( + h_flex() + .gap_1() + .child( + IconButton::new("review-changes", IconName::ListTodo) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Review Changes", + &OpenAgentDiff, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(OpenAgentDiff.boxed_clone(), cx); + })), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child( + Button::new("reject-all-changes", "Reject All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in( + &RejectAll, + &focus_handle.clone(), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.reject_all_edits(cx).detach(); + }) + }) + }), + ) + .child( + Button::new("keep-all-changes", "Keep All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_all_edits(cx); + }) + }) + }), + ), + ) + } + + fn render_edited_files( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + pending_edits: bool, + cx: &Context, + ) -> Div { + let editor_bg_color = cx.theme().colors().editor_background; + + v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + |(index, (buffer, _diff))| { + let file = buffer.read(cx).file()?; + let path = file.path(); + + let file_path = path.parent().and_then(|parent| { + let parent_str = parent.to_string_lossy(); + + if parent_str.is_empty() { + None + } else { + Some( + Label::new(format!("/{}{}", parent_str, std::path::MAIN_SEPARATOR_STR)) + .color(Color::Muted) + .size(LabelSize::XSmall) + .buffer_font(cx), + ) + } + }); + + let file_name = path.file_name().map(|name| { + Label::new(name.to_string_lossy().to_string()) + .size(LabelSize::XSmall) + .buffer_font(cx) + }); + + let file_icon = FileIcons::get_icon(&path, cx) + .map(Icon::from_path) + .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) + .unwrap_or_else(|| { + Icon::new(IconName::File) + .color(Color::Muted) + .size(IconSize::Small) + }); + + let overlay_gradient = linear_gradient( + 90., + linear_color_stop(editor_bg_color, 1.), + linear_color_stop(editor_bg_color.opacity(0.2), 0.), + ); + + let element = h_flex() + .group("edited-code") + .id(("file-container", index)) + .relative() + .py_1() + .pl_2() + .pr_1() + .gap_2() + .justify_between() + .bg(editor_bg_color) + .when(index < changed_buffers.len() - 1, |parent| { + parent.border_color(cx.theme().colors().border).border_b_1() + }) + .child( + h_flex() + .id(("file-name", index)) + .pr_8() + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .child(file_icon) + .child(h_flex().gap_0p5().children(file_name).children(file_path)) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child( + h_flex() + .gap_1() + .visible_on_hover("edited-code") + .child( + Button::new("review", "Review") + .label_size(LabelSize::Small) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child(Divider::vertical().color(DividerColor::BorderVariant)) + .child( + Button::new("reject-file", "Reject") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log + .reject_edits_in_ranges( + buffer.clone(), + vec![Anchor::MIN..Anchor::MAX], + cx, + ) + .detach_and_log_err(cx); + }) + } + }), + ) + .child( + Button::new("keep-file", "Keep") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_edits_in_range( + buffer.clone(), + Anchor::MIN..Anchor::MAX, + cx, + ); + }) + } + }), + ), + ) + .child( + div() + .id("gradient-overlay") + .absolute() + .h_full() + .w_12() + .top_0() + .bottom_0() + .right(px(152.)) + .bg(overlay_gradient), + ); + + Some(element) + }, + )) + } + + fn render_message_editor(&mut self, window: &mut Window, cx: &mut Context) -> AnyElement { + let focus_handle = self.message_editor.focus_handle(cx); + let editor_bg_color = cx.theme().colors().editor_background; + let (expand_icon, expand_tooltip) = if self.editor_expanded { + (IconName::Minimize, "Minimize Message Editor") + } else { + (IconName::Maximize, "Expand Message Editor") + }; + + v_flex() + .on_action(cx.listener(Self::expand_message_editor)) + .p_2() + .gap_2() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(editor_bg_color) + .when(self.editor_expanded, |this| { + this.h(vh(0.8, window)).size_full().justify_between() + }) + .child( + v_flex() + .relative() + .size_full() + .pt_1() + .pr_2p5() + .child(div().flex_1().child({ + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small + .rems(cx) + .to_pixels(settings.agent_font_size(cx)); + let line_height = settings.buffer_line_height.value() * font_size; + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.message_editor, + EditorStyle { + background: editor_bg_color, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ) + })) + .child( + h_flex() + .absolute() + .top_0() + .right_0() + .opacity(0.5) + .hover(|this| this.opacity(1.0)) + .child( + IconButton::new("toggle-height", expand_icon) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + expand_tooltip, + &ExpandMessageEditor, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(Box::new(ExpandMessageEditor), cx); + })), + ), + ), + ) + .child( + h_flex() + .flex_none() + .justify_between() + .child(self.render_follow_toggle(cx)) + .child(self.render_send_button(cx)), + ) + .into_any() + } + + fn render_send_button(&self, cx: &mut Context) -> AnyElement { + if self.thread().map_or(true, |thread| { + thread.read(cx).status() == ThreadStatus::Idle + }) { + let is_editor_empty = self.message_editor.read(cx).is_empty(cx); + IconButton::new("send-message", IconName::Send) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled(self.thread().is_none() || is_editor_empty) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) + .when(!is_editor_empty, |button| { + button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) + }) + .when(is_editor_empty, |button| { + button.tooltip(Tooltip::text("Type a message to submit")) + }) + .into_any_element() + } else { + IconButton::new("stop-generation", IconName::StopFilled) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) + }) + .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) + .into_any_element() + } + } + + fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + let following = self + .workspace + .read_with(cx, |workspace, _| { + workspace.is_being_followed(CollaboratorId::Agent) + }) + .unwrap_or(false); + + IconButton::new("follow-agent", IconName::Crosshair) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(following) + .selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor))) + .tooltip(move |window, cx| { + if following { + Tooltip::for_action("Stop Following Agent", &Follow, window, cx) + } else { + Tooltip::with_meta( + "Follow Agent", + Some(&Follow), + "Track the agent's location as it reads and edits files.", + window, + cx, + ) + } + }) + .on_click(cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + if following { + workspace.unfollow(CollaboratorId::Agent, window, cx); + } else { + workspace.follow(CollaboratorId::Agent, window, cx); + } + }) + .ok(); + })) + } + + fn render_markdown(&self, markdown: Entity, style: MarkdownStyle) -> MarkdownElement { + let workspace = self.workspace.clone(); + MarkdownElement::new(markdown, style).on_url_click(move |text, window, cx| { + Self::open_link(text, &workspace, window, cx); + }) + } + + fn open_link( + url: SharedString, + workspace: &WeakEntity, + window: &mut Window, + cx: &mut App, + ) { + let Some(workspace) = workspace.upgrade() else { + cx.open_url(&url); + return; + }; + + if let Some(mention_path) = MentionPath::try_parse(&url) { + workspace.update(cx, |workspace, cx| { + let project = workspace.project(); + let Some((path, entry)) = project.update(cx, |project, cx| { + let path = project.find_project_path(mention_path.path(), cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; + + if entry.is_dir() { + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry.id)); + }); + } else { + workspace + .open_path(path, None, true, window, cx) + .detach_and_log_err(cx); + } + }) + } else { + cx.open_url(&url); + } + } + + fn open_tool_call_location( + &self, + entry_ix: usize, + location_ix: usize, + window: &mut Window, + cx: &mut Context, + ) -> Option<()> { + let location = self + .thread()? + .read(cx) + .entries() + .get(entry_ix)? + .locations()? + .get(location_ix)?; + + let project_path = self + .project + .read(cx) + .find_project_path(&location.path, cx)?; + + let open_task = self + .workspace + .update(cx, |worskpace, cx| { + worskpace.open_path(project_path, None, true, window, cx) + }) + .log_err()?; + + window + .spawn(cx, async move |cx| { + let item = open_task.await?; + + let Some(active_editor) = item.downcast::() else { + return anyhow::Ok(()); + }; + + active_editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let first_hunk = editor + .diff_hunks_in_ranges( + &[editor::Anchor::min()..editor::Anchor::max()], + &snapshot, + ) + .next(); + if let Some(first_hunk) = first_hunk { + let first_hunk_start = first_hunk.multi_buffer_range().start; + editor.change_selections(Default::default(), window, cx, |selections| { + selections.select_anchor_ranges([first_hunk_start..first_hunk_start]); + }) + } + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + None + } + + pub fn open_thread_as_markdown( + &self, + workspace: Entity, + window: &mut Window, + cx: &mut App, + ) -> Task> { + let markdown_language_task = workspace + .read(cx) + .app_state() + .languages + .language_for_name("Markdown"); + + let (thread_summary, markdown) = if let Some(thread) = self.thread() { + let thread = thread.read(cx); + (thread.title().to_string(), thread.to_markdown(cx)) + } else { + return Task::ready(Ok(())); + }; + + window.spawn(cx, async move |cx| { + let markdown_language = markdown_language_task.await?; + + workspace.update_in(cx, |workspace, window, cx| { + let project = workspace.project().clone(); + + if !project.read(cx).is_local() { + anyhow::bail!("failed to open active thread as markdown in remote project"); + } + + let buffer = project.update(cx, |project, cx| { + project.create_local_buffer(&markdown, Some(markdown_language), cx) + }); + let buffer = cx.new(|cx| { + MultiBuffer::singleton(buffer, cx).with_title(thread_summary.clone()) + }); + + workspace.add_item_to_active_pane( + Box::new(cx.new(|cx| { + let mut editor = + Editor::for_multibuffer(buffer, Some(project.clone()), window, cx); + editor.set_breadcrumb_header(thread_summary); + editor + })), + None, + true, + window, + cx, + ); + + anyhow::Ok(()) + })??; + anyhow::Ok(()) + }) + } + + fn scroll_to_top(&mut self, cx: &mut Context) { + self.list_state.scroll_to(ListOffset::default()); + cx.notify(); + } +} + +impl Focusable for AcpThreadView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.message_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + .tooltip(Tooltip::text("Open Thread as Markdown")) + .on_click(cx.listener(move |this, _, window, cx| { + if let Some(workspace) = this.workspace.upgrade() { + this.open_thread_as_markdown(workspace, window, cx) + .detach_and_log_err(cx); + } + })); + + let scroll_to_top = IconButton::new("scroll_to_top", IconName::ArrowUpAlt) + .icon_size(IconSize::XSmall) + .icon_color(Color::Ignored) + .tooltip(Tooltip::text("Scroll To Top")) + .on_click(cx.listener(move |this, _, _, cx| { + this.scroll_to_top(cx); + })); + + v_flex() + .size_full() + .key_context("AcpThread") + .on_action(cx.listener(Self::chat)) + .on_action(cx.listener(Self::previous_history_message)) + .on_action(cx.listener(Self::next_history_message)) + .on_action(cx.listener(Self::open_agent_diff)) + .child(match &self.thread_state { + ThreadState::Unauthenticated { .. } => { + v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child( + h_flex().mt_1p5().justify_center().child( + Button::new("sign-in", format!("Sign in to {}", self.agent.name())) + .on_click(cx.listener(|this, _, window, cx| { + this.authenticate(window, cx) + })), + ), + ) + } + ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), + ThreadState::LoadError(e) => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_error_state(e, cx)), + ThreadState::Ready { thread, .. } => v_flex().flex_1().map(|this| { + if self.list_state.item_count() > 0 { + this.child( + list(self.list_state.clone()) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any(), + ) + .child( + h_flex() + .group("controls") + .mt_1() + .mr_1() + .py_2() + .px(RESPONSE_PADDING_X) + .opacity(0.4) + .hover(|style| style.opacity(1.)) + .flex_wrap() + .justify_end() + .child(open_as_markdown) + .child(scroll_to_top) + .into_any_element(), + ) + .children(match thread.read(cx).status() { + ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => None, + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }) + .children(self.render_activity_bar(&thread, window, cx)) + } else { + this.child(self.render_empty_state(cx)) + } + }), + }) + .when_some(self.last_error.clone(), |el, error| { + el.child( + div() + .p_2() + .text_xs() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().status().error_background) + .child( + self.render_markdown(error, default_markdown_style(false, window, cx)), + ), + ) + }) + .child(self.render_message_editor(window, cx)) + } +} + +fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let mut style = default_markdown_style(false, window, cx); + let mut text_style = window.text_style(); + let theme_settings = ThemeSettings::get_global(cx); + + let buffer_font = theme_settings.buffer_font.family.clone(); + let buffer_font_size = TextSize::Small.rems(cx); + + text_style.refine(&TextStyleRefinement { + font_family: Some(buffer_font), + font_size: Some(buffer_font_size.into()), + ..Default::default() + }); + + style.base_text_style = text_style; + style.link_callback = Some(Rc::new(move |url, cx| { + if MentionPath::try_parse(url).is_some() { + let colors = cx.theme().colors(); + Some(TextStyleRefinement { + background_color: Some(colors.element_background), + ..Default::default() + }) + } else { + None + } + })); + style +} + +fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + + let buffer_font_size = TextSize::Small.rems(cx); + + let mut text_style = window.text_style(); + let line_height = buffer_font_size * 1.75; + + let font_family = if buffer_font { + theme_settings.buffer_font.family.clone() + } else { + theme_settings.ui_font.family.clone() + }; + + let font_size = if buffer_font { + TextSize::Small.rems(cx) + } else { + TextSize::Default.rems(cx) + }; + + text_style.refine(&TextStyleRefinement { + font_family: Some(font_family), + font_fallbacks: theme_settings.ui_font.fallbacks.clone(), + font_features: Some(theme_settings.ui_font.features.clone()), + font_size: Some(font_size.into()), + line_height: Some(line_height.into()), + color: Some(cx.theme().colors().text), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style.clone(), + syntax: cx.theme().syntax().clone(), + selection_background_color: cx.theme().colors().element_selection_background, + code_block_overflow_x_scroll: true, + table_overflow_x_scroll: true, + heading_level_styles: Some(HeadingLevelStyles { + h1: Some(TextStyleRefinement { + font_size: Some(rems(1.15).into()), + ..Default::default() + }), + h2: Some(TextStyleRefinement { + font_size: Some(rems(1.1).into()), + ..Default::default() + }), + h3: Some(TextStyleRefinement { + font_size: Some(rems(1.05).into()), + ..Default::default() + }), + h4: Some(TextStyleRefinement { + font_size: Some(rems(1.).into()), + ..Default::default() + }), + h5: Some(TextStyleRefinement { + font_size: Some(rems(0.95).into()), + ..Default::default() + }), + h6: Some(TextStyleRefinement { + font_size: Some(rems(0.875).into()), + ..Default::default() + }), + }), + code_block: StyleRefinement { + padding: EdgesRefinement { + top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), + }, + margin: EdgesRefinement { + top: Some(Length::Definite(Pixels(8.).into())), + left: Some(Length::Definite(Pixels(0.).into())), + right: Some(Length::Definite(Pixels(0.).into())), + bottom: Some(Length::Definite(Pixels(12.).into())), + }, + border_style: Some(BorderStyle::Solid), + border_widths: EdgesRefinement { + top: Some(AbsoluteLength::Pixels(Pixels(1.))), + left: Some(AbsoluteLength::Pixels(Pixels(1.))), + right: Some(AbsoluteLength::Pixels(Pixels(1.))), + bottom: Some(AbsoluteLength::Pixels(Pixels(1.))), + }, + border_color: Some(colors.border_variant), + background: Some(colors.editor_background.into()), + text: Some(TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + font_size: Some(buffer_font_size.into()), + ..Default::default() + }), + ..Default::default() + }, + inline_code: TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + font_size: Some(buffer_font_size.into()), + background_color: Some(colors.editor_foreground.opacity(0.08)), + ..Default::default() + }, + link: TextStyleRefinement { + background_color: Some(colors.editor_foreground.opacity(0.025)), + underline: Some(UnderlineStyle { + color: Some(colors.text_accent.opacity(0.5)), + thickness: px(1.), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + } +} + +fn plan_label_markdown_style( + status: &acp::PlanEntryStatus, + window: &Window, + cx: &App, +) -> MarkdownStyle { + let default_md_style = default_markdown_style(false, window, cx); + + MarkdownStyle { + base_text_style: TextStyle { + color: cx.theme().colors().text_muted, + strikethrough: if matches!(status, acp::PlanEntryStatus::Completed) { + Some(gpui::StrikethroughStyle { + thickness: px(1.), + color: Some(cx.theme().colors().text_muted.opacity(0.8)), + }) + } else { + None + }, + ..default_md_style.base_text_style + }, + ..default_md_style + } +} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index a4553fc901..e27c318221 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -787,6 +787,15 @@ impl ActiveThread { .unwrap() } }); + + let workspace_subscription = if let Some(workspace) = workspace.upgrade() { + Some(cx.observe_release(&workspace, |this, _, cx| { + this.dismiss_notifications(cx); + })) + } else { + None + }; + let mut this = Self { language_registry, thread_store, @@ -834,6 +843,10 @@ impl ActiveThread { } } + if let Some(subscription) = workspace_subscription { + this._subscriptions.push(subscription); + } + this } @@ -983,30 +996,57 @@ impl ActiveThread { | ThreadEvent::SummaryChanged => { self.save_thread(cx); } - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn | StopReason::MaxTokens) => { - let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); - self.play_notification_sound(window, cx); - self.show_notification( - if used_tools { - "Finished running tools" - } else { - "New message" - }, - IconName::ZedAssistant, - window, - cx, - ); + ThreadEvent::Stopped(reason) => { + match reason { + Ok(StopReason::EndTurn | StopReason::MaxTokens) => { + let used_tools = self.thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + Ok(StopReason::ToolUse) => { + // Don't notify for intermediate tool use + } + Ok(StopReason::Refusal) => { + self.notify_with_sound( + "Language model refused to respond", + IconName::Warning, + window, + cx, + ); + } + Err(error) => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + self.last_error = Some(ThreadError::Message { + header: "Error".into(), + message: error_message.into(), + }); + } } - _ => {} - }, + } ThreadEvent::ToolConfirmationNeeded => { - self.play_notification_sound(window, cx); - self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx); + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } ThreadEvent::ToolUseLimitReached => { - self.play_notification_sound(window, cx); - self.show_notification( + self.notify_with_sound( "Consecutive tool use limit reached.", IconName::Warning, window, @@ -1149,9 +1189,6 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } - ThreadEvent::RetriesFailed { message } => { - self.show_notification(message, ui::IconName::Warning, window, cx); - } } } @@ -1206,6 +1243,17 @@ impl ActiveThread { } } + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + fn pop_up( &mut self, icon: IconName, @@ -1461,6 +1509,7 @@ impl ActiveThread { &configured_model.model, cx, ), + thinking_allowed: true, }; Some(configured_model.model.count_tokens(request, cx)) @@ -2580,8 +2629,8 @@ impl ActiveThread { h_flex() .gap_1p5() .child( - Icon::new(IconName::LightBulb) - .size(IconSize::XSmall) + Icon::new(IconName::ToolBulb) + .size(IconSize::Small) .color(Color::Muted), ) .child(LoadingLabel::new("Thinking").size(LabelSize::Small)), @@ -2994,7 +3043,7 @@ impl ActiveThread { .overflow_x_scroll() .child( Icon::new(tool_use.icon) - .size(IconSize::XSmall) + .size(IconSize::Small) .color(Color::Muted), ) .child( @@ -3153,7 +3202,10 @@ impl ActiveThread { .border_color(self.tool_card_border_color(cx)) .rounded_b_lg() .child( - LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small) + div() + .min_w(rems_from_px(145.)) + .child(LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small) + ) ) .child( h_flex() @@ -3198,7 +3250,6 @@ impl ActiveThread { }, )) }) - .child(ui::Divider::vertical()) .child({ let tool_id = tool_use.id.clone(); Button::new("allow-tool-action", "Allow") @@ -3673,8 +3724,11 @@ pub(crate) fn open_context( AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_thread(thread_context.thread.clone(), window, cx); + let thread = thread_context.thread.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_thread(thread, window, cx); + }); }); } }), @@ -3682,8 +3736,11 @@ pub(crate) fn open_context( AgentContextHandle::TextThread(text_thread_context) => { workspace.update(cx, |workspace, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| { - panel.open_prompt_editor(text_thread_context.context.clone(), window, cx) + let context = text_thread_context.context.clone(); + window.defer(cx, move |window, cx| { + panel.update(cx, |panel, cx| { + panel.open_prompt_editor(context, window, cx) + }); }); } }) @@ -3838,7 +3895,7 @@ mod tests { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), + provider: Arc::new(FakeLanguageModelProvider::default()), model, }), cx, @@ -3922,7 +3979,7 @@ mod tests { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), + provider: Arc::new(FakeLanguageModelProvider::default()), model: model.clone(), }), cx, diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 8bfdd50761..cacd409ac6 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -1,3 +1,4 @@ +mod add_llm_provider_modal; mod configure_context_server_modal; mod manage_profiles_modal; mod tool_picker; @@ -24,10 +25,11 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; +use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ - ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, - Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*, + Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, + Scrollbar, ScrollbarState, Switch, SwitchColor, SwitchField, Tooltip, prelude::*, }; use util::ResultExt as _; use workspace::Workspace; @@ -36,7 +38,10 @@ use zed_actions::ExtensionCategoryFilter; pub(crate) use configure_context_server_modal::ConfigureContextServerModal; pub(crate) use manage_profiles_modal::ManageProfilesModal; -use crate::AddContextServer; +use crate::{ + AddContextServer, + agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider}, +}; pub struct AgentConfiguration { fs: Arc, @@ -171,7 +176,24 @@ impl AgentConfiguration { .copied() .unwrap_or(false); + let is_zed_provider = provider.id() == ZED_CLOUD_PROVIDER_ID; + let current_plan = if is_zed_provider { + self.workspace + .upgrade() + .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) + } else { + None + }; + + let is_signed_in = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_connected() + }) + .unwrap_or(false); + v_flex() + .w_full() .when(is_expanded, |this| this.mb_2()) .child( div() @@ -202,20 +224,39 @@ impl AgentConfiguration { .hover(|hover| hover.bg(cx.theme().colors().element_hover)) .child( h_flex() + .w_full() .gap_2() .child( Icon::new(provider.icon()) .size(IconSize::Small) .color(Color::Muted), ) - .child(Label::new(provider_name.clone()).size(LabelSize::Large)) - .when( - provider.is_authenticated(cx) && !is_expanded, - |parent| { - parent.child( - Icon::new(IconName::Check).color(Color::Success), + .child( + h_flex() + .w_full() + .gap_1() + .child( + Label::new(provider_name.clone()) + .size(LabelSize::Large), ) - }, + .map(|this| { + if is_zed_provider && is_signed_in { + this.child( + self.render_zed_plan_info(current_plan, cx), + ) + } else { + this.when( + provider.is_authenticated(cx) + && !is_expanded, + |parent| { + parent.child( + Icon::new(IconName::Check) + .color(Color::Success), + ) + }, + ) + } + }), ), ) .child( @@ -276,21 +317,78 @@ impl AgentConfiguration { let providers = LanguageModelRegistry::read_global(cx).providers(); v_flex() + .w_full() .child( - v_flex() + h_flex() .p(DynamicSpacing::Base16.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx)) .pb_0() .mb_2p5() - .gap_0p5() - .child(Headline::new("LLM Providers")) + .items_start() + .justify_between() .child( - Label::new("Add at least one provider to use AI-powered features.") - .color(Color::Muted), + v_flex() + .w_full() + .gap_0p5() + .child( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child(Headline::new("LLM Providers")) + .child( + PopoverMenu::new("add-provider-popover") + .trigger( + Button::new("add-provider", "Add Provider") + .icon_position(IconPosition::Start) + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .label_size(LabelSize::Small), + ) + .anchor(gpui::Corner::TopRight) + .menu({ + let workspace = self.workspace.clone(); + move |window, cx| { + Some(ContextMenu::build( + window, + cx, + |menu, _window, _cx| { + menu.header("Compatible APIs").entry( + "OpenAI", + None, + { + let workspace = + workspace.clone(); + move |window, cx| { + workspace + .update(cx, |workspace, cx| { + AddLlmProviderModal::toggle( + LlmCompatibleProvider::OpenAi, + workspace, + window, + cx, + ); + }) + .log_err(); + } + }, + ) + }, + )) + } + }), + ), + ) + .child( + Label::new("Add at least one provider to use AI-powered features.") + .color(Color::Muted), + ), ), ) .child( div() + .w_full() .pl(DynamicSpacing::Base08.rems(cx)) .pr(DynamicSpacing::Base20.rems(cx)) .children( @@ -303,119 +401,74 @@ impl AgentConfiguration { fn render_command_permission(&mut self, cx: &mut Context) -> impl IntoElement { let always_allow_tool_actions = AgentSettings::get_global(cx).always_allow_tool_actions; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Allow running editing tools without asking for confirmation")) - .child( - Label::new( - "The agent can perform potentially destructive actions without asking for your confirmation.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new( - "always-allow-tool-actions-switch", - always_allow_tool_actions.into(), - ) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_always_allow_tool_actions(allow); - }, - ); - } - }), - ) + SwitchField::new( + "single-file-review", + "Enable single-file agent reviews", + "Agent edits are also displayed in single-file editors for review.", + always_allow_tool_actions, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_always_allow_tool_actions(allow); + }); + }, + ) } fn render_single_file_review(&mut self, cx: &mut Context) -> impl IntoElement { let single_file_review = AgentSettings::get_global(cx).single_file_review; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Enable single-file agent reviews")) - .child( - Label::new( - "Agent edits are also displayed in single-file editors for review.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new("single-file-review-switch", single_file_review.into()) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_single_file_review(allow); - }, - ); - } - }), - ) + SwitchField::new( + "single-file-review", + "Enable single-file agent reviews", + "Agent edits are also displayed in single-file editors for review.", + single_file_review, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_single_file_review(allow); + }); + }, + ) } fn render_sound_notification(&mut self, cx: &mut Context) -> impl IntoElement { let play_sound_when_agent_done = AgentSettings::get_global(cx).play_sound_when_agent_done; + let fs = self.fs.clone(); - h_flex() - .gap_4() - .justify_between() - .flex_wrap() - .child( - v_flex() - .gap_0p5() - .max_w_5_6() - .child(Label::new("Play sound when finished generating")) - .child( - Label::new( - "Hear a notification sound when the agent is done generating changes or needs your input.", - ) - .color(Color::Muted), - ), - ) - .child( - Switch::new("play-sound-notification-switch", play_sound_when_agent_done.into()) - .color(SwitchColor::Accent) - .on_click({ - let fs = self.fs.clone(); - move |state, _window, cx| { - let allow = state == &ToggleState::Selected; - update_settings_file::( - fs.clone(), - cx, - move |settings, _| { - settings.set_play_sound_when_agent_done(allow); - }, - ); - } - }), - ) + SwitchField::new( + "sound-notification", + "Play sound when finished generating", + "Hear a notification sound when the agent is done generating changes or needs your input.", + play_sound_when_agent_done, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_play_sound_when_agent_done(allow); + }); + }, + ) + } + + fn render_modifier_to_send(&mut self, cx: &mut Context) -> impl IntoElement { + let use_modifier_to_send = AgentSettings::get_global(cx).use_modifier_to_send; + let fs = self.fs.clone(); + + SwitchField::new( + "modifier-send", + "Use modifier to submit a message", + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.", + use_modifier_to_send, + move |state, _window, cx| { + let allow = state == &ToggleState::Selected; + update_settings_file::(fs.clone(), cx, move |settings, _| { + settings.set_use_modifier_to_send(allow); + }); + }, + ) } fn render_general_settings_section(&mut self, cx: &mut Context) -> impl IntoElement { @@ -429,6 +482,38 @@ impl AgentConfiguration { .child(self.render_command_permission(cx)) .child(self.render_single_file_review(cx)) .child(self.render_sound_notification(cx)) + .child(self.render_modifier_to_send(cx)) + } + + fn render_zed_plan_info(&self, plan: Option, cx: &mut Context) -> impl IntoElement { + if let Some(plan) = plan { + let free_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.05)); + + let pro_chip_bg = cx + .theme() + .colors() + .editor_background + .opacity(0.5) + .blend(cx.theme().colors().text_accent.opacity(0.2)); + + let (plan_name, label_color, bg_color) = match plan { + Plan::Free => ("Free", Color::Default, free_chip_bg), + Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), + Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), + }; + + Chip::new(plan_name.to_string()) + .bg_color(bg_color) + .label_color(label_color) + .into_any_element() + } else { + div().into_any_element() + } } fn render_context_servers_section( @@ -491,6 +576,7 @@ impl AgentConfiguration { category_filter: Some( ExtensionCategoryFilter::ContextServers, ), + id: None, } .boxed_clone(), cx, diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs new file mode 100644 index 0000000000..94b32d156b --- /dev/null +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -0,0 +1,639 @@ +use std::sync::Arc; + +use anyhow::Result; +use collections::HashSet; +use fs::Fs; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task}; +use language_model::LanguageModelRegistry; +use language_models::{ + AllLanguageModelSettings, OpenAiCompatibleSettingsContent, + provider::open_ai_compatible::AvailableModel, +}; +use settings::update_settings_file; +use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*}; +use ui_input::SingleLineInput; +use workspace::{ModalView, Workspace}; + +#[derive(Clone, Copy)] +pub enum LlmCompatibleProvider { + OpenAi, +} + +impl LlmCompatibleProvider { + fn name(&self) -> &'static str { + match self { + LlmCompatibleProvider::OpenAi => "OpenAI", + } + } + + fn api_url(&self) -> &'static str { + match self { + LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1", + } + } +} + +struct AddLlmProviderInput { + provider_name: Entity, + api_url: Entity, + api_key: Entity, + models: Vec, +} + +impl AddLlmProviderInput { + fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self { + let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx); + let api_url = single_line_input("API URL", provider.api_url(), None, window, cx); + let api_key = single_line_input( + "API Key", + "000000000000000000000000000000000000000000000000", + None, + window, + cx, + ); + + Self { + provider_name, + api_url, + api_key, + models: vec![ModelInput::new(window, cx)], + } + } + + fn add_model(&mut self, window: &mut Window, cx: &mut App) { + self.models.push(ModelInput::new(window, cx)); + } + + fn remove_model(&mut self, index: usize) { + self.models.remove(index); + } +} + +struct ModelInput { + name: Entity, + max_completion_tokens: Entity, + max_output_tokens: Entity, + max_tokens: Entity, +} + +impl ModelInput { + fn new(window: &mut Window, cx: &mut App) -> Self { + let model_name = single_line_input( + "Model Name", + "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro", + None, + window, + cx, + ); + let max_completion_tokens = single_line_input( + "Max Completion Tokens", + "200000", + Some("200000"), + window, + cx, + ); + let max_output_tokens = single_line_input( + "Max Output Tokens", + "Max Output Tokens", + Some("32000"), + window, + cx, + ); + let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx); + Self { + name: model_name, + max_completion_tokens, + max_output_tokens, + max_tokens, + } + } + + fn parse(&self, cx: &App) -> Result { + let name = self.name.read(cx).text(cx); + if name.is_empty() { + return Err(SharedString::from("Model Name cannot be empty")); + } + Ok(AvailableModel { + name, + display_name: None, + max_completion_tokens: Some( + self.max_completion_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?, + ), + max_output_tokens: Some( + self.max_output_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?, + ), + max_tokens: self + .max_tokens + .read(cx) + .text(cx) + .parse::() + .map_err(|_| SharedString::from("Max Tokens must be a number"))?, + }) + } +} + +fn single_line_input( + label: impl Into, + placeholder: impl Into, + text: Option<&str>, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let input = SingleLineInput::new(window, cx, placeholder).label(label); + if let Some(text) = text { + input + .editor() + .update(cx, |editor, cx| editor.set_text(text, window, cx)); + } + input + }) +} + +fn save_provider_to_settings( + input: &AddLlmProviderInput, + cx: &mut App, +) -> Task> { + let provider_name: Arc = input.provider_name.read(cx).text(cx).into(); + if provider_name.is_empty() { + return Task::ready(Err("Provider Name cannot be empty".into())); + } + + if LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .any(|provider| { + provider.id().0.as_ref() == provider_name.as_ref() + || provider.name().0.as_ref() == provider_name.as_ref() + }) + { + return Task::ready(Err( + "Provider Name is already taken by another provider".into() + )); + } + + let api_url = input.api_url.read(cx).text(cx); + if api_url.is_empty() { + return Task::ready(Err("API URL cannot be empty".into())); + } + + let api_key = input.api_key.read(cx).text(cx); + if api_key.is_empty() { + return Task::ready(Err("API Key cannot be empty".into())); + } + + let mut models = Vec::new(); + let mut model_names: HashSet = HashSet::default(); + for model in &input.models { + match model.parse(cx) { + Ok(model) => { + if !model_names.insert(model.name.clone()) { + return Task::ready(Err("Model Names must be unique".into())); + } + models.push(model) + } + Err(err) => return Task::ready(Err(err)), + } + } + + let fs = ::global(cx); + let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes()); + cx.spawn(async move |cx| { + task.await + .map_err(|_| "Failed to write API key to keychain")?; + cx.update(|cx| { + update_settings_file::(fs, cx, |settings, _cx| { + settings.openai_compatible.get_or_insert_default().insert( + provider_name, + OpenAiCompatibleSettingsContent { + api_url, + available_models: models, + }, + ); + }); + }) + .ok(); + Ok(()) + }) +} + +pub struct AddLlmProviderModal { + provider: LlmCompatibleProvider, + input: AddLlmProviderInput, + focus_handle: FocusHandle, + last_error: Option, +} + +impl AddLlmProviderModal { + pub fn toggle( + provider: LlmCompatibleProvider, + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) { + workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx)); + } + + fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context) -> Self { + Self { + input: AddLlmProviderInput::new(provider, window, cx), + provider, + last_error: None, + focus_handle: cx.focus_handle(), + } + } + + fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context) { + let task = save_provider_to_settings(&self.input, cx); + cx.spawn(async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| match result { + Ok(_) => { + cx.emit(DismissEvent); + } + Err(error) => { + this.last_error = Some(error); + cx.notify(); + } + }) + }) + .detach_and_log_err(cx); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } + + fn render_section(&self) -> Section { + Section::new() + .child(self.input.provider_name.clone()) + .child(self.input.api_url.clone()) + .child(self.input.api_key.clone()) + } + + fn render_model_section(&self, cx: &mut Context) -> Section { + Section::new().child( + v_flex() + .gap_2() + .child( + h_flex() + .justify_between() + .child(Label::new("Models").size(LabelSize::Small)) + .child( + Button::new("add-model", "Add Model") + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.input.add_model(window, cx); + cx.notify(); + })), + ), + ) + .children( + self.input + .models + .iter() + .enumerate() + .map(|(ix, _)| self.render_model(ix, cx)), + ), + ) + } + + fn render_model(&self, ix: usize, cx: &mut Context) -> impl IntoElement + use<> { + let has_more_than_one_model = self.input.models.len() > 1; + let model = &self.input.models[ix]; + + v_flex() + .p_2() + .gap_2() + .rounded_sm() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_active.opacity(0.15)) + .child(model.name.clone()) + .child( + h_flex() + .gap_2() + .child(model.max_completion_tokens.clone()) + .child(model.max_output_tokens.clone()), + ) + .child(model.max_tokens.clone()) + .when(has_more_than_one_model, |this| { + this.child( + Button::new(("remove-model", ix), "Remove Model") + .icon(IconName::Trash) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .label_size(LabelSize::Small) + .style(ButtonStyle::Outlined) + .full_width() + .on_click(cx.listener(move |this, _, _window, cx| { + this.input.remove_model(ix); + cx.notify(); + })), + ) + }) + } +} + +impl EventEmitter for AddLlmProviderModal {} + +impl Focusable for AddLlmProviderModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ModalView for AddLlmProviderModal {} + +impl Render for AddLlmProviderModal { + fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + div() + .id("add-llm-provider-modal") + .key_context("AddLlmProviderModal") + .w(rems(34.)) + .elevation_3(cx) + .on_action(cx.listener(Self::cancel)) + .capture_any_mouse_down(cx.listener(|this, _, window, cx| { + this.focus_handle(cx).focus(window); + })) + .child( + Modal::new("configure-context-server", None) + .header(ModalHeader::new().headline("Add LLM Provider").description( + match self.provider { + LlmCompatibleProvider::OpenAi => { + "This provider will use an OpenAI compatible API." + } + }, + )) + .when_some(self.last_error.clone(), |this, error| { + this.section( + Section::new().child( + Banner::new() + .severity(ui::Severity::Warning) + .child(div().text_xs().child(error)), + ), + ) + }) + .child( + v_flex() + .id("modal_content") + .max_h_128() + .overflow_y_scroll() + .gap_2() + .child(self.render_section()) + .child(self.render_model_section(cx)), + ) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("cancel", "Cancel") + .key_binding( + KeyBinding::for_action_in( + &menu::Cancel, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, window, cx| { + this.cancel(&menu::Cancel, window, cx) + })), + ) + .child( + Button::new("save-server", "Save Provider") + .key_binding( + KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _event, window, cx| { + this.confirm(&menu::Confirm, window, cx) + })), + ), + ), + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use editor::EditorSettings; + use fs::FakeFs; + use gpui::{TestAppContext, VisualTestContext}; + use language::language_settings; + use language_model::{ + LanguageModelProviderId, LanguageModelProviderName, + fake_provider::FakeLanguageModelProvider, + }; + use project::Project; + use settings::{Settings as _, SettingsStore}; + use util::path; + + #[gpui::test] + async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + assert_eq!( + save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await, + Some("Provider Name cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await, + Some("API URL cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await, + Some("API Key cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("", "200000", "200000", "32000")], + cx, + ) + .await, + Some("Model Name cannot be empty".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "abc", "200000", "32000")], + cx, + ) + .await, + Some("Max Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "200000", "abc", "32000")], + cx, + ) + .await, + Some("Max Completion Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![("somemodel", "200000", "200000", "abc")], + cx, + ) + .await, + Some("Max Output Tokens must be a number".into()) + ); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "somekey", + vec![ + ("somemodel", "200000", "200000", "32000"), + ("somemodel", "200000", "200000", "32000"), + ], + cx, + ) + .await, + Some("Model Names must be unique".into()) + ); + } + + #[gpui::test] + async fn test_save_provider_name_conflict(cx: &mut TestAppContext) { + let cx = setup_test(cx).await; + + cx.update(|_window, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.register_provider( + FakeLanguageModelProvider::new( + LanguageModelProviderId::new("someprovider"), + LanguageModelProviderName::new("Some Provider"), + ), + cx, + ); + }); + }); + + assert_eq!( + save_provider_validation_errors( + "someprovider", + "someurl", + "someapikey", + vec![("somemodel", "200000", "200000", "32000")], + cx, + ) + .await, + Some("Provider Name is already taken by another provider".into()) + ); + } + + async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext { + cx.update(|cx| { + let store = SettingsStore::test(cx); + cx.set_global(store); + workspace::init_settings(cx); + Project::init_settings(cx); + theme::init(theme::LoadThemes::JustBase, cx); + language_settings::init(cx); + EditorSettings::register(cx); + language_model::init_settings(cx); + language_models::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| ::set_global(fs.clone(), cx)); + let project = Project::test(fs, [path!("/dir").as_ref()], cx).await; + let (_, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx + } + + async fn save_provider_validation_errors( + provider_name: &str, + api_url: &str, + api_key: &str, + models: Vec<(&str, &str, &str, &str)>, + cx: &mut VisualTestContext, + ) -> Option { + fn set_text( + input: &Entity, + text: &str, + window: &mut Window, + cx: &mut App, + ) { + input.update(cx, |input, cx| { + input.editor().update(cx, |editor, cx| { + editor.set_text(text, window, cx); + }); + }); + } + + let task = cx.update(|window, cx| { + let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx); + set_text(&input.provider_name, provider_name, window, cx); + set_text(&input.api_url, api_url, window, cx); + set_text(&input.api_key, api_key, window, cx); + + for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in + models.iter().enumerate() + { + if i >= input.models.len() { + input.models.push(ModelInput::new(window, cx)); + } + let model = &mut input.models[i]; + set_text(&model.name, name, window, cx); + set_text(&model.max_tokens, max_tokens, window, cx); + set_text( + &model.max_completion_tokens, + max_completion_tokens, + window, + cx, + ); + set_text(&model.max_output_tokens, max_output_tokens, window, cx); + } + save_provider_to_settings(&input, cx) + }); + + task.await.err() + } +} diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index ba0021c33c..06d035d836 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -1,4 +1,5 @@ use std::{ + path::PathBuf, sync::{Arc, Mutex}, time::Duration, }; @@ -188,7 +189,7 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) } None => ( "some-mcp-server".to_string(), - "".to_string(), + PathBuf::new(), "[]".to_string(), "{}".to_string(), ), @@ -199,13 +200,14 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) /// The name of your MCP server "{name}": {{ /// The command which runs the MCP server - "command": "{command}", + "command": "{}", /// The arguments to pass to the MCP server "args": {args}, /// The environment variables to set "env": {env} }} -}}"# +}}"#, + command.display() ) } @@ -740,7 +742,9 @@ fn wait_for_context_server( }); cx.spawn(async move |_cx| { - let result = rx.await.unwrap(); + let result = rx + .await + .map_err(|_| Arc::from("Context server store was dropped"))?; drop(subscription); result }) diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 1a0f3ff27d..ec0a11f86b 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1,7 +1,9 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; -use agent::{Thread, ThreadEvent}; +use acp_thread::{AcpThread, AcpThreadEvent}; +use agent::{Thread, ThreadEvent, ThreadSummary}; use agent_settings::AgentSettings; use anyhow::Result; +use assistant_tool::ActionLog; use buffer_diff::DiffHunkStatus; use collections::{HashMap, HashSet}; use editor::{ @@ -41,16 +43,108 @@ use zed_actions::assistant::ToggleFocus; pub struct AgentDiffPane { multibuffer: Entity, editor: Entity, - thread: Entity, + thread: AgentDiffThread, focus_handle: FocusHandle, workspace: WeakEntity, title: SharedString, _subscriptions: Vec, } +#[derive(PartialEq, Eq, Clone)] +pub enum AgentDiffThread { + Native(Entity), + AcpThread(Entity), +} + +impl AgentDiffThread { + fn project(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).project().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).project().clone(), + } + } + fn action_log(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).action_log().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).action_log().clone(), + } + } + + fn summary(&self, cx: &App) -> ThreadSummary { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).summary().clone(), + AgentDiffThread::AcpThread(thread) => ThreadSummary::Ready(thread.read(cx).title()), + } + } + + fn is_generating(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).is_generating(), + AgentDiffThread::AcpThread(thread) => { + thread.read(cx).status() == acp_thread::ThreadStatus::Generating + } + } + } + + fn has_pending_edit_tool_uses(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).has_pending_edit_tool_uses(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).has_pending_edit_tool_calls(), + } + } + + fn downgrade(&self) -> WeakAgentDiffThread { + match self { + AgentDiffThread::Native(thread) => WeakAgentDiffThread::Native(thread.downgrade()), + AgentDiffThread::AcpThread(thread) => { + WeakAgentDiffThread::AcpThread(thread.downgrade()) + } + } + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::Native(entity) + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::AcpThread(entity) + } +} + +#[derive(PartialEq, Eq, Clone)] +pub enum WeakAgentDiffThread { + Native(WeakEntity), + AcpThread(WeakEntity), +} + +impl WeakAgentDiffThread { + pub fn upgrade(&self) -> Option { + match self { + WeakAgentDiffThread::Native(weak) => weak.upgrade().map(AgentDiffThread::Native), + WeakAgentDiffThread::AcpThread(weak) => weak.upgrade().map(AgentDiffThread::AcpThread), + } + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::Native(entity) + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::AcpThread(entity) + } +} + impl AgentDiffPane { pub fn deploy( - thread: Entity, + thread: impl Into, workspace: WeakEntity, window: &mut Window, cx: &mut App, @@ -61,14 +155,16 @@ impl AgentDiffPane { } pub fn deploy_in_workspace( - thread: Entity, + thread: impl Into, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) -> Entity { + let thread = thread.into(); let existing_diff = workspace .items_of_type::(cx) .find(|diff| diff.read(cx).thread == thread); + if let Some(existing_diff) = existing_diff { workspace.activate_item(&existing_diff, true, true, window, cx); existing_diff @@ -81,7 +177,7 @@ impl AgentDiffPane { } pub fn new( - thread: Entity, + thread: AgentDiffThread, workspace: WeakEntity, window: &mut Window, cx: &mut Context, @@ -89,7 +185,7 @@ impl AgentDiffPane { let focus_handle = cx.focus_handle(); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); - let project = thread.read(cx).project().clone(); + let project = thread.project(cx).clone(); let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); @@ -100,16 +196,27 @@ impl AgentDiffPane { editor }); - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); + let mut this = Self { - _subscriptions: vec![ - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - }), - ], + _subscriptions: [ + Some( + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), + ), + match &thread { + AgentDiffThread::Native(thread) => { + Some(cx.subscribe(&thread, |this, _thread, event, cx| { + this.handle_thread_event(event, cx) + })) + } + AgentDiffThread::AcpThread(_) => None, + }, + ] + .into_iter() + .flatten() + .collect(), title: SharedString::default(), multibuffer, editor, @@ -123,8 +230,7 @@ impl AgentDiffPane { } fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context) { - let thread = self.thread.read(cx); - let changed_buffers = thread.action_log().read(cx).changed_buffers(cx); + let changed_buffers = self.thread.action_log(cx).read(cx).changed_buffers(cx); let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::>(); for (buffer, diff_handle) in changed_buffers { @@ -211,7 +317,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let new_title = self.thread.summary(cx).unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -275,14 +381,15 @@ impl AgentDiffPane { fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context) { self.thread - .update(cx, |thread, cx| thread.keep_all_edits(cx)); + .action_log(cx) + .update(cx, |action_log, cx| action_log.keep_all_edits(cx)) } } fn keep_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -297,7 +404,7 @@ fn keep_edits_in_selection( fn reject_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -311,7 +418,7 @@ fn reject_edits_in_selection( fn keep_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -326,8 +433,8 @@ fn keep_edits_in_ranges( for hunk in &diff_hunks_in_ranges { let buffer = multibuffer.read(cx).buffer(hunk.buffer_id); if let Some(buffer) = buffer { - thread.update(cx, |thread, cx| { - thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) + thread.action_log(cx).update(cx, |action_log, cx| { + action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) }); } } @@ -336,7 +443,7 @@ fn keep_edits_in_ranges( fn reject_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -362,8 +469,9 @@ fn reject_edits_in_ranges( for (buffer, ranges) in ranges_by_buffer { thread - .update(cx, |thread, cx| { - thread.reject_edits_in_ranges(buffer, ranges, cx) + .action_log(cx) + .update(cx, |action_log, cx| { + action_log.reject_edits_in_ranges(buffer, ranges, cx) }) .detach_and_log_err(cx); } @@ -461,7 +569,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let summary = self.thread.summary(cx).unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default @@ -641,7 +749,7 @@ impl Render for AgentDiffPane { } } -fn diff_hunk_controls(thread: &Entity) -> editor::RenderDiffHunkControlsFn { +fn diff_hunk_controls(thread: &AgentDiffThread) -> editor::RenderDiffHunkControlsFn { let thread = thread.clone(); Arc::new( @@ -676,7 +784,7 @@ fn render_diff_hunk_controls( hunk_range: Range, is_created_file: bool, line_height: Pixels, - thread: &Entity, + thread: &AgentDiffThread, editor: &Entity, window: &mut Window, cx: &mut App, @@ -1112,11 +1220,8 @@ impl Render for AgentDiffToolbar { return Empty.into_any(); }; - let has_pending_edit_tool_use = agent_diff - .read(cx) - .thread - .read(cx) - .has_pending_edit_tool_uses(); + let has_pending_edit_tool_use = + agent_diff.read(cx).thread.has_pending_edit_tool_uses(cx); if has_pending_edit_tool_use { return div().px_2().child(spinner_icon).into_any(); @@ -1187,8 +1292,8 @@ pub enum EditorState { } struct WorkspaceThread { - thread: WeakEntity, - _thread_subscriptions: [Subscription; 2], + thread: WeakAgentDiffThread, + _thread_subscriptions: (Subscription, Subscription), singleton_editors: HashMap, HashMap, Subscription>>, _settings_subscription: Subscription, _workspace_subscription: Option, @@ -1212,23 +1317,23 @@ impl AgentDiff { pub fn set_active_thread( workspace: &WeakEntity, - thread: &Entity, + thread: impl Into, window: &mut Window, cx: &mut App, ) { Self::global(cx).update(cx, |this, cx| { - this.register_active_thread_impl(workspace, thread, window, cx); + this.register_active_thread_impl(workspace, thread.into(), window, cx); }); } fn register_active_thread_impl( &mut self, workspace: &WeakEntity, - thread: &Entity, + thread: AgentDiffThread, window: &mut Window, cx: &mut Context, ) { - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); let action_log_subscription = cx.observe_in(&action_log, window, { let workspace = workspace.clone(); @@ -1237,17 +1342,25 @@ impl AgentDiff { } }); - let thread_subscription = cx.subscribe_in(&thread, window, { - let workspace = workspace.clone(); - move |this, _thread, event, window, cx| { - this.handle_thread_event(&workspace, event, window, cx) - } - }); + let thread_subscription = match &thread { + AgentDiffThread::Native(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, _thread, event, window, cx| { + this.handle_native_thread_event(&workspace, event, window, cx) + } + }), + AgentDiffThread::AcpThread(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, thread, event, window, cx| { + this.handle_acp_thread_event(&workspace, thread, event, window, cx) + } + }), + }; if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { // replace thread and action log subscription, but keep editors workspace_thread.thread = thread.downgrade(); - workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription]; + workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription); self.update_reviewing_editors(&workspace, window, cx); return; } @@ -1272,7 +1385,7 @@ impl AgentDiff { workspace.clone(), WorkspaceThread { thread: thread.downgrade(), - _thread_subscriptions: [action_log_subscription, thread_subscription], + _thread_subscriptions: (action_log_subscription, thread_subscription), singleton_editors: HashMap::default(), _settings_subscription: settings_subscription, _workspace_subscription: workspace_subscription, @@ -1319,7 +1432,7 @@ impl AgentDiff { fn register_review_action( workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState + 'static, this: &Entity, ) { @@ -1338,7 +1451,7 @@ impl AgentDiff { }); } - fn handle_thread_event( + fn handle_native_thread_event( &mut self, workspace: &WeakEntity, event: &ThreadEvent, @@ -1375,11 +1488,42 @@ impl AgentDiff { | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached | ThreadEvent::CancelEditing - | ThreadEvent::RetriesFailed { .. } | ThreadEvent::ProfileChanged => {} } } + fn handle_acp_thread_event( + &mut self, + workspace: &WeakEntity, + thread: &Entity, + event: &AcpThreadEvent, + window: &mut Window, + cx: &mut Context, + ) { + match event { + AcpThreadEvent::NewEntry => { + if thread + .read(cx) + .entries() + .last() + .map_or(false, |entry| entry.diffs().next().is_some()) + { + self.update_reviewing_editors(workspace, window, cx); + } + } + AcpThreadEvent::EntryUpdated(ix) => { + if thread + .read(cx) + .entries() + .get(*ix) + .map_or(false, |entry| entry.diffs().next().is_some()) + { + self.update_reviewing_editors(workspace, window, cx); + } + } + } + } + fn handle_workspace_event( &mut self, workspace: &Entity, @@ -1485,7 +1629,7 @@ impl AgentDiff { return; }; - let action_log = thread.read(cx).action_log(); + let action_log = thread.action_log(cx); let changed_buffers = action_log.read(cx).changed_buffers(cx); let mut unaffected = self.reviewing_editors.clone(); @@ -1510,7 +1654,7 @@ impl AgentDiff { multibuffer.add_diff(diff_handle.clone(), cx); }); - let new_state = if thread.read(cx).is_generating() { + let new_state = if thread.is_generating(cx) { EditorState::Generating } else { EditorState::Reviewing @@ -1606,7 +1750,7 @@ impl AgentDiff { fn keep_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1626,7 +1770,7 @@ impl AgentDiff { fn reject_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1646,7 +1790,7 @@ impl AgentDiff { fn keep( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1659,7 +1803,7 @@ impl AgentDiff { fn reject( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1682,7 +1826,7 @@ impl AgentDiff { fn review_in_active_editor( &mut self, workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState, + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState, window: &mut Window, cx: &mut Context, ) -> Option>> { @@ -1703,7 +1847,7 @@ impl AgentDiff { if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx); + let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); let mut keys = changed_buffers.keys().cycle(); keys.find(|k| *k == &curr_buffer); @@ -1801,8 +1945,9 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let thread = + AgentDiffThread::Native(thread_store.update(cx, |store, cx| store.create_thread(cx))); + let action_log = cx.read(|cx| thread.action_log(cx)); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); @@ -1988,8 +2133,9 @@ mod tests { }); // Set the active thread + let thread = AgentDiffThread::Native(thread); cx.update(|window, cx| { - AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx) + AgentDiff::set_active_thread(&workspace.downgrade(), thread.clone(), window, cx) }); let buffer1 = project diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index f7b9157bbb..b989e7bf1e 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -1,8 +1,6 @@ use crate::{ ModelUsageContext, - language_model_selector::{ - LanguageModelSelector, ToggleModelSelector, language_model_selector, - }, + language_model_selector::{LanguageModelSelector, language_model_selector}, }; use agent_settings::AgentSettings; use fs::Fs; @@ -12,6 +10,7 @@ use picker::popover_menu::PickerPopoverMenu; use settings::update_settings_file; use std::sync::Arc; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; +use zed_actions::agent::ToggleModelSelector; pub struct AgentModelSelector { selector: Entity, @@ -96,22 +95,18 @@ impl Render for AgentModelSelector { let model_name = model .as_ref() .map(|model| model.model.name().0) - .unwrap_or_else(|| SharedString::from("No model selected")); - let provider_icon = model - .as_ref() - .map(|model| model.provider.icon()) - .unwrap_or_else(|| IconName::Ai); + .unwrap_or_else(|| SharedString::from("Select a Model")); + + let provider_icon = model.as_ref().map(|model| model.provider.icon()); let focus_handle = self.focus_handle.clone(); PickerPopoverMenu::new( self.selector.clone(), ButtonLike::new("active-model") - .child( - Icon::new(provider_icon) - .color(Color::Muted) - .size(IconSize::XSmall), - ) + .when_some(provider_icon, |this, icon| { + this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + }) .child( Label::new(model_name) .color(Color::Muted) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5f58e0bd8d..4b3db4bc1d 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1,18 +1,24 @@ -use std::ops::Range; +use std::cell::RefCell; +use std::ops::{Not, Range}; use std::path::Path; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; +use agent_servers::AgentServer; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; -use crate::language_model_selector::ToggleModelSelector; +use crate::NewExternalAgentThread; +use crate::agent_diff::AgentDiffThread; +use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; +use crate::ui::NewThreadButton; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, ToggleBurnMode, ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu, + acp::AcpThreadView, active_thread::{self, ActiveThread, ActiveThreadEvent}, agent_configuration::{AgentConfiguration, AssistantConfigurationEvent}, agent_diff::AgentDiff, @@ -23,7 +29,7 @@ use crate::{ render_remaining_tokens, }, thread_history::{HistoryEntryElement, ThreadHistory}, - ui::AgentOnboardingModal, + ui::{AgentOnboardingModal, EndTrialUpsell}, }; use agent::{ Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, @@ -32,22 +38,24 @@ use agent::{ thread_store::{TextThreadStore, ThreadStore}, }; use agent_settings::{AgentDockPosition, AgentSettings, CompletionMode, DefaultView}; +use ai_onboarding::AgentPanelOnboarding; use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; -use client::{UserStore, zed_urls}; +use client::{DisableAiSettings, UserStore, zed_urls}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; +use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; use gpui::{ Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem, Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, Hsla, - KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, linear_color_stop, - linear_gradient, prelude::*, pulsating_between, + KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, + pulsating_between, }; use language::LanguageRegistry; use language_model::{ - ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, + ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; @@ -59,8 +67,8 @@ use theme::ThemeSettings; use time::UtcOffset; use ui::utils::WithRemSize; use ui::{ - Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, - PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, + Banner, Callout, ContextMenu, ContextMenuEntry, ElevationIndex, KeyBinding, PopoverMenu, + PopoverMenuHandle, ProgressBar, Tab, Tooltip, prelude::*, }; use util::ResultExt as _; use workspace::{ @@ -69,7 +77,7 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding}, + agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; use zed_llm_client::{CompletionIntent, UsageLimit}; @@ -109,6 +117,14 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); } }) + .register_action(|workspace, action: &NewExternalAgentThread, window, cx| { + if let Some(panel) = workspace.panel::(cx) { + workspace.focus_panel::(window, cx); + panel.update(cx, |panel, cx| { + panel.new_external_thread(action.agent, window, cx) + }); + } + }) .register_action(|workspace, action: &OpenRulesLibrary, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); @@ -125,7 +141,8 @@ pub fn init(cx: &mut App) { let thread = thread.read(cx).thread().clone(); AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx); } - ActiveView::TextThread { .. } + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} } @@ -171,7 +188,7 @@ pub fn init(cx: &mut App) { window.refresh(); }) .register_action(|_workspace, _: &ResetTrialUpsell, _window, cx| { - Upsell::set_dismissed(false, cx); + OnboardingUpsell::set_dismissed(false, cx); }) .register_action(|_workspace, _: &ResetTrialEndUpsell, _window, cx| { TrialEndUpsell::set_dismissed(false, cx); @@ -188,6 +205,9 @@ enum ActiveView { message_editor: Entity, _subscriptions: Vec, }, + ExternalAgentThread { + thread_view: Entity, + }, TextThread { context_editor: Entity, title_editor: Entity, @@ -207,7 +227,9 @@ enum WhichFontSize { impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { - ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont, + ActiveView::Thread { .. } + | ActiveView::ExternalAgentThread { .. } + | ActiveView::History => WhichFontSize::AgentFont, ActiveView::TextThread { .. } => WhichFontSize::BufferFont, ActiveView::Configuration => WhichFontSize::None, } @@ -238,6 +260,7 @@ impl ActiveView { thread.scroll_to_bottom(cx); }); } + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -416,18 +439,21 @@ pub struct AgentPanel { configuration_subscription: Option, local_timezone: UtcOffset, active_view: ActiveView, + acp_message_history: + Rc>>>, previous_view: Option, history_store: Entity, history: Entity, hovered_recent_history_item: Option, - assistant_dropdown_menu_handle: PopoverMenuHandle, + new_thread_menu_handle: PopoverMenuHandle, + agent_panel_menu_handle: PopoverMenuHandle, assistant_navigation_menu_handle: PopoverMenuHandle, assistant_navigation_menu: Option>, width: Option, height: Option, zoomed: bool, pending_serialization: Option>>, - hide_upsell: bool, + onboarding: Entity, } impl AgentPanel { @@ -529,6 +555,7 @@ impl AgentPanel { let user_store = workspace.app_state().user_store.clone(); let project = workspace.project(); let language_registry = project.read(cx).languages().clone(); + let client = workspace.client().clone(); let workspace = workspace.weak_handle(); let weak_self = cx.entity().downgrade(); @@ -537,6 +564,17 @@ impl AgentPanel { let inline_assist_context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); + let thread_id = thread.read(cx).id().clone(); + + let history_store = cx.new(|cx| { + HistoryStore::new( + thread_store.clone(), + context_store.clone(), + [HistoryEntryId::Thread(thread_id)], + cx, + ) + }); + let message_editor = cx.new(|cx| { MessageEditor::new( fs.clone(), @@ -546,22 +584,13 @@ impl AgentPanel { prompt_store.clone(), thread_store.downgrade(), context_store.downgrade(), + Some(history_store.downgrade()), thread.clone(), window, cx, ) }); - let thread_id = thread.read(cx).id().clone(); - let history_store = cx.new(|cx| { - HistoryStore::new( - thread_store.clone(), - context_store.clone(), - [HistoryEntryId::Thread(thread_id)], - cx, - ) - }); - cx.observe(&history_store, |_, _, cx| cx.notify()).detach(); let active_thread = cx.new(|cx| { @@ -607,7 +636,7 @@ impl AgentPanel { } }; - AgentDiff::set_active_thread(&workspace, &thread, window, cx); + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); let weak_panel = weak_self.clone(); @@ -653,7 +682,8 @@ impl AgentPanel { .clone() .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); } - ActiveView::TextThread { .. } + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} }, @@ -661,6 +691,17 @@ impl AgentPanel { }, ); + let onboarding = cx.new(|cx| { + AgentPanelOnboarding::new( + user_store.clone(), + client, + |_window, cx| { + OnboardingUpsell::set_dismissed(true, cx); + }, + cx, + ) + }); + Self { active_view, workspace, @@ -680,17 +721,19 @@ impl AgentPanel { .unwrap(), inline_assist_context_store, previous_view: None, + acp_message_history: Default::default(), history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), hovered_recent_history_item: None, - assistant_dropdown_menu_handle: PopoverMenuHandle::default(), + new_thread_menu_handle: PopoverMenuHandle::default(), + agent_panel_menu_handle: PopoverMenuHandle::default(), assistant_navigation_menu_handle: PopoverMenuHandle::default(), assistant_navigation_menu: None, width: None, height: None, zoomed: false, pending_serialization: None, - hide_upsell: false, + onboarding, } } @@ -703,6 +746,7 @@ impl AgentPanel { if workspace .panel::(cx) .is_some_and(|panel| panel.read(cx).enabled(cx)) + && !DisableAiSettings::get_global(cx).disable_ai { workspace.toggle_panel_focus::(window, cx); } @@ -733,6 +777,9 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } + ActiveView::ExternalAgentThread { thread_view, .. } => { + thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); + } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} } } @@ -740,18 +787,18 @@ impl AgentPanel { fn active_message_editor(&self) -> Option<&Entity> { match &self.active_view { ActiveView::Thread { message_editor, .. } => Some(message_editor), - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => None, } } fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context) { - // Preserve chat box text when using creating new thread from summary' - let preserved_text = if action.from_thread_id.is_some() { - self.active_message_editor() - .map(|editor| editor.read(cx).get_text(cx).trim().to_string()) - } else { - None - }; + // Preserve chat box text when using creating new thread + let preserved_text = self + .active_message_editor() + .map(|editor| editor.read(cx).get_text(cx).trim().to_string()); let thread = self .thread_store @@ -806,6 +853,7 @@ impl AgentPanel { self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -823,7 +871,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -862,6 +910,81 @@ impl AgentPanel { context_editor.focus_handle(cx).focus(window); } + fn new_external_thread( + &mut self, + agent_choice: Option, + window: &mut Window, + cx: &mut Context, + ) { + let workspace = self.workspace.clone(); + let project = self.project.clone(); + let message_history = self.acp_message_history.clone(); + + const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; + + #[derive(Default, Serialize, Deserialize)] + struct LastUsedExternalAgent { + agent: crate::ExternalAgent, + } + + cx.spawn_in(window, async move |this, cx| { + let server: Rc = match agent_choice { + Some(agent) => { + cx.background_spawn(async move { + if let Some(serialized) = + serde_json::to_string(&LastUsedExternalAgent { agent }).log_err() + { + KEY_VALUE_STORE + .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized) + .await + .log_err(); + } + }) + .detach(); + + agent.server() + } + None => cx + .background_spawn(async move { + KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) + }) + .await + .log_err() + .flatten() + .and_then(|value| { + serde_json::from_str::(&value).log_err() + }) + .unwrap_or_default() + .agent + .server(), + }; + + this.update_in(cx, |this, window, cx| { + let thread_view = cx.new(|cx| { + crate::acp::AcpThreadView::new( + server, + workspace.clone(), + project, + message_history, + MIN_EDITOR_LINES, + Some(MAX_EDITOR_LINES), + window, + cx, + ) + }); + + this.set_active_view( + ActiveView::ExternalAgentThread { + thread_view: thread_view.clone(), + }, + window, + cx, + ); + }) + }) + .detach_and_log_err(cx); + } + fn deploy_rules_library( &mut self, action: &OpenRulesLibrary, @@ -994,6 +1117,7 @@ impl AgentPanel { cx, ) }); + let message_editor = cx.new(|cx| { MessageEditor::new( self.fs.clone(), @@ -1003,6 +1127,7 @@ impl AgentPanel { self.prompt_store.clone(), self.thread_store.downgrade(), self.context_store.downgrade(), + Some(self.history_store.downgrade()), thread.clone(), window, cx, @@ -1012,7 +1137,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context) { @@ -1025,6 +1150,9 @@ impl AgentPanel { ActiveView::Thread { message_editor, .. } => { message_editor.focus_handle(cx).focus(window); } + ActiveView::ExternalAgentThread { thread_view } => { + thread_view.focus_handle(cx).focus(window); + } ActiveView::TextThread { context_editor, .. } => { context_editor.focus_handle(cx).focus(window); } @@ -1052,7 +1180,7 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - self.assistant_dropdown_menu_handle.toggle(window, cx); + self.agent_panel_menu_handle.toggle(window, cx); } pub fn increase_font_size( @@ -1140,11 +1268,19 @@ impl AgentPanel { let thread = thread.read(cx).thread().clone(); self.workspace .update(cx, |workspace, cx| { - AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx) + AgentDiffPane::deploy_in_workspace( + AgentDiffThread::Native(thread), + workspace, + window, + cx, + ) }) .log_err(); } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} } } @@ -1197,6 +1333,13 @@ impl AgentPanel { ) .detach_and_log_err(cx); } + ActiveView::ExternalAgentThread { thread_view } => { + thread_view + .update(cx, |thread_view, cx| { + thread_view.open_thread_as_markdown(workspace, window, cx) + }) + .detach_and_log_err(cx); + } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} } } @@ -1224,6 +1367,19 @@ impl AgentPanel { } self.new_thread(&NewThread::default(), window, cx); + if let Some((thread, model)) = + self.active_thread(cx).zip(provider.default_model(cx)) + { + thread.update(cx, |thread, cx| { + thread.set_configured_model( + Some(ConfiguredModel { + provider: provider.clone(), + model, + }), + cx, + ); + }); + } } } } @@ -1351,7 +1507,8 @@ impl AgentPanel { } }) } - _ => {} + ActiveView::ExternalAgentThread { .. } => {} + ActiveView::History | ActiveView::Configuration => {} } if current_is_special && !new_is_special { @@ -1365,6 +1522,8 @@ impl AgentPanel { self.active_view = new_view; } + self.acp_message_history.borrow_mut().reset_position(); + self.focus_handle(cx).focus(window); } @@ -1437,6 +1596,7 @@ impl Focusable for AgentPanel { fn focus_handle(&self, cx: &App) -> FocusHandle { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), + ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), ActiveView::History => self.history.focus_handle(cx), ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { @@ -1526,7 +1686,7 @@ impl Panel for AgentPanel { } fn enabled(&self, cx: &App) -> bool { - AgentSettings::get_global(cx).enabled + DisableAiSettings::get_global(cx).disable_ai.not() && AgentSettings::get_global(cx).enabled } fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool { @@ -1593,6 +1753,11 @@ impl AgentPanel { .into_any_element(), } } + ActiveView::ExternalAgentThread { thread_view } => { + Label::new(thread_view.read(cx).title(cx)) + .truncate() + .into_any_element() + } ActiveView::TextThread { title_editor, context_editor, @@ -1727,10 +1892,112 @@ impl AgentPanel { let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => None, }; - let agent_extra_menu = PopoverMenu::new("agent-options-menu") + let new_thread_menu = PopoverMenu::new("new_thread_menu") + .trigger_with_tooltip( + IconButton::new("new_thread_menu_btn", IconName::Plus).icon_size(IconSize::Small), + Tooltip::text("New Thread…"), + ) + .anchor(Corner::TopRight) + .with_handle(self.new_thread_menu_handle.clone()) + .menu({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + let active_thread = active_thread.clone(); + Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { + menu = menu + .context(focus_handle.clone()) + .when(cx.has_flag::(), |this| { + this.header("Zed Agent") + }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::NewThread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::NewTextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ) + .when_some(active_thread, |this, active_thread| { + let thread = active_thread.read(cx); + + if !thread.is_empty() { + let thread_id = thread.id().clone(); + this.item( + ContextMenuEntry::new("New From Summary") + .icon(IconName::NewFromSummary) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + cx, + ); + }), + ) + } else { + this + } + }) + .when(cx.has_flag::(), |this| { + this.separator() + .header("External Agents") + .item( + ContextMenuEntry::new("New Gemini Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + } + .boxed_clone(), + cx, + ); + }), + ) + }); + menu + })) + } + }); + + let agent_panel_menu = PopoverMenu::new("agent-options-menu") .trigger_with_tooltip( IconButton::new("agent-options-menu", IconName::Ellipsis) .icon_size(IconSize::Small), @@ -1748,41 +2015,9 @@ impl AgentPanel { }, ) .anchor(Corner::TopRight) - .with_handle(self.assistant_dropdown_menu_handle.clone()) + .with_handle(self.agent_panel_menu_handle.clone()) .menu(move |window, cx| { - let active_thread = active_thread.clone(); - Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { - menu = menu - .action("New Thread", NewThread::default().boxed_clone()) - .action("New Text Thread", NewTextThread.boxed_clone()) - .when_some(active_thread, |this, active_thread| { - let thread = active_thread.read(cx); - if !thread.is_empty() { - this.action( - "New From Summary", - Box::new(NewThread { - from_thread_id: Some(thread.id().clone()), - }), - ) - } else { - this - } - }) - .separator(); - - menu = menu - .header("MCP Servers") - .action( - "View Server Extensions", - Box::new(zed_actions::Extensions { - category_filter: Some( - zed_actions::ExtensionCategoryFilter::ContextServers, - ), - }), - ) - .action("Add Custom Server…", Box::new(AddContextServer)) - .separator(); - + Some(ContextMenu::build(window, cx, |mut menu, _window, _| { if let Some(usage) = usage { menu = menu .header_with_link("Prompt Usage", "Manage", account_url.clone()) @@ -1820,6 +2055,20 @@ impl AgentPanel { .separator() } + menu = menu + .header("MCP Servers") + .action( + "View Server Extensions", + Box::new(zed_actions::Extensions { + category_filter: Some( + zed_actions::ExtensionCategoryFilter::ContextServers, + ), + id: None, + }), + ) + .action("Add Custom Server…", Box::new(AddContextServer)) + .separator(); + menu = menu .action("Rules…", Box::new(OpenRulesLibrary::default())) .action("Settings", Box::new(OpenConfiguration)) @@ -1861,71 +2110,52 @@ impl AgentPanel { .px(DynamicSpacing::Base08.rems(cx)) .border_l_1() .border_color(cx.theme().colors().border) - .child( - IconButton::new("new", IconName::Plus) - .icon_size(IconSize::Small) - .style(ButtonStyle::Subtle) - .tooltip(move |window, cx| { - Tooltip::for_action_in( - "New Thread", - &NewThread::default(), - &focus_handle, - window, - cx, - ) - }) - .on_click(move |_event, window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); - }), - ) - .child(agent_extra_menu), + .child(new_thread_menu) + .child(agent_panel_menu), ), ) } fn render_token_count(&self, cx: &App) -> Option { - let (active_thread, message_editor) = match &self.active_view { + match &self.active_view { ActiveView::Thread { thread, message_editor, .. - } => (thread.read(cx), message_editor.read(cx)), - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { - return None; - } - }; + } => { + let active_thread = thread.read(cx); + let message_editor = message_editor.read(cx); - let editor_empty = message_editor.is_editor_fully_empty(cx); + let editor_empty = message_editor.is_editor_fully_empty(cx); - if active_thread.is_empty() && editor_empty { - return None; - } + if active_thread.is_empty() && editor_empty { + return None; + } - let thread = active_thread.thread().read(cx); - let is_generating = thread.is_generating(); - let conversation_token_usage = thread.total_token_usage()?; + let thread = active_thread.thread().read(cx); + let is_generating = thread.is_generating(); + let conversation_token_usage = thread.total_token_usage()?; - let (total_token_usage, is_estimating) = - if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() { - let combined = thread - .token_usage_up_to_message(editing_message_id) - .add(unsent_tokens); + let (total_token_usage, is_estimating) = + if let Some((editing_message_id, unsent_tokens)) = + active_thread.editing_message_id() + { + let combined = thread + .token_usage_up_to_message(editing_message_id) + .add(unsent_tokens); - (combined, unsent_tokens > 0) - } else { - let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0); - let combined = conversation_token_usage.add(unsent_tokens); + (combined, unsent_tokens > 0) + } else { + let unsent_tokens = + message_editor.last_estimated_token_count().unwrap_or(0); + let combined = conversation_token_usage.add(unsent_tokens); - (combined, unsent_tokens > 0) - }; + (combined, unsent_tokens > 0) + }; - let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count(); + let is_waiting_to_update_token_count = + message_editor.is_waiting_to_update_token_count(); - match &self.active_view { - ActiveView::Thread { .. } => { if total_token_usage.total == 0 { return None; } @@ -2002,7 +2232,11 @@ impl AgentPanel { Some(element.into_any_element()) } - _ => None, + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => { + return None; + } } } @@ -2011,188 +2245,91 @@ impl AgentPanel { return false; } + match &self.active_view { + ActiveView::Thread { thread, .. } => { + if thread + .read(cx) + .thread() + .read(cx) + .configured_model() + .map_or(false, |model| { + model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }) + { + return false; + } + } + ActiveView::TextThread { .. } => { + if LanguageModelRegistry::global(cx) + .read(cx) + .default_model() + .map_or(false, |model| { + model.provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }) + { + return false; + } + } + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => return false, + } + let plan = self.user_store.read(cx).current_plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); matches!(plan, Some(Plan::Free)) && has_previous_trial } - fn should_render_upsell(&self, cx: &mut Context) -> bool { + fn should_render_onboarding(&self, cx: &mut Context) -> bool { + if OnboardingUpsell::dismissed() { + return false; + } + match &self.active_view { - ActiveView::Thread { thread, .. } => { - let is_using_zed_provider = thread - .read(cx) - .thread() - .read(cx) - .configured_model() - .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID); + ActiveView::Thread { .. } | ActiveView::TextThread { .. } => { + let history_is_empty = self + .history_store + .update(cx, |store, cx| store.recent_entries(1, cx).is_empty()); - if !is_using_zed_provider { - return false; - } + let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .any(|provider| { + provider.is_authenticated(cx) + && provider.id() != language_model::ZED_CLOUD_PROVIDER_ID + }); + + history_is_empty || !has_configured_non_zed_providers } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { - return false; - } - }; - - if self.hide_upsell || Upsell::dismissed() { - return false; + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => false, } - - let plan = self.user_store.read(cx).current_plan(); - if matches!(plan, Some(Plan::ZedPro | Plan::ZedProTrial)) { - return false; - } - - let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - if has_previous_trial { - return false; - } - - true } - fn render_upsell( + fn render_onboarding( &self, _window: &mut Window, cx: &mut Context, ) -> Option { - if !self.should_render_upsell(cx) { + if !self.should_render_onboarding(cx) { return None; } - if self.user_store.read(cx).account_too_young() { - Some(self.render_young_account_upsell(cx).into_any_element()) - } else { - Some(self.render_trial_upsell(cx).into_any_element()) - } - } + let thread_view = matches!(&self.active_view, ActiveView::Thread { .. }); + let text_thread_view = matches!(&self.active_view, ActiveView::TextThread { .. }); - fn render_young_account_upsell(&self, cx: &mut Context) -> impl IntoElement { - let checkbox = CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again").color(Color::Muted), - ToggleState::Unselected, - move |toggle_state, _window, cx| { - let toggle_state_bool = toggle_state.selected(); - - Upsell::set_dismissed(toggle_state_bool, cx); - }, - ); - - let contents = div() - .size_full() - .gap_2() - .flex() - .flex_col() - .child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small)) - .child( - Label::new("Your GitHub account was created less than 30 days ago, so we can't offer you a free trial.") - .size(LabelSize::Small), - ) - .child( - Label::new( - "Use your own API keys, upgrade to Zed Pro or send an email to billing-support@zed.dev.", - ) - .color(Color::Muted), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(h_flex().items_center().gap_1().child(checkbox)) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Not Now") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |this, cx| { - this.hide_upsell = true; - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Upgrade to Zed Pro") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ), - ), - ); - - self.render_upsell_container(cx, contents) - } - - fn render_trial_upsell(&self, cx: &mut Context) -> impl IntoElement { - let checkbox = CheckboxWithLabel::new( - "dont-show-again", - Label::new("Don't show again").color(Color::Muted), - ToggleState::Unselected, - move |toggle_state, _window, cx| { - let toggle_state_bool = toggle_state.selected(); - - Upsell::set_dismissed(toggle_state_bool, cx); - }, - ); - - let contents = div() - .size_full() - .gap_2() - .flex() - .flex_col() - .child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small)) - .child( - Label::new("Try Zed Pro for free for 14 days - no credit card required.") - .size(LabelSize::Small), - ) - .child( - Label::new( - "Use your own API keys or enable usage-based billing once you hit the cap.", - ) - .color(Color::Muted), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(h_flex().items_center().gap_1().child(checkbox)) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Not Now") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |this, cx| { - this.hide_upsell = true; - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Start Trial") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))), - ), - ), - ); - - self.render_upsell_container(cx, contents) + Some( + div() + .when(thread_view, |this| { + this.size_full().bg(cx.theme().colors().panel_background) + }) + .when(text_thread_view, |this| { + this.bg(cx.theme().colors().editor_background) + }) + .child(self.onboarding.clone()), + ) } fn render_trial_end_upsell( @@ -2204,141 +2341,37 @@ impl AgentPanel { return None; } - Some( - self.render_upsell_container( - cx, - div() - .size_full() - .gap_2() - .flex() - .flex_col() - .child( - Headline::new("Your Zed Pro trial has expired.").size(HeadlineSize::Small), - ) - .child( - Label::new("You've been automatically reset to the free plan.") - .size(LabelSize::Small), - ) - .child( - h_flex() - .w_full() - .px_neg_1() - .justify_between() - .items_center() - .child(div()) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "Stay on Free") - .style(ButtonStyle::Transparent) - .color(Color::Muted) - .on_click({ - let agent_panel = cx.entity(); - move |_, _, cx| { - agent_panel.update(cx, |_this, cx| { - TrialEndUpsell::set_dismissed(true, cx); - cx.notify(); - }); - } - }), - ) - .child( - Button::new("cta-button", "Upgrade to Zed Pro") - .style(ButtonStyle::Transparent) - .on_click(|_, _, cx| { - cx.open_url(&zed_urls::account_url(cx)) - }), - ), - ), - ), - ), - ) + Some(EndTrialUpsell::new(Arc::new({ + let this = cx.entity(); + move |_, cx| { + this.update(cx, |_this, cx| { + TrialEndUpsell::set_dismissed(true, cx); + cx.notify(); + }); + } + }))) } - fn render_upsell_container(&self, cx: &mut Context, content: Div) -> Div { - div().p_2().child( - v_flex() - .w_full() - .elevation_2(cx) - .rounded(px(8.)) - .bg(cx.theme().colors().background.alpha(0.5)) - .p(px(3.)) - .child( - div() - .gap_2() - .flex() - .flex_col() - .size_full() - .border_1() - .rounded(px(5.)) - .border_color(cx.theme().colors().text.alpha(0.1)) - .overflow_hidden() - .relative() - .bg(cx.theme().colors().panel_background) - .px_4() - .py_3() - .child( - div() - .absolute() - .top_0() - .right(px(-1.0)) - .w(px(441.)) - .h(px(167.)) - .child( - Vector::new( - VectorName::Grid, - rems_from_px(441.), - rems_from_px(167.), - ) - .color(ui::Color::Custom(cx.theme().colors().text.alpha(0.1))), - ), - ) - .child( - div() - .absolute() - .top(px(-8.0)) - .right_0() - .w(px(400.)) - .h(px(92.)) - .child( - Vector::new( - VectorName::AiGrid, - rems_from_px(400.), - rems_from_px(92.), - ) - .color(ui::Color::Custom(cx.theme().colors().text.alpha(0.32))), - ), - ) - // .child( - // div() - // .absolute() - // .top_0() - // .right(px(360.)) - // .size(px(401.)) - // .overflow_hidden() - // .bg(cx.theme().colors().panel_background) - // ) - .child( - div() - .absolute() - .top_0() - .right_0() - .w(px(660.)) - .h(px(401.)) - .overflow_hidden() - .bg(linear_gradient( - 75., - linear_color_stop( - cx.theme().colors().panel_background.alpha(0.01), - 1.0, - ), - linear_color_stop(cx.theme().colors().panel_background, 0.45), - )), - ) - .child(content), - ), - ) + fn render_empty_state_section_header( + &self, + label: impl Into, + action_slot: Option, + cx: &mut Context, + ) -> impl IntoElement { + h_flex() + .mt_2() + .pl_1p5() + .pb_1() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child( + Label::new(label.into()) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .children(action_slot) } fn render_thread_empty_state( @@ -2351,8 +2384,10 @@ impl AgentPanel { .update(cx, |this, cx| this.recent_entries(6, cx)); let model_registry = LanguageModelRegistry::read_global(cx); + let configuration_error = model_registry.configuration_error(model_registry.default_model(), cx); + let no_error = configuration_error.is_none(); let focus_handle = self.focus_handle(cx); @@ -2360,11 +2395,9 @@ impl AgentPanel { .size_full() .bg(cx.theme().colors().panel_background) .when(recent_history.is_empty(), |this| { - let configuration_error_ref = &configuration_error; this.child( v_flex() .size_full() - .max_w_80() .mx_auto() .justify_center() .items_center() @@ -2372,156 +2405,100 @@ impl AgentPanel { .child(h_flex().child(Headline::new("Welcome to the Agent Panel"))) .when(no_error, |parent| { parent + .child(h_flex().child( + Label::new("Ask and build anything.").color(Color::Muted), + )) .child( - h_flex().child( - Label::new("Ask and build anything.") - .color(Color::Muted) - .mb_2p5(), - ), - ) - .child( - Button::new("new-thread", "Start New Thread") - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &NewThread::default(), - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("context", "Add Context") - .icon(IconName::FileCode) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &ToggleContextPicker, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - ToggleContextPicker.boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("mode", "Switch Model") - .icon(IconName::DatabaseZap) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &ToggleModelSelector, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - ToggleModelSelector.boxed_clone(), - cx, - ) - }), - ) - .child( - Button::new("settings", "View Settings") - .icon(IconName::Settings) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, - ) - }), + v_flex() + .mt_2() + .gap_1() + .max_w_48() + .child( + Button::new("context", "Add Context") + .label_size(LabelSize::Small) + .icon(IconName::FileCode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &ToggleContextPicker, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + ToggleContextPicker.boxed_clone(), + cx, + ) + }), + ) + .child( + Button::new("mode", "Switch Model") + .label_size(LabelSize::Small) + .icon(IconName::DatabaseZap) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &ToggleModelSelector, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + ToggleModelSelector.boxed_clone(), + cx, + ) + }), + ) + .child( + Button::new("settings", "View Settings") + .label_size(LabelSize::Small) + .icon(IconName::Settings) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .full_width() + .key_binding(KeyBinding::for_action_in( + &OpenConfiguration, + &focus_handle, + window, + cx, + )) + .on_click(|_event, window, cx| { + window.dispatch_action( + OpenConfiguration.boxed_clone(), + cx, + ) + }), + ), ) }) - .map(|parent| match configuration_error_ref { - Some( - err @ (ConfigurationError::ModelNotFound - | ConfigurationError::ProviderNotAuthenticated(_) - | ConfigurationError::NoProvider), - ) => parent - .child(h_flex().child( - Label::new(err.to_string()).color(Color::Muted).mb_2p5(), - )) - .child( - Button::new("settings", "Configure a Provider") - .icon(IconName::Settings) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .full_width() - .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - )) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, - ) - }), - ), - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent.children(provider.render_accept_terms( - LanguageModelProviderTosView::ThreadFreshStart, - cx, - )) - } - None => parent, + .when_some(configuration_error.as_ref(), |this, err| { + this.child(self.render_configuration_error( + err, + &focus_handle, + window, + cx, + )) }), ) }) .when(!recent_history.is_empty(), |parent| { let focus_handle = focus_handle.clone(); - let configuration_error_ref = &configuration_error; - parent .overflow_hidden() .p_1p5() .justify_end() .gap_1() .child( - h_flex() - .pl_1p5() - .pb_1() - .w_full() - .justify_between() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .child( - Label::new("Recent") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( + self.render_empty_state_section_header( + "Recent", + Some( Button::new("view-history", "View All") .style(ButtonStyle::Subtle) .label_size(LabelSize::Small) @@ -2536,8 +2513,11 @@ impl AgentPanel { ) .on_click(move |_event, window, cx| { window.dispatch_action(OpenHistory.boxed_clone(), cx); - }), + }) + .into_any_element(), ), + cx, + ), ) .child( v_flex() @@ -2565,49 +2545,162 @@ impl AgentPanel { }, )), ) - .map(|parent| match configuration_error_ref { - Some( - err @ (ConfigurationError::ModelNotFound - | ConfigurationError::ProviderNotAuthenticated(_) - | ConfigurationError::NoProvider), - ) => parent.child( - Banner::new() - .severity(ui::Severity::Warning) - .child(Label::new(err.to_string()).size(LabelSize::Small)) - .action_slot( - Button::new("settings", "Configure Provider") - .style(ButtonStyle::Tinted(ui::TintColor::Warning)) - .label_size(LabelSize::Small) - .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), + .child(self.render_empty_state_section_header("Start", None, cx)) + .child( + v_flex() + .p_1() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-thread-btn", + "New Thread", + IconName::NewThread, ) - .on_click(|_event, window, cx| { - window.dispatch_action( - OpenConfiguration.boxed_clone(), - cx, + .keybinding(KeyBinding::for_action_in( + &NewThread::default(), + &self.focus_handle(cx), + window, + cx, + )) + .on_click( + |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-text-thread-btn", + "New Text Thread", + IconName::NewTextThread, + ) + .keybinding(KeyBinding::for_action_in( + &NewTextThread, + &self.focus_handle(cx), + window, + cx, + )) + .on_click( + |window, cx| { + window.dispatch_action(Box::new(NewTextThread), cx) + }, + ), + ), + ) + .when(cx.has_flag::(), |this| { + this.child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-gemini-thread-btn", + "New Gemini Thread", + IconName::AiGemini, ) - }), - ), - ), - Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => { - parent.child(Banner::new().severity(ui::Severity::Warning).child( - h_flex().w_full().children(provider.render_accept_terms( - LanguageModelProviderTosView::ThreadEmptyState, - cx, - )), - )) - } - None => parent, + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::Gemini, + ), + }), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-claude-thread-btn", + "New Claude Code Thread", + IconName::AiClaude, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + }), + cx, + ) + }, + ), + ), + ) + }), + ) + .when_some(configuration_error.as_ref(), |this, err| { + this.child(self.render_configuration_error(err, &focus_handle, window, cx)) }) }) } + fn render_configuration_error( + &self, + configuration_error: &ConfigurationError, + focus_handle: &FocusHandle, + window: &mut Window, + cx: &mut App, + ) -> impl IntoElement { + match configuration_error { + ConfigurationError::ModelNotFound + | ConfigurationError::ProviderNotAuthenticated(_) + | ConfigurationError::NoProvider => Banner::new() + .severity(ui::Severity::Warning) + .child(Label::new(configuration_error.to_string())) + .action_slot( + Button::new("settings", "Configure Provider") + .style(ButtonStyle::Tinted(ui::TintColor::Warning)) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &OpenConfiguration, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenConfiguration.boxed_clone(), cx) + }), + ), + ConfigurationError::ProviderPendingTermsAcceptance(provider) => { + Banner::new().severity(ui::Severity::Warning).child( + h_flex().w_full().children( + provider.render_accept_terms( + LanguageModelProviderTosView::ThreadEmptyState, + cx, + ), + ), + ) + } + } + } + fn render_tool_use_limit_reached( &self, window: &mut Window, @@ -2615,6 +2708,9 @@ impl AgentPanel { ) -> Option { let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => thread, + ActiveView::ExternalAgentThread { .. } => { + return None; + } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { return None; } @@ -2737,7 +2833,7 @@ impl AgentPanel { this.clear_last_error(); }); - cx.open_url(&zed_urls::account_url(cx)); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)); cx.notify(); } })) @@ -2819,6 +2915,23 @@ impl AgentPanel { .size(IconSize::Small) .color(Color::Error); + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + div() .border_t_1() .border_color(cx.theme().colors().border) @@ -2827,13 +2940,76 @@ impl AgentPanel { .icon(icon) .title(header) .description(message.clone()) - .primary_action(self.dismiss_error_button(thread, cx)) - .secondary_action(self.create_copy_button(message_with_header)) + .primary_action(retry_button) + .secondary_action(self.dismiss_error_button(thread, cx)) + .tertiary_action(self.create_copy_button(message_with_header)) .bg_color(self.error_callout_bg(cx)), ) .into_any_element() } + fn render_retryable_error( + &self, + message: SharedString, + can_enable_burn_mode: bool, + thread: &Entity, + cx: &mut Context, + ) -> AnyElement { + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + let retry_button = Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.retry_last_completion(Some(window.window_handle()), cx); + }); + }); + } + }); + + let mut callout = Callout::new() + .icon(icon) + .title("Error") + .description(message.clone()) + .bg_color(self.error_callout_bg(cx)) + .primary_action(retry_button); + + if can_enable_burn_mode { + let burn_mode_button = Button::new("enable_burn_retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click({ + let thread = thread.clone(); + move |_, window, cx| { + thread.update(cx, |thread, cx| { + thread.clear_last_error(); + thread.thread().update(cx, |thread, cx| { + thread.enable_burn_mode_and_retry(Some(window.window_handle()), cx); + }); + }); + } + }); + callout = callout.secondary_action(burn_mode_button); + } + + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child(callout) + .into_any_element() + } + fn render_prompt_editor( &self, context_editor: &Entity, @@ -2961,6 +3137,9 @@ impl AgentPanel { .detach(); }); } + ActiveView::ExternalAgentThread { .. } => { + unimplemented!() + } ActiveView::TextThread { context_editor, .. } => { context_editor.update(cx, |context_editor, cx| { TextThreadEditor::insert_dragged_files( @@ -2979,8 +3158,10 @@ impl AgentPanel { fn key_context(&self) -> KeyContext { let mut key_context = KeyContext::new_with_defaults(); key_context.add("AgentPanel"); - if matches!(self.active_view, ActiveView::TextThread { .. }) { - key_context.add("prompt_editor"); + match &self.active_view { + ActiveView::ExternalAgentThread { .. } => key_context.add("external_agent_thread"), + ActiveView::TextThread { .. } => key_context.add("prompt_editor"), + ActiveView::Thread { .. } | ActiveView::History | ActiveView::Configuration => {} } key_context } @@ -3034,6 +3215,7 @@ impl Render for AgentPanel { }); this.continue_conversation(window, cx); } + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -3041,7 +3223,7 @@ impl Render for AgentPanel { })) .on_action(cx.listener(Self::toggle_burn_mode)) .child(self.render_toolbar(window, cx)) - .children(self.render_upsell(window, cx)) + .children(self.render_onboarding(window, cx)) .children(self.render_trial_end_upsell(window, cx)) .map(|parent| match &self.active_view { ActiveView::Thread { @@ -3050,12 +3232,14 @@ impl Render for AgentPanel { .. } => parent .relative() - .child(if thread.read(cx).is_empty() { - self.render_thread_empty_state(window, cx) - .into_any_element() - } else { - thread.clone().into_any_element() - }) + .child( + if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) { + self.render_thread_empty_state(window, cx) + .into_any_element() + } else { + thread.clone().into_any_element() + }, + ) .children(self.render_tool_use_limit_reached(window, cx)) .when_some(thread.read(cx).last_error(), |this, last_error| { this.child( @@ -3069,23 +3253,73 @@ impl Render for AgentPanel { ThreadError::Message { header, message } => { self.render_error_message(header, message, thread, cx) } + ThreadError::RetryableError { + message, + can_enable_burn_mode, + } => self.render_retryable_error( + message, + can_enable_burn_mode, + thread, + cx, + ), }) .into_any(), ) }) - .child(h_flex().child(message_editor.clone())) + .child(h_flex().relative().child(message_editor.clone()).when( + !LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx), + |this| { + this.child( + div() + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll(), + ) + }, + )) + .child(self.render_drag_target(cx)), + ActiveView::ExternalAgentThread { thread_view, .. } => parent + .relative() + .child(thread_view.clone()) .child(self.render_drag_target(cx)), ActiveView::History => parent.child(self.history.clone()), ActiveView::TextThread { context_editor, buffer_search_bar, .. - } => parent.child(self.render_prompt_editor( - context_editor, - buffer_search_bar, - window, - cx, - )), + } => { + let model_registry = LanguageModelRegistry::read_global(cx); + let configuration_error = + model_registry.configuration_error(model_registry.default_model(), cx); + parent + .map(|this| { + if !self.should_render_onboarding(cx) + && let Some(err) = configuration_error.as_ref() + { + this.child( + div().bg(cx.theme().colors().editor_background).p_2().child( + self.render_configuration_error( + err, + &self.focus_handle(cx), + window, + cx, + ), + ), + ) + } else { + this + } + }) + .child(self.render_prompt_editor( + context_editor, + buffer_search_bar, + window, + cx, + )) + } ActiveView::Configuration => parent.children(self.configuration.clone()), }); @@ -3254,9 +3488,9 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { } } -struct Upsell; +struct OnboardingUpsell; -impl Dismissable for Upsell { +impl Dismissable for OnboardingUpsell { const KEY: &'static str = "dismissed-trial-upsell"; } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index e488cf5a1e..cac0f1adac 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -1,3 +1,4 @@ +mod acp; mod active_thread; mod agent_configuration; mod agent_diff; @@ -24,12 +25,14 @@ mod thread_history; mod tool_compatibility; mod ui; +use std::rc::Rc; use std::sync::Arc; use agent::{Thread, ThreadId}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; -use client::Client; +use client::{Client, DisableAiSettings}; +use command_palette_hooks::CommandPaletteFilter; use feature_flags::FeatureFlagAppExt as _; use fs::Fs; use gpui::{Action, App, Entity, actions}; @@ -39,8 +42,9 @@ use language_model::{ }; use prompt_store::PromptBuilder; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; +use std::any::TypeId; pub use crate::active_thread::ActiveThread; use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal}; @@ -50,6 +54,7 @@ use crate::slash_command_settings::SlashCommandSettings; pub use agent_diff::{AgentDiffPane, AgentDiffToolbar}; pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor}; pub use ui::preview::{all_agent_previews, get_agent_preview}; +use zed_actions; actions!( agent, @@ -76,8 +81,6 @@ actions!( AddContextServer, /// Removes the currently selected thread. RemoveSelectedThread, - /// Starts a chat conversation with the agent. - Chat, /// Starts a chat conversation with follow-up enabled. ChatWithFollow, /// Cycles to the next inline assist suggestion. @@ -132,6 +135,32 @@ pub struct NewThread { from_thread_id: Option, } +/// Creates a new external agent conversation thread. +#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = agent)] +#[serde(deny_unknown_fields)] +pub struct NewExternalAgentThread { + /// Which agent to use for the conversation. + agent: Option, +} + +#[derive(Default, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +enum ExternalAgent { + #[default] + Gemini, + ClaudeCode, +} + +impl ExternalAgent { + pub fn server(&self) -> Rc { + match self { + ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), + ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + } + } +} + /// Opens the profile management interface for configuring agent tools and settings. #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] @@ -215,6 +244,66 @@ pub fn init( }) .detach(); cx.observe_new(ManageProfilesModal::register).detach(); + + // Update command palette filter based on AI settings + update_command_palette_filter(cx); + + // Watch for settings changes + cx.observe_global::(|app_cx| { + // When settings change, update the command palette filter + update_command_palette_filter(app_cx); + }) + .detach(); +} + +fn update_command_palette_filter(cx: &mut App) { + let disable_ai = DisableAiSettings::get_global(cx).disable_ai; + CommandPaletteFilter::update_global(cx, |filter, _| { + if disable_ai { + filter.hide_namespace("agent"); + filter.hide_namespace("assistant"); + filter.hide_namespace("zed_predict_onboarding"); + filter.hide_namespace("edit_prediction"); + + use editor::actions::{ + AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction, + PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction, + }; + let edit_prediction_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + filter.hide_action_types(&edit_prediction_actions); + filter.hide_action_types(&[TypeId::of::()]); + } else { + filter.show_namespace("agent"); + filter.show_namespace("assistant"); + filter.show_namespace("zed_predict_onboarding"); + + filter.show_namespace("edit_prediction"); + + use editor::actions::{ + AcceptEditPrediction, AcceptPartialEditPrediction, NextEditPrediction, + PreviousEditPrediction, ShowEditPrediction, ToggleEditPrediction, + }; + let edit_prediction_actions = [ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + filter.show_action_types(edit_prediction_actions.iter()); + + filter + .show_action_types([TypeId::of::()].iter()); + } + }); } fn init_language_model_settings(cx: &mut App) { diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 117dcf4f8e..64498e9281 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -475,6 +475,7 @@ impl CodegenAlternative { stop: Vec::new(), temperature, messages: vec![request_message], + thinking_allowed: false, } })) } diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 73fc0b36ce..5cc56b014e 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -1,6 +1,6 @@ mod completion_provider; mod fetch_context_picker; -mod file_context_picker; +pub(crate) mod file_context_picker; mod rules_context_picker; mod symbol_context_picker; mod thread_context_picker; diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index c9c173a68b..44ec050ae2 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -16,7 +16,7 @@ use agent::{ }; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; -use client::telemetry::Telemetry; +use client::{DisableAiSettings, telemetry::Telemetry}; use collections::{HashMap, HashSet, VecDeque, hash_map}; use editor::SelectionEffects; use editor::{ @@ -57,6 +57,17 @@ pub fn init( cx: &mut App, ) { cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry)); + + cx.observe_global::(|cx| { + if DisableAiSettings::get_global(cx).disable_ai { + // Hide any active inline assist UI when AI is disabled + InlineAssistant::update_global(cx, |assistant, cx| { + assistant.cancel_all_active_completions(cx); + }); + } + }) + .detach(); + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; @@ -141,6 +152,26 @@ impl InlineAssistant { .detach(); } + /// Hides all active inline assists when AI is disabled + pub fn cancel_all_active_completions(&mut self, cx: &mut App) { + // Cancel all active completions in editors + for (editor_handle, _) in self.assists_by_editor.iter() { + if let Some(editor) = editor_handle.upgrade() { + let windows = cx.windows(); + if !windows.is_empty() { + let window = windows[0]; + let _ = window.update(cx, |_, window, cx| { + editor.update(cx, |editor, cx| { + if editor.has_active_inline_completion() { + editor.cancel(&Default::default(), window, cx); + } + }); + }); + } + } + } + } + fn handle_workspace_event( &mut self, workspace: Entity, @@ -176,7 +207,7 @@ impl InlineAssistant { window: &mut Window, cx: &mut App, ) { - let is_assistant2_enabled = true; + let is_assistant2_enabled = !DisableAiSettings::get_global(cx).disable_ai; if let Some(editor) = item.act_as::(cx) { editor.update(cx, |editor, cx| { @@ -199,6 +230,13 @@ impl InlineAssistant { cx, ); + if DisableAiSettings::get_global(cx).disable_ai { + // Cancel any active completions + if editor.has_active_inline_completion() { + editor.cancel(&Default::default(), window, cx); + } + } + // Remove the Assistant1 code action provider, as it still might be registered. editor.remove_code_action_provider("assistant".into(), window, cx); } else { @@ -219,7 +257,7 @@ impl InlineAssistant { cx: &mut Context, ) { let settings = AgentSettings::get_global(cx); - if !settings.enabled { + if !settings.enabled || DisableAiSettings::get_global(cx).disable_ai { return; } @@ -660,7 +698,6 @@ impl InlineAssistant { height: Some(prompt_editor_height), render: build_assist_editor_renderer(prompt_editor), priority: 0, - render_in_minimap: false, }, BlockProperties { style: BlockStyle::Sticky, @@ -675,7 +712,6 @@ impl InlineAssistant { .into_any_element() }), priority: 0, - render_in_minimap: false, }, ]; @@ -1451,7 +1487,6 @@ impl InlineAssistant { .into_any_element() }), priority: 0, - render_in_minimap: false, }); } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index 7a61eef748..ade7a5e13d 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -2,7 +2,6 @@ use crate::agent_model_selector::AgentModelSelector; use crate::buffer_codegen::BufferCodegen; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; -use crate::language_model_selector::ToggleModelSelector; use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases}; use crate::terminal_codegen::TerminalCodegen; use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext}; @@ -38,6 +37,7 @@ use ui::{ CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, prelude::*, }; use workspace::Workspace; +use zed_actions::agent::ToggleModelSelector; pub struct PromptEditor { pub editor: Entity, diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index ff18a95f3f..655e87d7cd 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -3,9 +3,7 @@ use std::{cmp::Reverse, sync::Arc}; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; -use gpui::{ - Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions, -}; +use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task}; use language_model::{ AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, @@ -15,15 +13,6 @@ use picker::{Picker, PickerDelegate}; use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; -actions!( - agent, - [ - /// Toggles the language model selector dropdown. - #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])] - ToggleModelSelector - ] -); - const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; type OnModelChanged = Arc, &mut App) + 'static>; diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 38065b828a..c160f1de04 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -2,18 +2,20 @@ use std::collections::BTreeMap; use std::rc::Rc; use std::sync::Arc; +use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; -use crate::language_model_selector::ToggleModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::ui::{ MaxModeTooltip, preview::{AgentPreview, UsageCallout}, }; +use agent::history_store::HistoryStore; use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, }; use agent_settings::{AgentSettings, CompletionMode}; +use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; use client::UserStore; use collections::{HashMap, HashSet}; @@ -28,12 +30,14 @@ use fs::Fs; use futures::future::Shared; use futures::{FutureExt as _, future}; use gpui::{ - Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, - WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, EventEmitter, Focusable, IntoElement, KeyContext, + Subscription, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, + pulsating_between, }; use language::{Buffer, Language, Point}; use language_model::{ - ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, + ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent, + ZED_CLOUD_PROVIDER_ID, }; use multi_buffer; use project::Project; @@ -47,13 +51,15 @@ use ui::{ }; use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; +use zed_actions::agent::Chat; +use zed_actions::agent::ToggleModelSelector; use zed_llm_client::CompletionIntent; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; use crate::{ - ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll, + ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll, ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode, ToggleContextPicker, ToggleProfileSelector, register_agent_preview, }; @@ -63,6 +69,9 @@ use agent::{ thread_store::{TextThreadStore, ThreadStore}, }; +pub const MIN_EDITOR_LINES: usize = 4; +pub const MAX_EDITOR_LINES: usize = 8; + #[derive(RegisterComponent)] pub struct MessageEditor { thread: Entity, @@ -73,6 +82,7 @@ pub struct MessageEditor { user_store: Entity, context_store: Entity, prompt_store: Option>, + history_store: Option>, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, @@ -86,9 +96,6 @@ pub struct MessageEditor { _subscriptions: Vec, } -const MIN_EDITOR_LINES: usize = 4; -const MAX_EDITOR_LINES: usize = 8; - pub(crate) fn create_editor( workspace: WeakEntity, context_store: WeakEntity, @@ -130,6 +137,7 @@ pub(crate) fn create_editor( placement: Some(ContextMenuPlacement::Above), }); editor.register_addon(ContextCreasesAddon::new()); + editor.register_addon(MessageEditorAddon::new()); editor }); @@ -156,6 +164,7 @@ impl MessageEditor { prompt_store: Option>, thread_store: WeakEntity, text_thread_store: WeakEntity, + history_store: Option>, thread: Entity, window: &mut Window, cx: &mut Context, @@ -228,6 +237,7 @@ impl MessageEditor { workspace, context_store, prompt_store, + history_store, context_strip, context_picker_menu_handle, load_context_task: None, @@ -474,9 +484,12 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { - if let Ok(diff) = - AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx) - { + if let Ok(diff) = AgentDiffPane::deploy( + AgentDiffThread::Native(self.thread.clone()), + self.workspace.clone(), + window, + cx, + ) { let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx); diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx)); } @@ -604,7 +617,11 @@ impl MessageEditor { ) } - fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + fn render_follow_toggle( + &self, + is_model_selected: bool, + cx: &mut Context, + ) -> impl IntoElement { let following = self .workspace .read_with(cx, |workspace, _| { @@ -613,6 +630,7 @@ impl MessageEditor { .unwrap_or(false); IconButton::new("follow-agent", IconName::Crosshair) + .disabled(!is_model_selected) .icon_size(IconSize::Small) .icon_color(Color::Muted) .toggle_state(following) @@ -700,11 +718,11 @@ impl MessageEditor { cx.listener(|this, _: &RejectAll, window, cx| this.handle_reject_all(window, cx)), ) .capture_action(cx.listener(Self::paste)) - .gap_2() .p_2() - .bg(editor_bg_color) + .gap_2() .border_t_1() .border_color(cx.theme().colors().border) + .bg(editor_bg_color) .child( h_flex() .justify_between() @@ -781,7 +799,7 @@ impl MessageEditor { .justify_between() .child( h_flex() - .child(self.render_follow_toggle(cx)) + .child(self.render_follow_toggle(is_model_selected, cx)) .children(self.render_burn_mode_toggle(cx)), ) .child( @@ -897,6 +915,10 @@ impl MessageEditor { .on_click({ let focus_handle = focus_handle.clone(); move |_event, window, cx| { + telemetry::event!( + "Agent Message Sent", + agent = "zed", + ); focus_handle.dispatch_action( &Chat, window, cx, ); @@ -1453,6 +1475,7 @@ impl MessageEditor { tool_choice: None, stop: vec![], temperature: AgentSettings::temperature_for_model(&model.model, cx), + thinking_allowed: true, }; Some(model.model.count_tokens(request, cx)) @@ -1483,6 +1506,31 @@ pub struct ContextCreasesAddon { _subscription: Option, } +pub struct MessageEditorAddon {} + +impl MessageEditorAddon { + pub fn new() -> Self { + Self {} + } +} + +impl Addon for MessageEditorAddon { + fn to_any(&self) -> &dyn std::any::Any { + self + } + + fn to_any_mut(&mut self) -> Option<&mut dyn std::any::Any> { + Some(self) + } + + fn extend_key_context(&self, key_context: &mut KeyContext, cx: &App) { + let settings = agent_settings::AgentSettings::get_global(cx); + if settings.use_modifier_to_send { + key_context.add("use_modifier_to_send"); + } + } +} + impl Addon for ContextCreasesAddon { fn to_any(&self) -> &dyn std::any::Any { self @@ -1618,8 +1666,38 @@ impl Render for MessageEditor { let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5; + let has_configured_providers = LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .count() + > 0; + + let is_signed_out = self + .workspace + .read_with(cx, |workspace, _| { + workspace.client().status().borrow().is_signed_out() + }) + .unwrap_or(true); + + let has_history = self + .history_store + .as_ref() + .and_then(|hs| hs.update(cx, |hs, cx| hs.entries(cx).len() > 0).ok()) + .unwrap_or(false) + || self + .thread + .read_with(cx, |thread, _| thread.messages().len() > 0); + v_flex() .size_full() + .bg(cx.theme().colors().panel_background) + .when( + !has_history && is_signed_out && has_configured_providers, + |this| this.child(cx.new(ApiKeysWithProviders::new)), + ) .when(changed_buffers.len() > 0, |parent| { parent.child(self.render_edits_bar(&changed_buffers, window, cx)) }) @@ -1709,6 +1787,7 @@ impl AgentPreview for MessageEditor { None, thread_store.downgrade(), text_thread_store.downgrade(), + None, thread, window, cx, diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 162b45413f..91867957cd 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -297,6 +297,7 @@ impl TerminalInlineAssistant { tool_choice: None, stop: Vec::new(), temperature, + thinking_allowed: false, } })) } diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index de7606dbfb..3df0a48aa4 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1,8 +1,6 @@ use crate::{ burn_mode_tooltip::BurnModeTooltip, - language_model_selector::{ - LanguageModelSelector, ToggleModelSelector, language_model_selector, - }, + language_model_selector::{LanguageModelSelector, language_model_selector}, }; use agent_settings::{AgentSettings, CompletionMode}; use anyhow::Result; @@ -38,8 +36,7 @@ use language::{ language_settings::{SoftWrap, all_language_settings}, }; use language_model::{ - ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelProviderTosView, - LanguageModelRegistry, Role, + ConfigurationError, LanguageModelExt, LanguageModelImage, LanguageModelRegistry, Role, }; use multi_buffer::MultiBufferRow; use picker::{Picker, popover_menu::PickerPopoverMenu}; @@ -74,6 +71,7 @@ use workspace::{ pane, searchable::{SearchEvent, SearchableItem}, }; +use zed_actions::agent::ToggleModelSelector; use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}; use assistant_context::{ @@ -1256,7 +1254,6 @@ impl TextThreadEditor { ), priority: usize::MAX, render: render_block(MessageMetadata::from(message)), - render_in_minimap: false, }; let mut new_blocks = vec![]; let mut block_index_to_message = vec![]; @@ -1858,7 +1855,6 @@ impl TextThreadEditor { .into_any_element() }), priority: 0, - render_in_minimap: false, }) }) .collect::>(); @@ -1897,108 +1893,6 @@ impl TextThreadEditor { .update(cx, |context, cx| context.summarize(true, cx)); } - fn render_notice(&self, cx: &mut Context) -> Option { - // This was previously gated behind the `zed-pro` feature flag. Since we - // aren't planning to ship that right now, we're just hard-coding this - // value to not show the nudge. - let nudge = Some(false); - - let model_registry = LanguageModelRegistry::read_global(cx); - - if nudge.map_or(false, |value| value) { - Some( - h_flex() - .p_3() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .bg(cx.theme().colors().editor_background) - .justify_between() - .child( - h_flex() - .gap_3() - .child(Icon::new(IconName::ZedAssistant).color(Color::Accent)) - .child(Label::new("Zed AI is here! Get started by signing in β†’")), - ) - .child( - Button::new("sign-in", "Sign in") - .size(ButtonSize::Compact) - .style(ButtonStyle::Filled) - .on_click(cx.listener(|this, _event, _window, cx| { - let client = this - .workspace - .read_with(cx, |workspace, _| workspace.client().clone()) - .log_err(); - - if let Some(client) = client { - cx.spawn(async move |context_editor, cx| { - match client.authenticate_and_connect(true, cx).await { - util::ConnectionResult::Timeout => { - log::error!("Authentication timeout") - } - util::ConnectionResult::ConnectionReset => { - log::error!("Connection reset") - } - util::ConnectionResult::Result(r) => { - if r.log_err().is_some() { - context_editor - .update(cx, |_, cx| cx.notify()) - .ok(); - } - } - } - }) - .detach() - } - })), - ) - .into_any_element(), - ) - } else if let Some(configuration_error) = - model_registry.configuration_error(model_registry.default_model(), cx) - { - Some( - h_flex() - .px_3() - .py_2() - .border_b_1() - .border_color(cx.theme().colors().border_variant) - .bg(cx.theme().colors().editor_background) - .justify_between() - .child( - h_flex() - .gap_3() - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .child(Label::new(configuration_error.to_string())), - ) - .child( - Button::new("open-configuration", "Configure Providers") - .size(ButtonSize::Compact) - .icon(Some(IconName::SlidersVertical)) - .icon_size(IconSize::Small) - .icon_position(IconPosition::Start) - .style(ButtonStyle::Filled) - .on_click({ - let focus_handle = self.focus_handle(cx).clone(); - move |_event, window, cx| { - focus_handle.dispatch_action( - &zed_actions::agent::OpenConfiguration, - window, - cx, - ); - } - }), - ) - .into_any_element(), - ) - } else { - None - } - } - fn render_send_button(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let focus_handle = self.focus_handle(cx).clone(); @@ -2130,12 +2024,13 @@ impl TextThreadEditor { .map(|default| default.model); let model_name = match active_model { Some(model) => model.name().0, - None => SharedString::from("No model selected"), + None => SharedString::from("Select Model"), }; let active_provider = LanguageModelRegistry::read_global(cx) .default_model() .map(|default| default.provider); + let provider_icon = match active_provider { Some(provider) => provider.icon(), None => IconName::Ai, @@ -2583,20 +2478,7 @@ impl EventEmitter for TextThreadEditor {} impl Render for TextThreadEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let provider = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.provider); - - let accept_terms = if self.show_accept_terms { - provider.as_ref().and_then(|provider| { - provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx) - }) - } else { - None - }; - let language_model_selector = self.language_model_selector_menu_handle.clone(); - let burn_mode_toggle = self.render_burn_mode_toggle(cx); v_flex() .key_context("ContextEditor") @@ -2613,28 +2495,12 @@ impl Render for TextThreadEditor { language_model_selector.toggle(window, cx); }) .size_full() - .children(self.render_notice(cx)) .child( div() .flex_grow() .bg(cx.theme().colors().editor_background) .child(self.editor.clone()), ) - .when_some(accept_terms, |this, element| { - this.child( - div() - .absolute() - .right_3() - .bottom_12() - .max_w_96() - .py_2() - .px_3() - .elevation_2(cx) - .bg(cx.theme().colors().surface_background) - .occlude() - .child(element), - ) - }) .children(self.render_last_error(cx)) .child( h_flex() @@ -2651,7 +2517,7 @@ impl Render for TextThreadEditor { h_flex() .gap_0p5() .child(self.render_inject_context_menu(cx)) - .when_some(burn_mode_toggle, |this, element| this.child(element)), + .children(self.render_burn_mode_toggle(cx)), ) .child( h_flex() diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index 43cd0f5e89..b477a8c385 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -1,11 +1,14 @@ mod agent_notification; mod burn_mode_tooltip; mod context_pill; +mod end_trial_upsell; +mod new_thread_button; mod onboarding_modal; pub mod preview; -mod upsell; pub use agent_notification::*; pub use burn_mode_tooltip::*; pub use context_pill::*; +pub use end_trial_upsell::*; +pub use new_thread_button::*; pub use onboarding_modal::*; diff --git a/crates/agent_ui/src/ui/end_trial_upsell.rs b/crates/agent_ui/src/ui/end_trial_upsell.rs new file mode 100644 index 0000000000..36770c2197 --- /dev/null +++ b/crates/agent_ui/src/ui/end_trial_upsell.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use ai_onboarding::{AgentPanelOnboardingCard, BulletItem}; +use client::zed_urls; +use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; +use ui::{Divider, List, Tooltip, prelude::*}; + +#[derive(IntoElement, RegisterComponent)] +pub struct EndTrialUpsell { + dismiss_upsell: Arc, +} + +impl EndTrialUpsell { + pub fn new(dismiss_upsell: Arc) -> Self { + Self { dismiss_upsell } + } +} + +impl RenderOnce for EndTrialUpsell { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let pro_section = v_flex() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child( + List::new() + .child(BulletItem::new("500 prompts with Claude models")) + .child(BulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )), + ) + .child( + Button::new("cta-button", "Upgrade to Zed Pro") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!("Upgrade To Pro Clicked", state = "end-of-trial"); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ); + + let free_section = v_flex() + .mt_1p5() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child( + List::new() + .child(BulletItem::new("50 prompts with the Claude models")) + .child(BulletItem::new("2,000 accepted edit predictions")), + ); + + AgentPanelOnboardingCard::new() + .child(Headline::new("Your Zed Pro Trial has expired")) + .child( + Label::new("You've been automatically reset to the Free plan.") + .color(Color::Muted) + .mb_2(), + ) + .child(pro_section) + .child(free_section) + .child( + h_flex().absolute().top_4().right_4().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click({ + let callback = self.dismiss_upsell.clone(); + move |_, window, cx| { + telemetry::event!("Banner Dismissed", source = "AI Onboarding"); + callback(window, cx) + } + }), + ), + ) + } +} + +impl Component for EndTrialUpsell { + fn scope() -> ComponentScope { + ComponentScope::Agent + } + + fn sort_name() -> &'static str { + "AgentEndTrialUpsell" + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + Some( + v_flex() + .p_4() + .gap_4() + .child(EndTrialUpsell { + dismiss_upsell: Arc::new(|_, _| {}), + }) + .into_any_element(), + ) + } +} diff --git a/crates/agent_ui/src/ui/new_thread_button.rs b/crates/agent_ui/src/ui/new_thread_button.rs new file mode 100644 index 0000000000..7764144150 --- /dev/null +++ b/crates/agent_ui/src/ui/new_thread_button.rs @@ -0,0 +1,75 @@ +use gpui::{ClickEvent, ElementId, IntoElement, ParentElement, Styled}; +use ui::prelude::*; + +#[derive(IntoElement)] +pub struct NewThreadButton { + id: ElementId, + label: SharedString, + icon: IconName, + keybinding: Option, + on_click: Option>, +} + +impl NewThreadButton { + pub fn new(id: impl Into, label: impl Into, icon: IconName) -> Self { + Self { + id: id.into(), + label: label.into(), + icon, + keybinding: None, + on_click: None, + } + } + + pub fn keybinding(mut self, keybinding: Option) -> Self { + self.keybinding = keybinding; + self + } + + pub fn on_click(mut self, handler: F) -> Self + where + F: Fn(&mut Window, &mut App) + 'static, + { + self.on_click = Some(Box::new( + move |_: &ClickEvent, window: &mut Window, cx: &mut App| handler(window, cx), + )); + self + } +} + +impl RenderOnce for NewThreadButton { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + h_flex() + .id(self.id) + .w_full() + .py_1p5() + .px_2() + .gap_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.4)) + .bg(cx.theme().colors().element_active.opacity(0.2)) + .hover(|style| { + style + .bg(cx.theme().colors().element_hover) + .border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new(self.label).size(LabelSize::Small)), + ) + .when_some(self.keybinding, |this, keybinding| { + this.child(keybinding.size(rems_from_px(10.))) + }) + .when_some(self.on_click, |this, on_click| { + this.on_click(move |event, window, cx| on_click(event, window, cx)) + }) + } +} diff --git a/crates/agent_ui/src/ui/upsell.rs b/crates/agent_ui/src/ui/upsell.rs deleted file mode 100644 index f311aade22..0000000000 --- a/crates/agent_ui/src/ui/upsell.rs +++ /dev/null @@ -1,163 +0,0 @@ -use component::{Component, ComponentScope, single_example}; -use gpui::{ - AnyElement, App, ClickEvent, IntoElement, ParentElement, RenderOnce, SharedString, Styled, - Window, -}; -use theme::ActiveTheme; -use ui::{ - Button, ButtonCommon, ButtonStyle, Checkbox, Clickable, Color, Label, LabelCommon, - RegisterComponent, ToggleState, h_flex, v_flex, -}; - -/// A component that displays an upsell message with a call-to-action button -/// -/// # Example -/// ``` -/// let upsell = Upsell::new( -/// "Upgrade to Zed Pro", -/// "Get access to advanced AI features and more", -/// "Upgrade Now", -/// Box::new(|_, _window, cx| { -/// cx.open_url("https://zed.dev/pricing"); -/// }), -/// Box::new(|_, _window, cx| { -/// // Handle dismiss -/// }), -/// Box::new(|checked, window, cx| { -/// // Handle don't show again -/// }), -/// ); -/// ``` -#[derive(IntoElement, RegisterComponent)] -pub struct Upsell { - title: SharedString, - message: SharedString, - cta_text: SharedString, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, -} - -impl Upsell { - /// Create a new upsell component - pub fn new( - title: impl Into, - message: impl Into, - cta_text: impl Into, - on_click: Box, - on_dismiss: Box, - on_dont_show_again: Box, - ) -> Self { - Self { - title: title.into(), - message: message.into(), - cta_text: cta_text.into(), - on_click, - on_dismiss, - on_dont_show_again, - } - } -} - -impl RenderOnce for Upsell { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - v_flex() - .w_full() - .p_4() - .gap_3() - .bg(cx.theme().colors().surface_background) - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border) - .child( - v_flex() - .gap_1() - .child( - Label::new(self.title) - .size(ui::LabelSize::Large) - .weight(gpui::FontWeight::BOLD), - ) - .child(Label::new(self.message).color(Color::Muted)), - ) - .child( - h_flex() - .w_full() - .justify_between() - .items_center() - .child( - h_flex() - .items_center() - .gap_1() - .child( - Checkbox::new("dont-show-again", ToggleState::Unselected).on_click( - move |_, window, cx| { - (self.on_dont_show_again)(true, window, cx); - }, - ), - ) - .child( - Label::new("Don't show again") - .color(Color::Muted) - .size(ui::LabelSize::Small), - ), - ) - .child( - h_flex() - .gap_2() - .child( - Button::new("dismiss-button", "No Thanks") - .style(ButtonStyle::Subtle) - .on_click(self.on_dismiss), - ) - .child( - Button::new("cta-button", self.cta_text) - .style(ButtonStyle::Filled) - .on_click(self.on_click), - ), - ), - ) - } -} - -impl Component for Upsell { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn name() -> &'static str { - "Upsell" - } - - fn description() -> Option<&'static str> { - Some("A promotional component that displays a message with a call-to-action.") - } - - fn preview(window: &mut Window, cx: &mut App) -> Option { - let examples = vec![ - single_example( - "Default", - Upsell::new( - "Upgrade to Zed Pro", - "Get unlimited access to AI features and more with Zed Pro. Unlock advanced AI capabilities and other premium features.", - "Upgrade Now", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - single_example( - "Short Message", - Upsell::new( - "Try Zed Pro for free", - "Start your 7-day trial today.", - "Start Trial", - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - Box::new(|_, _, _| {}), - ).render(window, cx).into_any_element(), - ), - ]; - - Some(v_flex().gap_4().children(examples).into_any_element()) - } -} diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml new file mode 100644 index 0000000000..9031e14e29 --- /dev/null +++ b/crates/ai_onboarding/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ai_onboarding" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/ai_onboarding.rs" + +[features] +default = [] + +[dependencies] +client.workspace = true +component.workspace = true +gpui.workspace = true +language_model.workspace = true +proto.workspace = true +serde.workspace = true +smallvec.workspace = true +telemetry.workspace = true +ui.workspace = true +workspace-hack.workspace = true +zed_actions.workspace = true diff --git a/crates/ai_onboarding/LICENSE-GPL b/crates/ai_onboarding/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/ai_onboarding/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs new file mode 100644 index 0000000000..5f56e4d26e --- /dev/null +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -0,0 +1,146 @@ +use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; +use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use ui::{Divider, List, prelude::*}; + +use crate::BulletItem; + +pub struct ApiKeysWithProviders { + configured_providers: Vec<(IconName, SharedString)>, +} + +impl ApiKeysWithProviders { + pub fn new(cx: &mut Context) -> Self { + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this: &mut Self, _registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + this.configured_providers = Self::compute_configured_providers(cx) + } + _ => {} + }, + ) + .detach(); + + Self { + configured_providers: Self::compute_configured_providers(cx), + } + } + + fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> { + LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .map(|provider| (provider.icon(), provider.name().0.clone())) + .collect() + } +} + +impl Render for ApiKeysWithProviders { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let configured_providers_list = + self.configured_providers + .iter() + .cloned() + .map(|(icon, name)| { + h_flex() + .gap_1p5() + .child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .child(Label::new(name)) + }); + div() + .mx_2p5() + .p_1() + .pb_0() + .gap_2() + .rounded_t_lg() + .border_t_1() + .border_x_1() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().background.alpha(0.5)) + .shadow(vec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .child( + h_flex() + .px_2p5() + .py_1p5() + .gap_2() + .flex_wrap() + .rounded_t(px(5.)) + .overflow_hidden() + .border_t_1() + .border_x_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().panel_background) + .child( + h_flex() + .min_w_0() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted) + ) + .child( + div() + .w_full() + .child( + Label::new("Start now using API keys from your environment for the following providers:") + .color(Color::Muted) + ) + ) + ) + .children(configured_providers_list) + ) + } +} + +#[derive(IntoElement)] +pub struct ApiKeysWithoutProviders; + +impl ApiKeysWithoutProviders { + pub fn new() -> Self { + Self + } +} + +impl RenderOnce for ApiKeysWithoutProviders { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("API Keys") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child(List::new().child(BulletItem::new( + "Add your own keys to use AI without signing in.", + ))) + .child( + Button::new("configure-providers", "Configure Providers") + .full_width() + .style(ButtonStyle::Outlined) + .on_click(move |_, window, cx| { + window.dispatch_action( + zed_actions::agent::OpenConfiguration.boxed_clone(), + cx, + ); + }), + ) + } +} diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_card.rs b/crates/ai_onboarding/src/agent_panel_onboarding_card.rs new file mode 100644 index 0000000000..c63c592642 --- /dev/null +++ b/crates/ai_onboarding/src/agent_panel_onboarding_card.rs @@ -0,0 +1,83 @@ +use gpui::{AnyElement, IntoElement, ParentElement, linear_color_stop, linear_gradient}; +use smallvec::SmallVec; +use ui::{Vector, VectorName, prelude::*}; + +#[derive(IntoElement)] +pub struct AgentPanelOnboardingCard { + children: SmallVec<[AnyElement; 2]>, +} + +impl AgentPanelOnboardingCard { + pub fn new() -> Self { + Self { + children: SmallVec::new(), + } + } +} + +impl ParentElement for AgentPanelOnboardingCard { + fn extend(&mut self, elements: impl IntoIterator) { + self.children.extend(elements) + } +} + +impl RenderOnce for AgentPanelOnboardingCard { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + div() + .m_2p5() + .p(px(3.)) + .elevation_2(cx) + .rounded_lg() + .bg(cx.theme().colors().background.alpha(0.5)) + .child( + v_flex() + .relative() + .size_full() + .px_4() + .py_3() + .gap_2() + .border_1() + .rounded(px(5.)) + .border_color(cx.theme().colors().text.alpha(0.1)) + .overflow_hidden() + .bg(cx.theme().colors().panel_background) + .child( + div() + .opacity(0.5) + .absolute() + .top(px(-8.0)) + .right_0() + .w(px(400.)) + .h(px(92.)) + .rounded_md() + .child( + Vector::new( + VectorName::AiGrid, + rems_from_px(400.), + rems_from_px(92.), + ) + .color(Color::Custom(cx.theme().colors().text.alpha(0.32))), + ), + ) + .child( + div() + .absolute() + .top_0p5() + .right_0p5() + .w(px(660.)) + .h(px(401.)) + .overflow_hidden() + .rounded_md() + .bg(linear_gradient( + 75., + linear_color_stop( + cx.theme().colors().panel_background.alpha(0.01), + 1.0, + ), + linear_color_stop(cx.theme().colors().panel_background, 0.45), + )), + ) + .children(self.children), + ) + } +} diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs new file mode 100644 index 0000000000..e8a62f7ff2 --- /dev/null +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use client::{Client, UserStore}; +use gpui::{Entity, IntoElement, ParentElement}; +use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use ui::prelude::*; + +use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; + +pub struct AgentPanelOnboarding { + user_store: Entity, + client: Arc, + configured_providers: Vec<(IconName, SharedString)>, + continue_with_zed_ai: Arc, +} + +impl AgentPanelOnboarding { + pub fn new( + user_store: Entity, + client: Arc, + continue_with_zed_ai: impl Fn(&mut Window, &mut App) + 'static, + cx: &mut Context, + ) -> Self { + cx.subscribe( + &LanguageModelRegistry::global(cx), + |this: &mut Self, _registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + this.configured_providers = Self::compute_available_providers(cx) + } + _ => {} + }, + ) + .detach(); + + Self { + user_store, + client, + configured_providers: Self::compute_available_providers(cx), + continue_with_zed_ai: Arc::new(continue_with_zed_ai), + } + } + + fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> { + LanguageModelRegistry::read_global(cx) + .providers() + .iter() + .filter(|provider| { + provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID + }) + .map(|provider| (provider.icon(), provider.name().0.clone())) + .collect() + } +} + +impl Render for AgentPanelOnboarding { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let enrolled_in_trial = matches!( + self.user_store.read(cx).current_plan(), + Some(proto::Plan::ZedProTrial) + ); + + let is_pro_user = matches!( + self.user_store.read(cx).current_plan(), + Some(proto::Plan::ZedPro) + ); + + AgentPanelOnboardingCard::new() + .child( + ZedAiOnboarding::new( + self.client.clone(), + &self.user_store, + self.continue_with_zed_ai.clone(), + cx, + ) + .with_dismiss({ + let callback = self.continue_with_zed_ai.clone(); + move |window, cx| callback(window, cx) + }), + ) + .map(|this| { + if enrolled_in_trial || is_pro_user || self.configured_providers.len() >= 1 { + this + } else { + this.child(ApiKeysWithoutProviders::new()) + } + }) + } +} diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs new file mode 100644 index 0000000000..7fffb60ecc --- /dev/null +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -0,0 +1,492 @@ +mod agent_api_keys_onboarding; +mod agent_panel_onboarding_card; +mod agent_panel_onboarding_content; +mod edit_prediction_onboarding_content; +mod young_account_banner; + +pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders}; +pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; +pub use agent_panel_onboarding_content::AgentPanelOnboarding; +pub use edit_prediction_onboarding_content::EditPredictionOnboarding; +pub use young_account_banner::YoungAccountBanner; + +use std::sync::Arc; + +use client::{Client, UserStore, zed_urls}; +use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString}; +use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*}; + +#[derive(IntoElement)] +pub struct BulletItem { + label: SharedString, +} + +impl BulletItem { + pub fn new(label: impl Into) -> Self { + Self { + label: label.into(), + } + } +} + +impl RenderOnce for BulletItem { + fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement { + let line_height = 0.85 * window.line_height(); + + ListItem::new("list-item") + .selectable(false) + .child( + h_flex() + .w_full() + .min_w_0() + .gap_1() + .items_start() + .child( + h_flex().h(line_height).justify_center().child( + Icon::new(IconName::Dash) + .size(IconSize::XSmall) + .color(Color::Hidden), + ), + ) + .child(div().w_full().min_w_0().child(Label::new(self.label))), + ) + .into_any_element() + } +} + +pub enum SignInStatus { + SignedIn, + SigningIn, + SignedOut, +} + +impl From for SignInStatus { + fn from(status: client::Status) -> Self { + if status.is_signing_in() { + Self::SigningIn + } else if status.is_signed_out() { + Self::SignedOut + } else { + Self::SignedIn + } + } +} + +#[derive(RegisterComponent, IntoElement)] +pub struct ZedAiOnboarding { + pub sign_in_status: SignInStatus, + pub has_accepted_terms_of_service: bool, + pub plan: Option, + pub account_too_young: bool, + pub continue_with_zed_ai: Arc, + pub sign_in: Arc, + pub accept_terms_of_service: Arc, + pub dismiss_onboarding: Option>, +} + +impl ZedAiOnboarding { + pub fn new( + client: Arc, + user_store: &Entity, + continue_with_zed_ai: Arc, + cx: &mut App, + ) -> Self { + let store = user_store.read(cx); + let status = *client.status().borrow(); + + Self { + sign_in_status: status.into(), + has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), + plan: store.current_plan(), + account_too_young: store.account_too_young(), + continue_with_zed_ai, + accept_terms_of_service: Arc::new({ + let store = user_store.clone(); + move |_window, cx| { + let task = store.update(cx, |store, cx| store.accept_terms_of_service(cx)); + task.detach_and_log_err(cx); + } + }), + sign_in: Arc::new(move |_window, cx| { + cx.spawn({ + let client = client.clone(); + async move |cx| { + client.authenticate_and_connect(true, cx).await; + } + }) + .detach(); + }), + dismiss_onboarding: None, + } + } + + pub fn with_dismiss( + mut self, + dismiss_callback: impl Fn(&mut Window, &mut App) + 'static, + ) -> Self { + self.dismiss_onboarding = Some(Arc::new(dismiss_callback)); + self + } + + fn free_plan_definition(&self, cx: &mut App) -> impl IntoElement { + v_flex() + .mt_2() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + Label::new("Free") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ) + .child( + Label::new("(Current Plan)") + .size(LabelSize::Small) + .color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6))) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child( + List::new() + .child(BulletItem::new("50 prompts per month with Claude models")) + .child(BulletItem::new( + "2,000 accepted edit predictions with Zeta, our open-source model", + )), + ) + } + + fn pro_trial_definition(&self) -> impl IntoElement { + List::new() + .child(BulletItem::new("150 prompts with Claude models")) + .child(BulletItem::new( + "Unlimited accepted edit predictions with Zeta, our open-source model", + )) + } + + fn pro_plan_definition(&self, cx: &mut App) -> impl IntoElement { + v_flex().mt_2().gap_1().map(|this| { + if self.account_too_young { + this.child( + h_flex() + .gap_2() + .child( + Label::new("Pro") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child( + List::new() + .child(BulletItem::new("500 prompts per month with Claude models")) + .child(BulletItem::new( + "Unlimited accepted edit predictions with Zeta, our open-source model", + )) + .child(BulletItem::new("$20 USD per month")), + ) + .child( + Button::new("pro", "Get Started") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!("Upgrade To Pro Clicked", state = "young-account"); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) + }), + ) + } else { + this.child( + h_flex() + .gap_2() + .child( + Label::new("Pro Trial") + .size(LabelSize::Small) + .color(Color::Accent) + .buffer_font(cx), + ) + .child(Divider::horizontal()), + ) + .child( + List::new() + .child(self.pro_trial_definition()) + .child(BulletItem::new( + "Try it out for 14 days for free, no credit card required", + )), + ) + .child( + Button::new("pro", "Start Free Trial") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(move |_, _window, cx| { + telemetry::event!("Start Trial Clicked", state = "post-sign-in"); + cx.open_url(&zed_urls::start_trial_url(cx)) + }), + ) + } + }) + } + + fn render_accept_terms_of_service(&self) -> AnyElement { + v_flex() + .gap_1() + .w_full() + .child(Headline::new("Accept Terms of Service")) + .child( + Label::new("We don’t sell your data, track you across the web, or compromise your privacy.") + .color(Color::Muted) + .mb_2(), + ) + .child( + Button::new("terms_of_service", "Review Terms of Service") + .full_width() + .style(ButtonStyle::Outlined) + .icon(IconName::ArrowUpRight) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(move |_, _window, cx| { + telemetry::event!("Review Terms of Service Clicked"); + cx.open_url(&zed_urls::terms_of_service(cx)) + }), + ) + .child( + Button::new("accept_terms", "Accept") + .full_width() + .style(ButtonStyle::Tinted(TintColor::Accent)) + .on_click({ + let callback = self.accept_terms_of_service.clone(); + move |_, window, cx| { + telemetry::event!("Terms of Service Accepted"); + (callback)(window, cx)} + }), + ) + .into_any_element() + } + + fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { + let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); + + v_flex() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .child( + Label::new("Sign in to try Zed Pro for 14 days, no credit card required.") + .color(Color::Muted) + .mb_2(), + ) + .child(self.pro_trial_definition()) + .child( + Button::new("sign_in", "Try Zed Pro for Free") + .disabled(signing_in) + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click({ + let callback = self.sign_in.clone(); + move |_, window, cx| { + telemetry::event!("Start Trial Clicked", state = "pre-sign-in"); + callback(window, cx) + } + }), + ) + .into_any_element() + } + + fn render_free_plan_state(&self, cx: &mut App) -> AnyElement { + let young_account_banner = YoungAccountBanner; + + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to Zed AI")) + .map(|this| { + if self.account_too_young { + this.child(young_account_banner) + } else { + this.child(self.free_plan_definition(cx)).when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); + + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + } + }) + .child(self.pro_plan_definition(cx)) + .into_any_element() + } + + fn render_trial_state(&self, _cx: &mut App) -> AnyElement { + v_flex() + .relative() + .gap_1() + .child(Headline::new("Welcome to the Zed Pro Trial")) + .child( + Label::new("Here's what you get for the next 14 days:") + .color(Color::Muted) + .mb_2(), + ) + .child( + List::new() + .child(BulletItem::new("150 prompts with Claude models")) + .child(BulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )), + ) + .when_some( + self.dismiss_onboarding.as_ref(), + |this, dismiss_callback| { + let callback = dismiss_callback.clone(); + this.child( + h_flex().absolute().top_0().right_0().child( + IconButton::new("dismiss_onboarding", IconName::Close) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Dismiss")) + .on_click(move |_, window, cx| { + telemetry::event!( + "Banner Dismissed", + source = "AI Onboarding", + ); + callback(window, cx) + }), + ), + ) + }, + ) + .into_any_element() + } + + fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement { + v_flex() + .gap_1() + .child(Headline::new("Welcome to Zed Pro")) + .child( + Label::new("Here's what you get:") + .color(Color::Muted) + .mb_2(), + ) + .child( + List::new() + .child(BulletItem::new("500 prompts with Claude models")) + .child(BulletItem::new( + "Unlimited edit predictions with Zeta, our open-source model", + )), + ) + .child( + Button::new("pro", "Continue with Zed Pro") + .full_width() + .style(ButtonStyle::Outlined) + .on_click({ + let callback = self.continue_with_zed_ai.clone(); + move |_, window, cx| { + telemetry::event!("Banner Dismissed", source = "AI Onboarding"); + callback(window, cx) + } + }), + ) + .into_any_element() + } +} + +impl RenderOnce for ZedAiOnboarding { + fn render(self, _window: &mut ui::Window, cx: &mut App) -> impl IntoElement { + if matches!(self.sign_in_status, SignInStatus::SignedIn) { + if self.has_accepted_terms_of_service { + match self.plan { + None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), + Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), + Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), + } + } else { + self.render_accept_terms_of_service() + } + } else { + self.render_sign_in_disclaimer(cx) + } + } +} + +impl Component for ZedAiOnboarding { + fn scope() -> ComponentScope { + ComponentScope::Agent + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + fn onboarding( + sign_in_status: SignInStatus, + has_accepted_terms_of_service: bool, + plan: Option, + account_too_young: bool, + ) -> AnyElement { + ZedAiOnboarding { + sign_in_status, + has_accepted_terms_of_service, + plan, + account_too_young, + continue_with_zed_ai: Arc::new(|_, _| {}), + sign_in: Arc::new(|_, _| {}), + accept_terms_of_service: Arc::new(|_, _| {}), + dismiss_onboarding: None, + } + .into_any_element() + } + + Some( + v_flex() + .p_4() + .gap_4() + .children(vec![ + single_example( + "Not Signed-in", + onboarding(SignInStatus::SignedOut, false, None, false), + ), + single_example( + "Not Accepted ToS", + onboarding(SignInStatus::SignedIn, false, None, false), + ), + single_example( + "Account too young", + onboarding(SignInStatus::SignedIn, false, None, true), + ), + single_example( + "Free Plan", + onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), + ), + single_example( + "Pro Trial", + onboarding( + SignInStatus::SignedIn, + true, + Some(proto::Plan::ZedProTrial), + false, + ), + ), + single_example( + "Pro Plan", + onboarding( + SignInStatus::SignedIn, + true, + Some(proto::Plan::ZedPro), + false, + ), + ), + ]) + .into_any_element(), + ) + } +} diff --git a/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs new file mode 100644 index 0000000000..e883d8da8c --- /dev/null +++ b/crates/ai_onboarding/src/edit_prediction_onboarding_content.rs @@ -0,0 +1,73 @@ +use std::sync::Arc; + +use client::{Client, UserStore}; +use gpui::{Entity, IntoElement, ParentElement}; +use ui::prelude::*; + +use crate::ZedAiOnboarding; + +pub struct EditPredictionOnboarding { + user_store: Entity, + client: Arc, + copilot_is_configured: bool, + continue_with_zed_ai: Arc, + continue_with_copilot: Arc, +} + +impl EditPredictionOnboarding { + pub fn new( + user_store: Entity, + client: Arc, + copilot_is_configured: bool, + continue_with_zed_ai: Arc, + continue_with_copilot: Arc, + _cx: &mut Context, + ) -> Self { + Self { + user_store, + copilot_is_configured, + client, + continue_with_zed_ai, + continue_with_copilot, + } + } +} + +impl Render for EditPredictionOnboarding { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let github_copilot = v_flex() + .gap_1() + .child(Label::new(if self.copilot_is_configured { + "Alternatively, you can continue to use GitHub Copilot as that's already set up." + } else { + "Alternatively, you can use GitHub Copilot as your edit prediction provider." + })) + .child( + Button::new( + "configure-copilot", + if self.copilot_is_configured { + "Use Copilot" + } else { + "Configure Copilot" + }, + ) + .full_width() + .style(ButtonStyle::Outlined) + .on_click({ + let callback = self.continue_with_copilot.clone(); + move |_, window, cx| callback(window, cx) + }), + ); + + v_flex() + .gap_2() + .child(ZedAiOnboarding::new( + self.client.clone(), + &self.user_store, + self.continue_with_zed_ai.clone(), + cx, + )) + .child(ui::Divider::horizontal()) + .child(github_copilot) + } +} diff --git a/crates/ai_onboarding/src/young_account_banner.rs b/crates/ai_onboarding/src/young_account_banner.rs new file mode 100644 index 0000000000..a43625a60e --- /dev/null +++ b/crates/ai_onboarding/src/young_account_banner.rs @@ -0,0 +1,21 @@ +use gpui::{IntoElement, ParentElement}; +use ui::{Banner, prelude::*}; + +#[derive(IntoElement)] +pub struct YoungAccountBanner; + +impl RenderOnce for YoungAccountBanner { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + const YOUNG_ACCOUNT_DISCLAIMER: &str = "To prevent abuse of our service, we cannot offer plans to GitHub accounts created fewer than 30 days ago. To request an exception, reach out to billing-support@zed.dev."; + + let label = div() + .w_full() + .text_sm() + .text_color(cx.theme().colors().text_muted) + .child(YOUNG_ACCOUNT_DISCLAIMER); + + div() + .my_1() + .child(Banner::new().severity(ui::Severity::Warning).child(label)) + } +} diff --git a/crates/assistant_context/src/assistant_context.rs b/crates/assistant_context/src/assistant_context.rs index aaaef15250..136468e084 100644 --- a/crates/assistant_context/src/assistant_context.rs +++ b/crates/assistant_context/src/assistant_context.rs @@ -2293,6 +2293,7 @@ impl AssistantContext { tool_choice: None, stop: Vec::new(), temperature: model.and_then(|model| AgentSettings::temperature_for_model(model, cx)), + thinking_allowed: true, }; for message in self.messages(cx) { if message.status != MessageStatus::Done { diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs index dba3bfde61..f139d525d3 100644 --- a/crates/assistant_context/src/assistant_context_tests.rs +++ b/crates/assistant_context/src/assistant_context_tests.rs @@ -1323,7 +1323,7 @@ fn setup_context_editor_with_fake_model( ) -> (Entity, Arc) { let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); - let fake_provider = Arc::new(FakeLanguageModelProvider); + let fake_provider = Arc::new(FakeLanguageModelProvider::default()); let fake_model = Arc::new(fake_provider.test_model()); cx.update(|cx| { diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 3400913eb8..3090a7b234 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -767,6 +767,11 @@ impl ContextStore { fn reload(&mut self, cx: &mut Context) -> Task> { let fs = self.fs.clone(); cx.spawn(async move |this, cx| { + pub static ZED_STATELESS: LazyLock = + LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + if *ZED_STATELESS { + return Ok(()); + } fs.create_dir(contexts_dir()).await?; let mut paths = fs.read_dir(contexts_dir()).await?; diff --git a/crates/assistant_slash_command/src/extension_slash_command.rs b/crates/assistant_slash_command/src/extension_slash_command.rs index 6cc1f73c47..74c46ffb5f 100644 --- a/crates/assistant_slash_command/src/extension_slash_command.rs +++ b/crates/assistant_slash_command/src/extension_slash_command.rs @@ -34,6 +34,11 @@ impl ExtensionSlashCommandProxy for SlashCommandRegistryProxy { self.slash_command_registry .register_command(ExtensionSlashCommand::new(extension, command), false) } + + fn unregister_slash_command(&self, command_name: Arc) { + self.slash_command_registry + .unregister_command_by_name(&command_name) + } } /// An adapter that allows an [`LspAdapterDelegate`] to be used as a [`WorktreeDelegate`]. diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 5a54e86eac..acbe674b02 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -40,6 +40,7 @@ collections = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } ctor.workspace = true gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true language = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } log.workspace = true diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 2071a1f444..672c048872 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -8,7 +8,10 @@ use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; -use util::RangeExt; +use util::{ + RangeExt, ResultExt as _, + paths::{PathStyle, RemotePathBuf}, +}; /// Tracks actions performed by tools in a thread pub struct ActionLog { @@ -18,8 +21,6 @@ pub struct ActionLog { edited_since_project_diagnostics_check: bool, /// The project this action log is associated with project: Entity, - /// Tracks which buffer versions have already been notified as changed externally - notified_versions: BTreeMap, clock::Global>, } impl ActionLog { @@ -29,7 +30,6 @@ impl ActionLog { tracked_buffers: BTreeMap::default(), edited_since_project_diagnostics_check: false, project, - notified_versions: BTreeMap::default(), } } @@ -47,6 +47,65 @@ impl ActionLog { self.edited_since_project_diagnostics_check } + pub fn latest_snapshot(&self, buffer: &Entity) -> Option { + Some(self.tracked_buffers.get(buffer)?.snapshot.clone()) + } + + /// Return a unified diff patch with user edits made since last read or notification + pub fn unnotified_user_edits(&self, cx: &Context) -> Option { + let diffs = self + .tracked_buffers + .values() + .filter_map(|tracked| { + if !tracked.may_have_unnotified_user_edits { + return None; + } + + let text_with_latest_user_edits = tracked.diff_base.to_string(); + let text_with_last_seen_user_edits = tracked.last_seen_base.to_string(); + if text_with_latest_user_edits == text_with_last_seen_user_edits { + return None; + } + let patch = language::unified_diff( + &text_with_last_seen_user_edits, + &text_with_latest_user_edits, + ); + + let buffer = tracked.buffer.clone(); + let file_path = buffer + .read(cx) + .file() + .map(|file| RemotePathBuf::new(file.full_path(cx), PathStyle::Posix).to_proto()) + .unwrap_or_else(|| format!("buffer_{}", buffer.entity_id())); + + let mut result = String::new(); + result.push_str(&format!("--- a/{}\n", file_path)); + result.push_str(&format!("+++ b/{}\n", file_path)); + result.push_str(&patch); + + Some(result) + }) + .collect::>(); + + if diffs.is_empty() { + return None; + } + + let unified_diff = diffs.join("\n\n"); + Some(unified_diff) + } + + /// Return a unified diff patch with user edits made since last read/notification + /// and mark them as notified + pub fn flush_unnotified_user_edits(&mut self, cx: &Context) -> Option { + let patch = self.unnotified_user_edits(cx); + self.tracked_buffers.values_mut().for_each(|tracked| { + tracked.may_have_unnotified_user_edits = false; + tracked.last_seen_base = tracked.diff_base.clone(); + }); + patch + } + fn track_buffer_internal( &mut self, buffer: Entity, @@ -55,7 +114,6 @@ impl ActionLog { ) -> &mut TrackedBuffer { let status = if is_created { if let Some(tracked) = self.tracked_buffers.remove(&buffer) { - self.notified_versions.remove(&buffer); match tracked.status { TrackedBufferStatus::Created { existing_file_content, @@ -97,26 +155,31 @@ impl ActionLog { let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx)); let (diff_update_tx, diff_update_rx) = mpsc::unbounded(); let diff_base; + let last_seen_base; let unreviewed_edits; if is_created { diff_base = Rope::default(); + last_seen_base = Rope::default(); unreviewed_edits = Patch::new(vec![Edit { old: 0..1, new: 0..text_snapshot.max_point().row + 1, }]) } else { diff_base = buffer.read(cx).as_rope().clone(); + last_seen_base = diff_base.clone(); unreviewed_edits = Patch::default(); } TrackedBuffer { buffer: buffer.clone(), diff_base, + last_seen_base, unreviewed_edits, snapshot: text_snapshot.clone(), status, version: buffer.read(cx).version(), diff, diff_update: diff_update_tx, + may_have_unnotified_user_edits: false, _open_lsp_handle: open_lsp_handle, _maintain_diff: cx.spawn({ let buffer = buffer.clone(); @@ -170,7 +233,6 @@ impl ActionLog { // If the buffer had been edited by a tool, but it got // deleted externally, we want to stop tracking it. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); } cx.notify(); } @@ -184,7 +246,6 @@ impl ActionLog { // resurrected externally, we want to clear the edits we // were tracking and reset the buffer's state. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); self.track_buffer_internal(buffer, false, cx); } cx.notify(); @@ -258,10 +319,10 @@ impl ActionLog { buffer_snapshot: text::BufferSnapshot, cx: &mut AsyncApp, ) -> Result<()> { - let rebase = this.read_with(cx, |this, cx| { + let rebase = this.update(cx, |this, cx| { let tracked_buffer = this .tracked_buffers - .get(buffer) + .get_mut(buffer) .context("buffer not tracked")?; let rebase = cx.background_spawn({ @@ -269,23 +330,35 @@ impl ActionLog { let old_snapshot = tracked_buffer.snapshot.clone(); let new_snapshot = buffer_snapshot.clone(); let unreviewed_edits = tracked_buffer.unreviewed_edits.clone(); + let edits = diff_snapshots(&old_snapshot, &new_snapshot); + let mut has_user_changes = false; async move { - let edits = diff_snapshots(&old_snapshot, &new_snapshot); if let ChangeAuthor::User = author { - apply_non_conflicting_edits( + has_user_changes = apply_non_conflicting_edits( &unreviewed_edits, edits, &mut base_text, new_snapshot.as_rope(), ); } - (Arc::new(base_text.to_string()), base_text) + + (Arc::new(base_text.to_string()), base_text, has_user_changes) } }); anyhow::Ok(rebase) })??; - let (new_base_text, new_diff_base) = rebase.await; + let (new_base_text, new_diff_base, has_user_changes) = rebase.await; + + this.update(cx, |this, _| { + let tracked_buffer = this + .tracked_buffers + .get_mut(buffer) + .context("buffer not tracked") + .unwrap(); + tracked_buffer.may_have_unnotified_user_edits |= has_user_changes; + })?; + Self::update_diff( this, buffer, @@ -490,7 +563,6 @@ impl ActionLog { match tracked_buffer.status { TrackedBufferStatus::Created { .. } => { self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); } TrackedBufferStatus::Modified => { @@ -516,7 +588,6 @@ impl ActionLog { match tracked_buffer.status { TrackedBufferStatus::Deleted => { self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); } _ => { @@ -625,7 +696,6 @@ impl ActionLog { }; self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); cx.notify(); task } @@ -639,7 +709,6 @@ impl ActionLog { // Clear all tracked edits for this buffer and start over as if we just read it. self.tracked_buffers.remove(&buffer); - self.notified_versions.remove(&buffer); self.buffer_read(buffer.clone(), cx); cx.notify(); save @@ -715,6 +784,22 @@ impl ActionLog { cx.notify(); } + pub fn reject_all_edits(&mut self, cx: &mut Context) -> Task<()> { + let futures = self.changed_buffers(cx).into_keys().map(|buffer| { + let reject = self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx); + + async move { + reject.await.log_err(); + } + }); + + let task = futures::future::join_all(futures); + + cx.spawn(async move |_, _| { + task.await; + }) + } + /// Returns the set of buffers that contain edits that haven't been reviewed by the user. pub fn changed_buffers(&self, cx: &App) -> BTreeMap, Entity> { self.tracked_buffers @@ -724,33 +809,6 @@ impl ActionLog { .collect() } - /// Returns stale buffers that haven't been notified yet - pub fn unnotified_stale_buffers<'a>( - &'a self, - cx: &'a App, - ) -> impl Iterator> { - self.stale_buffers(cx).filter(|buffer| { - let buffer_entity = buffer.read(cx); - self.notified_versions - .get(buffer) - .map_or(true, |notified_version| { - *notified_version != buffer_entity.version - }) - }) - } - - /// Marks the given buffers as notified at their current versions - pub fn mark_buffers_as_notified( - &mut self, - buffers: impl IntoIterator>, - cx: &App, - ) { - for buffer in buffers { - let version = buffer.read(cx).version.clone(); - self.notified_versions.insert(buffer, version); - } - } - /// Iterate over buffers changed since last read or edited by the model pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator> { self.tracked_buffers @@ -772,11 +830,12 @@ fn apply_non_conflicting_edits( edits: Vec>, old_text: &mut Rope, new_text: &Rope, -) { +) -> bool { let mut old_edits = patch.edits().iter().cloned().peekable(); let mut new_edits = edits.into_iter().peekable(); let mut applied_delta = 0i32; let mut rebased_delta = 0i32; + let mut has_made_changes = false; while let Some(mut new_edit) = new_edits.next() { let mut conflict = false; @@ -826,8 +885,10 @@ fn apply_non_conflicting_edits( &new_text.chunks_in_range(new_bytes).collect::(), ); applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32; + has_made_changes = true; } } + has_made_changes } fn diff_snapshots( @@ -894,12 +955,14 @@ enum TrackedBufferStatus { struct TrackedBuffer { buffer: Entity, diff_base: Rope, + last_seen_base: Rope, unreviewed_edits: Patch, status: TrackedBufferStatus, version: clock::Global, diff: Entity, snapshot: text::BufferSnapshot, diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>, + may_have_unnotified_user_edits: bool, _open_lsp_handle: OpenLspBufferHandle, _maintain_diff: Task<()>, _subscription: Subscription, @@ -930,6 +993,7 @@ mod tests { use super::*; use buffer_diff::DiffHunkStatusKind; use gpui::TestAppContext; + use indoc::indoc; use language::Point; use project::{FakeFs, Fs, Project, RemoveOptions}; use rand::prelude::*; @@ -1212,6 +1276,110 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + #[gpui::test(iterations = 10)] + async fn test_user_edits_notifications(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/dir"), + json!({"file": indoc! {" + abc + def + ghi + jkl + mno"}}), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + // Agent edits + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx) + .unwrap() + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + indoc! {" + abc + deF + GHI + jkl + mno"} + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(1, 0)..Point::new(3, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\nghi\n".into(), + }], + )] + ); + + // User edits + buffer.update(cx, |buffer, cx| { + buffer.edit( + [ + (Point::new(0, 2)..Point::new(0, 2), "X"), + (Point::new(3, 0)..Point::new(3, 0), "Y"), + ], + None, + cx, + ) + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + indoc! {" + abXc + deF + GHI + Yjkl + mno"} + ); + + // User edits should be stored separately from agent's + let user_edits = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx)); + assert_eq!( + user_edits.expect("should have some user edits"), + indoc! {" + --- a/dir/file + +++ b/dir/file + @@ -1,5 +1,5 @@ + -abc + +abXc + def + ghi + -jkl + +Yjkl + mno + "} + ); + + action_log.update(cx, |log, cx| { + log.keep_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx) + }); + cx.run_until_parked(); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 10)] async fn test_creating_files(cx: &mut TestAppContext) { init_test(cx); @@ -2201,4 +2369,61 @@ mod tests { .collect() }) } + + #[gpui::test] + async fn test_format_patch(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/dir"), + json!({"test.txt": "line 1\nline 2\nline 3\n"}), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/test.txt", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + cx.update(|cx| { + // Track the buffer and mark it as read first + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + }); + + // Make some edits to create a patch + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 0)..Point::new(1, 6), "CHANGED")], None, cx) + .unwrap(); // Replace "line2" with "CHANGED" + }); + }); + + cx.run_until_parked(); + + // Get the patch + let patch = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx)); + + // Verify the patch format contains expected unified diff elements + assert_eq!( + patch.unwrap(), + indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + line 1 + -line 2 + +CHANGED + line 3 + "} + ); + } } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 2b8958feb1..146800e094 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -20,6 +20,7 @@ anyhow.workspace = true assistant_tool.workspace = true buffer_diff.workspace = true chrono.workspace = true +client.workspace = true collections.workspace = true component.workspace = true derive_more.workspace = true @@ -63,6 +64,7 @@ which.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_llm_client.workspace = true +diffy = "0.4.2" [dev-dependencies] lsp = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index eef792f526..57fdc51336 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -20,14 +20,13 @@ mod thinking_tool; mod ui; mod web_search_tool; -use std::sync::Arc; - use assistant_tool::ToolRegistry; use copy_path_tool::CopyPathTool; use gpui::{App, Entity}; use http_client::HttpClientWithUrl; use language_model::LanguageModelRegistry; use move_path_tool::MovePathTool; +use std::sync::Arc; use web_search_tool::WebSearchTool; pub(crate) use templates::*; diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 28d6bef9dd..1922b5677a 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -57,7 +57,7 @@ impl Tool for CopyPathTool { } fn icon(&self) -> IconName { - IconName::Clipboard + IconName::ToolCopy } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index b3e198c1b5..224e8357e5 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -46,7 +46,7 @@ impl Tool for CreateDirectoryTool { } fn icon(&self) -> IconName { - IconName::Folder + IconName::ToolFolder } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index e45c1976d1..b13f9863c9 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -46,7 +46,7 @@ impl Tool for DeletePathTool { } fn icon(&self) -> IconName { - IconName::FileDelete + IconName::ToolDeleteFile } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 3b6d38fc06..84595a37b7 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -59,7 +59,7 @@ impl Tool for DiagnosticsTool { } fn icon(&self) -> IconName { - IconName::XCircle + IconName::ToolDiagnostics } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index c2540633f7..0184dff36c 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -719,6 +719,7 @@ impl EditAgent { tools, stop: Vec::new(), temperature: None, + thinking_allowed: true, }; Ok(self.model.stream_completion_text(request, cx).await?.stream) diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 8df8f677f2..eda7eee0e3 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -12,6 +12,7 @@ use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; use gpui::{AppContext, TestAppContext, Timer}; +use http_client::StatusCode; use indoc::{formatdoc, indoc}; use language_model::{ LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, @@ -365,17 +366,23 @@ fn eval_disable_cursor_blinking() { // Model | Pass rate // ============================================ // - // claude-3.7-sonnet | 0.99 (2025-06-14) - // claude-sonnet-4 | 0.85 (2025-06-14) - // gemini-2.5-pro-preview-latest | 0.97 (2025-06-16) - // gemini-2.5-flash-preview-04-17 | - // gpt-4.1 | + // claude-3.7-sonnet | 0.59 (2025-07-14) + // claude-sonnet-4 | 0.81 (2025-07-14) + // gemini-2.5-pro | 0.95 (2025-07-14) + // gemini-2.5-flash-preview-04-17 | 0.78 (2025-07-14) + // gpt-4.1 | 0.00 (2025-07-14) (follows edit_description too literally) let input_file_path = "root/editor.rs"; let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs"); let edit_description = "Comment out the call to `BlinkManager::enable`"; + let possible_diffs = vec![ + include_str!("evals/fixtures/disable_cursor_blinking/possible-01.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-02.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-03.diff"), + include_str!("evals/fixtures/disable_cursor_blinking/possible-04.diff"), + ]; eval( 100, - 0.95, + 0.51, 0.05, EvalInput::from_conversation( vec![ @@ -433,11 +440,7 @@ fn eval_disable_cursor_blinking() { ), ], Some(input_file_content.into()), - EvalAssertion::judge_diff(indoc! {" - - Calls to BlinkManager in `observe_window_activation` were commented out - - The call to `blink_manager.enable` above the call to show_cursor_names was commented out - - All the edits have valid indentation - "}), + EvalAssertion::assert_diff_any(possible_diffs), ), ); } @@ -1263,6 +1266,7 @@ impl EvalAssertion { content: vec![prompt.into()], cache: false, }], + thinking_allowed: true, ..Default::default() }; let mut response = retry_on_rate_limit(async || { @@ -1599,6 +1603,7 @@ impl EditAgentTest { let conversation = LanguageModelRequest { messages, tools, + thinking_allowed: true, ..Default::default() }; @@ -1671,6 +1676,30 @@ async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Timer::after(retry_after + jitter).await; continue; } + LanguageModelCompletionError::UpstreamProviderError { + status, + retry_after, + .. + } => { + // Only retry for specific status codes + let should_retry = matches!( + *status, + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE + ) || status.as_u16() == 529; + + if !should_retry { + return Err(err.into()); + } + + // Use server-provided retry_after if available, otherwise use default + let retry_after = retry_after.unwrap_or(Duration::from_secs(5)); + let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0)); + eprintln!( + "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}" + ); + Timer::after(retry_after + jitter).await; + continue; + } _ => return Err(err.into()), }, Err(err) => return Err(err), diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff new file mode 100644 index 0000000000..1a38a1967f --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff @@ -0,0 +1,28 @@ +--- before.rs 2025-07-07 11:37:48.434629001 +0300 ++++ expected.rs 2025-07-14 10:33:53.346906775 +0300 +@@ -1780,11 +1780,11 @@ + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff new file mode 100644 index 0000000000..b484cce48f --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff @@ -0,0 +1,29 @@ +@@ -1778,13 +1778,13 @@ + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { +- let active = window.is_window_active(); ++ // let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff new file mode 100644 index 0000000000..431e34e48a --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff @@ -0,0 +1,34 @@ +@@ -1774,17 +1774,17 @@ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), +- cx.observe(&blink_manager, |_, _, cx| cx.notify()), ++ // cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { +- let active = window.is_window_active(); ++ // let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff new file mode 100644 index 0000000000..64a6b85dd3 --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff @@ -0,0 +1,33 @@ +@@ -1774,17 +1774,17 @@ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), +- cx.observe(&blink_manager, |_, _, cx| cx.notify()), ++ // cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { +- if active { +- blink_manager.enable(cx); +- } else { +- blink_manager.disable(cx); +- } ++ // if active { ++ // blink_manager.enable(cx); ++ // } else { ++ // blink_manager.disable(cx); ++ // } + }); + }), + ], +@@ -18463,7 +18463,7 @@ + } + + self.blink_manager.update(cx, |blink_manager, cx| { +- blink_manager.enable(cx); ++ // blink_manager.enable(cx); + }); + self.show_cursor_names(window, cx); + self.buffer.update(cx, |buffer, cx| { diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index baf62c11f2..6413677bd9 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -139,7 +139,7 @@ impl Tool for EditFileTool { } fn icon(&self) -> IconName { - IconName::Pencil + IconName::ToolPencil } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { @@ -278,6 +278,9 @@ impl Tool for EditFileTool { .unwrap_or(false); if format_on_save_enabled { + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; let format_task = project.update(cx, |project, cx| { project.format( HashSet::from_iter([buffer.clone()]), @@ -783,8 +786,8 @@ impl ToolCard for EditFileToolCard { .child( h_flex() .child( - Icon::new(IconName::Pencil) - .size(IconSize::XSmall) + Icon::new(IconName::ToolPencil) + .size(IconSize::Small) .color(Color::Muted), ) .child( diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 82b15b7a86..54d49359ba 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -69,10 +69,9 @@ impl FetchTool { .to_str() .context("invalid Content-Type header")?; let content_type = match content_type { - "text/html" => ContentType::Html, - "text/plain" => ContentType::Plaintext, + "text/html" | "application/xhtml+xml" => ContentType::Html, "application/json" => ContentType::Json, - _ => ContentType::Html, + _ => ContentType::Plaintext, }; match content_type { @@ -130,7 +129,7 @@ impl Tool for FetchTool { } fn icon(&self) -> IconName { - IconName::Globe + IconName::ToolWeb } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index 86e67a8f58..fd0e44e42c 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -68,7 +68,7 @@ impl Tool for FindPathTool { } fn icon(&self) -> IconName { - IconName::SearchCode + IconName::ToolSearch } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { @@ -313,7 +313,7 @@ impl ToolCard for FindPathToolCard { .mb_2() .gap_1() .child( - ToolCallCardHeader::new(IconName::SearchCode, matches_label) + ToolCallCardHeader::new(IconName::ToolSearch, matches_label) .with_code_path(&self.glob) .disclosure_slot( Disclosure::new("path-search-disclosure", self.expanded) diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index eb4c8d38e5..053273d71b 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -70,7 +70,7 @@ impl Tool for GrepTool { } fn icon(&self) -> IconName { - IconName::Regex + IconName::ToolRegex } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index aef186b9ae..723416e2ce 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -58,7 +58,7 @@ impl Tool for ListDirectoryTool { } fn icon(&self) -> IconName { - IconName::Folder + IconName::ToolFolder } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index 01dcbba4ac..7567926dca 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -6,8 +6,7 @@ use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchem use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::fmt::Write as _; -use std::sync::Arc; +use std::{fmt::Write, sync::Arc}; use ui::IconName; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -31,7 +30,7 @@ impl Tool for ProjectNotificationsTool { } fn icon(&self) -> IconName { - IconName::Envelope + IconName::ToolNotification } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { @@ -52,39 +51,113 @@ impl Tool for ProjectNotificationsTool { _window: Option, cx: &mut App, ) -> ToolResult { - let mut stale_files = String::new(); - let mut notified_buffers = Vec::new(); - - for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) { - if let Some(file) = stale_file.read(cx).file() { - writeln!(&mut stale_files, "- {}", file.path().display()).ok(); - notified_buffers.push(stale_file.clone()); - } - } - - if !notified_buffers.is_empty() { - action_log.update(cx, |log, cx| { - log.mark_buffers_as_notified(notified_buffers, cx); - }); - } - - let response = if stale_files.is_empty() { - "No new notifications".to_string() - } else { - // NOTE: Changes to this prompt require a symmetric update in the LLM Worker - const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); - format!("{HEADER}{stale_files}").replace("\r\n", "\n") + let Some(user_edits_diff) = + action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx)) + else { + return result("No new notifications"); }; - Task::ready(Ok(response.into())).into() + // NOTE: Changes to this prompt require a symmetric update in the LLM Worker + const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); + const MAX_BYTES: usize = 8000; + let diff = fit_patch_to_size(&user_edits_diff, MAX_BYTES); + result(&format!("{HEADER}\n\n```diff\n{diff}\n```\n").replace("\r\n", "\n")) } } +fn result(response: &str) -> ToolResult { + Task::ready(Ok(response.to_string().into())).into() +} + +/// Make sure that the patch fits into the size limit (in bytes). +/// Compress the patch by omitting some parts if needed. +/// Unified diff format is assumed. +fn fit_patch_to_size(patch: &str, max_size: usize) -> String { + if patch.len() <= max_size { + return patch.to_string(); + } + + // Compression level 1: remove context lines in diff bodies, but + // leave the counts and positions of inserted/deleted lines + let mut current_size = patch.len(); + let mut file_patches = split_patch(&patch); + file_patches.sort_by_key(|patch| patch.len()); + let compressed_patches = file_patches + .iter() + .rev() + .map(|patch| { + if current_size > max_size { + let compressed = compress_patch(patch).unwrap_or_else(|_| patch.to_string()); + current_size -= patch.len() - compressed.len(); + compressed + } else { + patch.to_string() + } + }) + .collect::>(); + + if current_size <= max_size { + return compressed_patches.join("\n\n"); + } + + // Compression level 2: list paths of the changed files only + let filenames = file_patches + .iter() + .map(|patch| { + let patch = diffy::Patch::from_str(patch).unwrap(); + let path = patch + .modified() + .and_then(|path| path.strip_prefix("b/")) + .unwrap_or_default(); + format!("- {path}\n") + }) + .collect::>(); + + filenames.join("") +} + +/// Split a potentially multi-file patch into multiple single-file patches +fn split_patch(patch: &str) -> Vec { + let mut result = Vec::new(); + let mut current_patch = String::new(); + + for line in patch.lines() { + if line.starts_with("---") && !current_patch.is_empty() { + result.push(current_patch.trim_end_matches('\n').into()); + current_patch = String::new(); + } + current_patch.push_str(line); + current_patch.push('\n'); + } + + if !current_patch.is_empty() { + result.push(current_patch.trim_end_matches('\n').into()); + } + + result +} + +fn compress_patch(patch: &str) -> anyhow::Result { + let patch = diffy::Patch::from_str(patch)?; + let mut out = String::new(); + + writeln!(out, "--- {}", patch.original().unwrap_or("a"))?; + writeln!(out, "+++ {}", patch.modified().unwrap_or("b"))?; + + for hunk in patch.hunks() { + writeln!(out, "@@ -{} +{} @@", hunk.old_range(), hunk.new_range())?; + writeln!(out, "[...skipped...]")?; + } + + Ok(out) +} + #[cfg(test)] mod tests { use super::*; use assistant_tool::ToolResultContent; use gpui::{AppContext, TestAppContext}; + use indoc::indoc; use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider}; use project::{FakeFs, Project}; use serde_json::json; @@ -123,10 +196,11 @@ mod tests { action_log.update(cx, |log, cx| { log.buffer_read(buffer.clone(), cx); }); + cx.run_until_parked(); // Run the tool before any changes let tool = Arc::new(ProjectNotificationsTool); - let provider = Arc::new(FakeLanguageModelProvider); + let provider = Arc::new(FakeLanguageModelProvider::default()); let model: Arc = Arc::new(provider.test_model()); let request = Arc::new(LanguageModelRequest::default()); let tool_input = json!({}); @@ -142,6 +216,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); let response = result.output.await.unwrap(); let response_text = match &response.content { @@ -158,6 +233,7 @@ mod tests { buffer.update(cx, |buffer, cx| { buffer.edit([(1..1, "\nChange!\n")], None, cx); }); + cx.run_until_parked(); // Run the tool again let result = cx.update(|cx| { @@ -171,6 +247,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); // This time the buffer is stale, so the tool should return a notification let response = result.output.await.unwrap(); @@ -179,10 +256,12 @@ mod tests { _ => panic!("Expected text response"), }; - let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n"; - assert_eq!( - response_text.as_str(), - expected_content, + assert!( + response_text.contains("These files have changed"), + "Tool should return the stale buffer notification" + ); + assert!( + response_text.contains("test/code.rs"), "Tool should return the stale buffer notification" ); @@ -198,6 +277,7 @@ mod tests { cx, ) }); + cx.run_until_parked(); let response = result.output.await.unwrap(); let response_text = match &response.content { @@ -212,6 +292,61 @@ mod tests { ); } + #[test] + fn test_patch_compression() { + // Given a patch that doesn't fit into the size budget + let patch = indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + line 1 + -line 2 + +CHANGED + line 3 + @@ -10,2 +10,2 @@ + line 10 + -line 11 + +line eleven + + + --- a/dir/another.txt + +++ b/dir/another.txt + @@ -100,1 +1,1 @@ + -before + +after + "}; + + // When the size deficit can be compensated by dropping the body, + // then the body should be trimmed for larger files first + let limit = patch.len() - 10; + let compressed = fit_patch_to_size(patch, limit); + let expected = indoc! {" + --- a/dir/test.txt + +++ b/dir/test.txt + @@ -1,3 +1,3 @@ + [...skipped...] + @@ -10,2 +10,2 @@ + [...skipped...] + + + --- a/dir/another.txt + +++ b/dir/another.txt + @@ -100,1 +1,1 @@ + -before + +after"}; + assert_eq!(compressed, expected); + + // When the size deficit is too large, then only file paths + // should be returned + let limit = 10; + let compressed = fit_patch_to_size(patch, limit); + let expected = indoc! {" + - dir/another.txt + - dir/test.txt + "}; + assert_eq!(compressed, expected); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 4d40fc6a7c..dc504e2dc4 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -18,7 +18,6 @@ use serde::{Deserialize, Serialize}; use settings::Settings; use std::sync::Arc; use ui::IconName; -use util::markdown::MarkdownInlineCode; /// If the model requests to read a file whose size exceeds this, then #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -68,7 +67,7 @@ impl Tool for ReadFileTool { } fn icon(&self) -> IconName { - IconName::FileSearch + IconName::ToolRead } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { @@ -78,11 +77,21 @@ impl Tool for ReadFileTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { Ok(input) => { - let path = MarkdownInlineCode(&input.path); + let path = &input.path; match (input.start_line, input.end_line) { - (Some(start), None) => format!("Read file {path} (from line {start})"), - (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"), - _ => format!("Read file {path}"), + (Some(start), Some(end)) => { + format!( + "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", + path, start, end, path, start, end + ) + } + (Some(start), None) => { + format!( + "[Read file `{}` (from line {})](@selection:{}:({}-{}))", + path, start, path, start, start + ) + } + _ => format!("[Read file `{}`](@file:{})", path, path), } } Err(_) => "Read file".to_string(), @@ -276,7 +285,10 @@ impl Tool for ReadFileTool { Using the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the - implementations of symbols in the outline." + implementations of symbols in the outline. + + Alternatively, you can fall back to the `grep` tool (if available) + to search the file for specific content." } .into()) } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 6641873182..03e76f6a5b 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -90,7 +90,7 @@ impl Tool for TerminalTool { } fn icon(&self) -> IconName { - IconName::Terminal + IconName::ToolTerminal } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 4641b7359e..422204f97d 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -37,7 +37,7 @@ impl Tool for ThinkingTool { } fn icon(&self) -> IconName { - IconName::LightBulb + IconName::ToolBulb } fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { diff --git a/crates/assistant_tools/src/ui/tool_call_card_header.rs b/crates/assistant_tools/src/ui/tool_call_card_header.rs index a19ea8f2b7..b71453373f 100644 --- a/crates/assistant_tools/src/ui/tool_call_card_header.rs +++ b/crates/assistant_tools/src/ui/tool_call_card_header.rs @@ -82,7 +82,7 @@ impl RenderOnce for ToolCallCardHeader { .child( h_flex().h(line_height).justify_center().child( Icon::new(self.icon) - .size(IconSize::XSmall) + .size(IconSize::Small) .color(Color::Muted), ), ) diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 9430ac9d9e..24bc8e9cba 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -143,6 +143,8 @@ impl ToolCard for WebSearchToolCard { _workspace: WeakEntity, cx: &mut Context, ) -> impl IntoElement { + let icon = IconName::ToolWeb; + let header = match self.response.as_ref() { Some(Ok(response)) => { let text: SharedString = if response.results.len() == 1 { @@ -150,13 +152,12 @@ impl ToolCard for WebSearchToolCard { } else { format!("{} results", response.results.len()).into() }; - ToolCallCardHeader::new(IconName::Globe, "Searched the Web") - .with_secondary_text(text) + ToolCallCardHeader::new(icon, "Searched the Web").with_secondary_text(text) } Some(Err(error)) => { - ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string()) + ToolCallCardHeader::new(icon, "Web Search").with_error(error.to_string()) } - None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(), + None => ToolCallCardHeader::new(icon, "Searching the Web").loading(), }; let content = self.response.as_ref().and_then(|response| match response { diff --git a/crates/aws_http_client/Cargo.toml b/crates/aws_http_client/Cargo.toml index 3760f70fe0..2749286d4c 100644 --- a/crates/aws_http_client/Cargo.toml +++ b/crates/aws_http_client/Cargo.toml @@ -17,7 +17,5 @@ default = [] [dependencies] aws-smithy-runtime-api.workspace = true aws-smithy-types.workspace = true -futures.workspace = true http_client.workspace = true -tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } workspace-hack.workspace = true diff --git a/crates/aws_http_client/src/aws_http_client.rs b/crates/aws_http_client/src/aws_http_client.rs index 6adb995747..d08c8e64a7 100644 --- a/crates/aws_http_client/src/aws_http_client.rs +++ b/crates/aws_http_client/src/aws_http_client.rs @@ -11,14 +11,11 @@ use aws_smithy_runtime_api::client::result::ConnectorError; use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; use aws_smithy_runtime_api::http::{Headers, StatusCode}; use aws_smithy_types::body::SdkBody; -use futures::AsyncReadExt; -use http_client::{AsyncBody, Inner}; +use http_client::AsyncBody; use http_client::{HttpClient, Request}; -use tokio::runtime::Handle; struct AwsHttpConnector { client: Arc, - handle: Handle, } impl std::fmt::Debug for AwsHttpConnector { @@ -42,18 +39,17 @@ impl AwsConnector for AwsHttpConnector { .client .send(Request::from_parts(parts, convert_to_async_body(body))); - let handle = self.handle.clone(); - HttpConnectorFuture::new(async move { let response = match response.await { Ok(response) => response, Err(err) => return Err(ConnectorError::other(err.into(), None)), }; let (parts, body) = response.into_parts(); - let body = convert_to_sdk_body(body, handle).await; - let mut response = - HttpResponse::new(StatusCode::try_from(parts.status.as_u16()).unwrap(), body); + let mut response = HttpResponse::new( + StatusCode::try_from(parts.status.as_u16()).unwrap(), + convert_to_sdk_body(body), + ); let headers = match Headers::try_from(parts.headers) { Ok(headers) => headers, @@ -70,7 +66,6 @@ impl AwsConnector for AwsHttpConnector { #[derive(Clone)] pub struct AwsHttpClient { client: Arc, - handler: Handle, } impl std::fmt::Debug for AwsHttpClient { @@ -80,11 +75,8 @@ impl std::fmt::Debug for AwsHttpClient { } impl AwsHttpClient { - pub fn new(client: Arc, handle: Handle) -> Self { - Self { - client, - handler: handle, - } + pub fn new(client: Arc) -> Self { + Self { client } } } @@ -96,25 +88,12 @@ impl AwsClient for AwsHttpClient { ) -> SharedHttpConnector { SharedHttpConnector::new(AwsHttpConnector { client: self.client.clone(), - handle: self.handler.clone(), }) } } -pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody { - match body.0 { - Inner::Empty => SdkBody::empty(), - Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()), - Inner::AsyncReader(mut reader) => { - let buffer = handle.spawn(async move { - let mut buffer = Vec::new(); - let _ = reader.read_to_end(&mut buffer).await; - buffer - }); - - SdkBody::from(buffer.await.unwrap_or_default()) - } - } +pub fn convert_to_sdk_body(body: AsyncBody) -> SdkBody { + SdkBody::from_body_1_x(body) } pub fn convert_to_async_body(body: SdkBody) -> AsyncBody { diff --git a/crates/buffer_diff/src/buffer_diff.rs b/crates/buffer_diff/src/buffer_diff.rs index ee09fda46e..97f529fe37 100644 --- a/crates/buffer_diff/src/buffer_diff.rs +++ b/crates/buffer_diff/src/buffer_diff.rs @@ -343,8 +343,7 @@ impl BufferDiffInner { .. } in hunks.iter().cloned() { - let preceding_pending_hunks = - old_pending_hunks.slice(&buffer_range.start, Bias::Left, buffer); + let preceding_pending_hunks = old_pending_hunks.slice(&buffer_range.start, Bias::Left); pending_hunks.append(preceding_pending_hunks, buffer); // Skip all overlapping or adjacent old pending hunks @@ -355,7 +354,7 @@ impl BufferDiffInner { .cmp(&buffer_range.end, buffer) .is_le() }) { - old_pending_hunks.next(buffer); + old_pending_hunks.next(); } if (stage && secondary_status == DiffHunkSecondaryStatus::NoSecondaryHunk) @@ -379,10 +378,10 @@ impl BufferDiffInner { ); } // append the remainder - pending_hunks.append(old_pending_hunks.suffix(buffer), buffer); + pending_hunks.append(old_pending_hunks.suffix(), buffer); let mut unstaged_hunk_cursor = unstaged_diff.hunks.cursor::(buffer); - unstaged_hunk_cursor.next(buffer); + unstaged_hunk_cursor.next(); // then, iterate over all pending hunks (both new ones and the existing ones) and compute the edits let mut prev_unstaged_hunk_buffer_end = 0; @@ -397,8 +396,7 @@ impl BufferDiffInner { }) = pending_hunks_iter.next() { // Advance unstaged_hunk_cursor to skip unstaged hunks before current hunk - let skipped_unstaged = - unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left, buffer); + let skipped_unstaged = unstaged_hunk_cursor.slice(&buffer_range.start, Bias::Left); if let Some(unstaged_hunk) = skipped_unstaged.last() { prev_unstaged_hunk_base_text_end = unstaged_hunk.diff_base_byte_range.end; @@ -425,7 +423,7 @@ impl BufferDiffInner { buffer_offset_range.end = buffer_offset_range.end.max(unstaged_hunk_offset_range.end); - unstaged_hunk_cursor.next(buffer); + unstaged_hunk_cursor.next(); continue; } } @@ -514,7 +512,7 @@ impl BufferDiffInner { }); let anchor_iter = iter::from_fn(move || { - cursor.next(buffer); + cursor.next(); cursor.item() }) .flat_map(move |hunk| { @@ -531,12 +529,12 @@ impl BufferDiffInner { }); let mut pending_hunks_cursor = self.pending_hunks.cursor::(buffer); - pending_hunks_cursor.next(buffer); + pending_hunks_cursor.next(); let mut secondary_cursor = None; if let Some(secondary) = secondary.as_ref() { let mut cursor = secondary.hunks.cursor::(buffer); - cursor.next(buffer); + cursor.next(); secondary_cursor = Some(cursor); } @@ -564,7 +562,7 @@ impl BufferDiffInner { .cmp(&pending_hunks_cursor.start().buffer_range.start, buffer) .is_gt() { - pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left, buffer); + pending_hunks_cursor.seek_forward(&start_anchor, Bias::Left); } if let Some(pending_hunk) = pending_hunks_cursor.item() { @@ -590,7 +588,7 @@ impl BufferDiffInner { .cmp(&secondary_cursor.start().buffer_range.start, buffer) .is_gt() { - secondary_cursor.seek_forward(&start_anchor, Bias::Left, buffer); + secondary_cursor.seek_forward(&start_anchor, Bias::Left); } if let Some(secondary_hunk) = secondary_cursor.item() { @@ -635,7 +633,7 @@ impl BufferDiffInner { }); iter::from_fn(move || { - cursor.prev(buffer); + cursor.prev(); let hunk = cursor.item()?; let range = hunk.buffer_range.to_point(buffer); @@ -653,8 +651,8 @@ impl BufferDiffInner { fn compare(&self, old: &Self, new_snapshot: &text::BufferSnapshot) -> Option> { let mut new_cursor = self.hunks.cursor::<()>(new_snapshot); let mut old_cursor = old.hunks.cursor::<()>(new_snapshot); - old_cursor.next(new_snapshot); - new_cursor.next(new_snapshot); + old_cursor.next(); + new_cursor.next(); let mut start = None; let mut end = None; @@ -669,7 +667,7 @@ impl BufferDiffInner { Ordering::Less => { start.get_or_insert(new_hunk.buffer_range.start); end.replace(new_hunk.buffer_range.end); - new_cursor.next(new_snapshot); + new_cursor.next(); } Ordering::Equal => { if new_hunk != old_hunk { @@ -686,25 +684,25 @@ impl BufferDiffInner { } } - new_cursor.next(new_snapshot); - old_cursor.next(new_snapshot); + new_cursor.next(); + old_cursor.next(); } Ordering::Greater => { start.get_or_insert(old_hunk.buffer_range.start); end.replace(old_hunk.buffer_range.end); - old_cursor.next(new_snapshot); + old_cursor.next(); } } } (Some(new_hunk), None) => { start.get_or_insert(new_hunk.buffer_range.start); end.replace(new_hunk.buffer_range.end); - new_cursor.next(new_snapshot); + new_cursor.next(); } (None, Some(old_hunk)) => { start.get_or_insert(old_hunk.buffer_range.start); end.replace(old_hunk.buffer_range.end); - old_cursor.next(new_snapshot); + old_cursor.next(); } (None, None) => break, } diff --git a/crates/call/src/call_impl/room.rs b/crates/call/src/call_impl/room.rs index 31ca144cf8..afeee4c924 100644 --- a/crates/call/src/call_impl/room.rs +++ b/crates/call/src/call_impl/room.rs @@ -11,15 +11,18 @@ use client::{ use collections::{BTreeMap, HashMap, HashSet}; use fs::Fs; use futures::{FutureExt, StreamExt}; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use gpui::{ + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource, + ScreenCaptureStream, Task, WeakEntity, +}; use gpui_tokio::Tokio; use language::LanguageRegistry; use livekit::{LocalTrackPublication, ParticipantIdentity, RoomEvent}; -use livekit_client::{self as livekit, TrackSid}; +use livekit_client::{self as livekit, AudioStream, TrackSid}; use postage::{sink::Sink, stream::Stream, watch}; use project::Project; use settings::Settings as _; -use std::{any::Any, future::Future, mem, rc::Rc, sync::Arc, time::Duration}; +use std::{future::Future, mem, rc::Rc, sync::Arc, time::Duration}; use util::{ResultExt, TryFutureExt, post_inc}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -1251,12 +1254,21 @@ impl Room { }) } - pub fn is_screen_sharing(&self) -> bool { + pub fn is_sharing_screen(&self) -> bool { self.live_kit.as_ref().map_or(false, |live_kit| { !matches!(live_kit.screen_track, LocalTrack::None) }) } + pub fn shared_screen_id(&self) -> Option { + self.live_kit.as_ref().and_then(|lk| match lk.screen_track { + LocalTrack::Published { ref _stream, .. } => { + _stream.metadata().ok().map(|meta| meta.id) + } + _ => None, + }) + } + pub fn is_sharing_mic(&self) -> bool { self.live_kit.as_ref().map_or(false, |live_kit| { !matches!(live_kit.microphone_track, LocalTrack::None) @@ -1369,11 +1381,15 @@ impl Room { }) } - pub fn share_screen(&mut self, cx: &mut Context) -> Task> { + pub fn share_screen( + &mut self, + source: Rc, + cx: &mut Context, + ) -> Task> { if self.status.is_offline() { return Task::ready(Err(anyhow!("room is offline"))); } - if self.is_screen_sharing() { + if self.is_sharing_screen() { return Task::ready(Err(anyhow!("screen was already shared"))); } @@ -1386,13 +1402,8 @@ impl Room { return Task::ready(Err(anyhow!("live-kit was not initialized"))); }; - let sources = cx.screen_capture_sources(); - cx.spawn(async move |this, cx| { - let sources = sources.await??; - let source = sources.first().context("no display found")?; - - let publication = participant.publish_screenshare_track(&**source, cx).await; + let publication = participant.publish_screenshare_track(&*source, cx).await; this.update(cx, |this, cx| { let live_kit = this @@ -1419,7 +1430,7 @@ impl Room { } else { live_kit.screen_track = LocalTrack::Published { track_publication: publication, - _stream: Box::new(stream), + _stream: stream, }; cx.notify(); } @@ -1485,7 +1496,7 @@ impl Room { } } - pub fn unshare_screen(&mut self, cx: &mut Context) -> Result<()> { + pub fn unshare_screen(&mut self, play_sound: bool, cx: &mut Context) -> Result<()> { anyhow::ensure!(!self.status.is_offline(), "room is offline"); let live_kit = self @@ -1509,7 +1520,10 @@ impl Room { cx.notify(); } - Audio::play_sound(Sound::StopScreenshare, cx); + if play_sound { + Audio::play_sound(Sound::StopScreenshare, cx); + } + Ok(()) } } @@ -1617,8 +1631,8 @@ fn spawn_room_connection( struct LiveKitRoom { room: Rc, - screen_track: LocalTrack, - microphone_track: LocalTrack, + screen_track: LocalTrack, + microphone_track: LocalTrack, /// Tracks whether we're currently in a muted state due to auto-mute from deafening or manual mute performed by user. muted_by_user: bool, deafened: bool, @@ -1656,18 +1670,18 @@ impl LiveKitRoom { } } -enum LocalTrack { +enum LocalTrack { None, Pending { publish_id: usize, }, Published { track_publication: LocalTrackPublication, - _stream: Box, + _stream: Box, }, } -impl Default for LocalTrack { +impl Default for LocalTrack { fn default() -> Self { Self::None } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 8394972d43..866e3ccd90 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -333,7 +333,7 @@ impl ChannelChat { if first_id <= message_id { let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(&()); let message_id = ChannelMessageId::Saved(message_id); - cursor.seek(&message_id, Bias::Left, &()); + cursor.seek(&message_id, Bias::Left); return ControlFlow::Break( if cursor .item() @@ -499,7 +499,7 @@ impl ChannelChat { pub fn message(&self, ix: usize) -> &ChannelMessage { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(ix), Bias::Right, &()); + cursor.seek(&Count(ix), Bias::Right); cursor.item().unwrap() } @@ -516,13 +516,13 @@ impl ChannelChat { pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&Count(range.start), Bias::Right, &()); + cursor.seek(&Count(range.start), Bias::Right); cursor.take(range.len()) } pub fn pending_messages(&self) -> impl Iterator { let mut cursor = self.messages.cursor::(&()); - cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); + cursor.seek(&ChannelMessageId::Pending(0), Bias::Left); cursor } @@ -588,9 +588,9 @@ impl ChannelChat { .collect::>(); let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(&()); - let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); + let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left); let start_ix = old_cursor.start().1.0; - let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); + let removed_messages = old_cursor.slice(&last_message.id, Bias::Right); let removed_count = removed_messages.summary().count; let new_count = messages.summary().count; let end_ix = start_ix + removed_count; @@ -599,10 +599,10 @@ impl ChannelChat { let mut ranges = Vec::>::new(); if new_messages.last().unwrap().is_pending() { - new_messages.append(old_cursor.suffix(&()), &()); + new_messages.append(old_cursor.suffix(), &()); } else { new_messages.append( - old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), + old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left), &(), ); @@ -617,7 +617,7 @@ impl ChannelChat { } else { new_messages.push(message.clone(), &()); } - old_cursor.next(&()); + old_cursor.next(); } } @@ -641,12 +641,12 @@ impl ChannelChat { fn message_removed(&mut self, id: u64, cx: &mut Context) { let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &()); + let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left); if let Some(item) = cursor.item() { if item.id == ChannelMessageId::Saved(id) { let deleted_message_ix = messages.summary().count; - cursor.next(&()); - messages.append(cursor.suffix(&()), &()); + cursor.next(); + messages.append(cursor.suffix(), &()); drop(cursor); self.messages = messages; @@ -680,7 +680,7 @@ impl ChannelChat { cx: &mut Context, ) { let mut cursor = self.messages.cursor::(&()); - let mut messages = cursor.slice(&id, Bias::Left, &()); + let mut messages = cursor.slice(&id, Bias::Left); let ix = messages.summary().count; if let Some(mut message_to_update) = cursor.item().cloned() { @@ -688,10 +688,10 @@ impl ChannelChat { message_to_update.mentions = mentions; message_to_update.edited_at = edited_at; messages.push(message_to_update, &()); - cursor.next(&()); + cursor.next(); } - messages.append(cursor.suffix(&()), &()); + messages.append(cursor.suffix(), &()); drop(cursor); self.messages = messages; diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index d6ddf79ea6..287c62b753 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -315,19 +315,19 @@ fn main() -> Result<()> { }); let stdin_pipe_handle: Option>> = - stdin_tmp_file.map(|tmp_file| { + stdin_tmp_file.map(|mut tmp_file| { thread::spawn(move || { - let stdin = std::io::stdin().lock(); - if io::IsTerminal::is_terminal(&stdin) { - return Ok(()); + let mut stdin = std::io::stdin().lock(); + if !io::IsTerminal::is_terminal(&stdin) { + io::copy(&mut stdin, &mut tmp_file)?; } - return pipe_to_tmp(stdin, tmp_file); + Ok(()) }) }); - let anonymous_fd_pipe_handles: Vec>> = anonymous_fd_tmp_files + let anonymous_fd_pipe_handles: Vec<_> = anonymous_fd_tmp_files .into_iter() - .map(|(file, tmp_file)| thread::spawn(move || pipe_to_tmp(file, tmp_file))) + .map(|(mut file, mut tmp_file)| thread::spawn(move || io::copy(&mut file, &mut tmp_file))) .collect(); if args.foreground { @@ -349,22 +349,6 @@ fn main() -> Result<()> { Ok(()) } -fn pipe_to_tmp(mut src: impl io::Read, mut dest: fs::File) -> Result<()> { - let mut buffer = [0; 8 * 1024]; - loop { - let bytes_read = match src.read(&mut buffer) { - Err(err) if err.kind() == io::ErrorKind::Interrupted => continue, - res => res?, - }; - if bytes_read == 0 { - break; - } - io::Write::write_all(&mut dest, &buffer[..bytes_read])?; - } - io::Write::flush(&mut dest)?; - Ok(()) -} - fn anonymous_fd(path: &str) -> Option { #[cfg(target_os = "linux")] { diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index c4211f72c8..81bb95b514 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -151,6 +151,7 @@ impl Settings for ProxySettings { pub fn init_settings(cx: &mut App) { TelemetrySettings::register(cx); + DisableAiSettings::register(cx); ClientSettings::register(cx); ProxySettings::register(cx); } @@ -301,6 +302,13 @@ impl Status { matches!(self, Self::Connected { .. }) } + pub fn is_signing_in(&self) -> bool { + matches!( + self, + Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting + ) + } + pub fn is_signed_out(&self) -> bool { matches!(self, Self::SignedOut | Self::UpgradeRequired) } @@ -541,6 +549,33 @@ impl settings::Settings for TelemetrySettings { } } +/// Whether to disable all AI features in Zed. +/// +/// Default: false +#[derive(Copy, Clone, Debug)] +pub struct DisableAiSettings { + pub disable_ai: bool, +} + +impl settings::Settings for DisableAiSettings { + const KEY: Option<&'static str> = Some("disable_ai"); + + type FileContent = Option; + + fn load(sources: SettingsSources, _: &mut App) -> Result { + Ok(Self { + disable_ai: sources + .user + .or(sources.server) + .copied() + .flatten() + .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), + }) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} + impl Client { pub fn new( clock: Arc, diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 4983fda5ef..7d39464e4a 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -358,13 +358,13 @@ impl Telemetry { worktree_id: WorktreeId, updated_entries_set: &UpdatedEntriesSet, ) { - let Some(project_type_names) = self.detect_project_types(worktree_id, updated_entries_set) + let Some(project_types) = self.detect_project_types(worktree_id, updated_entries_set) else { return; }; - for project_type_name in project_type_names { - telemetry::event!("Project Opened", project_type = project_type_name); + for project_type in project_types { + telemetry::event!("Project Opened", project_type = project_type); } } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 61e3064eb4..5ed258aa8e 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -764,6 +764,18 @@ impl UserStore { } pub fn current_plan(&self) -> Option { + #[cfg(debug_assertions)] + if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { + return match plan.as_str() { + "free" => Some(proto::Plan::Free), + "trial" => Some(proto::Plan::ZedProTrial), + "pro" => Some(proto::Plan::ZedPro), + _ => { + panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); + } + }; + } + self.current_plan } diff --git a/crates/client/src/zed_urls.rs b/crates/client/src/zed_urls.rs index bfdae468fb..693c7bf836 100644 --- a/crates/client/src/zed_urls.rs +++ b/crates/client/src/zed_urls.rs @@ -17,3 +17,21 @@ fn server_url(cx: &App) -> &str { pub fn account_url(cx: &App) -> String { format!("{server_url}/account", server_url = server_url(cx)) } + +/// Returns the URL to the start trial page on zed.dev. +pub fn start_trial_url(cx: &App) -> String { + format!( + "{server_url}/account/start-trial", + server_url = server_url(cx) + ) +} + +/// Returns the URL to the upgrade page on zed.dev. +pub fn upgrade_to_zed_pro_url(cx: &App) -> String { + format!("{server_url}/account/upgrade", server_url = server_url(cx)) +} + +/// Returns the URL to Zed's terms of service. +pub fn terms_of_service(cx: &App) -> String { + format!("{server_url}/terms-of-service", server_url = server_url(cx)) +} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 55c15cac5a..d3b5048283 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -35,7 +35,7 @@ dashmap.workspace = true derive_more.workspace = true envy = "0.4.2" futures.workspace = true -gpui = { workspace = true, features = ["screen-capture"] } +gpui.workspace = true hex.workspace = true http_client.workspace = true jsonwebtoken.workspace = true @@ -94,6 +94,7 @@ context_server.workspace = true ctor.workspace = true dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } +dap-types.workspace = true debugger_ui = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } extension.workspace = true @@ -126,6 +127,7 @@ sea-orm = { version = "1.1.0-rc.1", features = ["sqlx-sqlite"] } serde_json.workspace = true session = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +smol.workspace = true sqlx = { version = "0.8", features = ["sqlite"] } task.workspace = true theme.workspace = true diff --git a/crates/collab/seed.default.json b/crates/collab/seed.default.json index dee924e103..983594d623 100644 --- a/crates/collab/seed.default.json +++ b/crates/collab/seed.default.json @@ -1,12 +1,33 @@ { "admins": [ "nathansobo", - "as-cii", "maxbrunsfeld", - "iamnbutler", - "mikayla-maki", + "as-cii", "JosephTLyons", - "rgbkrk" + "maxdeviant", + "SomeoneToIgnore", + "mikayla-maki", + "agu-z", + "osiewicz", + "ConradIrwin", + "benbrandt", + "bennetbo", + "smitbarmase", + "notpeter", + "rgbkrk", + "JunkuiZhang", + "Anthony-Eid", + "rtfeldman", + "danilo-leal", + "MrSubidubi", + "cole-miller", + "osyvokon", + "probably-neb", + "mgsloan", + "P1n3appl3", + "mslzed", + "franciskafyi", + "katie-z-geer" ], "channels": ["zed"] } diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 7fca27c5c2..3b0f5396a7 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -11,7 +11,9 @@ use crate::{ db::{User, UserId}, rpc, }; +use ::rpc::proto; use anyhow::Context as _; +use axum::extract; use axum::{ Extension, Json, Router, body::Body, @@ -23,6 +25,7 @@ use axum::{ routing::{get, post}, }; use axum_extra::response::ErasedJson; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; @@ -100,6 +103,8 @@ pub fn routes(rpc_server: Arc) -> Router<(), Body> { .route("/user", get(update_or_create_authenticated_user)) .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) + .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) + .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(billing::router()) .merge(contributors::router()) @@ -334,3 +339,90 @@ async fn create_access_token( encrypted_access_token, })) } + +#[derive(Serialize)] +struct RefreshLlmTokensResponse {} + +async fn refresh_llm_tokens( + Path(user_id): Path, + Extension(rpc_server): Extension>, +) -> Result> { + rpc_server.refresh_llm_tokens_for_user(user_id).await; + + Ok(Json(RefreshLlmTokensResponse {})) +} + +#[derive(Debug, Serialize, Deserialize)] +struct UpdatePlanBody { + pub plan: zed_llm_client::Plan, + pub subscription_period: SubscriptionPeriod, + pub usage: zed_llm_client::CurrentUsage, + pub trial_started_at: Option>, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +struct SubscriptionPeriod { + pub started_at: DateTime, + pub ended_at: DateTime, +} + +#[derive(Serialize)] +struct UpdatePlanResponse {} + +async fn update_plan( + Path(user_id): Path, + Extension(rpc_server): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let plan = match body.plan { + zed_llm_client::Plan::ZedFree => proto::Plan::Free, + zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + }; + + let update_user_plan = proto::UpdateUserPlan { + plan: plan.into(), + trial_started_at: body + .trial_started_at + .map(|trial_started_at| trial_started_at.timestamp() as u64), + is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled), + usage: Some(proto::SubscriptionUsage { + model_requests_usage_amount: body.usage.model_requests.used, + model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)), + edit_predictions_usage_amount: body.usage.edit_predictions.used, + edit_predictions_usage_limit: Some(usage_limit_to_proto( + body.usage.edit_predictions.limit, + )), + }), + subscription_period: Some(proto::SubscriptionPeriod { + started_at: body.subscription_period.started_at.timestamp() as u64, + ended_at: body.subscription_period.ended_at.timestamp() as u64, + }), + account_too_young: Some(body.is_account_too_young), + has_overdue_invoices: Some(body.has_overdue_invoices), + }; + + rpc_server + .update_plan_for_user(user_id, update_user_plan) + .await?; + + Ok(Json(UpdatePlanResponse {})) +} + +fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit { + proto::UsageLimit { + variant: Some(match limit { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + } +} diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index c8df066cbf..9a27e22f87 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,657 +1,40 @@ use anyhow::{Context as _, bail}; -use axum::{ - Extension, Json, Router, - extract::{self, Query}, - routing::{get, post}, -}; -use chrono::{DateTime, SecondsFormat, Utc}; -use collections::HashSet; +use axum::{Extension, Json, Router, extract, routing::post}; +use chrono::{DateTime, Utc}; +use collections::{HashMap, HashSet}; use reqwest::StatusCode; use sea_orm::ActiveValue; use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{str::FromStr, sync::Arc, time::Duration}; -use stripe::{ - BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession, - CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, - CreateBillingPortalSessionFlowDataAfterCompletionRedirect, - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, - CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents, - PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, -}; +use std::{sync::Arc, time::Duration}; +use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus}; use util::{ResultExt, maybe}; +use zed_llm_client::LanguageModelProvider; -use crate::api::events::SnowflakeRow; use crate::db::billing_subscription::{ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, }; -use crate::llm::db::subscription_usage_meter::CompletionMode; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND}; +use crate::llm::db::subscription_usage_meter::{self, CompletionMode}; use crate::rpc::{ResultExt as _, Server}; use crate::stripe_client::{ StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription, - StripeSubscriptionId, UpdateCustomerParams, + StripeSubscriptionId, }; use crate::{AppState, Error, Result}; use crate::{db::UserId, llm::db::LlmDatabase}; use crate::{ db::{ - BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams, + CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, billing_customer, + UpdateBillingSubscriptionParams, billing_customer, }, stripe_billing::StripeBilling, }; pub fn router() -> Router { - Router::new() - .route( - "/billing/preferences", - get(get_billing_preferences).put(update_billing_preferences), - ) - .route( - "/billing/subscriptions", - get(list_billing_subscriptions).post(create_billing_subscription), - ) - .route( - "/billing/subscriptions/manage", - post(manage_billing_subscription), - ) - .route( - "/billing/subscriptions/sync", - post(sync_billing_subscription), - ) - .route("/billing/usage", get(get_current_usage)) -} - -#[derive(Debug, Deserialize)] -struct GetBillingPreferencesParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct BillingPreferencesResponse { - trial_started_at: Option, - max_monthly_llm_usage_spending_in_cents: i32, - model_request_overages_enabled: bool, - model_request_overages_spend_limit_in_cents: i32, -} - -async fn get_billing_preferences( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - let preferences = app.db.get_billing_preferences(user.id).await?; - - Ok(Json(BillingPreferencesResponse { - trial_started_at: billing_customer - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| { - trial_started_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - max_monthly_llm_usage_spending_in_cents: preferences - .as_ref() - .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| { - preferences.max_monthly_llm_usage_spending_in_cents - }), - model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents - }), - })) -} - -#[derive(Debug, Deserialize)] -struct UpdateBillingPreferencesBody { - github_user_id: i32, - #[serde(default)] - max_monthly_llm_usage_spending_in_cents: i32, - #[serde(default)] - model_request_overages_enabled: bool, - #[serde(default)] - model_request_overages_spend_limit_in_cents: i32, -} - -async fn update_billing_preferences( - Extension(app): Extension>, - Extension(rpc_server): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - - let max_monthly_llm_usage_spending_in_cents = - body.max_monthly_llm_usage_spending_in_cents.max(0); - let model_request_overages_spend_limit_in_cents = - body.model_request_overages_spend_limit_in_cents.max(0); - - let billing_preferences = - if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? { - app.db - .update_billing_preferences( - user.id, - &UpdateBillingPreferencesParams { - max_monthly_llm_usage_spending_in_cents: ActiveValue::set( - max_monthly_llm_usage_spending_in_cents, - ), - model_request_overages_enabled: ActiveValue::set( - body.model_request_overages_enabled, - ), - model_request_overages_spend_limit_in_cents: ActiveValue::set( - model_request_overages_spend_limit_in_cents, - ), - }, - ) - .await? - } else { - app.db - .create_billing_preferences( - user.id, - &crate::db::CreateBillingPreferencesParams { - max_monthly_llm_usage_spending_in_cents, - model_request_overages_enabled: body.model_request_overages_enabled, - model_request_overages_spend_limit_in_cents, - }, - ) - .await? - }; - - SnowflakeRow::new( - "Billing Preferences Updated", - Some(user.metrics_id), - user.admin, - None, - json!({ - "user_id": user.id, - "model_request_overages_enabled": billing_preferences.model_request_overages_enabled, - "model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents, - "max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents, - }), + Router::new().route( + "/billing/subscriptions/sync", + post(sync_billing_subscription), ) - .write(&app.kinesis_client, &app.config.kinesis_stream) - .await - .log_err(); - - rpc_server.refresh_llm_tokens_for_user(user.id).await; - - Ok(Json(BillingPreferencesResponse { - trial_started_at: billing_customer - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| { - trial_started_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - max_monthly_llm_usage_spending_in_cents: billing_preferences - .max_monthly_llm_usage_spending_in_cents, - model_request_overages_enabled: billing_preferences.model_request_overages_enabled, - model_request_overages_spend_limit_in_cents: billing_preferences - .model_request_overages_spend_limit_in_cents, - })) -} - -#[derive(Debug, Deserialize)] -struct ListBillingSubscriptionsParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct BillingSubscriptionJson { - id: BillingSubscriptionId, - name: String, - status: StripeSubscriptionStatus, - period: Option, - trial_end_at: Option, - cancel_at: Option, - /// Whether this subscription can be canceled. - is_cancelable: bool, -} - -#[derive(Debug, Serialize)] -struct BillingSubscriptionPeriodJson { - start_at: String, - end_at: String, -} - -#[derive(Debug, Serialize)] -struct ListBillingSubscriptionsResponse { - subscriptions: Vec, -} - -async fn list_billing_subscriptions( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let subscriptions = app.db.get_billing_subscriptions(user.id).await?; - - Ok(Json(ListBillingSubscriptionsResponse { - subscriptions: subscriptions - .into_iter() - .map(|subscription| BillingSubscriptionJson { - id: subscription.id, - name: match subscription.kind { - Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(), - Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(), - Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(), - None => "Zed LLM Usage".to_string(), - }, - status: subscription.stripe_subscription_status, - period: maybe!({ - let start_at = subscription.current_period_start_at()?; - let end_at = subscription.current_period_end_at()?; - - Some(BillingSubscriptionPeriodJson { - start_at: start_at.to_rfc3339_opts(SecondsFormat::Millis, true), - end_at: end_at.to_rfc3339_opts(SecondsFormat::Millis, true), - }) - }), - trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) { - maybe!({ - let end_at = subscription.stripe_current_period_end?; - let end_at = DateTime::from_timestamp(end_at, 0)?; - - Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true)) - }) - } else { - None - }, - cancel_at: subscription.stripe_cancel_at.map(|cancel_at| { - cancel_at - .and_utc() - .to_rfc3339_opts(SecondsFormat::Millis, true) - }), - is_cancelable: subscription.kind != Some(SubscriptionKind::ZedFree) - && subscription.stripe_subscription_status.is_cancelable() - && subscription.stripe_cancel_at.is_none(), - }) - .collect(), - })) -} - -#[derive(Debug, PartialEq, Clone, Copy, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ProductCode { - ZedPro, - ZedProTrial, -} - -#[derive(Debug, Deserialize)] -struct CreateBillingSubscriptionBody { - github_user_id: i32, - product: ProductCode, -} - -#[derive(Debug, Serialize)] -struct CreateBillingSubscriptionResponse { - checkout_session_url: String, -} - -/// Initiates a Stripe Checkout session for creating a billing subscription. -async fn create_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::error!("failed to retrieve Stripe billing object"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? { - let is_checkout_allowed = body.product == ProductCode::ZedProTrial - && existing_subscription.kind == Some(SubscriptionKind::ZedFree); - - if !is_checkout_allowed { - return Err(Error::http( - StatusCode::CONFLICT, - "user already has an active subscription".into(), - )); - } - } - - let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; - if let Some(existing_billing_customer) = &existing_billing_customer { - if existing_billing_customer.has_overdue_invoices { - return Err(Error::http( - StatusCode::PAYMENT_REQUIRED, - "user has overdue invoices".into(), - )); - } - } - - let customer_id = if let Some(existing_customer) = &existing_billing_customer { - let customer_id = StripeCustomerId(existing_customer.stripe_customer_id.clone().into()); - if let Some(email) = user.email_address.as_deref() { - stripe_billing - .client() - .update_customer(&customer_id, UpdateCustomerParams { email: Some(email) }) - .await - // Update of email address is best-effort - continue checkout even if it fails - .context("error updating stripe customer email address") - .log_err(); - } - customer_id - } else { - stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await? - }; - - let success_url = format!( - "{}/account?checkout_complete=1", - app.config.zed_dot_dev_url() - ); - - let checkout_session_url = match body.product { - ProductCode::ZedPro => { - stripe_billing - .checkout_with_zed_pro(&customer_id, &user.github_login, &success_url) - .await? - } - ProductCode::ZedProTrial => { - if let Some(existing_billing_customer) = &existing_billing_customer { - if existing_billing_customer.trial_started_at.is_some() { - return Err(Error::http( - StatusCode::FORBIDDEN, - "user already used free trial".into(), - )); - } - } - - let feature_flags = app.db.get_user_flags(user.id).await?; - - stripe_billing - .checkout_with_zed_pro_trial( - &customer_id, - &user.github_login, - feature_flags, - &success_url, - ) - .await? - } - }; - - Ok(Json(CreateBillingSubscriptionResponse { - checkout_session_url, - })) -} - -#[derive(Debug, PartialEq, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ManageSubscriptionIntent { - /// The user intends to manage their subscription. - /// - /// This will open the Stripe billing portal without putting the user in a specific flow. - ManageSubscription, - /// The user intends to update their payment method. - UpdatePaymentMethod, - /// The user intends to upgrade to Zed Pro. - UpgradeToPro, - /// The user intends to cancel their subscription. - Cancel, - /// The user intends to stop the cancellation of their subscription. - StopCancellation, -} - -#[derive(Debug, Deserialize)] -struct ManageBillingSubscriptionBody { - github_user_id: i32, - intent: ManageSubscriptionIntent, - /// The ID of the subscription to manage. - subscription_id: BillingSubscriptionId, - redirect_to: Option, -} - -#[derive(Debug, Serialize)] -struct ManageBillingSubscriptionResponse { - billing_portal_session_url: Option, -} - -/// Initiates a Stripe customer portal session for managing a billing subscription. -async fn manage_billing_subscription( - Extension(app): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(body.github_user_id) - .await? - .context("user not found")?; - - let Some(stripe_client) = app.real_stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::error!("failed to retrieve Stripe billing object"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; - - let customer = app - .db - .get_billing_customer_by_user_id(user.id) - .await? - .context("billing customer not found")?; - let customer_id = CustomerId::from_str(&customer.stripe_customer_id) - .context("failed to parse customer ID")?; - - let subscription = app - .db - .get_billing_subscription_by_id(body.subscription_id) - .await? - .context("subscription not found")?; - let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) - .context("failed to parse subscription ID")?; - - if body.intent == ManageSubscriptionIntent::StopCancellation { - let updated_stripe_subscription = Subscription::update( - &stripe_client, - &subscription_id, - stripe::UpdateSubscription { - cancel_at_period_end: Some(false), - ..Default::default() - }, - ) - .await?; - - app.db - .update_billing_subscription( - subscription.id, - &UpdateBillingSubscriptionParams { - stripe_cancel_at: ActiveValue::set( - updated_stripe_subscription - .cancel_at - .and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0)) - .map(|time| time.naive_utc()), - ), - ..Default::default() - }, - ) - .await?; - - return Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: None, - })); - } - - let flow = match body.intent { - ManageSubscriptionIntent::ManageSubscription => None, - ManageSubscriptionIntent::UpgradeToPro => { - let zed_pro_price_id: stripe::PriceId = - stripe_billing.zed_pro_price_id().await?.try_into()?; - let zed_free_price_id: stripe::PriceId = - stripe_billing.zed_free_price_id().await?.try_into()?; - - let stripe_subscription = - Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; - - let is_on_zed_pro_trial = stripe_subscription.status == SubscriptionStatus::Trialing - && stripe_subscription.items.data.iter().any(|item| { - item.price - .as_ref() - .map_or(false, |price| price.id == zed_pro_price_id) - }); - if is_on_zed_pro_trial { - let payment_methods = PaymentMethod::list( - &stripe_client, - &stripe::ListPaymentMethods { - customer: Some(stripe_subscription.customer.id()), - ..Default::default() - }, - ) - .await?; - - let has_payment_method = !payment_methods.data.is_empty(); - if !has_payment_method { - return Err(Error::http( - StatusCode::BAD_REQUEST, - "missing payment method".into(), - )); - } - - // If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early. - Subscription::update( - &stripe_client, - &stripe_subscription.id, - stripe::UpdateSubscription { - trial_end: Some(stripe::Scheduled::now()), - ..Default::default() - }, - ) - .await?; - - return Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: None, - })); - } - - let subscription_item_to_update = stripe_subscription - .items - .data - .iter() - .find_map(|item| { - let price = item.price.as_ref()?; - - if price.id == zed_free_price_id { - Some(item.id.clone()) - } else { - None - } - }) - .context("No subscription item to update")?; - - Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm, - subscription_update_confirm: Some( - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm { - subscription: subscription.stripe_subscription_id, - items: vec![ - CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems { - id: subscription_item_to_update.to_string(), - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }, - ], - discounts: None, - }, - ), - ..Default::default() - }) - } - ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate, - after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { - type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, - redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { - return_url: format!( - "{}{path}", - app.config.zed_dot_dev_url(), - path = body.redirect_to.unwrap_or_else(|| "/account".to_string()) - ), - }), - ..Default::default() - }), - ..Default::default() - }), - ManageSubscriptionIntent::Cancel => { - if subscription.kind == Some(SubscriptionKind::ZedFree) { - return Err(Error::http( - StatusCode::BAD_REQUEST, - "free subscription cannot be canceled".into(), - )); - } - - Some(CreateBillingPortalSessionFlowData { - type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, - after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { - type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, - redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { - return_url: format!("{}/account", app.config.zed_dot_dev_url()), - }), - ..Default::default() - }), - subscription_cancel: Some( - stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel { - subscription: subscription.stripe_subscription_id, - retention: None, - }, - ), - ..Default::default() - }) - } - ManageSubscriptionIntent::StopCancellation => unreachable!(), - }; - - let mut params = CreateBillingPortalSession::new(customer_id); - params.flow_data = flow; - let return_url = format!("{}/account", app.config.zed_dot_dev_url()); - params.return_url = Some(&return_url); - - let session = BillingPortalSession::create(&stripe_client, params).await?; - - Ok(Json(ManageBillingSubscriptionResponse { - billing_portal_session_url: Some(session.url), - })) } #[derive(Debug, Deserialize)] @@ -1144,7 +527,7 @@ async fn handle_customer_subscription_event( // When the user's subscription changes, push down any changes to their plan. rpc_server - .update_plan_for_user(billing_customer.user_id) + .update_plan_for_user_legacy(billing_customer.user_id) .await .trace_err(); @@ -1157,157 +540,6 @@ async fn handle_customer_subscription_event( Ok(()) } -#[derive(Debug, Deserialize)] -struct GetCurrentUsageParams { - github_user_id: i32, -} - -#[derive(Debug, Serialize)] -struct UsageCounts { - pub used: i32, - pub limit: Option, - pub remaining: Option, -} - -#[derive(Debug, Serialize)] -struct ModelRequestUsage { - pub model: String, - pub mode: CompletionMode, - pub requests: i32, -} - -#[derive(Debug, Serialize)] -struct CurrentUsage { - pub model_requests: UsageCounts, - pub model_request_usage: Vec, - pub edit_predictions: UsageCounts, -} - -#[derive(Debug, Default, Serialize)] -struct GetCurrentUsageResponse { - pub plan: String, - pub current_usage: Option, -} - -async fn get_current_usage( - Extension(app): Extension>, - Query(params): Query, -) -> Result> { - let user = app - .db - .get_user_by_github_user_id(params.github_user_id) - .await? - .context("user not found")?; - - let feature_flags = app.db.get_user_flags(user.id).await?; - let has_extended_trial = feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG); - - let Some(llm_db) = app.llm_db.clone() else { - return Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "LLM database not available".into(), - )); - }; - - let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else { - return Ok(Json(GetCurrentUsageResponse::default())); - }; - - let subscription_period = maybe!({ - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at, period_end_at)) - }); - - let Some((period_start_at, period_end_at)) = subscription_period else { - return Ok(Json(GetCurrentUsageResponse::default())); - }; - - let usage = llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await?; - - let plan = subscription - .kind - .map(Into::into) - .unwrap_or(zed_llm_client::Plan::ZedFree); - - let model_requests_limit = match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial && has_extended_trial { - 1_000 - } else { - limit - }; - - Some(limit) - } - zed_llm_client::UsageLimit::Unlimited => None, - }; - - let edit_predictions_limit = match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => Some(limit), - zed_llm_client::UsageLimit::Unlimited => None, - }; - - let Some(usage) = usage else { - return Ok(Json(GetCurrentUsageResponse { - plan: plan.as_str().to_string(), - current_usage: Some(CurrentUsage { - model_requests: UsageCounts { - used: 0, - limit: model_requests_limit, - remaining: model_requests_limit, - }, - model_request_usage: Vec::new(), - edit_predictions: UsageCounts { - used: 0, - limit: edit_predictions_limit, - remaining: edit_predictions_limit, - }, - }), - })); - }; - - let subscription_usage_meters = llm_db - .get_current_subscription_usage_meters_for_user(user.id, Utc::now()) - .await?; - - let model_request_usage = subscription_usage_meters - .into_iter() - .filter_map(|(usage_meter, _usage)| { - let model = llm_db.model_by_id(usage_meter.model_id).ok()?; - - Some(ModelRequestUsage { - model: model.name.clone(), - mode: usage_meter.mode, - requests: usage_meter.requests, - }) - }) - .collect::>(); - - Ok(Json(GetCurrentUsageResponse { - plan: plan.as_str().to_string(), - current_usage: Some(CurrentUsage { - model_requests: UsageCounts { - used: usage.model_requests, - limit: model_requests_limit, - remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)), - }, - model_request_usage, - edit_predictions: UsageCounts { - used: usage.edit_predictions, - limit: edit_predictions_limit, - remaining: edit_predictions_limit - .map(|limit| (limit - usage.edit_predictions).max(0)), - }, - }), - })) -} - impl From for StripeSubscriptionStatus { fn from(value: SubscriptionStatus) -> Self { match value { @@ -1416,18 +648,21 @@ async fn sync_model_request_usage_with_stripe( let usage_meters = llm_db .get_current_subscription_usage_meters(Utc::now()) .await?; - let usage_meters = usage_meters - .into_iter() - .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id)) - .collect::>(); - let user_ids = usage_meters - .iter() - .map(|(_, usage)| usage.user_id) - .collect::>(); - let billing_subscriptions = app - .db - .get_active_zed_pro_billing_subscriptions(user_ids) - .await?; + let mut usage_meters_by_user_id = + HashMap::>::default(); + for (usage_meter, usage) in usage_meters { + let meters = usage_meters_by_user_id.entry(usage.user_id).or_default(); + meters.push(usage_meter); + } + + log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions"); + let get_zed_pro_subscriptions_started_at = Utc::now(); + let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?; + log::info!( + "Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}", + billing_subscriptions.len(), + Utc::now() - get_zed_pro_subscriptions_started_at + ); let claude_sonnet_4 = stripe_billing .find_price_by_lookup_key("claude-sonnet-4-requests") @@ -1451,59 +686,90 @@ async fn sync_model_request_usage_with_stripe( .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") .await?; - let usage_meter_count = usage_meters.len(); + let model_mode_combinations = [ + ("claude-opus-4", CompletionMode::Max), + ("claude-opus-4", CompletionMode::Normal), + ("claude-sonnet-4", CompletionMode::Max), + ("claude-sonnet-4", CompletionMode::Normal), + ("claude-3-7-sonnet", CompletionMode::Max), + ("claude-3-7-sonnet", CompletionMode::Normal), + ("claude-3-5-sonnet", CompletionMode::Normal), + ]; - log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters"); + let billing_subscription_count = billing_subscriptions.len(); - for (usage_meter, usage) in usage_meters { + log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions"); + + for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions { maybe!(async { - let Some((billing_customer, billing_subscription)) = - billing_subscriptions.get(&usage.user_id) - else { - bail!( - "Attempted to sync usage meter for user who is not a Stripe customer: {}", - usage.user_id - ); - }; + if staff_user_ids.contains(&user_id) { + return anyhow::Ok(()); + } let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); let stripe_subscription_id = StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); - let model = llm_db.model_by_id(usage_meter.model_id)?; + let usage_meters = usage_meters_by_user_id.get(&user_id); - let (price, meter_event_name) = match model.name.as_str() { - "claude-opus-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), - CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), - }, - "claude-sonnet-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), - CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"), - }, - "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => match usage_meter.mode { - CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), - CompletionMode::Max => { - (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") + for (model, mode) in &model_mode_combinations { + let Ok(model) = + llm_db.model(LanguageModelProvider::Anthropic, model) + else { + log::warn!("Failed to load model for user {user_id}: {model}"); + continue; + }; + + let (price, meter_event_name) = match model.name.as_str() { + "claude-opus-4" => match mode { + CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), + CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), + }, + "claude-sonnet-4" => match mode { + CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), + CompletionMode::Max => { + (&claude_sonnet_4_max, "claude_sonnet_4/requests/max") + } + }, + "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), + "claude-3-7-sonnet" => match mode { + CompletionMode::Normal => { + (&claude_3_7_sonnet, "claude_3_7_sonnet/requests") + } + CompletionMode::Max => { + (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") + } + }, + model_name => { + bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") } - }, - model_name => { - bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") - } - }; + }; - stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) - .await?; - stripe_billing - .bill_model_request_usage( - &stripe_customer_id, - meter_event_name, - usage_meter.requests, - ) - .await?; + let model_requests = usage_meters + .and_then(|usage_meters| { + usage_meters + .iter() + .find(|meter| meter.model_id == model.id && meter.mode == *mode) + }) + .map(|usage_meter| usage_meter.requests) + .unwrap_or(0); + + if model_requests > 0 { + stripe_billing + .subscribe_to_price(&stripe_subscription_id, price) + .await?; + } + + stripe_billing + .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests) + .await + .with_context(|| { + format!( + "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}", + ) + })?; + } Ok(()) }) @@ -1512,7 +778,7 @@ async fn sync_model_request_usage_with_stripe( } log::info!( - "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}", + "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}", Utc::now() - started_at ); diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 6ccc86c520..bc7dd152b0 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -389,53 +389,58 @@ pub async fn post_panic( } } - let backtrace = if panic.backtrace.len() > 25 { - let total = panic.backtrace.len(); - format!( - "{}\n and {} more", - panic - .backtrace - .iter() - .take(20) - .cloned() - .collect::>() - .join("\n"), - total - 20 - ) - } else { - panic.backtrace.join("\n") - }; - if !report_to_slack(&panic) { return Ok(()); } - let backtrace_with_summary = panic.payload + "\n" + &backtrace; - if let Some(slack_panics_webhook) = app.config.slack_panics_webhook.clone() { + let backtrace = if panic.backtrace.len() > 25 { + let total = panic.backtrace.len(); + format!( + "{}\n and {} more", + panic + .backtrace + .iter() + .take(20) + .cloned() + .collect::>() + .join("\n"), + total - 20 + ) + } else { + panic.backtrace.join("\n") + }; + let backtrace_with_summary = panic.payload + "\n" + &backtrace; + + let version = if panic.release_channel == "nightly" + && !panic.app_version.contains("remote-server") + && let Some(sha) = panic.app_commit_sha + { + format!("Zed Nightly {}", sha.chars().take(7).collect::()) + } else { + panic.app_version + }; + let payload = slack::WebhookBody::new(|w| { w.add_section(|s| s.text(slack::Text::markdown("Panic request".to_string()))) .add_section(|s| { - s.add_field(slack::Text::markdown(format!( - "*Version:*\n {} ", - panic.app_version - ))) - .add_field({ - let hostname = app.config.blob_store_url.clone().unwrap_or_default(); - let hostname = hostname.strip_prefix("https://").unwrap_or_else(|| { - hostname.strip_prefix("http://").unwrap_or_default() - }); + s.add_field(slack::Text::markdown(format!("*Version:*\n {version} ",))) + .add_field({ + let hostname = app.config.blob_store_url.clone().unwrap_or_default(); + let hostname = hostname.strip_prefix("https://").unwrap_or_else(|| { + hostname.strip_prefix("http://").unwrap_or_default() + }); - slack::Text::markdown(format!( - "*{} {}:*\n", - panic.os_name, - panic.os_version.unwrap_or_default(), - CRASH_REPORTS_BUCKET, - hostname, - incident_id, - incident_id.chars().take(8).collect::(), - )) - }) + slack::Text::markdown(format!( + "*{} {}:*\n", + panic.os_name, + panic.os_version.unwrap_or_default(), + CRASH_REPORTS_BUCKET, + hostname, + incident_id, + incident_id.chars().take(8).collect::(), + )) + }) }) .add_rich_text(|r| r.add_preformatted(|p| p.add_text(backtrace_with_summary))) }); diff --git a/crates/collab/src/cents.rs b/crates/collab/src/cents.rs deleted file mode 100644 index a05971f141..0000000000 --- a/crates/collab/src/cents.rs +++ /dev/null @@ -1,83 +0,0 @@ -use serde::Serialize; - -/// A number of cents. -#[derive( - Debug, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - Clone, - Copy, - derive_more::Add, - derive_more::AddAssign, - derive_more::Sub, - derive_more::SubAssign, - Serialize, -)] -pub struct Cents(pub u32); - -impl Cents { - pub const ZERO: Self = Self(0); - - pub const fn new(cents: u32) -> Self { - Self(cents) - } - - pub const fn from_dollars(dollars: u32) -> Self { - Self(dollars * 100) - } - - pub fn saturating_sub(self, other: Cents) -> Self { - Self(self.0.saturating_sub(other.0)) - } -} - -#[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_cents_new() { - assert_eq!(Cents::new(50), Cents(50)); - } - - #[test] - fn test_cents_from_dollars() { - assert_eq!(Cents::from_dollars(1), Cents(100)); - assert_eq!(Cents::from_dollars(5), Cents(500)); - } - - #[test] - fn test_cents_zero() { - assert_eq!(Cents::ZERO, Cents(0)); - } - - #[test] - fn test_cents_add() { - assert_eq!(Cents(50) + Cents(30), Cents(80)); - } - - #[test] - fn test_cents_add_assign() { - let mut cents = Cents(50); - cents += Cents(30); - assert_eq!(cents, Cents(80)); - } - - #[test] - fn test_cents_saturating_sub() { - assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20)); - assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0)); - } - - #[test] - fn test_cents_ordering() { - assert!(Cents(50) > Cents(30)); - assert!(Cents(30) < Cents(50)); - assert_eq!(Cents(50), Cents(50)); - } -} diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index cc29245697..8cd1e3ea83 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -42,9 +42,6 @@ pub use tests::TestDb; pub use ids::*; pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; -pub use queries::billing_preferences::{ - CreateBillingPreferencesParams, UpdateBillingPreferencesParams, -}; pub use queries::billing_subscriptions::{ CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, }; diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs index 1a6fbe946a..f370964ecd 100644 --- a/crates/collab/src/db/queries/billing_preferences.rs +++ b/crates/collab/src/db/queries/billing_preferences.rs @@ -1,21 +1,5 @@ -use anyhow::Context as _; - use super::*; -#[derive(Debug)] -pub struct CreateBillingPreferencesParams { - pub max_monthly_llm_usage_spending_in_cents: i32, - pub model_request_overages_enabled: bool, - pub model_request_overages_spend_limit_in_cents: i32, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingPreferencesParams { - pub max_monthly_llm_usage_spending_in_cents: ActiveValue, - pub model_request_overages_enabled: ActiveValue, - pub model_request_overages_spend_limit_in_cents: ActiveValue, -} - impl Database { /// Returns the billing preferences for the given user, if they exist. pub async fn get_billing_preferences( @@ -30,62 +14,4 @@ impl Database { }) .await } - - /// Creates new billing preferences for the given user. - pub async fn create_billing_preferences( - &self, - user_id: UserId, - params: &CreateBillingPreferencesParams, - ) -> Result { - self.transaction(|tx| async move { - let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel { - user_id: ActiveValue::set(user_id), - max_monthly_llm_usage_spending_in_cents: ActiveValue::set( - params.max_monthly_llm_usage_spending_in_cents, - ), - model_request_overages_enabled: ActiveValue::set( - params.model_request_overages_enabled, - ), - model_request_overages_spend_limit_in_cents: ActiveValue::set( - params.model_request_overages_spend_limit_in_cents, - ), - ..Default::default() - }) - .exec_with_returning(&*tx) - .await?; - - Ok(preferences) - }) - .await - } - - /// Updates the billing preferences for the given user. - pub async fn update_billing_preferences( - &self, - user_id: UserId, - params: &UpdateBillingPreferencesParams, - ) -> Result { - self.transaction(|tx| async move { - let preferences = billing_preference::Entity::update_many() - .set(billing_preference::ActiveModel { - max_monthly_llm_usage_spending_in_cents: params - .max_monthly_llm_usage_spending_in_cents - .clone(), - model_request_overages_enabled: params.model_request_overages_enabled.clone(), - model_request_overages_spend_limit_in_cents: params - .model_request_overages_spend_limit_in_cents - .clone(), - ..Default::default() - }) - .filter(billing_preference::Column::UserId.eq(user_id)) - .exec_with_returning(&*tx) - .await?; - - Ok(preferences - .into_iter() - .next() - .context("billing preferences not found")?) - }) - .await - } } diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index f25d0abeaa..9f82e3dbc4 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -199,6 +199,33 @@ impl Database { pub async fn get_active_zed_pro_billing_subscriptions( &self, + ) -> Result> { + self.transaction(|tx| async move { + let mut rows = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .select_also(billing_customer::Entity) + .filter( + billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Active), + ) + .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) + .order_by_asc(billing_subscription::Column::Id) + .stream(&*tx) + .await?; + + let mut subscriptions = HashMap::default(); + while let Some(row) = rows.next().await { + if let (subscription, Some(customer)) = row? { + subscriptions.insert(customer.user_id, (customer, subscription)); + } + } + Ok(subscriptions) + }) + .await + } + + pub async fn get_active_zed_pro_billing_subscriptions_for_users( + &self, user_ids: HashSet, ) -> Result> { self.transaction(|tx| { diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 2b20c8f080..905859ca69 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,6 +1,5 @@ pub mod api; pub mod auth; -mod cents; pub mod db; pub mod env; pub mod executor; @@ -21,7 +20,6 @@ use axum::{ http::{HeaderMap, StatusCode}, response::IntoResponse, }; -pub use cents::*; use db::{ChannelId, Database}; use executor::Executor; use llm::db::LlmDatabase; diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index cf5dec6e28..de74858168 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,8 +1,6 @@ pub mod db; mod token; -use crate::Cents; - pub use token::*; pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; @@ -12,9 +10,3 @@ pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-chec /// The minimum account age an account must have in order to use the LLM service. pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); - -/// The default value to use for maximum spend per month if the user did not -/// explicitly set a maximum spend. -/// -/// Used to prevent surprise bills. -pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 753e591914..5c5de2f36e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -829,7 +829,7 @@ impl Server { // This arrangement ensures we will attempt to process earlier messages first, but fall // back to processing messages arrived later in the spirit of making progress. let mut foreground_message_handlers = FuturesUnordered::new(); - let concurrent_handlers = Arc::new(Semaphore::new(256)); + let concurrent_handlers = Arc::new(Semaphore::new(512)); loop { let next_message = async { let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); @@ -1002,7 +1002,26 @@ impl Server { Ok(()) } - pub async fn update_plan_for_user(self: &Arc, user_id: UserId) -> Result<()> { + pub async fn update_plan_for_user( + self: &Arc, + user_id: UserId, + update_user_plan: proto::UpdateUserPlan, + ) -> Result<()> { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer + .send(connection_id, update_user_plan.clone()) + .trace_err(); + } + + Ok(()) + } + + /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan` + /// message on the Collab server. + /// + /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint. + pub async fn update_plan_for_user_legacy(self: &Arc, user_id: UserId) -> Result<()> { let user = self .app_state .db @@ -1018,14 +1037,7 @@ impl Server { ) .await?; - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, update_user_plan.clone()) - .trace_err(); - } - - Ok(()) + self.update_plan_for_user(user_id, update_user_plan).await } pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { @@ -2836,62 +2848,117 @@ async fn make_update_user_plan_message( account_too_young: Some(account_too_young), has_overdue_invoices: billing_customer .map(|billing_customer| billing_customer.has_overdue_invoices), - usage: usage.map(|usage| { - let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, - }; - - let model_requests_limit = match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; - - zed_llm_client::UsageLimit::Limited(limit) - } - zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit { - zed_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - zed_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - zed_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } - }), + usage: Some( + usage + .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags)) + .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)), + ), }) } +fn model_requests_limit( + plan: zed_llm_client::Plan, + feature_flags: &Vec, +) -> zed_llm_client::UsageLimit { + match plan.model_requests_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == zed_llm_client::Plan::ZedProTrial + && feature_flags + .iter() + .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) + { + 1_000 + } else { + limit + }; + + zed_llm_client::UsageLimit::Limited(limit) + } + zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, + } +} + +fn subscription_usage_to_proto( + plan: proto::Plan, + usage: crate::llm::db::subscription_usage::Model, + feature_flags: &Vec, +) -> proto::SubscriptionUsage { + let plan = match plan { + proto::Plan::Free => zed_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + }; + + proto::SubscriptionUsage { + model_requests_usage_amount: usage.model_requests as u32, + model_requests_usage_limit: Some(proto::UsageLimit { + variant: Some(match model_requests_limit(plan, feature_flags) { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + }), + edit_predictions_usage_amount: usage.edit_predictions as u32, + edit_predictions_usage_limit: Some(proto::UsageLimit { + variant: Some(match plan.edit_predictions_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + }), + } +} + +fn make_default_subscription_usage( + plan: proto::Plan, + feature_flags: &Vec, +) -> proto::SubscriptionUsage { + let plan = match plan { + proto::Plan::Free => zed_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + }; + + proto::SubscriptionUsage { + model_requests_usage_amount: 0, + model_requests_usage_limit: Some(proto::UsageLimit { + variant: Some(match model_requests_limit(plan, feature_flags) { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + }), + edit_predictions_usage_amount: 0, + edit_predictions_usage_limit: Some(proto::UsageLimit { + variant: Some(match plan.edit_predictions_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + }), + } +} + async fn update_user_plan(session: &Session) -> Result<()> { let db = session.db().await; @@ -4112,6 +4179,13 @@ async fn accept_terms_of_service( response.send(proto::AcceptTermsOfServiceResponse { accepted_tos_at: accepted_tos_at.timestamp() as u64, })?; + + // When the user accepts the terms of service, we want to refresh their LLM + // token to grant access. + session + .peer + .send(session.connection_id, proto::RefreshLlmToken {})?; + Ok(()) } diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 8bf6c08158..850b716a9f 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use anyhow::{Context as _, anyhow}; +use anyhow::anyhow; use chrono::Utc; use collections::HashMap; use stripe::SubscriptionStatus; @@ -9,15 +9,10 @@ use uuid::Uuid; use crate::Result; use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ - RealStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, + RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, - StripeCustomerId, StripeCustomerUpdate, StripeCustomerUpdateAddress, StripeCustomerUpdateName, - StripeMeter, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, + StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, UpdateSubscriptionParams, @@ -30,8 +25,6 @@ pub struct StripeBilling { #[derive(Default)] struct StripeBillingState { - meters_by_event_name: HashMap, - price_ids_by_meter_id: HashMap, prices_by_lookup_key: HashMap, } @@ -60,24 +53,11 @@ impl StripeBilling { let mut state = self.state.write().await; - let (meters, prices) = - futures::try_join!(self.client.list_meters(), self.client.list_prices())?; - - for meter in meters { - state - .meters_by_event_name - .insert(meter.event_name.clone(), meter); - } + let prices = self.client.list_prices().await?; for price in prices { if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price.clone()); - } - - if let Some(recurring) = price.recurring { - if let Some(meter) = recurring.meter { - state.price_ids_by_meter_id.insert(meter, price.id); - } + state.prices_by_lookup_key.insert(lookup_key, price); } } @@ -229,93 +209,6 @@ impl StripeBilling { Ok(()) } - pub async fn checkout_with_zed_pro( - &self, - customer_id: &StripeCustomerId, - github_login: &str, - success_url: &str, - ) -> Result { - let zed_pro_price_id = self.zed_pro_price_id().await?; - - let mut params = StripeCreateCheckoutSessionParams::default(); - params.mode = Some(StripeCheckoutSessionMode::Subscription); - params.customer = Some(customer_id); - params.client_reference_id = Some(github_login); - params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }]); - params.success_url = Some(success_url); - params.billing_address_collection = Some(StripeBillingAddressCollection::Required); - params.customer_update = Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }); - - let session = self.client.create_checkout_session(params).await?; - Ok(session.url.context("no checkout session URL")?) - } - - pub async fn checkout_with_zed_pro_trial( - &self, - customer_id: &StripeCustomerId, - github_login: &str, - feature_flags: Vec, - success_url: &str, - ) -> Result { - let zed_pro_price_id = self.zed_pro_price_id().await?; - - let eligible_for_extended_trial = feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG); - - let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 }; - - let mut subscription_metadata = std::collections::HashMap::new(); - if eligible_for_extended_trial { - subscription_metadata.insert( - "promo_feature_flag".to_string(), - AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(), - ); - } - - let mut params = StripeCreateCheckoutSessionParams::default(); - params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(trial_period_days), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: if !subscription_metadata.is_empty() { - Some(subscription_metadata) - } else { - None - }, - }); - params.mode = Some(StripeCheckoutSessionMode::Subscription); - params.payment_method_collection = - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired); - params.customer = Some(customer_id); - params.client_reference_id = Some(github_login); - params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(zed_pro_price_id.to_string()), - quantity: Some(1), - }]); - params.success_url = Some(success_url); - params.billing_address_collection = Some(StripeBillingAddressCollection::Required); - params.customer_update = Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }); - - let session = self.client.create_checkout_session(params).await?; - Ok(session.url.context("no checkout session URL")?) - } - pub async fn subscribe_to_zed_free( &self, customer_id: StripeCustomerId, @@ -342,6 +235,7 @@ impl StripeBilling { price: Some(zed_free_price_id), quantity: Some(1), }], + automatic_tax: Some(StripeAutomaticTax { enabled: true }), }; let subscription = self.client.create_subscription(params).await?; diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 9ffcb2ba6c..6e75a4d874 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -73,6 +73,7 @@ pub enum StripeCancellationDetailsReason { pub struct StripeCreateSubscriptionParams { pub customer: StripeCustomerId, pub items: Vec, + pub automatic_tax: Option, } #[derive(Debug)] @@ -190,6 +191,7 @@ pub struct StripeCreateCheckoutSessionParams<'a> { pub success_url: Option<&'a str>, pub billing_address_collection: Option, pub customer_update: Option, + pub tax_id_collection: Option, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -218,6 +220,16 @@ pub struct StripeCreateCheckoutSessionSubscriptionData { pub trial_settings: Option, } +#[derive(Debug, PartialEq, Clone)] +pub struct StripeTaxIdCollection { + pub enabled: bool, +} + +#[derive(Debug, Clone)] +pub struct StripeAutomaticTax { + pub enabled: bool, +} + #[derive(Debug)] pub struct StripeCheckoutSession { pub url: Option, diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 11b210dd0e..9bb08443ec 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -14,8 +14,8 @@ use crate::stripe_client::{ StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, UpdateCustomerParams, - UpdateSubscriptionParams, + StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, + UpdateCustomerParams, UpdateSubscriptionParams, }; #[derive(Debug, Clone)] @@ -38,6 +38,7 @@ pub struct StripeCreateCheckoutSessionCall { pub success_url: Option, pub billing_address_collection: Option, pub customer_update: Option, + pub tax_id_collection: Option, } pub struct FakeStripeClient { @@ -236,6 +237,7 @@ impl StripeClient for FakeStripeClient { success_url: params.success_url.map(|url| url.to_string()), billing_address_collection: params.billing_address_collection, customer_update: params.customer_update, + tax_id_collection: params.tax_id_collection, }); Ok(StripeCheckoutSession { diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 7108e8d759..07c191ff30 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -10,16 +10,17 @@ use stripe::{ CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription, - SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateCustomer, UpdateSubscriptionItems, - UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, + CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, + PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, + UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, + UpdateSubscriptionTrialSettingsEndBehavior, UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, }; use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCancellationDetails, - StripeCancellationDetailsReason, StripeCheckoutSession, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeClient, + CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, + StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, + StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, @@ -27,8 +28,8 @@ use crate::stripe_client::{ StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateCustomerParams, - UpdateSubscriptionParams, + StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, + UpdateCustomerParams, UpdateSubscriptionParams, }; pub struct RealStripeClient { @@ -151,6 +152,7 @@ impl StripeClient for RealStripeClient { }) .collect(), ); + create_subscription.automatic_tax = params.automatic_tax.map(Into::into); let subscription = Subscription::create(&self.client, create_subscription).await?; @@ -366,6 +368,15 @@ impl From for StripeSubscriptionItem { } } +impl From for CreateSubscriptionAutomaticTax { + fn from(value: StripeAutomaticTax) -> Self { + Self { + enabled: value.enabled, + liability: None, + } + } +} + impl From for UpdateSubscriptionTrialSettings { fn from(value: StripeSubscriptionTrialSettings) -> Self { Self { @@ -448,6 +459,7 @@ impl<'a> TryFrom> for CreateCheckoutSessio success_url: value.success_url, billing_address_collection: value.billing_address_collection.map(Into::into), customer_update: value.customer_update.map(Into::into), + tax_id_collection: value.tax_id_collection.map(Into::into), ..Default::default() }) } @@ -590,3 +602,11 @@ impl From for stripe::CreateCheckoutSessionCustomerUpdate } } } + +impl From for stripe::CreateCheckoutSessionTaxIdCollection { + fn from(value: StripeTaxIdCollection) -> Self { + stripe::CreateCheckoutSessionTaxIdCollection { + enabled: value.enabled, + } + } +} diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 2cc3ca76d1..73ab2b8167 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -2246,8 +2246,11 @@ async fn test_lsp_document_color(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo }); } -#[gpui::test(iterations = 10)] -async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { +async fn test_lsp_pull_diagnostics( + should_stream_workspace_diagnostic: bool, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { let mut server = TestServer::start(cx_a.executor()).await; let executor = cx_a.executor(); let client_a = server.create_client(cx_a, "user_a").await; @@ -2396,12 +2399,25 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp let closure_workspace_diagnostics_pulls_made = workspace_diagnostics_pulls_made.clone(); let closure_workspace_diagnostics_pulls_result_ids = workspace_diagnostics_pulls_result_ids.clone(); + let (workspace_diagnostic_cancel_tx, closure_workspace_diagnostic_cancel_rx) = + smol::channel::bounded::<()>(1); + let (closure_workspace_diagnostic_received_tx, workspace_diagnostic_received_rx) = + smol::channel::bounded::<()>(1); + let expected_workspace_diagnostic_token = lsp::ProgressToken::String(format!( + "workspace/diagnostic-{}-1", + fake_language_server.server.server_id() + )); + let closure_expected_workspace_diagnostic_token = expected_workspace_diagnostic_token.clone(); let mut workspace_diagnostics_pulls_handle = fake_language_server .set_request_handler::( move |params, _| { let workspace_requests_made = closure_workspace_diagnostics_pulls_made.clone(); let workspace_diagnostics_pulls_result_ids = closure_workspace_diagnostics_pulls_result_ids.clone(); + let workspace_diagnostic_cancel_rx = closure_workspace_diagnostic_cancel_rx.clone(); + let workspace_diagnostic_received_tx = closure_workspace_diagnostic_received_tx.clone(); + let expected_workspace_diagnostic_token = + closure_expected_workspace_diagnostic_token.clone(); async move { let workspace_request_count = workspace_requests_made.fetch_add(1, atomic::Ordering::Release) + 1; @@ -2411,6 +2427,21 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp .await .extend(params.previous_result_ids.into_iter().map(|id| id.value)); } + if should_stream_workspace_diagnostic && !workspace_diagnostic_cancel_rx.is_closed() + { + assert_eq!( + params.partial_result_params.partial_result_token, + Some(expected_workspace_diagnostic_token) + ); + workspace_diagnostic_received_tx.send(()).await.unwrap(); + workspace_diagnostic_cancel_rx.recv().await.unwrap(); + workspace_diagnostic_cancel_rx.close(); + // https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#partialResults + // > The final response has to be empty in terms of result values. + return Ok(lsp::WorkspaceDiagnosticReportResult::Report( + lsp::WorkspaceDiagnosticReport { items: Vec::new() }, + )); + } Ok(lsp::WorkspaceDiagnosticReportResult::Report( lsp::WorkspaceDiagnosticReport { items: vec![ @@ -2479,7 +2510,11 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp }, ); - workspace_diagnostics_pulls_handle.next().await.unwrap(); + if should_stream_workspace_diagnostic { + workspace_diagnostic_received_rx.recv().await.unwrap(); + } else { + workspace_diagnostics_pulls_handle.next().await.unwrap(); + } assert_eq!( 1, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), @@ -2503,10 +2538,10 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp "Expected single diagnostic, but got: {all_diagnostics:?}" ); let diagnostic = &all_diagnostics[0]; - let expected_messages = [ - expected_workspace_pull_diagnostics_main_message, - expected_pull_diagnostic_main_message, - ]; + let mut expected_messages = vec![expected_pull_diagnostic_main_message]; + if !should_stream_workspace_diagnostic { + expected_messages.push(expected_workspace_pull_diagnostics_main_message); + } assert!( expected_messages.contains(&diagnostic.diagnostic.message.as_str()), "Expected {expected_messages:?} on the host, but got: {}", @@ -2556,6 +2591,70 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp version: None, }, ); + + if should_stream_workspace_diagnostic { + fake_language_server.notify::(&lsp::ProgressParams { + token: expected_workspace_diagnostic_token.clone(), + value: lsp::ProgressParamsValue::WorkspaceDiagnostic( + lsp::WorkspaceDiagnosticReportResult::Report(lsp::WorkspaceDiagnosticReport { + items: vec![ + lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/main.rs")).unwrap(), + version: None, + full_document_diagnostic_report: + lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: vec![lsp::Diagnostic { + range: lsp::Range { + start: lsp::Position { + line: 0, + character: 1, + }, + end: lsp::Position { + line: 0, + character: 2, + }, + }, + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: + expected_workspace_pull_diagnostics_main_message + .to_string(), + ..lsp::Diagnostic::default() + }], + }, + }, + ), + lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + version: None, + full_document_diagnostic_report: + lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: Vec::new(), + }, + }, + ), + ], + }), + ), + }); + }; + + let mut workspace_diagnostic_start_count = + workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire); + executor.run_until_parked(); editor_a_main.update(cx_a, |editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -2599,7 +2698,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); executor.run_until_parked(); assert_eq!( - 1, + workspace_diagnostic_start_count, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "Workspace diagnostics should not be changed as the remote client does not initialize the workspace diagnostics pull" ); @@ -2646,7 +2745,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); executor.run_until_parked(); assert_eq!( - 1, + workspace_diagnostic_start_count, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "The remote client still did not anything to trigger the workspace diagnostics pull" ); @@ -2673,6 +2772,75 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); } }); + + if should_stream_workspace_diagnostic { + fake_language_server.notify::(&lsp::ProgressParams { + token: expected_workspace_diagnostic_token.clone(), + value: lsp::ProgressParamsValue::WorkspaceDiagnostic( + lsp::WorkspaceDiagnosticReportResult::Report(lsp::WorkspaceDiagnosticReport { + items: vec![lsp::WorkspaceDocumentDiagnosticReport::Full( + lsp::WorkspaceFullDocumentDiagnosticReport { + uri: lsp::Url::from_file_path(path!("/a/lib.rs")).unwrap(), + version: None, + full_document_diagnostic_report: lsp::FullDocumentDiagnosticReport { + result_id: Some(format!( + "workspace_{}", + workspace_diagnostics_pulls_made + .fetch_add(1, atomic::Ordering::Release) + + 1 + )), + items: vec![lsp::Diagnostic { + range: lsp::Range { + start: lsp::Position { + line: 0, + character: 1, + }, + end: lsp::Position { + line: 0, + character: 2, + }, + }, + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: expected_workspace_pull_diagnostics_lib_message + .to_string(), + ..lsp::Diagnostic::default() + }], + }, + }, + )], + }), + ), + }); + workspace_diagnostic_start_count = + workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire); + workspace_diagnostic_cancel_tx.send(()).await.unwrap(); + workspace_diagnostics_pulls_handle.next().await.unwrap(); + executor.run_until_parked(); + editor_b_lib.update(cx_b, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let all_diagnostics = snapshot + .diagnostics_in_range(0..snapshot.len()) + .collect::>(); + let expected_messages = [ + expected_workspace_pull_diagnostics_lib_message, + // TODO bug: the pushed diagnostics are not being sent to the client when they open the corresponding buffer. + // expected_push_diagnostic_lib_message, + ]; + assert_eq!( + all_diagnostics.len(), + 1, + "Expected pull diagnostics, but got: {all_diagnostics:?}" + ); + for diagnostic in all_diagnostics { + assert!( + expected_messages.contains(&diagnostic.diagnostic.message.as_str()), + "The client should get both push and pull messages: {expected_messages:?}, but got: {}", + diagnostic.diagnostic.message + ); + } + }); + }; + { assert!( diagnostics_pulls_result_ids.lock().await.len() > 0, @@ -2701,7 +2869,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 2, + workspace_diagnostic_start_count + 1, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After client lib.rs edits, the workspace diagnostics request should follow" ); @@ -2720,7 +2888,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 3, + workspace_diagnostic_start_count + 2, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After client main.rs edits, the workspace diagnostics pull should follow" ); @@ -2739,7 +2907,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 4, + workspace_diagnostic_start_count + 3, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "After host main.rs edits, the workspace diagnostics pull should follow" ); @@ -2769,7 +2937,7 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp ); workspace_diagnostics_pulls_handle.next().await.unwrap(); assert_eq!( - 5, + workspace_diagnostic_start_count + 4, workspace_diagnostics_pulls_made.load(atomic::Ordering::Acquire), "Another workspace diagnostics pull should happen after the diagnostics refresh server request" ); @@ -2840,6 +3008,19 @@ async fn test_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestApp }); } +#[gpui::test(iterations = 10)] +async fn test_non_streamed_lsp_pull_diagnostics( + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + test_lsp_pull_diagnostics(false, cx_a, cx_b).await; +} + +#[gpui::test(iterations = 10)] +async fn test_streamed_lsp_pull_diagnostics(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + test_lsp_pull_diagnostics(true, cx_a, cx_b).await; +} + #[gpui::test(iterations = 10)] async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { let mut server = TestServer::start(cx_a.executor()).await; diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index a77112213f..d9fd8ffeb2 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -439,7 +439,7 @@ async fn test_basic_following( editor_a1.item_id() ); - #[cfg(all(not(target_os = "macos"), not(target_os = "windows")))] + // #[cfg(all(not(target_os = "macos"), not(target_os = "windows")))] { use crate::rpc::RECONNECT_TIMEOUT; use gpui::TestScreenCaptureSource; @@ -456,11 +456,19 @@ async fn test_basic_following( .await .unwrap(); cx_b.set_screen_capture_sources(vec![display]); + let source = cx_b + .read(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); active_call_b .update(cx_b, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(source, cx)) }) .await .unwrap(); @@ -1013,7 +1021,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // and some of which were originally opened by client B. workspace_b.update_in(cx_b, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.close_inactive_items(&Default::default(), window, cx) + pane.close_other_items(&Default::default(), None, window, cx) .detach(); }); }); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index d1099a327a..9795c27574 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -277,11 +277,19 @@ async fn test_basic_calls( let events_b = active_call_events(cx_b); let events_c = active_call_events(cx_c); cx_a.set_screen_capture_sources(vec![display]); + let screen_a = cx_a + .update(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); active_call_a .update(cx_a, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(screen_a, cx)) }) .await .unwrap(); @@ -6312,11 +6320,20 @@ async fn test_join_call_after_screen_was_shared( // User A shares their screen let display = gpui::TestScreenCaptureSource::new(); cx_a.set_screen_capture_sources(vec![display]); + let screen_a = cx_a + .update(|cx| cx.screen_capture_sources()) + .await + .unwrap() + .unwrap() + .into_iter() + .next() + .unwrap(); + active_call_a .update(cx_a, |call, cx| { call.room() .unwrap() - .update(cx, |room, cx| room.share_screen(cx)) + .update(cx, |room, cx| room.share_screen(screen_a, cx)) }) .await .unwrap(); diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 7aeb381c02..8ab6e6910c 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -2,6 +2,7 @@ use crate::tests::TestServer; use call::ActiveCall; use collections::{HashMap, HashSet}; +use dap::{Capabilities, adapters::DebugTaskDefinition, transport::RequestHandling}; use debugger_ui::debugger_panel::DebugPanel; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs as _, RemoveOptions}; @@ -22,6 +23,7 @@ use language::{ use node_runtime::NodeRuntime; use project::{ ProjectPath, + debugger::session::ThreadId, lsp_store::{FormatTrigger, LspFormatTarget}, }; use remote::SshRemoteClient; @@ -29,7 +31,11 @@ use remote_server::{HeadlessAppState, HeadlessProject}; use rpc::proto; use serde_json::json; use settings::SettingsStore; -use std::{path::Path, sync::Arc}; +use std::{ + path::Path, + sync::{Arc, atomic::AtomicUsize}, +}; +use task::TcpArgumentsTemplate; use util::path; #[gpui::test(iterations = 10)] @@ -688,3 +694,162 @@ async fn test_remote_server_debugger( shutdown_session.await.unwrap(); } + +#[gpui::test] +async fn test_slow_adapter_startup_retries( + cx_a: &mut TestAppContext, + server_cx: &mut TestAppContext, + executor: BackgroundExecutor, +) { + cx_a.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + command_palette_hooks::init(cx); + zlog::init_test(); + dap_adapters::init(cx); + }); + server_cx.update(|cx| { + release_channel::init(SemanticVersion::default(), cx); + dap_adapters::init(cx); + }); + let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let remote_fs = FakeFs::new(server_cx.executor()); + remote_fs + .insert_tree( + path!("/code"), + json!({ + "lib.rs": "fn one() -> usize { 1 }" + }), + ) + .await; + + // User A connects to the remote project via SSH. + server_cx.update(HeadlessProject::init); + let remote_http_client = Arc::new(BlockedHttpClient); + let node = NodeRuntime::unavailable(); + let languages = Arc::new(LanguageRegistry::new(server_cx.executor())); + let _headless_project = server_cx.new(|cx| { + client::init_settings(cx); + HeadlessProject::new( + HeadlessAppState { + session: server_ssh, + fs: remote_fs.clone(), + http_client: remote_http_client, + node_runtime: node, + languages, + extension_host_proxy: Arc::new(ExtensionHostProxy::new()), + }, + cx, + ) + }); + + let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; + let mut server = TestServer::start(server_cx.executor()).await; + let client_a = server.create_client(cx_a, "user_a").await; + cx_a.update(|cx| { + debugger_ui::init(cx); + command_palette_hooks::init(cx); + }); + let (project_a, _) = client_a + .build_ssh_project(path!("/code"), client_ssh.clone(), cx_a) + .await; + + let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a); + + let debugger_panel = workspace + .update_in(cx_a, |_workspace, window, cx| { + cx.spawn_in(window, DebugPanel::load) + }) + .await + .unwrap(); + + workspace.update_in(cx_a, |workspace, window, cx| { + workspace.add_panel(debugger_panel, window, cx); + }); + + cx_a.run_until_parked(); + let debug_panel = workspace + .update(cx_a, |workspace, cx| workspace.panel::(cx)) + .unwrap(); + + let workspace_window = cx_a + .window_handle() + .downcast::() + .unwrap(); + + let count = Arc::new(AtomicUsize::new(0)); + let session = debugger_ui::tests::start_debug_session_with( + &workspace_window, + cx_a, + DebugTaskDefinition { + adapter: "fake-adapter".into(), + label: "test".into(), + config: json!({ + "request": "launch" + }), + tcp_connection: Some(TcpArgumentsTemplate { + port: None, + host: None, + timeout: None, + }), + }, + move |client| { + let count = count.clone(); + client.on_request_ext::(move |_seq, _request| { + if count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) < 5 { + return RequestHandling::Exit; + } + RequestHandling::Respond(Ok(Capabilities::default())) + }); + }, + ) + .unwrap(); + cx_a.run_until_parked(); + + let client = session.update(cx_a, |session, _| session.adapter_client().unwrap()); + client + .fake_event(dap::messages::Events::Stopped(dap::StoppedEvent { + reason: dap::StoppedEventReason::Pause, + description: None, + thread_id: Some(1), + preserve_focus_hint: None, + text: None, + all_threads_stopped: None, + hit_breakpoint_ids: None, + })) + .await; + + cx_a.run_until_parked(); + + let active_session = debug_panel + .update(cx_a, |this, _| this.active_session()) + .unwrap(); + + let running_state = active_session.update(cx_a, |active_session, _| { + active_session.running_state().clone() + }); + + assert_eq!( + client.id(), + running_state.read_with(cx_a, |running_state, _| running_state.session_id()) + ); + assert_eq!( + ThreadId(1), + running_state.read_with(cx_a, |running_state, _| running_state + .selected_thread_id() + .unwrap()) + ); + + let shutdown_session = workspace.update(cx_a, |workspace, cx| { + workspace.project().update(cx, |project, cx| { + project.dap_store().update(cx, |dap_store, cx| { + dap_store.shutdown_session(session.read(cx).session_id(), cx) + }) + }) + }); + + client_ssh.update(cx_a, |a, _| { + a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor) + }); + + shutdown_session.await.unwrap(); +} diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index c19eb0a234..5c5bcd5832 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -3,17 +3,11 @@ use std::sync::Arc; use chrono::{Duration, Utc}; use pretty_assertions::assert_eq; -use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_billing::StripeBilling; use crate::stripe_client::{ - FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems, - StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeMeter, StripeMeterId, StripePrice, - StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, - StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, + FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, + StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, + StripeSubscriptionItemId, UpdateSubscriptionItems, }; fn make_stripe_billing() -> (StripeBilling, Arc) { @@ -364,240 +358,3 @@ async fn test_bill_model_request_usage() { ); assert_eq!(create_meter_event_calls[0].value, 73); } - -#[gpui::test] -async fn test_checkout_with_zed_pro() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - let github_login = "zeduser1"; - let success_url = "https://example.com/success"; - - // It returns an error when the Zed Pro price doesn't exist. - { - let result = stripe_billing - .checkout_with_zed_pro(&customer_id, github_login, success_url) - .await; - - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - r#"no price ID found for "zed-pro""# - ); - } - - // Successful checkout. - { - let price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(2000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - stripe_billing.initialize().await.unwrap(); - - let checkout_url = stripe_billing - .checkout_with_zed_pro(&customer_id, github_login, success_url) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer, Some(customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!(call.payment_method_collection, None); - assert_eq!(call.subscription_data, None); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } -} - -#[gpui::test] -async fn test_checkout_with_zed_pro_trial() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - let customer_id = StripeCustomerId("cus_test".into()); - let github_login = "zeduser1"; - let success_url = "https://example.com/success"; - - // It returns an error when the Zed Pro price doesn't exist. - { - let result = stripe_billing - .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) - .await; - - assert!(result.is_err()); - assert_eq!( - result.err().unwrap().to_string(), - r#"no price ID found for "zed-pro""# - ); - } - - let price = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(2000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - stripe_client - .prices - .lock() - .insert(price.id.clone(), price.clone()); - - stripe_billing.initialize().await.unwrap(); - - // Successful checkout. - { - let checkout_url = stripe_billing - .checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer.as_ref(), Some(&customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!( - call.payment_method_collection, - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) - ); - assert_eq!( - call.subscription_data, - Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(14), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: None, - }) - ); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } - - // Successful checkout with extended trial. - { - let checkout_url = stripe_billing - .checkout_with_zed_pro_trial( - &customer_id, - github_login, - vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()], - success_url, - ) - .await - .unwrap(); - - assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay")); - - let create_checkout_session_calls = stripe_client - .create_checkout_session_calls - .lock() - .drain(..) - .collect::>(); - assert_eq!(create_checkout_session_calls.len(), 1); - let call = create_checkout_session_calls.into_iter().next().unwrap(); - assert_eq!(call.customer, Some(customer_id)); - assert_eq!(call.client_reference_id.as_deref(), Some(github_login)); - assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription)); - assert_eq!( - call.line_items, - Some(vec![StripeCreateCheckoutSessionLineItems { - price: Some(price.id.to_string()), - quantity: Some(1) - }]) - ); - assert_eq!( - call.payment_method_collection, - Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired) - ); - assert_eq!( - call.subscription_data, - Some(StripeCreateCheckoutSessionSubscriptionData { - trial_period_days: Some(60), - trial_settings: Some(StripeSubscriptionTrialSettings { - end_behavior: StripeSubscriptionTrialSettingsEndBehavior { - missing_payment_method: - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - metadata: Some(std::collections::HashMap::from_iter([( - "promo_feature_flag".into(), - AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into() - )])), - }) - ); - assert_eq!(call.success_url.as_deref(), Some(success_url)); - assert_eq!( - call.billing_address_collection, - Some(StripeBillingAddressCollection::Required) - ); - assert_eq!( - call.customer_update, - Some(StripeCustomerUpdate { - address: Some(StripeCustomerUpdateAddress::Auto), - name: Some(StripeCustomerUpdateName::Auto), - shipping: None, - }) - ); - } -} diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index ec23e2c3f5..4d5973481e 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -144,10 +144,22 @@ pub fn init(cx: &mut App) { if let Some(room) = room { window.defer(cx, move |_window, cx| { room.update(cx, |room, cx| { - if room.is_screen_sharing() { - room.unshare_screen(cx).ok(); + if room.is_sharing_screen() { + room.unshare_screen(true, cx).ok(); } else { - room.share_screen(cx).detach_and_log_err(cx); + let sources = cx.screen_capture_sources(); + + cx.spawn(async move |room, cx| { + let sources = sources.await??; + let first = sources.into_iter().next(); + if let Some(first) = first { + room.update(cx, |room, cx| room.share_screen(first, cx))? + .await + } else { + Ok(()) + } + }) + .detach_and_log_err(cx); }; }); }); @@ -528,10 +540,10 @@ impl CollabPanel { project_id: project.id, worktree_root_names: project.worktree_root_names.clone(), host_user_id: user_id, - is_last: projects.peek().is_none() && !room.is_screen_sharing(), + is_last: projects.peek().is_none() && !room.is_sharing_screen(), }); } - if room.is_screen_sharing() { + if room.is_sharing_screen() { self.entries.push(ListEntry::ParticipantScreen { peer_id: None, is_last: true, diff --git a/crates/command_palette/src/command_palette.rs b/crates/command_palette/src/command_palette.rs index abb8978d5a..dfaede0dc4 100644 --- a/crates/command_palette/src/command_palette.rs +++ b/crates/command_palette/src/command_palette.rs @@ -242,7 +242,7 @@ impl CommandPaletteDelegate { self.selected_ix = cmp::min(self.selected_ix, self.matches.len() - 1); } } - /// + /// Hit count for each command in the palette. /// We only account for commands triggered directly via command palette and not by e.g. keystrokes because /// if a user already knows a keystroke for a command, they are unlikely to use a command palette to look for it. diff --git a/crates/component/src/component_layout.rs b/crates/component/src/component_layout.rs index b749ea20ea..58bf1d8f0c 100644 --- a/crates/component/src/component_layout.rs +++ b/crates/component/src/component_layout.rs @@ -48,20 +48,20 @@ impl RenderOnce for ComponentExample { ) .child( div() - .flex() - .w_full() - .rounded_xl() .min_h(px(100.)) - .justify_center() + .w_full() .p_8() + .flex() + .items_center() + .justify_center() + .rounded_xl() .border_1() .border_color(cx.theme().colors().border.opacity(0.5)) .bg(pattern_slash( - cx.theme().colors().surface_background.opacity(0.5), + cx.theme().colors().surface_background.opacity(0.25), 12.0, 12.0, )) - .shadow_xs() .child(self.element), ) .into_any_element() @@ -118,8 +118,8 @@ impl RenderOnce for ComponentExampleGroup { .flex() .items_center() .gap_3() - .pb_1() - .child(div().h_px().w_4().bg(cx.theme().colors().border)) + .mt_4() + .mb_1() .child( div() .flex_none() diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 96bb9e071f..5e4f8369c4 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -21,12 +21,14 @@ collections.workspace = true futures.workspace = true gpui.workspace = true log.workspace = true +net.workspace = true parking_lot.workspace = true postage.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true +tempfile.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 83d815432d..8c5e7da0f1 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,15 +10,19 @@ use smol::channel; use std::{ fmt, path::PathBuf, + pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::TryFutureExt; +use util::{ResultExt, TryFutureExt}; -use crate::transport::{StdioTransport, Transport}; +use crate::{ + transport::{StdioTransport, Transport}, + types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, +}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; +type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -70,12 +75,21 @@ fn is_null_value(value: &T) -> bool { } #[derive(Serialize, Deserialize)] -struct Request<'a, T> { - jsonrpc: &'static str, - id: RequestId, - method: &'a str, +pub struct Request<'a, T> { + pub jsonrpc: &'static str, + pub id: RequestId, + pub method: &'a str, #[serde(skip_serializing_if = "is_null_value")] - params: T, + pub params: T, +} + +#[derive(Serialize, Deserialize)] +pub struct AnyRequest<'a> { + pub jsonrpc: &'a str, + pub id: RequestId, + pub method: &'a str, + #[serde(skip_serializing_if = "is_null_value")] + pub params: Option<&'a RawValue>, } #[derive(Serialize, Deserialize)] @@ -88,18 +102,18 @@ struct AnyResponse<'a> { result: Option<&'a RawValue>, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[allow(dead_code)] -struct Response { - jsonrpc: &'static str, - id: RequestId, +pub(crate) struct Response { + pub jsonrpc: &'static str, + pub id: RequestId, #[serde(flatten)] - value: CspResult, + pub value: CspResult, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -enum CspResult { +pub(crate) enum CspResult { #[serde(rename = "result")] Ok(Option), #[allow(dead_code)] @@ -123,8 +137,9 @@ struct AnyNotification<'a> { } #[derive(Debug, Serialize, Deserialize)] -struct Error { - message: String, +pub(crate) struct Error { + pub message: String, + pub code: i32, } #[derive(Debug, Clone, Deserialize)] @@ -175,15 +190,23 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input(transport, notification_handlers, response_handlers, cx) - .log_err() - .await + Self::handle_input( + transport, + notification_handlers, + request_handlers, + response_handlers, + cx, + ) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -229,13 +252,24 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { - if let Ok(response) = serde_json::from_str::(&message) { + log::trace!("recv: {}", &message); + if let Ok(request) = serde_json::from_str::(&message) { + let mut request_handlers = request_handlers.lock(); + if let Some(handler) = request_handlers.get_mut(request.method) { + handler( + request.id, + request.params.unwrap_or(RawValue::NULL), + cx.clone(), + ); + } + } else if let Ok(response) = serde_json::from_str::(&message) { if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); @@ -246,6 +280,8 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } + } else { + log::error!("Unhandled JSON from context_server: {}", message); } } @@ -293,6 +329,24 @@ impl Client { &self, method: &str, params: impl Serialize, + ) -> Result { + self.request_impl(method, params, None).await + } + + pub async fn cancellable_request( + &self, + method: &str, + params: impl Serialize, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.request_impl(method, params, Some(cancel_rx)).await + } + + pub async fn request_impl( + &self, + method: &str, + params: impl Serialize, + cancel_rx: Option>, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -329,6 +383,16 @@ impl Client { send?; let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut cancel_fut = pin!( + match cancel_rx { + Some(rx) => future::Either::Left(async { + rx.await.log_err(); + }), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -347,6 +411,16 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } + _ = cancel_fut => { + self.notify( + Cancelled::METHOD, + ClientNotification::Cancelled(CancelledParams { + request_id: RequestId::Int(id), + reason: None + }) + ).log_err(); + anyhow::bail!("Request cancelled") + } _ = timeout => { log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); anyhow::bail!("Context server request timeout"); diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 905435fcce..f2517feb27 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,13 +1,14 @@ pub mod client; +pub mod listener; pub mod protocol; #[cfg(any(test, feature = "test-support"))] pub mod test; pub mod transport; pub mod types; -use std::fmt::Display; use std::path::Path; use std::sync::Arc; +use std::{fmt::Display, path::PathBuf}; use anyhow::Result; use client::Client; @@ -30,7 +31,7 @@ impl Display for ContextServerId { #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] pub struct ContextServerCommand { #[serde(rename = "command")] - pub path: String, + pub path: PathBuf, pub args: Vec, pub env: Option>, } diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs new file mode 100644 index 0000000000..192f530816 --- /dev/null +++ b/crates/context_server/src/listener.rs @@ -0,0 +1,439 @@ +use ::serde::{Deserialize, Serialize}; +use anyhow::{Context as _, Result}; +use collections::HashMap; +use futures::{ + AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, + channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, AsyncApp, Task}; +use net::async_net::{UnixListener, UnixStream}; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; +use serde_json::{json, value::RawValue}; +use smol::stream::StreamExt; +use std::{ + cell::RefCell, + path::{Path, PathBuf}, + rc::Rc, +}; +use util::ResultExt; + +use crate::{ + client::{CspResult, RequestId, Response}, + types::{ + CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations, + ToolResponseContent, + requests::{CallTool, ListTools}, + }, +}; + +pub struct McpServer { + socket_path: PathBuf, + tools: Rc>>, + handlers: Rc>>, + _server_task: Task<()>, +} + +struct RegisteredTool { + tool: Tool, + handler: ToolHandler, +} + +type ToolHandler = Box< + dyn Fn( + Option, + &mut AsyncApp, + ) -> Task>>, +>; +type RequestHandler = Box>, &App) -> Task>; + +impl McpServer { + pub fn new(cx: &AsyncApp) -> Task> { + let task = cx.background_spawn(async move { + let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?; + let socket_path = temp_dir.path().join("mcp.sock"); + let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?; + + anyhow::Ok((temp_dir, socket_path, listener)) + }); + + cx.spawn(async move |cx| { + let (temp_dir, socket_path, listener) = task.await?; + let tools = Rc::new(RefCell::new(HashMap::default())); + let handlers = Rc::new(RefCell::new(HashMap::default())); + let server_task = cx.spawn({ + let tools = tools.clone(); + let handlers = handlers.clone(); + async move |cx| { + while let Ok((stream, _)) = listener.accept().await { + Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); + } + drop(temp_dir) + } + }); + Ok(Self { + socket_path, + _server_task: server_task, + tools, + handlers: handlers, + }) + }) + } + + pub fn add_tool(&mut self, tool: T) { + let output_schema = schemars::schema_for!(T::Output); + let unit_schema = schemars::schema_for!(()); + + let registered_tool = RegisteredTool { + tool: Tool { + name: T::NAME.into(), + description: Some(tool.description().into()), + input_schema: schemars::schema_for!(T::Input).into(), + output_schema: if output_schema == unit_schema { + None + } else { + Some(output_schema.into()) + }, + annotations: Some(tool.annotations()), + }, + handler: Box::new({ + let tool = tool.clone(); + move |input_value, cx| { + let input = match input_value { + Some(input) => serde_json::from_value(input), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let tool = tool.clone(); + match input { + Ok(input) => cx.spawn(async move |cx| { + let output = tool.run(input, cx).await?; + + Ok(ToolResponse { + content: output.content, + structured_content: serde_json::to_value(output.structured_content) + .unwrap_or_default(), + }) + }), + Err(err) => Task::ready(Err(err.into())), + } + } + }), + }; + + self.tools.borrow_mut().insert(T::NAME, registered_tool); + } + + pub fn handle_request( + &mut self, + f: impl Fn(R::Params, &App) -> Task> + 'static, + ) { + let f = Box::new(f); + self.handlers.borrow_mut().insert( + R::METHOD, + Box::new(move |req_id, opt_params, cx| { + let result = match opt_params { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let params: R::Params = match result { + Ok(params) => params, + Err(e) => { + return Task::ready( + serde_json::to_string(&Response:: { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error(Some(crate::client::Error { + message: format!("{e}"), + code: -32700, + })), + }) + .unwrap(), + ); + } + }; + let task = f(params, cx); + cx.background_spawn(async move { + match task.await { + Ok(result) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Ok(Some(result)), + }) + .unwrap(), + Err(e) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error::(Some(crate::client::Error { + message: format!("{e}"), + code: -32603, + })), + }) + .unwrap(), + } + }) + }), + ); + } + + pub fn socket_path(&self) -> &Path { + &self.socket_path + } + + fn serve_connection( + stream: UnixStream, + tools: Rc>>, + handlers: Rc>>, + cx: &mut AsyncApp, + ) { + let (read, write) = smol::io::split(stream); + let (incoming_tx, mut incoming_rx) = unbounded(); + let (outgoing_tx, outgoing_rx) = unbounded(); + + cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read)) + .detach(); + + cx.spawn(async move |cx| { + while let Some(request) = incoming_rx.next().await { + let Some(request_id) = request.id.clone() else { + continue; + }; + + if request.method == CallTool::METHOD { + Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx) + .await; + } else if request.method == ListTools::METHOD { + Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx); + } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + if let Some(task) = cx + .update(|cx| handler(request_id, request.params, cx)) + .log_err() + { + cx.spawn(async move |_| { + let response = task.await; + outgoing_tx.unbounded_send(response).ok(); + }) + .detach(); + } + } else { + Self::send_err( + request_id, + format!("unhandled method {}", request.method), + &outgoing_tx, + ); + } + } + }) + .detach(); + } + + fn handle_list_tools( + request_id: RequestId, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + ) { + let response = ListToolsResponse { + tools: tools.borrow().values().map(|t| t.tool.clone()).collect(), + next_cursor: None, + meta: None, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + } + + async fn handle_call_tool( + request_id: RequestId, + params: Option>, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + cx: &mut AsyncApp, + ) { + let result: Result = match params.as_ref() { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + match result { + Ok(params) => { + if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + let task = (tool.handler)(params.arguments, cx); + cx.spawn(async move |_| { + let response = match task.await { + Ok(result) => CallToolResponse { + content: result.content, + is_error: Some(false), + meta: None, + structured_content: if result.structured_content.is_null() { + None + } else { + Some(result.structured_content) + }, + }, + Err(err) => CallToolResponse { + content: vec![ToolResponseContent::Text { + text: err.to_string(), + }], + is_error: Some(true), + meta: None, + structured_content: None, + }, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + }) + .detach(); + } else { + Self::send_err( + request_id, + format!("Tool not found: {}", params.name), + &outgoing_tx, + ); + } + } + Err(err) => { + Self::send_err(request_id, err.to_string(), &outgoing_tx); + } + } + } + + fn send_err( + request_id: RequestId, + message: impl Into, + outgoing_tx: &UnboundedSender, + ) { + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Error(Some(crate::client::Error { + message: message.into(), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result<()> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next().fuse() => { + if let Some(message) = message { + log::trace!("send: {}", &message); + outgoing_bytes.write_all(message.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + outgoing_bytes.write_all(serde_json::to_string(&json!({ + "jsonrpc": "2.0", + "error": json!({ + "code": -32603, + "message": format!("Failed to parse: {error}"), + }), + }))?.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + Ok(()) + } +} + +pub trait McpServerTool { + type Input: DeserializeOwned + JsonSchema; + type Output: Serialize + JsonSchema; + + const NAME: &'static str; + + fn description(&self) -> &'static str; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: None, + read_only_hint: None, + destructive_hint: None, + idempotent_hint: None, + open_world_hint: None, + } + } + + fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> impl Future>>; +} + +pub struct ToolResponse { + pub content: Vec, + pub structured_content: T, +} + +#[derive(Serialize, Deserialize)] +struct RawRequest { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option>, +} + +#[derive(Serialize, Deserialize)] +struct RawResponse { + jsonrpc: &'static str, + id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option>, +} diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index d8bbac60d6..7263f502fa 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,6 +6,9 @@ //! of messages. use anyhow::Result; +use futures::channel::oneshot; +use gpui::AsyncApp; +use serde_json::Value; use crate::client::Client; use crate::types::{self, Notification, Request}; @@ -95,7 +98,24 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } + pub async fn cancellable_request( + &self, + params: T::Params, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.inner + .cancellable_request(T::METHOD, params, cancel_rx) + .await + } + pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } + + pub fn on_notification(&self, method: &'static str, f: F) + where + F: 'static + Send + FnMut(Value, AsyncApp), + { + self.inner.on_notification(method, f); + } } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 8e3daf9e22..cd97ff95bc 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,6 +3,8 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; +use crate::client::RequestId; + pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -100,6 +102,7 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); + notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -153,7 +156,7 @@ pub struct InitializeParams { pub struct CallToolParams { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, + pub arguments: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } @@ -492,18 +495,20 @@ pub struct RootsCapabilities { pub list_changed: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub input_schema: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_schema: Option, #[serde(skip_serializing_if = "Option::is_none")] pub annotations: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolAnnotations { /// A human-readable title for the tool. @@ -617,11 +622,14 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled { - request_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - reason: Option, - }, + Cancelled(CancelledParams), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CancelledParams { + pub request_id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -673,6 +681,8 @@ pub struct CallToolResponse { pub is_error: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub structured_content: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index e4370d2e67..e11242cb15 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -6,6 +6,7 @@ mod sign_in; use crate::sign_in::initiate_sign_in_within_workspace; use ::fs::Fs; use anyhow::{Context as _, Result, anyhow}; +use client::DisableAiSettings; use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared}; @@ -25,6 +26,7 @@ use node_runtime::NodeRuntime; use parking_lot::Mutex; use request::StatusNotification; use serde_json::json; +use settings::Settings; use settings::SettingsStore; use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace}; use std::collections::hash_map::Entry; @@ -93,26 +95,34 @@ pub fn init( let copilot_auth_action_types = [TypeId::of::()]; let copilot_no_auth_action_types = [TypeId::of::()]; let status = handle.read(cx).status(); + + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; let filter = CommandPaletteFilter::global_mut(cx); - match status { - Status::Disabled => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.hide_action_types(&copilot_no_auth_action_types); - } - Status::Authorized => { - filter.hide_action_types(&copilot_no_auth_action_types); - filter.show_action_types( - copilot_action_types - .iter() - .chain(&copilot_auth_action_types), - ); - } - _ => { - filter.hide_action_types(&copilot_action_types); - filter.hide_action_types(&copilot_auth_action_types); - filter.show_action_types(copilot_no_auth_action_types.iter()); + if is_ai_disabled { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.hide_action_types(&copilot_no_auth_action_types); + } else { + match status { + Status::Disabled => { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.hide_action_types(&copilot_no_auth_action_types); + } + Status::Authorized => { + filter.hide_action_types(&copilot_no_auth_action_types); + filter.show_action_types( + copilot_action_types + .iter() + .chain(&copilot_auth_action_types), + ); + } + _ => { + filter.hide_action_types(&copilot_action_types); + filter.hide_action_types(&copilot_auth_action_types); + filter.show_action_types(copilot_no_auth_action_types.iter()); + } } } }) @@ -209,8 +219,14 @@ impl Status { matches!(self, Status::Authorized) } - pub fn is_disabled(&self) -> bool { - matches!(self, Status::Disabled) + pub fn is_configured(&self) -> bool { + matches!( + self, + Status::Starting { .. } + | Status::Error(_) + | Status::SigningIn { .. } + | Status::Authorized + ) } } diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index d9f26b3b34..0c88f37ff8 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -378,6 +378,14 @@ pub trait DebugAdapter: 'static + Send + Sync { fn label_for_child_session(&self, _args: &StartDebuggingRequestArguments) -> Option { None } + + fn compact_child_session(&self) -> bool { + false + } + + fn prefer_thread_name(&self) -> bool { + false + } } #[cfg(any(test, feature = "test-support"))] @@ -442,10 +450,18 @@ impl DebugAdapter for FakeAdapter { _: Option>, _: &mut AsyncApp, ) -> Result { + let connection = task_definition + .tcp_connection + .as_ref() + .map(|connection| TcpArguments { + host: connection.host(), + port: connection.port.unwrap_or(17), + timeout: connection.timeout, + }); Ok(DebugAdapterBinary { command: Some("command".into()), arguments: vec![], - connection: None, + connection, envs: HashMap::default(), cwd: None, request_args: StartDebuggingRequestArguments { diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index ff082e3b76..86a15b2d8a 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -2,7 +2,7 @@ use crate::{ adapters::DebugAdapterBinary, transport::{IoKind, LogKind, TransportDelegate}, }; -use anyhow::{Context as _, Result}; +use anyhow::Result; use dap_types::{ messages::{Message, Response}, requests::Request, @@ -110,9 +110,7 @@ impl DebugAdapterClient { self.transport_delegate .pending_requests .lock() - .as_mut() - .context("client is closed")? - .insert(sequence_id, callback_tx); + .insert(sequence_id, callback_tx)?; log::debug!( "Client {} send `{}` request with sequence_id: {}", @@ -170,6 +168,7 @@ impl DebugAdapterClient { pub fn kill(&self) { log::debug!("Killing DAP process"); self.transport_delegate.transport.lock().kill(); + self.transport_delegate.pending_requests.lock().shutdown(); } pub fn has_adapter_logs(&self) -> bool { @@ -184,11 +183,34 @@ impl DebugAdapterClient { } #[cfg(any(test, feature = "test-support"))] - pub fn on_request(&self, handler: F) + pub fn on_request(&self, mut handler: F) where F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + { + use crate::transport::RequestHandling; + + self.transport_delegate + .transport + .lock() + .as_fake() + .on_request::(move |seq, request| { + RequestHandling::Respond(handler(seq, request)) + }); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn on_request_ext(&self, handler: F) + where + F: 'static + + Send + + FnMut( + u64, + R::Arguments, + ) -> crate::transport::RequestHandling< + Result, + >, { self.transport_delegate .transport diff --git a/crates/dap/src/registry.rs b/crates/dap/src/registry.rs index 9435b16b92..d56e2f8f34 100644 --- a/crates/dap/src/registry.rs +++ b/crates/dap/src/registry.rs @@ -46,6 +46,7 @@ impl DapRegistry { let name = adapter.name(); let _previous_value = self.0.write().adapters.insert(name, adapter); } + pub fn add_locator(&self, locator: Arc) { self.0.write().locators.insert(locator.name(), locator); } diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 14370f66e4..6dadf1cf35 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -49,6 +49,12 @@ pub enum IoKind { StdErr, } +#[cfg(any(test, feature = "test-support"))] +pub enum RequestHandling { + Respond(T), + Exit, +} + type LogHandlers = Arc>>; pub trait Transport: Send + Sync { @@ -76,7 +82,11 @@ async fn start( ) -> Result> { #[cfg(any(test, feature = "test-support"))] if cfg!(any(test, feature = "test-support")) { - return Ok(Box::new(FakeTransport::start(cx).await?)); + if let Some(connection) = binary.connection.clone() { + return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?)); + } else { + return Ok(Box::new(FakeTransport::start_stdio(cx).await?)); + } } if binary.connection.is_some() { @@ -90,11 +100,57 @@ async fn start( } } +pub(crate) struct PendingRequests { + inner: Option>>>, +} + +impl PendingRequests { + fn new() -> Self { + Self { + inner: Some(HashMap::default()), + } + } + + fn flush(&mut self, e: anyhow::Error) { + let Some(inner) = self.inner.as_mut() else { + return; + }; + for (_, sender) in inner.drain() { + sender.send(Err(e.cloned())).ok(); + } + } + + pub(crate) fn insert( + &mut self, + sequence_id: u64, + callback_tx: oneshot::Sender>, + ) -> anyhow::Result<()> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed") + }; + inner.insert(sequence_id, callback_tx); + Ok(()) + } + + pub(crate) fn remove( + &mut self, + sequence_id: u64, + ) -> anyhow::Result>>> { + let Some(inner) = self.inner.as_mut() else { + bail!("client is closed"); + }; + Ok(inner.remove(&sequence_id)) + } + + pub(crate) fn shutdown(&mut self) { + self.flush(anyhow!("transport shutdown")); + self.inner = None; + } +} + pub(crate) struct TransportDelegate { log_handlers: LogHandlers, - // TODO this should really be some kind of associative channel - pub(crate) pending_requests: - Arc>>>>>, + pub(crate) pending_requests: Arc>, pub(crate) transport: Mutex>, pub(crate) server_tx: smol::lock::Mutex>>, tasks: Mutex>>, @@ -108,7 +164,7 @@ impl TransportDelegate { transport: Mutex::new(transport), log_handlers, server_tx: Default::default(), - pending_requests: Arc::new(Mutex::new(Some(HashMap::default()))), + pending_requests: Arc::new(Mutex::new(PendingRequests::new())), tasks: Default::default(), }) } @@ -151,24 +207,10 @@ impl TransportDelegate { Ok(()) => { pending_requests .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request - .send(Err(anyhow!("debugger shutdown unexpectedly"))) - .ok(); - }); + .flush(anyhow!("debugger shutdown unexpectedly")); } Err(e) => { - pending_requests - .lock() - .take() - .into_iter() - .flatten() - .for_each(|(_, request)| { - request.send(Err(e.cloned())).ok(); - }); + pending_requests.lock().flush(e); } } })); @@ -286,7 +328,7 @@ impl TransportDelegate { async fn recv_from_server( server_stdout: Stdout, mut message_handler: DapMessageHandler, - pending_requests: Arc>>>>>, + pending_requests: Arc>, log_handlers: Option, ) -> Result<()> where @@ -303,14 +345,10 @@ impl TransportDelegate { ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"), ConnectionResult::ConnectionReset => { log::info!("Debugger closed the connection"); - break Ok(()); + return Ok(()); } ConnectionResult::Result(Ok(Message::Response(res))) => { - let tx = pending_requests - .lock() - .as_mut() - .context("client is closed")? - .remove(&res.request_seq); + let tx = pending_requests.lock().remove(res.request_seq)?; if let Some(tx) = tx { if let Err(e) = tx.send(Self::process_response(res)) { log::trace!("Did not send response `{:?}` for a cancelled", e); @@ -704,8 +742,7 @@ impl Drop for StdioTransport { } #[cfg(any(test, feature = "test-support"))] -type RequestHandler = - Box dap_types::messages::Response>; +type RequestHandler = Box RequestHandling>; #[cfg(any(test, feature = "test-support"))] type ResponseHandler = Box; @@ -716,23 +753,38 @@ pub struct FakeTransport { request_handlers: Arc>>, // for reverse request responses response_handlers: Arc>>, - - stdin_writer: Option, - stdout_reader: Option, message_handler: Option>>, + kind: FakeTransportKind, +} + +#[cfg(any(test, feature = "test-support"))] +pub enum FakeTransportKind { + Stdio { + stdin_writer: Option, + stdout_reader: Option, + }, + Tcp { + connection: TcpArguments, + executor: BackgroundExecutor, + }, } #[cfg(any(test, feature = "test-support"))] impl FakeTransport { pub fn on_request(&self, mut handler: F) where - F: 'static + Send + FnMut(u64, R::Arguments) -> Result, + F: 'static + + Send + + FnMut(u64, R::Arguments) -> RequestHandling>, { self.request_handlers.lock().insert( R::COMMAND, Box::new(move |seq, args| { let result = handler(seq, serde_json::from_value(args).unwrap()); - let response = match result { + let RequestHandling::Respond(response) = result else { + return RequestHandling::Exit; + }; + let response = match response { Ok(response) => Response { seq: seq + 1, request_seq: seq, @@ -750,7 +802,7 @@ impl FakeTransport { message: None, }, }; - response + RequestHandling::Respond(response) }), ); } @@ -764,86 +816,75 @@ impl FakeTransport { .insert(R::COMMAND, Box::new(handler)); } - async fn start(cx: &mut AsyncApp) -> Result { + async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result { + Ok(Self { + request_handlers: Arc::new(Mutex::new(HashMap::default())), + response_handlers: Arc::new(Mutex::new(HashMap::default())), + message_handler: None, + kind: FakeTransportKind::Tcp { + connection, + executor: cx.background_executor().clone(), + }, + }) + } + + async fn handle_messages( + request_handlers: Arc>>, + response_handlers: Arc>>, + stdin_reader: PipeReader, + stdout_writer: PipeWriter, + ) -> Result<()> { use dap_types::requests::{Request, RunInTerminal, StartDebugging}; use serde_json::json; - let (stdin_writer, stdin_reader) = async_pipe::pipe(); - let (stdout_writer, stdout_reader) = async_pipe::pipe(); - - let mut this = Self { - request_handlers: Arc::new(Mutex::new(HashMap::default())), - response_handlers: Arc::new(Mutex::new(HashMap::default())), - stdin_writer: Some(stdin_writer), - stdout_reader: Some(stdout_reader), - message_handler: None, - }; - - let request_handlers = this.request_handlers.clone(); - let response_handlers = this.response_handlers.clone(); + let mut reader = BufReader::new(stdin_reader); let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer)); + let mut buffer = String::new(); - this.message_handler = Some(cx.background_spawn(async move { - let mut reader = BufReader::new(stdin_reader); - let mut buffer = String::new(); - - loop { - match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None) - .await - { - ConnectionResult::Timeout => { - anyhow::bail!("Timed out when connecting to debugger"); - } - ConnectionResult::ConnectionReset => { - log::info!("Debugger closed the connection"); - break Ok(()); - } - ConnectionResult::Result(Err(e)) => break Err(e), - ConnectionResult::Result(Ok(message)) => { - match message { - Message::Request(request) => { - // redirect reverse requests to stdout writer/reader - if request.command == RunInTerminal::COMMAND - || request.command == StartDebugging::COMMAND - { - let message = - serde_json::to_string(&Message::Request(request)).unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); - } else { - let response = if let Some(handle) = - request_handlers.lock().get_mut(request.command.as_str()) - { - handle(request.seq, request.arguments.unwrap_or(json!({}))) - } else { - panic!("No request handler for {}", request.command); - }; - let message = - serde_json::to_string(&Message::Response(response)) - .unwrap(); - - let mut writer = stdout_writer.lock().await; - writer - .write_all( - TransportDelegate::build_rpc_message(message) - .as_bytes(), - ) - .await - .unwrap(); - writer.flush().await.unwrap(); - } - } - Message::Event(event) => { + loop { + match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await { + ConnectionResult::Timeout => { + anyhow::bail!("Timed out when connecting to debugger"); + } + ConnectionResult::ConnectionReset => { + log::info!("Debugger closed the connection"); + break Ok(()); + } + ConnectionResult::Result(Err(e)) => break Err(e), + ConnectionResult::Result(Ok(message)) => { + match message { + Message::Request(request) => { + // redirect reverse requests to stdout writer/reader + if request.command == RunInTerminal::COMMAND + || request.command == StartDebugging::COMMAND + { let message = - serde_json::to_string(&Message::Event(event)).unwrap(); + serde_json::to_string(&Message::Request(request)).unwrap(); + + let mut writer = stdout_writer.lock().await; + writer + .write_all( + TransportDelegate::build_rpc_message(message).as_bytes(), + ) + .await + .unwrap(); + writer.flush().await.unwrap(); + } else { + let response = if let Some(handle) = + request_handlers.lock().get_mut(request.command.as_str()) + { + handle(request.seq, request.arguments.unwrap_or(json!({}))) + } else { + panic!("No request handler for {}", request.command); + }; + let response = match response { + RequestHandling::Respond(response) => response, + RequestHandling::Exit => { + break Err(anyhow!("exit in response to request")); + } + }; + let message = + serde_json::to_string(&Message::Response(response)).unwrap(); let mut writer = stdout_writer.lock().await; writer @@ -854,20 +895,56 @@ impl FakeTransport { .unwrap(); writer.flush().await.unwrap(); } - Message::Response(response) => { - if let Some(handle) = - response_handlers.lock().get(response.command.as_str()) - { - handle(response); - } else { - log::error!("No response handler for {}", response.command); - } + } + Message::Event(event) => { + let message = serde_json::to_string(&Message::Event(event)).unwrap(); + + let mut writer = stdout_writer.lock().await; + writer + .write_all(TransportDelegate::build_rpc_message(message).as_bytes()) + .await + .unwrap(); + writer.flush().await.unwrap(); + } + Message::Response(response) => { + if let Some(handle) = + response_handlers.lock().get(response.command.as_str()) + { + handle(response); + } else { + log::error!("No response handler for {}", response.command); } } } } } - })); + } + } + + async fn start_stdio(cx: &mut AsyncApp) -> Result { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + let kind = FakeTransportKind::Stdio { + stdin_writer: Some(stdin_writer), + stdout_reader: Some(stdout_reader), + }; + + let mut this = Self { + request_handlers: Arc::new(Mutex::new(HashMap::default())), + response_handlers: Arc::new(Mutex::new(HashMap::default())), + message_handler: None, + kind, + }; + + let request_handlers = this.request_handlers.clone(); + let response_handlers = this.response_handlers.clone(); + + this.message_handler = Some(cx.background_spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); Ok(this) } @@ -876,7 +953,10 @@ impl FakeTransport { #[cfg(any(test, feature = "test-support"))] impl Transport for FakeTransport { fn tcp_arguments(&self) -> Option { - None + match &self.kind { + FakeTransportKind::Stdio { .. } => None, + FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()), + } } fn connect( @@ -887,12 +967,33 @@ impl Transport for FakeTransport { Box, )>, > { - let result = util::maybe!({ - Ok(( - Box::new(self.stdin_writer.take().context("Cannot reconnect")?) as _, - Box::new(self.stdout_reader.take().context("Cannot reconnect")?) as _, - )) - }); + let result = match &mut self.kind { + FakeTransportKind::Stdio { + stdin_writer, + stdout_reader, + } => util::maybe!({ + Ok(( + Box::new(stdin_writer.take().context("Cannot reconnect")?) as _, + Box::new(stdout_reader.take().context("Cannot reconnect")?) as _, + )) + }), + FakeTransportKind::Tcp { executor, .. } => { + let (stdin_writer, stdin_reader) = async_pipe::pipe(); + let (stdout_writer, stdout_reader) = async_pipe::pipe(); + + let request_handlers = self.request_handlers.clone(); + let response_handlers = self.response_handlers.clone(); + + self.message_handler = Some(executor.spawn(Self::handle_messages( + request_handlers, + response_handlers, + stdin_reader, + stdout_writer, + ))); + + Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _)) + } + }; Task::ready(result) } diff --git a/crates/dap_adapters/Cargo.toml b/crates/dap_adapters/Cargo.toml index 65544fbb6a..e7366785c8 100644 --- a/crates/dap_adapters/Cargo.toml +++ b/crates/dap_adapters/Cargo.toml @@ -36,6 +36,7 @@ paths.workspace = true serde.workspace = true serde_json.workspace = true shlex.workspace = true +smol.workspace = true task.workspace = true util.workspace = true workspace-hack.workspace = true diff --git a/crates/dap_adapters/src/dap_adapters.rs b/crates/dap_adapters/src/dap_adapters.rs index a147861f8d..a4e6beb249 100644 --- a/crates/dap_adapters/src/dap_adapters.rs +++ b/crates/dap_adapters/src/dap_adapters.rs @@ -13,7 +13,6 @@ use dap::{ DapRegistry, adapters::{ self, AdapterVersion, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName, - GithubRepo, }, configure_tcp_connection, }; diff --git a/crates/dap_adapters/src/go.rs b/crates/dap_adapters/src/go.rs index d32f5cbf34..22d8262b93 100644 --- a/crates/dap_adapters/src/go.rs +++ b/crates/dap_adapters/src/go.rs @@ -547,6 +547,7 @@ async fn handle_envs( } }; + let mut env_vars = HashMap::default(); for path in env_files { let Some(path) = path .and_then(|s| PathBuf::from_str(s).ok()) @@ -556,13 +557,33 @@ async fn handle_envs( }; if let Ok(file) = fs.open_sync(&path).await { - envs.extend(dotenvy::from_read_iter(file).filter_map(Result::ok)) + let file_envs: HashMap = dotenvy::from_read_iter(file) + .filter_map(Result::ok) + .collect(); + envs.extend(file_envs.iter().map(|(k, v)| (k.clone(), v.clone()))); + env_vars.extend(file_envs); } else { warn!("While starting Go debug session: failed to read env file {path:?}"); }; } + let mut env_obj: serde_json::Map = serde_json::Map::new(); + + for (k, v) in env_vars { + env_obj.insert(k, Value::String(v)); + } + + if let Some(existing_env) = config.get("env").and_then(|v| v.as_object()) { + for (k, v) in existing_env { + env_obj.insert(k.clone(), v.clone()); + } + } + + if !env_obj.is_empty() { + config.insert("env".to_string(), Value::Object(env_obj)); + } + // remove envFile now that it's been handled - config.remove("entry"); + config.remove("envFile"); Some(()) } diff --git a/crates/dap_adapters/src/javascript.rs b/crates/dap_adapters/src/javascript.rs index 76c1d1fb7b..2d19921a0f 100644 --- a/crates/dap_adapters/src/javascript.rs +++ b/crates/dap_adapters/src/javascript.rs @@ -54,20 +54,6 @@ impl JsDebugAdapter { user_args: Option>, _: &mut AsyncApp, ) -> Result { - let adapter_path = if let Some(user_installed_path) = user_installed_path { - user_installed_path - } else { - let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref()); - - let file_name_prefix = format!("{}_", self.name()); - - util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { - file_name.starts_with(&file_name_prefix) - }) - .await - .context("Couldn't find JavaScript dap directory")? - }; - let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; @@ -136,21 +122,27 @@ impl JsDebugAdapter { .or_insert(true.into()); } + let adapter_path = if let Some(user_installed_path) = user_installed_path { + user_installed_path + } else { + let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref()); + + let file_name_prefix = format!("{}_", self.name()); + + util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { + file_name.starts_with(&file_name_prefix) + }) + .await + .context("Couldn't find JavaScript dap directory")? + .join(Self::ADAPTER_PATH) + }; + let arguments = if let Some(mut args) = user_args { - args.insert( - 0, - adapter_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ); + args.insert(0, adapter_path.to_string_lossy().to_string()); args } else { vec![ - adapter_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), + adapter_path.to_string_lossy().to_string(), port.to_string(), host.to_string(), ] @@ -534,6 +526,14 @@ impl DebugAdapter for JsDebugAdapter { .filter(|name| !name.is_empty())?; Some(label.to_owned()) } + + fn compact_child_session(&self) -> bool { + true + } + + fn prefer_thread_name(&self) -> bool { + true + } } fn normalize_task_type(task_type: &mut Value) { diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index dc3d15e124..aa64fea6ed 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -1,31 +1,39 @@ use crate::*; use anyhow::Context as _; -use dap::adapters::latest_github_release; use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition}; -use gpui::{AppContext, AsyncApp, SharedString}; +use gpui::{AsyncApp, SharedString}; use json_dotpath::DotPaths; -use language::{LanguageName, Toolchain}; +use language::LanguageName; +use paths::debug_adapters_dir; use serde_json::Value; +use smol::lock::OnceCell; use std::net::Ipv4Addr; use std::{ collections::HashMap, ffi::OsStr, path::{Path, PathBuf}, - sync::OnceLock, }; -use util::ResultExt; #[derive(Default)] pub(crate) struct PythonDebugAdapter { - checked: OnceLock<()>, + python_venv_base: OnceCell, String>>, } impl PythonDebugAdapter { const ADAPTER_NAME: &'static str = "Debugpy"; const DEBUG_ADAPTER_NAME: DebugAdapterName = DebugAdapterName(SharedString::new_static(Self::ADAPTER_NAME)); - const ADAPTER_PACKAGE_NAME: &'static str = "debugpy"; - const ADAPTER_PATH: &'static str = "src/debugpy/adapter"; + const PYTHON_ADAPTER_IN_VENV: &'static str = if cfg!(target_os = "windows") { + "Scripts/python3" + } else { + "bin/python3" + }; + const ADAPTER_PATH: &'static str = if cfg!(target_os = "windows") { + "debugpy-venv/Scripts/debugpy-adapter" + } else { + "debugpy-venv/bin/debugpy-adapter" + }; + const LANGUAGE_NAME: &'static str = "Python"; async fn generate_debugpy_arguments( @@ -40,36 +48,18 @@ impl PythonDebugAdapter { "Using user-installed debugpy adapter from: {}", user_installed_path.display() ); - vec![ - user_installed_path - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ] + vec![user_installed_path.to_string_lossy().to_string()] } else if installed_in_venv { log::debug!("Using venv-installed debugpy"); vec!["-m".to_string(), "debugpy.adapter".to_string()] } else { let adapter_path = paths::debug_adapters_dir().join(Self::DEBUG_ADAPTER_NAME.as_ref()); - let file_name_prefix = format!("{}_", Self::ADAPTER_NAME); - - let debugpy_dir = - util::fs::find_file_name_in_dir(adapter_path.as_path(), |file_name| { - file_name.starts_with(&file_name_prefix) - }) - .await - .context("Debugpy directory not found")?; - - log::debug!( - "Using GitHub-downloaded debugpy adapter from: {}", - debugpy_dir.display() - ); - vec![ - debugpy_dir - .join(Self::ADAPTER_PATH) - .to_string_lossy() - .to_string(), - ] + let path = adapter_path + .join(Self::ADAPTER_PATH) + .to_string_lossy() + .into_owned(); + log::debug!("Using pip debugpy adapter from: {path}"); + vec![path] }; args.extend(if let Some(args) = user_args { @@ -105,44 +95,67 @@ impl PythonDebugAdapter { request, }) } - async fn fetch_latest_adapter_version( - &self, - delegate: &Arc, - ) -> Result { - let github_repo = GithubRepo { - repo_name: Self::ADAPTER_PACKAGE_NAME.into(), - repo_owner: "microsoft".into(), - }; - fetch_latest_adapter_version_from_github(github_repo, delegate.as_ref()).await - } - - async fn install_binary( - adapter_name: DebugAdapterName, - version: AdapterVersion, - delegate: Arc, - ) -> Result<()> { - let version_path = adapters::download_adapter_from_github( - adapter_name, - version, - adapters::DownloadedFileType::GzipTar, - delegate.as_ref(), - ) - .await?; - // only needed when you install the latest version for the first time - if let Some(debugpy_dir) = - util::fs::find_file_name_in_dir(version_path.as_path(), |file_name| { - file_name.starts_with("microsoft-debugpy-") - }) + async fn ensure_venv(delegate: &dyn DapDelegate) -> Result> { + let python_path = Self::find_base_python(delegate) .await - { - // TODO Debugger: Rename folder instead of moving all files to another folder - // We're doing unnecessary IO work right now - util::fs::move_folder_files_to_folder(debugpy_dir.as_path(), version_path.as_path()) + .context("Could not find Python installation for DebugPy")?; + let work_dir = debug_adapters_dir().join(Self::ADAPTER_NAME); + let mut path = work_dir.clone(); + path.push("debugpy-venv"); + if !path.exists() { + util::command::new_smol_command(python_path) + .arg("-m") + .arg("venv") + .arg("debugpy-venv") + .current_dir(work_dir) + .spawn()? + .output() .await?; } - Ok(()) + Ok(path.into()) + } + + // Find "baseline", user python version from which we'll create our own venv. + async fn find_base_python(delegate: &dyn DapDelegate) -> Option { + for path in ["python3", "python"] { + if let Some(path) = delegate.which(path.as_ref()).await { + return Some(path); + } + } + None + } + + async fn base_venv(&self, delegate: &dyn DapDelegate) -> Result, String> { + const BINARY_DIR: &str = if cfg!(target_os = "windows") { + "Scripts" + } else { + "bin" + }; + self.python_venv_base + .get_or_init(move || async move { + let venv_base = Self::ensure_venv(delegate) + .await + .map_err(|e| format!("{e}"))?; + let pip_path = venv_base.join(BINARY_DIR).join("pip3"); + let installation_succeeded = util::command::new_smol_command(pip_path.as_path()) + .arg("install") + .arg("debugpy") + .arg("-U") + .output() + .await + .map_err(|e| format!("{e}"))? + .status + .success(); + if !installation_succeeded { + return Err("debugpy installation failed".into()); + } + + Ok(venv_base) + }) + .await + .clone() } async fn get_installed_binary( @@ -151,15 +164,15 @@ impl PythonDebugAdapter { config: &DebugTaskDefinition, user_installed_path: Option, user_args: Option>, - toolchain: Option, + python_from_toolchain: Option, installed_in_venv: bool, ) -> Result { const BINARY_NAMES: [&str; 3] = ["python3", "python", "py"]; let tcp_connection = config.tcp_connection.clone().unwrap_or_default(); let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?; - let python_path = if let Some(toolchain) = toolchain { - Some(toolchain.path.to_string()) + let python_path = if let Some(toolchain) = python_from_toolchain { + Some(toolchain) } else { let mut name = None; @@ -640,25 +653,28 @@ impl DebugAdapter for PythonDebugAdapter { &config, None, user_args, - Some(toolchain.clone()), + Some(toolchain.path.to_string()), true, ) .await; } } } - - if self.checked.set(()).is_ok() { - delegate.output_to_console(format!("Checking latest version of {}...", self.name())); - if let Some(version) = self.fetch_latest_adapter_version(delegate).await.log_err() { - cx.background_spawn(Self::install_binary(self.name(), version, delegate.clone())) - .await - .context("Failed to install debugpy")?; - } - } - - self.get_installed_binary(delegate, &config, None, user_args, toolchain, false) + let toolchain = self + .base_venv(&**delegate) .await + .map_err(|e| anyhow::anyhow!(e))? + .join(Self::PYTHON_ADAPTER_IN_VENV); + + self.get_installed_binary( + delegate, + &config, + None, + user_args, + Some(toolchain.to_string_lossy().into_owned()), + false, + ) + .await } fn label_for_child_session(&self, args: &StartDebuggingRequestArguments) -> Option { @@ -671,24 +687,6 @@ impl DebugAdapter for PythonDebugAdapter { } } -async fn fetch_latest_adapter_version_from_github( - github_repo: GithubRepo, - delegate: &dyn DapDelegate, -) -> Result { - let release = latest_github_release( - &format!("{}/{}", github_repo.repo_owner, github_repo.repo_name), - false, - false, - delegate.http_client(), - ) - .await?; - - Ok(AdapterVersion { - tag_name: release.tag_name, - url: release.tarball_url, - }) -} - #[cfg(test)] mod tests { use super::*; @@ -700,7 +698,7 @@ mod tests { let port = 5678; // Case 1: User-defined debugpy path (highest precedence) - let user_path = PathBuf::from("/custom/path/to/debugpy"); + let user_path = PathBuf::from("/custom/path/to/debugpy/src/debugpy/adapter"); let user_args = PythonDebugAdapter::generate_debugpy_arguments( &host, port, @@ -717,7 +715,7 @@ mod tests { .await .unwrap(); - assert!(user_args[0].ends_with("src/debugpy/adapter")); + assert_eq!(user_args[0], "/custom/path/to/debugpy/src/debugpy/adapter"); assert_eq!(user_args[1], "--host=127.0.0.1"); assert_eq!(user_args[2], "--port=5678"); diff --git a/crates/debugger_tools/src/dap_log.rs b/crates/debugger_tools/src/dap_log.rs index f2f193cad4..b806381d25 100644 --- a/crates/debugger_tools/src/dap_log.rs +++ b/crates/debugger_tools/src/dap_log.rs @@ -32,12 +32,19 @@ use workspace::{ ui::{Button, Clickable, ContextMenu, Label, LabelCommon, PopoverMenu, h_flex}, }; +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum View { + AdapterLogs, + RpcMessages, + InitializationSequence, +} + struct DapLogView { editor: Entity, focus_handle: FocusHandle, log_store: Entity, editor_subscriptions: Vec, - current_view: Option<(SessionId, LogKind)>, + current_view: Option<(SessionId, View)>, project: Entity, _subscriptions: Vec, } @@ -77,6 +84,7 @@ struct DebugAdapterState { id: SessionId, log_messages: VecDeque, rpc_messages: RpcMessages, + session_label: SharedString, adapter_name: DebugAdapterName, has_adapter_logs: bool, is_terminated: bool, @@ -121,12 +129,18 @@ impl MessageKind { } impl DebugAdapterState { - fn new(id: SessionId, adapter_name: DebugAdapterName, has_adapter_logs: bool) -> Self { + fn new( + id: SessionId, + adapter_name: DebugAdapterName, + session_label: SharedString, + has_adapter_logs: bool, + ) -> Self { Self { id, log_messages: VecDeque::new(), rpc_messages: RpcMessages::new(), adapter_name, + session_label, has_adapter_logs, is_terminated: false, } @@ -371,18 +385,22 @@ impl LogStore { return None; }; - let (adapter_name, has_adapter_logs) = session.read_with(cx, |session, _| { - ( - session.adapter(), - session - .adapter_client() - .map_or(false, |client| client.has_adapter_logs()), - ) - }); + let (adapter_name, session_label, has_adapter_logs) = + session.read_with(cx, |session, _| { + ( + session.adapter(), + session.label(), + session + .adapter_client() + .map_or(false, |client| client.has_adapter_logs()), + ) + }); state.insert(DebugAdapterState::new( id.session_id, adapter_name, + session_label + .unwrap_or_else(|| format!("Session {} (child)", id.session_id.0).into()), has_adapter_logs, )); @@ -506,12 +524,13 @@ impl Render for DapLogToolbarItemView { current_client .map(|sub_item| { Cow::Owned(format!( - "{} ({}) - {}", + "{} - {} - {}", sub_item.adapter_name, - sub_item.session_id.0, + sub_item.session_label, match sub_item.selected_entry { - LogKind::Adapter => ADAPTER_LOGS, - LogKind::Rpc => RPC_MESSAGES, + View::AdapterLogs => ADAPTER_LOGS, + View::RpcMessages => RPC_MESSAGES, + View::InitializationSequence => INITIALIZATION_SEQUENCE, } )) }) @@ -529,8 +548,8 @@ impl Render for DapLogToolbarItemView { .pl_2() .child( Label::new(format!( - "{}. {}", - row.session_id.0, row.adapter_name, + "{} - {}", + row.adapter_name, row.session_label )) .color(workspace::ui::Color::Muted), ) @@ -669,9 +688,16 @@ impl DapLogView { let events_subscriptions = cx.subscribe(&log_store, |log_view, _, event, cx| match event { Event::NewLogEntry { id, entry, kind } => { - if log_view.current_view == Some((id.session_id, *kind)) - && log_view.project == *id.project - { + let is_current_view = match (log_view.current_view, *kind) { + (Some((i, View::AdapterLogs)), LogKind::Adapter) + | (Some((i, View::RpcMessages)), LogKind::Rpc) + if i == id.session_id => + { + log_view.project == *id.project + } + _ => false, + }; + if is_current_view { log_view.editor.update(cx, |editor, cx| { editor.set_read_only(false); let last_point = editor.buffer().read(cx).len(cx); @@ -768,10 +794,11 @@ impl DapLogView { .map(|state| DapMenuItem { session_id: state.id, adapter_name: state.adapter_name.clone(), + session_label: state.session_label.clone(), has_adapter_logs: state.has_adapter_logs, selected_entry: self .current_view - .map_or(LogKind::Adapter, |(_, kind)| kind), + .map_or(View::AdapterLogs, |(_, kind)| kind), }) .collect::>() }) @@ -789,7 +816,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(rpc_log) = rpc_log { - self.current_view = Some((id.session_id, LogKind::Rpc)); + self.current_view = Some((id.session_id, View::RpcMessages)); let (editor, editor_subscriptions) = Self::editor_for_logs(rpc_log, window, cx); let language = self.project.read(cx).languages().language_for_name("JSON"); editor @@ -830,7 +857,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(message_log) = message_log { - self.current_view = Some((id.session_id, LogKind::Adapter)); + self.current_view = Some((id.session_id, View::AdapterLogs)); let (editor, editor_subscriptions) = Self::editor_for_logs(message_log, window, cx); editor .read(cx) @@ -859,7 +886,7 @@ impl DapLogView { .map(|state| log_contents(state.iter().cloned())) }); if let Some(rpc_log) = rpc_log { - self.current_view = Some((id.session_id, LogKind::Rpc)); + self.current_view = Some((id.session_id, View::InitializationSequence)); let (editor, editor_subscriptions) = Self::editor_for_logs(rpc_log, window, cx); let language = self.project.read(cx).languages().language_for_name("JSON"); editor @@ -899,11 +926,12 @@ fn log_contents(lines: impl Iterator) -> String { } #[derive(Clone, PartialEq)] -pub(crate) struct DapMenuItem { - pub session_id: SessionId, - pub adapter_name: DebugAdapterName, - pub has_adapter_logs: bool, - pub selected_entry: LogKind, +struct DapMenuItem { + session_id: SessionId, + session_label: SharedString, + adapter_name: DebugAdapterName, + has_adapter_logs: bool, + selected_entry: View, } const ADAPTER_LOGS: &str = "Adapter Logs"; diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index fc543a47f9..df4125860f 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -16,13 +16,13 @@ doctest = false test-support = [ "dap/test-support", "dap_adapters/test-support", + "debugger_tools/test-support", "editor/test-support", "gpui/test-support", "project/test-support", "util/test-support", "workspace/test-support", "unindent", - "debugger_tools" ] [dependencies] @@ -35,22 +35,27 @@ command_palette_hooks.workspace = true dap.workspace = true dap_adapters = { workspace = true, optional = true } db.workspace = true +debugger_tools.workspace = true editor.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true +hex.workspace = true indoc.workspace = true itertools.workspace = true language.workspace = true log.workspace = true menu.workspace = true +notifications.workspace = true parking_lot.workspace = true +parse_int.workspace = true paths.workspace = true picker.workspace = true pretty_assertions.workspace = true project.workspace = true rpc.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true @@ -63,14 +68,13 @@ telemetry.workspace = true terminal_view.workspace = true text.workspace = true theme.workspace = true -tree-sitter.workspace = true tree-sitter-json.workspace = true +tree-sitter.workspace = true ui.workspace = true -util.workspace = true -workspace.workspace = true -workspace-hack.workspace = true -debugger_tools = { workspace = true, optional = true } unindent = { workspace = true, optional = true } +util.workspace = true +workspace-hack.workspace = true +workspace.workspace = true zed_actions.workspace = true [dev-dependencies] @@ -80,8 +84,8 @@ debugger_tools = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } +tree-sitter-go.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true -tree-sitter-go.workspace = true diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index 988f6f4019..d81c593484 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -2,6 +2,7 @@ use crate::persistence::DebuggerPaneItem; use crate::session::DebugSession; use crate::session::running::RunningState; use crate::session::running::breakpoint_list::BreakpointList; + use crate::{ ClearAllBreakpoints, Continue, CopyDebugAdapterArguments, Detach, FocusBreakpointList, FocusConsole, FocusFrames, FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables, @@ -9,6 +10,7 @@ use crate::{ ToggleExpandItem, ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal, }; use anyhow::{Context as _, Result, anyhow}; +use collections::IndexMap; use dap::adapters::DebugAdapterName; use dap::debugger_settings::DebugPanelDockPosition; use dap::{ @@ -26,7 +28,7 @@ use text::ToPoint as _; use itertools::Itertools as _; use language::Buffer; -use project::debugger::session::{Session, SessionStateEvent}; +use project::debugger::session::{Session, SessionQuirks, SessionState, SessionStateEvent}; use project::{DebugScenarioContext, Fs, ProjectPath, TaskSourceKind, WorktreeId}; use project::{Project, debugger::session::ThreadStatus}; use rpc::proto::{self}; @@ -35,7 +37,7 @@ use std::sync::{Arc, LazyLock}; use task::{DebugScenario, TaskContext}; use tree_sitter::{Query, StreamingIterator as _}; use ui::{ContextMenu, Divider, PopoverMenuHandle, Tooltip, prelude::*}; -use util::{ResultExt, maybe}; +use util::{ResultExt, debug_panic, maybe}; use workspace::SplitDirection; use workspace::item::SaveOptions; use workspace::{ @@ -63,13 +65,14 @@ pub enum DebugPanelEvent { pub struct DebugPanel { size: Pixels, - sessions: Vec>, active_session: Option>, project: Entity, workspace: WeakEntity, focus_handle: FocusHandle, context_menu: Option<(Entity, Point, Subscription)>, debug_scenario_scheduled_last: bool, + pub(crate) sessions_with_children: + IndexMap, Vec>>, pub(crate) thread_picker_menu_handle: PopoverMenuHandle, pub(crate) session_picker_menu_handle: PopoverMenuHandle, fs: Arc, @@ -100,7 +103,7 @@ impl DebugPanel { Self { size: px(300.), - sessions: vec![], + sessions_with_children: Default::default(), active_session: None, focus_handle, breakpoint_list: BreakpointList::new( @@ -138,8 +141,9 @@ impl DebugPanel { }); } - pub(crate) fn sessions(&self) -> Vec> { - self.sessions.clone() + #[cfg(test)] + pub(crate) fn sessions(&self) -> impl Iterator> { + self.sessions_with_children.keys().cloned() } pub fn active_session(&self) -> Option> { @@ -185,12 +189,20 @@ impl DebugPanel { cx: &mut Context, ) { let dap_store = self.project.read(cx).dap_store(); + let Some(adapter) = DapRegistry::global(cx).adapter(&scenario.adapter) else { + return; + }; + let quirks = SessionQuirks { + compact: adapter.compact_child_session(), + prefer_thread_name: adapter.prefer_thread_name(), + }; let session = dap_store.update(cx, |dap_store, cx| { dap_store.new_session( - scenario.label.clone(), + Some(scenario.label.clone()), DebugAdapterName(scenario.adapter.clone()), task_context.clone(), None, + quirks, cx, ) }); @@ -267,22 +279,34 @@ impl DebugPanel { } }); - cx.spawn(async move |_, cx| { - if let Err(error) = task.await { - log::error!("{error}"); - session - .update(cx, |session, cx| { - session - .console_output(cx) - .unbounded_send(format!("error: {}", error)) - .ok(); - session.shutdown(cx) - })? - .await; + let boot_task = cx.spawn({ + let session = session.clone(); + + async move |_, cx| { + if let Err(error) = task.await { + log::error!("{error}"); + session + .update(cx, |session, cx| { + session + .console_output(cx) + .unbounded_send(format!("error: {}", error)) + .ok(); + session.shutdown(cx) + })? + .await; + } + anyhow::Ok(()) } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + }); + + session.update(cx, |session, _| match &mut session.mode { + SessionState::Building(state_task) => { + *state_task = Some(boot_task); + } + SessionState::Running(_) => { + debug_panic!("Session state should be in building because we are just starting it"); + } + }); } pub(crate) fn rerun_last_session( @@ -363,14 +387,15 @@ impl DebugPanel { }; let dap_store_handle = self.project.read(cx).dap_store().clone(); - let label = curr_session.read(cx).label().clone(); + let label = curr_session.read(cx).label(); + let quirks = curr_session.read(cx).quirks(); let adapter = curr_session.read(cx).adapter().clone(); let binary = curr_session.read(cx).binary().cloned().unwrap(); let task_context = curr_session.read(cx).task_context().clone(); let curr_session_id = curr_session.read(cx).session_id(); - self.sessions - .retain(|session| session.read(cx).session_id(cx) != curr_session_id); + self.sessions_with_children + .retain(|session, _| session.read(cx).session_id(cx) != curr_session_id); let task = dap_store_handle.update(cx, |dap_store, cx| { dap_store.shutdown_session(curr_session_id, cx) }); @@ -379,7 +404,7 @@ impl DebugPanel { task.await.log_err(); let (session, task) = dap_store_handle.update(cx, |dap_store, cx| { - let session = dap_store.new_session(label, adapter, task_context, None, cx); + let session = dap_store.new_session(label, adapter, task_context, None, quirks, cx); let task = session.update(cx, |session, cx| { session.boot(binary, worktree, dap_store_handle.downgrade(), cx) @@ -425,6 +450,7 @@ impl DebugPanel { let dap_store_handle = self.project.read(cx).dap_store().clone(); let label = self.label_for_child_session(&parent_session, request, cx); let adapter = parent_session.read(cx).adapter().clone(); + let quirks = parent_session.read(cx).quirks(); let Some(mut binary) = parent_session.read(cx).binary().cloned() else { log::error!("Attempted to start a child-session without a binary"); return; @@ -438,6 +464,7 @@ impl DebugPanel { adapter, task_context, Some(parent_session.clone()), + quirks, cx, ); @@ -463,8 +490,8 @@ impl DebugPanel { cx: &mut Context, ) { let Some(session) = self - .sessions - .iter() + .sessions_with_children + .keys() .find(|other| entity_id == other.entity_id()) .cloned() else { @@ -498,15 +525,14 @@ impl DebugPanel { } session.update(cx, |session, cx| session.shutdown(cx)).ok(); this.update(cx, |this, cx| { - this.sessions.retain(|other| entity_id != other.entity_id()); - + this.retain_sessions(|other| entity_id != other.entity_id()); if let Some(active_session_id) = this .active_session .as_ref() .map(|session| session.entity_id()) { if active_session_id == entity_id { - this.active_session = this.sessions.first().cloned(); + this.active_session = this.sessions_with_children.keys().next().cloned(); } } cx.notify() @@ -622,6 +648,14 @@ impl DebugPanel { .on_click(move |_, _, cx| cx.open_url("https://zed.dev/docs/debugger")) .tooltip(Tooltip::text("Open Documentation")) }; + let logs_button = || { + IconButton::new("debug-open-logs", IconName::ScrollText) + .icon_size(IconSize::Small) + .on_click(move |_, window, cx| { + window.dispatch_action(debugger_tools::OpenDebugAdapterLogs.boxed_clone(), cx) + }) + .tooltip(Tooltip::text("Open Debug Adapter Logs")) + }; Some( div.border_b_1() @@ -805,13 +839,24 @@ impl DebugPanel { .on_click(window.listener_for( &running_state, |this, _, _window, cx| { - this.stop_thread(cx); + if this.session().read(cx).is_building() { + this.session().update(cx, |session, cx| { + session.shutdown(cx).detach() + }); + } else { + this.stop_thread(cx); + } + }, + )) + .disabled(active_session.as_ref().is_none_or( + |session| { + session + .read(cx) + .session(cx) + .read(cx) + .is_terminated() }, )) - .disabled( - thread_status != ThreadStatus::Stopped - && thread_status != ThreadStatus::Running, - ) .tooltip({ let focus_handle = focus_handle.clone(); let label = if capabilities @@ -873,6 +918,7 @@ impl DebugPanel { .justify_around() .when(is_side, |this| { this.child(new_session_button()) + .child(logs_button()) .child(documentation_button()) }), ) @@ -922,6 +968,7 @@ impl DebugPanel { )) .when(!is_side, |this| { this.child(new_session_button()) + .child(logs_button()) .child(documentation_button()) }), ), @@ -966,8 +1013,8 @@ impl DebugPanel { cx: &mut Context, ) { if let Some(session) = self - .sessions - .iter() + .sessions_with_children + .keys() .find(|session| session.read(cx).session_id(cx) == session_id) { self.activate_session(session.clone(), window, cx); @@ -980,7 +1027,7 @@ impl DebugPanel { window: &mut Window, cx: &mut Context, ) { - debug_assert!(self.sessions.contains(&session_item)); + debug_assert!(self.sessions_with_children.contains_key(&session_item)); session_item.focus_handle(cx).focus(window); session_item.update(cx, |this, cx| { this.running_state().update(cx, |this, cx| { @@ -1251,18 +1298,27 @@ impl DebugPanel { parent_session: &Entity, request: &StartDebuggingRequestArguments, cx: &mut Context<'_, Self>, - ) -> SharedString { + ) -> Option { let adapter = parent_session.read(cx).adapter(); if let Some(adapter) = DapRegistry::global(cx).adapter(&adapter) { if let Some(label) = adapter.label_for_child_session(request) { - return label.into(); + return Some(label.into()); } } - let mut label = parent_session.read(cx).label().clone(); - if !label.ends_with("(child)") { - label = format!("{label} (child)").into(); + None + } + + fn retain_sessions(&mut self, keep: impl Fn(&Entity) -> bool) { + self.sessions_with_children + .retain(|session, _| keep(session)); + for children in self.sessions_with_children.values_mut() { + children.retain(|child| { + let Some(child) = child.upgrade() else { + return false; + }; + keep(&child) + }); } - label } } @@ -1292,11 +1348,11 @@ async fn register_session_inner( let serialized_layout = persistence::get_serialized_layout(adapter_name).await; let debug_session = this.update_in(cx, |this, window, cx| { let parent_session = this - .sessions - .iter() + .sessions_with_children + .keys() .find(|p| Some(p.read(cx).session_id(cx)) == session.read(cx).parent_id(cx)) .cloned(); - this.sessions.retain(|session| { + this.retain_sessions(|session| { !session .read(cx) .running_state() @@ -1327,13 +1383,23 @@ async fn register_session_inner( ) .detach(); let insert_position = this - .sessions - .iter() + .sessions_with_children + .keys() .position(|session| Some(session) == parent_session.as_ref()) .map(|position| position + 1) - .unwrap_or(this.sessions.len()); + .unwrap_or(this.sessions_with_children.len()); // Maintain topological sort order of sessions - this.sessions.insert(insert_position, debug_session.clone()); + let (_, old) = this.sessions_with_children.insert_before( + insert_position, + debug_session.clone(), + Default::default(), + ); + debug_assert!(old.is_none()); + if let Some(parent_session) = parent_session { + this.sessions_with_children + .entry(parent_session) + .and_modify(|children| children.push(debug_session.downgrade())); + } debug_session })?; @@ -1373,7 +1439,7 @@ impl Panel for DebugPanel { cx: &mut Context, ) { if position.axis() != self.position(window, cx).axis() { - self.sessions.iter().for_each(|session_item| { + self.sessions_with_children.keys().for_each(|session_item| { session_item.update(cx, |item, cx| { item.running_state() .update(cx, |state, _| state.invert_axies()) @@ -1694,6 +1760,7 @@ impl Render for DebugPanel { category_filter: Some( zed_actions::ExtensionCategoryFilter::DebugAdapters, ), + id: None, } .boxed_clone(), cx, @@ -1739,6 +1806,7 @@ impl Render for DebugPanel { .child(breakpoint_list) .child(Divider::vertical()) .child(welcome_experience) + .child(Divider::vertical()) } else { this.items_end() .child(welcome_experience) diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 2056232e9b..9eac59af83 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -3,10 +3,12 @@ use std::any::TypeId; use dap::debugger_settings::DebuggerSettings; use debugger_panel::DebugPanel; use editor::Editor; -use gpui::{App, DispatchPhase, EntityInputHandler, actions}; +use gpui::{Action, App, DispatchPhase, EntityInputHandler, actions}; use new_process_modal::{NewProcessModal, NewProcessMode}; use onboarding_modal::DebuggerOnboardingModal; use project::debugger::{self, breakpoint_store::SourceBreakpoint, session::ThreadStatus}; +use schemars::JsonSchema; +use serde::Deserialize; use session::DebugSession; use settings::Settings; use stack_trace_view::StackTraceView; @@ -86,6 +88,20 @@ actions!( ] ); +/// Extends selection down by a specified number of lines. +#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] +#[action(namespace = debugger)] +#[serde(deny_unknown_fields)] +/// Set a data breakpoint on the selected variable or memory region. +pub struct ToggleDataBreakpoint { + /// The type of data breakpoint + /// Read & Write + /// Read + /// Write + #[serde(default)] + pub access_type: Option, +} + actions!( dev, [ diff --git a/crates/debugger_ui/src/dropdown_menus.rs b/crates/debugger_ui/src/dropdown_menus.rs index f93aceae09..dca15eb052 100644 --- a/crates/debugger_ui/src/dropdown_menus.rs +++ b/crates/debugger_ui/src/dropdown_menus.rs @@ -1,16 +1,82 @@ -use std::time::Duration; +use std::{rc::Rc, time::Duration}; use collections::HashMap; -use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage}; +use gpui::{Animation, AnimationExt as _, Entity, Transformation, WeakEntity, percentage}; use project::debugger::session::{ThreadId, ThreadStatus}; use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*}; -use util::truncate_and_trailoff; +use util::{maybe, truncate_and_trailoff}; use crate::{ debugger_panel::DebugPanel, session::{DebugSession, running::RunningState}, }; +struct SessionListEntry { + ancestors: Vec>, + leaf: Entity, +} + +impl SessionListEntry { + pub(crate) fn label_element(&self, depth: usize, cx: &mut App) -> AnyElement { + const MAX_LABEL_CHARS: usize = 150; + + let mut label = String::new(); + for ancestor in &self.ancestors { + label.push_str(&ancestor.update(cx, |ancestor, cx| { + ancestor.label(cx).unwrap_or("(child)".into()) + })); + label.push_str(" Β» "); + } + label.push_str( + &self + .leaf + .update(cx, |leaf, cx| leaf.label(cx).unwrap_or("(child)".into())), + ); + let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS); + + let is_terminated = self + .leaf + .read(cx) + .running_state + .read(cx) + .session() + .read(cx) + .is_terminated(); + let icon = { + if is_terminated { + Some(Indicator::dot().color(Color::Error)) + } else { + match self + .leaf + .read(cx) + .running_state + .read(cx) + .thread_status(cx) + .unwrap_or_default() + { + project::debugger::session::ThreadStatus::Stopped => { + Some(Indicator::dot().color(Color::Conflict)) + } + _ => Some(Indicator::dot().color(Color::Success)), + } + } + }; + + h_flex() + .id("session-label") + .ml(depth * px(16.0)) + .gap_2() + .when_some(icon, |this, indicator| this.child(indicator)) + .justify_between() + .child( + Label::new(label) + .size(LabelSize::Small) + .when(is_terminated, |this| this.strikethrough()), + ) + .into_any_element() + } +} + impl DebugPanel { fn dropdown_label(label: impl Into) -> Label { const MAX_LABEL_CHARS: usize = 50; @@ -25,145 +91,205 @@ impl DebugPanel { window: &mut Window, cx: &mut Context, ) -> Option { - if let Some(running_state) = running_state { - let sessions = self.sessions().clone(); - let weak = cx.weak_entity(); - let running_state = running_state.read(cx); - let label = if let Some(active_session) = active_session.clone() { - active_session.read(cx).session(cx).read(cx).label() - } else { - SharedString::new_static("Unknown Session") - }; + let running_state = running_state?; - let is_terminated = running_state.session().read(cx).is_terminated(); - let is_started = active_session - .is_some_and(|session| session.read(cx).session(cx).read(cx).is_started()); + let mut session_entries = Vec::with_capacity(self.sessions_with_children.len() * 3); + let mut sessions_with_children = self.sessions_with_children.iter().peekable(); - let session_state_indicator = if is_terminated { - Indicator::dot().color(Color::Error).into_any_element() - } else if !is_started { - Icon::new(IconName::ArrowCircle) - .size(IconSize::Small) - .color(Color::Muted) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ) - .into_any_element() + while let Some((root, children)) = sessions_with_children.next() { + let root_entry = if let Ok([single_child]) = <&[_; 1]>::try_from(children.as_slice()) + && let Some(single_child) = single_child.upgrade() + && single_child.read(cx).quirks.compact + { + sessions_with_children.next(); + SessionListEntry { + leaf: single_child.clone(), + ancestors: vec![root.clone()], + } } else { - match running_state.thread_status(cx).unwrap_or_default() { - ThreadStatus::Stopped => { - Indicator::dot().color(Color::Conflict).into_any_element() - } - _ => Indicator::dot().color(Color::Success).into_any_element(), + SessionListEntry { + leaf: root.clone(), + ancestors: Vec::new(), } }; + session_entries.push(root_entry); - let trigger = h_flex() - .gap_2() - .child(session_state_indicator) - .justify_between() - .child( - DebugPanel::dropdown_label(label) - .when(is_terminated, |this| this.strikethrough()), - ) - .into_any_element(); - - Some( - DropdownMenu::new_with_element( - "debugger-session-list", - trigger, - ContextMenu::build(window, cx, move |mut this, _, cx| { - let context_menu = cx.weak_entity(); - let mut session_depths = HashMap::default(); - for session in sessions.into_iter() { - let weak_session = session.downgrade(); - let weak_session_id = weak_session.entity_id(); - let session_id = session.read(cx).session_id(cx); - let parent_depth = session - .read(cx) - .session(cx) - .read(cx) - .parent_id(cx) - .and_then(|parent_id| session_depths.get(&parent_id).cloned()); - let self_depth = - *session_depths.entry(session_id).or_insert_with(|| { - parent_depth.map(|depth| depth + 1).unwrap_or(0usize) - }); - this = this.custom_entry( - { - let weak = weak.clone(); - let context_menu = context_menu.clone(); - move |_, cx| { - weak_session - .read_with(cx, |session, cx| { - let context_menu = context_menu.clone(); - - let id: SharedString = - format!("debug-session-{}", session_id.0) - .into(); - - h_flex() - .w_full() - .group(id.clone()) - .justify_between() - .child(session.label_element(self_depth, cx)) - .child( - IconButton::new( - "close-debug-session", - IconName::Close, - ) - .visible_on_hover(id.clone()) - .icon_size(IconSize::Small) - .on_click({ - let weak = weak.clone(); - move |_, window, cx| { - weak.update(cx, |panel, cx| { - panel.close_session( - weak_session_id, - window, - cx, - ); - }) - .ok(); - context_menu - .update(cx, |this, cx| { - this.cancel( - &Default::default(), - window, - cx, - ); - }) - .ok(); - } - }), - ) - .into_any_element() - }) - .unwrap_or_else(|_| div().into_any_element()) - } - }, - { - let weak = weak.clone(); - move |window, cx| { - weak.update(cx, |panel, cx| { - panel.activate_session(session.clone(), window, cx); - }) - .ok(); - } - }, - ); - } - this + session_entries.extend( + sessions_with_children + .by_ref() + .take_while(|(session, _)| { + session + .read(cx) + .session(cx) + .read(cx) + .parent_id(cx) + .is_some() + }) + .map(|(session, _)| SessionListEntry { + leaf: session.clone(), + ancestors: vec![], }), - ) - .style(DropdownStyle::Ghost) - .handle(self.session_picker_menu_handle.clone()), - ) - } else { - None + ); } + + let weak = cx.weak_entity(); + let trigger_label = if let Some(active_session) = active_session.clone() { + active_session.update(cx, |active_session, cx| { + active_session.label(cx).unwrap_or("(child)".into()) + }) + } else { + SharedString::new_static("Unknown Session") + }; + let running_state = running_state.read(cx); + + let is_terminated = running_state.session().read(cx).is_terminated(); + let is_started = active_session + .is_some_and(|session| session.read(cx).session(cx).read(cx).is_started()); + + let session_state_indicator = if is_terminated { + Indicator::dot().color(Color::Error).into_any_element() + } else if !is_started { + Icon::new(IconName::ArrowCircle) + .size(IconSize::Small) + .color(Color::Muted) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), + ) + .into_any_element() + } else { + match running_state.thread_status(cx).unwrap_or_default() { + ThreadStatus::Stopped => Indicator::dot().color(Color::Conflict).into_any_element(), + _ => Indicator::dot().color(Color::Success).into_any_element(), + } + }; + + let trigger = h_flex() + .gap_2() + .child(session_state_indicator) + .justify_between() + .child( + DebugPanel::dropdown_label(trigger_label) + .when(is_terminated, |this| this.strikethrough()), + ) + .into_any_element(); + + let menu = DropdownMenu::new_with_element( + "debugger-session-list", + trigger, + ContextMenu::build(window, cx, move |mut this, _, cx| { + let context_menu = cx.weak_entity(); + let mut session_depths = HashMap::default(); + for session_entry in session_entries { + let session_id = session_entry.leaf.read(cx).session_id(cx); + let parent_depth = session_entry + .ancestors + .first() + .unwrap_or(&session_entry.leaf) + .read(cx) + .session(cx) + .read(cx) + .parent_id(cx) + .and_then(|parent_id| session_depths.get(&parent_id).cloned()); + let self_depth = *session_depths + .entry(session_id) + .or_insert_with(|| parent_depth.map(|depth| depth + 1).unwrap_or(0usize)); + this = this.custom_entry( + { + let weak = weak.clone(); + let context_menu = context_menu.clone(); + let ancestors: Rc<[_]> = session_entry + .ancestors + .iter() + .map(|session| session.downgrade()) + .collect(); + let leaf = session_entry.leaf.downgrade(); + move |window, cx| { + Self::render_session_menu_entry( + weak.clone(), + context_menu.clone(), + ancestors.clone(), + leaf.clone(), + self_depth, + window, + cx, + ) + } + }, + { + let weak = weak.clone(); + let leaf = session_entry.leaf.clone(); + move |window, cx| { + weak.update(cx, |panel, cx| { + panel.activate_session(leaf.clone(), window, cx); + }) + .ok(); + } + }, + ); + } + this + }), + ) + .style(DropdownStyle::Ghost) + .handle(self.session_picker_menu_handle.clone()); + + Some(menu) + } + + fn render_session_menu_entry( + weak: WeakEntity, + context_menu: WeakEntity, + ancestors: Rc<[WeakEntity]>, + leaf: WeakEntity, + self_depth: usize, + _window: &mut Window, + cx: &mut App, + ) -> AnyElement { + let Some(session_entry) = maybe!({ + let ancestors = ancestors + .iter() + .map(|ancestor| ancestor.upgrade()) + .collect::>>()?; + let leaf = leaf.upgrade()?; + Some(SessionListEntry { ancestors, leaf }) + }) else { + return div().into_any_element(); + }; + + let id: SharedString = format!( + "debug-session-{}", + session_entry.leaf.read(cx).session_id(cx).0 + ) + .into(); + let session_entity_id = session_entry.leaf.entity_id(); + + h_flex() + .w_full() + .group(id.clone()) + .justify_between() + .child(session_entry.label_element(self_depth, cx)) + .child( + IconButton::new("close-debug-session", IconName::Close) + .visible_on_hover(id.clone()) + .icon_size(IconSize::Small) + .on_click({ + let weak = weak.clone(); + move |_, window, cx| { + weak.update(cx, |panel, cx| { + panel.close_session(session_entity_id, window, cx); + }) + .ok(); + context_menu + .update(cx, |this, cx| { + this.cancel(&Default::default(), window, cx); + }) + .ok(); + } + }), + ) + .into_any_element() } pub(crate) fn render_thread_dropdown( diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 6d7fa244a2..42f77ab056 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -766,14 +766,7 @@ impl Render for NewProcessModal { )) .child( h_flex() - .child(div().child(self.adapter_drop_down_menu(window, cx))) - .child( - Button::new("debugger-spawn", "Start") - .on_click(cx.listener(|this, _, window, cx| { - this.start_new_session(window, cx) - })) - .disabled(disabled), - ), + .child(div().child(self.adapter_drop_down_menu(window, cx))), ) }), NewProcessMode::Debug => el, diff --git a/crates/debugger_ui/src/persistence.rs b/crates/debugger_ui/src/persistence.rs index d15244c349..3a0ad7a40e 100644 --- a/crates/debugger_ui/src/persistence.rs +++ b/crates/debugger_ui/src/persistence.rs @@ -11,7 +11,7 @@ use workspace::{Member, Pane, PaneAxis, Workspace}; use crate::session::running::{ self, DebugTerminal, RunningState, SubView, breakpoint_list::BreakpointList, console::Console, - loaded_source_list::LoadedSourceList, module_list::ModuleList, + loaded_source_list::LoadedSourceList, memory_view::MemoryView, module_list::ModuleList, stack_frame_list::StackFrameList, variable_list::VariableList, }; @@ -24,6 +24,7 @@ pub(crate) enum DebuggerPaneItem { Modules, LoadedSources, Terminal, + MemoryView, } impl DebuggerPaneItem { @@ -36,6 +37,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Modules, DebuggerPaneItem::LoadedSources, DebuggerPaneItem::Terminal, + DebuggerPaneItem::MemoryView, ]; VARIANTS } @@ -43,6 +45,9 @@ impl DebuggerPaneItem { pub(crate) fn is_supported(&self, capabilities: &Capabilities) -> bool { match self { DebuggerPaneItem::Modules => capabilities.supports_modules_request.unwrap_or_default(), + DebuggerPaneItem::MemoryView => capabilities + .supports_read_memory_request + .unwrap_or_default(), DebuggerPaneItem::LoadedSources => capabilities .supports_loaded_sources_request .unwrap_or_default(), @@ -59,6 +64,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Modules => SharedString::new_static("Modules"), DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"), DebuggerPaneItem::Terminal => SharedString::new_static("Terminal"), + DebuggerPaneItem::MemoryView => SharedString::new_static("Memory View"), } } pub(crate) fn tab_tooltip(self) -> SharedString { @@ -80,6 +86,7 @@ impl DebuggerPaneItem { DebuggerPaneItem::Terminal => { "Provides an interactive terminal session within the debugging environment." } + DebuggerPaneItem::MemoryView => "Allows inspection of memory contents.", }; SharedString::new_static(tooltip) } @@ -204,6 +211,7 @@ pub(crate) fn deserialize_pane_layout( breakpoint_list: &Entity, loaded_sources: &Entity, terminal: &Entity, + memory_view: &Entity, subscriptions: &mut HashMap, window: &mut Window, cx: &mut Context, @@ -228,6 +236,7 @@ pub(crate) fn deserialize_pane_layout( breakpoint_list, loaded_sources, terminal, + memory_view, subscriptions, window, cx, @@ -298,6 +307,12 @@ pub(crate) fn deserialize_pane_layout( DebuggerPaneItem::Terminal, cx, )), + DebuggerPaneItem::MemoryView => Box::new(SubView::new( + memory_view.focus_handle(cx), + memory_view.clone().into(), + DebuggerPaneItem::MemoryView, + cx, + )), }) .collect(); diff --git a/crates/debugger_ui/src/session.rs b/crates/debugger_ui/src/session.rs index 482297b136..73cfef78cc 100644 --- a/crates/debugger_ui/src/session.rs +++ b/crates/debugger_ui/src/session.rs @@ -5,14 +5,13 @@ use dap::client::SessionId; use gpui::{ App, Axis, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, }; -use project::Project; use project::debugger::session::Session; use project::worktree_store::WorktreeStore; +use project::{Project, debugger::session::SessionQuirks}; use rpc::proto; use running::RunningState; -use std::{cell::OnceCell, sync::OnceLock}; -use ui::{Indicator, Tooltip, prelude::*}; -use util::truncate_and_trailoff; +use std::cell::OnceCell; +use ui::prelude::*; use workspace::{ CollaboratorId, FollowableItem, ViewId, Workspace, item::{self, Item}, @@ -20,8 +19,8 @@ use workspace::{ pub struct DebugSession { remote_id: Option, - running_state: Entity, - label: OnceLock, + pub(crate) running_state: Entity, + pub(crate) quirks: SessionQuirks, stack_trace_view: OnceCell>, _worktree_store: WeakEntity, workspace: WeakEntity, @@ -57,6 +56,7 @@ impl DebugSession { cx, ) }); + let quirks = session.read(cx).quirks(); cx.new(|cx| Self { _subscriptions: [cx.subscribe(&running_state, |_, _, _, cx| { @@ -64,7 +64,7 @@ impl DebugSession { })], remote_id: None, running_state, - label: OnceLock::new(), + quirks, stack_trace_view: OnceCell::new(), _worktree_store: project.read(cx).worktree_store().downgrade(), workspace, @@ -110,65 +110,28 @@ impl DebugSession { .update(cx, |state, cx| state.shutdown(cx)); } - pub(crate) fn label(&self, cx: &App) -> SharedString { - if let Some(label) = self.label.get() { - return label.clone(); - } - - let session = self.running_state.read(cx).session(); - - self.label - .get_or_init(|| session.read(cx).label()) - .to_owned() - } - - pub(crate) fn running_state(&self) -> &Entity { - &self.running_state - } - - pub(crate) fn label_element(&self, depth: usize, cx: &App) -> AnyElement { - const MAX_LABEL_CHARS: usize = 150; - - let label = self.label(cx); - let label = truncate_and_trailoff(&label, MAX_LABEL_CHARS); - - let is_terminated = self - .running_state - .read(cx) - .session() - .read(cx) - .is_terminated(); - let icon = { - if is_terminated { - Some(Indicator::dot().color(Color::Error)) - } else { - match self - .running_state - .read(cx) - .thread_status(cx) - .unwrap_or_default() - { - project::debugger::session::ThreadStatus::Stopped => { - Some(Indicator::dot().color(Color::Conflict)) - } - _ => Some(Indicator::dot().color(Color::Success)), + pub(crate) fn label(&self, cx: &mut App) -> Option { + let session = self.running_state.read(cx).session().clone(); + session.update(cx, |session, cx| { + let session_label = session.label(); + let quirks = session.quirks(); + let mut single_thread_name = || { + let threads = session.threads(cx); + match threads.as_slice() { + [(thread, _)] => Some(SharedString::from(&thread.name)), + _ => None, } + }; + if quirks.prefer_thread_name { + single_thread_name().or(session_label) + } else { + session_label.or_else(single_thread_name) } - }; + }) + } - h_flex() - .id("session-label") - .tooltip(Tooltip::text(format!("Session {}", self.session_id(cx).0,))) - .ml(depth * px(16.0)) - .gap_2() - .when_some(icon, |this, indicator| this.child(indicator)) - .justify_between() - .child( - Label::new(label) - .size(LabelSize::Small) - .when(is_terminated, |this| this.strikethrough()), - ) - .into_any_element() + pub fn running_state(&self) -> &Entity { + &self.running_state } } diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index af8c14aef7..2651a94520 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1,16 +1,17 @@ pub(crate) mod breakpoint_list; pub(crate) mod console; pub(crate) mod loaded_source_list; +pub(crate) mod memory_view; pub(crate) mod module_list; pub mod stack_frame_list; pub mod variable_list; - use std::{any::Any, ops::ControlFlow, path::PathBuf, sync::Arc, time::Duration}; use crate::{ ToggleExpandItem, new_process_modal::resolve_path, persistence::{self, DebuggerPaneItem, SerializedLayout}, + session::running::memory_view::MemoryView, }; use super::DebugPanelItemEvent; @@ -34,7 +35,7 @@ use loaded_source_list::LoadedSourceList; use module_list::ModuleList; use project::{ DebugScenarioContext, Project, WorktreeId, - debugger::session::{Session, SessionEvent, ThreadId, ThreadStatus}, + debugger::session::{self, Session, SessionEvent, SessionStateEvent, ThreadId, ThreadStatus}, terminals::TerminalKind, }; use rpc::proto::ViewId; @@ -81,6 +82,7 @@ pub struct RunningState { _schedule_serialize: Option>, pub(crate) scenario: Option, pub(crate) scenario_context: Option, + memory_view: Entity, } impl RunningState { @@ -676,14 +678,36 @@ impl RunningState { let session_id = session.read(cx).session_id(); let weak_state = cx.weak_entity(); let stack_frame_list = cx.new(|cx| { - StackFrameList::new(workspace.clone(), session.clone(), weak_state, window, cx) + StackFrameList::new( + workspace.clone(), + session.clone(), + weak_state.clone(), + window, + cx, + ) }); let debug_terminal = parent_terminal.unwrap_or_else(|| cx.new(|cx| DebugTerminal::empty(window, cx))); - - let variable_list = - cx.new(|cx| VariableList::new(session.clone(), stack_frame_list.clone(), window, cx)); + let memory_view = cx.new(|cx| { + MemoryView::new( + session.clone(), + workspace.clone(), + stack_frame_list.downgrade(), + window, + cx, + ) + }); + let variable_list = cx.new(|cx| { + VariableList::new( + session.clone(), + stack_frame_list.clone(), + memory_view.clone(), + weak_state.clone(), + window, + cx, + ) + }); let module_list = cx.new(|cx| ModuleList::new(session.clone(), workspace.clone(), cx)); @@ -770,6 +794,15 @@ impl RunningState { cx.on_focus_out(&focus_handle, window, |this, _, window, cx| { this.serialize_layout(window, cx); }), + cx.subscribe( + &session, + |this, session, event: &SessionStateEvent, cx| match event { + SessionStateEvent::Shutdown if session.read(cx).is_building() => { + this.shutdown(cx); + } + _ => {} + }, + ), ]; let mut pane_close_subscriptions = HashMap::default(); @@ -786,6 +819,7 @@ impl RunningState { &breakpoint_list, &loaded_source_list, &debug_terminal, + &memory_view, &mut pane_close_subscriptions, window, cx, @@ -814,6 +848,7 @@ impl RunningState { let active_pane = panes.first_pane(); Self { + memory_view, session, workspace, focus_handle, @@ -884,6 +919,7 @@ impl RunningState { let weak_project = project.downgrade(); let weak_workspace = workspace.downgrade(); let is_local = project.read(cx).is_local(); + cx.spawn_in(window, async move |this, cx| { let DebugScenario { adapter, @@ -1224,6 +1260,12 @@ impl RunningState { item_kind, cx, )), + DebuggerPaneItem::MemoryView => Box::new(SubView::new( + self.memory_view.focus_handle(cx), + self.memory_view.clone().into(), + item_kind, + cx, + )), } } @@ -1408,7 +1450,14 @@ impl RunningState { &self.module_list } - pub(crate) fn activate_item(&self, item: DebuggerPaneItem, window: &mut Window, cx: &mut App) { + pub(crate) fn activate_item( + &mut self, + item: DebuggerPaneItem, + window: &mut Window, + cx: &mut Context, + ) { + self.ensure_pane_item(item, window, cx); + let (variable_list_position, pane) = self .panes .panes() @@ -1420,9 +1469,10 @@ impl RunningState { .map(|view| (view, pane)) }) .unwrap(); + pane.update(cx, |this, cx| { this.activate_item(variable_list_position, true, true, window, cx); - }) + }); } #[cfg(test)] @@ -1459,7 +1509,7 @@ impl RunningState { } } - pub(crate) fn selected_thread_id(&self) -> Option { + pub fn selected_thread_id(&self) -> Option { self.thread_id } @@ -1599,9 +1649,21 @@ impl RunningState { }) .log_err(); - self.session.update(cx, |session, cx| { + let is_building = self.session.update(cx, |session, cx| { session.shutdown(cx).detach(); - }) + matches!(session.mode, session::SessionState::Building(_)) + }); + + if is_building { + self.debug_terminal.update(cx, |terminal, cx| { + if let Some(view) = terminal.terminal.as_ref() { + view.update(cx, |view, cx| { + view.terminal() + .update(cx, |terminal, _| terminal.kill_active_task()) + }) + } + }) + } } pub fn stop_thread(&self, cx: &mut Context) { diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 78c87db2e6..6ac4b1c878 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -24,10 +24,10 @@ use project::{ }; use ui::{ ActiveTheme, AnyElement, App, ButtonCommon, Clickable, Color, Context, Disableable, Div, - Divider, FluentBuilder as _, Icon, IconButton, IconName, IconSize, Indicator, - InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ListItem, ParentElement, - Render, RenderOnce, Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, - Styled, Toggleable, Tooltip, Window, div, h_flex, px, v_flex, + Divider, FluentBuilder as _, Icon, IconButton, IconName, IconSize, InteractiveElement, + IntoElement, Label, LabelCommon, LabelSize, ListItem, ParentElement, Render, RenderOnce, + Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable, + Tooltip, Window, div, h_flex, px, v_flex, }; use util::ResultExt; use workspace::Workspace; @@ -46,6 +46,7 @@ actions!( pub(crate) enum SelectedBreakpointKind { Source, Exception, + Data, } pub(crate) struct BreakpointList { workspace: WeakEntity, @@ -188,6 +189,9 @@ impl BreakpointList { BreakpointEntryKind::ExceptionBreakpoint(bp) => { (SelectedBreakpointKind::Exception, bp.is_enabled) } + BreakpointEntryKind::DataBreakpoint(bp) => { + (SelectedBreakpointKind::Data, bp.0.is_enabled) + } }) }) } @@ -391,7 +395,8 @@ impl BreakpointList { let row = line_breakpoint.breakpoint.row; self.go_to_line_breakpoint(path, row, window, cx); } - BreakpointEntryKind::ExceptionBreakpoint(_) => {} + BreakpointEntryKind::DataBreakpoint(_) + | BreakpointEntryKind::ExceptionBreakpoint(_) => {} } } @@ -421,6 +426,10 @@ impl BreakpointList { let id = exception_breakpoint.id.clone(); self.toggle_exception_breakpoint(&id, cx); } + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => { + let id = data_breakpoint.0.dap.data_id.clone(); + self.toggle_data_breakpoint(&id, cx); + } } cx.notify(); } @@ -441,7 +450,7 @@ impl BreakpointList { let row = line_breakpoint.breakpoint.row; self.edit_line_breakpoint(path, row, BreakpointEditAction::Toggle, cx); } - BreakpointEntryKind::ExceptionBreakpoint(_) => {} + _ => {} } cx.notify(); } @@ -490,6 +499,14 @@ impl BreakpointList { cx.notify(); } + fn toggle_data_breakpoint(&mut self, id: &str, cx: &mut Context) { + if let Some(session) = &self.session { + session.update(cx, |this, cx| { + this.toggle_data_breakpoint(&id, cx); + }); + } + } + fn toggle_exception_breakpoint(&mut self, id: &str, cx: &mut Context) { if let Some(session) = &self.session { session.update(cx, |this, cx| { @@ -642,6 +659,7 @@ impl BreakpointList { SelectedBreakpointKind::Exception => { "Exception Breakpoints cannot be removed from the breakpoint list" } + SelectedBreakpointKind::Data => "Remove data breakpoint from a breakpoint list", }); let toggle_label = selection_kind.map(|(_, is_enabled)| { if is_enabled { @@ -783,8 +801,20 @@ impl Render for BreakpointList { weak: weak.clone(), }) }); - self.breakpoints - .extend(breakpoints.chain(exception_breakpoints)); + let data_breakpoints = self.session.as_ref().into_iter().flat_map(|session| { + session + .read(cx) + .data_breakpoints() + .map(|state| BreakpointEntry { + kind: BreakpointEntryKind::DataBreakpoint(DataBreakpoint(state.clone())), + weak: weak.clone(), + }) + }); + self.breakpoints.extend( + breakpoints + .chain(data_breakpoints) + .chain(exception_breakpoints), + ); v_flex() .id("breakpoint-list") .key_context("BreakpointList") @@ -905,7 +935,11 @@ impl LineBreakpoint { .ok(); } }) - .child(Indicator::icon(Icon::new(icon_name)).color(Color::Debugger)) + .child( + Icon::new(icon_name) + .color(Color::Debugger) + .size(IconSize::XSmall), + ) .on_mouse_down(MouseButton::Left, move |_, _, _| {}); ListItem::new(SharedString::from(format!( @@ -996,6 +1030,103 @@ struct ExceptionBreakpoint { data: ExceptionBreakpointsFilter, is_enabled: bool, } +#[derive(Clone, Debug)] +struct DataBreakpoint(project::debugger::session::DataBreakpointState); + +impl DataBreakpoint { + fn render( + &self, + props: SupportedBreakpointProperties, + strip_mode: Option, + ix: usize, + is_selected: bool, + focus_handle: FocusHandle, + list: WeakEntity, + ) -> ListItem { + let color = if self.0.is_enabled { + Color::Debugger + } else { + Color::Muted + }; + let is_enabled = self.0.is_enabled; + let id = self.0.dap.data_id.clone(); + ListItem::new(SharedString::from(format!( + "data-breakpoint-ui-item-{}", + self.0.dap.data_id + ))) + .rounded() + .start_slot( + div() + .id(SharedString::from(format!( + "data-breakpoint-ui-item-{}-click-handler", + self.0.dap.data_id + ))) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + if is_enabled { + "Disable Data Breakpoint" + } else { + "Enable Data Breakpoint" + }, + &ToggleEnableBreakpoint, + &focus_handle, + window, + cx, + ) + } + }) + .on_click({ + let list = list.clone(); + move |_, _, cx| { + list.update(cx, |this, cx| { + this.toggle_data_breakpoint(&id, cx); + }) + .ok(); + } + }) + .cursor_pointer() + .child( + Icon::new(IconName::Binary) + .color(color) + .size(IconSize::Small), + ), + ) + .child( + h_flex() + .w_full() + .mr_4() + .py_0p5() + .justify_between() + .child( + v_flex() + .py_1() + .gap_1() + .min_h(px(26.)) + .justify_center() + .id(("data-breakpoint-label", ix)) + .child( + Label::new(self.0.context.human_readable_label()) + .size(LabelSize::Small) + .line_height_style(ui::LineHeightStyle::UiLabel), + ), + ) + .child(BreakpointOptionsStrip { + props, + breakpoint: BreakpointEntry { + kind: BreakpointEntryKind::DataBreakpoint(self.clone()), + weak: list, + }, + is_selected, + focus_handle, + strip_mode, + index: ix, + }), + ) + .toggle_state(is_selected) + } +} impl ExceptionBreakpoint { fn render( @@ -1062,7 +1193,11 @@ impl ExceptionBreakpoint { } }) .cursor_pointer() - .child(Indicator::icon(Icon::new(IconName::Flame)).color(color)), + .child( + Icon::new(IconName::Flame) + .color(color) + .size(IconSize::Small), + ), ) .child( h_flex() @@ -1105,6 +1240,7 @@ impl ExceptionBreakpoint { enum BreakpointEntryKind { LineBreakpoint(LineBreakpoint), ExceptionBreakpoint(ExceptionBreakpoint), + DataBreakpoint(DataBreakpoint), } #[derive(Clone, Debug)] @@ -1140,6 +1276,14 @@ impl BreakpointEntry { focus_handle, self.weak.clone(), ), + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => data_breakpoint.render( + props.for_data_breakpoints(), + strip_mode, + ix, + is_selected, + focus_handle, + self.weak.clone(), + ), } } @@ -1155,6 +1299,11 @@ impl BreakpointEntry { exception_breakpoint.id ) .into(), + BreakpointEntryKind::DataBreakpoint(data_breakpoint) => format!( + "data-breakpoint-control-strip--{}", + data_breakpoint.0.dap.data_id + ) + .into(), } } @@ -1172,8 +1321,8 @@ impl BreakpointEntry { BreakpointEntryKind::LineBreakpoint(line_breakpoint) => { line_breakpoint.breakpoint.condition.is_some() } - // We don't support conditions on exception breakpoints - BreakpointEntryKind::ExceptionBreakpoint(_) => false, + // We don't support conditions on exception/data breakpoints + _ => false, } } @@ -1225,6 +1374,10 @@ impl SupportedBreakpointProperties { // TODO: we don't yet support conditions for exception breakpoints at the data layer, hence all props are disabled here. Self::empty() } + fn for_data_breakpoints(self) -> Self { + // TODO: we don't yet support conditions for data breakpoints at the data layer, hence all props are disabled here. + Self::empty() + } } #[derive(IntoElement)] struct BreakpointOptionsStrip { diff --git a/crates/debugger_ui/src/session/running/console.rs b/crates/debugger_ui/src/session/running/console.rs index 9375c8820b..1385bec54e 100644 --- a/crates/debugger_ui/src/session/running/console.rs +++ b/crates/debugger_ui/src/session/running/console.rs @@ -12,7 +12,7 @@ use gpui::{ Action as _, AppContext, Context, Corner, Entity, FocusHandle, Focusable, HighlightStyle, Hsla, Render, Subscription, Task, TextStyle, WeakEntity, actions, }; -use language::{Buffer, CodeLabel, ToOffset}; +use language::{Anchor, Buffer, CodeLabel, TextBufferSnapshot, ToOffset}; use menu::{Confirm, SelectNext, SelectPrevious}; use project::{ Completion, CompletionResponse, @@ -637,27 +637,13 @@ impl ConsoleQueryBarCompletionProvider { }); let snapshot = buffer.read(cx).text_snapshot(); - let query = snapshot.text(); - let replace_range = { - let buffer_offset = buffer_position.to_offset(&snapshot); - let reversed_chars = snapshot.reversed_chars_for_range(0..buffer_offset); - let mut word_len = 0; - for ch in reversed_chars { - if ch.is_alphanumeric() || ch == '_' { - word_len += 1; - } else { - break; - } - } - let word_start_offset = buffer_offset - word_len; - let start_anchor = snapshot.anchor_at(word_start_offset, Bias::Left); - start_anchor..buffer_position - }; + let buffer_text = snapshot.text(); + cx.spawn(async move |_, cx| { const LIMIT: usize = 10; let matches = fuzzy::match_strings( &string_matches, - &query, + &buffer_text, true, true, LIMIT, @@ -672,7 +658,12 @@ impl ConsoleQueryBarCompletionProvider { let variable_value = variables.get(&string_match.string)?; Some(project::Completion { - replace_range: replace_range.clone(), + replace_range: Self::replace_range_for_completion( + &buffer_text, + buffer_position, + string_match.string.as_bytes(), + &snapshot, + ), new_text: string_match.string.clone(), label: CodeLabel { filter_range: 0..string_match.string.len(), @@ -697,6 +688,28 @@ impl ConsoleQueryBarCompletionProvider { }) } + fn replace_range_for_completion( + buffer_text: &String, + buffer_position: Anchor, + new_bytes: &[u8], + snapshot: &TextBufferSnapshot, + ) -> Range { + let buffer_offset = buffer_position.to_offset(&snapshot); + let buffer_bytes = &buffer_text.as_bytes()[0..buffer_offset]; + + let mut prefix_len = 0; + for i in (0..new_bytes.len()).rev() { + if buffer_bytes.ends_with(&new_bytes[0..i]) { + prefix_len = i; + break; + } + } + + let start = snapshot.clip_offset(buffer_offset - prefix_len, Bias::Left); + + snapshot.anchor_before(start)..buffer_position + } + const fn completion_type_score(completion_type: CompletionItemType) -> usize { match completion_type { CompletionItemType::Field | CompletionItemType::Property => 0, @@ -744,6 +757,8 @@ impl ConsoleQueryBarCompletionProvider { cx.background_executor().spawn(async move { let completions = completion_task.await?; + let buffer_text = snapshot.text(); + let completions = completions .into_iter() .map(|completion| { @@ -753,26 +768,14 @@ impl ConsoleQueryBarCompletionProvider { .as_ref() .unwrap_or(&completion.label) .to_owned(); - let buffer_text = snapshot.text(); - let buffer_bytes = buffer_text.as_bytes(); - let new_bytes = new_text.as_bytes(); - - let mut prefix_len = 0; - for i in (0..new_bytes.len()).rev() { - if buffer_bytes.ends_with(&new_bytes[0..i]) { - prefix_len = i; - break; - } - } - - let buffer_offset = buffer_position.to_offset(&snapshot); - let start = buffer_offset - prefix_len; - let start = snapshot.clip_offset(start, Bias::Left); - let start = snapshot.anchor_before(start); - let replace_range = start..buffer_position; project::Completion { - replace_range, + replace_range: Self::replace_range_for_completion( + &buffer_text, + buffer_position, + new_text.as_bytes(), + &snapshot, + ), new_text, label: CodeLabel { filter_range: 0..completion.label.len(), @@ -944,3 +947,64 @@ fn color_fetcher(color: ansi::Color) -> fn(&Theme) -> Hsla { }; color_fetcher } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::init_test; + use editor::test::editor_test_context::EditorTestContext; + use gpui::TestAppContext; + use language::Point; + + #[track_caller] + fn assert_completion_range( + input: &str, + expect: &str, + replacement: &str, + cx: &mut EditorTestContext, + ) { + cx.set_state(input); + + let buffer_position = + cx.editor(|editor, _, cx| editor.selections.newest::(cx).start); + + let snapshot = &cx.buffer_snapshot(); + + let replace_range = ConsoleQueryBarCompletionProvider::replace_range_for_completion( + &cx.buffer_text(), + snapshot.anchor_before(buffer_position), + replacement.as_bytes(), + &snapshot, + ); + + cx.update_editor(|editor, _, cx| { + editor.edit( + vec![( + snapshot.offset_for_anchor(&replace_range.start) + ..snapshot.offset_for_anchor(&replace_range.end), + replacement, + )], + cx, + ); + }); + + pretty_assertions::assert_eq!(expect, cx.display_text()); + } + + #[gpui::test] + async fn test_determine_completion_replace_range(cx: &mut TestAppContext) { + init_test(cx); + + let mut cx = EditorTestContext::new(cx).await; + + assert_completion_range("resΛ‡", "result", "result", &mut cx); + assert_completion_range("print(resΛ‡)", "print(result)", "result", &mut cx); + assert_completion_range("$author->nΛ‡", "$author->name", "$author->name", &mut cx); + assert_completion_range( + "$author->books[Λ‡", + "$author->books[0]", + "$author->books[0]", + &mut cx, + ); + } +} diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs new file mode 100644 index 0000000000..7b62a1d55d --- /dev/null +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -0,0 +1,984 @@ +use std::{ + cell::LazyCell, + fmt::Write, + ops::RangeInclusive, + sync::{Arc, LazyLock}, + time::Duration, +}; + +use editor::{Editor, EditorElement, EditorStyle}; +use gpui::{ + Action, AppContext, DismissEvent, DragMoveEvent, Empty, Entity, FocusHandle, Focusable, + MouseButton, Point, ScrollStrategy, ScrollWheelEvent, Stateful, Subscription, Task, TextStyle, + UniformList, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, point, + uniform_list, +}; +use notifications::status_toast::{StatusToast, ToastIcon}; +use project::debugger::{MemoryCell, dap_command::DataBreakpointContext, session::Session}; +use settings::Settings; +use theme::ThemeSettings; +use ui::{ + ActiveTheme, AnyElement, App, Color, Context, ContextMenu, Div, Divider, DropdownMenu, Element, + FluentBuilder, Icon, IconName, InteractiveElement, IntoElement, Label, LabelCommon, + ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString, + StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex, +}; +use util::ResultExt; +use workspace::Workspace; + +use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList}; + +actions!(debugger, [GoToSelectedAddress]); + +pub(crate) struct MemoryView { + workspace: WeakEntity, + scroll_handle: UniformListScrollHandle, + scroll_state: ScrollbarState, + show_scrollbar: bool, + stack_frame_list: WeakEntity, + hide_scrollbar_task: Option>, + focus_handle: FocusHandle, + view_state: ViewState, + query_editor: Entity, + session: Entity, + width_picker_handle: PopoverMenuHandle, + is_writing_memory: bool, + open_context_menu: Option<(Entity, Point, Subscription)>, +} + +impl Focusable for MemoryView { + fn focus_handle(&self, _: &ui::App) -> FocusHandle { + self.focus_handle.clone() + } +} +#[derive(Clone, Debug)] +struct Drag { + start_address: u64, + end_address: u64, +} + +impl Drag { + fn contains(&self, address: u64) -> bool { + let range = self.memory_range(); + range.contains(&address) + } + + fn memory_range(&self) -> RangeInclusive { + if self.start_address < self.end_address { + self.start_address..=self.end_address + } else { + self.end_address..=self.start_address + } + } +} +#[derive(Clone, Debug)] +enum SelectedMemoryRange { + DragUnderway(Drag), + DragComplete(Drag), +} + +impl SelectedMemoryRange { + fn contains(&self, address: u64) -> bool { + match self { + SelectedMemoryRange::DragUnderway(drag) => drag.contains(address), + SelectedMemoryRange::DragComplete(drag) => drag.contains(address), + } + } + fn is_dragging(&self) -> bool { + matches!(self, SelectedMemoryRange::DragUnderway(_)) + } + fn drag(&self) -> &Drag { + match self { + SelectedMemoryRange::DragUnderway(drag) => drag, + SelectedMemoryRange::DragComplete(drag) => drag, + } + } +} + +#[derive(Clone)] +struct ViewState { + /// Uppermost row index + base_row: u64, + /// How many cells per row do we have? + line_width: ViewWidth, + selection: Option, +} + +impl ViewState { + fn new(base_row: u64, line_width: ViewWidth) -> Self { + Self { + base_row, + line_width, + selection: None, + } + } + fn row_count(&self) -> u64 { + // This was picked fully arbitrarily. There's no incentive for us to care about page sizes other than the fact that it seems to be a good + // middle ground for data size. + const PAGE_SIZE: u64 = 4096; + PAGE_SIZE / self.line_width.width as u64 + } + fn schedule_scroll_down(&mut self) { + self.base_row = self.base_row.saturating_add(1) + } + fn schedule_scroll_up(&mut self) { + self.base_row = self.base_row.saturating_sub(1); + } +} + +struct ScrollbarDragging; + +static HEX_BYTES_MEMOIZED: LazyLock<[SharedString; 256]> = + LazyLock::new(|| std::array::from_fn(|byte| SharedString::from(format!("{byte:02X}")))); +static UNKNOWN_BYTE: SharedString = SharedString::new_static("??"); +impl MemoryView { + pub(crate) fn new( + session: Entity, + workspace: WeakEntity, + stack_frame_list: WeakEntity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let view_state = ViewState::new(0, WIDTHS[4].clone()); + let scroll_handle = UniformListScrollHandle::default(); + + let query_editor = cx.new(|cx| Editor::single_line(window, cx)); + + let scroll_state = ScrollbarState::new(scroll_handle.clone()); + let mut this = Self { + workspace, + scroll_state, + scroll_handle, + stack_frame_list, + show_scrollbar: false, + hide_scrollbar_task: None, + focus_handle: cx.focus_handle(), + view_state, + query_editor, + session, + width_picker_handle: Default::default(), + is_writing_memory: true, + open_context_menu: None, + }; + this.change_query_bar_mode(false, window, cx); + cx.on_focus_out(&this.focus_handle, window, |this, _, window, cx| { + this.change_query_bar_mode(false, window, cx); + cx.notify(); + }) + .detach(); + this + } + fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context) { + const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); + self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| { + cx.background_executor() + .timer(SCROLLBAR_SHOW_INTERVAL) + .await; + panel + .update(cx, |panel, cx| { + panel.show_scrollbar = false; + cx.notify(); + }) + .log_err(); + })) + } + + fn render_vertical_scrollbar(&self, cx: &mut Context) -> Option> { + if !(self.show_scrollbar || self.scroll_state.is_dragging()) { + return None; + } + Some( + div() + .occlude() + .id("memory-view-vertical-scrollbar") + .on_drag_move(cx.listener(|this, evt, _, cx| { + let did_handle = this.handle_scroll_drag(evt); + cx.notify(); + if did_handle { + cx.stop_propagation() + } + })) + .on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty)) + .on_hover(|_, _, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _, cx| { + cx.stop_propagation(); + }) + .on_mouse_up( + MouseButton::Left, + cx.listener(|_, _, _, cx| { + cx.stop_propagation(); + }), + ) + .on_scroll_wheel(cx.listener(|_, _, _, cx| { + cx.notify(); + })) + .h_full() + .absolute() + .right_1() + .top_1() + .bottom_0() + .w(px(12.)) + .cursor_default() + .children(Scrollbar::vertical(self.scroll_state.clone())), + ) + } + + fn render_memory(&self, cx: &mut Context) -> UniformList { + let weak = cx.weak_entity(); + let session = self.session.clone(); + let view_state = self.view_state.clone(); + uniform_list( + "debugger-memory-view", + self.view_state.row_count() as usize, + move |range, _, cx| { + let mut line_buffer = Vec::with_capacity(view_state.line_width.width as usize); + let memory_start = + (view_state.base_row + range.start as u64) * view_state.line_width.width as u64; + let memory_end = (view_state.base_row + range.end as u64) + * view_state.line_width.width as u64 + - 1; + let mut memory = session.update(cx, |this, cx| { + this.read_memory(memory_start..=memory_end, cx) + }); + let mut rows = Vec::with_capacity(range.end - range.start); + for ix in range { + line_buffer.extend((&mut memory).take(view_state.line_width.width as usize)); + rows.push(render_single_memory_view_line( + &line_buffer, + ix as u64, + weak.clone(), + cx, + )); + line_buffer.clear(); + } + rows + }, + ) + .track_scroll(self.scroll_handle.clone()) + .on_scroll_wheel(cx.listener(|this, evt: &ScrollWheelEvent, window, _| { + let delta = evt.delta.pixel_delta(window.line_height()); + let scroll_handle = this.scroll_state.scroll_handle(); + let size = scroll_handle.content_size(); + let viewport = scroll_handle.viewport(); + let current_offset = scroll_handle.offset(); + let first_entry_offset_boundary = size.height / this.view_state.row_count() as f32; + let last_entry_offset_boundary = size.height - first_entry_offset_boundary; + if first_entry_offset_boundary + viewport.size.height > current_offset.y.abs() { + // The topmost entry is visible, hence if we're scrolling up, we need to load extra lines. + this.view_state.schedule_scroll_up(); + } else if last_entry_offset_boundary < current_offset.y.abs() + viewport.size.height { + this.view_state.schedule_scroll_down(); + } + scroll_handle.set_offset(current_offset + point(px(0.), delta.y)); + })) + } + fn render_query_bar(&self, cx: &Context) -> impl IntoElement { + EditorElement::new( + &self.query_editor, + Self::editor_style(&self.query_editor, cx), + ) + } + pub(super) fn go_to_memory_reference( + &mut self, + memory_reference: &str, + evaluate_name: Option<&str>, + stack_frame_id: Option, + cx: &mut Context, + ) { + use parse_int::parse; + let Ok(as_address) = parse::(&memory_reference) else { + return; + }; + let access_size = evaluate_name + .map(|typ| { + self.session.update(cx, |this, cx| { + this.data_access_size(stack_frame_id, typ, cx) + }) + }) + .unwrap_or_else(|| Task::ready(None)); + cx.spawn(async move |this, cx| { + let access_size = access_size.await.unwrap_or(1); + this.update(cx, |this, cx| { + this.view_state.selection = Some(SelectedMemoryRange::DragComplete(Drag { + start_address: as_address, + end_address: as_address + access_size - 1, + })); + this.jump_to_address(as_address, cx); + }) + .ok(); + }) + .detach(); + } + + fn handle_memory_drag(&mut self, evt: &DragMoveEvent) { + if !self + .view_state + .selection + .as_ref() + .is_some_and(|selection| selection.is_dragging()) + { + return; + } + let row_count = self.view_state.row_count(); + debug_assert!(row_count > 1); + let scroll_handle = self.scroll_state.scroll_handle(); + let viewport = scroll_handle.viewport(); + + if viewport.bottom() < evt.event.position.y { + self.view_state.schedule_scroll_down(); + } else if viewport.top() > evt.event.position.y { + self.view_state.schedule_scroll_up(); + } + } + + fn handle_scroll_drag(&mut self, evt: &DragMoveEvent) -> bool { + if !self.scroll_state.is_dragging() { + return false; + } + let row_count = self.view_state.row_count(); + debug_assert!(row_count > 1); + let scroll_handle = self.scroll_state.scroll_handle(); + let viewport = scroll_handle.viewport(); + + if viewport.bottom() < evt.event.position.y { + self.view_state.schedule_scroll_down(); + true + } else if viewport.top() > evt.event.position.y { + self.view_state.schedule_scroll_up(); + true + } else { + false + } + } + + fn editor_style(editor: &Entity, cx: &Context) -> EditorStyle { + let is_read_only = editor.read(cx).read_only(cx); + let settings = ThemeSettings::get_global(cx); + let theme = cx.theme(); + let text_style = TextStyle { + color: if is_read_only { + theme.colors().text_muted + } else { + theme.colors().text + }, + font_family: settings.buffer_font.family.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: TextSize::Small.rems(cx).into(), + font_weight: settings.buffer_font.weight, + + ..Default::default() + }; + EditorStyle { + background: theme.colors().editor_background, + local_player: theme.players().local(), + text: text_style, + ..Default::default() + } + } + + fn render_width_picker(&self, window: &mut Window, cx: &mut Context) -> DropdownMenu { + let weak = cx.weak_entity(); + let selected_width = self.view_state.line_width.clone(); + DropdownMenu::new( + "memory-view-width-picker", + selected_width.label.clone(), + ContextMenu::build(window, cx, |mut this, window, cx| { + for width in &WIDTHS { + let weak = weak.clone(); + let width = width.clone(); + this = this.entry(width.label.clone(), None, move |_, cx| { + _ = weak.update(cx, |this, _| { + // Convert base ix between 2 line widths to keep the shown memory address roughly the same. + // All widths are powers of 2, so the conversion should be lossless. + match this.view_state.line_width.width.cmp(&width.width) { + std::cmp::Ordering::Less => { + // We're converting up. + let shift = width.width.trailing_zeros() + - this.view_state.line_width.width.trailing_zeros(); + this.view_state.base_row >>= shift; + } + std::cmp::Ordering::Greater => { + // We're converting down. + let shift = this.view_state.line_width.width.trailing_zeros() + - width.width.trailing_zeros(); + this.view_state.base_row <<= shift; + } + _ => {} + } + this.view_state.line_width = width.clone(); + }); + }); + } + if let Some(ix) = WIDTHS + .iter() + .position(|width| width.width == selected_width.width) + { + for _ in 0..=ix { + this.select_next(&Default::default(), window, cx); + } + } + this + }), + ) + .handle(self.width_picker_handle.clone()) + } + + fn page_down(&mut self, _: &menu::SelectLast, _: &mut Window, cx: &mut Context) { + self.view_state.base_row = self + .view_state + .base_row + .overflowing_add(self.view_state.row_count()) + .0; + cx.notify(); + } + fn page_up(&mut self, _: &menu::SelectFirst, _: &mut Window, cx: &mut Context) { + self.view_state.base_row = self + .view_state + .base_row + .overflowing_sub(self.view_state.row_count()) + .0; + cx.notify(); + } + + fn change_query_bar_mode( + &mut self, + is_writing_memory: bool, + window: &mut Window, + cx: &mut Context, + ) { + if is_writing_memory == self.is_writing_memory { + return; + } + if !self.is_writing_memory { + self.query_editor.update(cx, |this, cx| { + this.clear(window, cx); + this.set_placeholder_text("Write to Selected Memory Range", cx); + }); + self.is_writing_memory = true; + self.query_editor.focus_handle(cx).focus(window); + } else { + self.query_editor.update(cx, |this, cx| { + this.clear(window, cx); + this.set_placeholder_text("Go to Memory Address / Expression", cx); + }); + self.is_writing_memory = false; + } + } + + fn toggle_data_breakpoint( + &mut self, + _: &crate::ToggleDataBreakpoint, + _: &mut Window, + cx: &mut Context, + ) { + let Some(SelectedMemoryRange::DragComplete(selection)) = self.view_state.selection.clone() + else { + return; + }; + let range = selection.memory_range(); + let context = Arc::new(DataBreakpointContext::Address { + address: range.start().to_string(), + bytes: Some(*range.end() - *range.start()), + }); + + self.session.update(cx, |this, cx| { + let data_breakpoint_info = this.data_breakpoint_info(context.clone(), None, cx); + cx.spawn(async move |this, cx| { + if let Some(info) = data_breakpoint_info.await { + let Some(data_id) = info.data_id.clone() else { + return; + }; + _ = this.update(cx, |this, cx| { + this.create_data_breakpoint( + context, + data_id.clone(), + dap::DataBreakpoint { + data_id, + access_type: None, + condition: None, + hit_condition: None, + }, + cx, + ); + }); + } + }) + .detach(); + }) + } + + fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + if let Some(SelectedMemoryRange::DragComplete(drag)) = &self.view_state.selection { + // Go into memory writing mode. + if !self.is_writing_memory { + let should_return = self.session.update(cx, |session, cx| { + if !session + .capabilities() + .supports_write_memory_request + .unwrap_or_default() + { + let adapter_name = session.adapter(); + // We cannot write memory with this adapter. + _ = self.workspace.update(cx, |this, cx| { + this.toggle_status_toast( + StatusToast::new(format!( + "Debug Adapter `{adapter_name}` does not support writing to memory" + ), cx, |this, cx| { + cx.spawn(async move |this, cx| { + cx.background_executor().timer(Duration::from_secs(2)).await; + _ = this.update(cx, |_, cx| { + cx.emit(DismissEvent) + }); + }).detach(); + this.icon(ToastIcon::new(IconName::XCircle).color(Color::Error)) + }), + cx, + ); + }); + true + } else { + false + } + }); + if should_return { + return; + } + + self.change_query_bar_mode(true, window, cx); + } else if self.query_editor.focus_handle(cx).is_focused(window) { + let mut text = self.query_editor.read(cx).text(cx); + if text.chars().any(|c| !c.is_ascii_hexdigit()) { + // Interpret this text as a string and oh-so-conveniently convert it. + text = text.bytes().map(|byte| format!("{:02x}", byte)).collect(); + } + self.session.update(cx, |this, cx| { + let range = drag.memory_range(); + + if let Ok(as_hex) = hex::decode(text) { + this.write_memory(*range.start(), &as_hex, cx); + } + }); + self.change_query_bar_mode(false, window, cx); + } + + cx.notify(); + return; + } + // Just change the currently viewed address. + if !self.query_editor.focus_handle(cx).is_focused(window) { + return; + } + self.jump_to_query_bar_address(cx); + } + + fn jump_to_query_bar_address(&mut self, cx: &mut Context) { + use parse_int::parse; + let text = self.query_editor.read(cx).text(cx); + + let Ok(as_address) = parse::(&text) else { + return self.jump_to_expression(text, cx); + }; + self.jump_to_address(as_address, cx); + } + + fn jump_to_address(&mut self, address: u64, cx: &mut Context) { + self.view_state.base_row = (address & !0xfff) / self.view_state.line_width.width as u64; + let line_ix = (address & 0xfff) / self.view_state.line_width.width as u64; + self.scroll_handle + .scroll_to_item(line_ix as usize, ScrollStrategy::Center); + cx.notify(); + } + + fn jump_to_expression(&mut self, expr: String, cx: &mut Context) { + let Ok(selected_frame) = self + .stack_frame_list + .update(cx, |this, _| this.opened_stack_frame_id()) + else { + return; + }; + let expr = format!("?${{{expr}}}"); + let reference = self.session.update(cx, |this, cx| { + this.memory_reference_of_expr(selected_frame, expr, cx) + }); + cx.spawn(async move |this, cx| { + if let Some((reference, typ)) = reference.await { + _ = this.update(cx, |this, cx| { + let sizeof_expr = if typ.as_ref().is_some_and(|t| { + t.chars() + .all(|c| c.is_whitespace() || c.is_alphabetic() || c == '*') + }) { + typ.as_deref() + } else { + None + }; + this.go_to_memory_reference(&reference, sizeof_expr, selected_frame, cx); + }); + } + }) + .detach(); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + self.view_state.selection = None; + cx.notify(); + } + + /// Jump to memory pointed to by selected memory range. + fn go_to_address( + &mut self, + _: &GoToSelectedAddress, + window: &mut Window, + cx: &mut Context, + ) { + let Some(SelectedMemoryRange::DragComplete(drag)) = self.view_state.selection.clone() + else { + return; + }; + let range = drag.memory_range(); + let Some(memory): Option> = self.session.update(cx, |this, cx| { + this.read_memory(range, cx).map(|cell| cell.0).collect() + }) else { + return; + }; + if memory.len() > 8 { + return; + } + let zeros_to_write = 8 - memory.len(); + let mut acc = String::from("0x"); + acc.extend(std::iter::repeat("00").take(zeros_to_write)); + let as_query = memory.into_iter().rev().fold(acc, |mut acc, byte| { + _ = write!(&mut acc, "{:02x}", byte); + acc + }); + self.query_editor.update(cx, |this, cx| { + this.set_text(as_query, window, cx); + }); + self.jump_to_query_bar_address(cx); + } + + fn deploy_memory_context_menu( + &mut self, + range: RangeInclusive, + position: Point, + window: &mut Window, + cx: &mut Context, + ) { + let session = self.session.clone(); + let context_menu = ContextMenu::build(window, cx, |menu, _, cx| { + let range_too_large = range.end() - range.start() > std::mem::size_of::() as u64; + let caps = session.read(cx).capabilities(); + let supports_data_breakpoints = caps.supports_data_breakpoints.unwrap_or_default() + && caps.supports_data_breakpoint_bytes.unwrap_or_default(); + let memory_unreadable = LazyCell::new(|| { + session.update(cx, |this, cx| { + this.read_memory(range.clone(), cx) + .any(|cell| cell.0.is_none()) + }) + }); + + let mut menu = menu.action_disabled_when( + range_too_large || *memory_unreadable, + "Go To Selected Address", + GoToSelectedAddress.boxed_clone(), + ); + + if supports_data_breakpoints { + menu = menu.action_disabled_when( + *memory_unreadable, + "Set Data Breakpoint", + ToggleDataBreakpoint { access_type: None }.boxed_clone(), + ); + } + menu.context(self.focus_handle.clone()) + }); + + cx.focus_view(&context_menu, window); + let subscription = cx.subscribe_in( + &context_menu, + window, + |this, _, _: &DismissEvent, window, cx| { + if this.open_context_menu.as_ref().is_some_and(|context_menu| { + context_menu.0.focus_handle(cx).contains_focused(window, cx) + }) { + cx.focus_self(window); + } + this.open_context_menu.take(); + cx.notify(); + }, + ); + + self.open_context_menu = Some((context_menu, position, subscription)); + } +} + +#[derive(Clone)] +struct ViewWidth { + width: u8, + label: SharedString, +} + +impl ViewWidth { + const fn new(width: u8, label: &'static str) -> Self { + Self { + width, + label: SharedString::new_static(label), + } + } +} + +static WIDTHS: [ViewWidth; 7] = [ + ViewWidth::new(1, "1 byte"), + ViewWidth::new(2, "2 bytes"), + ViewWidth::new(4, "4 bytes"), + ViewWidth::new(8, "8 bytes"), + ViewWidth::new(16, "16 bytes"), + ViewWidth::new(32, "32 bytes"), + ViewWidth::new(64, "64 bytes"), +]; + +fn render_single_memory_view_line( + memory: &[MemoryCell], + ix: u64, + weak: gpui::WeakEntity, + cx: &mut App, +) -> AnyElement { + let Ok(view_state) = weak.update(cx, |this, _| this.view_state.clone()) else { + return div().into_any(); + }; + let base_address = (view_state.base_row + ix) * view_state.line_width.width as u64; + + h_flex() + .id(( + "memory-view-row-full", + ix * view_state.line_width.width as u64, + )) + .size_full() + .gap_x_2() + .child( + div() + .child( + Label::new(format!("{:016X}", base_address)) + .buffer_font(cx) + .size(ui::LabelSize::Small) + .color(Color::Muted), + ) + .px_1() + .border_r_1() + .border_color(Color::Muted.color(cx)), + ) + .child( + h_flex() + .id(( + "memory-view-row-raw-memory", + ix * view_state.line_width.width as u64, + )) + .px_1() + .children(memory.iter().enumerate().map(|(cell_ix, cell)| { + let weak = weak.clone(); + div() + .id(("memory-view-row-raw-memory-cell", cell_ix as u64)) + .px_0p5() + .when_some(view_state.selection.as_ref(), |this, selection| { + this.when(selection.contains(base_address + cell_ix as u64), |this| { + let weak = weak.clone(); + + this.bg(Color::Selected.color(cx).opacity(0.2)).when( + !selection.is_dragging(), + |this| { + let selection = selection.drag().memory_range(); + this.on_mouse_down( + MouseButton::Right, + move |click, window, cx| { + _ = weak.update(cx, |this, cx| { + this.deploy_memory_context_menu( + selection.clone(), + click.position, + window, + cx, + ) + }); + cx.stop_propagation(); + }, + ) + }, + ) + }) + }) + .child( + Label::new( + cell.0 + .map(|val| HEX_BYTES_MEMOIZED[val as usize].clone()) + .unwrap_or_else(|| UNKNOWN_BYTE.clone()), + ) + .buffer_font(cx) + .when(cell.0.is_none(), |this| this.color(Color::Muted)) + .size(ui::LabelSize::Small), + ) + .on_drag( + Drag { + start_address: base_address + cell_ix as u64, + end_address: base_address + cell_ix as u64, + }, + { + let weak = weak.clone(); + move |drag, _, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragUnderway(drag.clone())); + }); + + cx.new(|_| Empty) + } + }, + ) + .on_drop({ + let weak = weak.clone(); + move |drag: &Drag, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragComplete(Drag { + start_address: drag.start_address, + end_address: base_address + cell_ix as u64, + })); + }); + } + }) + .drag_over(move |style, drag: &Drag, _, cx| { + _ = weak.update(cx, |this, _| { + this.view_state.selection = + Some(SelectedMemoryRange::DragUnderway(Drag { + start_address: drag.start_address, + end_address: base_address + cell_ix as u64, + })); + }); + + style + }) + })), + ) + .child( + h_flex() + .id(( + "memory-view-row-ascii-memory", + ix * view_state.line_width.width as u64, + )) + .h_full() + .px_1() + .mr_4() + // .gap_x_1p5() + .border_x_1() + .border_color(Color::Muted.color(cx)) + .children(memory.iter().enumerate().map(|(ix, cell)| { + let as_character = char::from(cell.0.unwrap_or(0)); + let as_visible = if as_character.is_ascii_graphic() { + as_character + } else { + 'Β·' + }; + div() + .px_0p5() + .when_some(view_state.selection.as_ref(), |this, selection| { + this.when(selection.contains(base_address + ix as u64), |this| { + this.bg(Color::Selected.color(cx).opacity(0.2)) + }) + }) + .child( + Label::new(format!("{as_visible}")) + .buffer_font(cx) + .when(cell.0.is_none(), |this| this.color(Color::Muted)) + .size(ui::LabelSize::Small), + ) + })), + ) + .into_any() +} + +impl Render for MemoryView { + fn render( + &mut self, + window: &mut ui::Window, + cx: &mut ui::Context, + ) -> impl ui::IntoElement { + let (icon, tooltip_text) = if self.is_writing_memory { + (IconName::Pencil, "Edit memory at a selected address") + } else { + ( + IconName::LocationEdit, + "Change address of currently viewed memory", + ) + }; + v_flex() + .id("Memory-view") + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::go_to_address)) + .p_1() + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::toggle_data_breakpoint)) + .on_action(cx.listener(Self::page_down)) + .on_action(cx.listener(Self::page_up)) + .size_full() + .track_focus(&self.focus_handle) + .on_hover(cx.listener(|this, hovered, window, cx| { + if *hovered { + this.show_scrollbar = true; + this.hide_scrollbar_task.take(); + cx.notify(); + } else if !this.focus_handle.contains_focused(window, cx) { + this.hide_scrollbar(window, cx); + } + })) + .child( + h_flex() + .w_full() + .mb_0p5() + .gap_1() + .child( + h_flex() + .w_full() + .rounded_md() + .border_1() + .gap_x_2() + .px_2() + .py_0p5() + .mb_0p5() + .bg(cx.theme().colors().editor_background) + .when_else( + self.query_editor + .focus_handle(cx) + .contains_focused(window, cx), + |this| this.border_color(cx.theme().colors().border_focused), + |this| this.border_color(cx.theme().colors().border_transparent), + ) + .child( + div() + .id("memory-view-editor-icon") + .child(Icon::new(icon).size(ui::IconSize::XSmall)) + .tooltip(Tooltip::text(tooltip_text)), + ) + .child(self.render_query_bar(cx)), + ) + .child(self.render_width_picker(window, cx)), + ) + .child(Divider::horizontal()) + .child( + v_flex() + .size_full() + .on_drag_move(cx.listener(|this, evt, _, _| { + this.handle_memory_drag(&evt); + })) + .child(self.render_memory(cx).size_full()) + .children(self.open_context_menu.as_ref().map(|(menu, position, _)| { + deferred( + anchored() + .position(*position) + .anchor(gpui::Corner::TopLeft) + .child(menu.clone()), + ) + .with_priority(1) + })) + .children(self.render_vertical_scrollbar(cx)), + ) + } +} diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index bdb095bde3..906e482687 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -1,3 +1,5 @@ +use crate::session::running::{RunningState, memory_view::MemoryView}; + use super::stack_frame_list::{StackFrameList, StackFrameListEvent}; use dap::{ ScopePresentationHint, StackFrameId, VariablePresentationHint, VariablePresentationHintKind, @@ -7,13 +9,17 @@ use editor::Editor; use gpui::{ Action, AnyElement, ClickEvent, ClipboardItem, Context, DismissEvent, Empty, Entity, FocusHandle, Focusable, Hsla, MouseButton, MouseDownEvent, Point, Stateful, Subscription, - TextStyleRefinement, UniformListScrollHandle, actions, anchored, deferred, uniform_list, + TextStyleRefinement, UniformListScrollHandle, WeakEntity, actions, anchored, deferred, + uniform_list, }; use menu::{SelectFirst, SelectLast, SelectNext, SelectPrevious}; -use project::debugger::session::{Session, SessionEvent, Watcher}; +use project::debugger::{ + dap_command::DataBreakpointContext, + session::{Session, SessionEvent, Watcher}, +}; use std::{collections::HashMap, ops::Range, sync::Arc}; use ui::{ContextMenu, ListItem, ScrollableHandle, Scrollbar, ScrollbarState, Tooltip, prelude::*}; -use util::debug_panic; +use util::{debug_panic, maybe}; actions!( variable_list, @@ -32,6 +38,8 @@ actions!( AddWatch, /// Removes the selected variable from the watch list. RemoveWatch, + /// Jump to variable's memory location. + GoToMemory, ] ); @@ -86,30 +94,30 @@ impl EntryPath { } #[derive(Debug, Clone, PartialEq)] -enum EntryKind { +enum DapEntry { Watcher(Watcher), Variable(dap::Variable), Scope(dap::Scope), } -impl EntryKind { +impl DapEntry { fn as_watcher(&self) -> Option<&Watcher> { match self { - EntryKind::Watcher(watcher) => Some(watcher), + DapEntry::Watcher(watcher) => Some(watcher), _ => None, } } fn as_variable(&self) -> Option<&dap::Variable> { match self { - EntryKind::Variable(dap) => Some(dap), + DapEntry::Variable(dap) => Some(dap), _ => None, } } fn as_scope(&self) -> Option<&dap::Scope> { match self { - EntryKind::Scope(dap) => Some(dap), + DapEntry::Scope(dap) => Some(dap), _ => None, } } @@ -117,38 +125,38 @@ impl EntryKind { #[cfg(test)] fn name(&self) -> &str { match self { - EntryKind::Watcher(watcher) => &watcher.expression, - EntryKind::Variable(dap) => &dap.name, - EntryKind::Scope(dap) => &dap.name, + DapEntry::Watcher(watcher) => &watcher.expression, + DapEntry::Variable(dap) => &dap.name, + DapEntry::Scope(dap) => &dap.name, } } } #[derive(Debug, Clone, PartialEq)] struct ListEntry { - dap_kind: EntryKind, + entry: DapEntry, path: EntryPath, } impl ListEntry { fn as_watcher(&self) -> Option<&Watcher> { - self.dap_kind.as_watcher() + self.entry.as_watcher() } fn as_variable(&self) -> Option<&dap::Variable> { - self.dap_kind.as_variable() + self.entry.as_variable() } fn as_scope(&self) -> Option<&dap::Scope> { - self.dap_kind.as_scope() + self.entry.as_scope() } fn item_id(&self) -> ElementId { use std::fmt::Write; - let mut id = match &self.dap_kind { - EntryKind::Watcher(watcher) => format!("watcher-{}", watcher.expression), - EntryKind::Variable(dap) => format!("variable-{}", dap.name), - EntryKind::Scope(dap) => format!("scope-{}", dap.name), + let mut id = match &self.entry { + DapEntry::Watcher(watcher) => format!("watcher-{}", watcher.expression), + DapEntry::Variable(dap) => format!("variable-{}", dap.name), + DapEntry::Scope(dap) => format!("scope-{}", dap.name), }; for name in self.path.indices.iter() { _ = write!(id, "-{}", name); @@ -158,10 +166,10 @@ impl ListEntry { fn item_value_id(&self) -> ElementId { use std::fmt::Write; - let mut id = match &self.dap_kind { - EntryKind::Watcher(watcher) => format!("watcher-{}", watcher.expression), - EntryKind::Variable(dap) => format!("variable-{}", dap.name), - EntryKind::Scope(dap) => format!("scope-{}", dap.name), + let mut id = match &self.entry { + DapEntry::Watcher(watcher) => format!("watcher-{}", watcher.expression), + DapEntry::Variable(dap) => format!("variable-{}", dap.name), + DapEntry::Scope(dap) => format!("scope-{}", dap.name), }; for name in self.path.indices.iter() { _ = write!(id, "-{}", name); @@ -188,13 +196,17 @@ pub struct VariableList { focus_handle: FocusHandle, edited_path: Option<(EntryPath, Entity)>, disabled: bool, + memory_view: Entity, + weak_running: WeakEntity, _subscriptions: Vec, } impl VariableList { - pub fn new( + pub(crate) fn new( session: Entity, stack_frame_list: Entity, + memory_view: Entity, + weak_running: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -211,6 +223,7 @@ impl VariableList { SessionEvent::Variables | SessionEvent::Watchers => { this.build_entries(cx); } + _ => {} }), cx.on_focus_out(&focus_handle, window, |this, _, _, cx| { @@ -234,6 +247,8 @@ impl VariableList { edited_path: None, entries: Default::default(), entry_states: Default::default(), + weak_running, + memory_view, } } @@ -284,7 +299,7 @@ impl VariableList { scope.variables_reference, scope.variables_reference, EntryPath::for_scope(&scope.name), - EntryKind::Scope(scope), + DapEntry::Scope(scope), ) }) .collect::>(); @@ -298,7 +313,7 @@ impl VariableList { watcher.variables_reference, watcher.variables_reference, EntryPath::for_watcher(watcher.expression.clone()), - EntryKind::Watcher(watcher.clone()), + DapEntry::Watcher(watcher.clone()), ) }) .collect::>(), @@ -309,9 +324,9 @@ impl VariableList { while let Some((container_reference, variables_reference, mut path, dap_kind)) = stack.pop() { match &dap_kind { - EntryKind::Watcher(watcher) => path = path.with_child(watcher.expression.clone()), - EntryKind::Variable(dap) => path = path.with_name(dap.name.clone().into()), - EntryKind::Scope(dap) => path = path.with_child(dap.name.clone().into()), + DapEntry::Watcher(watcher) => path = path.with_child(watcher.expression.clone()), + DapEntry::Variable(dap) => path = path.with_name(dap.name.clone().into()), + DapEntry::Scope(dap) => path = path.with_child(dap.name.clone().into()), } let var_state = self @@ -336,7 +351,7 @@ impl VariableList { }); entries.push(ListEntry { - dap_kind, + entry: dap_kind, path: path.clone(), }); @@ -349,7 +364,7 @@ impl VariableList { variables_reference, child.variables_reference, path.with_child(child.name.clone().into()), - EntryKind::Variable(child), + DapEntry::Variable(child), ) })); } @@ -380,9 +395,9 @@ impl VariableList { pub fn completion_variables(&self, _cx: &mut Context) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Variable(dap) => Some(dap.clone()), - EntryKind::Scope(_) | EntryKind::Watcher { .. } => None, + .filter_map(|entry| match &entry.entry { + DapEntry::Variable(dap) => Some(dap.clone()), + DapEntry::Scope(_) | DapEntry::Watcher { .. } => None, }) .collect() } @@ -400,12 +415,12 @@ impl VariableList { .get(ix) .and_then(|entry| Some(entry).zip(self.entry_states.get(&entry.path)))?; - match &entry.dap_kind { - EntryKind::Watcher { .. } => { + match &entry.entry { + DapEntry::Watcher { .. } => { Some(self.render_watcher(entry, *state, window, cx)) } - EntryKind::Variable(_) => Some(self.render_variable(entry, *state, window, cx)), - EntryKind::Scope(_) => Some(self.render_scope(entry, *state, cx)), + DapEntry::Variable(_) => Some(self.render_variable(entry, *state, window, cx)), + DapEntry::Scope(_) => Some(self.render_scope(entry, *state, cx)), } }) .collect() @@ -562,6 +577,51 @@ impl VariableList { } } + fn jump_to_variable_memory( + &mut self, + _: &GoToMemory, + window: &mut Window, + cx: &mut Context, + ) { + _ = maybe!({ + let selection = self.selection.as_ref()?; + let entry = self.entries.iter().find(|entry| &entry.path == selection)?; + let var = entry.entry.as_variable()?; + let memory_reference = var.memory_reference.as_deref()?; + + let sizeof_expr = if var.type_.as_ref().is_some_and(|t| { + t.chars() + .all(|c| c.is_whitespace() || c.is_alphabetic() || c == '*') + }) { + var.type_.as_deref() + } else { + var.evaluate_name + .as_deref() + .map(|name| name.strip_prefix("/nat ").unwrap_or_else(|| name)) + }; + self.memory_view.update(cx, |this, cx| { + this.go_to_memory_reference( + memory_reference, + sizeof_expr, + self.selected_stack_frame_id, + cx, + ); + }); + let weak_panel = self.weak_running.clone(); + + window.defer(cx, move |window, cx| { + _ = weak_panel.update(cx, |this, cx| { + this.activate_item( + crate::persistence::DebuggerPaneItem::MemoryView, + window, + cx, + ); + }); + }); + Some(()) + }); + } + fn deploy_list_entry_context_menu( &mut self, entry: ListEntry, @@ -569,49 +629,197 @@ impl VariableList { window: &mut Window, cx: &mut Context, ) { - let supports_set_variable = self - .session - .read(cx) - .capabilities() - .supports_set_variable - .unwrap_or_default(); + let (supports_set_variable, supports_data_breakpoints, supports_go_to_memory) = + self.session.read_with(cx, |session, _| { + ( + session + .capabilities() + .supports_set_variable + .unwrap_or_default(), + session + .capabilities() + .supports_data_breakpoints + .unwrap_or_default(), + session + .capabilities() + .supports_read_memory_request + .unwrap_or_default(), + ) + }); + let can_toggle_data_breakpoint = entry + .as_variable() + .filter(|_| supports_data_breakpoints) + .and_then(|variable| { + let variables_reference = self + .entry_states + .get(&entry.path) + .map(|state| state.parent_reference)?; + Some(self.session.update(cx, |session, cx| { + session.data_breakpoint_info( + Arc::new(DataBreakpointContext::Variable { + variables_reference, + name: variable.name.clone(), + bytes: None, + }), + None, + cx, + ) + })) + }); - let context_menu = ContextMenu::build(window, cx, |menu, _, _| { - menu.when(entry.as_variable().is_some(), |menu| { - menu.action("Copy Name", CopyVariableName.boxed_clone()) - .action("Copy Value", CopyVariableValue.boxed_clone()) - .when(supports_set_variable, |menu| { - menu.action("Edit Value", EditVariable.boxed_clone()) + let focus_handle = self.focus_handle.clone(); + cx.spawn_in(window, async move |this, cx| { + let can_toggle_data_breakpoint = if let Some(task) = can_toggle_data_breakpoint { + task.await + } else { + None + }; + cx.update(|window, cx| { + let context_menu = ContextMenu::build(window, cx, |menu, _, _| { + menu.when_some(entry.as_variable(), |menu, _| { + menu.action("Copy Name", CopyVariableName.boxed_clone()) + .action("Copy Value", CopyVariableValue.boxed_clone()) + .when(supports_set_variable, |menu| { + menu.action("Edit Value", EditVariable.boxed_clone()) + }) + .when(supports_go_to_memory, |menu| { + menu.action("Go To Memory", GoToMemory.boxed_clone()) + }) + .action("Watch Variable", AddWatch.boxed_clone()) + .when_some(can_toggle_data_breakpoint, |mut menu, data_info| { + menu = menu.separator(); + if let Some(access_types) = data_info.access_types { + for access in access_types { + menu = menu.action( + format!( + "Toggle {} Data Breakpoint", + match access { + dap::DataBreakpointAccessType::Read => "Read", + dap::DataBreakpointAccessType::Write => "Write", + dap::DataBreakpointAccessType::ReadWrite => + "Read/Write", + } + ), + crate::ToggleDataBreakpoint { + access_type: Some(access), + } + .boxed_clone(), + ); + } + + menu + } else { + menu.action( + "Toggle Data Breakpoint", + crate::ToggleDataBreakpoint { access_type: None } + .boxed_clone(), + ) + } + }) }) - .action("Watch Variable", AddWatch.boxed_clone()) - }) - .when(entry.as_watcher().is_some(), |menu| { - menu.action("Copy Name", CopyVariableName.boxed_clone()) - .action("Copy Value", CopyVariableValue.boxed_clone()) - .when(supports_set_variable, |menu| { - menu.action("Edit Value", EditVariable.boxed_clone()) + .when(entry.as_watcher().is_some(), |menu| { + menu.action("Copy Name", CopyVariableName.boxed_clone()) + .action("Copy Value", CopyVariableValue.boxed_clone()) + .when(supports_set_variable, |menu| { + menu.action("Edit Value", EditVariable.boxed_clone()) + }) + .action("Remove Watch", RemoveWatch.boxed_clone()) }) - .action("Remove Watch", RemoveWatch.boxed_clone()) + .context(focus_handle.clone()) + }); + + _ = this.update(cx, |this, cx| { + cx.focus_view(&context_menu, window); + let subscription = cx.subscribe_in( + &context_menu, + window, + |this, _, _: &DismissEvent, window, cx| { + if this.open_context_menu.as_ref().is_some_and(|context_menu| { + context_menu.0.focus_handle(cx).contains_focused(window, cx) + }) { + cx.focus_self(window); + } + this.open_context_menu.take(); + cx.notify(); + }, + ); + + this.open_context_menu = Some((context_menu, position, subscription)); + }); }) - .context(self.focus_handle.clone()) + }) + .detach(); + } + + fn toggle_data_breakpoint( + &mut self, + data_info: &crate::ToggleDataBreakpoint, + _window: &mut Window, + cx: &mut Context, + ) { + let Some(entry) = self + .selection + .as_ref() + .and_then(|selection| self.entries.iter().find(|entry| &entry.path == selection)) + else { + return; + }; + + let Some((name, var_ref)) = entry.as_variable().map(|var| &var.name).zip( + self.entry_states + .get(&entry.path) + .map(|state| state.parent_reference), + ) else { + return; + }; + + let context = Arc::new(DataBreakpointContext::Variable { + variables_reference: var_ref, + name: name.clone(), + bytes: None, + }); + let data_breakpoint = self.session.update(cx, |session, cx| { + session.data_breakpoint_info(context.clone(), None, cx) }); - cx.focus_view(&context_menu, window); - let subscription = cx.subscribe_in( - &context_menu, - window, - |this, _, _: &DismissEvent, window, cx| { - if this.open_context_menu.as_ref().is_some_and(|context_menu| { - context_menu.0.focus_handle(cx).contains_focused(window, cx) - }) { - cx.focus_self(window); - } - this.open_context_menu.take(); - cx.notify(); - }, - ); + let session = self.session.downgrade(); + let access_type = data_info.access_type; + cx.spawn(async move |_, cx| { + let Some((data_id, access_types)) = data_breakpoint + .await + .and_then(|info| Some((info.data_id?, info.access_types))) + else { + return; + }; - self.open_context_menu = Some((context_menu, position, subscription)); + // Because user's can manually add this action to the keymap + // we check if access type is supported + let access_type = match access_types { + None => None, + Some(access_types) => { + if access_type.is_some_and(|access_type| access_types.contains(&access_type)) { + access_type + } else { + None + } + } + }; + _ = session.update(cx, |session, cx| { + session.create_data_breakpoint( + context, + data_id.clone(), + dap::DataBreakpoint { + data_id, + access_type, + condition: None, + hit_condition: None, + }, + cx, + ); + cx.notify(); + }); + }) + .detach(); } fn copy_variable_name( @@ -628,10 +836,10 @@ impl VariableList { return; }; - let variable_name = match &entry.dap_kind { - EntryKind::Variable(dap) => dap.name.clone(), - EntryKind::Watcher(watcher) => watcher.expression.to_string(), - EntryKind::Scope(_) => return, + let variable_name = match &entry.entry { + DapEntry::Variable(dap) => dap.name.clone(), + DapEntry::Watcher(watcher) => watcher.expression.to_string(), + DapEntry::Scope(_) => return, }; cx.write_to_clipboard(ClipboardItem::new_string(variable_name)); @@ -651,10 +859,10 @@ impl VariableList { return; }; - let variable_value = match &entry.dap_kind { - EntryKind::Variable(dap) => dap.value.clone(), - EntryKind::Watcher(watcher) => watcher.value.to_string(), - EntryKind::Scope(_) => return, + let variable_value = match &entry.entry { + DapEntry::Variable(dap) => dap.value.clone(), + DapEntry::Watcher(watcher) => watcher.value.to_string(), + DapEntry::Scope(_) => return, }; cx.write_to_clipboard(ClipboardItem::new_string(variable_value)); @@ -669,10 +877,10 @@ impl VariableList { return; }; - let variable_value = match &entry.dap_kind { - EntryKind::Watcher(watcher) => watcher.value.to_string(), - EntryKind::Variable(variable) => variable.value.clone(), - EntryKind::Scope(_) => return, + let variable_value = match &entry.entry { + DapEntry::Watcher(watcher) => watcher.value.to_string(), + DapEntry::Variable(variable) => variable.value.clone(), + DapEntry::Scope(_) => return, }; let editor = Self::create_variable_editor(&variable_value, window, cx); @@ -753,7 +961,7 @@ impl VariableList { "{}{} {}{}", INDENT.repeat(state.depth - 1), if state.is_expanded { "v" } else { ">" }, - entry.dap_kind.name(), + entry.entry.name(), if self.selection.as_ref() == Some(&entry.path) { " <=== selected" } else { @@ -770,8 +978,8 @@ impl VariableList { pub(crate) fn scopes(&self) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Scope(scope) => Some(scope), + .filter_map(|entry| match &entry.entry { + DapEntry::Scope(scope) => Some(scope), _ => None, }) .cloned() @@ -785,10 +993,10 @@ impl VariableList { let mut idx = 0; for entry in self.entries.iter() { - match &entry.dap_kind { - EntryKind::Watcher { .. } => continue, - EntryKind::Variable(dap) => scopes[idx].1.push(dap.clone()), - EntryKind::Scope(scope) => { + match &entry.entry { + DapEntry::Watcher { .. } => continue, + DapEntry::Variable(dap) => scopes[idx].1.push(dap.clone()), + DapEntry::Scope(scope) => { if scopes.len() > 0 { idx += 1; } @@ -806,8 +1014,8 @@ impl VariableList { pub(crate) fn variables(&self) -> Vec { self.entries .iter() - .filter_map(|entry| match &entry.dap_kind { - EntryKind::Variable(variable) => Some(variable), + .filter_map(|entry| match &entry.entry { + DapEntry::Variable(variable) => Some(variable), _ => None, }) .cloned() @@ -1358,6 +1566,8 @@ impl Render for VariableList { .on_action(cx.listener(Self::edit_variable)) .on_action(cx.listener(Self::add_watcher)) .on_action(cx.listener(Self::remove_watcher)) + .on_action(cx.listener(Self::toggle_data_breakpoint)) + .on_action(cx.listener(Self::jump_to_variable_memory)) .child( uniform_list( "variable-list", diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 05bca8131a..505df09cfb 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -427,7 +427,7 @@ async fn test_handle_start_debugging_request( let sessions = workspace .update(cx, |workspace, _window, cx| { let debug_panel = workspace.panel::(cx).unwrap(); - debug_panel.read(cx).sessions() + debug_panel.read(cx).sessions().collect::>() }) .unwrap(); assert_eq!(sessions.len(), 1); @@ -451,7 +451,7 @@ async fn test_handle_start_debugging_request( .unwrap() .read(cx) .session(cx); - let current_sessions = debug_panel.read(cx).sessions(); + let current_sessions = debug_panel.read(cx).sessions().collect::>(); assert_eq!(active_session, current_sessions[1].read(cx).session(cx)); assert_eq!( active_session.read(cx).parent_session(), @@ -1796,7 +1796,7 @@ async fn test_debug_adapters_shutdown_on_app_quit( let panel = workspace.panel::(cx).unwrap(); panel.read_with(cx, |panel, _| { assert!( - !panel.sessions().is_empty(), + panel.sessions().next().is_some(), "Debug session should be active" ); }); diff --git a/crates/debugger_ui/src/tests/inline_values.rs b/crates/debugger_ui/src/tests/inline_values.rs index 45cab2a306..9f921ec969 100644 --- a/crates/debugger_ui/src/tests/inline_values.rs +++ b/crates/debugger_ui/src/tests/inline_values.rs @@ -2241,3 +2241,34 @@ func main() { ) .await; } + +#[gpui::test] +async fn test_trim_multi_line_inline_value(executor: BackgroundExecutor, cx: &mut TestAppContext) { + let variables = [("y", "hello\n world")]; + + let before = r#" +fn main() { + let y = "hello\n world"; +} +"# + .unindent(); + + let after = r#" +fn main() { + let y: hello… = "hello\n world"; +} +"# + .unindent(); + + test_inline_values_util( + &variables, + &[], + &before, + &after, + None, + rust_lang(), + executor, + cx, + ) + .await; +} diff --git a/crates/debugger_ui/src/tests/module_list.rs b/crates/debugger_ui/src/tests/module_list.rs index 49cfd6fcf8..09c90cbc4a 100644 --- a/crates/debugger_ui/src/tests/module_list.rs +++ b/crates/debugger_ui/src/tests/module_list.rs @@ -111,7 +111,6 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext) }); running_state.update_in(cx, |this, window, cx| { - this.ensure_pane_item(DebuggerPaneItem::Modules, window, cx); this.activate_item(DebuggerPaneItem::Modules, window, cx); cx.refresh_windows(); }); diff --git a/crates/diagnostics/src/diagnostic_renderer.rs b/crates/diagnostics/src/diagnostic_renderer.rs index 77bb249733..ce7b253702 100644 --- a/crates/diagnostics/src/diagnostic_renderer.rs +++ b/crates/diagnostics/src/diagnostic_renderer.rs @@ -144,7 +144,6 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer { style: BlockStyle::Flex, render: Arc::new(move |bcx| block.render_block(editor.clone(), bcx)), priority: 1, - render_in_minimap: false, } }) .collect() diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index 1daa9025b6..ba64ba0eed 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -80,6 +80,7 @@ pub(crate) struct ProjectDiagnosticsEditor { include_warnings: bool, update_excerpts_task: Option>>, cargo_diagnostics_fetch: CargoDiagnosticsFetchState, + diagnostic_summary_update: Task<()>, _subscription: Subscription, } @@ -179,7 +180,16 @@ impl ProjectDiagnosticsEditor { path, } => { this.paths_to_update.insert(path.clone()); - this.summary = project.read(cx).diagnostic_summary(false, cx); + let project = project.clone(); + this.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.summary = project.read(cx).diagnostic_summary(false, cx); + }) + .log_err(); + }); cx.emit(EditorEvent::TitleChanged); if this.editor.focus_handle(cx).contains_focused(window, cx) || this.focus_handle.contains_focused(window, cx) { @@ -276,6 +286,7 @@ impl ProjectDiagnosticsEditor { cancel_task: None, diagnostic_sources: Arc::new(Vec::new()), }, + diagnostic_summary_update: Task::ready(()), _subscription: project_event_subscription, }; this.update_all_diagnostics(true, window, cx); @@ -656,7 +667,6 @@ impl ProjectDiagnosticsEditor { block.render_block(editor.clone(), bcx) }), priority: 1, - render_in_minimap: false, } }); let block_ids = this.editor.update(cx, |editor, cx| { diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 0d47eaf367..1364aaf853 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -14,7 +14,10 @@ use indoc::indoc; use language::{DiagnosticSourceKind, Rope}; use lsp::LanguageServerId; use pretty_assertions::assert_eq; -use project::FakeFs; +use project::{ + FakeFs, + project_settings::{GoToDiagnosticSeverity, GoToDiagnosticSeverityFilter}, +}; use rand::{Rng, rngs::StdRng, seq::IteratorRandom as _}; use serde_json::json; use settings::SettingsStore; @@ -1005,7 +1008,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) cx.run_until_parked(); cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); assert_eq!( editor .active_diagnostic_group() @@ -1047,7 +1050,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) "}); cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); assert_eq!(editor.active_diagnostic_group(), None); }); cx.assert_editor_state(indoc! {" @@ -1126,7 +1129,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> Λ‡u32 { @@ -1135,7 +1138,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Third diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc Λ‡def: i32) -> u32 { @@ -1144,7 +1147,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Second diagnostic, same place cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc Λ‡def: i32) -> u32 { @@ -1153,7 +1156,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // First diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcΛ‡ def: i32) -> u32 { @@ -1162,7 +1165,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Wrapped over, fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> Λ‡u32 { @@ -1181,7 +1184,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // First diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcΛ‡ def: i32) -> u32 { @@ -1190,7 +1193,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Second diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc Λ‡def: i32) -> u32 { @@ -1199,7 +1202,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Third diagnostic, same place cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc Λ‡def: i32) -> u32 { @@ -1208,7 +1211,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Fourth diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abc def: i32) -> Λ‡u32 { @@ -1217,7 +1220,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { // Wrapped around, first diagnostic cx.update_editor(|editor, window, cx| { - editor.go_to_diagnostic(&GoToDiagnostic, window, cx); + editor.go_to_diagnostic(&GoToDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" fn func(abcΛ‡ def: i32) -> u32 { @@ -1441,6 +1444,128 @@ async fn test_diagnostics_with_code(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { + init_test(cx); + + let mut cx = EditorTestContext::new(cx).await; + let lsp_store = + cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + + cx.set_state(indoc! {"error warning info hiΛ‡nt"}); + + cx.update(|_, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Url::from_file_path(path!("/root/file")).unwrap(), + version: None, + diagnostics: vec![ + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 5), + ), + severity: Some(lsp::DiagnosticSeverity::ERROR), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 6), + lsp::Position::new(0, 13), + ), + severity: Some(lsp::DiagnosticSeverity::WARNING), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 14), + lsp::Position::new(0, 18), + ), + severity: Some(lsp::DiagnosticSeverity::INFORMATION), + ..Default::default() + }, + lsp::Diagnostic { + range: lsp::Range::new( + lsp::Position::new(0, 19), + lsp::Position::new(0, 23), + ), + severity: Some(lsp::DiagnosticSeverity::HINT), + ..Default::default() + }, + ], + }, + None, + DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap() + }); + }); + cx.run_until_parked(); + + macro_rules! go { + ($severity:expr) => { + cx.update_editor(|editor, window, cx| { + editor.go_to_diagnostic( + &GoToDiagnostic { + severity: $severity, + }, + window, + cx, + ); + }); + }; + } + + // Default, should cycle through all diagnostics + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"Λ‡error warning info hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error Λ‡warning info hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error warning Λ‡info hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"error warning info Λ‡hint"}); + go!(GoToDiagnosticSeverityFilter::default()); + cx.assert_editor_state(indoc! {"Λ‡error warning info hint"}); + + let only_info = GoToDiagnosticSeverityFilter::Only(GoToDiagnosticSeverity::Information); + go!(only_info); + cx.assert_editor_state(indoc! {"error warning Λ‡info hint"}); + go!(only_info); + cx.assert_editor_state(indoc! {"error warning Λ‡info hint"}); + + let no_hints = GoToDiagnosticSeverityFilter::Range { + min: GoToDiagnosticSeverity::Information, + max: GoToDiagnosticSeverity::Error, + }; + + go!(no_hints); + cx.assert_editor_state(indoc! {"Λ‡error warning info hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"error Λ‡warning info hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"error warning Λ‡info hint"}); + go!(no_hints); + cx.assert_editor_state(indoc! {"Λ‡error warning info hint"}); + + let warning_info = GoToDiagnosticSeverityFilter::Range { + min: GoToDiagnosticSeverity::Information, + max: GoToDiagnosticSeverity::Warning, + }; + + go!(warning_info); + cx.assert_editor_state(indoc! {"error Λ‡warning info hint"}); + go!(warning_info); + cx.assert_editor_state(indoc! {"error warning Λ‡info hint"}); + go!(warning_info); + cx.assert_editor_state(indoc! {"error Λ‡warning info hint"}); +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { zlog::init_test(); diff --git a/crates/diagnostics/src/items.rs b/crates/diagnostics/src/items.rs index b5f9e901bb..7ac6d101f3 100644 --- a/crates/diagnostics/src/items.rs +++ b/crates/diagnostics/src/items.rs @@ -6,9 +6,10 @@ use gpui::{ WeakEntity, Window, }; use language::Diagnostic; -use project::project_settings::ProjectSettings; +use project::project_settings::{GoToDiagnosticSeverityFilter, ProjectSettings}; use settings::Settings; use ui::{Button, ButtonLike, Color, Icon, IconName, Label, Tooltip, h_flex, prelude::*}; +use util::ResultExt; use workspace::{StatusItemView, ToolbarItemEvent, Workspace, item::ItemHandle}; use crate::{Deploy, IncludeWarnings, ProjectDiagnosticsEditor}; @@ -20,6 +21,7 @@ pub struct DiagnosticIndicator { current_diagnostic: Option, _observe_active_editor: Option, diagnostics_update: Task<()>, + diagnostic_summary_update: Task<()>, } impl Render for DiagnosticIndicator { @@ -77,7 +79,7 @@ impl Render for DiagnosticIndicator { .tooltip(|window, cx| { Tooltip::for_action( "Next Diagnostic", - &editor::actions::GoToDiagnostic, + &editor::actions::GoToDiagnostic::default(), window, cx, ) @@ -135,8 +137,16 @@ impl DiagnosticIndicator { } project::Event::DiagnosticsUpdated { .. } => { - this.summary = project.read(cx).diagnostic_summary(false, cx); - cx.notify(); + this.diagnostic_summary_update = cx.spawn(async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(30)) + .await; + this.update(cx, |this, cx| { + this.summary = project.read(cx).diagnostic_summary(false, cx); + cx.notify(); + }) + .log_err(); + }); } _ => {} @@ -150,13 +160,19 @@ impl DiagnosticIndicator { current_diagnostic: None, _observe_active_editor: None, diagnostics_update: Task::ready(()), + diagnostic_summary_update: Task::ready(()), } } fn go_to_next_diagnostic(&mut self, window: &mut Window, cx: &mut Context) { if let Some(editor) = self.active_editor.as_ref().and_then(|e| e.upgrade()) { editor.update(cx, |editor, cx| { - editor.go_to_diagnostic_impl(editor::Direction::Next, window, cx); + editor.go_to_diagnostic_impl( + editor::Direction::Next, + GoToDiagnosticSeverityFilter::default(), + window, + cx, + ); }) } } diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index c8e945c7e8..8eeeb6f0c5 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -243,7 +243,6 @@ struct ActionDef { fn dump_all_gpui_actions() -> Vec { let mut actions = gpui::generate_list_of_all_registered_actions() - .into_iter() .map(|action| ActionDef { name: action.name, human_name: command_palette::humanize_action_name(action.name), diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 70ec8ea00f..1212651cb3 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -1,6 +1,7 @@ //! This module contains all actions supported by [`Editor`]. use super::*; use gpui::{Action, actions}; +use project::project_settings::GoToDiagnosticSeverityFilter; use schemars::JsonSchema; use util::serde::default_true; @@ -258,6 +259,13 @@ pub struct SpawnNearestTask { pub reveal: task::RevealStrategy, } +#[derive(Clone, PartialEq, Action)] +#[action(no_json, no_register)] +pub struct DiffClipboardWithSelectionData { + pub clipboard_text: String, + pub editor: Entity, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Default)] pub enum UuidVersion { #[default] @@ -265,6 +273,24 @@ pub enum UuidVersion { V7, } +/// Goes to the next diagnostic in the file. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct GoToDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + +/// Goes to the previous diagnostic in the file. +#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct GoToPreviousDiagnostic { + #[serde(default)] + pub severity: GoToDiagnosticSeverityFilter, +} + actions!( debugger, [ @@ -303,6 +329,8 @@ actions!( ApplyDiffHunk, /// Deletes the character before the cursor. Backspace, + /// Shows git blame information for the current line. + BlameHover, /// Cancels the current operation. Cancel, /// Cancels the running flycheck operation. @@ -337,6 +365,8 @@ actions!( ConvertToLowerCase, /// Toggles the case of selected text. ConvertToOppositeCase, + /// Converts selected text to sentence case. + ConvertToSentenceCase, /// Converts selected text to snake_case. ConvertToSnakeCase, /// Converts selected text to Title Case. @@ -377,6 +407,8 @@ actions!( DeleteToNextSubwordEnd, /// Deletes to the start of the previous subword. DeleteToPreviousSubwordStart, + /// Diffs the text stored in the clipboard against the current selection. + DiffClipboardWithSelection, /// Displays names of all active cursors. DisplayCursorNames, /// Duplicates the current line below. @@ -406,10 +438,14 @@ actions!( FoldRecursive, /// Folds the selected ranges. FoldSelectedRanges, + /// Toggles focus back to the last active buffer. + ToggleFocus, /// Toggles folding at the current position. ToggleFold, /// Toggles recursive folding at the current position. ToggleFoldRecursive, + /// Toggles all folds in a buffer or all excerpts in multibuffer. + ToggleFoldAll, /// Formats the entire document. Format, /// Formats only the selected text. @@ -422,8 +458,6 @@ actions!( GoToDefinition, /// Goes to definition in a split pane. GoToDefinitionSplit, - /// Goes to the next diagnostic in the file. - GoToDiagnostic, /// Goes to the next diff hunk. GoToHunk, /// Goes to the previous diff hunk. @@ -438,8 +472,6 @@ actions!( GoToParentModule, /// Goes to the previous change in the file. GoToPreviousChange, - /// Goes to the previous diagnostic in the file. - GoToPreviousDiagnostic, /// Goes to the type definition of the symbol at cursor. GoToTypeDefinition, /// Goes to type definition in a split pane. diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 8fbae8d605..52446ceafc 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -844,7 +844,7 @@ impl CompletionsMenu { .with_sizing_behavior(ListSizingBehavior::Infer) .w(rems(34.)); - Popover::new().child(list).into_any_element() + Popover::new().child(div().child(list)).into_any_element() } fn render_aside( @@ -1074,6 +1074,20 @@ impl CompletionsMenu { .and_then(|q| q.chars().next()) .and_then(|c| c.to_lowercase().next()); + if snippet_sort_order == SnippetSortOrder::None { + matches.retain(|string_match| { + let completion = &completions[string_match.candidate_id]; + + let is_snippet = matches!( + &completion.source, + CompletionSource::Lsp { lsp_completion, .. } + if lsp_completion.kind == Some(CompletionItemKind::SNIPPET) + ); + + !is_snippet + }); + } + matches.sort_unstable_by_key(|string_match| { let completion = &completions[string_match.candidate_id]; @@ -1112,6 +1126,7 @@ impl CompletionsMenu { SnippetSortOrder::Top => Reverse(if is_snippet { 1 } else { 0 }), SnippetSortOrder::Bottom => Reverse(if is_snippet { 0 } else { 1 }), SnippetSortOrder::Inline => Reverse(0), + SnippetSortOrder::None => Reverse(0), }; let sort_positions = string_match.positions.clone(); let sort_exact = Reverse(if Some(completion.label.filter_text()) == query { @@ -1369,7 +1384,7 @@ impl CodeActionsMenu { } } - fn visible(&self) -> bool { + pub fn visible(&self) -> bool { !self.actions.is_empty() } diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index aa2408d6d9..5425d5a8b9 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -271,7 +271,6 @@ impl DisplayMap { height: Some(height), style, priority, - render_in_minimap: true, } }), ); @@ -1663,7 +1662,6 @@ pub mod tests { height: Some(height), render: Arc::new(|_| div().into_any()), priority, - render_in_minimap: true, } }) .collect::>(); @@ -2029,7 +2027,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ); @@ -2227,7 +2224,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { placement: BlockPlacement::Below( @@ -2237,7 +2233,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ], cx, @@ -2344,7 +2339,6 @@ pub mod tests { style: BlockStyle::Sticky, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ) @@ -2420,7 +2414,6 @@ pub mod tests { style: BlockStyle::Fixed, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], cx, ); diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index ea754da03f..85495a2611 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -193,7 +193,6 @@ pub struct CustomBlock { style: BlockStyle, render: Arc>, priority: usize, - pub(crate) render_in_minimap: bool, } #[derive(Clone)] @@ -205,7 +204,6 @@ pub struct BlockProperties

{ pub style: BlockStyle, pub render: RenderBlock, pub priority: usize, - pub render_in_minimap: bool, } impl Debug for BlockProperties

{ @@ -526,10 +524,10 @@ impl BlockMap { // * Isomorphic transforms that end *at* the start of the edit // * Below blocks that end at the start of the edit // However, if we hit a replace block that ends at the start of the edit we want to reconstruct it. - new_transforms.append(cursor.slice(&old_start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&old_start, Bias::Left), &()); if let Some(transform) = cursor.item() { if transform.summary.input_rows > 0 - && cursor.end(&()) == old_start + && cursor.end() == old_start && transform .block .as_ref() @@ -537,13 +535,13 @@ impl BlockMap { { // Preserve the transform (push and next) new_transforms.push(transform.clone(), &()); - cursor.next(&()); + cursor.next(); // Preserve below blocks at end of edit while let Some(transform) = cursor.item() { if transform.block.as_ref().map_or(false, |b| b.place_below()) { new_transforms.push(transform.clone(), &()); - cursor.next(&()); + cursor.next(); } else { break; } @@ -581,8 +579,8 @@ impl BlockMap { let mut new_end = WrapRow(edit.new.end); loop { // Seek to the transform starting at or after the end of the edit - cursor.seek(&old_end, Bias::Left, &()); - cursor.next(&()); + cursor.seek(&old_end, Bias::Left); + cursor.next(); // Extend edit to the end of the discarded transform so it is reconstructed in full let transform_rows_after_edit = cursor.start().0 - old_end.0; @@ -594,8 +592,8 @@ impl BlockMap { if next_edit.old.start <= cursor.start().0 { old_end = WrapRow(next_edit.old.end); new_end = WrapRow(next_edit.new.end); - cursor.seek(&old_end, Bias::Left, &()); - cursor.next(&()); + cursor.seek(&old_end, Bias::Left); + cursor.next(); edits.next(); } else { break; @@ -610,7 +608,7 @@ impl BlockMap { // Discard below blocks at the end of the edit. They'll be reconstructed. while let Some(transform) = cursor.item() { if transform.block.as_ref().map_or(false, |b| b.place_below()) { - cursor.next(&()); + cursor.next(); } else { break; } @@ -722,7 +720,7 @@ impl BlockMap { push_isomorphic(&mut new_transforms, rows_after_last_block, wrap_snapshot); } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); debug_assert_eq!( new_transforms.summary().input_rows, wrap_snapshot.max_point().row() + 1 @@ -973,7 +971,7 @@ impl BlockMapReader<'_> { ); let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&start_wrap_row, Bias::Left, &()); + cursor.seek(&start_wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if cursor.start().0 > end_wrap_row { break; @@ -984,7 +982,7 @@ impl BlockMapReader<'_> { return Some(cursor.start().1); } } - cursor.next(&()); + cursor.next(); } None @@ -1044,7 +1042,6 @@ impl BlockMapWriter<'_> { render: Arc::new(Mutex::new(block.render)), style: block.style, priority: block.priority, - render_in_minimap: block.render_in_minimap, }); self.0.custom_blocks.insert(block_ix, new_block.clone()); self.0.custom_blocks_by_id.insert(id, new_block); @@ -1079,7 +1076,6 @@ impl BlockMapWriter<'_> { style: block.style, render: block.render.clone(), priority: block.priority, - render_in_minimap: block.render_in_minimap, }; let new_block = Arc::new(new_block); *block = new_block.clone(); @@ -1297,7 +1293,7 @@ impl BlockSnapshot { let max_output_row = cmp::min(rows.end, self.transforms.summary().output_rows); let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(rows.start), Bias::Right, &()); + cursor.seek(&BlockRow(rows.start), Bias::Right); let transform_output_start = cursor.start().0.0; let transform_input_start = cursor.start().1.0; @@ -1329,7 +1325,7 @@ impl BlockSnapshot { pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&start_row, Bias::Right, &()); + cursor.seek(&start_row, Bias::Right); let (output_start, input_start) = cursor.start(); let overshoot = if cursor .item() @@ -1350,9 +1346,9 @@ impl BlockSnapshot { pub fn blocks_in_range(&self, rows: Range) -> impl Iterator { let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&BlockRow(rows.start), Bias::Left, &()); - while cursor.start().0 < rows.start && cursor.end(&()).0 <= rows.start { - cursor.next(&()); + cursor.seek(&BlockRow(rows.start), Bias::Left); + while cursor.start().0 < rows.start && cursor.end().0 <= rows.start { + cursor.next(); } std::iter::from_fn(move || { @@ -1368,10 +1364,10 @@ impl BlockSnapshot { break; } if let Some(block) = &transform.block { - cursor.next(&()); + cursor.next(); return Some((start_row, block)); } else { - cursor.next(&()); + cursor.next(); } } None @@ -1381,7 +1377,7 @@ impl BlockSnapshot { pub fn sticky_header_excerpt(&self, position: f32) -> Option> { let top_row = position as u32; let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&BlockRow(top_row), Bias::Right, &()); + cursor.seek(&BlockRow(top_row), Bias::Right); while let Some(transform) = cursor.item() { match &transform.block { @@ -1390,7 +1386,7 @@ impl BlockSnapshot { } Some(block) if block.is_buffer_header() => return None, _ => { - cursor.prev(&()); + cursor.prev(); continue; } } @@ -1418,7 +1414,7 @@ impl BlockSnapshot { let wrap_row = WrapRow(wrap_point.row()); let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&wrap_row, Bias::Left, &()); + cursor.seek(&wrap_row, Bias::Left); while let Some(transform) = cursor.item() { if let Some(block) = transform.block.as_ref() { @@ -1429,7 +1425,7 @@ impl BlockSnapshot { break; } - cursor.next(&()); + cursor.next(); } None @@ -1446,7 +1442,7 @@ impl BlockSnapshot { pub fn longest_row_in_range(&self, range: Range) -> BlockRow { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + cursor.seek(&range.start, Bias::Right); let mut longest_row = range.start; let mut longest_row_chars = 0; @@ -1457,7 +1453,7 @@ impl BlockSnapshot { let wrap_start_row = input_start.0 + overshoot; let wrap_end_row = cmp::min( input_start.0 + (range.end.0 - output_start.0), - cursor.end(&()).1.0, + cursor.end().1.0, ); let summary = self .wrap_snapshot @@ -1465,12 +1461,12 @@ impl BlockSnapshot { longest_row = BlockRow(range.start.0 + summary.longest_row); longest_row_chars = summary.longest_row_chars; } - cursor.next(&()); + cursor.next(); } let cursor_start_row = cursor.start().0; if range.end > cursor_start_row { - let summary = cursor.summary::<_, TransformSummary>(&range.end, Bias::Right, &()); + let summary = cursor.summary::<_, TransformSummary>(&range.end, Bias::Right); if summary.longest_row_chars > longest_row_chars { longest_row = BlockRow(cursor_start_row.0 + summary.longest_row); longest_row_chars = summary.longest_row_chars; @@ -1497,7 +1493,7 @@ impl BlockSnapshot { pub(super) fn line_len(&self, row: BlockRow) -> u32 { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(row.0), Bias::Right, &()); + cursor.seek(&BlockRow(row.0), Bias::Right); if let Some(transform) = cursor.item() { let (output_start, input_start) = cursor.start(); let overshoot = row.0 - output_start.0; @@ -1515,13 +1511,13 @@ impl BlockSnapshot { pub(super) fn is_block_line(&self, row: BlockRow) -> bool { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&row, Bias::Right, &()); + cursor.seek(&row, Bias::Right); cursor.item().map_or(false, |t| t.block.is_some()) } pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&row, Bias::Right, &()); + cursor.seek(&row, Bias::Right); let Some(transform) = cursor.item() else { return false; }; @@ -1533,7 +1529,7 @@ impl BlockSnapshot { .wrap_snapshot .make_wrap_point(Point::new(row.0, 0), Bias::Left); let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&WrapRow(wrap_point.row()), Bias::Right, &()); + cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); cursor.item().map_or(false, |transform| { transform .block @@ -1544,17 +1540,17 @@ impl BlockSnapshot { pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(point.row), Bias::Right, &()); + cursor.seek(&BlockRow(point.row), Bias::Right); let max_input_row = WrapRow(self.transforms.summary().input_rows); let mut search_left = - (bias == Bias::Left && cursor.start().1.0 > 0) || cursor.end(&()).1 == max_input_row; + (bias == Bias::Left && cursor.start().1.0 > 0) || cursor.end().1 == max_input_row; let mut reversed = false; loop { if let Some(transform) = cursor.item() { let (output_start_row, input_start_row) = cursor.start(); - let (output_end_row, input_end_row) = cursor.end(&()); + let (output_end_row, input_end_row) = cursor.end(); let output_start = Point::new(output_start_row.0, 0); let input_start = Point::new(input_start_row.0, 0); let input_end = Point::new(input_end_row.0, 0); @@ -1588,23 +1584,23 @@ impl BlockSnapshot { } if search_left { - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); } } else if reversed { return self.max_point(); } else { reversed = true; search_left = !search_left; - cursor.seek(&BlockRow(point.row), Bias::Right, &()); + cursor.seek(&BlockRow(point.row), Bias::Right); } } } pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>(&()); - cursor.seek(&WrapRow(wrap_point.row()), Bias::Right, &()); + cursor.seek(&WrapRow(wrap_point.row()), Bias::Right); if let Some(transform) = cursor.item() { if transform.block.is_some() { BlockPoint::new(cursor.start().1.0, 0) @@ -1622,7 +1618,7 @@ impl BlockSnapshot { pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { let mut cursor = self.transforms.cursor::<(BlockRow, WrapRow)>(&()); - cursor.seek(&BlockRow(block_point.row), Bias::Right, &()); + cursor.seek(&BlockRow(block_point.row), Bias::Right); if let Some(transform) = cursor.item() { match transform.block.as_ref() { Some(block) => { @@ -1634,7 +1630,7 @@ impl BlockSnapshot { } else if bias == Bias::Left { WrapPoint::new(cursor.start().1.0, 0) } else { - let wrap_row = cursor.end(&()).1.0 - 1; + let wrap_row = cursor.end().1.0 - 1; WrapPoint::new(wrap_row, self.wrap_snapshot.line_len(wrap_row)) } } @@ -1654,14 +1650,14 @@ impl BlockChunks<'_> { /// Go to the next transform fn advance(&mut self) { self.input_chunk = Chunk::default(); - self.transforms.next(&()); + self.transforms.next(); while let Some(transform) = self.transforms.item() { if transform .block .as_ref() .map_or(false, |block| block.height() == 0) { - self.transforms.next(&()); + self.transforms.next(); } else { break; } @@ -1676,7 +1672,7 @@ impl BlockChunks<'_> { let start_output_row = self.transforms.start().0.0; if start_output_row < self.max_output_row { let end_input_row = cmp::min( - self.transforms.end(&()).1.0, + self.transforms.end().1.0, start_input_row + (self.max_output_row - start_output_row), ); self.input_chunks.seek(start_input_row..end_input_row); @@ -1700,7 +1696,7 @@ impl<'a> Iterator for BlockChunks<'a> { let transform = self.transforms.item()?; if transform.block.is_some() { let block_start = self.transforms.start().0.0; - let mut block_end = self.transforms.end(&()).0.0; + let mut block_end = self.transforms.end().0.0; self.advance(); if self.transforms.item().is_none() { block_end -= 1; @@ -1735,7 +1731,7 @@ impl<'a> Iterator for BlockChunks<'a> { } } - let transform_end = self.transforms.end(&()).0.0; + let transform_end = self.transforms.end().0.0; let (prefix_rows, prefix_bytes) = offset_for_row(self.input_chunk.text, transform_end - self.output_row); self.output_row += prefix_rows; @@ -1774,15 +1770,15 @@ impl Iterator for BlockRows<'_> { self.started = true; } - if self.output_row.0 >= self.transforms.end(&()).0.0 { - self.transforms.next(&()); + if self.output_row.0 >= self.transforms.end().0.0 { + self.transforms.next(); while let Some(transform) = self.transforms.item() { if transform .block .as_ref() .map_or(false, |block| block.height() == 0) { - self.transforms.next(&()); + self.transforms.next(); } else { break; } @@ -1976,7 +1972,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -1984,7 +1979,6 @@ mod tests { height: Some(2), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -1992,7 +1986,6 @@ mod tests { height: Some(3), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2217,7 +2210,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2225,7 +2217,6 @@ mod tests { height: Some(2), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2233,7 +2224,6 @@ mod tests { height: Some(3), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2322,7 +2312,6 @@ mod tests { render: Arc::new(|_| div().into_any()), height: Some(1), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2330,7 +2319,6 @@ mod tests { render: Arc::new(|_| div().into_any()), height: Some(1), priority: 0, - render_in_minimap: true, }, ]); @@ -2370,7 +2358,6 @@ mod tests { height: Some(4), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }])[0]; let blocks_snapshot = block_map.read(wraps_snapshot, Default::default()); @@ -2424,7 +2411,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2432,7 +2418,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2440,7 +2425,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); @@ -2455,7 +2439,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2463,7 +2446,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2471,7 +2453,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); @@ -2571,7 +2552,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2579,7 +2559,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2587,7 +2566,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); let excerpt_blocks_3 = writer.insert(vec![ @@ -2597,7 +2575,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, BlockProperties { style: BlockStyle::Fixed, @@ -2605,7 +2582,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }, ]); @@ -2653,7 +2629,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }]); let blocks_snapshot = block_map.read(wrap_snapshot.clone(), Patch::default()); let blocks = blocks_snapshot @@ -3011,7 +2986,6 @@ mod tests { height: Some(height), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, } }) .collect::>(); @@ -3032,7 +3006,6 @@ mod tests { style: props.style, render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, })); for (block_properties, block_id) in block_properties.iter().zip(block_ids) { @@ -3557,7 +3530,6 @@ mod tests { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }])[0]; let blocks_snapshot = block_map.read(wraps_snapshot.clone(), Default::default()); diff --git a/crates/editor/src/display_map/crease_map.rs b/crates/editor/src/display_map/crease_map.rs index e6fe4270ec..bdac982fa7 100644 --- a/crates/editor/src/display_map/crease_map.rs +++ b/crates/editor/src/display_map/crease_map.rs @@ -52,15 +52,15 @@ impl CreaseSnapshot { ) -> Option<&'a Crease> { let start = snapshot.anchor_before(Point::new(row.0, 0)); let mut cursor = self.creases.cursor::(snapshot); - cursor.seek(&start, Bias::Left, snapshot); + cursor.seek(&start, Bias::Left); while let Some(item) = cursor.item() { match Ord::cmp(&item.crease.range().start.to_point(snapshot).row, &row.0) { - Ordering::Less => cursor.next(snapshot), + Ordering::Less => cursor.next(), Ordering::Equal => { if item.crease.range().start.is_valid(snapshot) { return Some(&item.crease); } else { - cursor.next(snapshot); + cursor.next(); } } Ordering::Greater => break, @@ -76,11 +76,11 @@ impl CreaseSnapshot { ) -> impl 'a + Iterator> { let start = snapshot.anchor_before(Point::new(range.start.0, 0)); let mut cursor = self.creases.cursor::(snapshot); - cursor.seek(&start, Bias::Left, snapshot); + cursor.seek(&start, Bias::Left); std::iter::from_fn(move || { while let Some(item) = cursor.item() { - cursor.next(snapshot); + cursor.next(); let crease_range = item.crease.range(); let crease_start = crease_range.start.to_point(snapshot); let crease_end = crease_range.end.to_point(snapshot); @@ -102,13 +102,13 @@ impl CreaseSnapshot { let mut cursor = self.creases.cursor::(snapshot); let mut results = Vec::new(); - cursor.next(snapshot); + cursor.next(); while let Some(item) = cursor.item() { let crease_range = item.crease.range(); let start_point = crease_range.start.to_point(snapshot); let end_point = crease_range.end.to_point(snapshot); results.push((item.id, start_point..end_point)); - cursor.next(snapshot); + cursor.next(); } results @@ -298,7 +298,7 @@ impl CreaseMap { let mut cursor = self.snapshot.creases.cursor::(snapshot); for crease in creases { let crease_range = crease.range().clone(); - new_creases.append(cursor.slice(&crease_range, Bias::Left, snapshot), snapshot); + new_creases.append(cursor.slice(&crease_range, Bias::Left), snapshot); let id = self.next_id; self.next_id.0 += 1; @@ -306,7 +306,7 @@ impl CreaseMap { new_creases.push(CreaseItem { crease, id }, snapshot); new_ids.push(id); } - new_creases.append(cursor.suffix(snapshot), snapshot); + new_creases.append(cursor.suffix(), snapshot); new_creases }; new_ids @@ -332,9 +332,9 @@ impl CreaseMap { let mut cursor = self.snapshot.creases.cursor::(snapshot); for (id, range) in &removals { - new_creases.append(cursor.slice(range, Bias::Left, snapshot), snapshot); + new_creases.append(cursor.slice(range, Bias::Left), snapshot); while let Some(item) = cursor.item() { - cursor.next(snapshot); + cursor.next(); if item.id == *id { break; } else { @@ -343,7 +343,7 @@ impl CreaseMap { } } - new_creases.append(cursor.suffix(snapshot), snapshot); + new_creases.append(cursor.suffix(), snapshot); new_creases }; diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index f37e7063e7..829d34ff58 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -99,7 +99,7 @@ impl FoldPoint { pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { let mut cursor = snapshot.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&self, Bias::Right, &()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayPoint(cursor.start().1.0 + overshoot) } @@ -108,7 +108,7 @@ impl FoldPoint { let mut cursor = snapshot .transforms .cursor::<(FoldPoint, TransformSummary)>(&()); - cursor.seek(&self, Bias::Right, &()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().1.output.lines; let mut offset = cursor.start().1.output.len; if !overshoot.is_zero() { @@ -187,10 +187,10 @@ impl FoldMapWriter<'_> { width: None, }, ); - new_tree.append(cursor.slice(&fold.range, Bias::Right, buffer), buffer); + new_tree.append(cursor.slice(&fold.range, Bias::Right), buffer); new_tree.push(fold, buffer); } - new_tree.append(cursor.suffix(buffer), buffer); + new_tree.append(cursor.suffix(), buffer); new_tree }; @@ -252,7 +252,7 @@ impl FoldMapWriter<'_> { fold_ixs_to_delete.push(*folds_cursor.start()); self.0.snapshot.fold_metadata_by_id.remove(&fold.id); } - folds_cursor.next(buffer); + folds_cursor.next(); } } @@ -263,10 +263,10 @@ impl FoldMapWriter<'_> { let mut cursor = self.0.snapshot.folds.cursor::(buffer); let mut folds = SumTree::new(buffer); for fold_ix in fold_ixs_to_delete { - folds.append(cursor.slice(&fold_ix, Bias::Right, buffer), buffer); - cursor.next(buffer); + folds.append(cursor.slice(&fold_ix, Bias::Right), buffer); + cursor.next(); } - folds.append(cursor.suffix(buffer), buffer); + folds.append(cursor.suffix(), buffer); folds }; @@ -412,7 +412,7 @@ impl FoldMap { let mut new_transforms = SumTree::::default(); let mut cursor = self.snapshot.transforms.cursor::(&()); - cursor.seek(&InlayOffset(0), Bias::Right, &()); + cursor.seek(&InlayOffset(0), Bias::Right); while let Some(mut edit) = inlay_edits_iter.next() { if let Some(item) = cursor.item() { @@ -421,19 +421,19 @@ impl FoldMap { |transform| { if !transform.is_fold() { transform.summary.add_summary(&item.summary, &()); - cursor.next(&()); + cursor.next(); } }, &(), ); } } - new_transforms.append(cursor.slice(&edit.old.start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&edit.old.start, Bias::Left), &()); edit.new.start -= edit.old.start - *cursor.start(); edit.old.start = *cursor.start(); - cursor.seek(&edit.old.end, Bias::Right, &()); - cursor.next(&()); + cursor.seek(&edit.old.end, Bias::Right); + cursor.next(); let mut delta = edit.new_len().0 as isize - edit.old_len().0 as isize; loop { @@ -449,8 +449,8 @@ impl FoldMap { if next_edit.old.end >= edit.old.end { edit.old.end = next_edit.old.end; - cursor.seek(&edit.old.end, Bias::Right, &()); - cursor.next(&()); + cursor.seek(&edit.old.end, Bias::Right); + cursor.next(); } } else { break; @@ -467,11 +467,7 @@ impl FoldMap { .snapshot .folds .cursor::(&inlay_snapshot.buffer); - folds_cursor.seek( - &FoldRange(anchor..Anchor::max()), - Bias::Left, - &inlay_snapshot.buffer, - ); + folds_cursor.seek(&FoldRange(anchor..Anchor::max()), Bias::Left); let mut folds = iter::from_fn({ let inlay_snapshot = &inlay_snapshot; @@ -485,7 +481,7 @@ impl FoldMap { ..inlay_snapshot.to_inlay_offset(buffer_end), ) }); - folds_cursor.next(&inlay_snapshot.buffer); + folds_cursor.next(); item } }) @@ -558,7 +554,7 @@ impl FoldMap { } } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); if new_transforms.is_empty() { let text_summary = inlay_snapshot.text_summary(); push_isomorphic(&mut new_transforms, text_summary); @@ -575,31 +571,31 @@ impl FoldMap { let mut new_transforms = new_transforms.cursor::<(InlayOffset, FoldOffset)>(&()); for mut edit in inlay_edits { - old_transforms.seek(&edit.old.start, Bias::Left, &()); + old_transforms.seek(&edit.old.start, Bias::Left); if old_transforms.item().map_or(false, |t| t.is_fold()) { edit.old.start = old_transforms.start().0; } let old_start = old_transforms.start().1.0 + (edit.old.start - old_transforms.start().0).0; - old_transforms.seek_forward(&edit.old.end, Bias::Right, &()); + old_transforms.seek_forward(&edit.old.end, Bias::Right); if old_transforms.item().map_or(false, |t| t.is_fold()) { - old_transforms.next(&()); + old_transforms.next(); edit.old.end = old_transforms.start().0; } let old_end = old_transforms.start().1.0 + (edit.old.end - old_transforms.start().0).0; - new_transforms.seek(&edit.new.start, Bias::Left, &()); + new_transforms.seek(&edit.new.start, Bias::Left); if new_transforms.item().map_or(false, |t| t.is_fold()) { edit.new.start = new_transforms.start().0; } let new_start = new_transforms.start().1.0 + (edit.new.start - new_transforms.start().0).0; - new_transforms.seek_forward(&edit.new.end, Bias::Right, &()); + new_transforms.seek_forward(&edit.new.end, Bias::Right); if new_transforms.item().map_or(false, |t| t.is_fold()) { - new_transforms.next(&()); + new_transforms.next(); edit.new.end = new_transforms.start().0; } let new_end = @@ -656,10 +652,10 @@ impl FoldSnapshot { let mut summary = TextSummary::default(); let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + cursor.seek(&range.start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = range.start.0 - cursor.start().0.0; - let end_in_transform = cmp::min(range.end, cursor.end(&()).0).0 - cursor.start().0.0; + let end_in_transform = cmp::min(range.end, cursor.end().0).0 - cursor.start().0.0; if let Some(placeholder) = transform.placeholder.as_ref() { summary = TextSummary::from( &placeholder.text @@ -678,10 +674,10 @@ impl FoldSnapshot { } } - if range.end > cursor.end(&()).0 { - cursor.next(&()); + if range.end > cursor.end().0 { + cursor.next(); summary += &cursor - .summary::<_, TransformSummary>(&range.end, Bias::Right, &()) + .summary::<_, TransformSummary>(&range.end, Bias::Right) .output; if let Some(transform) = cursor.item() { let end_in_transform = range.end.0 - cursor.start().0.0; @@ -705,19 +701,16 @@ impl FoldSnapshot { pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { let mut cursor = self.transforms.cursor::<(InlayPoint, FoldPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| t.is_fold()) { if bias == Bias::Left || point == cursor.start().0 { cursor.start().1 } else { - cursor.end(&()).1 + cursor.end().1 } } else { let overshoot = point.0 - cursor.start().0.0; - FoldPoint(cmp::min( - cursor.start().1.0 + overshoot, - cursor.end(&()).1.0, - )) + FoldPoint(cmp::min(cursor.start().1.0 + overshoot, cursor.end().1.0)) } } @@ -742,7 +735,7 @@ impl FoldSnapshot { let fold_point = FoldPoint::new(start_row, 0); let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&fold_point, Bias::Left, &()); + cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - cursor.start().0.0; let inlay_point = InlayPoint(cursor.start().1.0 + overshoot); @@ -773,7 +766,7 @@ impl FoldSnapshot { let mut folds = intersecting_folds(&self.inlay_snapshot, &self.folds, range, false); iter::from_fn(move || { let item = folds.item(); - folds.next(&self.inlay_snapshot.buffer); + folds.next(); item }) } @@ -785,7 +778,7 @@ impl FoldSnapshot { let buffer_offset = offset.to_offset(&self.inlay_snapshot.buffer); let inlay_offset = self.inlay_snapshot.to_inlay_offset(buffer_offset); let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&inlay_offset, Bias::Right, &()); + cursor.seek(&inlay_offset, Bias::Right); cursor.item().map_or(false, |t| t.placeholder.is_some()) } @@ -794,7 +787,7 @@ impl FoldSnapshot { .inlay_snapshot .to_inlay_point(Point::new(buffer_row.0, 0)); let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&inlay_point, Bias::Right, &()); + cursor.seek(&inlay_point, Bias::Right); loop { match cursor.item() { Some(transform) => { @@ -808,11 +801,11 @@ impl FoldSnapshot { None => return false, } - if cursor.end(&()).row() == inlay_point.row() { - cursor.next(&()); + if cursor.end().row() == inlay_point.row() { + cursor.next(); } else { inlay_point.0 += Point::new(1, 0); - cursor.seek(&inlay_point, Bias::Right, &()); + cursor.seek(&inlay_point, Bias::Right); } } } @@ -824,14 +817,14 @@ impl FoldSnapshot { highlights: Highlights<'a>, ) -> FoldChunks<'a> { let mut transform_cursor = self.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); - transform_cursor.seek(&range.start, Bias::Right, &()); + transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { let overshoot = range.start.0 - transform_cursor.start().0.0; transform_cursor.start().1 + InlayOffset(overshoot) }; - let transform_end = transform_cursor.end(&()); + let transform_end = transform_cursor.end(); let inlay_end = if transform_cursor .item() @@ -879,14 +872,14 @@ impl FoldSnapshot { pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { let mut cursor = self.transforms.cursor::<(FoldPoint, InlayPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if let Some(transform) = cursor.item() { let transform_start = cursor.start().0.0; if transform.placeholder.is_some() { if point.0 == transform_start || matches!(bias, Bias::Left) { FoldPoint(transform_start) } else { - FoldPoint(cursor.end(&()).0.0) + FoldPoint(cursor.end().0.0) } } else { let overshoot = InlayPoint(point.0 - transform_start); @@ -945,7 +938,7 @@ fn intersecting_folds<'a>( start_cmp == Ordering::Less && end_cmp == Ordering::Greater } }); - cursor.next(buffer); + cursor.next(); cursor } @@ -1211,7 +1204,7 @@ pub struct FoldRows<'a> { impl FoldRows<'_> { pub(crate) fn seek(&mut self, row: u32) { let fold_point = FoldPoint::new(row, 0); - self.cursor.seek(&fold_point, Bias::Left, &()); + self.cursor.seek(&fold_point, Bias::Left); let overshoot = fold_point.0 - self.cursor.start().0.0; let inlay_point = InlayPoint(self.cursor.start().1.0 + overshoot); self.input_rows.seek(inlay_point.row()); @@ -1224,8 +1217,8 @@ impl Iterator for FoldRows<'_> { fn next(&mut self) -> Option { let mut traversed_fold = false; - while self.fold_point > self.cursor.end(&()).0 { - self.cursor.next(&()); + while self.fold_point > self.cursor.end().0 { + self.cursor.next(); traversed_fold = true; if self.cursor.item().is_none() { break; @@ -1330,14 +1323,14 @@ pub struct FoldChunks<'a> { impl FoldChunks<'_> { pub(crate) fn seek(&mut self, range: Range) { - self.transform_cursor.seek(&range.start, Bias::Right, &()); + self.transform_cursor.seek(&range.start, Bias::Right); let inlay_start = { let overshoot = range.start.0 - self.transform_cursor.start().0.0; self.transform_cursor.start().1 + InlayOffset(overshoot) }; - let transform_end = self.transform_cursor.end(&()); + let transform_end = self.transform_cursor.end(); let inlay_end = if self .transform_cursor @@ -1376,10 +1369,10 @@ impl<'a> Iterator for FoldChunks<'a> { self.inlay_chunk.take(); self.inlay_offset += InlayOffset(transform.summary.input.len); - while self.inlay_offset >= self.transform_cursor.end(&()).1 + while self.inlay_offset >= self.transform_cursor.end().1 && self.transform_cursor.item().is_some() { - self.transform_cursor.next(&()); + self.transform_cursor.next(); } self.output_offset.0 += placeholder.text.len(); @@ -1396,7 +1389,7 @@ impl<'a> Iterator for FoldChunks<'a> { && self.inlay_chunks.offset() != self.inlay_offset { let transform_start = self.transform_cursor.start(); - let transform_end = self.transform_cursor.end(&()); + let transform_end = self.transform_cursor.end(); let inlay_end = if self.max_output_offset < transform_end.0 { let overshoot = self.max_output_offset.0 - transform_start.0.0; transform_start.1 + InlayOffset(overshoot) @@ -1417,14 +1410,14 @@ impl<'a> Iterator for FoldChunks<'a> { if let Some((buffer_chunk_start, mut inlay_chunk)) = self.inlay_chunk.clone() { let chunk = &mut inlay_chunk.chunk; let buffer_chunk_end = buffer_chunk_start + InlayOffset(chunk.text.len()); - let transform_end = self.transform_cursor.end(&()).1; + let transform_end = self.transform_cursor.end().1; let chunk_end = buffer_chunk_end.min(transform_end); chunk.text = &chunk.text [(self.inlay_offset - buffer_chunk_start).0..(chunk_end - buffer_chunk_start).0]; if chunk_end == transform_end { - self.transform_cursor.next(&()); + self.transform_cursor.next(); } else if chunk_end == buffer_chunk_end { self.inlay_chunk.take(); } @@ -1456,7 +1449,7 @@ impl FoldOffset { let mut cursor = snapshot .transforms .cursor::<(FoldOffset, TransformSummary)>(&()); - cursor.seek(&self, Bias::Right, &()); + cursor.seek(&self, Bias::Right); let overshoot = if cursor.item().map_or(true, |t| t.is_fold()) { Point::new(0, (self.0 - cursor.start().0.0) as u32) } else { @@ -1470,7 +1463,7 @@ impl FoldOffset { #[cfg(test)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { let mut cursor = snapshot.transforms.cursor::<(FoldOffset, InlayOffset)>(&()); - cursor.seek(&self, Bias::Right, &()); + cursor.seek(&self, Bias::Right); let overshoot = self.0 - cursor.start().0.0; InlayOffset(cursor.start().1.0 + overshoot) } diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index f7a696860a..a36d18ff6d 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -263,7 +263,7 @@ pub struct InlayChunk<'a> { impl InlayChunks<'_> { pub fn seek(&mut self, new_range: Range) { - self.transforms.seek(&new_range.start, Bias::Right, &()); + self.transforms.seek(&new_range.start, Bias::Right); let buffer_range = self.snapshot.to_buffer_offset(new_range.start) ..self.snapshot.to_buffer_offset(new_range.end); @@ -296,12 +296,12 @@ impl<'a> Iterator for InlayChunks<'a> { *chunk = self.buffer_chunks.next().unwrap(); } - let desired_bytes = self.transforms.end(&()).0.0 - self.output_offset.0; + let desired_bytes = self.transforms.end().0.0 - self.output_offset.0; // If we're already at the transform boundary, skip to the next transform if desired_bytes == 0 { self.inlay_chunks = None; - self.transforms.next(&()); + self.transforms.next(); return self.next(); } @@ -397,7 +397,7 @@ impl<'a> Iterator for InlayChunks<'a> { let inlay_chunks = self.inlay_chunks.get_or_insert_with(|| { let start = offset_in_inlay; - let end = cmp::min(self.max_output_offset, self.transforms.end(&()).0) + let end = cmp::min(self.max_output_offset, self.transforms.end().0) - self.transforms.start().0; inlay.text.chunks_in_range(start.0..end.0) }); @@ -441,9 +441,9 @@ impl<'a> Iterator for InlayChunks<'a> { } }; - if self.output_offset >= self.transforms.end(&()).0 { + if self.output_offset >= self.transforms.end().0 { self.inlay_chunks = None; - self.transforms.next(&()); + self.transforms.next(); } Some(chunk) @@ -453,7 +453,7 @@ impl<'a> Iterator for InlayChunks<'a> { impl InlayBufferRows<'_> { pub fn seek(&mut self, row: u32) { let inlay_point = InlayPoint::new(row, 0); - self.transforms.seek(&inlay_point, Bias::Left, &()); + self.transforms.seek(&inlay_point, Bias::Left); let mut buffer_point = self.transforms.start().1; let buffer_row = MultiBufferRow(if row == 0 { @@ -487,7 +487,7 @@ impl Iterator for InlayBufferRows<'_> { self.inlay_row += 1; self.transforms - .seek_forward(&InlayPoint::new(self.inlay_row, 0), Bias::Left, &()); + .seek_forward(&InlayPoint::new(self.inlay_row, 0), Bias::Left); Some(buffer_row) } @@ -556,18 +556,18 @@ impl InlayMap { let mut cursor = snapshot.transforms.cursor::<(usize, InlayOffset)>(&()); let mut buffer_edits_iter = buffer_edits.iter().peekable(); while let Some(buffer_edit) = buffer_edits_iter.next() { - new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left, &()), &()); + new_transforms.append(cursor.slice(&buffer_edit.old.start, Bias::Left), &()); if let Some(Transform::Isomorphic(transform)) = cursor.item() { - if cursor.end(&()).0 == buffer_edit.old.start { + if cursor.end().0 == buffer_edit.old.start { push_isomorphic(&mut new_transforms, *transform); - cursor.next(&()); + cursor.next(); } } // Remove all the inlays and transforms contained by the edit. let old_start = cursor.start().1 + InlayOffset(buffer_edit.old.start - cursor.start().0); - cursor.seek(&buffer_edit.old.end, Bias::Right, &()); + cursor.seek(&buffer_edit.old.end, Bias::Right); let old_end = cursor.start().1 + InlayOffset(buffer_edit.old.end - cursor.start().0); @@ -625,20 +625,20 @@ impl InlayMap { // we can push its remainder. if buffer_edits_iter .peek() - .map_or(true, |edit| edit.old.start >= cursor.end(&()).0) + .map_or(true, |edit| edit.old.start >= cursor.end().0) { let transform_start = new_transforms.summary().input.len; let transform_end = - buffer_edit.new.end + (cursor.end(&()).0 - buffer_edit.old.end); + buffer_edit.new.end + (cursor.end().0 - buffer_edit.old.end); push_isomorphic( &mut new_transforms, buffer_snapshot.text_summary_for_range(transform_start..transform_end), ); - cursor.next(&()); + cursor.next(); } } - new_transforms.append(cursor.suffix(&()), &()); + new_transforms.append(cursor.suffix(), &()); if new_transforms.is_empty() { new_transforms.push(Transform::Isomorphic(Default::default()), &()); } @@ -773,7 +773,7 @@ impl InlaySnapshot { let mut cursor = self .transforms .cursor::<(InlayOffset, (InlayPoint, usize))>(&()); - cursor.seek(&offset, Bias::Right, &()); + cursor.seek(&offset, Bias::Right); let overshoot = offset.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -803,7 +803,7 @@ impl InlaySnapshot { let mut cursor = self .transforms .cursor::<(InlayPoint, (InlayOffset, Point))>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); let overshoot = point.0 - cursor.start().0.0; match cursor.item() { Some(Transform::Isomorphic(_)) => { @@ -822,7 +822,7 @@ impl InlaySnapshot { } pub fn to_buffer_point(&self, point: InlayPoint) -> Point { let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { let overshoot = point.0 - cursor.start().0.0; @@ -834,7 +834,7 @@ impl InlaySnapshot { } pub fn to_buffer_offset(&self, offset: InlayOffset) -> usize { let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&offset, Bias::Right, &()); + cursor.seek(&offset, Bias::Right); match cursor.item() { Some(Transform::Isomorphic(_)) => { let overshoot = offset - cursor.start().0; @@ -847,19 +847,19 @@ impl InlaySnapshot { pub fn to_inlay_offset(&self, offset: usize) -> InlayOffset { let mut cursor = self.transforms.cursor::<(usize, InlayOffset)>(&()); - cursor.seek(&offset, Bias::Left, &()); + cursor.seek(&offset, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(_)) => { - if offset == cursor.end(&()).0 { + if offset == cursor.end().0 { while let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { break; } else { - cursor.next(&()); + cursor.next(); } } - return cursor.end(&()).1; + return cursor.end().1; } else { let overshoot = offset - cursor.start().0; return InlayOffset(cursor.start().1.0 + overshoot); @@ -867,7 +867,7 @@ impl InlaySnapshot { } Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Left { - cursor.next(&()); + cursor.next(); } else { return cursor.start().1; } @@ -880,19 +880,19 @@ impl InlaySnapshot { } pub fn to_inlay_point(&self, point: Point) -> InlayPoint { let mut cursor = self.transforms.cursor::<(Point, InlayPoint)>(&()); - cursor.seek(&point, Bias::Left, &()); + cursor.seek(&point, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(_)) => { - if point == cursor.end(&()).0 { + if point == cursor.end().0 { while let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { break; } else { - cursor.next(&()); + cursor.next(); } } - return cursor.end(&()).1; + return cursor.end().1; } else { let overshoot = point - cursor.start().0; return InlayPoint(cursor.start().1.0 + overshoot); @@ -900,7 +900,7 @@ impl InlaySnapshot { } Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Left { - cursor.next(&()); + cursor.next(); } else { return cursor.start().1; } @@ -914,7 +914,7 @@ impl InlaySnapshot { pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); - cursor.seek(&point, Bias::Left, &()); + cursor.seek(&point, Bias::Left); loop { match cursor.item() { Some(Transform::Isomorphic(transform)) => { @@ -923,7 +923,7 @@ impl InlaySnapshot { if inlay.position.bias() == Bias::Left { return point; } else if bias == Bias::Left { - cursor.prev(&()); + cursor.prev(); } else if transform.first_line_chars == 0 { point.0 += Point::new(1, 0); } else { @@ -932,12 +932,12 @@ impl InlaySnapshot { } else { return point; } - } else if cursor.end(&()).0 == point { + } else if cursor.end().0 == point { if let Some(Transform::Inlay(inlay)) = cursor.next_item() { if inlay.position.bias() == Bias::Right { return point; } else if bias == Bias::Right { - cursor.next(&()); + cursor.next(); } else if point.0.column == 0 { point.0.row -= 1; point.0.column = self.line_len(point.0.row); @@ -970,7 +970,7 @@ impl InlaySnapshot { } _ => return point, } - } else if point == cursor.end(&()).0 && inlay.position.bias() == Bias::Left { + } else if point == cursor.end().0 && inlay.position.bias() == Bias::Left { match cursor.next_item() { Some(Transform::Inlay(inlay)) => { if inlay.position.bias() == Bias::Right { @@ -983,9 +983,9 @@ impl InlaySnapshot { if bias == Bias::Left { point = cursor.start().0; - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); point = cursor.start().0; } } @@ -993,9 +993,9 @@ impl InlaySnapshot { bias = bias.invert(); if bias == Bias::Left { point = cursor.start().0; - cursor.prev(&()); + cursor.prev(); } else { - cursor.next(&()); + cursor.next(); point = cursor.start().0; } } @@ -1011,7 +1011,7 @@ impl InlaySnapshot { let mut summary = TextSummary::default(); let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + cursor.seek(&range.start, Bias::Right); let overshoot = range.start.0 - cursor.start().0.0; match cursor.item() { @@ -1019,22 +1019,22 @@ impl InlaySnapshot { let buffer_start = cursor.start().1; let suffix_start = buffer_start + overshoot; let suffix_end = - buffer_start + (cmp::min(cursor.end(&()).0, range.end).0 - cursor.start().0.0); + buffer_start + (cmp::min(cursor.end().0, range.end).0 - cursor.start().0.0); summary = self.buffer.text_summary_for_range(suffix_start..suffix_end); - cursor.next(&()); + cursor.next(); } Some(Transform::Inlay(inlay)) => { let suffix_start = overshoot; - let suffix_end = cmp::min(cursor.end(&()).0, range.end).0 - cursor.start().0.0; + let suffix_end = cmp::min(cursor.end().0, range.end).0 - cursor.start().0.0; summary = inlay.text.cursor(suffix_start).summary(suffix_end); - cursor.next(&()); + cursor.next(); } None => {} } if range.end > cursor.start().0 { summary += cursor - .summary::<_, TransformSummary>(&range.end, Bias::Right, &()) + .summary::<_, TransformSummary>(&range.end, Bias::Right) .output; let overshoot = range.end.0 - cursor.start().0.0; @@ -1060,7 +1060,7 @@ impl InlaySnapshot { pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { let mut cursor = self.transforms.cursor::<(InlayPoint, Point)>(&()); let inlay_point = InlayPoint::new(row, 0); - cursor.seek(&inlay_point, Bias::Left, &()); + cursor.seek(&inlay_point, Bias::Left); let max_buffer_row = self.buffer.max_row(); let mut buffer_point = cursor.start().1; @@ -1101,7 +1101,7 @@ impl InlaySnapshot { highlights: Highlights<'a>, ) -> InlayChunks<'a> { let mut cursor = self.transforms.cursor::<(InlayOffset, usize)>(&()); - cursor.seek(&range.start, Bias::Right, &()); + cursor.seek(&range.start, Bias::Right); let buffer_range = self.to_buffer_offset(range.start)..self.to_buffer_offset(range.end); let buffer_chunks = CustomHighlightsChunks::new( diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index a29bf53882..d55577826e 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -72,7 +72,7 @@ pub struct WrapRows<'a> { impl WrapRows<'_> { pub(crate) fn seek(&mut self, start_row: u32) { self.transforms - .seek(&WrapPoint::new(start_row, 0), Bias::Left, &()); + .seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = self.transforms.start().1.row(); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { input_row += start_row - self.transforms.start().0.row(); @@ -340,7 +340,7 @@ impl WrapSnapshot { let mut tab_edits_iter = tab_edits.iter().peekable(); new_transforms = - old_cursor.slice(&tab_edits_iter.peek().unwrap().old.start, Bias::Right, &()); + old_cursor.slice(&tab_edits_iter.peek().unwrap().old.start, Bias::Right); while let Some(edit) = tab_edits_iter.next() { if edit.new.start > TabPoint::from(new_transforms.summary().input.lines) { @@ -356,31 +356,29 @@ impl WrapSnapshot { )); } - old_cursor.seek_forward(&edit.old.end, Bias::Right, &()); + old_cursor.seek_forward(&edit.old.end, Bias::Right); if let Some(next_edit) = tab_edits_iter.peek() { - if next_edit.old.start > old_cursor.end(&()) { - if old_cursor.end(&()) > edit.old.end { + if next_edit.old.start > old_cursor.end() { + if old_cursor.end() > edit.old.end { let summary = self .tab_snapshot - .text_summary_for_range(edit.old.end..old_cursor.end(&())); + .text_summary_for_range(edit.old.end..old_cursor.end()); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append( - old_cursor.slice(&next_edit.old.start, Bias::Right, &()), - &(), - ); + old_cursor.next(); + new_transforms + .append(old_cursor.slice(&next_edit.old.start, Bias::Right), &()); } } else { - if old_cursor.end(&()) > edit.old.end { + if old_cursor.end() > edit.old.end { let summary = self .tab_snapshot - .text_summary_for_range(edit.old.end..old_cursor.end(&())); + .text_summary_for_range(edit.old.end..old_cursor.end()); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append(old_cursor.suffix(&()), &()); + old_cursor.next(); + new_transforms.append(old_cursor.suffix(), &()); } } } @@ -441,7 +439,6 @@ impl WrapSnapshot { new_transforms = old_cursor.slice( &TabPoint::new(row_edits.peek().unwrap().old_rows.start, 0), Bias::Right, - &(), ); while let Some(edit) = row_edits.next() { @@ -516,34 +513,31 @@ impl WrapSnapshot { } new_transforms.extend(edit_transforms, &()); - old_cursor.seek_forward(&TabPoint::new(edit.old_rows.end, 0), Bias::Right, &()); + old_cursor.seek_forward(&TabPoint::new(edit.old_rows.end, 0), Bias::Right); if let Some(next_edit) = row_edits.peek() { - if next_edit.old_rows.start > old_cursor.end(&()).row() { - if old_cursor.end(&()) > TabPoint::new(edit.old_rows.end, 0) { + if next_edit.old_rows.start > old_cursor.end().row() { + if old_cursor.end() > TabPoint::new(edit.old_rows.end, 0) { let summary = self.tab_snapshot.text_summary_for_range( - TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(&()), + TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(), ); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); + old_cursor.next(); new_transforms.append( - old_cursor.slice( - &TabPoint::new(next_edit.old_rows.start, 0), - Bias::Right, - &(), - ), + old_cursor + .slice(&TabPoint::new(next_edit.old_rows.start, 0), Bias::Right), &(), ); } } else { - if old_cursor.end(&()) > TabPoint::new(edit.old_rows.end, 0) { + if old_cursor.end() > TabPoint::new(edit.old_rows.end, 0) { let summary = self.tab_snapshot.text_summary_for_range( - TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(&()), + TabPoint::new(edit.old_rows.end, 0)..old_cursor.end(), ); new_transforms.push_or_extend(Transform::isomorphic(summary)); } - old_cursor.next(&()); - new_transforms.append(old_cursor.suffix(&()), &()); + old_cursor.next(); + new_transforms.append(old_cursor.suffix(), &()); } } } @@ -570,19 +564,19 @@ impl WrapSnapshot { tab_edit.new.start.0.column = 0; tab_edit.new.end.0 += Point::new(1, 0); - old_cursor.seek(&tab_edit.old.start, Bias::Right, &()); + old_cursor.seek(&tab_edit.old.start, Bias::Right); let mut old_start = old_cursor.start().output.lines; old_start += tab_edit.old.start.0 - old_cursor.start().input.lines; - old_cursor.seek(&tab_edit.old.end, Bias::Right, &()); + old_cursor.seek(&tab_edit.old.end, Bias::Right); let mut old_end = old_cursor.start().output.lines; old_end += tab_edit.old.end.0 - old_cursor.start().input.lines; - new_cursor.seek(&tab_edit.new.start, Bias::Right, &()); + new_cursor.seek(&tab_edit.new.start, Bias::Right); let mut new_start = new_cursor.start().output.lines; new_start += tab_edit.new.start.0 - new_cursor.start().input.lines; - new_cursor.seek(&tab_edit.new.end, Bias::Right, &()); + new_cursor.seek(&tab_edit.new.end, Bias::Right); let mut new_end = new_cursor.start().output.lines; new_end += tab_edit.new.end.0 - new_cursor.start().input.lines; @@ -605,7 +599,7 @@ impl WrapSnapshot { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - transforms.seek(&output_start, Bias::Right, &()); + transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(transforms.start().1.0); if transforms.item().map_or(false, |t| t.is_isomorphic()) { input_start.0 += output_start.0 - transforms.start().0.0; @@ -633,7 +627,7 @@ impl WrapSnapshot { pub fn line_len(&self, row: u32) -> u32 { let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left, &()); + cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Left); if cursor .item() .map_or(false, |transform| transform.is_isomorphic()) @@ -658,10 +652,10 @@ impl WrapSnapshot { let end = WrapPoint::new(rows.end, 0); let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&start, Bias::Right, &()); + cursor.seek(&start, Bias::Right); if let Some(transform) = cursor.item() { let start_in_transform = start.0 - cursor.start().0.0; - let end_in_transform = cmp::min(end, cursor.end(&()).0).0 - cursor.start().0.0; + let end_in_transform = cmp::min(end, cursor.end().0).0 - cursor.start().0.0; if transform.is_isomorphic() { let tab_start = TabPoint(cursor.start().1.0 + start_in_transform); let tab_end = TabPoint(cursor.start().1.0 + end_in_transform); @@ -678,12 +672,12 @@ impl WrapSnapshot { }; } - cursor.next(&()); + cursor.next(); } if rows.end > cursor.start().0.row() { summary += &cursor - .summary::<_, TransformSummary>(&WrapPoint::new(rows.end, 0), Bias::Right, &()) + .summary::<_, TransformSummary>(&WrapPoint::new(rows.end, 0), Bias::Right) .output; if let Some(transform) = cursor.item() { @@ -712,7 +706,7 @@ impl WrapSnapshot { pub fn soft_wrap_indent(&self, row: u32) -> Option { let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Right, &()); + cursor.seek(&WrapPoint::new(row + 1, 0), Bias::Right); cursor.item().and_then(|transform| { if transform.is_isomorphic() { None @@ -728,7 +722,7 @@ impl WrapSnapshot { pub fn row_infos(&self, start_row: u32) -> WrapRows<'_> { let mut transforms = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left, &()); + transforms.seek(&WrapPoint::new(start_row, 0), Bias::Left); let mut input_row = transforms.start().1.row(); if transforms.item().map_or(false, |t| t.is_isomorphic()) { input_row += start_row - transforms.start().0.row(); @@ -748,7 +742,7 @@ impl WrapSnapshot { pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); let mut tab_point = cursor.start().1.0; if cursor.item().map_or(false, |t| t.is_isomorphic()) { tab_point += point.0 - cursor.start().0.0; @@ -766,14 +760,14 @@ impl WrapSnapshot { pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { let mut cursor = self.transforms.cursor::<(TabPoint, WrapPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); WrapPoint(cursor.start().1.0 + (point.0 - cursor.start().0.0)) } pub fn clip_point(&self, mut point: WrapPoint, bias: Bias) -> WrapPoint { if bias == Bias::Left { let mut cursor = self.transforms.cursor::(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if cursor.item().map_or(false, |t| !t.is_isomorphic()) { point = *cursor.start(); *point.column_mut() -= 1; @@ -791,16 +785,16 @@ impl WrapSnapshot { *point.column_mut() = 0; let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); if cursor.item().is_none() { - cursor.prev(&()); + cursor.prev(); } while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { - return cmp::min(cursor.end(&()).0.row(), point.row()); + return cmp::min(cursor.end().0.row(), point.row()); } else { - cursor.prev(&()); + cursor.prev(); } } @@ -811,12 +805,12 @@ impl WrapSnapshot { point.0 += Point::new(1, 0); let mut cursor = self.transforms.cursor::<(WrapPoint, TabPoint)>(&()); - cursor.seek(&point, Bias::Right, &()); + cursor.seek(&point, Bias::Right); while let Some(transform) = cursor.item() { if transform.is_isomorphic() && cursor.start().1.column() == 0 { return Some(cmp::max(cursor.start().0.row(), point.row())); } else { - cursor.next(&()); + cursor.next(); } } @@ -889,7 +883,7 @@ impl WrapChunks<'_> { pub(crate) fn seek(&mut self, rows: Range) { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); - self.transforms.seek(&output_start, Bias::Right, &()); + self.transforms.seek(&output_start, Bias::Right); let mut input_start = TabPoint(self.transforms.start().1.0); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { input_start.0 += output_start.0 - self.transforms.start().0.0; @@ -930,7 +924,7 @@ impl<'a> Iterator for WrapChunks<'a> { } self.output_position.0 += summary; - self.transforms.next(&()); + self.transforms.next(); return Some(Chunk { text: &display_text[start_ix..end_ix], ..Default::default() @@ -942,7 +936,7 @@ impl<'a> Iterator for WrapChunks<'a> { } let mut input_len = 0; - let transform_end = self.transforms.end(&()).0; + let transform_end = self.transforms.end().0; for c in self.input_chunk.text.chars() { let char_len = c.len_utf8(); input_len += char_len; @@ -954,7 +948,7 @@ impl<'a> Iterator for WrapChunks<'a> { } if self.output_position >= transform_end { - self.transforms.next(&()); + self.transforms.next(); break; } } @@ -982,7 +976,7 @@ impl Iterator for WrapRows<'_> { self.output_row += 1; self.transforms - .seek_forward(&WrapPoint::new(self.output_row, 0), Bias::Left, &()); + .seek_forward(&WrapPoint::new(self.output_row, 0), Bias::Left); if self.transforms.item().map_or(false, |t| t.is_isomorphic()) { self.input_buffer_row = self.input_buffer_rows.next().unwrap(); self.soft_wrapped = false; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index c5fe0db74c..8f57fb1a20 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -109,10 +109,10 @@ use inline_completion::{EditPredictionProvider, InlineCompletionProviderHandle}; pub use items::MAX_TAB_TITLE_LEN; use itertools::Itertools; use language::{ - AutoindentMode, BracketMatch, BracketPair, Buffer, Capability, CharKind, CodeLabel, - CursorShape, DiagnosticEntry, DiffOptions, DocumentationConfig, EditPredictionsMode, - EditPreview, HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, - Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, + AutoindentMode, BlockCommentConfig, BracketMatch, BracketPair, Buffer, Capability, CharKind, + CodeLabel, CursorShape, DiagnosticEntry, DiffOptions, EditPredictionsMode, EditPreview, + HighlightedText, IndentKind, IndentSize, Language, OffsetRangeExt, Point, Selection, + SelectionGoal, TextObject, TransactionId, TreeSitterOptions, WordsQuery, language_settings::{ self, InlayHintSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode, all_language_settings, language_settings, @@ -134,7 +134,7 @@ use project::{ session::{Session, SessionEvent}, }, git_store::{GitStoreEvent, RepositoryEvent}, - project_settings::DiagnosticSeverity, + project_settings::{DiagnosticSeverity, GoToDiagnosticSeverityFilter}, }; pub use git::blame::BlameRenderer; @@ -213,6 +213,7 @@ use workspace::{ notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt}, searchable::SearchEvent, }; +use zed_actions; use crate::{ code_context_menus::CompletionsMenuSource, @@ -356,6 +357,7 @@ pub fn init(cx: &mut App) { workspace.register_action(Editor::new_file_vertical); workspace.register_action(Editor::new_file_horizontal); workspace.register_action(Editor::cancel_language_server_work); + workspace.register_action(Editor::toggle_focus); }, ) .detach(); @@ -482,9 +484,7 @@ pub enum SelectMode { #[derive(Clone, PartialEq, Eq, Debug)] pub enum EditorMode { - SingleLine { - auto_width: bool, - }, + SingleLine, AutoHeight { min_lines: usize, max_lines: Option, @@ -951,6 +951,7 @@ struct InlineBlamePopover { hide_task: Option>, popover_bounds: Option>, popover_state: InlineBlamePopoverState, + keyboard_grace: bool, } enum SelectionDragState { @@ -1662,13 +1663,7 @@ impl Editor { pub fn single_line(window: &mut Window, cx: &mut Context) -> Self { let buffer = cx.new(|cx| Buffer::local("", cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: false }, - buffer, - None, - window, - cx, - ) + Self::new(EditorMode::SingleLine, buffer, None, window, cx) } pub fn multi_line(window: &mut Window, cx: &mut Context) -> Self { @@ -1677,18 +1672,6 @@ impl Editor { Self::new(EditorMode::full(), buffer, None, window, cx) } - pub fn auto_width(window: &mut Window, cx: &mut Context) -> Self { - let buffer = cx.new(|cx| Buffer::local("", cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - Self::new( - EditorMode::SingleLine { auto_width: true }, - buffer, - None, - window, - cx, - ) - } - pub fn auto_height( min_lines: usize, max_lines: usize, @@ -1795,6 +1778,7 @@ impl Editor { ); let full_mode = mode.is_full(); + let is_minimap = mode.is_minimap(); let diagnostics_max_severity = if full_mode { EditorSettings::get_global(cx) .diagnostics_max_severity @@ -1855,13 +1839,19 @@ impl Editor { let selections = SelectionsCollection::new(display_map.clone(), buffer.clone()); - let blink_manager = cx.new(|cx| BlinkManager::new(CURSOR_BLINK_INTERVAL, cx)); + let blink_manager = cx.new(|cx| { + let mut blink_manager = BlinkManager::new(CURSOR_BLINK_INTERVAL, cx); + if is_minimap { + blink_manager.disable(cx); + } + blink_manager + }); let soft_wrap_mode_override = matches!(mode, EditorMode::SingleLine { .. }) .then(|| language_settings::SoftWrap::None); let mut project_subscriptions = Vec::new(); - if mode.is_full() { + if full_mode { if let Some(project) = project.as_ref() { project_subscriptions.push(cx.subscribe_in( project, @@ -1972,18 +1962,23 @@ impl Editor { let inlay_hint_settings = inlay_hint_settings(selections.newest_anchor().head(), &buffer_snapshot, cx); let focus_handle = cx.focus_handle(); - cx.on_focus(&focus_handle, window, Self::handle_focus) - .detach(); - cx.on_focus_in(&focus_handle, window, Self::handle_focus_in) - .detach(); - cx.on_focus_out(&focus_handle, window, Self::handle_focus_out) - .detach(); - cx.on_blur(&focus_handle, window, Self::handle_blur) - .detach(); - cx.observe_pending_input(window, Self::observe_pending_input) - .detach(); + if !is_minimap { + cx.on_focus(&focus_handle, window, Self::handle_focus) + .detach(); + cx.on_focus_in(&focus_handle, window, Self::handle_focus_in) + .detach(); + cx.on_focus_out(&focus_handle, window, Self::handle_focus_out) + .detach(); + cx.on_blur(&focus_handle, window, Self::handle_blur) + .detach(); + cx.observe_pending_input(window, Self::observe_pending_input) + .detach(); + } - let show_indent_guides = if matches!(mode, EditorMode::SingleLine { .. }) { + let show_indent_guides = if matches!( + mode, + EditorMode::SingleLine { .. } | EditorMode::Minimap { .. } + ) { Some(false) } else { None @@ -2049,10 +2044,10 @@ impl Editor { minimap_visibility: MinimapVisibility::for_mode(&mode, cx), offset_content: !matches!(mode, EditorMode::SingleLine { .. }), show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs, - show_gutter: mode.is_full(), - show_line_numbers: None, + show_gutter: full_mode, + show_line_numbers: (!full_mode).then_some(false), use_relative_line_numbers: None, - disable_expand_excerpt_buttons: false, + disable_expand_excerpt_buttons: !full_mode, show_git_diff_gutter: None, show_code_actions: None, show_runnables: None, @@ -2086,7 +2081,7 @@ impl Editor { document_highlights_task: None, linked_editing_range_task: None, pending_rename: None, - searchable: true, + searchable: !is_minimap, cursor_shape: EditorSettings::get_global(cx) .cursor_shape .unwrap_or_default(), @@ -2094,9 +2089,9 @@ impl Editor { autoindent_mode: Some(AutoindentMode::EachLine), collapse_matches: false, workspace: None, - input_enabled: true, - use_modal_editing: mode.is_full(), - read_only: mode.is_minimap(), + input_enabled: !is_minimap, + use_modal_editing: full_mode, + read_only: is_minimap, use_autoclose: true, use_auto_surround: true, auto_replace_emoji_shortcode: false, @@ -2112,11 +2107,10 @@ impl Editor { edit_prediction_preview: EditPredictionPreview::Inactive { released_too_fast: false, }, - inline_diagnostics_enabled: mode.is_full(), - diagnostics_enabled: mode.is_full(), + inline_diagnostics_enabled: full_mode, + diagnostics_enabled: full_mode, inline_value_cache: InlineValueCache::new(inlay_hint_settings.show_value_hints), inlay_hint_cache: InlayHintCache::new(inlay_hint_settings), - gutter_hovered: false, pixel_position_of_newest_cursor: None, last_bounds: None, @@ -2139,9 +2133,10 @@ impl Editor { show_git_blame_inline: false, show_selection_menu: None, show_git_blame_inline_delay_task: None, - git_blame_inline_enabled: ProjectSettings::get_global(cx).git.inline_blame_enabled(), + git_blame_inline_enabled: full_mode + && ProjectSettings::get_global(cx).git.inline_blame_enabled(), render_diff_hunk_controls: Arc::new(render_diff_hunk_controls), - serialize_dirty_buffers: !mode.is_minimap() + serialize_dirty_buffers: !is_minimap && ProjectSettings::get_global(cx) .session .restore_unsaved_buffers, @@ -2152,27 +2147,31 @@ impl Editor { breakpoint_store, gutter_breakpoint_indicator: (None, None), hovered_diff_hunk_row: None, - _subscriptions: vec![ - cx.observe(&buffer, Self::on_buffer_changed), - cx.subscribe_in(&buffer, window, Self::on_buffer_event), - cx.observe_in(&display_map, window, Self::on_display_map_changed), - cx.observe(&blink_manager, |_, _, cx| cx.notify()), - cx.observe_global_in::(window, Self::settings_changed), - observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), - cx.observe_window_activation(window, |editor, window, cx| { - let active = window.is_window_active(); - editor.blink_manager.update(cx, |blink_manager, cx| { - if active { - blink_manager.enable(cx); - } else { - blink_manager.disable(cx); - } - }); - if active { - editor.show_mouse_cursor(cx); - } - }), - ], + _subscriptions: (!is_minimap) + .then(|| { + vec![ + cx.observe(&buffer, Self::on_buffer_changed), + cx.subscribe_in(&buffer, window, Self::on_buffer_event), + cx.observe_in(&display_map, window, Self::on_display_map_changed), + cx.observe(&blink_manager, |_, _, cx| cx.notify()), + cx.observe_global_in::(window, Self::settings_changed), + observe_buffer_font_size_adjustment(cx, |_, cx| cx.notify()), + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { + if active { + blink_manager.enable(cx); + } else { + blink_manager.disable(cx); + } + }); + if active { + editor.show_mouse_cursor(cx); + } + }), + ] + }) + .unwrap_or_default(), tasks_update_task: None, pull_diagnostics_task: Task::ready(()), colors: None, @@ -2203,6 +2202,11 @@ impl Editor { selection_drag_state: SelectionDragState::None, folding_newlines: Task::ready(()), }; + + if is_minimap { + return editor; + } + if let Some(breakpoints) = editor.breakpoint_store.as_ref() { editor ._subscriptions @@ -2322,7 +2326,10 @@ impl Editor { editor.update_lsp_data(false, None, window, cx); } - editor.report_editor_event("Editor Opened", None, cx); + if editor.mode.is_full() { + editor.report_editor_event("Editor Opened", None, cx); + } + editor } @@ -2377,13 +2384,17 @@ impl Editor { } match self.context_menu.borrow().as_ref() { - Some(CodeContextMenu::Completions(_)) => { - key_context.add("menu"); - key_context.add("showing_completions"); + Some(CodeContextMenu::Completions(menu)) => { + if menu.visible() { + key_context.add("menu"); + key_context.add("showing_completions"); + } } - Some(CodeContextMenu::CodeActions(_)) => { - key_context.add("menu"); - key_context.add("showing_code_actions") + Some(CodeContextMenu::CodeActions(menu)) => { + if menu.visible() { + key_context.add("menu"); + key_context.add("showing_code_actions") + } } None => {} } @@ -4381,7 +4392,7 @@ impl Editor { .take_while(|c| c.is_whitespace()) .count(); let comment_candidate = snapshot - .chars_for_range(range) + .chars_for_range(range.clone()) .skip(num_of_whitespaces) .take(max_len_of_delimiter) .collect::(); @@ -4397,6 +4408,24 @@ impl Editor { }) .max_by_key(|(_, len)| *len)?; + if let Some(BlockCommentConfig { + start: block_start, .. + }) = language.block_comment() + { + let block_start_trimmed = block_start.trim_end(); + if block_start_trimmed.starts_with(delimiter.trim_end()) { + let line_content = snapshot + .chars_for_range(range) + .skip(num_of_whitespaces) + .take(block_start_trimmed.len()) + .collect::(); + + if line_content.starts_with(block_start_trimmed) { + return None; + } + } + } + let cursor_is_placed_after_comment_marker = num_of_whitespaces + trimmed_len <= start_point.column as usize; if cursor_is_placed_after_comment_marker { @@ -4418,13 +4447,12 @@ impl Editor { return None; } - let DocumentationConfig { + let BlockCommentConfig { start: start_tag, end: end_tag, prefix: delimiter, tab_size: len, - } = language.documentation()?; - + } = language.documentation_comment()?; let is_within_block_comment = buffer .language_scope_at(start_point) .is_some_and(|scope| scope.override_name() == Some("comment")); @@ -4494,7 +4522,7 @@ impl Editor { let cursor_is_at_start_of_end_tag = column == end_tag_offset; if cursor_is_at_start_of_end_tag { - indent_on_extra_newline.len = (*len).into(); + indent_on_extra_newline.len = *len; } } cursor_is_before_end_tag @@ -4507,7 +4535,7 @@ impl Editor { && cursor_is_before_end_tag_if_exists { if cursor_is_after_start_tag { - indent_on_newline.len = (*len).into(); + indent_on_newline.len = *len; } Some(delimiter.clone()) } else { @@ -5426,7 +5454,7 @@ impl Editor { }; let (word_replace_range, word_to_exclude) = if let (word_range, Some(CharKind::Word)) = - buffer_snapshot.surrounding_word(buffer_position) + buffer_snapshot.surrounding_word(buffer_position, false) { let word_to_exclude = buffer_snapshot .text_for_range(word_range.clone()) @@ -6492,21 +6520,55 @@ impl Editor { } } + pub fn blame_hover(&mut self, _: &BlameHover, window: &mut Window, cx: &mut Context) { + let snapshot = self.snapshot(window, cx); + let cursor = self.selections.newest::(cx).head(); + let Some((buffer, point, _)) = snapshot.buffer_snapshot.point_to_buffer_point(cursor) + else { + return; + }; + + let Some(blame) = self.blame.as_ref() else { + return; + }; + + let row_info = RowInfo { + buffer_id: Some(buffer.remote_id()), + buffer_row: Some(point.row), + ..Default::default() + }; + let Some(blame_entry) = blame + .update(cx, |blame, cx| blame.blame_for_rows(&[row_info], cx).next()) + .flatten() + else { + return; + }; + + let anchor = self.selections.newest_anchor().head(); + let position = self.to_pixel_point(anchor, &snapshot, window); + if let (Some(position), Some(last_bounds)) = (position, self.last_bounds) { + self.show_blame_popover(&blame_entry, position + last_bounds.origin, true, cx); + }; + } + fn show_blame_popover( &mut self, blame_entry: &BlameEntry, position: gpui::Point, + ignore_timeout: bool, cx: &mut Context, ) { if let Some(state) = &mut self.inline_blame_popover { state.hide_task.take(); } else { - let delay = EditorSettings::get_global(cx).hover_popover_delay; + let blame_popover_delay = EditorSettings::get_global(cx).hover_popover_delay; let blame_entry = blame_entry.clone(); let show_task = cx.spawn(async move |editor, cx| { - cx.background_executor() - .timer(std::time::Duration::from_millis(delay)) - .await; + if !ignore_timeout { + cx.background_executor() + .timer(std::time::Duration::from_millis(blame_popover_delay)) + .await; + } editor .update(cx, |editor, cx| { editor.inline_blame_popover_show_task.take(); @@ -6535,6 +6597,7 @@ impl Editor { commit_message: details, markdown, }, + keyboard_grace: ignore_timeout, }); cx.notify(); }) @@ -6580,8 +6643,8 @@ impl Editor { } let snapshot = cursor_buffer.read(cx).snapshot(); - let (start_word_range, _) = snapshot.surrounding_word(cursor_buffer_position); - let (end_word_range, _) = snapshot.surrounding_word(tail_buffer_position); + let (start_word_range, _) = snapshot.surrounding_word(cursor_buffer_position, false); + let (end_word_range, _) = snapshot.surrounding_word(tail_buffer_position, false); if start_word_range != end_word_range { self.document_highlights_task.take(); self.clear_background_highlights::(cx); @@ -10426,7 +10489,6 @@ impl Editor { cloned_prompt.clone().into_any_element() }), priority: 0, - render_in_minimap: true, }]; let focus_handle = bp_prompt.focus_handle(cx); @@ -10816,17 +10878,6 @@ impl Editor { }); } - pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context) { - self.manipulate_text(window, cx, |text| { - let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); - if has_upper_case_characters { - text.to_lowercase() - } else { - text.to_uppercase() - } - }) - } - fn manipulate_immutable_lines( &mut self, window: &mut Window, @@ -11082,6 +11133,26 @@ impl Editor { }) } + pub fn convert_to_sentence_case( + &mut self, + _: &ConvertToSentenceCase, + window: &mut Window, + cx: &mut Context, + ) { + self.manipulate_text(window, cx, |text| text.to_case(Case::Sentence)) + } + + pub fn toggle_case(&mut self, _: &ToggleCase, window: &mut Window, cx: &mut Context) { + self.manipulate_text(window, cx, |text| { + let has_upper_case_characters = text.chars().any(|c| c.is_uppercase()); + if has_upper_case_characters { + text.to_lowercase() + } else { + text.to_uppercase() + } + }) + } + pub fn convert_to_rot13( &mut self, _: &ConvertToRot13, @@ -12094,6 +12165,41 @@ impl Editor { }); } + pub fn diff_clipboard_with_selection( + &mut self, + _: &DiffClipboardWithSelection, + window: &mut Window, + cx: &mut Context, + ) { + let selections = self.selections.all::(cx); + + if selections.is_empty() { + log::warn!("There should always be at least one selection in Zed. This is a bug."); + return; + }; + + let clipboard_text = match cx.read_from_clipboard() { + Some(item) => match item.entries().first() { + Some(ClipboardEntry::String(text)) => Some(text.text().to_string()), + _ => None, + }, + None => None, + }; + + let Some(clipboard_text) = clipboard_text else { + log::warn!("Clipboard doesn't contain text."); + return; + }; + + window.dispatch_action( + Box::new(DiffClipboardWithSelectionData { + clipboard_text, + editor: cx.entity(), + }), + cx, + ); + } + pub fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context) { self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); if let Some(item) = cx.read_from_clipboard() { @@ -14253,8 +14359,11 @@ impl Editor { (position..position, first_prefix.clone()) })); } - } else if let Some((full_comment_prefix, comment_suffix)) = - language.block_comment_delimiters() + } else if let Some(BlockCommentConfig { + start: full_comment_prefix, + end: comment_suffix, + .. + }) = language.block_comment() { let comment_prefix = full_comment_prefix.trim_end_matches(' '); let comment_prefix_whitespace = &full_comment_prefix[comment_prefix.len()..]; @@ -15047,7 +15156,7 @@ impl Editor { pub fn go_to_diagnostic( &mut self, - _: &GoToDiagnostic, + action: &GoToDiagnostic, window: &mut Window, cx: &mut Context, ) { @@ -15055,12 +15164,12 @@ impl Editor { return; } self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); - self.go_to_diagnostic_impl(Direction::Next, window, cx) + self.go_to_diagnostic_impl(Direction::Next, action.severity, window, cx) } pub fn go_to_prev_diagnostic( &mut self, - _: &GoToPreviousDiagnostic, + action: &GoToPreviousDiagnostic, window: &mut Window, cx: &mut Context, ) { @@ -15068,12 +15177,13 @@ impl Editor { return; } self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx); - self.go_to_diagnostic_impl(Direction::Prev, window, cx) + self.go_to_diagnostic_impl(Direction::Prev, action.severity, window, cx) } pub fn go_to_diagnostic_impl( &mut self, direction: Direction, + severity: GoToDiagnosticSeverityFilter, window: &mut Window, cx: &mut Context, ) { @@ -15089,9 +15199,11 @@ impl Editor { fn filtered( snapshot: EditorSnapshot, + severity: GoToDiagnosticSeverityFilter, diagnostics: impl Iterator>, ) -> impl Iterator> { diagnostics + .filter(move |entry| severity.matches(entry.diagnostic.severity)) .filter(|entry| entry.range.start != entry.range.end) .filter(|entry| !entry.diagnostic.is_unnecessary) .filter(move |entry| !snapshot.intersects_fold(entry.range.start)) @@ -15100,12 +15212,14 @@ impl Editor { let snapshot = self.snapshot(window, cx); let before = filtered( snapshot.clone(), + severity, buffer .diagnostics_in_range(0..selection.start) .filter(|entry| entry.range.start <= selection.start), ); let after = filtered( snapshot, + severity, buffer .diagnostics_in_range(selection.start..buffer.len()) .filter(|entry| entry.range.start >= selection.start), @@ -16122,7 +16236,6 @@ impl Editor { } }), priority: 0, - render_in_minimap: true, }], Some(Autoscroll::fit()), cx, @@ -16864,7 +16977,7 @@ impl Editor { now: Instant, window: &mut Window, cx: &mut Context, - ) { + ) -> Option { self.end_selection(window, cx); if let Some(tx_id) = self .buffer @@ -16874,7 +16987,10 @@ impl Editor { .insert_transaction(tx_id, self.selections.disjoint_anchors()); cx.emit(EditorEvent::TransactionBegun { transaction_id: tx_id, - }) + }); + Some(tx_id) + } else { + None } } @@ -16902,6 +17018,17 @@ impl Editor { } } + pub fn modify_transaction_selection_history( + &mut self, + transaction_id: TransactionId, + modify: impl FnOnce(&mut (Arc<[Selection]>, Option]>>)), + ) -> bool { + self.selection_history + .transaction_mut(transaction_id) + .map(modify) + .is_some() + } + pub fn set_mark(&mut self, _: &actions::SetMark, window: &mut Window, cx: &mut Context) { if self.selection_mark_mode { self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -16931,6 +17058,18 @@ impl Editor { cx.notify(); } + pub fn toggle_focus( + workspace: &mut Workspace, + _: &actions::ToggleFocus, + window: &mut Window, + cx: &mut Context, + ) { + let Some(item) = workspace.recent_active_item_by_type::(cx) else { + return; + }; + workspace.activate_item(&item, true, true, window, cx); + } + pub fn toggle_fold( &mut self, _: &actions::ToggleFold, @@ -17056,6 +17195,46 @@ impl Editor { } } + pub fn toggle_fold_all( + &mut self, + _: &actions::ToggleFoldAll, + window: &mut Window, + cx: &mut Context, + ) { + if self.buffer.read(cx).is_singleton() { + let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); + let has_folds = display_map + .folds_in_range(0..display_map.buffer_snapshot.len()) + .next() + .is_some(); + + if has_folds { + self.unfold_all(&actions::UnfoldAll, window, cx); + } else { + self.fold_all(&actions::FoldAll, window, cx); + } + } else { + let buffer_ids = self.buffer.read(cx).excerpt_buffer_ids(); + let should_unfold = buffer_ids + .iter() + .any(|buffer_id| self.is_buffer_folded(*buffer_id, cx)); + + self.toggle_fold_multiple_buffers = cx.spawn_in(window, async move |editor, cx| { + editor + .update_in(cx, |editor, _, cx| { + for buffer_id in buffer_ids { + if should_unfold { + editor.unfold_buffer(buffer_id, cx); + } else { + editor.fold_buffer(buffer_id, cx); + } + } + }) + .ok(); + }); + } + } + fn fold_at_level( &mut self, fold_at: &FoldAtLevel, @@ -17985,7 +18164,7 @@ impl Editor { parent: cx.weak_entity(), }, self.buffer.clone(), - self.project.clone(), + None, Some(self.display_map.clone()), window, cx, @@ -19639,8 +19818,9 @@ impl Editor { Anchor::in_buffer(excerpt_id, buffer_id, hint.position), hint.text(), ); - - new_inlays.push(inlay); + if !inlay.text.chars().contains(&'\n') { + new_inlays.push(inlay); + } }); } @@ -19868,14 +20048,12 @@ impl Editor { } fn settings_changed(&mut self, window: &mut Window, cx: &mut Context) { - let new_severity = if self.diagnostics_enabled() { - EditorSettings::get_global(cx) + if self.diagnostics_enabled() { + let new_severity = EditorSettings::get_global(cx) .diagnostics_max_severity - .unwrap_or(DiagnosticSeverity::Hint) - } else { - DiagnosticSeverity::Off - }; - self.set_max_diagnostics_severity(new_severity, cx); + .unwrap_or(DiagnosticSeverity::Hint); + self.set_max_diagnostics_severity(new_severity, cx); + } self.tasks_update_task = Some(self.refresh_runnables(window, cx)); self.update_edit_prediction_settings(cx); self.refresh_inline_completion(true, false, window, cx); @@ -20487,6 +20665,7 @@ impl Editor { if event.blurred != self.focus_handle { self.last_focused_descendant = Some(event.blurred); } + self.selection_drag_state = SelectionDragState::None; self.refresh_inlay_hints(InlayHintRefreshReason::ModifiersChanged(false), cx); } @@ -22057,7 +22236,7 @@ impl SemanticsProvider for Entity { // Fallback on using TreeSitter info to determine identifier range buffer.read_with(cx, |buffer, _| { let snapshot = buffer.snapshot(); - let (range, kind) = snapshot.surrounding_word(position); + let (range, kind) = snapshot.surrounding_word(position, false); if kind != Some(CharKind::Word) { return None; } @@ -22102,7 +22281,7 @@ fn consume_contiguous_rows( selections: &mut Peekable>>, ) -> (MultiBufferRow, MultiBufferRow) { contiguous_row_selections.push(selection.clone()); - let start_row = MultiBufferRow(selection.start.row); + let start_row = starting_row(selection, display_map); let mut end_row = ending_row(selection, display_map); while let Some(next_selection) = selections.peek() { @@ -22116,6 +22295,14 @@ fn consume_contiguous_rows( (start_row, end_row) } +fn starting_row(selection: &Selection, display_map: &DisplaySnapshot) -> MultiBufferRow { + if selection.start.column > 0 { + MultiBufferRow(display_map.prev_line_boundary(selection.start).0.row) + } else { + MultiBufferRow(selection.start.row) + } +} + fn ending_row(next_selection: &Selection, display_map: &DisplaySnapshot) -> MultiBufferRow { if next_selection.end.column > 0 || next_selection.is_empty() { MultiBufferRow(display_map.next_line_boundary(next_selection.end).0.row + 1) diff --git a/crates/editor/src/editor_settings.rs b/crates/editor/src/editor_settings.rs index 5d8379ddfb..14f46c0e60 100644 --- a/crates/editor/src/editor_settings.rs +++ b/crates/editor/src/editor_settings.rs @@ -395,6 +395,8 @@ pub enum SnippetSortOrder { Inline, /// Place snippets at the bottom of the completion list Bottom, + /// Do not show snippets in the completion list + None, } #[derive(Clone, Default, Serialize, Deserialize, JsonSchema)] diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index aea84de9b0..03b047e92e 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -55,7 +55,7 @@ use util::{ uri, }; use workspace::{ - CloseActiveItem, CloseAllItems, CloseInactiveItems, MoveItemToPaneInDirection, NavigationEntry, + CloseActiveItem, CloseAllItems, CloseOtherItems, MoveItemToPaneInDirection, NavigationEntry, OpenOptions, ViewId, item::{FollowEvent, FollowableItem, Item, ItemHandle, SaveOptions}, }; @@ -2875,11 +2875,11 @@ async fn test_newline_documentation_comments(cx: &mut TestAppContext) { let language = Arc::new( Language::new( LanguageConfig { - documentation: Some(language::DocumentationConfig { + documentation_comment: Some(language::BlockCommentConfig { start: "/**".into(), end: "*/".into(), prefix: "* ".into(), - tab_size: NonZeroU32::new(1).unwrap(), + tab_size: 1, }), ..LanguageConfig::default() @@ -3080,6 +3080,50 @@ async fn test_newline_documentation_comments(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_newline_comments_with_block_comment(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.tab_size = NonZeroU32::new(4) + }); + + let lua_language = Arc::new(Language::new( + LanguageConfig { + line_comments: vec!["--".into()], + block_comment: Some(language::BlockCommentConfig { + start: "--[[".into(), + prefix: "".into(), + end: "]]".into(), + tab_size: 0, + }), + ..LanguageConfig::default() + }, + None, + )); + + let mut cx = EditorTestContext::new(cx).await; + cx.update_buffer(|buffer, cx| buffer.set_language(Some(lua_language), cx)); + + // Line with line comment should extend + cx.set_state(indoc! {" + --ˇ + "}); + cx.update_editor(|e, window, cx| e.newline(&Newline, window, cx)); + cx.assert_editor_state(indoc! {" + -- + --ˇ + "}); + + // Line with block comment that matches line comment should not extend + cx.set_state(indoc! {" + --[[ˇ + "}); + cx.update_editor(|e, window, cx| e.newline(&Newline, window, cx)); + cx.assert_editor_state(indoc! {" + --[[ + ˇ + "}); +} + #[gpui::test] fn test_insert_with_old_selections(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -4680,6 +4724,23 @@ async fn test_toggle_case(cx: &mut TestAppContext) { "}); } +#[gpui::test] +async fn test_convert_to_sentence_case(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + + cx.set_state(indoc! {" + «implement-windows-supportˇ» + "}); + cx.update_editor(|e, window, cx| { + e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx) + }); + cx.assert_editor_state(indoc! {" + «Implement windows supportˇ» + "}); +} + #[gpui::test] async fn test_manipulate_text(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5025,6 +5086,33 @@ fn test_move_line_up_down(cx: &mut TestAppContext) { }); } +#[gpui::test] +fn test_move_line_up_selection_at_end_of_fold(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple("\n\n\n\n\n\naaaa\nbbbb\ncccc", cx); + build_editor(buffer, window, cx) + }); + _ = editor.update(cx, |editor, window, cx| { + editor.fold_creases( + vec![Crease::simple( + Point::new(6, 4)..Point::new(7, 4), + FoldPlaceholder::test(), + )], + true, + window, + cx, + ); + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_ranges([Point::new(7, 4)..Point::new(7, 4)]) + }); + assert_eq!(editor.display_text(cx), "\n\n\n\n\n\naaaa⋯\ncccc"); + editor.move_line_up(&MoveLineUp, window, cx); + let buffer_text = editor.buffer.read(cx).snapshot(cx).text(); + assert_eq!(buffer_text, "\n\n\n\n\naaaa\nbbbb\n\ncccc"); + }); +} + #[gpui::test] fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -5042,7 +5130,6 @@ fn test_move_line_up_down_with_blocks(cx: &mut TestAppContext) { height: Some(1), render: Arc::new(|_| div().into_any()), priority: 0, - render_in_minimap: true, }], Some(Autoscroll::fit()), cx, @@ -5085,7 +5172,6 @@ async fn test_selections_and_replace_blocks(cx: &mut TestAppContext) { style: BlockStyle::Sticky, render: Arc::new(|_| gpui::div().into_any_element()), priority: 0, - render_in_minimap: true, }], None, cx, @@ -9533,6 +9619,74 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) { } } +#[gpui::test] +async fn test_redo_after_noop_format(cx: &mut TestAppContext) { + init_test(cx, |settings| { + settings.defaults.ensure_final_newline_on_save = Some(false); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_file(path!("/file.txt"), "foo".into()).await; + + let project = Project::test(fs, [path!("/file.txt").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/file.txt"), cx) + }) + .await + .unwrap(); + + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = cx.add_window_view(|window, cx| { + build_editor_with_project(project.clone(), buffer, window, cx) + }); + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(SelectionEffects::default(), window, cx, |s| { + s.select_ranges([0..0]) + }); + }); + assert!(!cx.read(|cx| editor.is_dirty(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.handle_input("\n", window, cx) + }); + cx.run_until_parked(); + save(&editor, &project, cx).await; + assert_eq!("\nfoo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.undo(&Default::default(), window, cx); + }); + save(&editor, &project, cx).await; + assert_eq!("foo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + editor.update_in(cx, |editor, window, cx| { + editor.redo(&Default::default(), window, cx); + }); + cx.run_until_parked(); + assert_eq!("\nfoo", editor.read_with(cx, |editor, cx| editor.text(cx))); + + async fn save(editor: &Entity, project: &Entity, cx: &mut VisualTestContext) { + let save = editor + .update_in(cx, |editor, window, cx| { + editor.save( + SaveOptions { + format: true, + autosave: false, + }, + project.clone(), + window, + cx, + ) + }) + .unwrap(); + cx.executor().start_waiting(); + save.await; + assert!(!cx.read(|cx| editor.is_dirty(cx))); + } +} + #[gpui::test] async fn test_multibuffer_format_during_save(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -13701,7 +13855,12 @@ async fn test_toggle_block_comment(cx: &mut TestAppContext) { Language::new( LanguageConfig { name: "HTML".into(), - block_comment: Some(("".into())), + block_comment: Some(BlockCommentConfig { + start: "".into(), + tab_size: 0, + }), ..Default::default() }, Some(tree_sitter_html::LANGUAGE.into()), @@ -14697,7 +14856,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu executor.run_until_parked(); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14706,7 +14865,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14715,7 +14874,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -14724,7 +14883,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu "}); cx.update_editor(|editor, window, cx| { - editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic, window, cx); + editor.go_to_prev_diagnostic(&GoToPreviousDiagnostic::default(), window, cx); }); cx.assert_editor_state(indoc! {" @@ -16722,7 +16881,7 @@ async fn test_multibuffer_reverts(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_mutlibuffer_in_navigation_history(cx: &mut TestAppContext) { +async fn test_multibuffer_in_navigation_history(cx: &mut TestAppContext) { init_test(cx, |_| {}); let cols = 4; @@ -21426,7 +21585,7 @@ println!("5"); .unwrap(); pane_1 .update_in(cx, |pane, window, cx| { - pane.close_inactive_items(&CloseInactiveItems::default(), window, cx) + pane.close_other_items(&CloseOtherItems::default(), None, window, cx) }) .await .unwrap(); @@ -21462,7 +21621,7 @@ println!("5"); .unwrap(); pane_2 .update_in(cx, |pane, window, cx| { - pane.close_inactive_items(&CloseInactiveItems::default(), window, cx) + pane.close_other_items(&CloseOtherItems::default(), None, window, cx) }) .await .unwrap(); @@ -22671,7 +22830,7 @@ pub(crate) fn init_test(cx: &mut TestAppContext, f: fn(&mut AllLanguageSettingsC workspace::init_settings(cx); crate::init(cx); }); - + zlog::init_test(); update_test_language_settings(cx, f); } diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 8a5bfb3bab..7e77f113ac 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -9,7 +9,7 @@ use crate::{ LineUp, MAX_LINE_LEN, MINIMAP_FONT_SIZE, MULTI_BUFFER_EXCERPT_HEADER_HEIGHT, OpenExcerpts, PageDown, PageUp, PhantomBreakpointIndicator, Point, RowExt, RowRangeExt, SelectPhase, SelectedTextHighlight, Selection, SelectionDragState, SoftWrap, StickyHeaderExcerpt, ToPoint, - ToggleFold, + ToggleFold, ToggleFoldAll, code_context_menus::{CodeActionsMenu, MENU_ASIDE_MAX_WIDTH, MENU_ASIDE_MIN_WIDTH, MENU_GAP}, display_map::{ Block, BlockContext, BlockStyle, ChunkRendererId, DisplaySnapshot, EditorMargins, @@ -216,6 +216,7 @@ impl EditorElement { register_action(editor, window, Editor::newline_above); register_action(editor, window, Editor::newline_below); register_action(editor, window, Editor::backspace); + register_action(editor, window, Editor::blame_hover); register_action(editor, window, Editor::delete); register_action(editor, window, Editor::tab); register_action(editor, window, Editor::backtab); @@ -229,7 +230,6 @@ impl EditorElement { register_action(editor, window, Editor::sort_lines_case_insensitive); register_action(editor, window, Editor::reverse_lines); register_action(editor, window, Editor::shuffle_lines); - register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_indentation_to_spaces); register_action(editor, window, Editor::convert_indentation_to_tabs); register_action(editor, window, Editor::convert_to_upper_case); @@ -240,6 +240,8 @@ impl EditorElement { register_action(editor, window, Editor::convert_to_upper_camel_case); register_action(editor, window, Editor::convert_to_lower_camel_case); register_action(editor, window, Editor::convert_to_opposite_case); + register_action(editor, window, Editor::convert_to_sentence_case); + register_action(editor, window, Editor::toggle_case); register_action(editor, window, Editor::convert_to_rot13); register_action(editor, window, Editor::convert_to_rot47); register_action(editor, window, Editor::delete_to_previous_word_start); @@ -261,6 +263,7 @@ impl EditorElement { register_action(editor, window, Editor::kill_ring_yank); register_action(editor, window, Editor::copy); register_action(editor, window, Editor::copy_and_trim); + register_action(editor, window, Editor::diff_clipboard_with_selection); register_action(editor, window, Editor::paste); register_action(editor, window, Editor::undo); register_action(editor, window, Editor::redo); @@ -416,6 +419,7 @@ impl EditorElement { register_action(editor, window, Editor::fold_recursive); register_action(editor, window, Editor::toggle_fold); register_action(editor, window, Editor::toggle_fold_recursive); + register_action(editor, window, Editor::toggle_fold_all); register_action(editor, window, Editor::unfold_lines); register_action(editor, window, Editor::unfold_recursive); register_action(editor, window, Editor::unfold_all); @@ -948,6 +952,7 @@ impl EditorElement { if !pending_nonempty_selections && hovered_link_modifier && text_hitbox.is_hovered(window) { let point = position_map.point_for_position(event.up.position); editor.handle_click_hovered_link(point, event.modifiers(), window, cx); + editor.selection_drag_state = SelectionDragState::None; cx.stop_propagation(); } @@ -1141,10 +1146,14 @@ impl EditorElement { .as_ref() .and_then(|state| state.popover_bounds) .map_or(false, |bounds| bounds.contains(&event.position)); + let keyboard_grace = editor + .inline_blame_popover + .as_ref() + .map_or(false, |state| state.keyboard_grace); if mouse_over_inline_blame || mouse_over_popover { - editor.show_blame_popover(&blame_entry, event.position, cx); - } else { + editor.show_blame_popover(&blame_entry, event.position, false, cx); + } else if !keyboard_grace { editor.hide_blame_popover(cx); } } else { @@ -2093,16 +2102,19 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> HashMap { - if self.editor.read(cx).mode().is_minimap() { - return HashMap::default(); - } - - let max_severity = match ProjectSettings::get_global(cx) - .diagnostics - .inline - .max_severity - .unwrap_or_else(|| self.editor.read(cx).diagnostics_max_severity) - .into_lsp() + let max_severity = match self + .editor + .read(cx) + .inline_diagnostics_enabled() + .then(|| { + ProjectSettings::get_global(cx) + .diagnostics + .inline + .max_severity + .unwrap_or_else(|| self.editor.read(cx).diagnostics_max_severity) + .into_lsp() + }) + .flatten() { Some(max_severity) => max_severity, None => return HashMap::default(), @@ -2618,9 +2630,6 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> Option> { - if self.editor.read(cx).mode().is_minimap() { - return None; - } let indent_guides = self.editor.update(cx, |editor, cx| { editor.indent_guides(visible_buffer_range, snapshot, cx) })?; @@ -3084,9 +3093,9 @@ impl EditorElement { window: &mut Window, cx: &mut App, ) -> Arc> { - let include_line_numbers = snapshot.show_line_numbers.unwrap_or_else(|| { - EditorSettings::get_global(cx).gutter.line_numbers && snapshot.mode.is_full() - }); + let include_line_numbers = snapshot + .show_line_numbers + .unwrap_or_else(|| EditorSettings::get_global(cx).gutter.line_numbers); if !include_line_numbers { return Arc::default(); } @@ -3399,22 +3408,18 @@ impl EditorElement { div() .size_full() - .children( - (!snapshot.mode.is_minimap() || custom.render_in_minimap).then(|| { - custom.render(&mut BlockContext { - window, - app: cx, - anchor_x, - margins: editor_margins, - line_height, - em_width, - block_id, - selected, - max_width: text_hitbox.size.width.max(*scroll_width), - editor_style: &self.style, - }) - }), - ) + .child(custom.render(&mut BlockContext { + window, + app: cx, + anchor_x, + margins: editor_margins, + line_height, + em_width, + block_id, + selected, + max_width: text_hitbox.size.width.max(*scroll_width), + editor_style: &self.style, + })) .into_any() } @@ -3620,24 +3625,37 @@ impl EditorElement { .tooltip({ let focus_handle = focus_handle.clone(); move |window, cx| { - Tooltip::for_action_in( + Tooltip::with_meta_in( "Toggle Excerpt Fold", - &ToggleFold, + Some(&ToggleFold), + "Alt+click to toggle all", &focus_handle, window, cx, ) } }) - .on_click(move |_, _, cx| { - if is_folded { + .on_click(move |event, window, cx| { + if event.modifiers().alt { + // Alt+click toggles all buffers editor.update(cx, |editor, cx| { - editor.unfold_buffer(buffer_id, cx); + editor.toggle_fold_all( + &ToggleFoldAll, + window, + cx, + ); }); } else { - editor.update(cx, |editor, cx| { - editor.fold_buffer(buffer_id, cx); - }); + // Regular click toggles single buffer + if is_folded { + editor.update(cx, |editor, cx| { + editor.unfold_buffer(buffer_id, cx); + }); + } else { + editor.update(cx, |editor, cx| { + editor.fold_buffer(buffer_id, cx); + }); + } } }), ), @@ -3993,6 +4011,7 @@ impl EditorElement { let available_width = hitbox.bounds.size.width - right_margin; let mut header = v_flex() + .w_full() .relative() .child( div() @@ -6762,7 +6781,7 @@ impl EditorElement { } fn paint_mouse_listeners(&mut self, layout: &EditorLayout, window: &mut Window, cx: &mut App) { - if self.editor.read(cx).mode.is_minimap() { + if layout.mode.is_minimap() { return; } @@ -7777,46 +7796,13 @@ impl Element for EditorElement { editor.set_style(self.style.clone(), window, cx); let layout_id = match editor.mode { - EditorMode::SingleLine { auto_width } => { + EditorMode::SingleLine => { let rem_size = window.rem_size(); - let height = self.style.text.line_height_in_pixels(rem_size); - if auto_width { - let editor_handle = cx.entity().clone(); - let style = self.style.clone(); - window.request_measured_layout( - Style::default(), - move |_, _, window, cx| { - let editor_snapshot = editor_handle - .update(cx, |editor, cx| editor.snapshot(window, cx)); - let line = Self::layout_lines( - DisplayRow(0)..DisplayRow(1), - &editor_snapshot, - &style, - px(f32::MAX), - |_| false, // Single lines never soft wrap - window, - cx, - ) - .pop() - .unwrap(); - - let font_id = - window.text_system().resolve_font(&style.text.font()); - let font_size = - style.text.font_size.to_pixels(window.rem_size()); - let em_width = - window.text_system().em_width(font_id, font_size).unwrap(); - - size(line.width + em_width, height) - }, - ) - } else { - let mut style = Style::default(); - style.size.height = height.into(); - style.size.width = relative(1.).into(); - window.request_layout(style, None, cx) - } + let mut style = Style::default(); + style.size.height = height.into(); + style.size.width = relative(1.).into(); + window.request_layout(style, None, cx) } EditorMode::AutoHeight { min_lines, @@ -7889,9 +7875,14 @@ impl Element for EditorElement { line_height: Some(self.style.text.line_height), ..Default::default() }; - let focus_handle = self.editor.focus_handle(cx); - window.set_view_id(self.editor.entity_id()); - window.set_focus_handle(&focus_handle, cx); + + let is_minimap = self.editor.read(cx).mode.is_minimap(); + + if !is_minimap { + let focus_handle = self.editor.focus_handle(cx); + window.set_view_id(self.editor.entity_id()); + window.set_focus_handle(&focus_handle, cx); + } let rem_size = self.rem_size(cx); window.with_rem_size(rem_size, |window| { @@ -7953,17 +7944,11 @@ impl Element for EditorElement { right: right_margin, }; - // Offset the content_bounds from the text_bounds by the gutter margin (which - // is roughly half a character wide) to make hit testing work more like how we want. - let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); - - let editor_content_width = editor_width - content_offset.x; - snapshot = self.editor.update(cx, |editor, cx| { editor.last_bounds = Some(bounds); editor.gutter_dimensions = gutter_dimensions; editor.set_visible_line_count(bounds.size.height / line_height, window, cx); - editor.set_visible_column_count(editor_content_width / em_advance); + editor.set_visible_column_count(editor_width / em_advance); if matches!( editor.mode, @@ -7975,10 +7960,10 @@ impl Element for EditorElement { let wrap_width = match editor.soft_wrap_mode(cx) { SoftWrap::GitDiff => None, SoftWrap::None => Some(wrap_width_for(MAX_LINE_LEN as u32 / 2)), - SoftWrap::EditorWidth => Some(editor_content_width), + SoftWrap::EditorWidth => Some(editor_width), SoftWrap::Column(column) => Some(wrap_width_for(column)), SoftWrap::Bounded(column) => { - Some(editor_content_width.min(wrap_width_for(column))) + Some(editor_width.min(wrap_width_for(column))) } }; @@ -8003,13 +7988,12 @@ impl Element for EditorElement { HitboxBehavior::Normal, ); + // Offset the content_bounds from the text_bounds by the gutter margin (which + // is roughly half a character wide) to make hit testing work more like how we want. + let content_offset = point(editor_margins.gutter.margin, Pixels::ZERO); let content_origin = text_hitbox.origin + content_offset; - let editor_text_bounds = - Bounds::from_corners(content_origin, bounds.bottom_right()); - - let height_in_lines = editor_text_bounds.size.height / line_height; - + let height_in_lines = bounds.size.height / line_height; let max_row = snapshot.max_point().row().as_f32(); // The max scroll position for the top of the window @@ -8035,23 +8019,25 @@ impl Element for EditorElement { } }; - // TODO: Autoscrolling for both axes - let mut autoscroll_request = None; - let mut autoscroll_containing_element = false; - let mut autoscroll_horizontally = false; - self.editor.update(cx, |editor, cx| { - autoscroll_request = editor.autoscroll_request(); - autoscroll_containing_element = + let ( + autoscroll_request, + autoscroll_containing_element, + needs_horizontal_autoscroll, + ) = self.editor.update(cx, |editor, cx| { + let autoscroll_request = editor.autoscroll_request(); + let autoscroll_containing_element = autoscroll_request.is_some() || editor.has_pending_selection(); - // TODO: Is this horizontal or vertical?! - autoscroll_horizontally = editor.autoscroll_vertically( - bounds, - line_height, - max_scroll_top, - window, - cx, - ); - snapshot = editor.snapshot(window, cx); + + let (needs_horizontal_autoscroll, was_scrolled) = editor + .autoscroll_vertically(bounds, line_height, max_scroll_top, window, cx); + if was_scrolled.0 { + snapshot = editor.snapshot(window, cx); + } + ( + autoscroll_request, + autoscroll_containing_element, + needs_horizontal_autoscroll, + ) }); let mut scroll_position = snapshot.scroll_position(); @@ -8327,18 +8313,22 @@ impl Element for EditorElement { window, cx, ); - let new_renrerer_widths = line_layouts - .iter() - .flat_map(|layout| &layout.fragments) - .filter_map(|fragment| { - if let LineFragment::Element { id, size, .. } = fragment { - Some((*id, size.width)) - } else { - None - } - }); - if self.editor.update(cx, |editor, cx| { - editor.update_renderer_widths(new_renrerer_widths, cx) + let new_renderer_widths = (!is_minimap).then(|| { + line_layouts + .iter() + .flat_map(|layout| &layout.fragments) + .filter_map(|fragment| { + if let LineFragment::Element { id, size, .. } = fragment { + Some((*id, size.width)) + } else { + None + } + }) + }); + if new_renderer_widths.is_some_and(|new_renderer_widths| { + self.editor.update(cx, |editor, cx| { + editor.update_renderer_widths(new_renderer_widths, cx) + }) }) { // If the fold widths have changed, we need to prepaint // the element again to account for any changes in @@ -8387,7 +8377,6 @@ impl Element for EditorElement { glyph_grid_cell, size(longest_line_width, max_row.as_f32() * line_height), longest_line_blame_width, - editor_width, EditorSettings::get_global(cx), ); @@ -8401,27 +8390,31 @@ impl Element for EditorElement { let sticky_header_excerpt_id = sticky_header_excerpt.as_ref().map(|top| top.excerpt.id); - let blocks = window.with_element_namespace("blocks", |window| { - self.render_blocks( - start_row..end_row, - &snapshot, - &hitbox, - &text_hitbox, - editor_width, - &mut scroll_width, - &editor_margins, - em_width, - gutter_dimensions.full_width(), - line_height, - &mut line_layouts, - &local_selections, - &selected_buffer_ids, - is_row_soft_wrapped, - sticky_header_excerpt_id, - window, - cx, - ) - }); + let blocks = (!is_minimap) + .then(|| { + window.with_element_namespace("blocks", |window| { + self.render_blocks( + start_row..end_row, + &snapshot, + &hitbox, + &text_hitbox, + editor_width, + &mut scroll_width, + &editor_margins, + em_width, + gutter_dimensions.full_width(), + line_height, + &mut line_layouts, + &local_selections, + &selected_buffer_ids, + is_row_soft_wrapped, + sticky_header_excerpt_id, + window, + cx, + ) + }) + }) + .unwrap_or_else(|| Ok((Vec::default(), HashMap::default()))); let (mut blocks, row_block_types) = match blocks { Ok(blocks) => blocks, Err(resized_blocks) => { @@ -8455,30 +8448,27 @@ impl Element for EditorElement { MultiBufferRow(end_anchor.to_point(&snapshot.buffer_snapshot).row); let scroll_max = point( - ((scroll_width - editor_content_width) / em_advance).max(0.0), + ((scroll_width - editor_width) / em_advance).max(0.0), max_scroll_top, ); self.editor.update(cx, |editor, cx| { - let clamped = editor.scroll_manager.clamp_scroll_left(scroll_max.x); + if editor.scroll_manager.clamp_scroll_left(scroll_max.x) { + scroll_position.x = scroll_position.x.min(scroll_max.x); + } - let autoscrolled = if autoscroll_horizontally { - editor.autoscroll_horizontally( + if needs_horizontal_autoscroll.0 + && let Some(new_scroll_position) = editor.autoscroll_horizontally( start_row, - editor_content_width, + editor_width, scroll_width, em_advance, &line_layouts, window, cx, ) - } else { - false - }; - - if clamped || autoscrolled { - snapshot = editor.snapshot(window, cx); - scroll_position = snapshot.scroll_position(); + { + scroll_position = new_scroll_position; } }); @@ -8593,7 +8583,9 @@ impl Element for EditorElement { } } else { log::error!( - "bug: line_ix {} is out of bounds - row_infos.len(): {}, line_layouts.len(): {}, crease_trailers.len(): {}", + "bug: line_ix {} is out of bounds - row_infos.len(): {}, \ + line_layouts.len(): {}, \ + crease_trailers.len(): {}", line_ix, row_infos.len(), line_layouts.len(), @@ -8614,29 +8606,6 @@ impl Element for EditorElement { cx, ); - self.editor.update(cx, |editor, cx| { - let clamped = editor.scroll_manager.clamp_scroll_left(scroll_max.x); - - let autoscrolled = if autoscroll_horizontally { - editor.autoscroll_horizontally( - start_row, - editor_content_width, - scroll_width, - em_advance, - &line_layouts, - window, - cx, - ) - } else { - false - }; - - if clamped || autoscrolled { - snapshot = editor.snapshot(window, cx); - scroll_position = snapshot.scroll_position(); - } - }); - let line_elements = self.prepaint_lines( start_row, &mut line_layouts, @@ -8862,7 +8831,7 @@ impl Element for EditorElement { underline: None, strikethrough: None, }], - None + None, ); let space_invisible = window.text_system().shape_line( "‒".into(), @@ -8875,7 +8844,7 @@ impl Element for EditorElement { underline: None, strikethrough: None, }], - None + None, ); let mode = snapshot.mode.clone(); @@ -8977,19 +8946,21 @@ impl Element for EditorElement { window: &mut Window, cx: &mut App, ) { - let focus_handle = self.editor.focus_handle(cx); - let key_context = self - .editor - .update(cx, |editor, cx| editor.key_context(window, cx)); + if !layout.mode.is_minimap() { + let focus_handle = self.editor.focus_handle(cx); + let key_context = self + .editor + .update(cx, |editor, cx| editor.key_context(window, cx)); - window.set_key_context(key_context); - window.handle_input( - &focus_handle, - ElementInputHandler::new(bounds, self.editor.clone()), - cx, - ); - self.register_actions(window, cx); - self.register_key_listeners(window, cx, layout); + window.set_key_context(key_context); + window.handle_input( + &focus_handle, + ElementInputHandler::new(bounds, self.editor.clone()), + cx, + ); + self.register_actions(window, cx); + self.register_key_listeners(window, cx, layout); + } let text_style = TextStyleRefinement { font_size: Some(self.style.text.font_size), @@ -9070,7 +9041,6 @@ impl ScrollbarLayoutInformation { glyph_grid_cell: Size, document_size: Size, longest_line_blame_width: Pixels, - editor_width: Pixels, settings: &EditorSettings, ) -> Self { let vertical_overscroll = match settings.scroll_beyond_last_line { @@ -9081,19 +9051,11 @@ impl ScrollbarLayoutInformation { } }; - let right_margin = if document_size.width + longest_line_blame_width >= editor_width { - glyph_grid_cell.width - } else { - px(0.0) - }; - - let overscroll = size(right_margin + longest_line_blame_width, vertical_overscroll); - - let scroll_range = document_size + overscroll; + let overscroll = size(longest_line_blame_width, vertical_overscroll); ScrollbarLayoutInformation { editor_bounds, - scroll_range, + scroll_range: document_size + overscroll, glyph_grid_cell, } } @@ -9198,7 +9160,7 @@ struct EditorScrollbars { impl EditorScrollbars { pub fn from_scrollbar_axes( - settings_visibility: ScrollbarAxes, + show_scrollbar: ScrollbarAxes, layout_information: &ScrollbarLayoutInformation, content_offset: gpui::Point, scroll_position: gpui::Point, @@ -9236,22 +9198,13 @@ impl EditorScrollbars { }; let mut create_scrollbar_layout = |axis| { - settings_visibility - .along(axis) + let viewport_size = viewport_size.along(axis); + let scroll_range = scroll_range.along(axis); + + // We always want a vertical scrollbar track for scrollbar diagnostic visibility. + (show_scrollbar.along(axis) + && (axis == ScrollbarAxis::Vertical || scroll_range > viewport_size)) .then(|| { - ( - viewport_size.along(axis) - content_offset.along(axis), - scroll_range.along(axis), - ) - }) - .filter(|(viewport_size, scroll_range)| { - // The scrollbar should only be rendered if the content does - // not entirely fit into the editor - // However, this only applies to the horizontal scrollbar, as information about the - // vertical scrollbar layout is always needed for scrollbar diagnostics. - axis != ScrollbarAxis::Horizontal || viewport_size < scroll_range - }) - .map(|(viewport_size, scroll_range)| { ScrollbarLayout::new( window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal), viewport_size, @@ -10298,7 +10251,6 @@ mod tests { height: Some(3), render: Arc::new(|cx| div().h(3. * cx.window.line_height()).into_any()), priority: 0, - render_in_minimap: true, }], None, cx, @@ -10388,7 +10340,7 @@ mod tests { }); for editor_mode_without_invisibles in [ - EditorMode::SingleLine { auto_width: false }, + EditorMode::SingleLine, EditorMode::AutoHeight { min_lines: 1, max_lines: Some(100), diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index d4c9e37895..fc350a5a15 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -296,7 +296,7 @@ impl GitBlame { let row = info .buffer_row .filter(|_| info.buffer_id == Some(buffer_id))?; - cursor.seek_forward(&row, Bias::Right, &()); + cursor.seek_forward(&row, Bias::Right); cursor.item()?.blame.clone() }) } @@ -389,7 +389,7 @@ impl GitBlame { } } - new_entries.append(cursor.slice(&edit.old.start, Bias::Right, &()), &()); + new_entries.append(cursor.slice(&edit.old.start, Bias::Right), &()); if edit.new.start > new_entries.summary().rows { new_entries.push( @@ -401,7 +401,7 @@ impl GitBlame { ); } - cursor.seek(&edit.old.end, Bias::Right, &()); + cursor.seek(&edit.old.end, Bias::Right); if !edit.new.is_empty() { new_entries.push( GitBlameEntry { @@ -412,7 +412,7 @@ impl GitBlame { ); } - let old_end = cursor.end(&()); + let old_end = cursor.end(); if row_edits .peek() .map_or(true, |next_edit| next_edit.old.start >= old_end) @@ -421,18 +421,18 @@ impl GitBlame { if old_end > edit.old.end { new_entries.push( GitBlameEntry { - rows: cursor.end(&()) - edit.old.end, + rows: cursor.end() - edit.old.end, blame: entry.blame.clone(), }, &(), ); } - cursor.next(&()); + cursor.next(); } } } - new_entries.append(cursor.suffix(&()), &()); + new_entries.append(cursor.suffix(), &()); drop(cursor); self.buffer_snapshot = new_snapshot; diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 2e4631a62b..ca635a2132 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -813,7 +813,13 @@ impl Item for Editor { window: &mut Window, cx: &mut Context, ) -> Task> { - self.report_editor_event("Editor Saved", None, cx); + // Add meta data tracking # of auto saves + if options.autosave { + self.report_editor_event("Editor Autosaved", None, cx); + } else { + self.report_editor_event("Editor Saved", None, cx); + } + let buffers = self.buffer().clone().read(cx).all_buffers(); let buffers = buffers .into_iter() @@ -1220,7 +1226,20 @@ impl SerializableItem for Editor { abs_path: None, contents: None, .. - } => Task::ready(Err(anyhow!("No path or contents found for buffer"))), + } => window.spawn(cx, async move |cx| { + let buffer = project + .update(cx, |project, cx| project.create_buffer(cx))? + .await?; + + cx.update(|window, cx| { + cx.new(|cx| { + let mut editor = Editor::for_buffer(buffer, Some(project), window, cx); + + editor.read_metadata_from_db(item_id, workspace_id, window, cx); + editor + }) + }) + }), } } @@ -2092,5 +2111,38 @@ mod tests { assert!(editor.has_conflict(cx)); // The editor should have a conflict }); } + + // Test case 5: Deserialize with no path, no content, no language, and no old mtime (new, empty, unsaved buffer) + { + let project = Project::test(fs.clone(), [path!("/file.rs").as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap(); + + let item_id = 10000 as ItemId; + let serialized_editor = SerializedEditor { + abs_path: None, + contents: None, + language: None, + mtime: None, + }; + + DB.save_serialized_editor(item_id, workspace_id, serialized_editor) + .await + .unwrap(); + + let deserialized = + deserialize_editor(item_id, workspace_id, workspace, project, cx).await; + + deserialized.update(cx, |editor, cx| { + assert_eq!(editor.text(cx), ""); + assert!(!editor.is_dirty(cx)); + assert!(!editor.has_conflict(cx)); + + let buffer = editor.buffer().read(cx).as_singleton().unwrap().read(cx); + assert!(buffer.file().is_none()); + }); + } } } diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index b3007d3091..ecaf7c11e4 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -12,7 +12,7 @@ use crate::{ }; pub use autoscroll::{Autoscroll, AutoscrollStrategy}; use core::fmt::Debug; -use gpui::{App, Axis, Context, Global, Pixels, Task, Window, point, px}; +use gpui::{Along, App, Axis, Context, Global, Pixels, Task, Window, point, px}; use language::language_settings::{AllLanguageSettings, SoftWrap}; use language::{Bias, Point}; pub use scroll_amount::ScrollAmount; @@ -27,6 +27,8 @@ use workspace::{ItemId, WorkspaceId}; pub const SCROLL_EVENT_SEPARATION: Duration = Duration::from_millis(28); const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1); +pub struct WasScrolled(pub(crate) bool); + #[derive(Default)] pub struct ScrollbarAutoHide(pub bool); @@ -47,14 +49,14 @@ impl ScrollAnchor { } pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point { - let mut scroll_position = self.offset; - if self.anchor == Anchor::min() { - scroll_position.y = 0.; - } else { - let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); - scroll_position.y += scroll_top; - } - scroll_position + self.offset.apply_along(Axis::Vertical, |offset| { + if self.anchor == Anchor::min() { + 0. + } else { + let scroll_top = self.anchor.to_display_point(snapshot).row().as_f32(); + (offset + scroll_top).max(0.) + } + }) } pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 { @@ -215,87 +217,56 @@ impl ScrollManager { workspace_id: Option, window: &mut Window, cx: &mut Context, - ) { - let (new_anchor, top_row) = if scroll_position.y <= 0. && scroll_position.x <= 0. { - ( - ScrollAnchor { - anchor: Anchor::min(), - offset: scroll_position.max(&gpui::Point::default()), - }, - 0, - ) - } else if scroll_position.y <= 0. { - let buffer_point = map - .clip_point( - DisplayPoint::new(DisplayRow(0), scroll_position.x as u32), - Bias::Left, - ) - .to_point(map); - let anchor = map.buffer_snapshot.anchor_at(buffer_point, Bias::Right); - - ( - ScrollAnchor { - anchor: anchor, - offset: scroll_position.max(&gpui::Point::default()), - }, - 0, - ) - } else { - let scroll_top = scroll_position.y; - let scroll_top = match EditorSettings::get_global(cx).scroll_beyond_last_line { - ScrollBeyondLastLine::OnePage => scroll_top, - ScrollBeyondLastLine::Off => { - if let Some(height_in_lines) = self.visible_line_count { - let max_row = map.max_point().row().0 as f32; - scroll_top.min(max_row - height_in_lines + 1.).max(0.) - } else { - scroll_top - } + ) -> WasScrolled { + let scroll_top = scroll_position.y.max(0.); + let scroll_top = match EditorSettings::get_global(cx).scroll_beyond_last_line { + ScrollBeyondLastLine::OnePage => scroll_top, + ScrollBeyondLastLine::Off => { + if let Some(height_in_lines) = self.visible_line_count { + let max_row = map.max_point().row().0 as f32; + scroll_top.min(max_row - height_in_lines + 1.).max(0.) + } else { + scroll_top } - ScrollBeyondLastLine::VerticalScrollMargin => { - if let Some(height_in_lines) = self.visible_line_count { - let max_row = map.max_point().row().0 as f32; - scroll_top - .min(max_row - height_in_lines + 1. + self.vertical_scroll_margin) - .max(0.) - } else { - scroll_top - } + } + ScrollBeyondLastLine::VerticalScrollMargin => { + if let Some(height_in_lines) = self.visible_line_count { + let max_row = map.max_point().row().0 as f32; + scroll_top + .min(max_row - height_in_lines + 1. + self.vertical_scroll_margin) + .max(0.) + } else { + scroll_top } - }; - - let scroll_top_row = DisplayRow(scroll_top as u32); - let scroll_top_buffer_point = map - .clip_point( - DisplayPoint::new(scroll_top_row, scroll_position.x as u32), - Bias::Left, - ) - .to_point(map); - let top_anchor = map - .buffer_snapshot - .anchor_at(scroll_top_buffer_point, Bias::Right); - - ( - ScrollAnchor { - anchor: top_anchor, - offset: point( - scroll_position.x.max(0.), - scroll_top - top_anchor.to_display_point(map).row().as_f32(), - ), - }, - scroll_top_buffer_point.row, - ) + } }; + let scroll_top_row = DisplayRow(scroll_top as u32); + let scroll_top_buffer_point = map + .clip_point( + DisplayPoint::new(scroll_top_row, scroll_position.x as u32), + Bias::Left, + ) + .to_point(map); + let top_anchor = map + .buffer_snapshot + .anchor_at(scroll_top_buffer_point, Bias::Right); + self.set_anchor( - new_anchor, - top_row, + ScrollAnchor { + anchor: top_anchor, + offset: point( + scroll_position.x.max(0.), + scroll_top - top_anchor.to_display_point(map).row().as_f32(), + ), + }, + scroll_top_buffer_point.row, local, autoscroll, workspace_id, window, cx, - ); + ) } fn set_anchor( @@ -307,7 +278,7 @@ impl ScrollManager { workspace_id: Option, window: &mut Window, cx: &mut Context, - ) { + ) -> WasScrolled { let adjusted_anchor = if self.forbid_vertical_scroll { ScrollAnchor { offset: gpui::Point::new(anchor.offset.x, self.anchor.offset.y), @@ -317,10 +288,14 @@ impl ScrollManager { anchor }; + self.autoscroll_request.take(); + if self.anchor == adjusted_anchor { + return WasScrolled(false); + } + self.anchor = adjusted_anchor; cx.emit(EditorEvent::ScrollPositionChanged { local, autoscroll }); self.show_scrollbars(window, cx); - self.autoscroll_request.take(); if let Some(workspace_id) = workspace_id { let item_id = cx.entity().entity_id().as_u64() as ItemId; @@ -342,6 +317,8 @@ impl ScrollManager { .detach() } cx.notify(); + + WasScrolled(true) } pub fn show_scrollbars(&mut self, window: &mut Window, cx: &mut Context) { @@ -552,13 +529,13 @@ impl Editor { scroll_position: gpui::Point, window: &mut Window, cx: &mut Context, - ) { + ) -> WasScrolled { let mut position = scroll_position; if self.scroll_manager.forbid_vertical_scroll { let current_position = self.scroll_position(cx); position.y = current_position.y; } - self.set_scroll_position_internal(position, true, false, window, cx); + self.set_scroll_position_internal(position, true, false, window, cx) } /// Scrolls so that `row` is at the top of the editor view. @@ -590,7 +567,7 @@ impl Editor { autoscroll: bool, window: &mut Window, cx: &mut Context, - ) { + ) -> WasScrolled { let map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); self.set_scroll_position_taking_display_map( scroll_position, @@ -599,7 +576,7 @@ impl Editor { map, window, cx, - ); + ) } fn set_scroll_position_taking_display_map( @@ -610,7 +587,7 @@ impl Editor { display_map: DisplaySnapshot, window: &mut Window, cx: &mut Context, - ) { + ) -> WasScrolled { hide_hover(self, cx); let workspace_id = self.workspace.as_ref().and_then(|workspace| workspace.1); @@ -624,7 +601,7 @@ impl Editor { scroll_position }; - self.scroll_manager.set_scroll_position( + let editor_was_scrolled = self.scroll_manager.set_scroll_position( adjusted_position, &display_map, local, @@ -636,6 +613,7 @@ impl Editor { self.refresh_inlay_hints(InlayHintRefreshReason::NewLinesShown, cx); self.refresh_colors(false, None, window, cx); + editor_was_scrolled } pub fn scroll_position(&self, cx: &mut Context) -> gpui::Point { diff --git a/crates/editor/src/scroll/autoscroll.rs b/crates/editor/src/scroll/autoscroll.rs index 340277633a..e8a1f8da73 100644 --- a/crates/editor/src/scroll/autoscroll.rs +++ b/crates/editor/src/scroll/autoscroll.rs @@ -1,6 +1,6 @@ use crate::{ DisplayRow, Editor, EditorMode, LineWithInvisibles, RowExt, SelectionEffects, - display_map::ToDisplayPoint, + display_map::ToDisplayPoint, scroll::WasScrolled, }; use gpui::{Bounds, Context, Pixels, Window, px}; use language::Point; @@ -99,19 +99,21 @@ impl AutoscrollStrategy { } } +pub(crate) struct NeedsHorizontalAutoscroll(pub(crate) bool); + impl Editor { pub fn autoscroll_request(&self) -> Option { self.scroll_manager.autoscroll_request() } - pub fn autoscroll_vertically( + pub(crate) fn autoscroll_vertically( &mut self, bounds: Bounds, line_height: Pixels, max_scroll_top: f32, window: &mut Window, cx: &mut Context, - ) -> bool { + ) -> (NeedsHorizontalAutoscroll, WasScrolled) { let viewport_height = bounds.size.height; let visible_lines = viewport_height / line_height; let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); @@ -129,12 +131,14 @@ impl Editor { scroll_position.y = max_scroll_top; } - if original_y != scroll_position.y { - self.set_scroll_position(scroll_position, window, cx); - } + let editor_was_scrolled = if original_y != scroll_position.y { + self.set_scroll_position(scroll_position, window, cx) + } else { + WasScrolled(false) + }; let Some((autoscroll, local)) = self.scroll_manager.autoscroll_request.take() else { - return false; + return (NeedsHorizontalAutoscroll(false), editor_was_scrolled); }; let mut target_top; @@ -212,7 +216,7 @@ impl Editor { target_bottom = target_top + 1.; } - match strategy { + let was_autoscrolled = match strategy { AutoscrollStrategy::Fit | AutoscrollStrategy::Newest => { let margin = margin.min(self.scroll_manager.vertical_scroll_margin); let target_top = (target_top - margin).max(0.0); @@ -225,39 +229,42 @@ impl Editor { if needs_scroll_up && !needs_scroll_down { scroll_position.y = target_top; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); - } - if !needs_scroll_up && needs_scroll_down { + } else if !needs_scroll_up && needs_scroll_down { scroll_position.y = target_bottom - visible_lines; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + } + + if needs_scroll_up ^ needs_scroll_down { + self.set_scroll_position_internal(scroll_position, local, true, window, cx) + } else { + WasScrolled(false) } } AutoscrollStrategy::Center => { scroll_position.y = (target_top - margin).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Focused => { let margin = margin.min(self.scroll_manager.vertical_scroll_margin); scroll_position.y = (target_top - margin).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Top => { scroll_position.y = (target_top).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::Bottom => { scroll_position.y = (target_bottom - visible_lines).max(0.0); - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::TopRelative(lines) => { scroll_position.y = target_top - lines as f32; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } AutoscrollStrategy::BottomRelative(lines) => { scroll_position.y = target_bottom + lines as f32; - self.set_scroll_position_internal(scroll_position, local, true, window, cx); + self.set_scroll_position_internal(scroll_position, local, true, window, cx) } - } + }; self.scroll_manager.last_autoscroll = Some(( self.scroll_manager.anchor.offset, @@ -266,7 +273,8 @@ impl Editor { strategy, )); - true + let was_scrolled = WasScrolled(editor_was_scrolled.0 || was_autoscrolled.0); + (NeedsHorizontalAutoscroll(true), was_scrolled) } pub(crate) fn autoscroll_horizontally( @@ -278,7 +286,7 @@ impl Editor { layouts: &[LineWithInvisibles], window: &mut Window, cx: &mut Context, - ) -> bool { + ) -> Option> { let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let selections = self.selections.all::(cx); let mut scroll_position = self.scroll_manager.scroll_position(&display_map); @@ -319,22 +327,26 @@ impl Editor { target_right = target_right.min(scroll_width); if target_right - target_left > viewport_width { - return false; + return None; } let scroll_left = self.scroll_manager.anchor.offset.x * em_advance; let scroll_right = scroll_left + viewport_width; - if target_left < scroll_left { + let was_scrolled = if target_left < scroll_left { scroll_position.x = target_left / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx); - true + self.set_scroll_position_internal(scroll_position, true, true, window, cx) } else if target_right > scroll_right { scroll_position.x = (target_right - viewport_width) / em_advance; - self.set_scroll_position_internal(scroll_position, true, true, window, cx); - true + self.set_scroll_position_internal(scroll_position, true, true, window, cx) } else { - false + WasScrolled(false) + }; + + if was_scrolled.0 { + Some(scroll_position) + } else { + None } } diff --git a/crates/editor/src/test/editor_lsp_test_context.rs b/crates/editor/src/test/editor_lsp_test_context.rs index f7f34135f3..c59786b1eb 100644 --- a/crates/editor/src/test/editor_lsp_test_context.rs +++ b/crates/editor/src/test/editor_lsp_test_context.rs @@ -14,7 +14,8 @@ use futures::Future; use gpui::{Context, Entity, Focusable as _, VisualTestContext, Window}; use indoc::indoc; use language::{ - FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageQueries, point_to_lsp, + BlockCommentConfig, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageQueries, + point_to_lsp, }; use lsp::{notification, request}; use multi_buffer::ToPointUtf16; @@ -269,7 +270,12 @@ impl EditorLspTestContext { path_suffixes: vec!["html".into()], ..Default::default() }, - block_comment: Some(("".into())), + block_comment: Some(BlockCommentConfig { + start: "".into(), + tab_size: 0, + }), completion_query_characters: ['-'].into_iter().collect(), ..Default::default() }, diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 39121377bb..a02b4a7f0b 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -8,6 +8,7 @@ mod tool_metrics; use assertions::{AssertionsReport, display_error_row}; use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git}; +use language_extension::LspAccess; pub(crate) use tool_metrics::*; use ::fs::RealFs; @@ -415,7 +416,11 @@ pub fn init(cx: &mut App) -> Arc { language::init(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); - language_extension::init(extension_host_proxy.clone(), languages.clone()); + language_extension::init( + LspAccess::Noop, + extension_host_proxy.clone(), + languages.clone(), + ); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx); diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 904eca83e6..7ce3b1fdf1 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -221,9 +221,6 @@ impl ExampleContext { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } - ThreadEvent::RetriesFailed { .. } => { - // Ignore retries failed events - } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel(); @@ -425,6 +422,13 @@ impl AppContext for ExampleContext { self.app.update_entity(handle, update) } + fn as_mut<'a, T>(&'a mut self, handle: &Entity) -> Self::Result> + where + T: 'static, + { + self.app.as_mut(handle) + } + fn read_entity( &self, handle: &Entity, diff --git a/crates/eval/src/explorer.html b/crates/eval/src/explorer.html index fec4597163..04c41090d3 100644 --- a/crates/eval/src/explorer.html +++ b/crates/eval/src/explorer.html @@ -324,20 +324,8 @@

Thread Explorer

- - + + @@ -368,8 +352,7 @@ ← Previous
- Thread 1 of - 1: + Thread 1 of 1: Default Thread