diff --git a/.config/hakari.toml b/.config/hakari.toml index 2050065cc2..8ce0b77490 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -25,6 +25,8 @@ third-party = [ { name = "reqwest", version = "0.11.27" }, # build of remote_server should not include scap / its x11 dependency { name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" }, + # build of remote_server should not need to include on libalsa through rodio + { name = "rodio" }, ] [final-excludes] @@ -32,7 +34,6 @@ workspace-members = [ "zed_extension_api", # exclude all extensions - "zed_emmet", "zed_glsl", "zed_html", "zed_proto", diff --git a/.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml b/.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml new file mode 100644 index 0000000000..826c2b8027 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml @@ -0,0 +1,35 @@ +name: Bug Report (Windows Alpha) +description: Zed Windows Alpha Related Bugs +type: "Bug" +labels: ["windows"] +title: "Windows Alpha: " +body: + - type: textarea + attributes: + label: Summary + description: Describe the bug with a one-line summary, and provide detailed reproduction steps + value: | + + SUMMARY_SENTENCE_HERE + + ### Description + + Steps to trigger the problem: + 1. + 2. + 3. + + **Expected Behavior**: + **Actual Behavior**: + + validations: + required: true + - type: textarea + id: environment + attributes: + label: Zed Version and System Specs + description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"' + placeholder: | + Output of "zed: copy system specs into clipboard" + validations: + required: true diff --git a/.github/actions/run_tests_windows/action.yml b/.github/actions/run_tests_windows/action.yml index cbe95e82c1..e3e3b7142e 100644 --- a/.github/actions/run_tests_windows/action.yml +++ b/.github/actions/run_tests_windows/action.yml @@ -20,7 +20,168 @@ runs: with: node-version: "18" + - name: Configure crash dumps + shell: powershell + run: | + # Record the start time for this CI run + $runStartTime = Get-Date + $runStartTimeStr = $runStartTime.ToString("yyyy-MM-dd HH:mm:ss") + Write-Host "CI run started at: $runStartTimeStr" + + # Save the timestamp for later use + echo "CI_RUN_START_TIME=$($runStartTime.Ticks)" >> $env:GITHUB_ENV + + # Create crash dump directory in workspace (non-persistent) + $dumpPath = "$env:GITHUB_WORKSPACE\crash_dumps" + New-Item -ItemType Directory -Force -Path $dumpPath | Out-Null + + Write-Host "Setting up crash dump detection..." + Write-Host "Workspace dump path: $dumpPath" + + # Note: We're NOT modifying registry on stateful runners + # Instead, we'll check default Windows crash locations after tests + - name: Run tests shell: powershell working-directory: ${{ inputs.working-directory }} - run: cargo nextest run --workspace --no-fail-fast + run: | + $env:RUST_BACKTRACE = "full" + + # Enable Windows debugging features + $env:_NT_SYMBOL_PATH = "srv*https://msdl.microsoft.com/download/symbols" + + # .NET crash dump environment variables (ephemeral) + $env:COMPlus_DbgEnableMiniDump = "1" + $env:COMPlus_DbgMiniDumpType = "4" + $env:COMPlus_CreateDumpDiagnostics = "1" + + cargo nextest run --workspace --no-fail-fast + continue-on-error: true + + - name: Analyze crash dumps + if: always() + shell: powershell + run: | + Write-Host "Checking for crash dumps..." + + # Get the CI run start time from the environment + $runStartTime = [DateTime]::new([long]$env:CI_RUN_START_TIME) + Write-Host "Only analyzing dumps created after: $($runStartTime.ToString('yyyy-MM-dd HH:mm:ss'))" + + # Check all possible crash dump locations + $searchPaths = @( + "$env:GITHUB_WORKSPACE\crash_dumps", + "$env:LOCALAPPDATA\CrashDumps", + "$env:TEMP", + "$env:GITHUB_WORKSPACE", + "$env:USERPROFILE\AppData\Local\CrashDumps", + "C:\Windows\System32\config\systemprofile\AppData\Local\CrashDumps" + ) + + $dumps = @() + foreach ($path in $searchPaths) { + if (Test-Path $path) { + Write-Host "Searching in: $path" + $found = Get-ChildItem "$path\*.dmp" -ErrorAction SilentlyContinue | Where-Object { + $_.CreationTime -gt $runStartTime + } + if ($found) { + $dumps += $found + Write-Host " Found $($found.Count) dump(s) from this CI run" + } + } + } + + if ($dumps) { + Write-Host "Found $($dumps.Count) crash dump(s)" + + # Install debugging tools if not present + $cdbPath = "C:\Program Files (x86)\Windows Kits\10\Debuggers\x64\cdb.exe" + if (-not (Test-Path $cdbPath)) { + Write-Host "Installing Windows Debugging Tools..." + $url = "https://go.microsoft.com/fwlink/?linkid=2237387" + Invoke-WebRequest -Uri $url -OutFile winsdksetup.exe + Start-Process -Wait winsdksetup.exe -ArgumentList "/features OptionId.WindowsDesktopDebuggers /quiet" + } + + foreach ($dump in $dumps) { + Write-Host "`n==================================" + Write-Host "Analyzing crash dump: $($dump.Name)" + Write-Host "Size: $([math]::Round($dump.Length / 1MB, 2)) MB" + Write-Host "Time: $($dump.CreationTime)" + Write-Host "==================================" + + # Set symbol path + $env:_NT_SYMBOL_PATH = "srv*C:\symbols*https://msdl.microsoft.com/download/symbols" + + # Run analysis + $analysisOutput = & $cdbPath -z $dump.FullName -c "!analyze -v; ~*k; lm; q" 2>&1 | Out-String + + # Extract key information + if ($analysisOutput -match "ExceptionCode:\s*([\w]+)") { + Write-Host "Exception Code: $($Matches[1])" + if ($Matches[1] -eq "c0000005") { + Write-Host "Exception Type: ACCESS VIOLATION" + } + } + + if ($analysisOutput -match "EXCEPTION_RECORD:\s*(.+)") { + Write-Host "Exception Record: $($Matches[1])" + } + + if ($analysisOutput -match "FAULTING_IP:\s*\n(.+)") { + Write-Host "Faulting Instruction: $($Matches[1])" + } + + # Save full analysis + $analysisFile = "$($dump.FullName).analysis.txt" + $analysisOutput | Out-File -FilePath $analysisFile + Write-Host "`nFull analysis saved to: $analysisFile" + + # Print stack trace section + Write-Host "`n--- Stack Trace Preview ---" + $stackSection = $analysisOutput -split "STACK_TEXT:" | Select-Object -Last 1 + $stackLines = $stackSection -split "`n" | Select-Object -First 20 + $stackLines | ForEach-Object { Write-Host $_ } + Write-Host "--- End Stack Trace Preview ---" + } + + Write-Host "`n⚠️ Crash dumps detected! Download the 'crash-dumps' artifact for detailed analysis." + + # Copy dumps to workspace for artifact upload + $artifactPath = "$env:GITHUB_WORKSPACE\crash_dumps_collected" + New-Item -ItemType Directory -Force -Path $artifactPath | Out-Null + + foreach ($dump in $dumps) { + $destName = "$($dump.Directory.Name)_$($dump.Name)" + Copy-Item $dump.FullName -Destination "$artifactPath\$destName" + if (Test-Path "$($dump.FullName).analysis.txt") { + Copy-Item "$($dump.FullName).analysis.txt" -Destination "$artifactPath\$destName.analysis.txt" + } + } + + Write-Host "Copied $($dumps.Count) dump(s) to artifact directory" + } else { + Write-Host "No crash dumps from this CI run found" + } + + - name: Upload crash dumps + if: always() + uses: actions/upload-artifact@v4 + with: + name: crash-dumps-${{ github.run_id }}-${{ github.run_attempt }} + path: | + crash_dumps_collected/*.dmp + crash_dumps_collected/*.txt + if-no-files-found: ignore + retention-days: 7 + + - name: Check test results + shell: powershell + working-directory: ${{ inputs.working-directory }} + run: | + # Re-check test results to fail the job if tests failed + if ($LASTEXITCODE -ne 0) { + Write-Host "Tests failed with exit code: $LASTEXITCODE" + exit $LASTEXITCODE + } diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 928c47a4a7..f4ba227168 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -718,7 +718,7 @@ jobs: timeout-minutes: 60 runs-on: github-8vcpu-ubuntu-2404 if: | - ( startsWith(github.ref, 'refs/tags/v') + false && ( startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') ) needs: [linux_tests] name: Build Zed on FreeBSD @@ -851,3 +851,12 @@ jobs: run: gh release edit "$GITHUB_REF_NAME" --draft=false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Create Sentry release + uses: getsentry/action-release@526942b68292201ac6bbb99b9a0747d4abee354c # v3 + env: + SENTRY_ORG: zed-dev + SENTRY_PROJECT: zed + SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} + with: + environment: production diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index ed9f4c8450..0cc6737a45 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -316,3 +316,12 @@ jobs: git config user.email github-actions@github.com git tag -f nightly git push origin nightly --force + + - name: Create Sentry release + uses: getsentry/action-release@526942b68292201ac6bbb99b9a0747d4abee354c # v3 + env: + SENTRY_ORG: zed-dev + SENTRY_PROJECT: zed + SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} + with: + environment: production diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index 2e03fb028f..c03cf8b087 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -3,7 +3,7 @@ name: Run Unit Evals on: schedule: # GitHub might drop jobs at busy times, so we choose a random time in the middle of the night. - - cron: "47 1 * * *" + - cron: "47 1 * * 2" workflow_dispatch: concurrency: diff --git a/Cargo.lock b/Cargo.lock index 6f434e8685..b4bf705eb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,32 +6,65 @@ version = 4 name = "acp_thread" version = "0.1.0" dependencies = [ + "action_log", + "agent", "agent-client-protocol", "anyhow", - "assistant_tool", "buffer_diff", + "collections", "editor", "env_logger 0.11.8", + "file_icons", "futures 0.3.31", "gpui", "indoc", "itertools 0.14.0", "language", - "language_model", "markdown", "parking_lot", "project", + "prompt_store", "rand 0.8.5", "serde", "serde_json", "settings", "smol", "tempfile", + "terminal", "ui", + "url", "util", + "uuid", + "watch", "workspace-hack", ] +[[package]] +name = "action_log" +version = "0.1.0" +dependencies = [ + "anyhow", + "buffer_diff", + "clock", + "collections", + "ctor", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "log", + "pretty_assertions", + "project", + "rand 0.8.5", + "serde_json", + "settings", + "text", + "util", + "watch", + "workspace-hack", + "zlog", +] + [[package]] name = "activity_indicator" version = "0.1.0" @@ -84,6 +117,7 @@ dependencies = [ name = "agent" version = "0.1.0" dependencies = [ + "action_log", "agent_settings", "anyhow", "assistant_context", @@ -138,9 +172,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.23" +version = "0.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8" +checksum = "2ab66add8be8d6a963f5bf4070045c1bbf36472837654c73e2298dd16bda5bf7" dependencies = [ "anyhow", "futures 0.3.31", @@ -156,23 +190,29 @@ name = "agent2" version = "0.1.0" dependencies = [ "acp_thread", + "action_log", "agent-client-protocol", "agent_servers", "agent_settings", "anyhow", "assistant_tool", "assistant_tools", + "chrono", "client", "clock", "cloud_llm_client", "collections", + "context_server", "ctor", + "editor", "env_logger 0.11.8", "fs", "futures 0.3.31", "gpui", "gpui_tokio", "handlebars 4.5.0", + "html_to_markdown", + "http_client", "indoc", "itertools 0.14.0", "language", @@ -180,7 +220,9 @@ dependencies = [ "language_models", "log", "lsp", + "open", "paths", + "portable-pty", "pretty_assertions", "project", "prompt_store", @@ -191,12 +233,22 @@ dependencies = [ "serde_json", "settings", "smol", + "task", + "tempfile", + "terminal", + "text", + "theme", + "tree-sitter-rust", "ui", + "unindent", "util", "uuid", "watch", + "web_search", + "which 6.0.3", "workspace-hack", "worktree", + "zlog", ] [[package]] @@ -261,6 +313,7 @@ name = "agent_ui" version = "0.1.0" dependencies = [ "acp_thread", + "action_log", "agent", "agent-client-protocol", "agent2", @@ -294,7 +347,6 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", - "indexed_docs", "indoc", "inventory", "itertools 0.14.0", @@ -342,6 +394,7 @@ dependencies = [ "ui", "ui_input", "unindent", + "url", "urlencoding", "util", "uuid", @@ -818,7 +871,6 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", - "indexed_docs", "language", "pretty_assertions", "project", @@ -842,13 +894,13 @@ dependencies = [ name = "assistant_tool" version = "0.1.0" dependencies = [ + "action_log", "anyhow", "buffer_diff", "clock", "collections", "ctor", "derive_more 0.99.19", - "futures 0.3.31", "gpui", "icons", "indoc", @@ -865,7 +917,6 @@ dependencies = [ "settings", "text", "util", - "watch", "workspace", "workspace-hack", "zlog", @@ -875,6 +926,7 @@ dependencies = [ name = "assistant_tools" version = "0.1.0" dependencies = [ + "action_log", "agent_settings", "anyhow", "assistant_tool", @@ -1208,26 +1260,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "async-stripe" -version = "0.40.0" -source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735" -dependencies = [ - "chrono", - "futures-util", - "http-types", - "hyper 0.14.32", - "hyper-rustls 0.24.2", - "serde", - "serde_json", - "serde_path_to_error", - "serde_qs 0.10.1", - "smart-default 0.6.0", - "smol_str 0.1.24", - "thiserror 1.0.69", - "tokio", -] - [[package]] name = "async-tar" version = "0.5.0" @@ -1250,9 +1282,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -2029,12 +2061,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -3227,7 +3253,6 @@ dependencies = [ "anyhow", "assistant_context", "assistant_slash_command", - "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -3243,7 +3268,6 @@ dependencies = [ "chrono", "client", "clock", - "cloud_llm_client", "collab_ui", "collections", "command_palette_hooks", @@ -3254,7 +3278,6 @@ dependencies = [ "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more 0.99.19", "editor", "envy", "extension", @@ -3270,7 +3293,6 @@ dependencies = [ "http_client", "hyper 0.14.32", "indoc", - "jsonwebtoken", "language", "language_model", "livekit_api", @@ -3316,7 +3338,6 @@ dependencies = [ "telemetry_events", "text", "theme", - "thiserror 2.0.12", "time", "tokio", "toml 0.8.20", @@ -3818,7 +3839,7 @@ dependencies = [ "rustc-hash 1.1.0", "rustybuzz 0.14.1", "self_cell", - "smol_str 0.2.2", + "smol_str", "swash", "sys-locale", "ttf-parser 0.21.1", @@ -4014,6 +4035,9 @@ dependencies = [ "log", "minidumper", "paths", + "release_channel", + "serde", + "serde_json", "smol", "workspace-hack", ] @@ -6321,17 +6345,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.15" @@ -6396,6 +6409,7 @@ dependencies = [ "log", "parking_lot", "pretty_assertions", + "rand 0.8.5", "regex", "rope", "schemars", @@ -7827,6 +7841,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + [[package]] name = "html5ever" version = "0.27.0" @@ -7928,27 +7948,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" -[[package]] -name = "http-types" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" -dependencies = [ - "anyhow", - "async-channel 1.9.0", - "base64 0.13.1", - "futures-lite 1.13.0", - "http 0.2.12", - "infer", - "pin-project-lite", - "rand 0.7.3", - "serde", - "serde_json", - "serde_qs 0.8.5", - "serde_urlencoded", - "url", -] - [[package]] name = "http_client" version = "0.1.0" @@ -8382,34 +8381,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" -[[package]] -name = "indexed_docs" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-trait", - "cargo_metadata", - "collections", - "derive_more 0.99.19", - "extension", - "fs", - "futures 0.3.31", - "fuzzy", - "gpui", - "heed", - "html_to_markdown", - "http_client", - "indexmap", - "indoc", - "parking_lot", - "paths", - "pretty_assertions", - "serde", - "strum 0.27.1", - "util", - "workspace-hack", -] - [[package]] name = "indexmap" version = "2.9.0" @@ -8427,12 +8398,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "infer" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" - [[package]] name = "inherent" version = "1.0.12" @@ -9655,6 +9620,7 @@ dependencies = [ "objc", "parking_lot", "postage", + "rodio", "scap", "serde", "serde_json", @@ -10208,7 +10174,7 @@ dependencies = [ "num-traits", "range-map", "scroll", - "smart-default 0.7.1", + "smart-default", ] [[package]] @@ -11100,14 +11066,13 @@ dependencies = [ "ai_onboarding", "anyhow", "client", - "command_palette_hooks", "component", "db", "documented", "editor", - "feature_flags", "fs", "fuzzy", + "git", "gpui", "itertools 0.14.0", "language", @@ -11119,6 +11084,7 @@ dependencies = [ "schemars", "serde", "settings", + "telemetry", "theme", "ui", "util", @@ -11194,6 +11160,7 @@ dependencies = [ "anyhow", "futures 0.3.31", "http_client", + "log", "schemars", "serde", "serde_json", @@ -13081,19 +13048,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -13115,16 +13069,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -13145,15 +13089,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - [[package]] name = "rand_core" version = "0.6.4" @@ -13172,15 +13107,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "range-map" version = "0.2.0" @@ -13522,6 +13448,7 @@ dependencies = [ name = "remote_server" version = "0.1.0" dependencies = [ + "action_log", "anyhow", "askpass", "assistant_tool", @@ -13914,6 +13841,7 @@ checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183" dependencies = [ "cpal", "dasp_sample", + "hound", "num-rational", "symphonia", "tracing", @@ -14833,28 +14761,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_qs" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - -[[package]] -name = "serde_qs" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - [[package]] name = "serde_repr" version = "0.1.20" @@ -14996,8 +14902,10 @@ dependencies = [ "ui", "ui_input", "util", + "vim", "workspace", "workspace-hack", + "zed_actions", ] [[package]] @@ -15229,17 +15137,6 @@ dependencies = [ "serde", ] -[[package]] -name = "smart-default" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "smart-default" version = "0.7.1" @@ -15268,15 +15165,6 @@ dependencies = [ "futures-lite 2.6.0", ] -[[package]] -name = "smol_str" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" -dependencies = [ - "serde", -] - [[package]] name = "smol_str" version = "0.2.2" @@ -17978,6 +17866,7 @@ dependencies = [ "command_palette_hooks", "db", "editor", + "env_logger 0.11.8", "futures 0.3.31", "git_ui", "gpui", @@ -18124,12 +18013,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -18831,33 +18714,6 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" -[[package]] -name = "welcome" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "component", - "db", - "documented", - "editor", - "fuzzy", - "gpui", - "install_cli", - "language", - "picker", - "project", - "serde", - "settings", - "telemetry", - "ui", - "util", - "vim_mode_setting", - "workspace", - "workspace-hack", - "zed_actions", -] - [[package]] name = "which" version = "4.4.2" @@ -20243,7 +20099,7 @@ dependencies = [ [[package]] name = "xim" version = "0.4.0" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "ahash 0.8.11", "hashbrown 0.14.5", @@ -20256,7 +20112,7 @@ dependencies = [ [[package]] name = "xim-ctext" version = "0.3.0" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "encoding_rs", ] @@ -20264,7 +20120,7 @@ dependencies = [ [[package]] name = "xim-parser" version = "0.2.1" -source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd" +source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" dependencies = [ "bitflags 2.9.0", ] @@ -20472,7 +20328,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.200.0" +version = "0.201.0" dependencies = [ "activity_indicator", "agent", @@ -20542,6 +20398,7 @@ dependencies = [ "language_tools", "languages", "libc", + "livekit_client", "log", "markdown", "markdown_preview", @@ -20612,7 +20469,6 @@ dependencies = [ "watch", "web_search", "web_search_providers", - "welcome", "windows 0.61.1", "winresource", "workspace", @@ -20634,13 +20490,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "zed_emmet" -version = "0.0.4" -dependencies = [ - "zed_extension_api 0.1.0", -] - [[package]] name = "zed_extension_api" version = "0.1.0" @@ -20875,6 +20724,7 @@ dependencies = [ "menu", "postage", "project", + "rand 0.8.5", "regex", "release_channel", "reqwest_client", diff --git a/Cargo.toml b/Cargo.toml index 998e727602..b3105bd97c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "crates/acp_thread", + "crates/action_log", "crates/activity_indicator", "crates/agent", "crates/agent2", @@ -80,7 +81,6 @@ members = [ "crates/http_client_tls", "crates/icons", "crates/image_viewer", - "crates/indexed_docs", "crates/edit_prediction", "crates/edit_prediction_button", "crates/inspector_ui", @@ -184,7 +184,6 @@ members = [ "crates/watch", "crates/web_search", "crates/web_search_providers", - "crates/welcome", "crates/workspace", "crates/worktree", "crates/x_ai", @@ -199,7 +198,6 @@ members = [ # Extensions # - "extensions/emmet", "extensions/glsl", "extensions/html", "extensions/proto", @@ -229,6 +227,7 @@ edition = "2024" # acp_thread = { path = "crates/acp_thread" } +action_log = { path = "crates/action_log" } agent = { path = "crates/agent" } agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } @@ -305,7 +304,6 @@ http_client = { path = "crates/http_client" } http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } -indexed_docs = { path = "crates/indexed_docs" } edit_prediction = { path = "crates/edit_prediction" } edit_prediction_button = { path = "crates/edit_prediction_button" } inspector_ui = { path = "crates/inspector_ui" } @@ -362,6 +360,7 @@ remote_server = { path = "crates/remote_server" } repl = { path = "crates/repl" } reqwest_client = { path = "crates/reqwest_client" } rich_text = { path = "crates/rich_text" } +rodio = { version = "0.21.1", default-features = false } rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } rules_library = { path = "crates/rules_library" } @@ -410,7 +409,6 @@ vim_mode_setting = { path = "crates/vim_mode_setting" } watch = { path = "crates/watch" } web_search = { path = "crates/web_search" } web_search_providers = { path = "crates/web_search_providers" } -welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } x_ai = { path = "crates/x_ai" } @@ -425,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.23" +agent-client-protocol = "0.0.25" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -666,20 +664,6 @@ workspace-hack = "0.1.0" yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } zstd = "0.11" -[workspace.dependencies.async-stripe] -git = "https://github.com/zed-industries/async-stripe" -rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735" -default-features = false -features = [ - "runtime-tokio-hyper-rustls", - "billing", - "checkout", - "events", - # The features below are only enabled to get the `events` feature to build. - "chrono", - "connect", -] - [workspace.dependencies.windows] version = "0.61" features = [ @@ -712,6 +696,7 @@ features = [ "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Ole", + "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_SystemInformation", "Win32_System_SystemServices", @@ -839,6 +824,7 @@ style = { level = "allow", priority = -1 } module_inception = { level = "deny" } question_mark = { level = "deny" } redundant_closure = { level = "deny" } +declare_interior_mutable_const = { level = "deny" } # Individual rules that have violations in the codebase: type_complexity = "allow" # We often return trait objects from `new` functions. diff --git a/Dockerfile-collab b/Dockerfile-collab index 2dafe296c7..c1621d6ee6 100644 --- a/Dockerfile-collab +++ b/Dockerfile-collab @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:1.2 -FROM rust:1.88-bookworm as builder +FROM rust:1.89-bookworm as builder WORKDIR app COPY . . diff --git a/assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf b/assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf new file mode 100644 index 0000000000..1d66b1a2e9 Binary files /dev/null and b/assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf differ diff --git a/assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf b/assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf new file mode 100644 index 0000000000..e07bc1f527 Binary files /dev/null and b/assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf differ diff --git a/assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf b/assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf new file mode 100644 index 0000000000..efe8a1fb9d Binary files /dev/null and b/assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf differ diff --git a/assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf b/assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf new file mode 100644 index 0000000000..bd6817d520 Binary files /dev/null and b/assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf differ diff --git a/assets/fonts/plex-mono/license.txt b/assets/fonts/ibm-plex-sans/license.txt similarity index 100% rename from assets/fonts/plex-mono/license.txt rename to assets/fonts/ibm-plex-sans/license.txt diff --git a/assets/fonts/lilex/Lilex-Bold.ttf b/assets/fonts/lilex/Lilex-Bold.ttf new file mode 100644 index 0000000000..45930ee30b Binary files /dev/null and b/assets/fonts/lilex/Lilex-Bold.ttf differ diff --git a/assets/fonts/lilex/Lilex-BoldItalic.ttf b/assets/fonts/lilex/Lilex-BoldItalic.ttf new file mode 100644 index 0000000000..10c6ab5f74 Binary files /dev/null and b/assets/fonts/lilex/Lilex-BoldItalic.ttf differ diff --git a/assets/fonts/lilex/Lilex-Italic.ttf b/assets/fonts/lilex/Lilex-Italic.ttf new file mode 100644 index 0000000000..e7aef10f7e Binary files /dev/null and b/assets/fonts/lilex/Lilex-Italic.ttf differ diff --git a/assets/fonts/lilex/Lilex-Regular.ttf b/assets/fonts/lilex/Lilex-Regular.ttf new file mode 100644 index 0000000000..cb98a69b0f Binary files /dev/null and b/assets/fonts/lilex/Lilex-Regular.ttf differ diff --git a/assets/fonts/plex-sans/license.txt b/assets/fonts/lilex/OFL.txt similarity index 96% rename from assets/fonts/plex-sans/license.txt rename to assets/fonts/lilex/OFL.txt index f72f76504c..156240bc90 100644 --- a/assets/fonts/plex-sans/license.txt +++ b/assets/fonts/lilex/OFL.txt @@ -1,8 +1,9 @@ -Copyright © 2017 IBM Corp. with Reserved Font Name "Plex" +Copyright 2019 The Lilex Project Authors (https://github.com/mishamyrt/Lilex) This Font Software is licensed under the SIL Open Font License, Version 1.1. This license is copied below, and is also available with a FAQ at: -http://scripts.sil.org/OFL +https://scripts.sil.org/OFL + ----------------------------------------------------------- SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 @@ -89,4 +90,4 @@ COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM -OTHER DEALINGS IN THE FONT SOFTWARE. \ No newline at end of file +OTHER DEALINGS IN THE FONT SOFTWARE. diff --git a/assets/fonts/plex-mono/ZedPlexMono-Bold.ttf b/assets/fonts/plex-mono/ZedPlexMono-Bold.ttf deleted file mode 100644 index d5f4b5e285..0000000000 Binary files a/assets/fonts/plex-mono/ZedPlexMono-Bold.ttf and /dev/null differ diff --git a/assets/fonts/plex-mono/ZedPlexMono-BoldItalic.ttf b/assets/fonts/plex-mono/ZedPlexMono-BoldItalic.ttf deleted file mode 100644 index 05eaf7cccd..0000000000 Binary files a/assets/fonts/plex-mono/ZedPlexMono-BoldItalic.ttf and /dev/null differ diff --git a/assets/fonts/plex-mono/ZedPlexMono-Italic.ttf b/assets/fonts/plex-mono/ZedPlexMono-Italic.ttf deleted file mode 100644 index 3b07821757..0000000000 Binary files a/assets/fonts/plex-mono/ZedPlexMono-Italic.ttf and /dev/null differ diff --git a/assets/fonts/plex-mono/ZedPlexMono-Regular.ttf b/assets/fonts/plex-mono/ZedPlexMono-Regular.ttf deleted file mode 100644 index 61dbb58361..0000000000 Binary files a/assets/fonts/plex-mono/ZedPlexMono-Regular.ttf and /dev/null differ diff --git a/assets/fonts/plex-sans/ZedPlexSans-Bold.ttf b/assets/fonts/plex-sans/ZedPlexSans-Bold.ttf deleted file mode 100644 index f1e66392f7..0000000000 Binary files a/assets/fonts/plex-sans/ZedPlexSans-Bold.ttf and /dev/null differ diff --git a/assets/fonts/plex-sans/ZedPlexSans-BoldItalic.ttf b/assets/fonts/plex-sans/ZedPlexSans-BoldItalic.ttf deleted file mode 100644 index 7612dc5167..0000000000 Binary files a/assets/fonts/plex-sans/ZedPlexSans-BoldItalic.ttf and /dev/null differ diff --git a/assets/fonts/plex-sans/ZedPlexSans-Italic.ttf b/assets/fonts/plex-sans/ZedPlexSans-Italic.ttf deleted file mode 100644 index 8769c232ee..0000000000 Binary files a/assets/fonts/plex-sans/ZedPlexSans-Italic.ttf and /dev/null differ diff --git a/assets/fonts/plex-sans/ZedPlexSans-Regular.ttf b/assets/fonts/plex-sans/ZedPlexSans-Regular.ttf deleted file mode 100644 index 3ea293d59a..0000000000 Binary files a/assets/fonts/plex-sans/ZedPlexSans-Regular.ttf and /dev/null differ diff --git a/assets/icons/arrow_circle.svg b/assets/icons/arrow_circle.svg index 790428702e..76363c6270 100644 --- a/assets/icons/arrow_circle.svg +++ b/assets/icons/arrow_circle.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/blocks.svg b/assets/icons/blocks.svg index 128ca84ef1..e1690e2642 100644 --- a/assets/icons/blocks.svg +++ b/assets/icons/blocks.svg @@ -1 +1,3 @@ - + + + diff --git a/assets/icons/debug_disabled_log_breakpoint.svg b/assets/icons/debug_disabled_log_breakpoint.svg index a028ead3a0..2ccc37623d 100644 --- a/assets/icons/debug_disabled_log_breakpoint.svg +++ b/assets/icons/debug_disabled_log_breakpoint.svg @@ -1,3 +1,5 @@ - + + + diff --git a/assets/icons/debug_ignore_breakpoints.svg b/assets/icons/debug_ignore_breakpoints.svg index a0bbabfb26..b2a345d314 100644 --- a/assets/icons/debug_ignore_breakpoints.svg +++ b/assets/icons/debug_ignore_breakpoints.svg @@ -1 +1,3 @@ - + + + diff --git a/assets/icons/debug_log_breakpoint.svg b/assets/icons/debug_log_breakpoint.svg index 7c652db1e9..22eae9d029 100644 --- a/assets/icons/debug_log_breakpoint.svg +++ b/assets/icons/debug_log_breakpoint.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/icons/folder_search.svg b/assets/icons/folder_search.svg index 15b0705dd6..d1bc537c98 100644 --- a/assets/icons/folder_search.svg +++ b/assets/icons/folder_search.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/json.svg b/assets/icons/json.svg new file mode 100644 index 0000000000..5f012f8838 --- /dev/null +++ b/assets/icons/json.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/icons/maximize.svg b/assets/icons/maximize.svg index c51b71aaf0..ee03a2c021 100644 --- a/assets/icons/maximize.svg +++ b/assets/icons/maximize.svg @@ -1,6 +1,6 @@ - - + + diff --git a/assets/icons/minimize.svg b/assets/icons/minimize.svg index 97d4699687..ea825f054e 100644 --- a/assets/icons/minimize.svg +++ b/assets/icons/minimize.svg @@ -1,6 +1,6 @@ - - - - + + + + diff --git a/assets/icons/scissors.svg b/assets/icons/scissors.svg index 89d246841e..430293f913 100644 --- a/assets/icons/scissors.svg +++ b/assets/icons/scissors.svg @@ -1 +1,3 @@ - + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 708432393c..01c0b4e969 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -239,6 +239,7 @@ "ctrl-shift-a": "agent::ToggleContextPicker", "ctrl-shift-j": "agent::ToggleNavigationMenu", "ctrl-shift-i": "agent::ToggleOptionsMenu", + "ctrl-alt-shift-n": "agent::ToggleNewThreadMenu", "shift-alt-escape": "agent::ExpandMessageEditor", "ctrl->": "assistant::QuoteSelection", "ctrl-alt-e": "agent::RemoveAllContext", @@ -330,8 +331,6 @@ "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", - "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage", "shift-ctrl-r": "agent::OpenAgentDiff", "ctrl-shift-y": "agent::KeepAll", "ctrl-shift-n": "agent::RejectAll" diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index abb741af29..e5b7fff9e1 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -279,6 +279,7 @@ "cmd-shift-a": "agent::ToggleContextPicker", "cmd-shift-j": "agent::ToggleNavigationMenu", "cmd-shift-i": "agent::ToggleOptionsMenu", + "cmd-alt-shift-n": "agent::ToggleNewThreadMenu", "shift-alt-escape": "agent::ExpandMessageEditor", "cmd->": "assistant::QuoteSelection", "cmd-alt-e": "agent::RemoveAllContext", @@ -382,8 +383,6 @@ "use_key_equivalents": true, "bindings": { "enter": "agent::Chat", - "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage", "shift-ctrl-r": "agent::OpenAgentDiff", "cmd-shift-y": "agent::KeepAll", "cmd-shift-n": "agent::RejectAll" diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 57edb1e4c1..be6d34a134 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -58,6 +58,8 @@ "[ space": "vim::InsertEmptyLineAbove", "[ e": "editor::MoveLineUp", "] e": "editor::MoveLineDown", + "[ f": "workspace::FollowNextCollaborator", + "] f": "workspace::FollowNextCollaborator", // Word motions "w": "vim::NextWordStart", @@ -333,10 +335,14 @@ "ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific "ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific "ctrl-x ctrl-z": "editor::Cancel", + "ctrl-x ctrl-e": "vim::LineDown", + "ctrl-x ctrl-y": "vim::LineUp", "ctrl-w": "editor::DeleteToPreviousWordStart", "ctrl-u": "editor::DeleteToBeginningOfLine", "ctrl-t": "vim::Indent", "ctrl-d": "vim::Outdent", + "ctrl-y": "vim::InsertFromAbove", + "ctrl-e": "vim::InsertFromBelow", "ctrl-k": ["vim::PushDigraph", {}], "ctrl-v": ["vim::PushLiteral", {}], "ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use. @@ -386,7 +392,7 @@ "right": "vim::WrappingRight", "h": "vim::WrappingLeft", "l": "vim::WrappingRight", - "y": "editor::Copy", + "y": "vim::HelixYank", "alt-;": "vim::OtherEnd", "ctrl-r": "vim::Redo", "f": ["vim::PushFindForward", { "before": false, "multiline": true }], @@ -403,6 +409,7 @@ "g w": "vim::PushRewrap", "insert": "vim::InsertBefore", "alt-.": "vim::RepeatFind", + "alt-s": ["editor::SplitSelectionIntoLines", { "keep_selections": true }], // tree-sitter related commands "[ x": "editor::SelectLargerSyntaxNode", "] x": "editor::SelectSmallerSyntaxNode", diff --git a/assets/settings/default.json b/assets/settings/default.json index 28cf591ee7..6a8b034268 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -28,7 +28,9 @@ "edit_prediction_provider": "zed" }, // The name of a font to use for rendering text in the editor - "buffer_font_family": "Zed Plex Mono", + // ".ZedMono" currently aliases to Lilex + // but this may change in the future. + "buffer_font_family": ".ZedMono", // Set the buffer text's font fallbacks, this will be merged with // the platform's default fallbacks. "buffer_font_fallbacks": null, @@ -54,7 +56,9 @@ "buffer_line_height": "comfortable", // The name of a font to use for rendering text in the UI // You can set this to ".SystemUIFont" to use the system font - "ui_font_family": "Zed Plex Sans", + // ".ZedSans" currently aliases to "IBM Plex Sans", but this may + // change in the future + "ui_font_family": ".ZedSans", // Set the UI's font fallbacks, this will be merged with the platform's // default font fallbacks. "ui_font_fallbacks": null, @@ -67,8 +71,8 @@ "ui_font_weight": 400, // The default font size for text in the UI "ui_font_size": 16, - // The default font size for text in the agent panel - "agent_font_size": 16, + // The default font size for text in the agent panel. Falls back to the UI font size if unset. + "agent_font_size": null, // How much to fade out unused code. "unnecessary_code_fade": 0.3, // Active pane styling settings. @@ -82,10 +86,10 @@ // Layout mode of the bottom dock. Defaults to "contained" // choices: contained, full, left_aligned, right_aligned "bottom_dock_layout": "contained", - // The direction that you want to split panes horizontally. Defaults to "up" - "pane_split_direction_horizontal": "up", - // The direction that you want to split panes vertically. Defaults to "left" - "pane_split_direction_vertical": "left", + // The direction that you want to split panes horizontally. Defaults to "down" + "pane_split_direction_horizontal": "down", + // The direction that you want to split panes vertically. Defaults to "right" + "pane_split_direction_vertical": "right", // Centered layout related settings. "centered_layout": { // The relative width of the left padding of the central pane from the @@ -883,11 +887,6 @@ }, // The settings for slash commands. "slash_commands": { - // Settings for the `/docs` slash command. - "docs": { - // Whether `/docs` is enabled. - "enabled": false - }, // Settings for the `/project` slash command. "project": { // Whether `/project` is enabled. @@ -1252,7 +1251,9 @@ // Status bar-related settings. "status_bar": { // Whether to show the active language button in the status bar. - "active_language_button": true + "active_language_button": true, + // Whether to show the cursor position button in the status bar. + "cursor_position_button": true }, // Settings specific to the terminal "terminal": { @@ -1402,7 +1403,7 @@ // "font_size": 15, // Set the terminal's font family. If this option is not included, // the terminal will default to matching the buffer's font family. - // "font_family": "Zed Plex Mono", + // "font_family": ".ZedMono", // Set the terminal's font fallbacks. If this option is not included, // the terminal will default to matching the buffer's font fallbacks. // This will be merged with the platform's default font fallbacks diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 1831c7e473..2b9a6513c8 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -13,27 +13,35 @@ path = "src/acp_thread.rs" doctest = false [features] -test-support = ["gpui/test-support", "project/test-support"] +test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"] [dependencies] +action_log.workspace = true agent-client-protocol.workspace = true +agent.workspace = true anyhow.workspace = true -assistant_tool.workspace = true buffer_diff.workspace = true +collections.workspace = true editor.workspace = true +file_icons.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true -language_model.workspace = true markdown.workspace = true +parking_lot = { workspace = true, optional = true } project.workspace = true +prompt_store.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +terminal.workspace = true ui.workspace = true +url.workspace = true util.workspace = true +uuid.workspace = true +watch.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 1df0e1def7..3bb1b99ba1 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,85 +1,63 @@ mod connection; mod diff; +mod mention; +mod terminal; pub use connection::*; pub use diff::*; +pub use mention::*; +pub use terminal::*; +use action_log::ActionLog; use agent_client_protocol as acp; -use anyhow::{Context as _, Result}; -use assistant_tool::ActionLog; +use anyhow::{Context as _, Result, anyhow}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use itertools::Itertools; -use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff}; +use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff}; use markdown::Markdown; -use project::{AgentLocation, Project}; +use project::{AgentLocation, Project, git_store::GitStoreCheckpoint}; use std::collections::HashMap; use std::error::Error; -use std::fmt::Formatter; +use std::fmt::{Formatter, Write}; +use std::ops::Range; use std::process::ExitStatus; use std::rc::Rc; -use std::{ - fmt::Display, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; use util::ResultExt; #[derive(Debug)] pub struct UserMessage { + pub id: Option, pub content: ContentBlock, -} - -impl UserMessage { - pub fn from_acp( - message: impl IntoIterator, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let mut content = ContentBlock::Empty; - for chunk in message { - content.append(chunk, &language_registry, cx) - } - Self { content: content } - } - - fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) - } + pub chunks: Vec, + pub checkpoint: Option, } #[derive(Debug)] -pub struct MentionPath<'a>(&'a Path); - -impl<'a> MentionPath<'a> { - const PREFIX: &'static str = "@file:"; - - pub fn new(path: &'a Path) -> Self { - MentionPath(path) - } - - pub fn try_parse(url: &'a str) -> Option { - let path = url.strip_prefix(Self::PREFIX)?; - Some(MentionPath(Path::new(path))) - } - - pub fn path(&self) -> &Path { - self.0 - } +pub struct Checkpoint { + git_checkpoint: GitStoreCheckpoint, + pub show: bool, } -impl Display for MentionPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "[@{}]({}{})", - self.0.file_name().unwrap_or_default().display(), - Self::PREFIX, - self.0.display() - ) +impl UserMessage { + fn to_markdown(&self, cx: &App) -> String { + let mut markdown = String::new(); + if self + .checkpoint + .as_ref() + .map_or(false, |checkpoint| checkpoint.show) + { + writeln!(markdown, "## User (checkpoint)").unwrap(); + } else { + writeln!(markdown, "## User").unwrap(); + } + writeln!(markdown).unwrap(); + writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap(); + writeln!(markdown).unwrap(); + markdown } } @@ -131,7 +109,7 @@ pub enum AgentThreadEntry { } impl AgentThreadEntry { - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), @@ -139,6 +117,14 @@ impl AgentThreadEntry { } } + pub fn user_message(&self) -> Option<&UserMessage> { + if let AgentThreadEntry::UserMessage(message) = self { + Some(message) + } else { + None + } + } + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) @@ -147,9 +133,25 @@ impl AgentThreadEntry { } } - pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> { - if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self { - Some(locations) + pub fn terminals(&self) -> impl Iterator> { + if let AgentThreadEntry::ToolCall(call) = self { + itertools::Either::Left(call.terminals()) + } else { + itertools::Either::Right(std::iter::empty()) + } + } + + pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> { + if let AgentThreadEntry::ToolCall(ToolCall { + locations, + resolved_locations, + .. + }) = self + { + Some(( + locations.get(ix)?.clone(), + resolved_locations.get(ix)?.clone()?, + )) } else { None } @@ -164,6 +166,7 @@ pub struct ToolCall { pub content: Vec, pub status: ToolCallStatus, pub locations: Vec, + pub resolved_locations: Vec>, pub raw_input: Option, pub raw_output: Option, } @@ -192,6 +195,7 @@ impl ToolCall { .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) .collect(), locations: tool_call.locations, + resolved_locations: Vec::default(), status, raw_input: tool_call.raw_input, raw_output: tool_call.raw_output, @@ -244,14 +248,32 @@ impl ToolCall { } if let Some(raw_output) = raw_output { + if self.content.is_empty() { + if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) + { + self.content + .push(ToolCallContent::ContentBlock(ContentBlock::Markdown { + markdown, + })); + } + } self.raw_output = Some(raw_output); } } pub fn diffs(&self) -> impl Iterator> { self.content.iter().filter_map(|content| match content { - ToolCallContent::ContentBlock { .. } => None, - ToolCallContent::Diff { diff } => Some(diff), + ToolCallContent::Diff(diff) => Some(diff), + ToolCallContent::ContentBlock(_) => None, + ToolCallContent::Terminal(_) => None, + }) + } + + pub fn terminals(&self) -> impl Iterator> { + self.content.iter().filter_map(|content| match content { + ToolCallContent::Terminal(terminal) => Some(terminal), + ToolCallContent::ContentBlock(_) => None, + ToolCallContent::Diff(_) => None, }) } @@ -267,6 +289,57 @@ impl ToolCall { } markdown } + + async fn resolve_location( + location: acp::ToolCallLocation, + project: WeakEntity, + cx: &mut AsyncApp, + ) -> Option { + let buffer = project + .update(cx, |project, cx| { + if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) { + Some(project.open_buffer(path, cx)) + } else { + None + } + }) + .ok()??; + let buffer = buffer.await.log_err()?; + let position = buffer + .update(cx, |buffer, _| { + if let Some(row) = location.line { + let snapshot = buffer.snapshot(); + let column = snapshot.indent_size_for_line(row).len; + let point = snapshot.clip_point(Point::new(row, column), Bias::Left); + snapshot.anchor_before(point) + } else { + Anchor::MIN + } + }) + .ok()?; + + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }) + } + + fn resolve_locations( + &self, + project: Entity, + cx: &mut App, + ) -> Task>> { + let locations = self.locations.clone(); + project.update(cx, |_, cx| { + cx.spawn(async move |project, cx| { + let mut new_locations = Vec::new(); + for location in locations { + new_locations.push(Self::resolve_location(location, project.clone(), cx).await); + } + new_locations + }) + }) + } } #[derive(Debug)] @@ -306,6 +379,7 @@ impl Display for ToolCallStatus { pub enum ContentBlock { Empty, Markdown { markdown: Entity }, + ResourceLink { resource_link: acp::ResourceLink }, } impl ContentBlock { @@ -337,43 +411,78 @@ impl ContentBlock { language_registry: &Arc, cx: &mut App, ) { - let new_content = match block { - acp::ContentBlock::Text(text_content) => text_content.text.clone(), - acp::ContentBlock::ResourceLink(resource_link) => { - if let Some(path) = resource_link.uri.strip_prefix("file://") { - format!("{}", MentionPath(path.as_ref())) - } else { - resource_link.uri.clone() - } + if matches!(self, ContentBlock::Empty) { + if let acp::ContentBlock::ResourceLink(resource_link) = block { + *self = ContentBlock::ResourceLink { resource_link }; + return; } - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => String::new(), - }; + } + + let new_content = self.block_string_contents(block); match self { ContentBlock::Empty => { - *self = ContentBlock::Markdown { - markdown: cx.new(|cx| { - Markdown::new( - new_content.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - }; + *self = Self::create_markdown_block(new_content, language_registry, cx); } ContentBlock::Markdown { markdown } => { markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); } + ContentBlock::ResourceLink { resource_link } => { + let existing_content = Self::resource_link_md(&resource_link.uri); + let combined = format!("{}\n{}", existing_content, new_content); + + *self = Self::create_markdown_block(combined, language_registry, cx); + } } } + fn create_markdown_block( + content: String, + language_registry: &Arc, + cx: &mut App, + ) -> ContentBlock { + ContentBlock::Markdown { + markdown: cx + .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)), + } + } + + fn block_string_contents(&self, block: acp::ContentBlock) -> String { + match block { + acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::ResourceLink(resource_link) => { + Self::resource_link_md(&resource_link.uri) + } + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: + acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents { + uri, + .. + }), + .. + }) => Self::resource_link_md(&uri), + acp::ContentBlock::Image(image) => Self::image_md(&image), + acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(), + } + } + + fn resource_link_md(uri: &str) -> String { + if let Some(uri) = MentionUri::parse(&uri).log_err() { + uri.as_link().to_string() + } else { + uri.to_string() + } + } + + fn image_md(_image: &acp::ImageContent) -> String { + "`Image`".into() + } + fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { match self { ContentBlock::Empty => "", ContentBlock::Markdown { markdown } => markdown.read(cx).source(), + ContentBlock::ResourceLink { resource_link } => &resource_link.uri, } } @@ -381,14 +490,23 @@ impl ContentBlock { match self { ContentBlock::Empty => None, ContentBlock::Markdown { markdown } => Some(markdown), + ContentBlock::ResourceLink { .. } => None, + } + } + + pub fn resource_link(&self) -> Option<&acp::ResourceLink> { + match self { + ContentBlock::ResourceLink { resource_link } => Some(resource_link), + _ => None, } } } #[derive(Debug)] pub enum ToolCallContent { - ContentBlock { content: ContentBlock }, - Diff { diff: Entity }, + ContentBlock(ContentBlock), + Diff(Entity), + Terminal(Entity), } impl ToolCallContent { @@ -398,19 +516,20 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::Content { content } => Self::ContentBlock { - content: ContentBlock::new(content, &language_registry, cx), - }, - acp::ToolCallContent::Diff { diff } => Self::Diff { - diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)), - }, + acp::ToolCallContent::Content { content } => { + Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) + } + acp::ToolCallContent::Diff { diff } => { + Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) + } } } pub fn to_markdown(&self, cx: &App) -> String { match self { - Self::ContentBlock { content } => content.to_markdown(cx).to_string(), - Self::Diff { diff } => diff.read(cx).to_markdown(cx), + Self::ContentBlock(content) => content.to_markdown(cx).to_string(), + Self::Diff(diff) => diff.read(cx).to_markdown(cx), + Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx), } } } @@ -419,6 +538,7 @@ impl ToolCallContent { pub enum ToolCallUpdate { UpdateFields(acp::ToolCallUpdate), UpdateDiff(ToolCallUpdateDiff), + UpdateTerminal(ToolCallUpdateTerminal), } impl ToolCallUpdate { @@ -426,6 +546,7 @@ impl ToolCallUpdate { match self { Self::UpdateFields(update) => &update.id, Self::UpdateDiff(diff) => &diff.id, + Self::UpdateTerminal(terminal) => &terminal.id, } } } @@ -448,6 +569,18 @@ pub struct ToolCallUpdateDiff { pub diff: Entity, } +impl From for ToolCallUpdate { + fn from(terminal: ToolCallUpdateTerminal) -> Self { + Self::UpdateTerminal(terminal) + } +} + +#[derive(Debug, PartialEq)] +pub struct ToolCallUpdateTerminal { + pub id: acp::ToolCallId, + pub terminal: Entity, +} + #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -522,6 +655,7 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + EntriesRemoved(Range), ToolAuthorizationRequired, Stopped, Error, @@ -583,6 +717,10 @@ impl AcpThread { } } + pub fn connection(&self) -> &Rc { + &self.connection + } + pub fn action_log(&self) -> &Entity { &self.action_log } @@ -654,10 +792,10 @@ impl AcpThread { &mut self, update: acp::SessionUpdate, cx: &mut Context, - ) -> Result<()> { + ) -> Result<(), acp::Error> { match update { acp::SessionUpdate::UserMessageChunk { content } => { - self.push_user_content_block(content, cx); + self.push_user_content_block(None, content, cx); } acp::SessionUpdate::AgentMessageChunk { content } => { self.push_assistant_content_block(content, false, cx); @@ -666,7 +804,7 @@ impl AcpThread { self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { - self.upsert_tool_call(tool_call, cx); + self.upsert_tool_call(tool_call, cx)?; } acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { self.update_tool_call(tool_call_update, cx)?; @@ -678,18 +816,39 @@ impl AcpThread { Ok(()) } - pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + pub fn push_user_content_block( + &mut self, + message_id: Option, + chunk: acp::ContentBlock, + cx: &mut Context, + ) { let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + && let AgentThreadEntry::UserMessage(UserMessage { + id, + content, + chunks, + .. + }) = last_entry { - content.append(chunk, &language_registry, cx); - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + *id = message_id.or(id.take()); + content.append(chunk.clone(), &language_registry, cx); + chunks.push(chunk); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); } else { - let content = ContentBlock::new(chunk, &language_registry, cx); - self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + let content = ContentBlock::new(chunk.clone(), &language_registry, cx); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { + id: message_id, + content, + chunks: vec![chunk], + checkpoint: None, + }), + cx, + ); } } @@ -704,7 +863,8 @@ impl AcpThread { if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); match (chunks.last_mut(), is_thought) { (Some(AssistantMessageChunk::Message { block }), false) | (Some(AssistantMessageChunk::Thought { block }), true) => { @@ -754,13 +914,23 @@ impl AcpThread { .context("Tool call not found")?; match update { ToolCallUpdate::UpdateFields(update) => { + let location_updated = update.fields.locations.is_some(); current_call.update_fields(update.fields, languages, cx); + if location_updated { + self.resolve_locations(update.id.clone(), cx); + } } ToolCallUpdate::UpdateDiff(update) => { current_call.content.clear(); current_call .content - .push(ToolCallContent::Diff { diff: update.diff }); + .push(ToolCallContent::Diff(update.diff)); + } + ToolCallUpdate::UpdateTerminal(update) => { + current_call.content.clear(); + current_call + .content + .push(ToolCallContent::Terminal(update.terminal)); } } @@ -770,35 +940,40 @@ impl AcpThread { } /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. - pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { + pub fn upsert_tool_call( + &mut self, + tool_call: acp::ToolCall, + cx: &mut Context, + ) -> Result<(), acp::Error> { let status = ToolCallStatus::Allowed { status: tool_call.status, }; - self.upsert_tool_call_inner(tool_call, status, cx) + self.upsert_tool_call_inner(tool_call.into(), status, cx) } + /// Fails if id does not match an existing entry. pub fn upsert_tool_call_inner( &mut self, - tool_call: acp::ToolCall, + tool_call_update: acp::ToolCallUpdate, status: ToolCallStatus, cx: &mut Context, - ) { + ) -> Result<(), acp::Error> { let language_registry = self.project.read(cx).languages().clone(); - let call = ToolCall::from_acp(tool_call, status, language_registry, cx); + let id = tool_call_update.id.clone(); - let location = call.locations.last().cloned(); - - if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { - *current_call = call; + if let Some((ix, current_call)) = self.tool_call_mut(&id) { + current_call.update_fields(tool_call_update.fields, language_registry, cx); + current_call.status = status; cx.emit(AcpThreadEvent::EntryUpdated(ix)); } else { + let call = + ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx); self.push_entry(AgentThreadEntry::ToolCall(call), cx); - } + }; - if let Some(location) = location { - self.set_project_location(location, cx) - } + self.resolve_locations(id, cx); + Ok(()) } fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { @@ -819,43 +994,58 @@ impl AcpThread { }) } - pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { - self.project.update(cx, |project, cx| { - let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { - return; - }; - let buffer = project.open_buffer(path, cx); - cx.spawn(async move |project, cx| { - let buffer = buffer.await?; - - project.update(cx, |project, cx| { - let position = if let Some(line) = location.line { - let snapshot = buffer.read(cx).snapshot(); - let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); - snapshot.anchor_before(point) - } else { - Anchor::MIN - }; - - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position, - }), - cx, - ); - }) + pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context) { + let project = self.project.clone(); + let Some((_, tool_call)) = self.tool_call_mut(&id) else { + return; + }; + let task = tool_call.resolve_locations(project, cx); + cx.spawn(async move |this, cx| { + let resolved_locations = task.await; + this.update(cx, |this, cx| { + let project = this.project.clone(); + let Some((ix, tool_call)) = this.tool_call_mut(&id) else { + return; + }; + if let Some(Some(location)) = resolved_locations.last() { + project.update(cx, |project, cx| { + if let Some(agent_location) = project.agent_location() { + let should_ignore = agent_location.buffer == location.buffer + && location + .buffer + .update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let old_position = + agent_location.position.to_point(&snapshot); + let new_position = location.position.to_point(&snapshot); + // ignore this so that when we get updates from the edit tool + // the position doesn't reset to the startof line + old_position.row == new_position.row + && old_position.column > new_position.column + }) + .ok() + .unwrap_or_default(); + if !should_ignore { + project.set_agent_location(Some(location.clone()), cx); + } + } + }); + } + if tool_call.resolved_locations != resolved_locations { + tool_call.resolved_locations = resolved_locations; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } }) - .detach_and_log_err(cx); - }); + }) + .detach(); } pub fn request_tool_call_authorization( &mut self, - tool_call: acp::ToolCall, + tool_call: acp::ToolCallUpdate, options: Vec, cx: &mut Context, - ) -> oneshot::Receiver { + ) -> Result, acp::Error> { let (tx, rx) = oneshot::channel(); let status = ToolCallStatus::WaitingForConfirmation { @@ -863,9 +1053,9 @@ impl AcpThread { respond_tx: tx, }; - self.upsert_tool_call_inner(tool_call, status, cx); + self.upsert_tool_call_inner(tool_call, status, cx)?; cx.emit(AcpThreadEvent::ToolAuthorizationRequired); - rx + Ok(rx) } pub fn authorize_tool_call( @@ -981,66 +1171,112 @@ impl AcpThread { self.project.read(cx).languages().clone(), cx, ); + let request = acp::PromptRequest { + prompt: message.clone(), + session_id: self.session_id.clone(), + }; + let git_store = self.project.read(cx).git_store().clone(); + + let message_id = if self + .connection + .session_editor(&self.session_id, cx) + .is_some() + { + Some(UserMessageId::new()) + } else { + None + }; self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { content: block }), + AgentThreadEntry::UserMessage(UserMessage { + id: message_id.clone(), + content: block, + chunks: message, + checkpoint: None, + }), cx, ); + + self.run_turn(cx, async move |this, cx| { + let old_checkpoint = git_store + .update(cx, |git, cx| git.checkpoint(cx))? + .await + .context("failed to get old checkpoint") + .log_err(); + this.update(cx, |this, cx| { + if let Some((_ix, message)) = this.last_user_message() { + message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint { + git_checkpoint, + show: false, + }); + } + this.connection.prompt(message_id, request, cx) + })? + .await + }) + } + + pub fn resume(&mut self, cx: &mut Context) -> BoxFuture<'static, Result<()>> { + self.run_turn(cx, async move |this, cx| { + this.update(cx, |this, cx| { + this.connection + .resume(&this.session_id, cx) + .map(|resume| resume.run(cx)) + })? + .context("resuming a session is not supported")? + .await + }) + } + + fn run_turn( + &mut self, + cx: &mut Context, + f: impl 'static + AsyncFnOnce(WeakEntity, &mut AsyncApp) -> Result, + ) -> BoxFuture<'static, Result<()>> { self.clear_completed_plan_entries(cx); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); self.send_task = Some(cx.spawn(async move |this, cx| { - async { - cancel_task.await; - - let result = this - .update(cx, |this, cx| { - this.connection.prompt( - acp::PromptRequest { - prompt: message, - session_id: this.session_id.clone(), - }, - cx, - ) - })? - .await; - - tx.send(result).log_err(); - - anyhow::Ok(()) - } - .await - .log_err(); + cancel_task.await; + tx.send(f(this, cx).await).ok(); })); - cx.spawn(async move |this, cx| match rx.await { - Ok(Err(e)) => { - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) - .log_err(); - Err(e)? - } - result => { - let cancelled = matches!( - result, - Ok(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Cancelled - })) - ); + cx.spawn(async move |this, cx| { + let response = rx.await; - // We only take the task if the current prompt wasn't cancelled. - // - // This prompt may have been cancelled because another one was sent - // while it was still generating. In these cases, dropping `send_task` - // would cause the next generation to be cancelled. - if !cancelled { - this.update(cx, |this, _cx| this.send_task.take()).ok(); + this.update(cx, |this, cx| this.update_last_checkpoint(cx))? + .await?; + + this.update(cx, |this, cx| { + match response { + Ok(Err(e)) => { + this.send_task.take(); + cx.emit(AcpThreadEvent::Error); + Err(e) + } + result => { + let cancelled = matches!( + result, + Ok(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled + })) + ); + + // We only take the task if the current prompt wasn't cancelled. + // + // This prompt may have been cancelled because another one was sent + // while it was still generating. In these cases, dropping `send_task` + // would cause the next generation to be cancelled. + if !cancelled { + this.send_task.take(); + } + + cx.emit(AcpThreadEvent::Stopped); + Ok(()) + } } - - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) - .log_err(); - Ok(()) - } + })? }) .boxed() } @@ -1072,6 +1308,122 @@ impl AcpThread { cx.foreground_executor().spawn(send_task) } + /// Rewinds this thread to before the entry at `index`, removing it and all + /// subsequent entries while reverting any changes made from that point. + pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { + let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else { + return Task::ready(Err(anyhow!("not supported"))); + }; + let Some(message) = self.user_message(&id) else { + return Task::ready(Err(anyhow!("message not found"))); + }; + + let checkpoint = message + .checkpoint + .as_ref() + .map(|c| c.git_checkpoint.clone()); + + let git_store = self.project.read(cx).git_store().clone(); + cx.spawn(async move |this, cx| { + if let Some(checkpoint) = checkpoint { + git_store + .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))? + .await?; + } + + cx.update(|cx| session_editor.truncate(id.clone(), cx))? + .await?; + this.update(cx, |this, cx| { + if let Some((ix, _)) = this.user_message_mut(&id) { + let range = ix..this.entries.len(); + this.entries.truncate(ix); + cx.emit(AcpThreadEvent::EntriesRemoved(range)); + } + }) + }) + } + + fn update_last_checkpoint(&mut self, cx: &mut Context) -> Task> { + let git_store = self.project.read(cx).git_store().clone(); + + let old_checkpoint = if let Some((_, message)) = self.last_user_message() { + if let Some(checkpoint) = message.checkpoint.as_ref() { + checkpoint.git_checkpoint.clone() + } else { + return Task::ready(Ok(())); + } + } else { + return Task::ready(Ok(())); + }; + + let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); + cx.spawn(async move |this, cx| { + let new_checkpoint = new_checkpoint + .await + .context("failed to get new checkpoint") + .log_err(); + if let Some(new_checkpoint) = new_checkpoint { + let equal = git_store + .update(cx, |git, cx| { + git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) + })? + .await + .unwrap_or(true); + this.update(cx, |this, cx| { + let (ix, message) = this.last_user_message().context("no user message")?; + let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?; + checkpoint.show = !equal; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + anyhow::Ok(()) + })??; + } + + Ok(()) + }) + } + + fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> { + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + Some((ix, message)) + } else { + None + } + }) + } + + fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { + self.entries.iter().find_map(|entry| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some(message) + } else { + None + } + } else { + None + } + }) + } + + fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { + self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some((ix, message)) + } else { + None + } + } else { + None + } + }) + } + pub fn read_text_file( &self, path: PathBuf, @@ -1225,6 +1577,48 @@ impl AcpThread { } } +fn markdown_for_raw_output( + raw_output: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Option> { + match raw_output { + serde_json::Value::Null => None, + serde_json::Value::Bool(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::Number(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::String(value) => Some(cx.new(|cx| { + Markdown::new( + value.clone().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + value => Some(cx.new(|cx| { + Markdown::new( + format!("```json\n{}\n```", value).into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + } +} + #[cfg(test)] mod tests { use super::*; @@ -1232,13 +1626,19 @@ mod tests { use futures::{channel::mpsc, future::LocalBoxFuture, select}; use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; - use project::FakeFs; + use project::{FakeFs, Fs}; use rand::Rng as _; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{cell::RefCell, rc::Rc, time::Duration}; - + use std::{ + any::Any, + cell::RefCell, + path::Path, + rc::Rc, + sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + time::Duration, + }; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1259,17 +1659,14 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); // Test creating a new user message thread.update(cx, |thread, cx| { thread.push_user_content_block( + None, acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "Hello, ".to_string(), @@ -1281,6 +1678,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, None); assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); } else { panic!("Expected UserMessage"); @@ -1288,8 +1686,10 @@ mod tests { }); // Test appending to existing user message + let message_1_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_1_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "world!".to_string(), @@ -1301,6 +1701,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, Some(message_1_id)); assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); } else { panic!("Expected UserMessage"); @@ -1319,8 +1720,10 @@ mod tests { ); }); + let message_2_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_2_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "New user message".to_string(), @@ -1332,6 +1735,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 3); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.id, Some(message_2_id)); assert_eq!(user_msg.content.to_markdown(cx), "New user message"); } else { panic!("Expected UserMessage at index 2"); @@ -1375,11 +1779,7 @@ mod tests { )); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1462,7 +1862,7 @@ mod tests { .unwrap(); let thread = cx - .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -1525,11 +1925,7 @@ mod tests { })); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1637,10 +2033,11 @@ mod tests { } })); - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) .await .unwrap(); @@ -1648,6 +2045,180 @@ mod tests { assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); } + #[gpui::test(iterations = 10)] + async fn test_checkpoints(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/test"), + json!({ + ".git": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + + let simulate_changes = Arc::new(AtomicBool::new(true)); + let next_filename = Arc::new(AtomicUsize::new(0)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + let fs = fs.clone(); + move |request, thread, mut cx| { + let fs = fs.clone(); + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + async move { + if simulate_changes.load(SeqCst) { + let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst)); + fs.write(Path::new(&filename), b"").await?; + } + + let acp::ContentBlock::Text(content) = &request.prompt[0] else { + panic!("expected text content block"); + }; + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: content.text.to_uppercase().into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Checkpoint isn't stored when there are no changes. + simulate_changes.store(false, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + ## User + + dolor + + ## Assistant + + DOLOR + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Rewinding the conversation truncates the history and restores the checkpoint. + thread + .update(cx, |thread, cx| { + let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { + panic!("unexpected entries {:?}", thread.entries) + }; + thread.rewind(message.id.clone().unwrap(), cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1729,7 +2300,7 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::AsyncApp, + cx: &mut gpui::App, ) -> Task>> { let session_id = acp::SessionId( rand::thread_rng() @@ -1739,9 +2310,8 @@ mod tests { .collect::() .into(), ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); + let thread = + cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } @@ -1756,6 +2326,7 @@ mod tests { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -1784,5 +2355,29 @@ mod tests { }) .detach(); } + + fn session_editor( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(FakeAgentSessionEditor { + _session_id: session_id.clone(), + })) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + struct FakeAgentSessionEditor { + _session_id: acp::SessionId, + } + + impl AgentSessionEditor for FakeAgentSessionEditor { + fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index cf06563bee..7497d2309f 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,18 +1,98 @@ -use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; - +use crate::AcpThread; use agent_client_protocol::{self as acp}; use anyhow::Result; -use gpui::{AsyncApp, Entity, Task}; -use language_model::LanguageModel; +use collections::IndexMap; +use gpui::{Entity, SharedString, Task}; use project::Project; -use ui::App; +use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use ui::{App, IconName}; +use uuid::Uuid; -use crate::AcpThread; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UserMessageId(Arc); + +impl UserMessageId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} + +pub trait AgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut App, + ) -> Task>>; + + fn auth_methods(&self) -> &[acp::AuthMethod]; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; + + fn prompt( + &self, + user_message_id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task>; + + fn resume( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + None + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + fn session_editor( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + None + } + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None + } + + fn into_any(self: Rc) -> Rc; +} + +impl dyn AgentConnection { + pub fn downcast(self: Rc) -> Option> { + self.into_any().downcast().ok() + } +} + +pub trait AgentSessionEditor { + fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +} + +pub trait AgentSessionResume { + fn run(&self, cx: &mut App) -> Task>; +} + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} /// Trait for agents that support listing, selecting, and querying language models. /// /// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. -pub trait ModelSelector: 'static { +pub trait AgentModelSelector: 'static { /// Lists all available language models for this agent. /// /// # Parameters @@ -20,7 +100,7 @@ pub trait ModelSelector: 'static { /// /// # Returns /// A task resolving to the list of models or an error (e.g., if no models are configured). - fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + fn list_models(&self, cx: &mut App) -> Task>; /// Selects a model for a specific session (thread). /// @@ -37,8 +117,8 @@ pub trait ModelSelector: 'static { fn select_model( &self, session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, + model_id: AgentModelId, + cx: &mut App, ) -> Task>; /// Retrieves the currently selected model for a specific session (thread). @@ -52,42 +132,207 @@ pub trait ModelSelector: 'static { fn selected_model( &self, session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>>; + cx: &mut App, + ) -> Task>; + + /// Whenever the model list is updated the receiver will be notified. + fn watch(&self, cx: &mut App) -> watch::Receiver<()>; } -pub trait AgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>>; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelId(pub SharedString); - fn auth_methods(&self) -> &[acp::AuthMethod]; +impl std::ops::Deref for AgentModelId { + type Target = SharedString; - fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) - -> Task>; - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); - - /// Returns this agent as an [Rc] if the model selection capability is supported. - /// - /// If the agent does not support model selection, returns [None]. - /// This allows sharing the selector in UI components. - fn model_selector(&self) -> Option> { - None // Default impl for agents that don't support it + fn deref(&self) -> &Self::Target { + &self.0 } } -#[derive(Debug)] -pub struct AuthRequired; - -impl Error for AuthRequired {} -impl fmt::Display for AuthRequired { +impl fmt::Display for AgentModelId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") + self.0.fmt(f) } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AgentModelInfo { + pub id: AgentModelId, + pub name: SharedString, + pub icon: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelGroupName(pub SharedString); + +#[derive(Debug, Clone)] +pub enum AgentModelList { + Flat(Vec), + Grouped(IndexMap>), +} + +impl AgentModelList { + pub fn is_empty(&self) -> bool { + match self { + AgentModelList::Flat(models) => models.is_empty(), + AgentModelList::Grouped(groups) => groups.is_empty(), + } + } +} + +#[cfg(feature = "test-support")] +mod test_support { + use std::sync::Arc; + + use collections::HashMap; + use futures::future::try_join_all; + use gpui::{AppContext as _, WeakEntity}; + use parking_lot::Mutex; + + use super::*; + + #[derive(Clone, Default)] + pub struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + next_prompt_updates: Arc>>, + } + + impl StubAgentConnection { + pub fn new() -> Self { + Self { + next_prompt_updates: Default::default(), + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + pub fn set_next_prompt_updates(&self, updates: Vec) { + *self.next_prompt_updates.lock() = updates; + } + + pub fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + + pub fn send_update( + &self, + session_id: acp::SessionId, + update: acp::SessionUpdate, + cx: &mut App, + ) { + self.sessions + .lock() + .get(&session_id) + .unwrap() + .update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + }) + .unwrap(); + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); + let thread = + cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in self.next_prompt_updates.lock().drain(..) { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call.clone().into(), + options.clone(), + cx, + ) + })?; + permission?.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn session_editor( + &self, + _session_id: &agent_client_protocol::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(StubAgentSessionEditor)) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + struct StubAgentSessionEditor; + + impl AgentSessionEditor for StubAgentSessionEditor { + fn truncate(&self, _: UserMessageId, _: &mut App) -> Task> { + Task::ready(Ok(())) + } + } +} + +#[cfg(feature = "test-support")] +pub use test_support::*; diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index 9cc6271360..a2c2d6c322 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -174,6 +174,10 @@ impl Diff { buffer_text ) } + + pub fn has_revealed_range(&self, cx: &App) -> bool { + self.multibuffer().read(cx).excerpt_paths().next().is_some() + } } pub struct PendingDiff { diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs new file mode 100644 index 0000000000..b9b021c4ca --- /dev/null +++ b/crates/acp_thread/src/mention.rs @@ -0,0 +1,460 @@ +use agent::ThreadId; +use anyhow::{Context as _, Result, bail}; +use file_icons::FileIcons; +use prompt_store::{PromptId, UserPromptId}; +use std::{ + fmt, + ops::Range, + path::{Path, PathBuf}, + str::FromStr, +}; +use ui::{App, IconName, SharedString}; +use url::Url; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MentionUri { + File { + abs_path: PathBuf, + is_directory: bool, + }, + Symbol { + path: PathBuf, + name: String, + line_range: Range, + }, + Thread { + id: ThreadId, + name: String, + }, + TextThread { + path: PathBuf, + name: String, + }, + Rule { + id: PromptId, + name: String, + }, + Selection { + path: PathBuf, + line_range: Range, + }, + Fetch { + url: Url, + }, +} + +impl MentionUri { + pub fn parse(input: &str) -> Result { + let url = url::Url::parse(input)?; + let path = url.path(); + match url.scheme() { + "file" => { + if let Some(fragment) = url.fragment() { + let range = fragment + .strip_prefix("L") + .context("Line range must start with \"L\"")?; + let (start, end) = range + .split_once(":") + .context("Line range must use colon as separator")?; + let line_range = start + .parse::() + .context("Parsing line range start")? + .checked_sub(1) + .context("Line numbers should be 1-based")? + ..end + .parse::() + .context("Parsing line range end")? + .checked_sub(1) + .context("Line numbers should be 1-based")?; + if let Some(name) = single_query_param(&url, "symbol")? { + Ok(Self::Symbol { + name, + path: path.into(), + line_range, + }) + } else { + Ok(Self::Selection { + path: path.into(), + line_range, + }) + } + } else { + let file_path = + PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path)); + let is_directory = input.ends_with("/"); + + Ok(Self::File { + abs_path: file_path, + is_directory, + }) + } + } + "zed" => { + if let Some(thread_id) = path.strip_prefix("/agent/thread/") { + let name = single_query_param(&url, "name")?.context("Missing thread name")?; + Ok(Self::Thread { + id: thread_id.into(), + name, + }) + } else if let Some(path) = path.strip_prefix("/agent/text-thread/") { + let name = single_query_param(&url, "name")?.context("Missing thread name")?; + Ok(Self::TextThread { + path: path.into(), + name, + }) + } else if let Some(rule_id) = path.strip_prefix("/agent/rule/") { + let name = single_query_param(&url, "name")?.context("Missing rule name")?; + let rule_id = UserPromptId(rule_id.parse()?); + Ok(Self::Rule { + id: rule_id.into(), + name, + }) + } else { + bail!("invalid zed url: {:?}", input); + } + } + "http" | "https" => Ok(MentionUri::Fetch { url }), + other => bail!("unrecognized scheme {:?}", other), + } + } + + pub fn name(&self) -> String { + match self { + MentionUri::File { abs_path, .. } => abs_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .into_owned(), + MentionUri::Symbol { name, .. } => name.clone(), + MentionUri::Thread { name, .. } => name.clone(), + MentionUri::TextThread { name, .. } => name.clone(), + MentionUri::Rule { name, .. } => name.clone(), + MentionUri::Selection { + path, line_range, .. + } => selection_name(path, line_range), + MentionUri::Fetch { url } => url.to_string(), + } + } + + pub fn icon_path(&self, cx: &mut App) -> SharedString { + match self { + MentionUri::File { + abs_path, + is_directory, + } => { + if *is_directory { + FileIcons::get_folder_icon(false, cx) + .unwrap_or_else(|| IconName::Folder.path().into()) + } else { + FileIcons::get_icon(&abs_path, cx) + .unwrap_or_else(|| IconName::File.path().into()) + } + } + MentionUri::Symbol { .. } => IconName::Code.path().into(), + MentionUri::Thread { .. } => IconName::Thread.path().into(), + MentionUri::TextThread { .. } => IconName::Thread.path().into(), + MentionUri::Rule { .. } => IconName::Reader.path().into(), + MentionUri::Selection { .. } => IconName::Reader.path().into(), + MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(), + } + } + + pub fn as_link<'a>(&'a self) -> MentionLink<'a> { + MentionLink(self) + } + + pub fn to_uri(&self) -> Url { + match self { + MentionUri::File { + abs_path, + is_directory, + } => { + let mut url = Url::parse("file:///").unwrap(); + let mut path = abs_path.to_string_lossy().to_string(); + if *is_directory && !path.ends_with("/") { + path.push_str("/"); + } + url.set_path(&path); + url + } + MentionUri::Symbol { + path, + name, + line_range, + } => { + let mut url = Url::parse("file:///").unwrap(); + url.set_path(&path.to_string_lossy()); + url.query_pairs_mut().append_pair("symbol", name); + url.set_fragment(Some(&format!( + "L{}:{}", + line_range.start + 1, + line_range.end + 1 + ))); + url + } + MentionUri::Selection { path, line_range } => { + let mut url = Url::parse("file:///").unwrap(); + url.set_path(&path.to_string_lossy()); + url.set_fragment(Some(&format!( + "L{}:{}", + line_range.start + 1, + line_range.end + 1 + ))); + url + } + MentionUri::Thread { name, id } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/thread/{id}")); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::TextThread { path, name } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy())); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::Rule { name, id } => { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path(&format!("/agent/rule/{id}")); + url.query_pairs_mut().append_pair("name", name); + url + } + MentionUri::Fetch { url } => url.clone(), + } + } +} + +impl FromStr for MentionUri { + type Err = anyhow::Error; + + fn from_str(s: &str) -> anyhow::Result { + Self::parse(s) + } +} + +pub struct MentionLink<'a>(&'a MentionUri); + +impl fmt::Display for MentionLink<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[@{}]({})", self.0.name(), self.0.to_uri()) + } +} + +fn single_query_param(url: &Url, name: &'static str) -> Result> { + let pairs = url.query_pairs().collect::>(); + match pairs.as_slice() { + [] => Ok(None), + [(k, v)] => { + if k != name { + bail!("invalid query parameter") + } + + Ok(Some(v.to_string())) + } + _ => bail!("too many query pairs"), + } +} + +pub fn selection_name(path: &Path, line_range: &Range) -> String { + format!( + "{} ({}:{})", + path.file_name().unwrap_or_default().display(), + line_range.start + 1, + line_range.end + 1 + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_file_uri() { + let file_uri = "file:///path/to/file.rs"; + let parsed = MentionUri::parse(file_uri).unwrap(); + match &parsed { + MentionUri::File { + abs_path, + is_directory, + } => { + assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs"); + assert!(!is_directory); + } + _ => panic!("Expected File variant"), + } + assert_eq!(parsed.to_uri().to_string(), file_uri); + } + + #[test] + fn test_parse_directory_uri() { + let file_uri = "file:///path/to/dir/"; + let parsed = MentionUri::parse(file_uri).unwrap(); + match &parsed { + MentionUri::File { + abs_path, + is_directory, + } => { + assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/"); + assert!(is_directory); + } + _ => panic!("Expected File variant"), + } + assert_eq!(parsed.to_uri().to_string(), file_uri); + } + + #[test] + fn test_to_directory_uri_with_slash() { + let uri = MentionUri::File { + abs_path: PathBuf::from("/path/to/dir/"), + is_directory: true, + }; + assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/"); + } + + #[test] + fn test_to_directory_uri_without_slash() { + let uri = MentionUri::File { + abs_path: PathBuf::from("/path/to/dir"), + is_directory: true, + }; + assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/"); + } + + #[test] + fn test_parse_symbol_uri() { + let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20"; + let parsed = MentionUri::parse(symbol_uri).unwrap(); + match &parsed { + MentionUri::Symbol { + path, + name, + line_range, + } => { + assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(name, "MySymbol"); + assert_eq!(line_range.start, 9); + assert_eq!(line_range.end, 19); + } + _ => panic!("Expected Symbol variant"), + } + assert_eq!(parsed.to_uri().to_string(), symbol_uri); + } + + #[test] + fn test_parse_selection_uri() { + let selection_uri = "file:///path/to/file.rs#L5:15"; + let parsed = MentionUri::parse(selection_uri).unwrap(); + match &parsed { + MentionUri::Selection { path, line_range } => { + assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"); + assert_eq!(line_range.start, 4); + assert_eq!(line_range.end, 14); + } + _ => panic!("Expected Selection variant"), + } + assert_eq!(parsed.to_uri().to_string(), selection_uri); + } + + #[test] + fn test_parse_thread_uri() { + let thread_uri = "zed:///agent/thread/session123?name=Thread+name"; + let parsed = MentionUri::parse(thread_uri).unwrap(); + match &parsed { + MentionUri::Thread { + id: thread_id, + name, + } => { + assert_eq!(thread_id.to_string(), "session123"); + assert_eq!(name, "Thread name"); + } + _ => panic!("Expected Thread variant"), + } + assert_eq!(parsed.to_uri().to_string(), thread_uri); + } + + #[test] + fn test_parse_rule_uri() { + let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some+rule"; + let parsed = MentionUri::parse(rule_uri).unwrap(); + match &parsed { + MentionUri::Rule { id, name } => { + assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52"); + assert_eq!(name, "Some rule"); + } + _ => panic!("Expected Rule variant"), + } + assert_eq!(parsed.to_uri().to_string(), rule_uri); + } + + #[test] + fn test_parse_fetch_http_uri() { + let http_uri = "http://example.com/path?query=value#fragment"; + let parsed = MentionUri::parse(http_uri).unwrap(); + match &parsed { + MentionUri::Fetch { url } => { + assert_eq!(url.to_string(), http_uri); + } + _ => panic!("Expected Fetch variant"), + } + assert_eq!(parsed.to_uri().to_string(), http_uri); + } + + #[test] + fn test_parse_fetch_https_uri() { + let https_uri = "https://example.com/api/endpoint"; + let parsed = MentionUri::parse(https_uri).unwrap(); + match &parsed { + MentionUri::Fetch { url } => { + assert_eq!(url.to_string(), https_uri); + } + _ => panic!("Expected Fetch variant"), + } + assert_eq!(parsed.to_uri().to_string(), https_uri); + } + + #[test] + fn test_invalid_scheme() { + assert!(MentionUri::parse("ftp://example.com").is_err()); + assert!(MentionUri::parse("ssh://example.com").is_err()); + assert!(MentionUri::parse("unknown://example.com").is_err()); + } + + #[test] + fn test_invalid_zed_path() { + assert!(MentionUri::parse("zed:///invalid/path").is_err()); + assert!(MentionUri::parse("zed:///agent/unknown/test").is_err()); + } + + #[test] + fn test_invalid_line_range_format() { + // Missing L prefix + assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err()); + + // Missing colon separator + assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err()); + + // Invalid numbers + assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err()); + } + + #[test] + fn test_invalid_query_parameters() { + // Invalid query parameter name + assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err()); + + // Too many query parameters + assert!( + MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err() + ); + } + + #[test] + fn test_zero_based_line_numbers() { + // Test that 0-based line numbers are rejected (should be 1-based) + assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err()); + assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err()); + } +} diff --git a/crates/acp_thread/src/terminal.rs b/crates/acp_thread/src/terminal.rs new file mode 100644 index 0000000000..41d7fb89bb --- /dev/null +++ b/crates/acp_thread/src/terminal.rs @@ -0,0 +1,93 @@ +use gpui::{App, AppContext, Context, Entity}; +use language::LanguageRegistry; +use markdown::Markdown; +use std::{path::PathBuf, process::ExitStatus, sync::Arc, time::Instant}; + +pub struct Terminal { + command: Entity, + working_dir: Option, + terminal: Entity, + started_at: Instant, + output: Option, +} + +pub struct TerminalOutput { + pub ended_at: Instant, + pub exit_status: Option, + pub was_content_truncated: bool, + pub original_content_len: usize, + pub content_line_count: usize, + pub finished_with_empty_output: bool, +} + +impl Terminal { + pub fn new( + command: String, + working_dir: Option, + terminal: Entity, + language_registry: Arc, + cx: &mut Context, + ) -> Self { + Self { + command: cx.new(|cx| { + Markdown::new( + format!("```\n{}\n```", command).into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + working_dir, + terminal, + started_at: Instant::now(), + output: None, + } + } + + pub fn finish( + &mut self, + exit_status: Option, + original_content_len: usize, + truncated_content_len: usize, + content_line_count: usize, + finished_with_empty_output: bool, + cx: &mut Context, + ) { + self.output = Some(TerminalOutput { + ended_at: Instant::now(), + exit_status, + was_content_truncated: truncated_content_len < original_content_len, + original_content_len, + content_line_count, + finished_with_empty_output, + }); + cx.notify(); + } + + pub fn command(&self) -> &Entity { + &self.command + } + + pub fn working_dir(&self) -> &Option { + &self.working_dir + } + + pub fn started_at(&self) -> Instant { + self.started_at + } + + pub fn output(&self) -> Option<&TerminalOutput> { + self.output.as_ref() + } + + pub fn inner(&self) -> &Entity { + &self.terminal + } + + pub fn to_markdown(&self, cx: &App) -> String { + format!( + "Terminal:\n```\n{}\n```\n", + self.terminal.read(cx).get_content() + ) + } +} diff --git a/crates/action_log/Cargo.toml b/crates/action_log/Cargo.toml new file mode 100644 index 0000000000..1a389e8859 --- /dev/null +++ b/crates/action_log/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "action_log" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lib] +path = "src/action_log.rs" + +[lints] +workspace = true + +[dependencies] +anyhow.workspace = true +buffer_diff.workspace = true +clock.workspace = true +collections.workspace = true +futures.workspace = true +gpui.workspace = true +language.workspace = true +project.workspace = true +text.workspace = true +util.workspace = true +watch.workspace = true +workspace-hack.workspace = true + + +[dev-dependencies] +buffer_diff = { workspace = true, features = ["test-support"] } +collections = { workspace = true, features = ["test-support"] } +clock = { workspace = true, features = ["test-support"] } +ctor.workspace = true +gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } +log.workspace = true +pretty_assertions.workspace = true +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } +text = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/indexed_docs/LICENSE-GPL b/crates/action_log/LICENSE-GPL similarity index 100% rename from crates/indexed_docs/LICENSE-GPL rename to crates/action_log/LICENSE-GPL diff --git a/crates/assistant_tool/src/action_log.rs b/crates/action_log/src/action_log.rs similarity index 99% rename from crates/assistant_tool/src/action_log.rs rename to crates/action_log/src/action_log.rs index 025aba060d..c4eaffc228 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -17,8 +17,6 @@ use util::{ pub struct ActionLog { /// Buffers that we want to notify the model about when they change. tracked_buffers: BTreeMap, TrackedBuffer>, - /// Has the model edited a file since it last checked diagnostics? - edited_since_project_diagnostics_check: bool, /// The project this action log is associated with project: Entity, } @@ -28,7 +26,6 @@ impl ActionLog { pub fn new(project: Entity) -> Self { Self { tracked_buffers: BTreeMap::default(), - edited_since_project_diagnostics_check: false, project, } } @@ -37,16 +34,6 @@ impl ActionLog { &self.project } - /// Notifies a diagnostics check - pub fn checked_project_diagnostics(&mut self) { - self.edited_since_project_diagnostics_check = false; - } - - /// Returns true if any files have been edited since the last project diagnostics check - pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool { - self.edited_since_project_diagnostics_check - } - pub fn latest_snapshot(&self, buffer: &Entity) -> Option { Some(self.tracked_buffers.get(buffer)?.snapshot.clone()) } @@ -543,14 +530,11 @@ impl ActionLog { /// Mark a buffer as created by agent, so we can refresh it in the context pub fn buffer_created(&mut self, buffer: Entity, cx: &mut Context) { - self.edited_since_project_diagnostics_check = true; self.track_buffer_internal(buffer.clone(), true, cx); } /// Mark a buffer as edited by agent, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { - self.edited_since_project_diagnostics_check = true; - let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); if let TrackedBufferStatus::Deleted = tracked_buffer.status { tracked_buffer.status = TrackedBufferStatus::Modified; diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index f8ea7173d8..7c562aaba4 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -716,18 +716,10 @@ impl ActivityIndicator { })), tooltip_message: Some(Self::version_tooltip_message(&version)), }), - AutoUpdateStatus::Updated { - binary_path, - version, - } => Some(Content { + AutoUpdateStatus::Updated { version } => Some(Content { icon: None, message: "Click to restart and update Zed".to_string(), - on_click: Some(Arc::new({ - let reload = workspace::Reload { - binary_path: Some(binary_path.clone()), - }; - move |_, _, cx| workspace::reload(&reload, cx) - })), + on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))), tooltip_message: Some(Self::version_tooltip_message(&version)), }), AutoUpdateStatus::Errored => Some(Content { diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 7bc0e82cad..53ad2f4967 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -19,6 +19,7 @@ test-support = [ ] [dependencies] +action_log.workspace = true agent_settings.workspace = true anyhow.workspace = true assistant_context.workspace = true diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index 34ea1c8df7..38e697dd9b 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -326,7 +326,7 @@ mod tests { _input: serde_json::Value, _request: Arc, _project: Entity, - _action_log: Entity, + _action_log: Entity, _model: Arc, _window: Option, _cx: &mut App, diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 85e8ac7451..22d1a72bf5 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -1,7 +1,8 @@ use std::sync::Arc; +use action_log::ActionLog; use anyhow::{Result, anyhow, bail}; -use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource}; +use assistant_tool::{Tool, ToolResult, ToolSource}; use context_server::{ContextServerId, types}; use gpui::{AnyWindowHandle, App, Entity, Task}; use icons::IconName; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 048aa4245d..f3f1088483 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -8,9 +8,10 @@ use crate::{ }, tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, }; +use action_log::ActionLog; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; +use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; @@ -843,11 +844,17 @@ impl Thread { .await .unwrap_or(false); - if !equal { - this.update(cx, |this, cx| { - this.insert_checkpoint(pending_checkpoint, cx) - })?; - } + this.update(cx, |this, cx| { + this.pending_checkpoint = if equal { + Some(pending_checkpoint) + } else { + this.insert_checkpoint(pending_checkpoint, cx); + Some(ThreadCheckpoint { + message_id: this.next_message_id, + git_checkpoint: final_checkpoint, + }) + } + })?; Ok(()) } @@ -2267,6 +2274,15 @@ impl Thread { max_attempts: 3, }) } + Other(err) + if err.is::() + || err.is::() => + { + // Retrying won't help for Payment Required or Model Request Limit errors (where + // the user must upgrade to usage-based billing to get more requests, or else wait + // for a significant amount of time for the request limit to reset). + None + } // Conservatively assume that any other errors are non-retryable HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { delay: BASE_RETRY_DELAY, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cc7cb50c91..12c94a522d 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -205,6 +205,22 @@ impl ThreadStore { (this, ready_rx) } + #[cfg(any(test, feature = "test-support"))] + pub fn fake(project: Entity, cx: &mut App) -> Self { + Self { + project, + tools: cx.new(|_| ToolWorkingSet::default()), + prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()), + prompt_store: None, + context_server_tool_ids: HashMap::default(), + threads: Vec::new(), + project_context: SharedProjectContext::default(), + reload_system_prompt_tx: mpsc::channel(0).0, + _reload_system_prompt_task: Task::ready(()), + _subscriptions: vec![], + } + } + fn handle_project_event( &mut self, _project: Entity, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 3e19895a31..ac1840e5e5 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "agent2" version = "0.1.0" -edition = "2021" +edition.workspace = true +publish.workspace = true license = "GPL-3.0-or-later" -publish = false [lib] path = "src/agent2.rs" @@ -13,25 +13,32 @@ workspace = true [dependencies] acp_thread.workspace = true +action_log.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true +chrono.workspace = true cloud_llm_client.workspace = true collections.workspace = true +context_server.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } +html_to_markdown.workspace = true +http_client.workspace = true indoc.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true language_models.workspace = true log.workspace = true +open.workspace = true paths.workspace = true +portable-pty.workspace = true project.workspace = true prompt_store.workspace = true rust-embed.workspace = true @@ -40,16 +47,23 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +task.workspace = true +terminal.workspace = true +text.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true +web_search.workspace = true +which.workspace = true workspace-hack.workspace = true [dev-dependencies] ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } +context_server = { workspace = true, "features" = ["test-support"] } +editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } @@ -57,8 +71,14 @@ gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } lsp = { workspace = true, "features" = ["test-support"] } +pretty_assertions.workspace = true project = { workspace = true, "features" = ["test-support"] } reqwest_client.workspace = true settings = { workspace = true, "features" = ["test-support"] } +tempfile.workspace = true +terminal = { workspace = true, "features" = ["test-support"] } +theme = { workspace = true, "features" = ["test-support"] } +tree-sitter-rust.workspace = true +unindent = { workspace = true } worktree = { workspace = true, "features" = ["test-support"] } -pretty_assertions.workspace = true +zlog.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index df061cd5ed..d63e3f8134 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,17 +1,27 @@ -use crate::{templates::Templates, AgentResponseEvent, Thread}; -use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; -use acp_thread::ModelSelector; +use crate::{ + AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, + DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, + MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, + ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, +}; +use acp_thread::AgentModelSelector; use agent_client_protocol as acp; -use anyhow::{anyhow, Context as _, Result}; -use futures::{future, StreamExt}; +use agent_settings::AgentSettings; +use anyhow::{Context as _, Result, anyhow}; +use collections::{HashSet, IndexMap}; +use fs::Fs; +use futures::channel::mpsc; +use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; +use settings::update_settings_file; +use std::any::Any; use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; @@ -44,6 +54,104 @@ struct Session { _subscription: Subscription, } +pub struct LanguageModels { + /// Access language model by ID + models: HashMap>, + /// Cached list for returning language model information + model_list: acp_thread::AgentModelList, + refresh_models_rx: watch::Receiver<()>, + refresh_models_tx: watch::Sender<()>, +} + +impl LanguageModels { + fn new(cx: &App) -> Self { + let (refresh_models_tx, refresh_models_rx) = watch::channel(()); + let mut this = Self { + models: HashMap::default(), + model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), + refresh_models_rx, + refresh_models_tx, + }; + this.refresh_list(cx); + this + } + + fn refresh_list(&mut self, cx: &App) { + let providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .into_iter() + .filter(|provider| provider.is_authenticated(cx)) + .collect::>(); + + let mut language_model_list = IndexMap::default(); + let mut recommended_models = HashSet::default(); + + let mut recommended = Vec::new(); + for provider in &providers { + for model in provider.recommended_models(cx) { + recommended_models.insert(model.id()); + recommended.push(Self::map_language_model_to_info(&model, &provider)); + } + } + if !recommended.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName("Recommended".into()), + recommended, + ); + } + + let mut models = HashMap::default(); + for provider in providers { + let mut provider_models = Vec::new(); + for model in provider.provided_models(cx) { + let model_info = Self::map_language_model_to_info(&model, &provider); + let model_id = model_info.id.clone(); + if !recommended_models.contains(&model.id()) { + provider_models.push(model_info); + } + models.insert(model_id, model); + } + if !provider_models.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName(provider.name().0.clone()), + provider_models, + ); + } + } + + self.models = models; + self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); + self.refresh_models_tx.send(()).ok(); + } + + fn watch(&self) -> watch::Receiver<()> { + self.refresh_models_rx.clone() + } + + pub fn model_from_id( + &self, + model_id: &acp_thread::AgentModelId, + ) -> Option> { + self.models.get(model_id).cloned() + } + + fn map_language_model_to_info( + model: &Arc, + provider: &Arc, + ) -> acp_thread::AgentModelInfo { + acp_thread::AgentModelInfo { + id: Self::model_id(model), + name: model.name().0, + icon: Some(provider.icon()), + } + } + + fn model_id(model: &Arc) -> acp_thread::AgentModelId { + acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) + } +} + pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, @@ -51,10 +159,14 @@ pub struct NativeAgent { project_context: Rc>, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, + context_server_registry: Entity, /// Shared templates for all threads templates: Arc, + /// Cached model information + models: LanguageModels, project: Entity, prompt_store: Option>, + fs: Arc, _subscriptions: Vec, } @@ -63,6 +175,7 @@ impl NativeAgent { project: Entity, templates: Arc, prompt_store: Option>, + fs: Arc, cx: &mut AsyncApp, ) -> Result> { log::info!("Creating new NativeAgent"); @@ -72,7 +185,13 @@ impl NativeAgent { .await; cx.new(|cx| { - let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + let mut subscriptions = vec![ + cx.subscribe(&project, Self::handle_project_event), + cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_models_updated_event, + ), + ]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) } @@ -86,14 +205,23 @@ impl NativeAgent { _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await }), + context_server_registry: cx.new(|cx| { + ContextServerRegistry::new(project.read(cx).context_server_store(), cx) + }), templates, + models: LanguageModels::new(cx), project, prompt_store, + fs, _subscriptions: subscriptions, } }) } + pub fn models(&self) -> &LanguageModels { + &self.models + } + async fn maintain_project_context( this: WeakEntity, mut needs_refresh: watch::Receiver<()>, @@ -289,211 +417,63 @@ impl NativeAgent { ) { self.project_context_needs_refresh.send(()).ok(); } + + fn handle_models_updated_event( + &mut self, + _registry: Entity, + _event: &language_model::Event, + cx: &mut Context, + ) { + self.models.refresh_list(cx); + for session in self.sessions.values_mut() { + session.thread.update(cx, |thread, _| { + let model_id = LanguageModels::model_id(&thread.model()); + if let Some(model) = self.models.model_from_id(&model_id) { + thread.set_model(model.clone()); + } + }); + } + } } /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); -impl ModelSelector for NativeAgentConnection { - fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { - log::debug!("NativeAgentConnection::list_models called"); - cx.spawn(async move |cx| { - cx.update(|cx| { - let registry = LanguageModelRegistry::read_global(cx); - let models = registry.available_models(cx).collect::>(); - log::info!("Found {} available models", models.len()); - if models.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(models) - } - })? - }) +impl NativeAgentConnection { + pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { + self.0 + .read(cx) + .sessions + .get(session_id) + .map(|session| session.thread.clone()) } - fn select_model( + fn run_turn( &self, session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, - ) -> Task> { - log::info!( - "Setting model for session {}: {:?}", - session_id, - model.name() - ); - let agent = self.0.clone(); - - cx.spawn(async move |cx| { - agent.update(cx, |agent, cx| { - if let Some(session) = agent.sessions.get(&session_id) { - session.thread.update(cx, |thread, _cx| { - thread.selected_model = model; - }); - Ok(()) - } else { - Err(anyhow!("Session not found")) - } - })? - }) - } - - fn selected_model( - &self, - session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>> { - let agent = self.0.clone(); - let session_id = session_id.clone(); - cx.spawn(async move |cx| { - let thread = agent - .read_with(cx, |agent, _| { - agent - .sessions - .get(&session_id) - .map(|session| session.thread.clone()) - })? - .ok_or_else(|| anyhow::anyhow!("Session not found"))?; - let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - Ok(selected) - }) - } -} - -impl acp_thread::AgentConnection for NativeAgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let agent = self.0.clone(); - log::info!("Creating new thread for project at: {:?}", cwd); - - cx.spawn(async move |cx| { - log::debug!("Starting thread creation in async context"); - - // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); - log::info!("Created session with ID: {}", session_id); - - // Create AcpThread - let acp_thread = cx.update(|cx| { - cx.new(|cx| { - acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) - }) - })?; - let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; - - // Create Thread - let thread = agent.update( - cx, - |agent, cx: &mut gpui::Context| -> Result<_> { - // Fetch default model from registry settings - let registry = LanguageModelRegistry::read_global(cx); - - // Log available models for debugging - let available_count = registry.available_models(cx).count(); - log::debug!("Total available models: {}", available_count); - - let default_model = registry - .default_model() - .map(|configured| { - log::info!( - "Using configured default model: {:?} from provider: {:?}", - configured.model.name(), - configured.provider.name() - ); - configured.model - }) - .ok_or_else(|| { - log::warn!("No default model configured in settings"); - anyhow!("No default model configured. Please configure a default model in settings.") - })?; - - let thread = cx.new(|cx| { - let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); - thread.add_tool(ThinkingTool); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); - thread.add_tool(EditFileTool::new(cx.entity())); - thread - }); - - Ok(thread) - }, - )??; - - // Store the session - agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread, - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }) - }, - ); - })?; - - Ok(acp_thread) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] // No auth for in-process - } - - fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { - Task::ready(Ok(())) - } - - fn model_selector(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) - } - - fn prompt( - &self, - params: acp::PromptRequest, cx: &mut App, + f: impl 'static + + FnOnce( + Entity, + &mut App, + ) -> Result>>, ) -> Task> { - let session_id = params.session_id.clone(); - let agent = self.0.clone(); - log::info!("Received prompt request for session: {}", session_id); - log::debug!("Prompt blocks count: {}", params.prompt.len()); + let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + }) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + log::debug!("Found session for: {}", session_id); + let mut response_stream = match f(thread, cx) { + Ok(stream) => stream, + Err(err) => return Task::ready(Err(err)), + }; cx.spawn(async move |cx| { - // Get session - let (thread, acp_thread) = agent - .update(cx, |agent, _| { - agent - .sessions - .get_mut(&session_id) - .map(|s| (s.thread.clone(), s.acp_thread.clone())) - })? - .ok_or_else(|| { - log::error!("Session not found: {}", session_id); - anyhow::anyhow!("Session not found") - })?; - log::debug!("Found session for: {}", session_id); - - // Convert prompt to message - let message = convert_prompt_to_message(params.prompt); - log::info!("Converted prompt to message: {} chars", message.len()); - log::debug!("Message content: {}", message); - - // Get model using the ModelSelector capability (always available for agent2) - // Get the selected model from the thread directly - let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - - // Send to thread - log::info!("Sending message to thread with model: {:?}", model.name()); - let mut response_stream = - thread.update(cx, |thread, cx| thread.send(model, message, cx))?; - // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { match result { @@ -534,10 +514,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { thread.request_tool_call_authorization(tool_call, options, cx) })?; cx.background_spawn(async move { - if let Some(option) = recv - .await - .context("authorization sender was dropped") - .log_err() + if let Some(recv) = recv.log_err() + && let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() { response .send(option) @@ -550,7 +531,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { AgentResponseEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) - })?; + })??; } AgentResponseEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { @@ -565,8 +546,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } Err(e) => { log::error!("Error in model response stream: {:?}", e); - // TODO: Consider sending an error message to the UI - break; + return Err(e); } } } @@ -577,6 +557,246 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }) }) } +} + +impl AgentModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut App) -> Task> { + log::debug!("NativeAgentConnection::list_models called"); + let list = self.0.read(cx).models.model_list.clone(); + Task::ready(if list.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(list) + }) + } + + fn select_model( + &self, + session_id: acp::SessionId, + model_id: acp_thread::AgentModelId, + cx: &mut App, + ) -> Task> { + log::info!("Setting model for session {}: {}", session_id, model_id); + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + + let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else { + return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); + }; + + thread.update(cx, |thread, _cx| { + thread.set_model(model.clone()); + }); + + update_settings_file::( + self.0.read(cx).fs.clone(), + cx, + move |settings, _cx| { + settings.set_model(model); + }, + ); + + Task::ready(Ok(())) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut App, + ) -> Task> { + let session_id = session_id.clone(); + + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + let model = thread.read(cx).model().clone(); + let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) + else { + return Task::ready(Err(anyhow!("Provider not found"))); + }; + Task::ready(Ok(LanguageModels::map_language_model_to_info( + &model, &provider, + ))) + } + + fn watch(&self, cx: &mut App) -> watch::Receiver<()> { + self.0.read(cx).models.watch() + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut App, + ) -> Task>> { + let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new( + "agent2", + self.clone(), + project.clone(), + session_id.clone(), + cx, + ) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + + // Create Thread + let thread = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .and_then(|default_model| { + agent + .models + .model_from_id(&LanguageModels::model_id(&default_model.model)) + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!( + "No default model. Please configure a default model in settings." + ) + })?; + + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + default_model, + cx, + ); + thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(CreateDirectoryTool::new(project.clone())); + thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); + thread.add_tool(DiagnosticsTool::new(project.clone())); + thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(GrepTool::new(project.clone())); + thread.add_tool(ListDirectoryTool::new(project.clone())); + thread.add_tool(MovePathTool::new(project.clone())); + thread.add_tool(NowTool); + thread.add_tool(OpenTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(TerminalTool::new(project.clone(), cx)); + thread.add_tool(ThinkingTool); + thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + thread + }); + + Ok(thread) + }, + )??; + + // Store the session + agent.update(cx, |agent, cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + }, + ); + })?; + + Ok(acp_thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt( + &self, + id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let id = id.expect("UserMessageId is required"); + let session_id = params.session_id.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + self.run_turn(session_id, cx, |thread, cx| { + let content: Vec = params + .prompt + .into_iter() + .map(Into::into) + .collect::>(); + log::info!("Converted prompt to message: {} chars", content.len()); + log::debug!("Message id: {:?}", id); + log::debug!("Message content: {:?}", content); + + Ok(thread.update(cx, |thread, cx| { + log::info!( + "Sending message to thread with model: {:?}", + thread.model().name() + ); + thread.send(id, content, cx) + })) + }) + } + + fn resume( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionResume { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { log::info!("Cancelling on session: {}", session_id); @@ -586,44 +806,51 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } }); } -} -/// Convert ACP content blocks to a message string -fn convert_prompt_to_message(blocks: Vec) -> String { - log::debug!("Converting {} content blocks to message", blocks.len()); - let mut message = String::new(); - - for block in blocks { - match block { - acp::ContentBlock::Text(text) => { - log::trace!("Processing text block: {} chars", text.text.len()); - message.push_str(&text.text); - } - acp::ContentBlock::ResourceLink(link) => { - log::trace!("Processing resource link: {}", link.uri); - message.push_str(&format!(" @{} ", link.uri)); - } - acp::ContentBlock::Image(_) => { - log::trace!("Processing image block"); - message.push_str(" [image] "); - } - acp::ContentBlock::Audio(_) => { - log::trace!("Processing audio block"); - message.push_str(" [audio] "); - } - acp::ContentBlock::Resource(resource) => { - log::trace!("Processing resource block: {:?}", resource.resource); - message.push_str(&format!(" [resource: {:?}] ", resource.resource)); - } - } + fn session_editor( + &self, + session_id: &agent_client_protocol::SessionId, + cx: &mut App, + ) -> Option> { + self.0.update(cx, |agent, _cx| { + agent + .sessions + .get(session_id) + .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + }) } - message + fn into_any(self: Rc) -> Rc { + self + } +} + +struct NativeAgentSessionEditor(Entity); + +impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { + fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + } +} + +struct NativeAgentSessionResume { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionResume for NativeAgentSessionResume { + fn run(&self, cx: &mut App) -> Task> { + self.connection + .run_turn(self.session_id.clone(), cx, |thread, cx| { + thread.update(cx, |thread, cx| thread.resume(cx)) + }) + } } #[cfg(test)] mod tests { use super::*; + use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; @@ -641,9 +868,15 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [], cx).await; - let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); agent.read_with(cx, |agent, _| { assert_eq!(agent.project_context.borrow().worktrees, vec![]) }); @@ -684,13 +917,127 @@ mod tests { }); } + #[gpui::test] + async fn test_listing_models(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({ "a": {} })).await; + let project = Project::test(fs.clone(), [], cx).await; + let connection = NativeAgentConnection( + NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(), + ); + + let models = cx.update(|cx| connection.list_models(cx)).await.unwrap(); + + let acp_thread::AgentModelList::Grouped(models) = models else { + panic!("Unexpected model group"); + }; + assert_eq!( + models, + IndexMap::from_iter([( + AgentModelGroupName("Fake".into()), + vec![AgentModelInfo { + id: AgentModelId("fake/fake".into()), + name: "Fake".into(), + icon: Some(ui::IconName::ZedAssistant), + }] + )]) + ); + } + + #[gpui::test] + async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_model": { + "provider": "foo", + "model": "bar" + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + + // Create the agent and connection + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + // Select a model + let model_id = AgentModelId("fake/fake".into()); + cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx)) + .await + .unwrap(); + + // Verify the thread has the selected model + agent.read_with(cx, |agent, _| { + let session = agent.sessions.get(&session_id).unwrap(); + session.thread.read_with(cx, |thread, _| { + assert_eq!(thread.model().id().0, "fake"); + }); + }); + + cx.run_until_parked(); + + // Verify settings file was updated + let settings_content = fs.load(paths::settings_file()).await.unwrap(); + let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); + + // Check that the agent settings contain the selected model + assert_eq!( + settings_json["agent"]["default_model"]["model"], + json!("fake") + ); + assert_eq!( + settings_json["agent"]["default_model"]["provider"], + json!("fake") + ); + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); Project::init_settings(cx); + agent_settings::init(cx); language::init(cx); + LanguageModelRegistry::test(cx); }); } } diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index dd0188b548..cadd88a846 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,16 +1,24 @@ -use std::path::Path; -use std::rc::Rc; +use std::{path::Path, rc::Rc, sync::Arc}; use agent_servers::AgentServer; use anyhow::Result; +use fs::Fs; use gpui::{App, Entity, Task}; use project::Project; use prompt_store::PromptStore; -use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; +use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; #[derive(Clone)] -pub struct NativeAgentServer; +pub struct NativeAgentServer { + fs: Arc, +} + +impl NativeAgentServer { + pub fn new(fs: Arc) -> Self { + Self { fs } + } +} impl AgentServer for NativeAgentServer { fn name(&self) -> &'static str { @@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer { _root_dir ); let project = project.clone(); + let fs = self.fs.clone(); let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { log::debug!("Creating templates for native agent"); @@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer { let prompt_store = prompt_store.await?; log::debug!("Creating native agent entity"); - let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; + let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?; // Create the connection wrapper let connection = NativeAgentConnection(agent); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 273da1dae5..cc8bd483bb 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,24 +1,29 @@ use super::*; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; +use action_log::ActionLog; use agent_client_protocol::{self as acp}; +use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tool::ActionLog; use client::{Client, UserStore}; -use fs::FakeFs; +use fs::{FakeFs, Fs}; use futures::channel::mpsc::UnboundedReceiver; -use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; +use gpui::{ + App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, +}; use indoc::indoc; use language_model::{ - fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult, - LanguageModelToolUse, MessageContent, Role, StopReason, + LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, + LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, + Role, StopReason, fake_provider::FakeLanguageModel, }; +use pretty_assertions::assert_eq; use project::Project; use prompt_store::ProjectContext; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; +use settings::SettingsStore; use smol::stream::StreamExt; use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; @@ -29,19 +34,23 @@ use test_tools::*; #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_echo(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; let events = thread .update(cx, |thread, cx| { - thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) + thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) }) .collect() .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] - ); + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Hello + "} + ) }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } @@ -49,18 +58,18 @@ async fn test_echo(cx: &mut TestAppContext) { #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_thinking(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; let events = thread .update(cx, |thread, cx| { thread.send( - model.clone(), - indoc! {" + UserMessageId::new(), + [indoc! {" Testing: Generate a thinking step where you just think the word 'Think', and have your final answer be 'Hello' - "}, + "}], cx, ) }) @@ -68,9 +77,10 @@ async fn test_thinking(cx: &mut TestAppContext) { .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().to_markdown(), + thread.last_message().unwrap().to_markdown(), indoc! {" - ## assistant + ## Assistant + Think Hello "} @@ -91,7 +101,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) { project_context.borrow_mut().shell = "test-shell".into(); thread.update(cx, |thread, _| thread.add_tool(EchoTool)); - thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); assert_eq!( @@ -118,18 +130,146 @@ async fn test_system_prompt(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_prompt_caching(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Send initial user message and verify it's cached + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 1"], cx) + }); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: true + }] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 1".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Send another user message and verify only the latest is cached + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 2"], cx) + }); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: true + } + ] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text( + "Response to Message 2".into(), + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Simulate a tool call and verify that the latest tool result is cached + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Use the echo tool"], cx) + }); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: EchoTool.name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "test".into(), + output: Some("test".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 1".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec!["Response to Message 2".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Use the echo tool".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: true + } + ] + ); +} + #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_basic_tool_calls(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; // Test a tool call that's likely to complete *before* streaming stops. let events = thread .update(cx, |thread, cx| { thread.add_tool(EchoTool); thread.send( - model.clone(), - "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + UserMessageId::new(), + ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], cx, ) }) @@ -143,8 +283,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.remove_tool(&AgentTool::name(&EchoTool)); thread.add_tool(DelayTool); thread.send( - model.clone(), - "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + UserMessageId::new(), + [ + "Now call the delay tool with 200ms.", + "When the timer goes off, then you echo the output of the tool.", + ], cx, ) }) @@ -152,31 +295,36 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { .await; assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); thread.update(cx, |thread, _cx| { - assert!(thread - .messages() - .last() - .unwrap() - .content - .iter() - .any(|content| { - if let MessageContent::Text(text) = content { - text.contains("Ding") - } else { - false - } - })); + assert!( + thread + .last_message() + .unwrap() + .as_agent_message() + .unwrap() + .content + .iter() + .any(|content| { + if let AgentMessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + }), + "{}", + thread.to_markdown() + ); }); } #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread.update(cx, |thread, cx| { thread.add_tool(WordListTool); - thread.send(model.clone(), "Test the word_list tool.", cx) + thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) }); let mut saw_partial_tool_use = false; @@ -184,8 +332,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message - let last_content = thread.messages().last().unwrap().content.last().unwrap(); - if let MessageContent::ToolUse(last_tool_use) = last_content { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + let last_content = agent_message.content.last().unwrap(); + if let AgentMessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_call.status == acp::ToolCallStatus::Pending { if !last_tool_use.is_input_complete @@ -223,7 +373,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread.update(cx, |thread, cx| { thread.add_tool(ToolRequiringPermission); - thread.send(model.clone(), "abc", cx) + thread.send(UserMessageId::new(), ["abc"], cx) }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -267,14 +417,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { assert_eq!( message.content, vec![ - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), output: Some("Allowed".into()) }), - MessageContent::ToolResult(LanguageModelToolResult { + language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), tool_name: ToolRequiringPermission.name().into(), is_error: true, @@ -283,6 +433,67 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { }) ] ); + + // Simulate yet another tool call. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_3".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + // Respond by always allowing tools. + let tool_call_auth_3 = next_tool_call_authorization(&mut events).await; + tool_call_auth_3 + .response + .send(tool_call_auth_3.options[0].id.clone()) + .unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + } + )] + ); + + // Simulate a final tool call, ensuring we don't trigger authorization. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_4".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: "tool_id_4".into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + } + )] + ); } #[gpui::test] @@ -290,7 +501,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -310,8 +523,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); } +#[gpui::test] +async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }); + cx.run_until_parked(); + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool.name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.end_last_completion_stream(); + + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use.clone())], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result.clone())], + cache: true + }, + ] + ); + + // Simulate reaching tool use limit. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: true + } + ] + ); + + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into())); + fake_model.end_last_completion_stream(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Done + "} + ) + }); + + // Ensure we error if calling resume when tool use limit was *not* reached. + let error = thread + .update(cx, |thread, cx| thread.resume(cx)) + .unwrap_err(); + assert_eq!( + error.to_string(), + "can only resume after tool use limit is reached" + ) +} + +#[gpui::test] +async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["abc"], cx) + }); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_id_1".into(), + name: EchoTool.name().into(), + raw_input: "{}".into(), + input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), + is_input_complete: true, + }; + let tool_result = LanguageModelToolResult { + tool_use_id: "tool_id_1".into(), + tool_name: EchoTool.name().into(), + is_error: false, + content: "def".into(), + output: Some("def".into()), + }; + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( + cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, + )); + fake_model.end_last_completion_stream(); + let last_event = events.collect::>().await.pop().unwrap(); + assert!( + last_event + .unwrap_err() + .is::() + ); + + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), vec!["ghi"], cx) + }); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["abc".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec!["ghi".into()], + cache: true + } + ] + ); +} + async fn expect_tool_call( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCall { let event = events .next() @@ -327,7 +726,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -336,7 +735,7 @@ async fn expect_tool_call_update_fields( .unwrap(); match event { AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { - return update + return update; } event => { panic!("Unexpected event {event:?}"); @@ -345,7 +744,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -375,15 +774,19 @@ async fn next_tool_call_authorization( #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; // Test concurrent tool calls with different delay times let events = thread .update(cx, |thread, cx| { thread.add_tool(DelayTool); thread.send( - model.clone(), - "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + UserMessageId::new(), + [ + "Call the delay tool twice in the same message.", + "Once with 100ms. Once with 300ms.", + "When both timers are complete, describe the outputs.", + ], cx, ) }) @@ -394,12 +797,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); thread.update(cx, |thread, _cx| { - let last_message = thread.messages().last().unwrap(); - let text = last_message + let last_message = thread.last_message().unwrap(); + let agent_message = last_message.as_agent_message().unwrap(); + let text = agent_message .content .iter() .filter_map(|content| { - if let MessageContent::Text(text) = content { + if let AgentMessageContent::Text(text) = content { Some(text.as_str()) } else { None @@ -411,17 +815,93 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_profiles(cx: &mut TestAppContext) { + let ThreadTest { + model, thread, fs, .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(DelayTool); + thread.add_tool(EchoTool); + thread.add_tool(InfiniteTool); + }); + + // Override profiles and wait for settings to be loaded. + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test-1": { + "name": "Test Profile 1", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + } + }, + "test-2": { + "name": "Test Profile 2", + "tools": { + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + + // Test that test-1 profile (default) has echo and delay tools + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-1".into())); + thread.send(UserMessageId::new(), ["test"], cx); + }); + cx.run_until_parked(); + + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]); + fake_model.end_last_completion_stream(); + + // Switch to test-2 profile, and verify that it has only the infinite tool. + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-2".into())); + thread.send(UserMessageId::new(), ["test2"], cx) + }); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![InfiniteTool.name()]); +} + #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_cancellation(cx: &mut TestAppContext) { - let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; + let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await; let mut events = thread.update(cx, |thread, cx| { thread.add_tool(InfiniteTool); thread.add_tool(EchoTool); thread.send( - model.clone(), - "Call the echo tool and then call the infinite tool, then explain their output", + UserMessageId::new(), + ["Call the echo tool, then call the infinite tool, then explain their output"], cx, ) }); @@ -466,14 +946,20 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Ensure we can still send a new message after cancellation. let events = thread .update(cx, |thread, cx| { - thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx) + thread.send( + UserMessageId::new(), + ["Testing: reply with 'Hello' then stop."], + cx, + ) }) .collect::>() .await; thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] + agent_message.content, + vec![AgentMessageContent::Text("Hello".to_string())] ); }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); @@ -484,13 +970,16 @@ async fn test_refusal(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx)); + let events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello "} ); @@ -502,9 +991,12 @@ async fn test_refusal(cx: &mut TestAppContext) { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello - ## assistant + + ## Assistant + Hey! "} ); @@ -520,6 +1012,85 @@ async fn test_refusal(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_truncate(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let message_id = UserMessageId::new(); + thread.update(cx, |thread, cx| { + thread.send(message_id.clone(), ["Hello"], cx) + }); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + + ## Assistant + + Hey! + "} + ); + }); + + thread + .update(cx, |thread, _cx| thread.truncate(message_id)) + .unwrap(); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); + + // Ensure we can still send a new message after truncation. + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hi"], cx) + }); + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + "} + ); + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + + ## Assistant + + Ahoy! + "} + ); + }); +} + #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { cx.update(settings::init); @@ -538,19 +1109,26 @@ async fn test_agent_connection(cx: &mut TestAppContext) { language_models::init(user_store.clone(), client.clone(), cx); Project::init_settings(cx); LanguageModelRegistry::test(cx); + agent_settings::init(cx); }); cx.executor().forbid_parking(); // Create a project for new_thread let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); fake_fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let cwd = Path::new("/test"); // Create agent and connection - let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + templates.clone(), + None, + fake_fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some @@ -563,22 +1141,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test list_models let listed_models = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.list_models(&mut async_cx) - }) + .update(|cx| selector.list_models(cx)) .await .expect("list_models should succeed"); + let AgentModelList::Grouped(listed_models) = listed_models else { + panic!("Unexpected model list type"); + }; assert!(!listed_models.is_empty(), "should have at least one model"); - assert_eq!(listed_models[0].id().0, "fake"); + assert_eq!( + listed_models[&AgentModelGroupName("Fake".into())][0].id.0, + "fake/fake" + ); // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - connection_rc.new_thread(project, cwd, &mut async_cx) - }) + .update(|cx| connection_rc.new_thread(project, cwd, cx)) .await .expect("new_thread should succeed"); @@ -587,12 +1165,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test selected_model returns the default let model = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.selected_model(&session_id, &mut async_cx) - }) + .update(|cx| selector.selected_model(&session_id, cx)) .await .expect("selected_model should succeed"); + let model = cx + .update(|cx| agent.read(cx).models().model_from_id(&model.id)) + .unwrap(); let model = model.as_fake(); assert_eq!(model.id().0, "fake", "should return default model"); @@ -626,6 +1204,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let result = cx .update(|cx| { connection.prompt( + Some(acp_thread::UserMessageId::new()), acp::PromptRequest { session_id: session_id.clone(), prompt: vec!["ghi".into()], @@ -648,7 +1227,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Think"], cx) + }); cx.run_until_parked(); // Simulate streaming partial input. @@ -733,6 +1314,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { id: acp::ToolCallId("1".into()), fields: acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::Completed), + raw_output: Some("Finished thinking.".into()), ..Default::default() }, } @@ -740,9 +1322,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events( - result_events: Vec>, -) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { @@ -756,6 +1336,7 @@ struct ThreadTest { model: Arc, thread: Entity, project_context: Rc>, + fs: Arc, } enum TestModel { @@ -776,28 +1357,59 @@ impl TestModel { async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_profile": "test-profile", + "profiles": { + "test-profile": { + "name": "Test Profile", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + WordListTool.name(): true, + ToolRequiringPermission.name(): true, + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.update(|cx| { settings::init(cx); Project::init_settings(cx); + agent_settings::init(cx); + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + watch_settings(fs.clone(), cx); }); + let templates = Templates::new(); - let fs = FakeFs::new(cx.background_executor.clone()); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let model = cx .update(|cx| { - gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - - client::init_settings(cx); - let client = Client::production(cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); - if let TestModel::Fake = model { Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) } else { @@ -820,20 +1432,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { .await; let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, project_context.clone(), + context_server_registry, action_log, templates, model.clone(), + cx, ) }); ThreadTest { model, thread, project_context, + fs, } } @@ -844,3 +1461,26 @@ fn init_logger() { env_logger::init(); } } + +fn watch_settings(fs: Arc, cx: &mut App) { + let fs = fs.clone(); + cx.spawn({ + async move |cx| { + let mut new_settings_content_rx = settings::watch_config_file( + cx.background_executor(), + fs, + paths::settings_file().clone(), + ); + + while let Some(new_settings_content) = new_settings_content_rx.next().await { + cx.update(|cx| { + SettingsStore::update_global(cx, |settings, cx| { + settings.set_user_settings(&new_settings_content, cx) + }) + }) + .ok(); + } + } + }) + .detach(); +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index d06614f3fe..cbff44cedf 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -7,7 +7,7 @@ use std::future; #[derive(JsonSchema, Serialize, Deserialize)] pub struct EchoToolInput { /// The text to echo. - text: String, + pub text: String, } pub struct EchoTool; @@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let auth_check = event_stream.authorize("Authorize?".into()); + let authorize = event_stream.authorize("Authorize?", cx); cx.foreground_executor().spawn(async move { - auth_check.await?; + authorize.await?; Ok("Allowed".to_string()) }) } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index f664e0f5d2..0741bb9e08 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,56 +1,322 @@ -use crate::{SystemPromptTemplate, Template, Templates}; -use acp_thread::Diff; +use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use acp_thread::{MentionUri, UserMessageId}; +use action_log::ActionLog; use agent_client_protocol as acp; -use anyhow::{anyhow, Context as _, Result}; -use assistant_tool::{adapt_schema_to_format, ActionLog}; -use cloud_llm_client::{CompletionIntent, CompletionMode}; -use collections::HashMap; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::adapt_schema_to_format; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; +use collections::IndexMap; +use fs::Fs; use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, }; use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; -use log; use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; +use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; -use util::{markdown::MarkdownCodeBlock, ResultExt}; +use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; +use std::{fmt::Write, ops::Range}; +use util::{ResultExt, markdown::MarkdownCodeBlock}; +use uuid::Uuid; -#[derive(Debug, Clone)] -pub struct AgentMessage { - pub role: Role, - pub content: Vec, +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, +)] +pub struct ThreadId(Arc); + +impl ThreadId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} + +impl std::fmt::Display for ThreadId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for ThreadId { + fn from(value: &str) -> Self { + Self(value.into()) + } +} + +/// The ID of the user prompt that initiated a request. +/// +/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] +pub struct PromptId(Arc); + +impl PromptId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} + +impl std::fmt::Display for PromptId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Message { + User(UserMessage), + Agent(AgentMessage), + Resume, +} + +impl Message { + pub fn as_agent_message(&self) -> Option<&AgentMessage> { + match self { + Message::Agent(agent_message) => Some(agent_message), + _ => None, + } + } + + pub fn to_markdown(&self) -> String { + match self { + Message::User(message) => message.to_markdown(), + Message::Agent(message) => message.to_markdown(), + Message::Resume => "[resumed after tool use limit was reached]".into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UserMessage { + pub id: UserMessageId, + pub content: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UserMessageContent { + Text(String), + Mention { uri: MentionUri, content: String }, + Image(LanguageModelImage), +} + +impl UserMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = String::from("## User\n\n"); + + for content in &self.content { + match content { + UserMessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + UserMessageContent::Image(_) => { + markdown.push_str("\n"); + } + UserMessageContent::Mention { uri, content } => { + if !content.is_empty() { + let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content); + } else { + let _ = write!(&mut markdown, "{}\n", uri.as_link()); + } + } + } + } + + markdown + } + + fn to_request(&self) -> LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_FETCH_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut fetch_context = OPEN_FETCH_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + UserMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + UserMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + UserMessageContent::Mention { uri, content } => { + match uri { + MentionUri::File { abs_path, .. } => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&abs_path, None), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { + path, line_range, .. + } => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path, Some(line_range)), + text: &content + } + ) + .ok(); + } + MentionUri::Thread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::TextThread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule { .. } => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: &content + } + ) + .ok(); + } + MentionUri::Fetch { url } => { + write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok(); + } + } + + language_model::MessageContent::Text(uri.as_link().to_string()) + } + }; + + message.content.push(chunk); + } + + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); + } + + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } + + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } + + if fetch_context.len() > OPEN_FETCH_TAG.len() { + fetch_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(fetch_context)); + } + + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } + + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); + } + + message + } +} + +fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { + let mut result = String::new(); + + if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { + let _ = write!(result, "{} ", extension); + } + + let _ = write!(result, "{}", full_path.display()); + + if let Some(range) = line_range { + if range.start == range.end { + let _ = write!(result, ":{}", range.start + 1); + } else { + let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1); + } + } + + result } impl AgentMessage { pub fn to_markdown(&self) -> String { - let mut markdown = format!("## {}\n", self.role); + let mut markdown = String::from("## Assistant\n\n"); for content in &self.content { match content { - MessageContent::Text(text) => { + AgentMessageContent::Text(text) => { markdown.push_str(text); markdown.push('\n'); } - MessageContent::Thinking { text, .. } => { + AgentMessageContent::Thinking { text, .. } => { markdown.push_str(""); markdown.push_str(text); markdown.push_str("\n"); } - MessageContent::RedactedThinking(_) => markdown.push_str("\n"), - MessageContent::Image(_) => { + AgentMessageContent::RedactedThinking(_) => { + markdown.push_str("\n") + } + AgentMessageContent::Image(_) => { markdown.push_str("\n"); } - MessageContent::ToolUse(tool_use) => { + AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", tool_use.name, tool_use.id @@ -63,38 +329,111 @@ impl AgentMessage { } )); } - MessageContent::ToolResult(tool_result) => { - markdown.push_str(&format!( - "**Tool Result**: {} (ID: {})\n\n", - tool_result.tool_name, tool_result.tool_use_id - )); - if tool_result.is_error { - markdown.push_str("**ERROR:**\n"); - } + } + } - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}\n").ok(); - } - LanguageModelToolResultContent::Image(_) => { - writeln!(markdown, "\n").ok(); - } - } + for tool_result in self.tool_results.values() { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "**Debug Output**:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output).unwrap() - ) - .unwrap(); - } + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } + } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); } } markdown } + + pub fn to_request(&self) -> Vec { + let mut assistant_message = LanguageModelRequestMessage { + role: Role::Assistant, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + for chunk in &self.content { + let chunk = match chunk { + AgentMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + AgentMessageContent::Thinking { text, signature } => { + language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + AgentMessageContent::RedactedThinking(value) => { + language_model::MessageContent::RedactedThinking(value.clone()) + } + AgentMessageContent::ToolUse(value) => { + language_model::MessageContent::ToolUse(value.clone()) + } + AgentMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + }; + assistant_message.content.push(chunk); + } + + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; + + for tool_result in self.tool_results.values() { + user_message + .content + .push(language_model::MessageContent::ToolResult( + tool_result.clone(), + )); + } + + let mut messages = Vec::new(); + if !assistant_message.content.is_empty() { + messages.push(assistant_message); + } + if !user_message.content.is_empty() { + messages.push(user_message); + } + messages + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct AgentMessage { + pub content: Vec, + pub tool_results: IndexMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentMessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), } #[derive(Debug)] @@ -109,23 +448,28 @@ pub enum AgentResponseEvent { #[derive(Debug)] pub struct ToolCallAuthorization { - pub tool_call: acp::ToolCall, + pub tool_call: acp::ToolCallUpdate, pub options: Vec, pub response: oneshot::Sender, } pub struct Thread { - messages: Vec, + id: ThreadId, + prompt_id: PromptId, + messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. running_turn: Option>, - pending_tool_uses: HashMap, + pending_message: Option, tools: BTreeMap>, + tool_use_limit_reached: bool, + context_server_registry: Entity, + profile_id: AgentProfileId, project_context: Rc>, templates: Arc, - pub selected_model: Arc, + model: Arc, project: Entity, action_log: Entity, } @@ -134,19 +478,27 @@ impl Thread { pub fn new( project: Entity, project_context: Rc>, + context_server_registry: Entity, action_log: Entity, templates: Arc, - default_model: Arc, + model: Arc, + cx: &mut Context, ) -> Self { + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { + id: ThreadId::new(), + prompt_id: PromptId::new(), messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, - pending_tool_uses: HashMap::default(), + pending_message: None, tools: BTreeMap::default(), + tool_use_limit_reached: false, + context_server_registry, + profile_id, project_context, templates, - selected_model: default_model, + model, project, action_log, } @@ -160,12 +512,29 @@ impl Thread { &self.action_log } - pub fn set_mode(&mut self, mode: CompletionMode) { + pub fn model(&self) -> &Arc { + &self.model + } + + pub fn set_model(&mut self, model: Arc) { + self.model = model; + } + + pub fn completion_mode(&self) -> CompletionMode { + self.completion_mode + } + + pub fn set_completion_mode(&mut self, mode: CompletionMode) { self.completion_mode = mode; } - pub fn messages(&self) -> &[AgentMessage] { - &self.messages + #[cfg(any(test, feature = "test-support"))] + pub fn last_message(&self) -> Option { + if let Some(message) = self.pending_message.clone() { + Some(Message::Agent(message)) + } else { + self.messages.last().cloned() + } } pub fn add_tool(&mut self, tool: impl AgentTool) { @@ -176,114 +545,133 @@ impl Thread { self.tools.remove(name).is_some() } - pub fn cancel(&mut self) { - self.running_turn.take(); + pub fn profile(&self) -> &AgentProfileId { + &self.profile_id + } - let tool_results = self - .pending_tool_uses - .drain() - .map(|(tool_use_id, tool_use)| { - MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id, - tool_name: tool_use.name.clone(), - is_error: true, - content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), - output: None, - }) - }) - .collect::>(); - self.last_user_message().content.extend(tool_results); + pub fn set_profile(&mut self, profile_id: AgentProfileId) { + self.profile_id = profile_id; + } + + pub fn cancel(&mut self) { + // TODO: do we need to emit a stop::cancel for ACP? + self.running_turn.take(); + self.flush_pending_message(); + } + + pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { + self.cancel(); + let Some(position) = self.messages.iter().position( + |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), + ) else { + return Err(anyhow!("Message not found")); + }; + self.messages.truncate(position); + Ok(()) + } + + pub fn resume( + &mut self, + cx: &mut Context, + ) -> Result>> { + anyhow::ensure!( + self.tool_use_limit_reached, + "can only resume after tool use limit is reached" + ); + + self.messages.push(Message::Resume); + cx.notify(); + + log::info!("Total messages in thread: {}", self.messages.len()); + Ok(self.run_turn(cx)) } /// Sending a message results in the model streaming a response, which could include tool calls. /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. - pub fn send( + pub fn send( &mut self, - model: Arc, - content: impl Into, + id: UserMessageId, + content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { - let content = content.into(); - log::info!("Thread::send called with model: {:?}", model.name()); + ) -> mpsc::UnboundedReceiver> + where + T: Into, + { + log::info!("Thread::send called with model: {:?}", self.model.name()); + self.advance_prompt_id(); + + let content = content.into_iter().map(Into::into).collect::>(); log::debug!("Thread::send content: {:?}", content); + self.messages + .push(Message::User(UserMessage { id, content })); cx.notify(); - let (events_tx, events_rx) = - mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); - let user_message_ix = self.messages.len(); - self.messages.push(AgentMessage { - role: Role::User, - content: vec![content], - }); log::info!("Total messages in thread: {}", self.messages.len()); - self.running_turn = Some(cx.spawn(async move |thread, cx| { + self.run_turn(cx) + } + + fn run_turn( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let model = self.model.clone(); + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = AgentResponseEventStream(events_tx); + let message_ix = self.messages.len().saturating_sub(1); + self.tool_use_limit_reached = false; + self.running_turn = Some(cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); - let turn_result = async { - // Perform one request, then keep looping if the model makes tool calls. + let turn_result: Result<()> = async { let mut completion_intent = CompletionIntent::UserPrompt; - 'outer: loop { + loop { log::debug!( "Building completion request with intent: {:?}", completion_intent ); - let request = thread.update(cx, |thread, cx| { - thread.build_completion_request(completion_intent, cx) + let request = this.update(cx, |this, cx| { + this.build_completion_request(completion_intent, cx) })?; - // println!( - // "request: {}", - // serde_json::to_string_pretty(&request).unwrap() - // ); - - // Stream events, appending to messages and collecting up tool uses. log::info!("Calling model.stream_completion"); let mut events = model.stream_completion(request, cx).await?; log::debug!("Stream completion started successfully"); + + let mut tool_use_limit_reached = false; let mut tool_uses = FuturesUnordered::new(); while let Some(event) = events.next().await { - match event { - Ok(LanguageModelCompletionEvent::Stop(reason)) => { + match event? { + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::ToolUseLimitReached, + ) => { + tool_use_limit_reached = true; + } + LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - thread.update(cx, |thread, _cx| { - thread.messages.truncate(user_message_ix); + this.update(cx, |this, _cx| { + this.flush_pending_message(); + this.messages.truncate(message_ix); })?; - break 'outer; + return Ok(()); } } - Ok(event) => { + event => { log::trace!("Received completion event: {:?}", event); - thread - .update(cx, |thread, cx| { - tool_uses.extend(thread.handle_streamed_completion_event( - event, - &event_stream, - cx, - )); - }) - .ok(); - } - Err(error) => { - log::error!("Error in completion stream: {:?}", error); - event_stream.send_error(error); - break; + this.update(cx, |this, cx| { + tool_uses.extend(this.handle_streamed_completion_event( + event, + &event_stream, + cx, + )); + }) + .ok(); } } } - // If there are no tool uses, the turn is done. - if tool_uses.is_empty() { - log::info!("No tool uses found, completing turn"); - break; - } - log::info!("Found {} tool uses to execute", tool_uses.len()); - - // As tool results trickle in, insert them in the last user - // message so that they can be sent on the next tick of the - // agentic loop. + let used_tools = tool_uses.is_empty(); while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); @@ -295,27 +683,34 @@ impl Thread { } else { acp::ToolCallStatus::Completed }), + raw_output: tool_result.output.clone(), ..Default::default() }, ); - thread - .update(cx, |thread, _cx| { - thread.pending_tool_uses.remove(&tool_result.tool_use_id); - thread - .last_user_message() - .content - .push(MessageContent::ToolResult(tool_result)); - }) - .ok(); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + }) + .ok(); } - completion_intent = CompletionIntent::ToolResults; + if tool_use_limit_reached { + log::info!("Tool use limit reached, completing turn"); + this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; + return Err(language_model::ToolUseLimitReachedError.into()); + } else if used_tools { + log::info!("No tool uses found, completing turn"); + return Ok(()); + } else { + this.update(cx, |this, _| this.flush_pending_message())?; + completion_intent = CompletionIntent::ToolResults; + } } - - Ok(()) } .await; + this.update(cx, |this, _| this.flush_pending_message()).ok(); if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error); event_stream.send_error(error); @@ -326,7 +721,7 @@ impl Thread { events_rx } - pub fn build_system_message(&self) -> AgentMessage { + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { project: &self.project_context.borrow(), @@ -336,9 +731,10 @@ impl Thread { .context("failed to build system prompt") .expect("Invalid template"); log::debug!("System message built"); - AgentMessage { + LanguageModelRequestMessage { role: Role::System, content: vec![prompt.into()], + cache: true, } } @@ -356,10 +752,8 @@ impl Thread { match event { StartMessage { .. } => { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); + self.flush_pending_message(); + self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Thinking { text, signature } => { @@ -392,16 +786,18 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - events_stream: &AgentResponseEventStream, + event_stream: &AgentResponseEventStream, cx: &mut Context, ) { - events_stream.send_text(&new_text); + event_stream.send_text(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + let last_message = self.pending_message(); + if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { text.push_str(&new_text); } else { - last_message.content.push(MessageContent::Text(new_text)); + last_message + .content + .push(AgentMessageContent::Text(new_text)); } cx.notify(); @@ -416,13 +812,14 @@ impl Thread { ) { event_stream.send_thinking(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + let last_message = self.pending_message(); + if let Some(AgentMessageContent::Thinking { text, signature }) = + last_message.content.last_mut() { text.push_str(&new_text); *signature = new_signature.or(signature.take()); } else { - last_message.content.push(MessageContent::Thinking { + last_message.content.push(AgentMessageContent::Thinking { text: new_text, signature: new_signature, }); @@ -432,10 +829,10 @@ impl Thread { } fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { - let last_message = self.last_assistant_message(); + let last_message = self.pending_message(); last_message .content - .push(MessageContent::RedactedThinking(data)); + .push(AgentMessageContent::RedactedThinking(data)); cx.notify(); } @@ -448,14 +845,17 @@ impl Thread { cx.notify(); let tool = self.tools.get(tool_use.name.as_ref()).cloned(); - - self.pending_tool_uses - .insert(tool_use.id.clone(), tool_use.clone()); - let last_message = self.last_assistant_message(); + let mut title = SharedString::from(&tool_use.name); + let mut kind = acp::ToolKind::Other; + if let Some(tool) = tool.as_ref() { + title = tool.initial_title(tool_use.input.clone()); + kind = tool.kind(); + } // Ensure the last message ends in the current tool use + let last_message = self.pending_message(); let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { - if let MessageContent::ToolUse(last_tool_use) = content { + if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); false @@ -467,18 +867,11 @@ impl Thread { } }); - let mut title = SharedString::from(&tool_use.name); - let mut kind = acp::ToolKind::Other; - if let Some(tool) = tool.as_ref() { - title = tool.initial_title(tool_use.input.clone()); - kind = tool.kind(); - } - if push_new_tool_use { event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); last_message .content - .push(MessageContent::ToolUse(tool_use.clone())); + .push(AgentMessageContent::ToolUse(tool_use.clone())); } else { event_stream.update_tool_call_fields( &tool_use.id, @@ -506,14 +899,16 @@ impl Thread { })); }; + let fs = self.project.read(cx).fs().clone(); let tool_event_stream = - ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone()); + ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs)); tool_event_stream.update_fields(acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() }); - let supports_images = self.selected_model.supports_images(); + let supports_images = self.model.supports_images(); let tool_result = tool.run(tool_use.input, tool_event_stream, cx); + log::info!("Running tool {}", tool_use.name); Some(cx.foreground_executor().spawn(async move { let tool_result = tool_result.await.and_then(|output| { if let LanguageModelToolResultContent::Image(_) = &output.llm_output { @@ -562,30 +957,37 @@ impl Thread { } } - /// Guarantees the last message is from the assistant and returns a mutable reference. - fn last_assistant_message(&mut self) -> &mut AgentMessage { - if self - .messages - .last() - .map_or(true, |m| m.role != Role::Assistant) - { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); - } - self.messages.last_mut().unwrap() + fn pending_message(&mut self) -> &mut AgentMessage { + self.pending_message.get_or_insert_default() } - /// Guarantees the last message is from the user and returns a mutable reference. - fn last_user_message(&mut self) -> &mut AgentMessage { - if self.messages.last().map_or(true, |m| m.role != Role::User) { - self.messages.push(AgentMessage { - role: Role::User, - content: Vec::new(), - }); + fn flush_pending_message(&mut self) { + let Some(mut message) = self.pending_message.take() else { + return; + }; + + for content in &message.content { + let AgentMessageContent::ToolUse(tool_use) = content else { + continue; + }; + + if !message.tool_results.contains_key(&tool_use.id) { + message.tool_results.insert( + tool_use.id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text( + "Tool canceled by user".into(), + ), + output: None, + }, + ); + } } - self.messages.last_mut().unwrap() + + self.messages.push(Message::Agent(message)); } pub(crate) fn build_completion_request( @@ -600,34 +1002,36 @@ impl Thread { let messages = self.build_request_messages(); log::info!("Request will include {} messages", messages.len()); - let tools: Vec = self - .tools - .values() - .filter_map(|tool| { - let tool_name = tool.name().to_string(); - log::trace!("Including tool: {}", tool_name); - Some(LanguageModelRequestTool { - name: tool_name, - description: tool.description(cx).to_string(), - input_schema: tool - .input_schema(self.selected_model.tool_input_format()) - .log_err()?, + let tools = if let Some(tools) = self.tools(cx).log_err() { + tools + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description().to_string(), + input_schema: tool + .input_schema(self.model.tool_input_format()) + .log_err()?, + }) }) - }) - .collect(); + .collect() + } else { + Vec::new() + }; log::info!("Request includes {} tools", tools.len()); let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, + thread_id: Some(self.id.to_string()), + prompt_id: Some(self.prompt_id.to_string()), intent: Some(completion_intent), - mode: Some(self.completion_mode), + mode: Some(self.completion_mode.into()), messages, tools, tool_choice: None, stop: Vec::new(), - temperature: None, + temperature: AgentSettings::temperature_for_model(self.model(), cx), thinking_allowed: true, }; @@ -635,42 +1039,90 @@ impl Thread { request } + fn tools<'a>(&'a self, cx: &'a App) -> Result>> { + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("profile not found")?; + let provider_id = self.model.provider_id(); + + Ok(self + .tools + .iter() + .filter(move |(_, tool)| tool.supported_provider(&provider_id)) + .filter_map(|(tool_name, tool)| { + if profile.is_tool_enabled(tool_name) { + Some(tool) + } else { + None + } + }) + .chain(self.context_server_registry.read(cx).servers().flat_map( + |(server_id, tools)| { + tools.iter().filter_map(|(tool_name, tool)| { + if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { + Some(tool) + } else { + None + } + }) + }, + ))) + } + fn build_request_messages(&self) -> Vec { log::trace!( "Building request messages from {} thread messages", self.messages.len() ); - - let messages = Some(self.build_system_message()) - .iter() - .chain(self.messages.iter()) - .map(|message| { - log::trace!( - " - {} message with {} content items", - match message.role { - Role::System => "System", - Role::User => "User", - Role::Assistant => "Assistant", - }, - message.content.len() - ); - LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), + let mut messages = vec![self.build_system_message()]; + for message in &self.messages { + match message { + Message::User(message) => messages.push(message.to_request()), + Message::Agent(message) => messages.extend(message.to_request()), + Message::Resume => messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], cache: false, - } - }) - .collect(); + }), + } + } + + if let Some(message) = self.pending_message.as_ref() { + messages.extend(message.to_request()); + } + + if let Some(last_user_message) = messages + .iter_mut() + .rev() + .find(|message| message.role == Role::User) + { + last_user_message.cache = true; + } + messages } pub fn to_markdown(&self) -> String { let mut markdown = String::new(); - for message in &self.messages { + for (ix, message) in self.messages.iter().enumerate() { + if ix > 0 { + markdown.push('\n'); + } markdown.push_str(&message.to_markdown()); } + + if let Some(message) = self.pending_message.as_ref() { + markdown.push('\n'); + markdown.push_str(&message.to_markdown()); + } + markdown } + + fn advance_prompt_id(&mut self) { + self.prompt_id = PromptId::new(); + } } pub trait AgentTool @@ -682,7 +1134,7 @@ where fn name(&self) -> SharedString; - fn description(&self, _cx: &mut App) -> SharedString { + fn description(&self) -> SharedString { let schema = schemars::schema_for!(Self::Input); SharedString::new( schema @@ -702,6 +1154,12 @@ where schemars::schema_for!(Self::Input) } + /// Some tools rely on a provider for the underlying billing or other reasons. + /// Allow the tool to check if they are compatible, or should be filtered out. + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } + /// Runs the tool with the provided input. fn run( self: Arc, @@ -718,16 +1176,19 @@ where pub struct Erased(T); pub struct AgentToolOutput { - llm_output: LanguageModelToolResultContent, - raw_output: serde_json::Value, + pub llm_output: LanguageModelToolResultContent, + pub raw_output: serde_json::Value, } pub trait AnyAgentTool { fn name(&self) -> SharedString; - fn description(&self, cx: &mut App) -> SharedString; + fn description(&self) -> SharedString; fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } fn run( self: Arc, input: serde_json::Value, @@ -744,8 +1205,8 @@ where self.0.name() } - fn description(&self, cx: &mut App) -> SharedString { - self.0.description(cx) + fn description(&self) -> SharedString { + self.0.description() } fn kind(&self) -> agent_client_protocol::ToolKind { @@ -763,6 +1224,10 @@ where Ok(json) } + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + self.0.supported_provider(provider) + } + fn run( self: Arc, input: serde_json::Value, @@ -784,9 +1249,7 @@ where } #[derive(Clone)] -struct AgentResponseEventStream( - mpsc::UnboundedSender>, -); +struct AgentResponseEventStream(mpsc::UnboundedSender>); impl AgentResponseEventStream { fn send_text(&self, text: &str) { @@ -801,47 +1264,6 @@ impl AgentResponseEventStream { .ok(); } - fn authorize_tool_call( - &self, - id: &LanguageModelToolUseId, - title: String, - kind: acp::ToolKind, - input: serde_json::Value, - ) -> impl use<> + Future> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( - ToolCallAuthorization { - tool_call: Self::initial_tool_call(id, title, kind, input), - options: vec![ - acp::PermissionOption { - id: acp::PermissionOptionId("always_allow".into()), - name: "Always Allow".into(), - kind: acp::PermissionOptionKind::AllowAlways, - }, - acp::PermissionOption { - id: acp::PermissionOptionId("allow".into()), - name: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }, - acp::PermissionOption { - id: acp::PermissionOptionId("deny".into()), - name: "Deny".into(), - kind: acp::PermissionOptionKind::RejectOnce, - }, - ], - response: response_tx, - }, - ))) - .ok(); - async move { - match response_rx.await?.0.as_ref() { - "allow" | "always_allow" => Ok(()), - _ => Err(anyhow!("Permission to run tool denied by user")), - } - } - } - fn send_tool_call( &self, id: &LanguageModelToolUseId, @@ -893,18 +1315,6 @@ impl AgentResponseEventStream { .ok(); } - fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity) { - self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdateDiff { - id: acp::ToolCallId(tool_use_id.to_string().into()), - diff, - } - .into(), - ))) - .ok(); - } - fn send_stop(&self, reason: StopReason) { match reason { StopReason::EndTurn => { @@ -926,50 +1336,38 @@ impl AgentResponseEventStream { } } - fn send_error(&self, error: LanguageModelCompletionError) { - self.0.unbounded_send(Err(error)).ok(); + fn send_error(&self, error: impl Into) { + self.0.unbounded_send(Err(error.into())).ok(); } } #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - kind: acp::ToolKind, - input: serde_json::Value, stream: AgentResponseEventStream, + fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = - mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = ToolCallEventStream::new( - &LanguageModelToolUse { - id: "test_id".into(), - name: "test_tool".into(), - raw_input: String::new(), - input: serde_json::Value::Null, - is_input_complete: true, - }, - acp::ToolKind::Other, - AgentResponseEventStream(events_tx), - ); + let stream = + ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( - tool_use: &LanguageModelToolUse, - kind: acp::ToolKind, + tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream, + fs: Option>, ) -> Self { Self { - tool_use_id: tool_use.id.clone(), - kind, - input: tool_use.input.clone(), + tool_use_id, stream, + fs, } } @@ -978,28 +1376,95 @@ impl ToolCallEventStream { .update_tool_call_fields(&self.tool_use_id, fields); } - pub fn update_diff(&self, diff: Entity) { - self.stream.update_tool_call_diff(&self.tool_use_id, diff); + pub fn update_diff(&self, diff: Entity) { + self.stream + .0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp_thread::ToolCallUpdateDiff { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + diff, + } + .into(), + ))) + .ok(); } - pub fn authorize(&self, title: String) -> impl use<> + Future> { - self.stream.authorize_tool_call( - &self.tool_use_id, - title, - self.kind.clone(), - self.input.clone(), - ) + pub fn update_terminal(&self, terminal: Entity) { + self.stream + .0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp_thread::ToolCallUpdateTerminal { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + terminal, + } + .into(), + ))) + .ok(); + } + + pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return Task::ready(Ok(())); + } + + let (response_tx, response_rx) = oneshot::channel(); + self.stream + .0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + ToolCallAuthorization { + tool_call: acp::ToolCallUpdate { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + title: Some(title.into()), + ..Default::default() + }, + }, + options: vec![ + acp::PermissionOption { + id: acp::PermissionOptionId("always_allow".into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("allow".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("deny".into()), + name: "Deny".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + response: response_tx, + }, + ))) + .ok(); + let fs = self.fs.clone(); + cx.spawn(async move |cx| match response_rx.await?.0.as_ref() { + "always_allow" => { + if let Some(fs) = fs.clone() { + cx.update(|cx| { + update_settings_file::(fs, cx, |settings, _| { + settings.set_always_allow_tool_actions(true); + }); + })?; + } + + Ok(()) + } + "allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), + }) } } #[cfg(test)] -pub struct ToolCallEventStreamReceiver( - mpsc::UnboundedReceiver>, -); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { - pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization { + pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { auth @@ -1007,11 +1472,23 @@ impl ToolCallEventStreamReceiver { panic!("Expected ToolCallAuthorization but got: {:?}", event); } } + + pub async fn expect_terminal(&mut self) -> Entity { + let event = self.0.next().await; + if let Some(Ok(AgentResponseEvent::ToolCallUpdate( + acp_thread::ToolCallUpdate::UpdateTerminal(update), + ))) = event + { + update.terminal + } else { + panic!("Expected terminal but got: {:?}", event); + } + } } #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 @@ -1024,3 +1501,66 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver { &mut self.0 } } + +impl From<&str> for UserMessageContent { + fn from(text: &str) -> Self { + Self::Text(text.into()) + } +} + +impl From for UserMessageContent { + fn from(value: acp::ContentBlock) -> Self { + match value { + acp::ContentBlock::Text(text_content) => Self::Text(text_content.text), + acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)), + acp::ContentBlock::Audio(_) => { + // TODO + Self::Text("[audio]".to_string()) + } + acp::ContentBlock::ResourceLink(resource_link) => { + match MentionUri::parse(&resource_link.uri) { + Ok(uri) => Self::Mention { + uri, + content: String::new(), + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri)) + } + } + } + acp::ContentBlock::Resource(resource) => match resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(resource) => { + match MentionUri::parse(&resource.uri) { + Ok(uri) => Self::Mention { + uri, + content: resource.text, + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + Self::Text( + MarkdownCodeBlock { + tag: &resource.uri, + text: &resource.text, + } + .to_string(), + ) + } + } + } + acp::EmbeddedResourceResource::BlobResourceContents(_) => { + // TODO + Self::Text("[blob]".to_string()) + } + }, + } + } +} + +fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { + LanguageModelImage { + source: image_content.data.into(), + // TODO: make this optional? + size: gpui::Size::new(0.into(), 0.into()), + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 5fe13db854..d1f2b3b1c7 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,9 +1,35 @@ +mod context_server_registry; +mod copy_path_tool; +mod create_directory_tool; +mod delete_path_tool; +mod diagnostics_tool; mod edit_file_tool; +mod fetch_tool; mod find_path_tool; +mod grep_tool; +mod list_directory_tool; +mod move_path_tool; +mod now_tool; +mod open_tool; mod read_file_tool; +mod terminal_tool; mod thinking_tool; +mod web_search_tool; +pub use context_server_registry::*; +pub use copy_path_tool::*; +pub use create_directory_tool::*; +pub use delete_path_tool::*; +pub use diagnostics_tool::*; pub use edit_file_tool::*; +pub use fetch_tool::*; pub use find_path_tool::*; +pub use grep_tool::*; +pub use list_directory_tool::*; +pub use move_path_tool::*; +pub use now_tool::*; +pub use open_tool::*; pub use read_file_tool::*; +pub use terminal_tool::*; pub use thinking_tool::*; +pub use web_search_tool::*; diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs new file mode 100644 index 0000000000..db39e9278c --- /dev/null +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -0,0 +1,231 @@ +use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Result, anyhow, bail}; +use collections::{BTreeMap, HashMap}; +use context_server::ContextServerId; +use gpui::{App, Context, Entity, SharedString, Task}; +use project::context_server_store::{ContextServerStatus, ContextServerStore}; +use std::sync::Arc; +use util::ResultExt; + +pub struct ContextServerRegistry { + server_store: Entity, + registered_servers: HashMap, + _subscription: gpui::Subscription, +} + +struct RegisteredContextServer { + tools: BTreeMap>, + load_tools: Task>, +} + +impl ContextServerRegistry { + pub fn new(server_store: Entity, cx: &mut Context) -> Self { + let mut this = Self { + server_store: server_store.clone(), + registered_servers: HashMap::default(), + _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event), + }; + for server in server_store.read(cx).running_servers() { + this.reload_tools_for_server(server.id(), cx); + } + this + } + + pub fn servers( + &self, + ) -> impl Iterator< + Item = ( + &ContextServerId, + &BTreeMap>, + ), + > { + self.registered_servers + .iter() + .map(|(id, server)| (id, &server.tools)) + } + + fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context) { + let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else { + return; + }; + let Some(client) = server.client() else { + return; + }; + if !client.capable(context_server::protocol::ServerCapability::Tools) { + return; + } + + let registered_server = + self.registered_servers + .entry(server_id.clone()) + .or_insert(RegisteredContextServer { + tools: BTreeMap::default(), + load_tools: Task::ready(Ok(())), + }); + registered_server.load_tools = cx.spawn(async move |this, cx| { + let response = client + .request::(()) + .await; + + this.update(cx, |this, cx| { + let Some(registered_server) = this.registered_servers.get_mut(&server_id) else { + return; + }; + + registered_server.tools.clear(); + if let Some(response) = response.log_err() { + for tool in response.tools { + let tool = Arc::new(ContextServerTool::new( + this.server_store.clone(), + server.id(), + tool, + )); + registered_server.tools.insert(tool.name(), tool); + } + cx.notify(); + } + }) + }); + } + + fn handle_context_server_store_event( + &mut self, + _: Entity, + event: &project::context_server_store::Event, + cx: &mut Context, + ) { + match event { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { + match status { + ContextServerStatus::Starting => {} + ContextServerStatus::Running => { + self.reload_tools_for_server(server_id.clone(), cx); + } + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { + self.registered_servers.remove(&server_id); + cx.notify(); + } + } + } + } + } +} + +struct ContextServerTool { + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, +} + +impl ContextServerTool { + fn new( + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, + ) -> Self { + Self { + store, + server_id, + tool, + } + } +} + +impl AnyAgentTool for ContextServerTool { + fn name(&self) -> SharedString { + self.tool.name.clone().into() + } + + fn description(&self) -> SharedString { + self.tool.description.clone().unwrap_or_default().into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Other + } + + fn initial_title(&self, _input: serde_json::Value) -> SharedString { + format!("Run MCP tool `{}`", self.tool.name).into() + } + + fn input_schema( + &self, + format: language_model::LanguageModelToolSchemaFormat, + ) -> Result { + let mut schema = self.tool.input_schema.clone(); + assistant_tool::adapt_schema_to_format(&mut schema, format)?; + Ok(match schema { + serde_json::Value::Null => { + serde_json::json!({ "type": "object", "properties": [] }) + } + serde_json::Value::Object(map) if map.is_empty() => { + serde_json::json!({ "type": "object", "properties": [] }) + } + _ => schema, + }) + } + + fn run( + self: Arc, + input: serde_json::Value, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else { + return Task::ready(Err(anyhow!("Context server not found"))); + }; + let tool_name = self.tool.name.clone(); + let server_clone = server.clone(); + let input_clone = input.clone(); + + cx.spawn(async move |_cx| { + let Some(protocol) = server_clone.client() else { + bail!("Context server not initialized"); + }; + + let arguments = if let serde_json::Value::Object(map) = input_clone { + Some(map.into_iter().collect()) + } else { + None + }; + + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; + + let mut result = String::new(); + for content in response.content { + match content { + context_server::types::ToolResponseContent::Text { text } => { + result.push_str(&text); + } + context_server::types::ToolResponseContent::Image { .. } => { + log::warn!("Ignoring image content from tool response"); + } + context_server::types::ToolResponseContent::Audio { .. } => { + log::warn!("Ignoring audio content from tool response"); + } + context_server::types::ToolResponseContent::Resource { .. } => { + log::warn!("Ignoring resource content from tool response"); + } + } + } + Ok(AgentToolOutput { + raw_output: result.clone().into(), + llm_output: result.into(), + }) + }) + } +} diff --git a/crates/agent2/src/tools/copy_path_tool.rs b/crates/agent2/src/tools/copy_path_tool.rs new file mode 100644 index 0000000000..f973b86990 --- /dev/null +++ b/crates/agent2/src/tools/copy_path_tool.rs @@ -0,0 +1,118 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Context as _, Result, anyhow}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use util::markdown::MarkdownInlineCode; + +/// Copies a file or directory in the project, and returns confirmation that the +/// copy succeeded. +/// +/// Directory contents will be copied recursively (like `cp -r`). +/// +/// This tool should be used when it's desirable to create a copy of a file or +/// directory without modifying the original. It's much more efficient than +/// doing this by separately reading and then writing the file or directory's +/// contents, so this tool should be preferred over that approach whenever +/// copying is the goal. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct CopyPathToolInput { + /// The source path of the file or directory to copy. + /// If a directory is specified, its contents will be copied recursively (like `cp -r`). + /// + /// + /// If the project has the following files: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can copy the first file by providing a source_path of "directory1/a/something.txt" + /// + pub source_path: String, + + /// The destination path where the file or directory should be copied to. + /// + /// + /// To copy "directory1/a/something.txt" to "directory2/b/copy.txt", + /// provide a destination_path of "directory2/b/copy.txt" + /// + pub destination_path: String, +} + +pub struct CopyPathTool { + project: Entity, +} + +impl CopyPathTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for CopyPathTool { + type Input = CopyPathToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "copy_path".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Move + } + + fn initial_title(&self, input: Result) -> ui::SharedString { + if let Ok(input) = input { + let src = MarkdownInlineCode(&input.source_path); + let dest = MarkdownInlineCode(&input.destination_path); + format!("Copy {src} to {dest}").into() + } else { + "Copy path".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let copy_task = self.project.update(cx, |project, cx| { + match project + .find_project_path(&input.source_path, cx) + .and_then(|project_path| project.entry_for_path(&project_path, cx)) + { + Some(entity) => match project.find_project_path(&input.destination_path, cx) { + Some(project_path) => { + project.copy_entry(entity.id, None, project_path.path, cx) + } + None => Task::ready(Err(anyhow!( + "Destination path {} was outside the project.", + input.destination_path + ))), + }, + None => Task::ready(Err(anyhow!( + "Source path {} was not found in the project.", + input.source_path + ))), + } + }); + + cx.background_spawn(async move { + let _ = copy_task.await.with_context(|| { + format!( + "Copying {} to {}", + input.source_path, input.destination_path + ) + })?; + Ok(format!( + "Copied {} to {}", + input.source_path, input.destination_path + )) + }) + } +} diff --git a/crates/agent2/src/tools/create_directory_tool.rs b/crates/agent2/src/tools/create_directory_tool.rs new file mode 100644 index 0000000000..c173c5ae67 --- /dev/null +++ b/crates/agent2/src/tools/create_directory_tool.rs @@ -0,0 +1,89 @@ +use agent_client_protocol::ToolKind; +use anyhow::{Context as _, Result, anyhow}; +use gpui::{App, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use util::markdown::MarkdownInlineCode; + +use crate::{AgentTool, ToolCallEventStream}; + +/// Creates a new directory at the specified path within the project. Returns +/// confirmation that the directory was created. +/// +/// This tool creates a directory and all necessary parent directories (similar +/// to `mkdir -p`). It should be used whenever you need to create new +/// directories within the project. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct CreateDirectoryToolInput { + /// The path of the new directory. + /// + /// + /// If the project has the following structure: + /// + /// - directory1/ + /// - directory2/ + /// + /// You can create a new directory by providing a path of "directory1/new_directory" + /// + pub path: String, +} + +pub struct CreateDirectoryTool { + project: Entity, +} + +impl CreateDirectoryTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for CreateDirectoryTool { + type Input = CreateDirectoryToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "create_directory".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Read + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + format!("Create directory {}", MarkdownInlineCode(&input.path)).into() + } else { + "Create directory".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let project_path = match self.project.read(cx).find_project_path(&input.path, cx) { + Some(project_path) => project_path, + None => { + return Task::ready(Err(anyhow!("Path to create was outside the project"))); + } + }; + let destination_path: Arc = input.path.as_str().into(); + + let create_entry = self.project.update(cx, |project, cx| { + project.create_entry(project_path.clone(), true, cx) + }); + + cx.spawn(async move |_cx| { + create_entry + .await + .with_context(|| format!("Creating directory {destination_path}"))?; + + Ok(format!("Created directory {destination_path}")) + }) + } +} diff --git a/crates/agent2/src/tools/delete_path_tool.rs b/crates/agent2/src/tools/delete_path_tool.rs new file mode 100644 index 0000000000..e013b3a3e7 --- /dev/null +++ b/crates/agent2/src/tools/delete_path_tool.rs @@ -0,0 +1,137 @@ +use crate::{AgentTool, ToolCallEventStream}; +use action_log::ActionLog; +use agent_client_protocol::ToolKind; +use anyhow::{Context as _, Result, anyhow}; +use futures::{SinkExt, StreamExt, channel::mpsc}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::{Project, ProjectPath}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Deletes the file or directory (and the directory's contents, recursively) at +/// the specified path in the project, and returns confirmation of the deletion. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct DeletePathToolInput { + /// The path of the file or directory to delete. + /// + /// + /// If the project has the following files: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can delete the first file by providing a path of "directory1/a/something.txt" + /// + pub path: String, +} + +pub struct DeletePathTool { + project: Entity, + action_log: Entity, +} + +impl DeletePathTool { + pub fn new(project: Entity, action_log: Entity) -> Self { + Self { + project, + action_log, + } + } +} + +impl AgentTool for DeletePathTool { + type Input = DeletePathToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "delete_path".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Delete + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + format!("Delete “`{}`”", input.path).into() + } else { + "Delete path".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let path = input.path; + let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else { + return Task::ready(Err(anyhow!( + "Couldn't delete {path} because that path isn't in this project." + ))); + }; + + let Some(worktree) = self + .project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + else { + return Task::ready(Err(anyhow!( + "Couldn't delete {path} because that path isn't in this project." + ))); + }; + + let worktree_snapshot = worktree.read(cx).snapshot(); + let (mut paths_tx, mut paths_rx) = mpsc::channel(256); + cx.background_spawn({ + let project_path = project_path.clone(); + async move { + for entry in + worktree_snapshot.traverse_from_path(true, false, false, &project_path.path) + { + if !entry.path.starts_with(&project_path.path) { + break; + } + paths_tx + .send(ProjectPath { + worktree_id: project_path.worktree_id, + path: entry.path.clone(), + }) + .await?; + } + anyhow::Ok(()) + } + }) + .detach(); + + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |cx| { + while let Some(path) = paths_rx.next().await { + if let Ok(buffer) = project + .update(cx, |project, cx| project.open_buffer(path, cx))? + .await + { + action_log.update(cx, |action_log, cx| { + action_log.will_delete_buffer(buffer.clone(), cx) + })?; + } + } + + let deletion_task = project + .update(cx, |project, cx| { + project.delete_file(project_path, false, cx) + })? + .with_context(|| { + format!("Couldn't delete {path} because that path isn't in this project.") + })?; + deletion_task + .await + .with_context(|| format!("Deleting {path}"))?; + Ok(format!("Deleted {path}")) + }) + } +} diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent2/src/tools/diagnostics_tool.rs new file mode 100644 index 0000000000..6ba8b7b377 --- /dev/null +++ b/crates/agent2/src/tools/diagnostics_tool.rs @@ -0,0 +1,163 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use gpui::{App, Entity, Task}; +use language::{DiagnosticSeverity, OffsetRangeExt}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{fmt::Write, path::Path, sync::Arc}; +use ui::SharedString; +use util::markdown::MarkdownInlineCode; + +/// Get errors and warnings for the project or a specific file. +/// +/// This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase. +/// +/// When a path is provided, shows all diagnostics for that specific file. +/// When no path is provided, shows a summary of error and warning counts for all files in the project. +/// +/// +/// To get diagnostics for a specific file: +/// { +/// "path": "src/main.rs" +/// } +/// +/// To get a project-wide diagnostic summary: +/// {} +/// +/// +/// +/// - If you think you can fix a diagnostic, make 1-2 attempts and then give up. +/// - Don't remove code you've generated just because you can't fix an error. The user can help you fix it. +/// +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct DiagnosticsToolInput { + /// The path to get diagnostics for. If not provided, returns a project-wide summary. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - lorem + /// - ipsum + /// + /// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`. + /// + pub path: Option, +} + +pub struct DiagnosticsTool { + project: Entity, +} + +impl DiagnosticsTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for DiagnosticsTool { + type Input = DiagnosticsToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "diagnostics".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Read + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Some(path) = input.ok().and_then(|input| match input.path { + Some(path) if !path.is_empty() => Some(path), + _ => None, + }) { + format!("Check diagnostics for {}", MarkdownInlineCode(&path)).into() + } else { + "Check project diagnostics".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + match input.path { + Some(path) if !path.is_empty() => { + let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else { + return Task::ready(Err(anyhow!("Could not find path {path} in project",))); + }; + + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + + cx.spawn(async move |cx| { + let mut output = String::new(); + let buffer = buffer.await?; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for (_, group) in snapshot.diagnostic_groups(None) { + let entry = &group.entries[group.primary_ix]; + let range = entry.range.to_point(&snapshot); + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + _ => continue, + }; + + writeln!( + output, + "{} at line {}: {}", + severity, + range.start.row + 1, + entry.diagnostic.message + )?; + } + + if output.is_empty() { + Ok("File doesn't have errors or warnings!".to_string()) + } else { + Ok(output) + } + }) + } + _ => { + let project = self.project.read(cx); + let mut output = String::new(); + let mut has_diagnostics = false; + + for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { + if summary.error_count > 0 || summary.warning_count > 0 { + let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) + else { + continue; + }; + + has_diagnostics = true; + output.push_str(&format!( + "{}: {} error(s), {} warning(s)\n", + Path::new(worktree.read(cx).root_name()) + .join(project_path.path) + .display(), + summary.error_count, + summary.warning_count + )); + } + } + + if has_diagnostics { + Task::ready(Ok(output)) + } else { + Task::ready(Ok("No errors or warnings found in the project.".into())) + } + } + } + } +} diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 0858bb501c..4b4f98daec 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -1,12 +1,13 @@ use crate::{AgentTool, Thread, ToolCallEventStream}; use acp_thread::Diff; -use agent_client_protocol as acp; -use anyhow::{anyhow, Context as _, Result}; +use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; +use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; use gpui::{App, AppContext, AsyncApp, Entity, Task}; use indoc::formatdoc; +use language::ToPoint; use language::language_settings::{self, FormatOnSave}; use language_model::LanguageModelToolResultContent; use paths; @@ -133,7 +134,7 @@ impl EditFileTool { &self, input: &EditFileToolInput, event_stream: &ToolCallEventStream, - cx: &App, + cx: &mut App, ) -> Task> { if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { return Task::ready(Ok(())); @@ -147,8 +148,9 @@ impl EditFileTool { .components() .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) { - return cx.foreground_executor().spawn( - event_stream.authorize(format!("{} (local settings)", input.display_description)), + return event_stream.authorize( + format!("{} (local settings)", input.display_description), + cx, ); } @@ -156,9 +158,9 @@ impl EditFileTool { // so check for that edge case too. if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { if canonical_path.starts_with(paths::config_dir()) { - return cx.foreground_executor().spawn( - event_stream - .authorize(format!("{} (global settings)", input.display_description)), + return event_stream.authorize( + format!("{} (global settings)", input.display_description), + cx, ); } } @@ -173,8 +175,7 @@ impl EditFileTool { if project_path.is_some() { Task::ready(Ok(())) } else { - cx.foreground_executor() - .spawn(event_stream.authorize(input.display_description.clone())) + event_stream.authorize(&input.display_description, cx) } } } @@ -225,12 +226,22 @@ impl AgentTool for EditFileTool { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), }; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: abs_path, + line: None, + }]), + ..Default::default() + }); + } let request = self.thread.update(cx, |thread, cx| { thread.build_completion_request(CompletionIntent::ToolResults, cx) }); let thread = self.thread.read(cx); - let model = thread.selected_model.clone(); + let model = thread.model().clone(); let action_log = thread.action_log().clone(); let authorize = self.authorize(&input, &event_stream, cx); @@ -283,13 +294,38 @@ impl AgentTool for EditFileTool { let mut hallucinated_old_text = false; let mut ambiguous_ranges = Vec::new(); + let mut emitted_location = false; while let Some(event) = events.next().await { match event { - EditAgentOutputEvent::Edited => {}, + EditAgentOutputEvent::Edited(range) => { + if !emitted_location { + let line = buffer.update(cx, |buffer, _cx| { + range.start.to_point(&buffer.snapshot()).row + }).ok(); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + ..Default::default() + }); + } + emitted_location = true; + } + }, EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, EditAgentOutputEvent::ResolvingEditRange(range) => { - diff.update(cx, |card, cx| card.reveal_range(range, cx))?; + diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?; + // if !emitted_location { + // let line = buffer.update(cx, |buffer, _cx| { + // range.start.to_point(&buffer.snapshot()).row + // }).ok(); + // if let Some(abs_path) = abs_path.clone() { + // event_stream.update_fields(ToolCallUpdateFields { + // locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + // ..Default::default() + // }); + // } + // } } } } @@ -454,10 +490,9 @@ fn resolve_path( #[cfg(test)] mod tests { - use crate::Templates; - use super::*; - use assistant_tool::ActionLog; + use crate::{ContextServerRegistry, Templates}; + use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; use gpui::{TestAppContext, UpdateGlobal}; @@ -475,9 +510,20 @@ mod tests { fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = - cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model)); + let thread = cx.new(|cx| { + Thread::new( + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + model, + cx, + ) + }); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -661,14 +707,18 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -792,15 +842,19 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -914,15 +968,19 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -942,8 +1000,11 @@ mod tests { ) }); - let event = stream_rx.expect_tool_authorization().await; - assert_eq!(event.tool_call.title, "test 1 (local settings)"); + let event = stream_rx.expect_authorization().await; + assert_eq!( + event.tool_call.fields.title, + Some("test 1 (local settings)".into()) + ); // Test 2: Path outside project should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -959,8 +1020,8 @@ mod tests { ) }); - let event = stream_rx.expect_tool_authorization().await; - assert_eq!(event.tool_call.title, "test 2"); + let event = stream_rx.expect_authorization().await; + assert_eq!(event.tool_call.fields.title, Some("test 2".into())); // Test 3: Relative path without .zed should not require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -992,8 +1053,11 @@ mod tests { cx, ) }); - let event = stream_rx.expect_tool_authorization().await; - assert_eq!(event.tool_call.title, "test 4 (local settings)"); + let event = stream_rx.expect_authorization().await; + assert_eq!( + event.tool_call.fields.title, + Some("test 4 (local settings)".into()) + ); // Test 5: When always_allow_tool_actions is enabled, no confirmation needed cx.update(|cx| { @@ -1041,15 +1105,19 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1088,7 +1156,7 @@ mod tests { }); if should_confirm { - stream_rx.expect_tool_authorization().await; + stream_rx.expect_authorization().await; } else { auth.await.unwrap(); assert!( @@ -1148,14 +1216,18 @@ mod tests { .await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1192,7 +1264,7 @@ mod tests { }); if should_confirm { - stream_rx.expect_tool_authorization().await; + stream_rx.expect_authorization().await; } else { auth.await.unwrap(); assert!( @@ -1225,14 +1297,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1276,7 +1352,7 @@ mod tests { }); if should_confirm { - stream_rx.expect_tool_authorization().await; + stream_rx.expect_authorization().await; } else { auth.await.unwrap(); assert!( @@ -1305,14 +1381,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1339,7 +1419,7 @@ mod tests { ) }); - stream_rx.expect_tool_authorization().await; + stream_rx.expect_authorization().await; // Test outside path with different modes let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1355,7 +1435,7 @@ mod tests { ) }); - stream_rx.expect_tool_authorization().await; + stream_rx.expect_authorization().await; // Test normal path with different modes let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1382,14 +1462,18 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent2/src/tools/fetch_tool.rs new file mode 100644 index 0000000000..ae26c5fe19 --- /dev/null +++ b/crates/agent2/src/tools/fetch_tool.rs @@ -0,0 +1,155 @@ +use std::rc::Rc; +use std::sync::Arc; +use std::{borrow::Cow, cell::RefCell}; + +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, bail}; +use futures::AsyncReadExt as _; +use gpui::{App, AppContext as _, Task}; +use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; +use http_client::{AsyncBody, HttpClientWithUrl}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use ui::SharedString; +use util::markdown::MarkdownEscaped; + +use crate::{AgentTool, ToolCallEventStream}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +enum ContentType { + Html, + Plaintext, + Json, +} + +/// Fetches a URL and returns the content as Markdown. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FetchToolInput { + /// The URL to fetch. + url: String, +} + +pub struct FetchTool { + http_client: Arc, +} + +impl FetchTool { + pub fn new(http_client: Arc) -> Self { + Self { http_client } + } + + async fn build_message(http_client: Arc, url: &str) -> Result { + let url = if !url.starts_with("https://") && !url.starts_with("http://") { + Cow::Owned(format!("https://{url}")) + } else { + Cow::Borrowed(url) + }; + + let mut response = http_client.get(&url, AsyncBody::default(), true).await?; + + let mut body = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("error reading response body")?; + + if response.status().is_client_error() { + let text = String::from_utf8_lossy(body.as_slice()); + bail!( + "status error {}, response: {text:?}", + response.status().as_u16() + ); + } + + let Some(content_type) = response.headers().get("content-type") else { + bail!("missing Content-Type header"); + }; + let content_type = content_type + .to_str() + .context("invalid Content-Type header")?; + + let content_type = if content_type.starts_with("text/plain") { + ContentType::Plaintext + } else if content_type.starts_with("application/json") { + ContentType::Json + } else { + ContentType::Html + }; + + match content_type { + ContentType::Html => { + let mut handlers: Vec = vec![ + Rc::new(RefCell::new(markdown::WebpageChromeRemover)), + Rc::new(RefCell::new(markdown::ParagraphHandler)), + Rc::new(RefCell::new(markdown::HeadingHandler)), + Rc::new(RefCell::new(markdown::ListHandler)), + Rc::new(RefCell::new(markdown::TableHandler::new())), + Rc::new(RefCell::new(markdown::StyledTextHandler)), + ]; + if url.contains("wikipedia.org") { + use html_to_markdown::structure::wikipedia; + + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover))); + handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler))); + handlers.push(Rc::new( + RefCell::new(wikipedia::WikipediaCodeHandler::new()), + )); + } else { + handlers.push(Rc::new(RefCell::new(markdown::CodeHandler))); + } + + convert_html_to_markdown(&body[..], &mut handlers) + } + ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()), + ContentType::Json => { + let json: serde_json::Value = serde_json::from_slice(&body)?; + + Ok(format!( + "```json\n{}\n```", + serde_json::to_string_pretty(&json)? + )) + } + } + } +} + +impl AgentTool for FetchTool { + type Input = FetchToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "fetch".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Fetch + } + + fn initial_title(&self, input: Result) -> SharedString { + match input { + Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(), + Err(_) => "Fetch URL".into(), + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let text = cx.background_spawn({ + let http_client = self.http_client.clone(); + async move { Self::build_message(http_client, &input.url).await } + }); + + cx.foreground_executor().spawn(async move { + let text = text.await?; + if text.trim().is_empty() { + bail!("no textual content found"); + } + Ok(text) + }) + } +} diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs index f4589e5600..552de144a7 100644 --- a/crates/agent2/src/tools/find_path_tool.rs +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -1,6 +1,6 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use gpui::{App, AppContext, Entity, SharedString, Task}; use language_model::LanguageModelToolResultContent; use project::Project; @@ -139,9 +139,6 @@ impl AgentTool for FindPathTool { }) .collect(), ), - raw_output: Some(serde_json::json!({ - "paths": &matches, - })), ..Default::default() }); diff --git a/crates/agent2/src/tools/grep_tool.rs b/crates/agent2/src/tools/grep_tool.rs new file mode 100644 index 0000000000..e5d92b3c1d --- /dev/null +++ b/crates/agent2/src/tools/grep_tool.rs @@ -0,0 +1,1185 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use futures::StreamExt; +use gpui::{App, Entity, SharedString, Task}; +use language::{OffsetRangeExt, ParseStatus, Point}; +use project::{ + Project, WorktreeSettings, + search::{SearchQuery, SearchResult}, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use std::{cmp, fmt::Write, sync::Arc}; +use util::RangeExt; +use util::markdown::MarkdownInlineCode; +use util::paths::PathMatcher; + +/// Searches the contents of files in the project with a regular expression +/// +/// - Prefer this tool to path search when searching for symbols in the project, because you won't need to guess what path it's in. +/// - Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.) +/// - Pass an `include_pattern` if you know how to narrow your search on the files system +/// - Never use this tool to search for paths. Only search file contents with this tool. +/// - Use this tool when you need to find files containing specific patterns +/// - Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. +/// - DO NOT use HTML entities solely to escape characters in the tool parameters. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct GrepToolInput { + /// A regex pattern to search for in the entire project. Note that the regex + /// will be parsed by the Rust `regex` crate. + /// + /// Do NOT specify a path here! This will only be matched against the code **content**. + pub regex: String, + /// A glob pattern for the paths of files to include in the search. + /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts". + /// If omitted, all files in the project will be searched. + pub include_pattern: Option, + /// Optional starting position for paginated results (0-based). + /// When not provided, starts from the beginning. + #[serde(default)] + pub offset: u32, + /// Whether the regex is case-sensitive. Defaults to false (case-insensitive). + #[serde(default)] + pub case_sensitive: bool, +} + +impl GrepToolInput { + /// Which page of search results this is. + pub fn page(&self) -> u32 { + 1 + (self.offset / RESULTS_PER_PAGE) + } +} + +const RESULTS_PER_PAGE: u32 = 20; + +pub struct GrepTool { + project: Entity, +} + +impl GrepTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for GrepTool { + type Input = GrepToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "grep".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Search + } + + fn initial_title(&self, input: Result) -> SharedString { + match input { + Ok(input) => { + let page = input.page(); + let regex_str = MarkdownInlineCode(&input.regex); + let case_info = if input.case_sensitive { + " (case-sensitive)" + } else { + "" + }; + + if page > 1 { + format!("Get page {page} of search results for regex {regex_str}{case_info}") + } else { + format!("Search files for regex {regex_str}{case_info}") + } + } + Err(_) => "Search with regex".into(), + } + .into() + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + const CONTEXT_LINES: u32 = 2; + const MAX_ANCESTOR_LINES: u32 = 10; + + let include_matcher = match PathMatcher::new( + input + .include_pattern + .as_ref() + .into_iter() + .collect::>(), + ) { + Ok(matcher) => matcher, + Err(error) => { + return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))); + } + }; + + // Exclude global file_scan_exclusions and private_files settings + let exclude_matcher = { + let global_settings = WorktreeSettings::get_global(cx); + let exclude_patterns = global_settings + .file_scan_exclusions + .sources() + .iter() + .chain(global_settings.private_files.sources().iter()); + + match PathMatcher::new(exclude_patterns) { + Ok(matcher) => matcher, + Err(error) => { + return Task::ready(Err(anyhow!("invalid exclude pattern: {error}"))); + } + } + }; + + let query = match SearchQuery::regex( + &input.regex, + false, + input.case_sensitive, + false, + false, + include_matcher, + exclude_matcher, + true, // Always match file include pattern against *full project paths* that start with a project root. + None, + ) { + Ok(query) => query, + Err(error) => return Task::ready(Err(error)), + }; + + let results = self + .project + .update(cx, |project, cx| project.search(query, cx)); + + let project = self.project.downgrade(); + cx.spawn(async move |cx| { + futures::pin_mut!(results); + + let mut output = String::new(); + let mut skips_remaining = input.offset; + let mut matches_found = 0; + let mut has_more_matches = false; + + 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { + if ranges.is_empty() { + continue; + } + + let Ok((Some(path), mut parse_status)) = buffer.read_with(cx, |buffer, cx| { + (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status()) + }) else { + continue; + }; + + // Check if this file should be excluded based on its worktree settings + if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| { + project.find_project_path(&path, cx) + }) { + if cx.update(|cx| { + let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); + worktree_settings.is_path_excluded(&project_path.path) + || worktree_settings.is_path_private(&project_path.path) + }).unwrap_or(false) { + continue; + } + } + + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let mut ranges = ranges + .into_iter() + .map(|range| { + let matched = range.to_point(&snapshot); + let matched_end_line_len = snapshot.line_len(matched.end.row); + let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len); + let symbols = snapshot.symbols_containing(matched.start, None); + + if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) { + let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot); + let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES); + let end_col = snapshot.line_len(end_row); + let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col); + + if capped_ancestor_range.contains_inclusive(&full_lines) { + return (capped_ancestor_range, Some(full_ancestor_range), symbols) + } + } + + let mut matched = matched; + matched.start.column = 0; + matched.start.row = + matched.start.row.saturating_sub(CONTEXT_LINES); + matched.end.row = cmp::min( + snapshot.max_point().row, + matched.end.row + CONTEXT_LINES, + ); + matched.end.column = snapshot.line_len(matched.end.row); + + (matched, None, symbols) + }) + .peekable(); + + let mut file_header_written = false; + + while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){ + if skips_remaining > 0 { + skips_remaining -= 1; + continue; + } + + // We'd already found a full page of matches, and we just found one more. + if matches_found >= RESULTS_PER_PAGE { + has_more_matches = true; + break 'outer; + } + + while let Some((next_range, _, _)) = ranges.peek() { + if range.end.row >= next_range.start.row { + range.end = next_range.end; + ranges.next(); + } else { + break; + } + } + + if !file_header_written { + writeln!(output, "\n## Matches in {}", path.display())?; + file_header_written = true; + } + + let end_row = range.end.row; + output.push_str("\n### "); + + if let Some(parent_symbols) = &parent_symbols { + for symbol in parent_symbols { + write!(output, "{} › ", symbol.text)?; + } + } + + if range.start.row == end_row { + writeln!(output, "L{}", range.start.row + 1)?; + } else { + writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?; + } + + output.push_str("```\n"); + output.extend(snapshot.text_for_range(range)); + output.push_str("\n```\n"); + + if let Some(ancestor_range) = ancestor_range { + if end_row < ancestor_range.end.row { + let remaining_lines = ancestor_range.end.row - end_row; + writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; + } + } + + matches_found += 1; + } + } + + if matches_found == 0 { + Ok("No matches found".into()) + } else if has_more_matches { + Ok(format!( + "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", + input.offset + 1, + input.offset + matches_found, + input.offset + RESULTS_PER_PAGE, + )) + } else { + Ok(format!("Found {matches_found} matches:\n{output}")) + } + }) + } +} + +#[cfg(test)] +mod tests { + use crate::ToolCallEventStream; + + use super::*; + use gpui::{TestAppContext, UpdateGlobal}; + use language::{Language, LanguageConfig, LanguageMatcher}; + use project::{FakeFs, Project, WorktreeSettings}; + use serde_json::json; + use settings::SettingsStore; + use unindent::Unindent; + use util::path; + + #[gpui::test] + async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) { + init_test(cx); + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + path!("/root"), + serde_json::json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}", + "utils": { + "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}", + }, + }, + "tests": { + "test_main.rs": "fn test_main() {\n assert!(true);\n}", + } + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Test with include pattern for Rust files inside the root of the project + let input = GrepToolInput { + regex: "println".to_string(), + include_pattern: Some("root/**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!(result.contains("main.rs"), "Should find matches in main.rs"); + assert!( + result.contains("helper.rs"), + "Should find matches in helper.rs" + ); + assert!( + !result.contains("test_main.rs"), + "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)" + ); + + // Test with include pattern for src directory only + let input = GrepToolInput { + regex: "fn".to_string(), + include_pattern: Some("root/**/src/**".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!( + result.contains("main.rs"), + "Should find matches in src/main.rs" + ); + assert!( + result.contains("helper.rs"), + "Should find matches in src/utils/helper.rs" + ); + assert!( + !result.contains("test_main.rs"), + "Should not include test_main.rs as it's not in src directory" + ); + + // Test with empty include pattern (should default to all files) + let input = GrepToolInput { + regex: "fn".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!(result.contains("main.rs"), "Should find matches in main.rs"); + assert!( + result.contains("helper.rs"), + "Should find matches in helper.rs" + ); + assert!( + result.contains("test_main.rs"), + "Should include test_main.rs" + ); + } + + #[gpui::test] + async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) { + init_test(cx); + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor().clone()); + fs.insert_tree( + path!("/root"), + serde_json::json!({ + "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true", + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Test case-insensitive search (default) + let input = GrepToolInput { + regex: "uppercase".to_string(), + include_pattern: Some("**/*.txt".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!( + result.contains("UPPERCASE"), + "Case-insensitive search should match uppercase" + ); + + // Test case-sensitive search + let input = GrepToolInput { + regex: "uppercase".to_string(), + include_pattern: Some("**/*.txt".to_string()), + offset: 0, + case_sensitive: true, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!( + !result.contains("UPPERCASE"), + "Case-sensitive search should not match uppercase" + ); + + // Test case-sensitive search + let input = GrepToolInput { + regex: "LOWERCASE".to_string(), + include_pattern: Some("**/*.txt".to_string()), + offset: 0, + case_sensitive: true, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + + assert!( + !result.contains("lowercase"), + "Case-sensitive search should match lowercase" + ); + + // Test case-sensitive search for lowercase pattern + let input = GrepToolInput { + regex: "lowercase".to_string(), + include_pattern: Some("**/*.txt".to_string()), + offset: 0, + case_sensitive: true, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + assert!( + result.contains("lowercase"), + "Case-sensitive search should match lowercase text" + ); + } + + /// Helper function to set up a syntax test environment + async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity { + use unindent::Unindent; + init_test(cx); + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor().clone()); + + // Create test file with syntax structures + fs.insert_tree( + path!("/root"), + serde_json::json!({ + "test_syntax.rs": r#" + fn top_level_function() { + println!("This is at the top level"); + } + + mod feature_module { + pub mod nested_module { + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + } + } + + struct MyStruct { + field1: String, + field2: i32, + } + + impl MyStruct { + fn method_with_block() { + let condition = true; + if condition { + println!("Inside if block"); + } + } + + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + + trait Processor { + fn process(&self, input: &str) -> String; + } + + impl Processor for MyStruct { + fn process(&self, input: &str) -> String { + format!("Processed: {}", input) + } + } + "#.unindent().trim(), + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + project.update(cx, |project, _cx| { + project.languages().add(rust_lang().into()) + }); + + project + } + + #[gpui::test] + async fn test_grep_top_level_function(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line at the top level of the file + let input = GrepToolInput { + regex: "This is at the top level".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### fn top_level_function › L1-3 + ``` + fn top_level_function() { + println!("This is at the top level"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line inside a function body + let input = GrepToolInput { + regex: "Function in nested module".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14 + ``` + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_function_args_and_body(cx: &mut TestAppContext) { + let project = setup_syntax_test(cx).await; + + // Test: Line with a function argument + let input = GrepToolInput { + regex: "second_arg".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14 + ``` + pub fn nested_function( + first_arg: String, + second_arg: i32, + ) { + println!("Function in nested module"); + println!("{first_arg}"); + println!("{second_arg}"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_if_block(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line inside an if block + let input = GrepToolInput { + regex: "Inside if block".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn method_with_block › L26-28 + ``` + if condition { + println!("Inside if block"); + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_top(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the middle of a long function - should show message about remaining lines + let input = GrepToolInput { + regex: "Line 5".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L31-41 + ``` + fn long_function() { + println!("Line 1"); + println!("Line 2"); + println!("Line 3"); + println!("Line 4"); + println!("Line 5"); + println!("Line 6"); + println!("Line 7"); + println!("Line 8"); + println!("Line 9"); + println!("Line 10"); + ``` + + 3 lines remaining in ancestor node. Read the file to see all. + "# + .unindent(); + assert_eq!(result, expected); + } + + #[gpui::test] + async fn test_grep_long_function_bottom(cx: &mut TestAppContext) { + use unindent::Unindent; + let project = setup_syntax_test(cx).await; + + // Test: Line in the long function + let input = GrepToolInput { + regex: "Line 12".to_string(), + include_pattern: Some("**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; + + let result = run_grep_tool(input, project.clone(), cx).await; + let expected = r#" + Found 1 matches: + + ## Matches in root/test_syntax.rs + + ### impl MyStruct › fn long_function › L41-45 + ``` + println!("Line 10"); + println!("Line 11"); + println!("Line 12"); + } + } + ``` + "# + .unindent(); + assert_eq!(result, expected); + } + + async fn run_grep_tool( + input: GrepToolInput, + project: Entity, + cx: &mut TestAppContext, + ) -> String { + let tool = Arc::new(GrepTool { project }); + let task = cx.update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)); + + match task.await { + Ok(result) => { + if cfg!(windows) { + result.replace("root\\", "root/") + } else { + result.to_string() + } + } + Err(e) => panic!("Failed to run grep tool: {}", e), + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../../languages/src/rust/outline.scm")) + .unwrap() + } + + #[gpui::test] + async fn test_grep_security_boundaries(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + path!("/"), + json!({ + "project_root": { + "allowed_file.rs": "fn main() { println!(\"This file is in the project\"); }", + ".mysecrets": "SECRET_KEY=abc123\nfn secret() { /* private */ }", + ".secretdir": { + "config": "fn special_configuration() { /* excluded */ }" + }, + ".mymetadata": "fn custom_metadata() { /* excluded */ }", + "subdir": { + "normal_file.rs": "fn normal_file_content() { /* Normal */ }", + "special.privatekey": "fn private_key_content() { /* private */ }", + "data.mysensitive": "fn sensitive_data() { /* private */ }" + } + }, + "outside_project": { + "sensitive_file.rs": "fn outside_function() { /* This file is outside the project */ }" + } + }), + ) + .await; + + cx.update(|cx| { + use gpui::UpdateGlobal; + use project::WorktreeSettings; + use settings::SettingsStore; + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = Some(vec![ + "**/.secretdir".to_string(), + "**/.mymetadata".to_string(), + ]); + settings.private_files = Some(vec![ + "**/.mysecrets".to_string(), + "**/*.privatekey".to_string(), + "**/*.mysensitive".to_string(), + ]); + }); + }); + }); + + let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; + + // Searching for files outside the project worktree should return no results + let result = run_grep_tool( + GrepToolInput { + regex: "outside_function".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not find files outside the project worktree" + ); + + // Searching within the project should succeed + let result = run_grep_tool( + GrepToolInput { + regex: "main".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.iter().any(|p| p.contains("allowed_file.rs")), + "grep_tool should be able to search files inside worktrees" + ); + + // Searching files that match file_scan_exclusions should return no results + let result = run_grep_tool( + GrepToolInput { + regex: "special_configuration".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not search files in .secretdir (file_scan_exclusions)" + ); + + let result = run_grep_tool( + GrepToolInput { + regex: "custom_metadata".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not search .mymetadata files (file_scan_exclusions)" + ); + + // Searching private files should return no results + let result = run_grep_tool( + GrepToolInput { + regex: "SECRET_KEY".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not search .mysecrets (private_files)" + ); + + let result = run_grep_tool( + GrepToolInput { + regex: "private_key_content".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + + assert!( + paths.is_empty(), + "grep_tool should not search .privatekey files (private_files)" + ); + + let result = run_grep_tool( + GrepToolInput { + regex: "sensitive_data".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not search .mysensitive files (private_files)" + ); + + // Searching a normal file should still work, even with private_files configured + let result = run_grep_tool( + GrepToolInput { + regex: "normal_file_content".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.iter().any(|p| p.contains("normal_file.rs")), + "Should be able to search normal files" + ); + + // Path traversal attempts with .. in include_pattern should not escape project + let result = run_grep_tool( + GrepToolInput { + regex: "outside_function".to_string(), + include_pattern: Some("../outside_project/**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + assert!( + paths.is_empty(), + "grep_tool should not allow escaping project boundaries with relative paths" + ); + } + + #[gpui::test] + async fn test_grep_with_multiple_worktree_settings(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + // Create first worktree with its own private files + fs.insert_tree( + path!("/worktree1"), + json!({ + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/fixture.*"], + "private_files": ["**/secret.rs"] + }"# + }, + "src": { + "main.rs": "fn main() { let secret_key = \"hidden\"; }", + "secret.rs": "const API_KEY: &str = \"secret_value\";", + "utils.rs": "pub fn get_config() -> String { \"config\".to_string() }" + }, + "tests": { + "test.rs": "fn test_secret() { assert!(true); }", + "fixture.sql": "SELECT * FROM secret_table;" + } + }), + ) + .await; + + // Create second worktree with different private files + fs.insert_tree( + path!("/worktree2"), + json!({ + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/internal.*"], + "private_files": ["**/private.js", "**/data.json"] + }"# + }, + "lib": { + "public.js": "export function getSecret() { return 'public'; }", + "private.js": "const SECRET_KEY = \"private_value\";", + "data.json": "{\"secret_data\": \"hidden\"}" + }, + "docs": { + "README.md": "# Documentation with secret info", + "internal.md": "Internal secret documentation" + } + }), + ) + .await; + + // Set global settings + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = + Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); + settings.private_files = Some(vec!["**/.env".to_string()]); + }); + }); + }); + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + + // Wait for worktrees to be fully scanned + cx.executor().run_until_parked(); + + // Search for "secret" - should exclude files based on worktree-specific settings + let result = run_grep_tool( + GrepToolInput { + regex: "secret".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + let paths = extract_paths_from_results(&result); + + // Should find matches in non-private files + assert!( + paths.iter().any(|p| p.contains("main.rs")), + "Should find 'secret' in worktree1/src/main.rs" + ); + assert!( + paths.iter().any(|p| p.contains("test.rs")), + "Should find 'secret' in worktree1/tests/test.rs" + ); + assert!( + paths.iter().any(|p| p.contains("public.js")), + "Should find 'secret' in worktree2/lib/public.js" + ); + assert!( + paths.iter().any(|p| p.contains("README.md")), + "Should find 'secret' in worktree2/docs/README.md" + ); + + // Should NOT find matches in private/excluded files based on worktree settings + assert!( + !paths.iter().any(|p| p.contains("secret.rs")), + "Should not search in worktree1/src/secret.rs (local private_files)" + ); + assert!( + !paths.iter().any(|p| p.contains("fixture.sql")), + "Should not search in worktree1/tests/fixture.sql (local file_scan_exclusions)" + ); + assert!( + !paths.iter().any(|p| p.contains("private.js")), + "Should not search in worktree2/lib/private.js (local private_files)" + ); + assert!( + !paths.iter().any(|p| p.contains("data.json")), + "Should not search in worktree2/lib/data.json (local private_files)" + ); + assert!( + !paths.iter().any(|p| p.contains("internal.md")), + "Should not search in worktree2/docs/internal.md (local file_scan_exclusions)" + ); + + // Test with `include_pattern` specific to one worktree + let result = run_grep_tool( + GrepToolInput { + regex: "secret".to_string(), + include_pattern: Some("worktree1/**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }, + project.clone(), + cx, + ) + .await; + + let paths = extract_paths_from_results(&result); + + // Should only find matches in worktree1 *.rs files (excluding private ones) + assert!( + paths.iter().any(|p| p.contains("main.rs")), + "Should find match in worktree1/src/main.rs" + ); + assert!( + paths.iter().any(|p| p.contains("test.rs")), + "Should find match in worktree1/tests/test.rs" + ); + assert!( + !paths.iter().any(|p| p.contains("secret.rs")), + "Should not find match in excluded worktree1/src/secret.rs" + ); + assert!( + paths.iter().all(|p| !p.contains("worktree2")), + "Should not find any matches in worktree2" + ); + } + + // Helper function to extract file paths from grep results + fn extract_paths_from_results(results: &str) -> Vec { + results + .lines() + .filter(|line| line.starts_with("## Matches in ")) + .map(|line| { + line.strip_prefix("## Matches in ") + .unwrap() + .trim() + .to_string() + }) + .collect() + } +} diff --git a/crates/agent2/src/tools/list_directory_tool.rs b/crates/agent2/src/tools/list_directory_tool.rs new file mode 100644 index 0000000000..61f21d8f95 --- /dev/null +++ b/crates/agent2/src/tools/list_directory_tool.rs @@ -0,0 +1,664 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Result, anyhow}; +use gpui::{App, Entity, SharedString, Task}; +use project::{Project, WorktreeSettings}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use std::fmt::Write; +use std::{path::Path, sync::Arc}; +use util::markdown::MarkdownInlineCode; + +/// Lists files and directories in a given path. Prefer the `grep` or +/// `find_path` tools when searching the codebase. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ListDirectoryToolInput { + /// The fully-qualified path of the directory to list in the project. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1 + /// - directory2 + /// + /// You can list the contents of `directory1` by using the path `directory1`. + /// + /// + /// + /// If the project has the following root directories: + /// + /// - foo + /// - bar + /// + /// If you wanna list contents in the directory `foo/baz`, you should use the path `foo/baz`. + /// + pub path: String, +} + +pub struct ListDirectoryTool { + project: Entity, +} + +impl ListDirectoryTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for ListDirectoryTool { + type Input = ListDirectoryToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "list_directory".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Read + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + let path = MarkdownInlineCode(&input.path); + format!("List the {path} directory's contents").into() + } else { + "List directory".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + // Sometimes models will return these even though we tell it to give a path and not a glob. + // When this happens, just list the root worktree directories. + if matches!(input.path.as_str(), "." | "" | "./" | "*") { + let output = self + .project + .read(cx) + .worktrees(cx) + .filter_map(|worktree| { + worktree.read(cx).root_entry().and_then(|entry| { + if entry.is_dir() { + entry.path.to_str() + } else { + None + } + }) + }) + .collect::>() + .join("\n"); + + return Task::ready(Ok(output)); + } + + let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { + return Task::ready(Err(anyhow!("Path {} not found in project", input.path))); + }; + let Some(worktree) = self + .project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + else { + return Task::ready(Err(anyhow!("Worktree not found"))); + }; + + // Check if the directory whose contents we're listing is itself excluded or private + let global_settings = WorktreeSettings::get_global(cx); + if global_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot list directory because its path matches the user's global `file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if global_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot list directory because its path matches the user's global `private_files` setting: {}", + &input.path + ))); + } + + let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); + if worktree_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot list directory because its path matches the user's worktree`file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if worktree_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot list directory because its path matches the user's worktree `private_paths` setting: {}", + &input.path + ))); + } + + let worktree_snapshot = worktree.read(cx).snapshot(); + let worktree_root_name = worktree.read(cx).root_name().to_string(); + + let Some(entry) = worktree_snapshot.entry_for_path(&project_path.path) else { + return Task::ready(Err(anyhow!("Path not found: {}", input.path))); + }; + + if !entry.is_dir() { + return Task::ready(Err(anyhow!("{} is not a directory.", input.path))); + } + let worktree_snapshot = worktree.read(cx).snapshot(); + + let mut folders = Vec::new(); + let mut files = Vec::new(); + + for entry in worktree_snapshot.child_entries(&project_path.path) { + // Skip private and excluded files and directories + if global_settings.is_path_private(&entry.path) + || global_settings.is_path_excluded(&entry.path) + { + continue; + } + + if self + .project + .read(cx) + .find_project_path(&entry.path, cx) + .map(|project_path| { + let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); + + worktree_settings.is_path_excluded(&project_path.path) + || worktree_settings.is_path_private(&project_path.path) + }) + .unwrap_or(false) + { + continue; + } + + let full_path = Path::new(&worktree_root_name) + .join(&entry.path) + .display() + .to_string(); + if entry.is_dir() { + folders.push(full_path); + } else { + files.push(full_path); + } + } + + let mut output = String::new(); + + if !folders.is_empty() { + writeln!(output, "# Folders:\n{}", folders.join("\n")).unwrap(); + } + + if !files.is_empty() { + writeln!(output, "\n# Files:\n{}", files.join("\n")).unwrap(); + } + + if output.is_empty() { + writeln!(output, "{} is empty.", input.path).unwrap(); + } + + Task::ready(Ok(output)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{TestAppContext, UpdateGlobal}; + use indoc::indoc; + use project::{FakeFs, Project, WorktreeSettings}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + fn platform_paths(path_str: &str) -> String { + if cfg!(target_os = "windows") { + path_str.replace("/", "\\") + } else { + path_str.to_string() + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + #[gpui::test] + async fn test_list_directory_separates_files_and_dirs(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {}", + "lib.rs": "pub fn hello() {}", + "models": { + "user.rs": "struct User {}", + "post.rs": "struct Post {}" + }, + "utils": { + "helper.rs": "pub fn help() {}" + } + }, + "tests": { + "integration_test.rs": "#[test] fn test() {}" + }, + "README.md": "# Project", + "Cargo.toml": "[package]" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let tool = Arc::new(ListDirectoryTool::new(project)); + + // Test listing root directory + let input = ListDirectoryToolInput { + path: "project".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert_eq!( + output, + platform_paths(indoc! {" + # Folders: + project/src + project/tests + + # Files: + project/Cargo.toml + project/README.md + "}) + ); + + // Test listing src directory + let input = ListDirectoryToolInput { + path: "project/src".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert_eq!( + output, + platform_paths(indoc! {" + # Folders: + project/src/models + project/src/utils + + # Files: + project/src/lib.rs + project/src/main.rs + "}) + ); + + // Test listing directory with only files + let input = ListDirectoryToolInput { + path: "project/tests".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert!(!output.contains("# Folders:")); + assert!(output.contains("# Files:")); + assert!(output.contains(&platform_paths("project/tests/integration_test.rs"))); + } + + #[gpui::test] + async fn test_list_directory_empty_directory(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "empty_dir": {} + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let tool = Arc::new(ListDirectoryTool::new(project)); + + let input = ListDirectoryToolInput { + path: "project/empty_dir".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert_eq!(output, "project/empty_dir is empty.\n"); + } + + #[gpui::test] + async fn test_list_directory_error_cases(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "file.txt": "content" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let tool = Arc::new(ListDirectoryTool::new(project)); + + // Test non-existent path + let input = ListDirectoryToolInput { + path: "project/nonexistent".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await; + assert!(output.unwrap_err().to_string().contains("Path not found")); + + // Test trying to list a file instead of directory + let input = ListDirectoryToolInput { + path: "project/file.txt".into(), + }; + let output = cx + .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)) + .await; + assert!( + output + .unwrap_err() + .to_string() + .contains("is not a directory") + ); + } + + #[gpui::test] + async fn test_list_directory_security(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "normal_dir": { + "file1.txt": "content", + "file2.txt": "content" + }, + ".mysecrets": "SECRET_KEY=abc123", + ".secretdir": { + "config": "special configuration", + "secret.txt": "secret content" + }, + ".mymetadata": "custom metadata", + "visible_dir": { + "normal.txt": "normal content", + "special.privatekey": "private key content", + "data.mysensitive": "sensitive data", + ".hidden_subdir": { + "hidden_file.txt": "hidden content" + } + } + }), + ) + .await; + + // Configure settings explicitly + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = Some(vec![ + "**/.secretdir".to_string(), + "**/.mymetadata".to_string(), + "**/.hidden_subdir".to_string(), + ]); + settings.private_files = Some(vec![ + "**/.mysecrets".to_string(), + "**/*.privatekey".to_string(), + "**/*.mysensitive".to_string(), + ]); + }); + }); + }); + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let tool = Arc::new(ListDirectoryTool::new(project)); + + // Listing root directory should exclude private and excluded files + let input = ListDirectoryToolInput { + path: "project".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + + // Should include normal directories + assert!(output.contains("normal_dir"), "Should list normal_dir"); + assert!(output.contains("visible_dir"), "Should list visible_dir"); + + // Should NOT include excluded or private files + assert!( + !output.contains(".secretdir"), + "Should not list .secretdir (file_scan_exclusions)" + ); + assert!( + !output.contains(".mymetadata"), + "Should not list .mymetadata (file_scan_exclusions)" + ); + assert!( + !output.contains(".mysecrets"), + "Should not list .mysecrets (private_files)" + ); + + // Trying to list an excluded directory should fail + let input = ListDirectoryToolInput { + path: "project/.secretdir".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await; + assert!( + output + .unwrap_err() + .to_string() + .contains("file_scan_exclusions"), + "Error should mention file_scan_exclusions" + ); + + // Listing a directory should exclude private files within it + let input = ListDirectoryToolInput { + path: "project/visible_dir".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + + // Should include normal files + assert!(output.contains("normal.txt"), "Should list normal.txt"); + + // Should NOT include private files + assert!( + !output.contains("privatekey"), + "Should not list .privatekey files (private_files)" + ); + assert!( + !output.contains("mysensitive"), + "Should not list .mysensitive files (private_files)" + ); + + // Should NOT include subdirectories that match exclusions + assert!( + !output.contains(".hidden_subdir"), + "Should not list .hidden_subdir (file_scan_exclusions)" + ); + } + + #[gpui::test] + async fn test_list_directory_with_multiple_worktree_settings(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + // Create first worktree with its own private files + fs.insert_tree( + path!("/worktree1"), + json!({ + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/fixture.*"], + "private_files": ["**/secret.rs", "**/config.toml"] + }"# + }, + "src": { + "main.rs": "fn main() { println!(\"Hello from worktree1\"); }", + "secret.rs": "const API_KEY: &str = \"secret_key_1\";", + "config.toml": "[database]\nurl = \"postgres://localhost/db1\"" + }, + "tests": { + "test.rs": "mod tests { fn test_it() {} }", + "fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));" + } + }), + ) + .await; + + // Create second worktree with different private files + fs.insert_tree( + path!("/worktree2"), + json!({ + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/internal.*"], + "private_files": ["**/private.js", "**/data.json"] + }"# + }, + "lib": { + "public.js": "export function greet() { return 'Hello from worktree2'; }", + "private.js": "const SECRET_TOKEN = \"private_token_2\";", + "data.json": "{\"api_key\": \"json_secret_key\"}" + }, + "docs": { + "README.md": "# Public Documentation", + "internal.md": "# Internal Secrets and Configuration" + } + }), + ) + .await; + + // Set global settings + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = + Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); + settings.private_files = Some(vec!["**/.env".to_string()]); + }); + }); + }); + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + + // Wait for worktrees to be fully scanned + cx.executor().run_until_parked(); + + let tool = Arc::new(ListDirectoryTool::new(project)); + + // Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings + let input = ListDirectoryToolInput { + path: "worktree1/src".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert!(output.contains("main.rs"), "Should list main.rs"); + assert!( + !output.contains("secret.rs"), + "Should not list secret.rs (local private_files)" + ); + assert!( + !output.contains("config.toml"), + "Should not list config.toml (local private_files)" + ); + + // Test listing worktree1/tests - should exclude fixture.sql based on local settings + let input = ListDirectoryToolInput { + path: "worktree1/tests".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert!(output.contains("test.rs"), "Should list test.rs"); + assert!( + !output.contains("fixture.sql"), + "Should not list fixture.sql (local file_scan_exclusions)" + ); + + // Test listing worktree2/lib - should exclude private.js and data.json based on local settings + let input = ListDirectoryToolInput { + path: "worktree2/lib".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert!(output.contains("public.js"), "Should list public.js"); + assert!( + !output.contains("private.js"), + "Should not list private.js (local private_files)" + ); + assert!( + !output.contains("data.json"), + "Should not list data.json (local private_files)" + ); + + // Test listing worktree2/docs - should exclude internal.md based on local settings + let input = ListDirectoryToolInput { + path: "worktree2/docs".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await + .unwrap(); + assert!(output.contains("README.md"), "Should list README.md"); + assert!( + !output.contains("internal.md"), + "Should not list internal.md (local file_scan_exclusions)" + ); + + // Test trying to list an excluded directory directly + let input = ListDirectoryToolInput { + path: "worktree1/src/secret.rs".into(), + }; + let output = cx + .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .await; + assert!( + output + .unwrap_err() + .to_string() + .contains("Cannot list directory"), + ); + } +} diff --git a/crates/agent2/src/tools/move_path_tool.rs b/crates/agent2/src/tools/move_path_tool.rs new file mode 100644 index 0000000000..f8d5d0d176 --- /dev/null +++ b/crates/agent2/src/tools/move_path_tool.rs @@ -0,0 +1,123 @@ +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Context as _, Result, anyhow}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{path::Path, sync::Arc}; +use util::markdown::MarkdownInlineCode; + +/// Moves or rename a file or directory in the project, and returns confirmation +/// that the move succeeded. +/// +/// If the source and destination directories are the same, but the filename is +/// different, this performs a rename. Otherwise, it performs a move. +/// +/// This tool should be used when it's desirable to move or rename a file or +/// directory without changing its contents at all. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct MovePathToolInput { + /// The source path of the file or directory to move/rename. + /// + /// + /// If the project has the following files: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can move the first file by providing a source_path of "directory1/a/something.txt" + /// + pub source_path: String, + + /// The destination path where the file or directory should be moved/renamed to. + /// If the paths are the same except for the filename, then this will be a rename. + /// + /// + /// To move "directory1/a/something.txt" to "directory2/b/renamed.txt", + /// provide a destination_path of "directory2/b/renamed.txt" + /// + pub destination_path: String, +} + +pub struct MovePathTool { + project: Entity, +} + +impl MovePathTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for MovePathTool { + type Input = MovePathToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "move_path".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Move + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + let src = MarkdownInlineCode(&input.source_path); + let dest = MarkdownInlineCode(&input.destination_path); + let src_path = Path::new(&input.source_path); + let dest_path = Path::new(&input.destination_path); + + match dest_path + .file_name() + .and_then(|os_str| os_str.to_os_string().into_string().ok()) + { + Some(filename) if src_path.parent() == dest_path.parent() => { + let filename = MarkdownInlineCode(&filename); + format!("Rename {src} to {filename}").into() + } + _ => format!("Move {src} to {dest}").into(), + } + } else { + "Move path".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let rename_task = self.project.update(cx, |project, cx| { + match project + .find_project_path(&input.source_path, cx) + .and_then(|project_path| project.entry_for_path(&project_path, cx)) + { + Some(entity) => match project.find_project_path(&input.destination_path, cx) { + Some(project_path) => project.rename_entry(entity.id, project_path.path, cx), + None => Task::ready(Err(anyhow!( + "Destination path {} was outside the project.", + input.destination_path + ))), + }, + None => Task::ready(Err(anyhow!( + "Source path {} was not found in the project.", + input.source_path + ))), + } + }); + + cx.background_spawn(async move { + let _ = rename_task.await.with_context(|| { + format!("Moving {} to {}", input.source_path, input.destination_path) + })?; + Ok(format!( + "Moved {} to {}", + input.source_path, input.destination_path + )) + }) + } +} diff --git a/crates/agent2/src/tools/now_tool.rs b/crates/agent2/src/tools/now_tool.rs new file mode 100644 index 0000000000..a72ede26fe --- /dev/null +++ b/crates/agent2/src/tools/now_tool.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use agent_client_protocol as acp; +use anyhow::Result; +use chrono::{Local, Utc}; +use gpui::{App, SharedString, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{AgentTool, ToolCallEventStream}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum Timezone { + /// Use UTC for the datetime. + Utc, + /// Use local time for the datetime. + Local, +} + +/// Returns the current datetime in RFC 3339 format. +/// Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct NowToolInput { + /// The timezone to use for the datetime. + timezone: Timezone, +} + +pub struct NowTool; + +impl AgentTool for NowTool { + type Input = NowToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "now".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title(&self, _input: Result) -> SharedString { + "Get current time".into() + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + let now = match input.timezone { + Timezone::Utc => Utc::now().to_rfc3339(), + Timezone::Local => Local::now().to_rfc3339(), + }; + Task::ready(Ok(format!("The current datetime is {now}."))) + } +} diff --git a/crates/agent2/src/tools/open_tool.rs b/crates/agent2/src/tools/open_tool.rs new file mode 100644 index 0000000000..36420560c1 --- /dev/null +++ b/crates/agent2/src/tools/open_tool.rs @@ -0,0 +1,170 @@ +use crate::AgentTool; +use agent_client_protocol::ToolKind; +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{path::PathBuf, sync::Arc}; +use util::markdown::MarkdownEscaped; + +/// This tool opens a file or URL with the default application associated with +/// it on the user's operating system: +/// +/// - On macOS, it's equivalent to the `open` command +/// - On Windows, it's equivalent to `start` +/// - On Linux, it uses something like `xdg-open`, `gio open`, `gnome-open`, `kde-open`, `wslview` as appropriate +/// +/// For example, it can open a web browser with a URL, open a PDF file with the +/// default PDF viewer, etc. +/// +/// You MUST ONLY use this tool when the user has explicitly requested opening +/// something. You MUST NEVER assume that the user would like for you to use +/// this tool. +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct OpenToolInput { + /// The path or URL to open with the default application. + path_or_url: String, +} + +pub struct OpenTool { + project: Entity, +} + +impl OpenTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for OpenTool { + type Input = OpenToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "open".into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Execute + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + format!("Open `{}`", MarkdownEscaped(&input.path_or_url)).into() + } else { + "Open file or URL".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: crate::ToolCallEventStream, + cx: &mut App, + ) -> Task> { + // If path_or_url turns out to be a path in the project, make it absolute. + let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); + cx.background_spawn(async move { + authorize.await?; + + match abs_path { + Some(path) => open::that(path), + None => open::that(&input.path_or_url), + } + .context("Failed to open URL or file path")?; + + Ok(format!("Successfully opened {}", input.path_or_url)) + }) + } +} + +fn to_absolute_path( + potential_path: &str, + project: Entity, + cx: &mut App, +) -> Option { + let project = project.read(cx); + project + .find_project_path(PathBuf::from(potential_path), cx) + .and_then(|project_path| project.absolute_path(&project_path, cx)) +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use std::path::Path; + use tempfile::TempDir; + + #[gpui::test] + async fn test_to_absolute_path(cx: &mut TestAppContext) { + init_test(cx); + let temp_dir = TempDir::new().expect("Failed to create temp directory"); + let temp_path = temp_dir.path().to_string_lossy().to_string(); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + &temp_path, + serde_json::json!({ + "src": { + "main.rs": "fn main() {}", + "lib.rs": "pub fn lib_fn() {}" + }, + "docs": { + "readme.md": "# Project Documentation" + } + }), + ) + .await; + + // Use the temp_path as the root directory, not just its filename + let project = Project::test(fs.clone(), [temp_dir.path()], cx).await; + + // Test cases where the function should return Some + cx.update(|cx| { + // Project-relative paths should return Some + // Create paths using the last segment of the temp path to simulate a project-relative path + let root_dir_name = Path::new(&temp_path) + .file_name() + .unwrap_or_else(|| std::ffi::OsStr::new("temp")) + .to_string_lossy(); + + assert!( + to_absolute_path(&format!("{root_dir_name}/src/main.rs"), project.clone(), cx) + .is_some(), + "Failed to resolve main.rs path" + ); + + assert!( + to_absolute_path( + &format!("{root_dir_name}/docs/readme.md",), + project.clone(), + cx, + ) + .is_some(), + "Failed to resolve readme.md path" + ); + + // External URL should return None + let result = to_absolute_path("https://example.com", project.clone(), cx); + assert_eq!(result, None, "External URLs should return None"); + + // Path outside project + let result = to_absolute_path("../invalid/path", project.clone(), cx); + assert_eq!(result, None, "Paths outside the project should return None"); + }); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs index 7bbe3ac4c1..f21643cbbb 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -1,16 +1,16 @@ -use agent_client_protocol::{self as acp}; -use anyhow::{anyhow, Context, Result}; -use assistant_tool::{outline, ActionLog}; -use gpui::{Entity, Task}; +use action_log::ActionLog; +use agent_client_protocol::{self as acp, ToolCallUpdateFields}; +use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::outline; +use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; -use language::{Anchor, Point}; +use language::Point; use language_model::{LanguageModelImage, LanguageModelToolResultContent}; -use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings}; +use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use std::sync::Arc; -use ui::{App, SharedString}; use crate::{AgentTool, ToolCallEventStream}; @@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { @@ -166,7 +166,9 @@ impl AgentTool for ReadFileTool { cx.spawn(async move |cx| { let buffer = cx .update(|cx| { - project.update(cx, |project, cx| project.open_buffer(project_path, cx)) + project.update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + }) })? .await?; if buffer.read_with(cx, |buffer, _| { @@ -178,19 +180,10 @@ impl AgentTool for ReadFileTool { anyhow::bail!("{file_path} not found"); } - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: Anchor::MIN, - }), - cx, - ); - })?; + let mut anchor = None; // Check if specific line ranges are provided - if input.start_line.is_some() || input.end_line.is_some() { - let mut anchor = None; + let result = if input.start_line.is_some() || input.end_line.is_some() { let result = buffer.read_with(cx, |buffer, _cx| { let text = buffer.text(); // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0. @@ -214,18 +207,6 @@ impl AgentTool for ReadFileTool { log.buffer_read(buffer.clone(), cx); })?; - if let Some(anchor) = anchor { - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: anchor, - }), - cx, - ); - })?; - } - Ok(result.into()) } else { // No line ranges specified, so check file size to see if it's too big. @@ -236,7 +217,7 @@ impl AgentTool for ReadFileTool { let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?; action_log.update(cx, |log, cx| { - log.buffer_read(buffer, cx); + log.buffer_read(buffer.clone(), cx); })?; Ok(result.into()) @@ -244,7 +225,8 @@ impl AgentTool for ReadFileTool { // File is too big, so return the outline // and a suggestion to read again with line numbers. let outline = - outline::file_outline(project, file_path, action_log, None, cx).await?; + outline::file_outline(project.clone(), file_path, action_log, None, cx) + .await?; Ok(formatdoc! {" This file was too big to read all at once. @@ -261,7 +243,28 @@ impl AgentTool for ReadFileTool { } .into()) } - } + }; + + project.update(cx, |project, cx| { + if let Some(abs_path) = project.absolute_path(&project_path, cx) { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: anchor.unwrap_or(text::Anchor::MIN), + }), + cx, + ); + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: abs_path, + line: input.start_line.map(|line| line.saturating_sub(1)), + }]), + ..Default::default() + }); + } + })?; + + result }) } } @@ -270,7 +273,7 @@ impl AgentTool for ReadFileTool { mod test { use super::*; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; - use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; + use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs new file mode 100644 index 0000000000..ecb855ac34 --- /dev/null +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -0,0 +1,473 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use futures::{FutureExt as _, future::Shared}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::{Project, terminals::TerminalKind}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; +use util::{ResultExt, get_system_shell, markdown::MarkdownInlineCode}; + +use crate::{AgentTool, ToolCallEventStream}; + +const COMMAND_OUTPUT_LIMIT: usize = 16 * 1024; + +/// Executes a shell one-liner and returns the combined output. +/// +/// This tool spawns a process using the user's shell, reads from stdout and stderr (preserving the order of writes), and returns a string with the combined output result. +/// +/// The output results will be shown to the user already, only list it again if necessary, avoid being redundant. +/// +/// Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error. +/// +/// Do not use this tool for commands that run indefinitely, such as servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers that don't terminate on their own. +/// +/// Remember that each invocation of this tool will spawn a new shell process, so you can't rely on any state from previous invocations. +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct TerminalToolInput { + /// The one-liner command to execute. + command: String, + /// Working directory for the command. This must be one of the root directories of the project. + cd: String, +} + +pub struct TerminalTool { + project: Entity, + determine_shell: Shared>, +} + +impl TerminalTool { + pub fn new(project: Entity, cx: &mut App) -> Self { + let determine_shell = cx.background_spawn(async move { + if cfg!(windows) { + return get_system_shell(); + } + + if which::which("bash").is_ok() { + log::info!("agent selected bash for terminal tool"); + "bash".into() + } else { + let shell = get_system_shell(); + log::info!("agent selected {shell} for terminal tool"); + shell + } + }); + Self { + project, + determine_shell: determine_shell.shared(), + } + } +} + +impl AgentTool for TerminalTool { + type Input = TerminalToolInput; + type Output = String; + + fn name(&self) -> SharedString { + "terminal".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Execute + } + + fn initial_title(&self, input: Result) -> SharedString { + if let Ok(input) = input { + let mut lines = input.command.lines(); + let first_line = lines.next().unwrap_or_default(); + let remaining_line_count = lines.count(); + match remaining_line_count { + 0 => MarkdownInlineCode(&first_line).to_string().into(), + 1 => MarkdownInlineCode(&format!( + "{} - {} more line", + first_line, remaining_line_count + )) + .to_string() + .into(), + n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)) + .to_string() + .into(), + } + } else { + "Run terminal command".into() + } + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let language_registry = self.project.read(cx).languages().clone(); + let working_dir = match working_dir(&input, &self.project, cx) { + Ok(dir) => dir, + Err(err) => return Task::ready(Err(err)), + }; + let program = self.determine_shell.clone(); + let command = if cfg!(windows) { + format!("$null | & {{{}}}", input.command.replace("\"", "'")) + } else if let Some(cwd) = working_dir + .as_ref() + .and_then(|cwd| cwd.as_os_str().to_str()) + { + // Make sure once we're *inside* the shell, we cd into `cwd` + format!("(cd {cwd}; {}) self.project.update(cx, |project, cx| { + project.directory_environment(dir.as_path().into(), cx) + }), + None => Task::ready(None).shared(), + }; + + let env = cx.spawn(async move |_| { + let mut env = env.await.unwrap_or_default(); + if cfg!(unix) { + env.insert("PAGER".into(), "cat".into()); + } + env + }); + + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); + + cx.spawn({ + async move |cx| { + authorize.await?; + + let program = program.await; + let env = env.await; + let terminal = self + .project + .update(cx, |project, cx| { + project.create_terminal( + TerminalKind::Task(task::SpawnInTerminal { + command: Some(program), + args, + cwd: working_dir.clone(), + env, + ..Default::default() + }), + cx, + ) + })? + .await?; + let acp_terminal = cx.new(|cx| { + acp_thread::Terminal::new( + input.command.clone(), + working_dir.clone(), + terminal.clone(), + language_registry, + cx, + ) + })?; + event_stream.update_terminal(acp_terminal.clone()); + + let exit_status = terminal + .update(cx, |terminal, cx| terminal.wait_for_completed_task(cx))? + .await; + let (content, content_line_count) = terminal.read_with(cx, |terminal, _| { + (terminal.get_content(), terminal.total_lines()) + })?; + + let (processed_content, finished_with_empty_output) = process_content( + &content, + &input.command, + exit_status.map(portable_pty::ExitStatus::from), + ); + + acp_terminal + .update(cx, |terminal, cx| { + terminal.finish( + exit_status, + content.len(), + processed_content.len(), + content_line_count, + finished_with_empty_output, + cx, + ); + }) + .log_err(); + + Ok(processed_content) + } + }) + } +} + +fn process_content( + content: &str, + command: &str, + exit_status: Option, +) -> (String, bool) { + let should_truncate = content.len() > COMMAND_OUTPUT_LIMIT; + + let content = if should_truncate { + let mut end_ix = COMMAND_OUTPUT_LIMIT.min(content.len()); + while !content.is_char_boundary(end_ix) { + end_ix -= 1; + } + // Don't truncate mid-line, clear the remainder of the last line + end_ix = content[..end_ix].rfind('\n').unwrap_or(end_ix); + &content[..end_ix] + } else { + content + }; + let content = content.trim(); + let is_empty = content.is_empty(); + let content = format!("```\n{content}\n```"); + let content = if should_truncate { + format!( + "Command output too long. The first {} bytes:\n\n{content}", + content.len(), + ) + } else { + content + }; + + let content = match exit_status { + Some(exit_status) if exit_status.success() => { + if is_empty { + "Command executed successfully.".to_string() + } else { + content.to_string() + } + } + Some(exit_status) => { + if is_empty { + format!( + "Command \"{command}\" failed with exit code {}.", + exit_status.exit_code() + ) + } else { + format!( + "Command \"{command}\" failed with exit code {}.\n\n{content}", + exit_status.exit_code() + ) + } + } + None => { + format!( + "Command failed or was interrupted.\nPartial output captured:\n\n{}", + content, + ) + } + }; + (content, is_empty) +} + +fn working_dir( + input: &TerminalToolInput, + project: &Entity, + cx: &mut App, +) -> Result> { + let project = project.read(cx); + let cd = &input.cd; + + if cd == "." || cd == "" { + // Accept "." or "" as meaning "the one worktree" if we only have one worktree. + let mut worktrees = project.worktrees(cx); + + match worktrees.next() { + Some(worktree) => { + anyhow::ensure!( + worktrees.next().is_none(), + "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.", + ); + Ok(Some(worktree.read(cx).abs_path().to_path_buf())) + } + None => Ok(None), + } + } else { + let input_path = Path::new(cd); + + if input_path.is_absolute() { + // Absolute paths are allowed, but only if they're in one of the project's worktrees. + if project + .worktrees(cx) + .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path())) + { + return Ok(Some(input_path.into())); + } + } else { + if let Some(worktree) = project.worktree_for_root_name(cd, cx) { + return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); + } + } + + anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees."); + } +} + +#[cfg(test)] +mod tests { + use agent_settings::AgentSettings; + use editor::EditorSettings; + use fs::RealFs; + use gpui::{BackgroundExecutor, TestAppContext}; + use pretty_assertions::assert_eq; + use serde_json::json; + use settings::{Settings, SettingsStore}; + use terminal::terminal_settings::TerminalSettings; + use theme::ThemeSettings; + use util::test::TempTree; + + use crate::AgentResponseEvent; + + use super::*; + + fn init_test(executor: &BackgroundExecutor, cx: &mut TestAppContext) { + zlog::init_test(); + + executor.allow_parking(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + ThemeSettings::register(cx); + TerminalSettings::register(cx); + EditorSettings::register(cx); + AgentSettings::register(cx); + }); + } + + #[gpui::test] + async fn test_interactive_command(executor: BackgroundExecutor, cx: &mut TestAppContext) { + if cfg!(windows) { + return; + } + + init_test(&executor, cx); + + let fs = Arc::new(RealFs::new(None, executor)); + let tree = TempTree::new(json!({ + "project": {}, + })); + let project: Entity = + Project::test(fs, [tree.path().join("project").as_path()], cx).await; + + let input = TerminalToolInput { + command: "cat".to_owned(), + cd: tree + .path() + .join("project") + .as_path() + .to_string_lossy() + .to_string(), + }; + let (event_stream_tx, mut event_stream_rx) = ToolCallEventStream::test(); + let result = cx + .update(|cx| Arc::new(TerminalTool::new(project, cx)).run(input, event_stream_tx, cx)); + + let auth = event_stream_rx.expect_authorization().await; + auth.response.send(auth.options[0].id.clone()).unwrap(); + event_stream_rx.expect_terminal().await; + assert_eq!(result.await.unwrap(), "Command executed successfully."); + } + + #[gpui::test] + async fn test_working_directory(executor: BackgroundExecutor, cx: &mut TestAppContext) { + if cfg!(windows) { + return; + } + + init_test(&executor, cx); + + let fs = Arc::new(RealFs::new(None, executor)); + let tree = TempTree::new(json!({ + "project": {}, + "other-project": {}, + })); + let project: Entity = + Project::test(fs, [tree.path().join("project").as_path()], cx).await; + + let check = |input, expected, cx: &mut TestAppContext| { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let result = cx.update(|cx| { + Arc::new(TerminalTool::new(project.clone(), cx)).run(input, stream_tx, cx) + }); + cx.run_until_parked(); + let event = stream_rx.try_next(); + if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { + auth.response.send(auth.options[0].id.clone()).unwrap(); + } + + cx.spawn(async move |_| { + let output = result.await; + assert_eq!(output.ok(), expected); + }) + }; + + check( + TerminalToolInput { + command: "pwd".into(), + cd: ".".into(), + }, + Some(format!( + "```\n{}\n```", + tree.path().join("project").display() + )), + cx, + ) + .await; + + check( + TerminalToolInput { + command: "pwd".into(), + cd: "other-project".into(), + }, + None, // other-project is a dir, but *not* a worktree (yet) + cx, + ) + .await; + + // Absolute path above the worktree root + check( + TerminalToolInput { + command: "pwd".into(), + cd: tree.path().to_string_lossy().into(), + }, + None, + cx, + ) + .await; + + project + .update(cx, |project, cx| { + project.create_worktree(tree.path().join("other-project"), true, cx) + }) + .await + .unwrap(); + + check( + TerminalToolInput { + command: "pwd".into(), + cd: "other-project".into(), + }, + Some(format!( + "```\n{}\n```", + tree.path().join("other-project").display() + )), + cx, + ) + .await; + + check( + TerminalToolInput { + command: "pwd".into(), + cd: ".".into(), + }, + None, + cx, + ) + .await; + } +} diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs new file mode 100644 index 0000000000..c1c0970742 --- /dev/null +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use crate::{AgentTool, ToolCallEventStream}; +use agent_client_protocol as acp; +use anyhow::{Result, anyhow}; +use cloud_llm_client::WebSearchResponse; +use gpui::{App, AppContext, Task}; +use language_model::{ + LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use ui::prelude::*; +use web_search::WebSearchRegistry; + +/// Search the web for information using your query. +/// Use this when you need real-time information, facts, or data that might not be in your training. \ +/// Results will include snippets and links from relevant web pages. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct WebSearchToolInput { + /// The search term or question to query on the web. + query: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(transparent)] +pub struct WebSearchToolOutput(WebSearchResponse); + +impl From for LanguageModelToolResultContent { + fn from(value: WebSearchToolOutput) -> Self { + serde_json::to_string(&value.0) + .expect("Failed to serialize WebSearchResponse") + .into() + } +} + +pub struct WebSearchTool; + +impl AgentTool for WebSearchTool { + type Input = WebSearchToolInput; + type Output = WebSearchToolOutput; + + fn name(&self) -> SharedString { + "web_search".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Fetch + } + + fn initial_title(&self, _input: Result) -> SharedString { + "Searching the Web".into() + } + + /// We currently only support Zed Cloud as a provider. + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + provider == &ZED_CLOUD_PROVIDER_ID + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { + return Task::ready(Err(anyhow!("Web search is not available."))); + }; + + let search_task = provider.search(input.query, cx); + cx.background_spawn(async move { + let response = match search_task.await { + Ok(response) => response, + Err(err) => { + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some("Web Search Failed".to_string()), + ..Default::default() + }); + return Err(err); + } + }; + + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); + Ok(WebSearchToolOutput(response)) + }) + } +} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 8d85435f92..74647f7313 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, path::Path, rc::Rc}; +use std::{any::Any, cell::RefCell, path::Path, rc::Rc}; use ui::App; use util::ResultExt as _; @@ -135,9 +135,9 @@ impl acp_old::Client for OldAcpClientDelegate { let response = cx .update(|cx| { self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, acp_options, cx) + thread.request_tool_call_authorization(tool_call.into(), acp_options, cx) }) - })? + })?? .context("Failed to update thread")? .await; @@ -168,7 +168,7 @@ impl acp_old::Client for OldAcpClientDelegate { cx, ) }) - })? + })?? .context("Failed to update thread")?; Ok(acp_old::PushToolCallResponse { @@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection { self: Rc, project: Entity, _cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let task = self.connection.request_any( acp_old::InitializeParams { @@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -506,4 +507,8 @@ impl AgentConnection for AcpConnection { }) .detach_and_log_err(cx) } + + fn into_any(self: Rc) -> Rc { + self + } } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index ff71783b48..506ae80886 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -1,11 +1,13 @@ use agent_client_protocol::{self as acp, Agent as _}; use anyhow::anyhow; use collections::HashMap; +use futures::AsyncBufReadExt as _; use futures::channel::oneshot; +use futures::io::BufReader; use project::Project; -use std::cell::RefCell; use std::path::Path; use std::rc::Rc; +use std::{any::Any, cell::RefCell}; use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; @@ -40,12 +42,13 @@ impl AcpConnection { .current_dir(root_dir) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) + .stderr(std::process::Stdio::piped()) .kill_on_drop(true) .spawn()?; - let stdout = child.stdout.take().expect("Failed to take stdout"); - let stdin = child.stdin.take().expect("Failed to take stdin"); + let stdout = child.stdout.take().context("Failed to take stdout")?; + let stdin = child.stdin.take().context("Failed to take stdin")?; + let stderr = child.stderr.take().context("Failed to take stderr")?; log::trace!("Spawned (pid: {})", child.id()); let sessions = Rc::new(RefCell::new(HashMap::default())); @@ -63,6 +66,18 @@ impl AcpConnection { let io_task = cx.background_spawn(io_task); + cx.background_spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line).await + && n > 0 + { + log::warn!("agent stderr: {}", &line); + line.clear(); + } + }) + .detach(); + cx.spawn({ let sessions = sessions.clone(); async move |cx| { @@ -111,7 +126,7 @@ impl AgentConnection for AcpConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let conn = self.connection.clone(); let sessions = self.sessions.clone(); @@ -171,6 +186,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -190,6 +206,10 @@ impl AgentConnection for AcpConnection { .spawn(async move { conn.cancel(params).await }) .detach(); } + + fn into_any(self: Rc) -> Rc { + self + } } struct ClientDelegate { @@ -213,7 +233,7 @@ impl acp::Client for ClientDelegate { thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) })?; - let result = rx.await; + let result = rx?.await; let outcome = match result { Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index c65508f152..4b3a173349 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,6 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; +use std::any::Any; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; @@ -13,7 +14,7 @@ use std::rc::Rc; use uuid::Uuid; use agent_client_protocol as acp; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use futures::channel::oneshot; use futures::{AsyncBufReadExt, AsyncWriteExt}; use futures::{ @@ -74,7 +75,7 @@ impl AgentConnection for ClaudeAgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let cwd = cwd.to_owned(); cx.spawn(async move |cx| { @@ -129,12 +130,25 @@ impl AgentConnection for ClaudeAgentConnection { &cwd, )?; - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); + let stdout = child.stdout.take().context("Failed to take stdout")?; + let stdin = child.stdin.take().context("Failed to take stdin")?; + let stderr = child.stderr.take().context("Failed to take stderr")?; let pid = child.id(); log::trace!("Spawned (pid: {})", pid); + cx.background_spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line).await + && n > 0 + { + log::warn!("agent stderr: {}", &line); + line.clear(); + } + }) + .detach(); + cx.background_spawn(async move { let mut outgoing_rx = Some(outgoing_rx); @@ -210,6 +224,7 @@ impl AgentConnection for ClaudeAgentConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -288,6 +303,10 @@ impl AgentConnection for ClaudeAgentConnection { }) .log_err(); } + + fn into_any(self: Rc) -> Rc { + self + } } #[derive(Clone, Copy)] @@ -339,7 +358,7 @@ fn spawn_claude( .current_dir(root_dir) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) + .stderr(std::process::Stdio::piped()) .kill_on_drop(true) .spawn()?; @@ -423,7 +442,7 @@ impl ClaudeAgentSession { if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { - thread.push_user_content_block(text.into(), cx) + thread.push_user_content_block(None, text.into(), cx) }) .log_err(); } @@ -541,8 +560,9 @@ impl ClaudeAgentSession { thread.upsert_tool_call( claude_tool.as_acp(acp::ToolCallId(id.into())), cx, - ); + )?; } + anyhow::Ok(()) }) .log_err(); } diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 53a8556e74..22cb2f8f8d 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -154,7 +154,7 @@ impl McpServerTool for PermissionTool { let chosen_option = thread .update(cx, |thread, cx| { thread.request_tool_call_authorization( - claude_tool.as_acp(tool_call_id), + claude_tool.as_acp(tool_call_id).into(), vec![ acp::PermissionOption { id: allow_option_id.clone(), @@ -169,7 +169,7 @@ impl McpServerTool for PermissionTool { ], cx, ) - })? + })?? .await?; let response = if chosen_option == allow_option_id { diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index ec6ca29b9d..5af7010f26 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -422,8 +422,8 @@ pub async fn new_test_thread( .await .unwrap(); - let thread = connection - .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx)) .await .unwrap(); diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index a6b8633b34..402cf81678 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -48,6 +48,20 @@ pub struct AgentProfileSettings { pub context_servers: IndexMap, ContextServerPreset>, } +impl AgentProfileSettings { + pub fn is_tool_enabled(&self, tool_name: &str) -> bool { + self.tools.get(tool_name) == Some(&true) + } + + pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool { + self.enable_all_context_servers + || self + .context_servers + .get(server_id) + .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true)) + } +} + #[derive(Debug, Clone, Default)] pub struct ContextServerPreset { pub tools: IndexMap, bool>, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index e6a79963d6..fd38ba1f7f 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -309,7 +309,7 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, - /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. /// /// Default: false use_modifier_to_send: Option, @@ -442,10 +442,6 @@ impl Settings for AgentSettings { &mut settings.inline_alternatives, value.inline_alternatives.clone(), ); - merge( - &mut settings.always_allow_tool_actions, - value.always_allow_tool_actions, - ); merge( &mut settings.notify_when_agent_waiting, value.notify_when_agent_waiting, @@ -507,6 +503,20 @@ impl Settings for AgentSettings { } } + debug_assert_eq!( + sources.default.always_allow_tool_actions.unwrap_or(false), + false, + "For security, agent.always_allow_tool_actions should always be false in default.json. If it's true, that is a bug that should be fixed!" + ); + + // For security reasons, only trust the user's global settings for whether to always allow tool actions. + // If this could be overridden locally, an attacker could (e.g. by committing to source control and + // convincing you to switch branches) modify your project-local settings to disable the agent's safety checks. + settings.always_allow_tool_actions = sources + .user + .and_then(|setting| setting.always_allow_tool_actions) + .unwrap_or(false); + Ok(settings) } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index c145df0eae..fbf8590e68 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -17,6 +17,7 @@ test-support = ["gpui/test-support", "language/test-support"] [dependencies] acp_thread.workspace = true +action_log.workspace = true agent-client-protocol.workspace = true agent.workspace = true agent2.workspace = true @@ -49,7 +50,6 @@ fuzzy.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true -indexed_docs.workspace = true indoc.workspace = true inventory.workspace = true itertools.workspace = true @@ -92,6 +92,7 @@ time.workspace = true time_format.workspace = true ui.workspace = true ui_input.workspace = true +url.workspace = true urlencoding.workspace = true util.workspace = true uuid.workspace = true @@ -101,6 +102,9 @@ workspace.workspace = true zed_actions.workspace = true [dev-dependencies] +acp_thread = { workspace = true, features = ["test-support"] } +agent = { workspace = true, features = ["test-support"] } +assistant_context = { workspace = true, features = ["test-support"] } assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index cc476b1a86..831d296eeb 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -1,6 +1,10 @@ mod completion_provider; -mod message_history; +mod entry_view_state; +mod message_editor; +mod model_selector; +mod model_selector_popover; mod thread_view; -pub use message_history::MessageHistory; +pub use model_selector::AcpModelSelector; +pub use model_selector_popover::AcpModelSelectorPopover; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index d8f452afa5..8a413fc91e 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,61 +1,426 @@ use std::ops::Range; -use std::path::Path; use std::sync::Arc; use std::sync::atomic::AtomicBool; +use acp_thread::MentionUri; use anyhow::Result; -use collections::HashMap; -use editor::display_map::CreaseId; use editor::{CompletionProvider, Editor, ExcerptId}; -use file_icons::FileIcons; +use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{App, Entity, Task, WeakEntity}; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; -use parking_lot::Mutex; -use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId}; +use project::{ + Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, +}; +use prompt_store::PromptStore; use rope::Point; -use text::{Anchor, ToPoint}; +use text::{Anchor, ToPoint as _}; use ui::prelude::*; use workspace::Workspace; -use crate::context_picker::MentionLink; -use crate::context_picker::file_context_picker::{extract_file_name_and_directory, search_files}; +use agent::thread_store::{TextThreadStore, ThreadStore}; -#[derive(Default)] -pub struct MentionSet { - paths_by_crease_id: HashMap, +use crate::acp::message_editor::MessageEditor; +use crate::context_picker::file_context_picker::{FileMatch, search_files}; +use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; +use crate::context_picker::symbol_context_picker::SymbolMatch; +use crate::context_picker::symbol_context_picker::search_symbols; +use crate::context_picker::thread_context_picker::{ + ThreadContextEntry, ThreadMatch, search_threads, +}; +use crate::context_picker::{ + ContextPickerAction, ContextPickerEntry, ContextPickerMode, RecentEntry, + available_context_picker_entries, recent_context_picker_entries, selection_ranges, +}; + +pub(crate) enum Match { + File(FileMatch), + Symbol(SymbolMatch), + Thread(ThreadMatch), + Fetch(SharedString), + Rules(RulesContextEntry), + Entry(EntryMatch), } -impl MentionSet { - pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) { - self.paths_by_crease_id.insert(crease_id, path); - } +pub struct EntryMatch { + mat: Option, + entry: ContextPickerEntry, +} - pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option { - self.paths_by_crease_id.get(&crease_id).cloned() +impl Match { + pub fn score(&self) -> f64 { + match self { + Match::File(file) => file.mat.score, + Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.), + Match::Thread(_) => 1., + Match::Symbol(_) => 1., + Match::Rules(_) => 1., + Match::Fetch(_) => 1., + } } +} - pub fn drain(&mut self) -> impl Iterator { - self.paths_by_crease_id.drain().map(|(id, _)| id) +fn search( + mode: Option, + query: String, + cancellation_flag: Arc, + recent_entries: Vec, + prompt_store: Option>, + thread_store: WeakEntity, + text_thread_context_store: WeakEntity, + workspace: Entity, + cx: &mut App, +) -> Task> { + match mode { + Some(ContextPickerMode::File) => { + let search_files_task = + search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + cx.background_spawn(async move { + search_files_task + .await + .into_iter() + .map(Match::File) + .collect() + }) + } + + Some(ContextPickerMode::Symbol) => { + let search_symbols_task = + search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx); + cx.background_spawn(async move { + search_symbols_task + .await + .into_iter() + .map(Match::Symbol) + .collect() + }) + } + + Some(ContextPickerMode::Thread) => { + if let Some((thread_store, context_store)) = thread_store + .upgrade() + .zip(text_thread_context_store.upgrade()) + { + let search_threads_task = search_threads( + query.clone(), + cancellation_flag.clone(), + thread_store, + context_store, + cx, + ); + cx.background_spawn(async move { + search_threads_task + .await + .into_iter() + .map(Match::Thread) + .collect() + }) + } else { + Task::ready(Vec::new()) + } + } + + Some(ContextPickerMode::Fetch) => { + if !query.is_empty() { + Task::ready(vec![Match::Fetch(query.into())]) + } else { + Task::ready(Vec::new()) + } + } + + Some(ContextPickerMode::Rules) => { + if let Some(prompt_store) = prompt_store.as_ref() { + let search_rules_task = + search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx); + cx.background_spawn(async move { + search_rules_task + .await + .into_iter() + .map(Match::Rules) + .collect::>() + }) + } else { + Task::ready(Vec::new()) + } + } + + None => { + if query.is_empty() { + let mut matches = recent_entries + .into_iter() + .map(|entry| match entry { + RecentEntry::File { + project_path, + path_prefix, + } => Match::File(FileMatch { + mat: fuzzy::PathMatch { + score: 1., + positions: Vec::new(), + worktree_id: project_path.worktree_id.to_usize(), + path: project_path.path, + path_prefix, + is_dir: false, + distance_to_relative_ancestor: 0, + }, + is_recent: true, + }), + RecentEntry::Thread(thread_context_entry) => Match::Thread(ThreadMatch { + thread: thread_context_entry, + is_recent: true, + }), + }) + .collect::>(); + + matches.extend( + available_context_picker_entries( + &prompt_store, + &Some(thread_store.clone()), + &workspace, + cx, + ) + .into_iter() + .map(|mode| { + Match::Entry(EntryMatch { + entry: mode, + mat: None, + }) + }), + ); + + Task::ready(matches) + } else { + let executor = cx.background_executor().clone(); + + let search_files_task = + search_files(query.clone(), cancellation_flag.clone(), &workspace, cx); + + let entries = available_context_picker_entries( + &prompt_store, + &Some(thread_store.clone()), + &workspace, + cx, + ); + let entry_candidates = entries + .iter() + .enumerate() + .map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword())) + .collect::>(); + + cx.background_spawn(async move { + let mut matches = search_files_task + .await + .into_iter() + .map(Match::File) + .collect::>(); + + let entry_matches = fuzzy::match_strings( + &entry_candidates, + &query, + false, + true, + 100, + &Arc::new(AtomicBool::default()), + executor, + ) + .await; + + matches.extend(entry_matches.into_iter().map(|mat| { + Match::Entry(EntryMatch { + entry: entries[mat.candidate_id], + mat: Some(mat), + }) + })); + + matches.sort_by(|a, b| { + b.score() + .partial_cmp(&a.score()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + matches + }) + } + } } } pub struct ContextPickerCompletionProvider { workspace: WeakEntity, - editor: WeakEntity, - mention_set: Arc>, + thread_store: WeakEntity, + text_thread_store: WeakEntity, + message_editor: WeakEntity, } impl ContextPickerCompletionProvider { pub fn new( - mention_set: Arc>, workspace: WeakEntity, - editor: WeakEntity, + thread_store: WeakEntity, + text_thread_store: WeakEntity, + message_editor: WeakEntity, ) -> Self { Self { - mention_set, workspace, - editor, + thread_store, + text_thread_store, + message_editor, + } + } + + fn completion_for_entry( + entry: ContextPickerEntry, + source_range: Range, + message_editor: WeakEntity, + workspace: &Entity, + cx: &mut App, + ) -> Option { + match entry { + ContextPickerEntry::Mode(mode) => Some(Completion { + replace_range: source_range.clone(), + new_text: format!("@{} ", mode.keyword()), + label: CodeLabel::plain(mode.label().to_string(), None), + icon_path: Some(mode.icon().path().into()), + documentation: None, + source: project::CompletionSource::Custom, + insert_text_mode: None, + // This ensures that when a user accepts this completion, the + // completion menu will still be shown after "@category " is + // inserted + confirm: Some(Arc::new(|_, _, _| true)), + }), + ContextPickerEntry::Action(action) => { + let (new_text, on_action) = match action { + ContextPickerAction::AddSelections => { + const PLACEHOLDER: &str = "selection "; + let selections = selection_ranges(workspace, cx) + .into_iter() + .enumerate() + .map(|(ix, (buffer, range))| { + ( + buffer, + range, + (PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1), + ) + }) + .collect::>(); + + let new_text: String = PLACEHOLDER.repeat(selections.len()); + + let callback = Arc::new({ + let source_range = source_range.clone(); + move |_, window: &mut Window, cx: &mut App| { + let selections = selections.clone(); + let message_editor = message_editor.clone(); + let source_range = source_range.clone(); + window.defer(cx, move |window, cx| { + message_editor + .update(cx, |message_editor, cx| { + message_editor.confirm_mention_for_selection( + source_range, + selections, + window, + cx, + ) + }) + .ok(); + }); + false + } + }); + + (new_text, callback) + } + }; + + Some(Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(action.label().to_string(), None), + icon_path: Some(action.icon().path().into()), + documentation: None, + source: project::CompletionSource::Custom, + insert_text_mode: None, + // This ensures that when a user accepts this completion, the + // completion menu will still be shown after "@category " is + // inserted + confirm: Some(on_action), + }) + } + } + } + + fn completion_for_thread( + thread_entry: ThreadContextEntry, + source_range: Range, + recent: bool, + editor: WeakEntity, + cx: &mut App, + ) -> Completion { + let uri = match &thread_entry { + ThreadContextEntry::Thread { id, title } => MentionUri::Thread { + id: id.clone(), + name: title.to_string(), + }, + ThreadContextEntry::Context { path, title } => MentionUri::TextThread { + path: path.to_path_buf(), + name: title.to_string(), + }, + }; + + let icon_for_completion = if recent { + IconName::HistoryRerun.path().into() + } else { + uri.icon_path(cx) + }; + + let new_text = format!("{} ", uri.as_link()); + + let new_text_len = new_text.len(); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(thread_entry.title().to_string(), None), + documentation: None, + insert_text_mode: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_for_completion.clone()), + confirm: Some(confirm_completion_callback( + thread_entry.title().clone(), + source_range.start, + new_text_len - 1, + editor, + uri, + )), + } + } + + fn completion_for_rules( + rule: RulesContextEntry, + source_range: Range, + editor: WeakEntity, + cx: &mut App, + ) -> Completion { + let uri = MentionUri::Rule { + id: rule.prompt_id.into(), + name: rule.title.to_string(), + }; + let new_text = format!("{} ", uri.as_link()); + let new_text_len = new_text.len(); + let icon_path = uri.icon_path(cx); + Completion { + replace_range: source_range.clone(), + new_text, + label: CodeLabel::plain(rule.title.to_string(), None), + documentation: None, + insert_text_mode: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_path.clone()), + confirm: Some(confirm_completion_callback( + rule.title.clone(), + source_range.start, + new_text_len - 1, + editor, + uri, + )), } } @@ -64,38 +429,37 @@ impl ContextPickerCompletionProvider { path_prefix: &str, is_recent: bool, is_directory: bool, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, - cx: &App, - ) -> Completion { + message_editor: WeakEntity, + project: Entity, + cx: &mut App, + ) -> Option { let (file_name, directory) = - extract_file_name_and_directory(&project_path.path, path_prefix); + crate::context_picker::file_context_picker::extract_file_name_and_directory( + &project_path.path, + path_prefix, + ); let label = build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx); - let full_path = if let Some(directory) = directory { - format!("{}{}", directory, file_name) - } else { - file_name.to_string() + + let abs_path = project.read(cx).absolute_path(&project_path, cx)?; + + let file_uri = MentionUri::File { + abs_path, + is_directory, }; - let crease_icon_path = if is_directory { - FileIcons::get_folder_icon(false, cx).unwrap_or_else(|| IconName::Folder.path().into()) - } else { - FileIcons::get_icon(Path::new(&full_path), cx) - .unwrap_or_else(|| IconName::File.path().into()) - }; + let crease_icon_path = file_uri.icon_path(cx); let completion_icon_path = if is_recent { IconName::HistoryRerun.path().into() } else { crease_icon_path.clone() }; - let new_text = format!("{} ", MentionLink::for_file(&file_name, &full_path)); + let new_text = format!("{} ", file_uri.as_link()); let new_text_len = new_text.len(); - Completion { + Some(Completion { replace_range: source_range.clone(), new_text, label, @@ -104,16 +468,83 @@ impl ContextPickerCompletionProvider { icon_path: Some(completion_icon_path), insert_text_mode: None, confirm: Some(confirm_completion_callback( - crease_icon_path, file_name, - project_path, - excerpt_id, source_range.start, new_text_len - 1, - editor, - mention_set, + message_editor, + file_uri, )), - } + }) + } + + fn completion_for_symbol( + symbol: Symbol, + source_range: Range, + message_editor: WeakEntity, + workspace: Entity, + cx: &mut App, + ) -> Option { + let project = workspace.read(cx).project().clone(); + + let label = CodeLabel::plain(symbol.name.clone(), None); + + let abs_path = project.read(cx).absolute_path(&symbol.path, cx)?; + let uri = MentionUri::Symbol { + path: abs_path, + name: symbol.name.clone(), + line_range: symbol.range.start.0.row..symbol.range.end.0.row, + }; + let new_text = format!("{} ", uri.as_link()); + let new_text_len = new_text.len(); + let icon_path = uri.icon_path(cx); + Some(Completion { + replace_range: source_range.clone(), + new_text, + label, + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_path.clone()), + insert_text_mode: None, + confirm: Some(confirm_completion_callback( + symbol.name.clone().into(), + source_range.start, + new_text_len - 1, + message_editor, + uri, + )), + }) + } + + fn completion_for_fetch( + source_range: Range, + url_to_fetch: SharedString, + message_editor: WeakEntity, + cx: &mut App, + ) -> Option { + let new_text = format!("@fetch {} ", url_to_fetch.clone()); + let url_to_fetch = url::Url::parse(url_to_fetch.as_ref()) + .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) + .ok()?; + let mention_uri = MentionUri::Fetch { + url: url_to_fetch.clone(), + }; + let icon_path = mention_uri.icon_path(cx); + Some(Completion { + replace_range: source_range.clone(), + new_text: new_text.clone(), + label: CodeLabel::plain(url_to_fetch.to_string(), None), + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(icon_path.clone()), + insert_text_mode: None, + confirm: Some(confirm_completion_callback( + url_to_fetch.to_string().into(), + source_range.start, + new_text.len() - 1, + message_editor, + mention_uri, + )), + }) } } @@ -136,7 +567,7 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: impl CompletionProvider for ContextPickerCompletionProvider { fn completions( &self, - excerpt_id: ExcerptId, + _excerpt_id: ExcerptId, buffer: &Entity, buffer_position: Anchor, _trigger: CompletionContext, @@ -159,44 +590,115 @@ impl CompletionProvider for ContextPickerCompletionProvider { return Task::ready(Ok(Vec::new())); }; + let project = workspace.read(cx).project().clone(); let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); - let editor = self.editor.clone(); - let mention_set = self.mention_set.clone(); - let MentionCompletion { argument, .. } = state; + let thread_store = self.thread_store.clone(); + let text_thread_store = self.text_thread_store.clone(); + let editor = self.message_editor.clone(); + let Ok((exclude_paths, exclude_threads)) = + self.message_editor.update(cx, |message_editor, _cx| { + message_editor.mentioned_path_and_threads() + }) + else { + return Task::ready(Ok(Vec::new())); + }; + + let MentionCompletion { mode, argument, .. } = state; let query = argument.unwrap_or_else(|| "".to_string()); - let search_task = search_files(query.clone(), Arc::::default(), &workspace, cx); + let recent_entries = recent_context_picker_entries( + Some(thread_store.clone()), + Some(text_thread_store.clone()), + workspace.clone(), + &exclude_paths, + &exclude_threads, + cx, + ); + + let prompt_store = thread_store + .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) + .ok() + .flatten(); + + let search_task = search( + mode, + query, + Arc::::default(), + recent_entries, + prompt_store, + thread_store.clone(), + text_thread_store.clone(), + workspace.clone(), + cx, + ); cx.spawn(async move |_, cx| { let matches = search_task.await; - let Some(editor) = editor.upgrade() else { - return Ok(Vec::new()); - }; let completions = cx.update(|cx| { matches .into_iter() - .map(|mat| { - let path_match = &mat.mat; - let project_path = ProjectPath { - worktree_id: WorktreeId::from_usize(path_match.worktree_id), - path: path_match.path.clone(), - }; + .filter_map(|mat| match mat { + Match::File(FileMatch { mat, is_recent }) => { + let project_path = ProjectPath { + worktree_id: WorktreeId::from_usize(mat.worktree_id), + path: mat.path.clone(), + }; - Self::completion_for_path( - project_path, - &path_match.path_prefix, - mat.is_recent, - path_match.is_dir, - excerpt_id, + Self::completion_for_path( + project_path, + &mat.path_prefix, + is_recent, + mat.is_dir, + source_range.clone(), + editor.clone(), + project.clone(), + cx, + ) + } + + Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( + symbol, source_range.clone(), editor.clone(), - mention_set.clone(), + workspace.clone(), cx, - ) + ), + + Match::Thread(ThreadMatch { + thread, is_recent, .. + }) => Some(Self::completion_for_thread( + thread, + source_range.clone(), + is_recent, + editor.clone(), + cx, + )), + + Match::Rules(user_rules) => Some(Self::completion_for_rules( + user_rules, + source_range.clone(), + editor.clone(), + cx, + )), + + Match::Fetch(url) => Self::completion_for_fetch( + source_range.clone(), + url, + editor.clone(), + cx, + ), + + Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( + entry, + source_range.clone(), + editor.clone(), + &workspace, + cx, + ), }) .collect() })?; @@ -246,35 +748,30 @@ impl CompletionProvider for ContextPickerCompletionProvider { } fn confirm_completion_callback( - crease_icon_path: SharedString, crease_text: SharedString, - project_path: ProjectPath, - excerpt_id: ExcerptId, start: Anchor, content_len: usize, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, + mention_uri: MentionUri, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { + let message_editor = message_editor.clone(); let crease_text = crease_text.clone(); - let crease_icon_path = crease_icon_path.clone(); - let editor = editor.clone(); - let project_path = project_path.clone(); - let mention_set = mention_set.clone(); + let mention_uri = mention_uri.clone(); window.defer(cx, move |window, cx| { - let crease_id = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - crease_text.clone(), - crease_icon_path, - editor.clone(), - window, - cx, - ); - if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); - } + message_editor + .clone() + .update(cx, |message_editor, cx| { + message_editor.confirm_completion( + crease_text, + start, + content_len, + mention_uri, + window, + cx, + ) + }) + .ok(); }); false }) @@ -283,6 +780,7 @@ fn confirm_completion_callback( #[derive(Debug, Default, PartialEq)] struct MentionCompletion { source_range: Range, + mode: Option, argument: Option, } @@ -302,17 +800,37 @@ impl MentionCompletion { } let rest_of_line = &line[last_mention_start + 1..]; + + let mut mode = None; let mut argument = None; let mut parts = rest_of_line.split_whitespace(); let mut end = last_mention_start + 1; - if let Some(argument_text) = parts.next() { - end += argument_text.len(); - argument = Some(argument_text.to_string()); + if let Some(mode_text) = parts.next() { + end += mode_text.len(); + + if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() { + mode = Some(parsed_mode); + } else { + argument = Some(mode_text.to_string()); + } + match rest_of_line[mode_text.len()..].find(|c: char| !c.is_whitespace()) { + Some(whitespace_count) => { + if let Some(argument_text) = parts.next() { + argument = Some(argument_text.to_string()); + end += whitespace_count + argument_text.len(); + } + } + None => { + // Rest of line is entirely whitespace + end += rest_of_line.len() - mode_text.len(); + } + } } Some(Self { source_range: last_mention_start + offset_to_line..end + offset_to_line, + mode, argument, }) } @@ -321,13 +839,6 @@ impl MentionCompletion { #[cfg(test)] mod tests { use super::*; - use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; - use project::{Project, ProjectPath}; - use serde_json::json; - use settings::SettingsStore; - use std::{ops::Deref, rc::Rc}; - use util::path; - use workspace::{AppState, Item}; #[test] fn test_mention_completion_parse() { @@ -337,238 +848,65 @@ mod tests { MentionCompletion::try_parse("Lorem @", 0), Some(MentionCompletion { source_range: 6..7, + mode: None, argument: None, }) ); + assert_eq!( + MentionCompletion::try_parse("Lorem @file", 0), + Some(MentionCompletion { + source_range: 6..11, + mode: Some(ContextPickerMode::File), + argument: None, + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file ", 0), + Some(MentionCompletion { + source_range: 6..12, + mode: Some(ContextPickerMode::File), + argument: None, + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs ", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + + assert_eq!( + MentionCompletion::try_parse("Lorem @file main.rs Ipsum", 0), + Some(MentionCompletion { + source_range: 6..19, + mode: Some(ContextPickerMode::File), + argument: Some("main.rs".to_string()), + }) + ); + assert_eq!( MentionCompletion::try_parse("Lorem @main", 0), Some(MentionCompletion { source_range: 6..11, + mode: None, argument: Some("main".to_string()), }) ); assert_eq!(MentionCompletion::try_parse("test@", 0), None); } - - struct AtMentionEditor(Entity); - - impl Item for AtMentionEditor { - type Event = (); - - fn include_in_nav_history() -> bool { - false - } - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Test".into() - } - } - - impl EventEmitter<()> for AtMentionEditor {} - - impl Focusable for AtMentionEditor { - fn focus_handle(&self, cx: &App) -> FocusHandle { - self.0.read(cx).focus_handle(cx).clone() - } - } - - impl Render for AtMentionEditor { - fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - self.0.clone().into_any_element() - } - } - - #[gpui::test] - async fn test_context_completion_provider(cx: &mut TestAppContext) { - init_test(cx); - - let app_state = cx.update(AppState::test); - - cx.update(|cx| { - language::init(cx); - editor::init(cx); - workspace::init(app_state.clone(), cx); - Project::init_settings(cx); - }); - - app_state - .fs - .as_fake() - .insert_tree( - path!("/dir"), - json!({ - "editor": "", - "a": { - "one.txt": "", - "two.txt": "", - "three.txt": "", - "four.txt": "" - }, - "b": { - "five.txt": "", - "six.txt": "", - "seven.txt": "", - "eight.txt": "", - } - }), - ) - .await; - - let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; - let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let workspace = window.root(cx).unwrap(); - - let worktree = project.update(cx, |project, cx| { - let mut worktrees = project.worktrees(cx).collect::>(); - assert_eq!(worktrees.len(), 1); - worktrees.pop().unwrap() - }); - let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); - - let mut cx = VisualTestContext::from_window(*window.deref(), cx); - - let paths = vec![ - path!("a/one.txt"), - path!("a/two.txt"), - path!("a/three.txt"), - path!("a/four.txt"), - path!("b/five.txt"), - path!("b/six.txt"), - path!("b/seven.txt"), - path!("b/eight.txt"), - ]; - - let mut opened_editors = Vec::new(); - for path in paths { - let buffer = workspace - .update_in(&mut cx, |workspace, window, cx| { - workspace.open_path( - ProjectPath { - worktree_id, - path: Path::new(path).into(), - }, - None, - false, - window, - cx, - ) - }) - .await - .unwrap(); - opened_editors.push(buffer); - } - - let editor = workspace.update_in(&mut cx, |workspace, window, cx| { - let editor = cx.new(|cx| { - Editor::new( - editor::EditorMode::full(), - multi_buffer::MultiBuffer::build_simple("", cx), - None, - window, - cx, - ) - }); - workspace.active_pane().update(cx, |pane, cx| { - pane.add_item( - Box::new(cx.new(|_| AtMentionEditor(editor.clone()))), - true, - true, - None, - window, - cx, - ); - }); - editor - }); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); - - let editor_entity = editor.downgrade(); - editor.update_in(&mut cx, |editor, window, cx| { - window.focus(&editor.focus_handle(cx)); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.downgrade(), - editor_entity, - )))); - }); - - cx.simulate_input("Lorem "); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem "); - assert!(!editor.has_visible_completions_menu()); - }); - - cx.simulate_input("@"); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem @"); - assert!(editor.has_visible_completions_menu()); - assert_eq!( - current_completion_labels(editor), - &[ - "eight.txt dir/b/", - "seven.txt dir/b/", - "six.txt dir/b/", - "five.txt dir/b/", - "four.txt dir/a/", - "three.txt dir/a/", - "two.txt dir/a/", - "one.txt dir/a/", - "dir ", - "a dir/", - "four.txt dir/a/", - "one.txt dir/a/", - "three.txt dir/a/", - "two.txt dir/a/", - "b dir/", - "eight.txt dir/b/", - "five.txt dir/b/", - "seven.txt dir/b/", - "six.txt dir/b/", - "editor dir/" - ] - ); - }); - - // Select and confirm "File" - editor.update_in(&mut cx, |editor, window, cx| { - assert!(editor.has_visible_completions_menu()); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); - editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); - }); - - cx.run_until_parked(); - - editor.update(&mut cx, |editor, cx| { - assert_eq!(editor.text(cx), "Lorem [@four.txt](@file:dir/a/four.txt) "); - }); - } - - fn current_completion_labels(editor: &Editor) -> Vec { - let completions = editor.current_completions().expect("Missing completions"); - completions - .into_iter() - .map(|completion| completion.label.text.to_string()) - .collect::>() - } - - pub(crate) fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let store = SettingsStore::test(cx); - cx.set_global(store); - theme::init(theme::LoadThemes::JustBase, cx); - client::init_settings(cx); - language::init(cx); - Project::init_settings(cx); - workspace::init_settings(cx); - editor::init_settings(cx); - }); - } } diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs new file mode 100644 index 0000000000..e99d1f6323 --- /dev/null +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -0,0 +1,444 @@ +use std::ops::Range; + +use acp_thread::{AcpThread, AgentThreadEntry}; +use agent::{TextThreadStore, ThreadStore}; +use collections::HashMap; +use editor::{Editor, EditorMode, MinimapVisibility}; +use gpui::{ + AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement, + WeakEntity, Window, +}; +use language::language_settings::SoftWrap; +use project::Project; +use settings::Settings as _; +use terminal_view::TerminalView; +use theme::ThemeSettings; +use ui::{Context, TextSize}; +use workspace::Workspace; + +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; + +pub struct EntryViewState { + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + entries: Vec, +} + +impl EntryViewState { + pub fn new( + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + ) -> Self { + Self { + workspace, + project, + thread_store, + text_thread_store, + entries: Vec::new(), + } + } + + pub fn entry(&self, index: usize) -> Option<&Entry> { + self.entries.get(index) + } + + pub fn sync_entry( + &mut self, + index: usize, + thread: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread_entry) = thread.read(cx).entries().get(index) else { + return; + }; + + match thread_entry { + AgentThreadEntry::UserMessage(message) => { + let has_id = message.id.is_some(); + let chunks = message.chunks.clone(); + let message_editor = cx.new(|cx| { + let mut editor = MessageEditor::new( + self.workspace.clone(), + self.project.clone(), + self.thread_store.clone(), + self.text_thread_store.clone(), + editor::EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ); + if !has_id { + editor.set_read_only(true, cx); + } + editor.set_message(chunks, window, cx); + editor + }); + cx.subscribe(&message_editor, move |_, editor, event, cx| { + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::MessageEditorEvent(editor, *event), + }) + }) + .detach(); + self.set_entry(index, Entry::UserMessage(message_editor)); + } + AgentThreadEntry::ToolCall(tool_call) => { + let terminals = tool_call.terminals().cloned().collect::>(); + let diffs = tool_call.diffs().cloned().collect::>(); + + let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) { + views + } else { + self.set_entry(index, Entry::empty()); + let Some(Entry::Content(views)) = self.entries.get_mut(index) else { + unreachable!() + }; + views + }; + + for terminal in terminals { + views.entry(terminal.entity_id()).or_insert_with(|| { + create_terminal( + self.workspace.clone(), + self.project.clone(), + terminal.clone(), + window, + cx, + ) + .into_any() + }); + } + + for diff in diffs { + views + .entry(diff.entity_id()) + .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any()); + } + } + AgentThreadEntry::AssistantMessage(_) => { + if index == self.entries.len() { + self.entries.push(Entry::empty()) + } + } + }; + } + + fn set_entry(&mut self, index: usize, entry: Entry) { + if index == self.entries.len() { + self.entries.push(entry); + } else { + self.entries[index] = entry; + } + } + + pub fn remove(&mut self, range: Range) { + self.entries.drain(range); + } + + pub fn settings_changed(&mut self, cx: &mut App) { + for entry in self.entries.iter() { + match entry { + Entry::UserMessage { .. } => {} + Entry::Content(response_views) => { + for view in response_views.values() { + if let Ok(diff_editor) = view.clone().downcast::() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement( + diff_editor_text_style_refinement(cx), + ); + cx.notify(); + }) + } + } + } + } + } + } +} + +impl EventEmitter for EntryViewState {} + +pub struct EntryViewEvent { + pub entry_index: usize, + pub view_event: ViewEvent, +} + +pub enum ViewEvent { + MessageEditorEvent(Entity, MessageEditorEvent), +} + +pub enum Entry { + UserMessage(Entity), + Content(HashMap), +} + +impl Entry { + pub fn message_editor(&self) -> Option<&Entity> { + match self { + Self::UserMessage(editor) => Some(editor), + Entry::Content(_) => None, + } + } + + pub fn editor_for_diff(&self, diff: &Entity) -> Option> { + self.content_map()? + .get(&diff.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + pub fn terminal( + &self, + terminal: &Entity, + ) -> Option> { + self.content_map()? + .get(&terminal.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + fn content_map(&self) -> Option<&HashMap> { + match self { + Self::Content(map) => Some(map), + _ => None, + } + } + + fn empty() -> Self { + Self::Content(HashMap::default()) + } + + #[cfg(test)] + pub fn has_content(&self) -> bool { + match self { + Self::Content(map) => !map.is_empty(), + Self::UserMessage(_) => false, + } + } +} + +fn create_terminal( + workspace: WeakEntity, + project: Entity, + terminal: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut view = TerminalView::new( + terminal.read(cx).inner().clone(), + workspace.clone(), + None, + project.downgrade(), + window, + cx, + ); + view.set_embedded_mode(Some(1000), cx); + view + }) +} + +fn create_editor_diff( + diff: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + diff.read(cx).multibuffer().clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor + }) +} + +fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { + TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + } +} + +#[cfg(test)] +mod tests { + use std::{path::Path, rc::Rc}; + + use acp_thread::{AgentConnection, StubAgentConnection}; + use agent::{TextThreadStore, ThreadStore}; + use agent_client_protocol as acp; + use agent_settings::AgentSettings; + use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; + use editor::{EditorSettings, RowInfo}; + use fs::FakeFs; + use gpui::{AppContext as _, SemanticVersion, TestAppContext}; + + use crate::acp::entry_view_state::EntryViewState; + use multi_buffer::MultiBufferRow; + use pretty_assertions::assert_matches; + use project::Project; + use serde_json::json; + use settings::{Settings as _, SettingsStore}; + use theme::ThemeSettings; + use util::path; + use workspace::Workspace; + + #[gpui::test] + async fn test_diff_sync(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "hello.txt": "hi world" + }), + ) + .await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let tool_call = acp::ToolCall { + id: acp::ToolCallId("tool".into()), + title: "Tool call".into(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::InProgress, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/hello.txt".into(), + old_text: Some("hi world".into()), + new_text: "hello world".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + }; + let connection = Rc::new(StubAgentConnection::new()); + let thread = cx + .update(|_, cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(path!("/project")), cx) + }) + .await + .unwrap(); + let session_id = thread.update(cx, |thread, _| thread.session_id().clone()); + + cx.update(|_, cx| { + connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) + }); + + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let view_state = cx.new(|_cx| { + EntryViewState::new( + workspace.downgrade(), + project.clone(), + thread_store, + text_thread_store, + ) + }); + + view_state.update_in(cx, |view_state, window, cx| { + view_state.sync_entry(0, &thread, window, cx) + }); + + let diff = thread.read_with(cx, |thread, _cx| { + thread + .entries() + .get(0) + .unwrap() + .diffs() + .next() + .unwrap() + .clone() + }); + + cx.run_until_parked(); + + let diff_editor = view_state.read_with(cx, |view_state, _cx| { + view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap() + }); + assert_eq!( + diff_editor.read_with(cx, |editor, cx| editor.text(cx)), + "hi world\nhello world" + ); + let row_infos = diff_editor.read_with(cx, |editor, cx| { + let multibuffer = editor.buffer().read(cx); + multibuffer + .snapshot(cx) + .row_infos(MultiBufferRow(0)) + .collect::>() + }); + assert_matches!( + row_infos.as_slice(), + [ + RowInfo { + multibuffer_row: Some(MultiBufferRow(0)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Deleted, + .. + }), + .. + }, + RowInfo { + multibuffer_row: Some(MultiBufferRow(1)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Added, + .. + }), + .. + } + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs new file mode 100644 index 0000000000..f6fee3b87e --- /dev/null +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -0,0 +1,1671 @@ +use crate::{ + acp::completion_provider::ContextPickerCompletionProvider, + context_picker::fetch_context_picker::fetch_url_content, +}; +use acp_thread::{MentionUri, selection_name}; +use agent::{TextThreadStore, ThreadId, ThreadStore}; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; +use collections::{HashMap, HashSet}; +use editor::{ + Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, + EditorMode, EditorStyle, ExcerptId, FoldPlaceholder, MultiBuffer, ToOffset, + actions::Paste, + display_map::{Crease, CreaseId, FoldId}, +}; +use futures::{ + FutureExt as _, TryFutureExt as _, + future::{Shared, try_join_all}, +}; +use gpui::{ + AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image, + ImageFormat, Img, Task, TextStyle, WeakEntity, +}; +use language::{Buffer, Language}; +use language_model::LanguageModelImage; +use project::{CompletionIntent, Project}; +use rope::Point; +use settings::Settings; +use std::{ + ffi::OsStr, + fmt::Write, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use text::OffsetRangeExt; +use theme::ThemeSettings; +use ui::{ + ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Icon, IconName, + IconSize, InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ParentElement, + Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div, + h_flex, +}; +use url::Url; +use util::ResultExt; +use workspace::{Workspace, notifications::NotifyResultExt as _}; +use zed_actions::agent::Chat; + +pub struct MessageEditor { + mention_set: MentionSet, + editor: Entity, + project: Entity, + workspace: WeakEntity, + thread_store: Entity, + text_thread_store: Entity, +} + +#[derive(Clone, Copy)] +pub enum MessageEditorEvent { + Send, + Cancel, + Focus, +} + +impl EventEmitter for MessageEditor {} + +impl MessageEditor { + pub fn new( + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + mode: EditorMode, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let language = Language::new( + language::LanguageConfig { + completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), + ..Default::default() + }, + None, + ); + let completion_provider = ContextPickerCompletionProvider::new( + workspace.clone(), + thread_store.downgrade(), + text_thread_store.downgrade(), + cx.weak_entity(), + ); + let mention_set = MentionSet::default(); + let editor = cx.new(|cx| { + let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + + let mut editor = Editor::new(mode, buffer, None, window, cx); + editor.set_placeholder_text("Message the agent - @ to include files", cx); + editor.set_show_indent_guides(false, cx); + editor.set_soft_wrap(); + editor.set_use_modal_editing(true); + editor.set_completion_provider(Some(Rc::new(completion_provider))); + editor.set_context_menu_options(ContextMenuOptions { + min_entries_visible: 12, + max_entries_visible: 12, + placement: Some(ContextMenuPlacement::Above), + }); + editor + }); + + cx.on_focus(&editor.focus_handle(cx), window, |_, _, cx| { + cx.emit(MessageEditorEvent::Focus) + }) + .detach(); + + Self { + editor, + project, + mention_set, + thread_store, + text_thread_store, + workspace, + } + } + + #[cfg(test)] + pub(crate) fn editor(&self) -> &Entity { + &self.editor + } + + #[cfg(test)] + pub(crate) fn mention_set(&mut self) -> &mut MentionSet { + &mut self.mention_set + } + + pub fn is_empty(&self, cx: &App) -> bool { + self.editor.read(cx).is_empty(cx) + } + + pub fn mentioned_path_and_threads(&self) -> (HashSet, HashSet) { + let mut excluded_paths = HashSet::default(); + let mut excluded_threads = HashSet::default(); + + for uri in self.mention_set.uri_by_crease_id.values() { + match uri { + MentionUri::File { abs_path, .. } => { + excluded_paths.insert(abs_path.clone()); + } + MentionUri::Thread { id, .. } => { + excluded_threads.insert(id.clone()); + } + _ => {} + } + } + + (excluded_paths, excluded_threads) + } + + pub fn confirm_completion( + &mut self, + crease_text: SharedString, + start: text::Anchor, + content_len: usize, + mention_uri: MentionUri, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self + .editor + .update(cx, |editor, cx| editor.snapshot(window, cx)); + let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { + return; + }; + let Some(anchor) = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, start) + else { + return; + }; + + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + *excerpt_id, + start, + content_len, + crease_text.clone(), + mention_uri.icon_path(cx), + self.editor.clone(), + window, + cx, + ) else { + return; + }; + + match mention_uri { + MentionUri::Fetch { url } => { + self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx); + } + MentionUri::File { + abs_path, + is_directory, + } => { + self.confirm_mention_for_file( + crease_id, + anchor, + abs_path, + is_directory, + window, + cx, + ); + } + MentionUri::Symbol { .. } + | MentionUri::Thread { .. } + | MentionUri::TextThread { .. } + | MentionUri::Rule { .. } + | MentionUri::Selection { .. } => { + self.mention_set.insert_uri(crease_id, mention_uri.clone()); + } + } + } + + fn confirm_mention_for_file( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + abs_path: PathBuf, + is_directory: bool, + window: &mut Window, + cx: &mut Context, + ) { + let extension = abs_path + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default(); + + if Img::extensions().contains(&extension) && !extension.contains("svg") { + let project = self.project.clone(); + let Some(project_path) = project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return; + }; + let image = cx.spawn(async move |_, cx| { + let image = project + .update(cx, |project, cx| project.open_image(project_path, cx))? + .await?; + image.read_with(cx, |image, _cx| image.image.clone()) + }); + self.confirm_mention_for_image(crease_id, anchor, Some(abs_path), image, window, cx); + } else { + self.mention_set.insert_uri( + crease_id, + MentionUri::File { + abs_path, + is_directory, + }, + ); + } + } + + fn confirm_mention_for_fetch( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + url: url::Url, + window: &mut Window, + cx: &mut Context, + ) { + let Some(http_client) = self + .workspace + .update(cx, |workspace, _cx| workspace.client().http_client()) + .ok() + else { + return; + }; + + let url_string = url.to_string(); + let fetch = cx + .background_executor() + .spawn(async move { + fetch_url_content(http_client, url_string) + .map_err(|e| e.to_string()) + .await + }) + .shared(); + self.mention_set + .add_fetch_result(url.clone(), fetch.clone()); + + cx.spawn_in(window, async move |this, cx| { + let fetch = fetch.await.notify_async_err(cx); + this.update(cx, |this, cx| { + let mention_uri = MentionUri::Fetch { url }; + if fetch.is_some() { + this.mention_set.insert_uri(crease_id, mention_uri.clone()); + } else { + // Remove crease if we failed to fetch + this.editor.update(cx, |editor, cx| { + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting(vec![anchor..anchor], true, cx); + }); + editor.remove_creases([crease_id], cx); + }); + } + }) + .ok(); + }) + .detach(); + } + + pub fn confirm_mention_for_selection( + &mut self, + source_range: Range, + selections: Vec<(Entity, Range, Range)>, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx); + let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else { + return; + }; + let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else { + return; + }; + + let offset = start.to_offset(&snapshot); + + for (buffer, selection_range, range_to_fold) in selections { + let range = snapshot.anchor_after(offset + range_to_fold.start) + ..snapshot.anchor_after(offset + range_to_fold.end); + + let path = buffer + .read(cx) + .file() + .map_or(PathBuf::from("untitled"), |file| file.path().to_path_buf()); + let snapshot = buffer.read(cx).snapshot(); + + let point_range = selection_range.to_point(&snapshot); + let line_range = point_range.start.row..point_range.end.row; + + let uri = MentionUri::Selection { + path: path.clone(), + line_range: line_range.clone(), + }; + let crease = crate::context_picker::crease_for_mention( + selection_name(&path, &line_range).into(), + uri.icon_path(cx), + range, + self.editor.downgrade(), + ); + + let crease_id = self.editor.update(cx, |editor, cx| { + let crease_ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + crease_ids.first().copied().unwrap() + }); + + self.mention_set + .insert_uri(crease_id, MentionUri::Selection { path, line_range }); + } + } + + pub fn contents( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Task>> { + let contents = self.mention_set.contents( + self.project.clone(), + self.thread_store.clone(), + self.text_thread_store.clone(), + window, + cx, + ); + let editor = self.editor.clone(); + + cx.spawn(async move |_, cx| { + let contents = contents.await?; + + editor.update(cx, |editor, cx| { + let mut ix = 0; + let mut chunks: Vec = Vec::new(); + let text = editor.text(cx); + editor.display_map.update(cx, |map, cx| { + let snapshot = map.snapshot(cx); + for (crease_id, crease) in snapshot.crease_snapshot.creases() { + // Skip creases that have been edited out of the message buffer. + if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { + continue; + } + + let Some(mention) = contents.get(&crease_id) else { + continue; + }; + + let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); + if crease_range.start > ix { + chunks.push(text[ix..crease_range.start].into()); + } + let chunk = match mention { + Mention::Text { uri, content } => { + acp::ContentBlock::Resource(acp::EmbeddedResource { + annotations: None, + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: content.clone(), + uri: uri.to_uri().to_string(), + }, + ), + }) + } + Mention::Image(mention_image) => { + acp::ContentBlock::Image(acp::ImageContent { + annotations: None, + data: mention_image.data.to_string(), + mime_type: mention_image.format.mime_type().into(), + uri: mention_image + .abs_path + .as_ref() + .map(|path| format!("file://{}", path.display())), + }) + } + }; + chunks.push(chunk); + ix = crease_range.end; + } + + if ix < text.len() { + let last_chunk = text[ix..].trim_end(); + if !last_chunk.is_empty() { + chunks.push(last_chunk.into()); + } + } + }); + + chunks + }) + }) + } + + pub fn clear(&mut self, window: &mut Window, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.clear(window, cx); + editor.remove_creases(self.mention_set.drain(), cx) + }); + } + + fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { + cx.emit(MessageEditorEvent::Send) + } + + fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(MessageEditorEvent::Cancel) + } + + fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context) { + let images = cx + .read_from_clipboard() + .map(|item| { + item.into_entries() + .filter_map(|entry| { + if let ClipboardEntry::Image(image) = entry { + Some(image) + } else { + None + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if images.is_empty() { + return; + } + cx.stop_propagation(); + + let replacement_text = "image"; + for image in images { + let (excerpt_id, text_anchor, multibuffer_anchor) = + self.editor.update(cx, |message_editor, cx| { + let snapshot = message_editor.snapshot(window, cx); + let (excerpt_id, _, buffer_snapshot) = + snapshot.buffer_snapshot.as_singleton().unwrap(); + + let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len()); + let multibuffer_anchor = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, text_anchor); + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + format!("{replacement_text} "), + )], + cx, + ); + (*excerpt_id, text_anchor, multibuffer_anchor) + }); + + let content_len = replacement_text.len(); + let Some(anchor) = multibuffer_anchor else { + return; + }; + let Some(crease_id) = insert_crease_for_image( + excerpt_id, + text_anchor, + content_len, + None.clone(), + self.editor.clone(), + window, + cx, + ) else { + return; + }; + self.confirm_mention_for_image( + crease_id, + anchor, + None, + Task::ready(Ok(Arc::new(image))), + window, + cx, + ); + } + } + + pub fn insert_dragged_files( + &self, + paths: Vec, + window: &mut Window, + cx: &mut Context, + ) { + let buffer = self.editor.read(cx).buffer().clone(); + let Some(buffer) = buffer.read(cx).as_singleton() else { + return; + }; + for path in paths { + let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { + continue; + }; + let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else { + continue; + }; + + let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); + let path_prefix = abs_path + .file_name() + .unwrap_or(path.path.as_os_str()) + .display() + .to_string(); + let Some(completion) = ContextPickerCompletionProvider::completion_for_path( + path, + &path_prefix, + false, + entry.is_dir(), + anchor..anchor, + cx.weak_entity(), + self.project.clone(), + cx, + ) else { + continue; + }; + + self.editor.update(cx, |message_editor, cx| { + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + completion.new_text, + )], + cx, + ); + }); + if let Some(confirm) = completion.confirm.clone() { + confirm(CompletionIntent::Complete, window, cx); + } + } + } + + pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context) { + self.editor.update(cx, |message_editor, cx| { + message_editor.set_read_only(read_only); + cx.notify() + }) + } + + fn confirm_mention_for_image( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + abs_path: Option, + image: Task>>, + window: &mut Window, + cx: &mut Context, + ) { + let editor = self.editor.clone(); + let task = cx + .spawn_in(window, async move |this, cx| { + let image = image.await.map_err(|e| e.to_string())?; + let format = image.format; + let image = cx + .update(|_, cx| LanguageModelImage::from_image(image, cx)) + .map_err(|e| e.to_string())? + .await; + if let Some(image) = image { + if let Some(abs_path) = abs_path.clone() { + this.update(cx, |this, _cx| { + this.mention_set.insert_uri( + crease_id, + MentionUri::File { + abs_path, + is_directory: false, + }, + ); + }) + .map_err(|e| e.to_string())?; + } + Ok(MentionImage { + abs_path, + data: image.source, + format, + }) + } else { + editor + .update(cx, |editor, cx| { + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting(vec![anchor..anchor], true, cx); + }); + editor.remove_creases([crease_id], cx); + }) + .ok(); + Err("Failed to convert image".to_string()) + } + }) + .shared(); + + cx.spawn_in(window, { + let task = task.clone(); + async move |_, cx| task.clone().await.notify_async_err(cx) + }) + .detach(); + + self.mention_set.insert_image(crease_id, task); + } + + pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.set_mode(mode); + cx.notify() + }); + } + + pub fn set_message( + &mut self, + message: Vec, + window: &mut Window, + cx: &mut Context, + ) { + self.clear(window, cx); + + let mut text = String::new(); + let mut mentions = Vec::new(); + let mut images = Vec::new(); + + for chunk in message { + match chunk { + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); + } + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents(resource), + .. + }) => { + if let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() { + let start = text.len(); + write!(&mut text, "{}", mention_uri.as_link()).ok(); + let end = text.len(); + mentions.push((start..end, mention_uri)); + } + } + acp::ContentBlock::Image(content) => { + let start = text.len(); + text.push_str("image"); + let end = text.len(); + images.push((start..end, content)); + } + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) + | acp::ContentBlock::ResourceLink(_) => {} + } + } + + let snapshot = self.editor.update(cx, |editor, cx| { + editor.set_text(text, window, cx); + editor.buffer().read(cx).snapshot(cx) + }); + + for (range, mention_uri) in mentions { + let anchor = snapshot.anchor_before(range.start); + let crease_id = crate::context_picker::insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + mention_uri.name().into(), + mention_uri.icon_path(cx), + self.editor.clone(), + window, + cx, + ); + + if let Some(crease_id) = crease_id { + self.mention_set.insert_uri(crease_id, mention_uri); + } + } + for (range, content) in images { + let Some(format) = ImageFormat::from_mime_type(&content.mime_type) else { + continue; + }; + let anchor = snapshot.anchor_before(range.start); + let abs_path = content + .uri + .as_ref() + .and_then(|uri| uri.strip_prefix("file://").map(|s| Path::new(s).into())); + + let name = content + .uri + .as_ref() + .and_then(|uri| { + uri.strip_prefix("file://") + .and_then(|path| Path::new(path).file_name()) + }) + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or("Image".to_owned()); + let crease_id = crate::context_picker::insert_crease_for_mention( + anchor.excerpt_id, + anchor.text_anchor, + range.end - range.start, + name.into(), + IconName::Image.path().into(), + self.editor.clone(), + window, + cx, + ); + let data: SharedString = content.data.to_string().into(); + + if let Some(crease_id) = crease_id { + self.mention_set.insert_image( + crease_id, + Task::ready(Ok(MentionImage { + abs_path, + data, + format, + })) + .shared(), + ); + } + } + cx.notify(); + } + + #[cfg(test)] + pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context) { + self.editor.update(cx, |editor, cx| { + editor.set_text(text, window, cx); + }); + } + + #[cfg(test)] + pub fn text(&self, cx: &App) -> String { + self.editor.read(cx).text(cx) + } +} + +impl Focusable for MessageEditor { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.editor.focus_handle(cx) + } +} + +impl Render for MessageEditor { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .key_context("MessageEditor") + .on_action(cx.listener(Self::send)) + .on_action(cx.listener(Self::cancel)) + .capture_action(cx.listener(Self::paste)) + .flex_1() + .child({ + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small + .rems(cx) + .to_pixels(settings.agent_font_size(cx)); + let line_height = settings.buffer_line_height.value() * font_size; + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ) + }) + } +} + +pub(crate) fn insert_crease_for_image( + excerpt_id: ExcerptId, + anchor: text::Anchor, + content_len: usize, + abs_path: Option>, + editor: Entity, + window: &mut Window, + cx: &mut App, +) -> Option { + let crease_label = abs_path + .as_ref() + .and_then(|path| path.file_name()) + .map(|name| name.to_string_lossy().to_string().into()) + .unwrap_or(SharedString::from("Image")); + + editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + + let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?; + + let start = start.bias_right(&snapshot); + let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len); + + let placeholder = FoldPlaceholder { + render: render_image_fold_icon_button(crease_label, cx.weak_entity()), + merge_adjacent: false, + ..Default::default() + }; + + let crease = Crease::Inline { + range: start..end, + placeholder, + render_toggle: None, + render_trailer: None, + metadata: None, + }; + + let ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + + Some(ids[0]) + }) +} + +fn render_image_fold_icon_button( + label: SharedString, + editor: WeakEntity, +) -> Arc, &mut App) -> AnyElement> { + Arc::new({ + move |fold_id, fold_range, cx| { + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); + + ButtonLike::new(fold_id) + .style(ButtonStyle::Filled) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + .toggle_state(is_in_text_selection) + .child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Image) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new(label.clone()) + .size(LabelSize::Small) + .buffer_font(cx) + .single_line(), + ), + ) + .into_any_element() + } + }) +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Mention { + Text { uri: MentionUri, content: String }, + Image(MentionImage), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MentionImage { + pub abs_path: Option, + pub data: SharedString, + pub format: ImageFormat, +} + +#[derive(Default)] +pub struct MentionSet { + pub(crate) uri_by_crease_id: HashMap, + fetch_results: HashMap>>>, + images: HashMap>>>, +} + +impl MentionSet { + pub fn insert_uri(&mut self, crease_id: CreaseId, uri: MentionUri) { + self.uri_by_crease_id.insert(crease_id, uri); + } + + pub fn add_fetch_result(&mut self, url: Url, content: Shared>>) { + self.fetch_results.insert(url, content); + } + + pub fn insert_image( + &mut self, + crease_id: CreaseId, + task: Shared>>, + ) { + self.images.insert(crease_id, task); + } + + pub fn drain(&mut self) -> impl Iterator { + self.fetch_results.clear(); + self.uri_by_crease_id + .drain() + .map(|(id, _)| id) + .chain(self.images.drain().map(|(id, _)| id)) + } + + pub fn contents( + &self, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + window: &mut Window, + cx: &mut App, + ) -> Task>> { + let mut processed_image_creases = HashSet::default(); + + let mut contents = self + .uri_by_crease_id + .iter() + .map(|(&crease_id, uri)| { + match uri { + MentionUri::File { abs_path, .. } => { + // TODO directories + let uri = uri.clone(); + let abs_path = abs_path.to_path_buf(); + + if let Some(task) = self.images.get(&crease_id).cloned() { + processed_image_creases.insert(crease_id); + return cx.spawn(async move |_| { + let image = task.await.map_err(|e| anyhow!("{e}"))?; + anyhow::Ok((crease_id, Mention::Image(image))) + }); + } + + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(abs_path, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + anyhow::Ok((crease_id, Mention::Text { uri, content })) + }) + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { + path, line_range, .. + } => { + let uri = uri.clone(); + let path_buf = path.clone(); + let line_range = line_range.clone(); + + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(&path_buf, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| { + buffer + .text_for_range( + Point::new(line_range.start, 0) + ..Point::new( + line_range.end, + buffer.line_len(line_range.end), + ), + ) + .collect() + })?; + + anyhow::Ok((crease_id, Mention::Text { uri, content })) + }) + } + MentionUri::Thread { id: thread_id, .. } => { + let open_task = thread_store.update(cx, |thread_store, cx| { + thread_store.open_thread(&thread_id, window, cx) + }); + + let uri = uri.clone(); + cx.spawn(async move |cx| { + let thread = open_task.await?; + let content = thread.read_with(cx, |thread, _cx| { + thread.latest_detailed_summary_or_text().to_string() + })?; + + anyhow::Ok((crease_id, Mention::Text { uri, content })) + }) + } + MentionUri::TextThread { path, .. } => { + let context = text_thread_store.update(cx, |text_thread_store, cx| { + text_thread_store.open_local_context(path.as_path().into(), cx) + }); + let uri = uri.clone(); + cx.spawn(async move |cx| { + let context = context.await?; + let xml = context.update(cx, |context, cx| context.to_xml(cx))?; + anyhow::Ok((crease_id, Mention::Text { uri, content: xml })) + }) + } + MentionUri::Rule { id: prompt_id, .. } => { + let Some(prompt_store) = thread_store.read(cx).prompt_store().clone() + else { + return Task::ready(Err(anyhow!("missing prompt store"))); + }; + let text_task = prompt_store.read(cx).load(*prompt_id, cx); + let uri = uri.clone(); + cx.spawn(async move |_| { + // TODO: report load errors instead of just logging + let text = text_task.await?; + anyhow::Ok((crease_id, Mention::Text { uri, content: text })) + }) + } + MentionUri::Fetch { url } => { + let Some(content) = self.fetch_results.get(&url).cloned() else { + return Task::ready(Err(anyhow!("missing fetch result"))); + }; + let uri = uri.clone(); + cx.spawn(async move |_| { + Ok(( + crease_id, + Mention::Text { + uri, + content: content.await.map_err(|e| anyhow::anyhow!("{e}"))?, + }, + )) + }) + } + } + }) + .collect::>(); + + // Handle images that didn't have a mention URI (because they were added by the paste handler). + contents.extend(self.images.iter().filter_map(|(crease_id, image)| { + if processed_image_creases.contains(crease_id) { + return None; + } + let crease_id = *crease_id; + let image = image.clone(); + Some(cx.spawn(async move |_| { + Ok(( + crease_id, + Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?), + )) + })) + })); + + cx.spawn(async move |_cx| { + let contents = try_join_all(contents).await?.into_iter().collect(); + anyhow::Ok(contents) + }) + } +} + +#[cfg(test)] +mod tests { + use std::{ops::Range, path::Path, sync::Arc}; + + use agent::{TextThreadStore, ThreadStore}; + use agent_client_protocol as acp; + use editor::{AnchorRangeExt as _, Editor, EditorMode}; + use fs::FakeFs; + use futures::StreamExt as _; + use gpui::{ + AppContext, Entity, EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext, + }; + use lsp::{CompletionContext, CompletionTriggerKind}; + use project::{CompletionIntent, Project, ProjectPath}; + use serde_json::json; + use text::Point; + use ui::{App, Context, IntoElement, Render, SharedString, Window}; + use util::path; + use workspace::{AppState, Item, Workspace}; + + use crate::acp::{ + message_editor::{Mention, MessageEditor}, + thread_view::tests::init_test, + }; + + #[gpui::test] + async fn test_at_mention_removal(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({"file": ""})).await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let message_editor = cx.update(|window, cx| { + cx.new(|cx| { + MessageEditor::new( + workspace.downgrade(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ) + }) + }); + let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone()); + + cx.run_until_parked(); + + let excerpt_id = editor.update(cx, |editor, cx| { + editor + .buffer() + .read(cx) + .excerpt_ids() + .into_iter() + .next() + .unwrap() + }); + let completions = editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello @file ", window, cx); + let buffer = editor.buffer().read(cx).as_singleton().unwrap(); + let completion_provider = editor.completion_provider().unwrap(); + completion_provider.completions( + excerpt_id, + &buffer, + text::Anchor::MAX, + CompletionContext { + trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER, + trigger_character: Some("@".into()), + }, + window, + cx, + ) + }); + let [_, completion]: [_; 2] = completions + .await + .unwrap() + .into_iter() + .flat_map(|response| response.completions) + .collect::>() + .try_into() + .unwrap(); + + editor.update_in(cx, |editor, window, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let start = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.start) + .unwrap(); + let end = snapshot + .anchor_in_excerpt(excerpt_id, completion.replace_range.end) + .unwrap(); + editor.edit([(start..end, completion.new_text)], cx); + (completion.confirm.unwrap())(CompletionIntent::Complete, window, cx); + }); + + cx.run_until_parked(); + + // Backspace over the inserted crease (and the following space). + editor.update_in(cx, |editor, window, cx| { + editor.backspace(&Default::default(), window, cx); + editor.backspace(&Default::default(), window, cx); + }); + + let content = message_editor + .update_in(cx, |message_editor, window, cx| { + message_editor.contents(window, cx) + }) + .await + .unwrap(); + + // We don't send a resource link for the deleted crease. + pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]); + } + + struct MessageEditorItem(Entity); + + impl Item for MessageEditorItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for MessageEditorItem {} + + impl Focusable for MessageEditorItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx).clone() + } + } + + impl Render for MessageEditorItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + + #[gpui::test] + async fn test_context_completion_provider(cx: &mut TestAppContext) { + init_test(cx); + + let app_state = cx.update(AppState::test); + + cx.update(|cx| { + language::init(cx); + editor::init(cx); + workspace::init(app_state.clone(), cx); + Project::init_settings(cx); + }); + + app_state + .fs + .as_fake() + .insert_tree( + path!("/dir"), + json!({ + "editor": "", + "a": { + "one.txt": "1", + "two.txt": "2", + "three.txt": "3", + "four.txt": "4" + }, + "b": { + "five.txt": "5", + "six.txt": "6", + "seven.txt": "7", + "eight.txt": "8", + } + }), + ) + .await; + + let project = Project::test(app_state.fs.clone(), [path!("/dir").as_ref()], cx).await; + let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let workspace = window.root(cx).unwrap(); + + let worktree = project.update(cx, |project, cx| { + let mut worktrees = project.worktrees(cx).collect::>(); + assert_eq!(worktrees.len(), 1); + worktrees.pop().unwrap() + }); + let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id()); + + let mut cx = VisualTestContext::from_window(*window, cx); + + let paths = vec![ + path!("a/one.txt"), + path!("a/two.txt"), + path!("a/three.txt"), + path!("a/four.txt"), + path!("b/five.txt"), + path!("b/six.txt"), + path!("b/seven.txt"), + path!("b/eight.txt"), + ]; + + let mut opened_editors = Vec::new(); + for path in paths { + let buffer = workspace + .update_in(&mut cx, |workspace, window, cx| { + workspace.open_path( + ProjectPath { + worktree_id, + path: Path::new(path).into(), + }, + None, + false, + window, + cx, + ) + }) + .await + .unwrap(); + opened_editors.push(buffer); + } + + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| { + let workspace_handle = cx.weak_entity(); + let message_editor = cx.new(|cx| { + MessageEditor::new( + workspace_handle, + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + EditorMode::AutoHeight { + max_lines: None, + min_lines: 1, + }, + window, + cx, + ) + }); + workspace.active_pane().update(cx, |pane, cx| { + pane.add_item( + Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))), + true, + true, + None, + window, + cx, + ); + }); + message_editor.read(cx).focus_handle(cx).focus(window); + let editor = message_editor.read(cx).editor().clone(); + (message_editor, editor) + }); + + cx.simulate_input("Lorem "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem "); + assert!(!editor.has_visible_completions_menu()); + }); + + cx.simulate_input("@"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @"); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + current_completion_labels(editor), + &[ + "eight.txt dir/b/", + "seven.txt dir/b/", + "six.txt dir/b/", + "five.txt dir/b/", + "Files & Directories", + "Symbols", + "Threads", + "Fetch" + ] + ); + }); + + // Select and confirm "File" + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.context_menu_next(&editor::actions::ContextMenuNext, window, cx); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @file "); + assert!(editor.has_visible_completions_menu()); + }); + + cx.simulate_input("one"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem @file one"); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), vec!["one.txt dir/a/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + assert!(editor.has_visible_completions_menu()); + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + pretty_assertions::assert_eq!( + contents, + [Mention::Text { + content: "1".into(), + uri: "file:///dir/a/one.txt".parse().unwrap() + }] + ); + + cx.simulate_input(" "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!(editor.text(cx), "Lorem [@one.txt](file:///dir/a/one.txt) "); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + cx.simulate_input("Ipsum "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum ", + ); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + cx.simulate_input("@file "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum @file ", + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![Point::new(0, 6)..Point::new(0, 39)] + ); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + cx.run_until_parked(); + + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + assert_eq!(contents.len(), 2); + pretty_assertions::assert_eq!( + contents[1], + Mention::Text { + content: "8".to_string(), + uri: "file:///dir/b/eight.txt".parse().unwrap(), + } + ); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) " + ); + assert!(!editor.has_visible_completions_menu()); + assert_eq!( + fold_ranges(editor, cx), + vec![ + Point::new(0, 6)..Point::new(0, 39), + Point::new(0, 47)..Point::new(0, 84) + ] + ); + }); + + let plain_text_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Plain Text".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["txt".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(&cx, |project, _| project.languages().clone()); + language_registry.add(plain_text_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Plain Text", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + workspace_symbol_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(&mut cx, |project, cx| { + project.open_local_buffer(path!("/dir/a/one.txt"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(&mut cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + cx.run_until_parked(); + + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::( + |_, _| async move { + Ok(Some(lsp::WorkspaceSymbolResponse::Flat(vec![ + #[allow(deprecated)] + lsp::SymbolInformation { + name: "MySymbol".into(), + location: lsp::Location { + uri: lsp::Url::from_file_path(path!("/dir/a/one.txt")).unwrap(), + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 1), + ), + }, + kind: lsp::SymbolKind::CONSTANT, + tags: None, + container_name: None, + deprecated: None, + }, + ]))) + }, + ); + + cx.simulate_input("@symbol "); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) @symbol " + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!( + current_completion_labels(editor), + &[ + "MySymbol", + ] + ); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( + project.clone(), + thread_store, + text_thread_store, + window, + cx, + ) + }) + .await + .unwrap() + .into_values() + .collect::>(); + + assert_eq!(contents.len(), 3); + pretty_assertions::assert_eq!( + contents[2], + Mention::Text { + content: "1".into(), + uri: "file:///dir/a/one.txt?symbol=MySymbol#L1:1" + .parse() + .unwrap(), + } + ); + + cx.run_until_parked(); + + editor.read_with(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + "Lorem [@one.txt](file:///dir/a/one.txt) Ipsum [@eight.txt](file:///dir/b/eight.txt) [@MySymbol](file:///dir/a/one.txt?symbol=MySymbol#L1:1) " + ); + }); + } + + fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec> { + let snapshot = editor.buffer().read(cx).snapshot(cx); + editor.display_map.update(cx, |display_map, cx| { + display_map + .snapshot(cx) + .folds_in_range(0..snapshot.len()) + .map(|fold| fold.range.to_point(&snapshot)) + .collect() + }) + } + + fn current_completion_labels(editor: &Editor) -> Vec { + let completions = editor.current_completions().expect("Missing completions"); + completions + .into_iter() + .map(|completion| completion.label.text.to_string()) + .collect::>() + } +} diff --git a/crates/agent_ui/src/acp/message_history.rs b/crates/agent_ui/src/acp/message_history.rs deleted file mode 100644 index c6106c7578..0000000000 --- a/crates/agent_ui/src/acp/message_history.rs +++ /dev/null @@ -1,92 +0,0 @@ -pub struct MessageHistory { - items: Vec, - current: Option, -} - -impl Default for MessageHistory { - fn default() -> Self { - MessageHistory { - items: Vec::new(), - current: None, - } - } -} - -impl MessageHistory { - pub fn push(&mut self, message: T) { - self.current.take(); - self.items.push(message); - } - - pub fn reset_position(&mut self) { - self.current.take(); - } - - pub fn prev(&mut self) -> Option<&T> { - if self.items.is_empty() { - return None; - } - - let new_ix = self - .current - .get_or_insert(self.items.len()) - .saturating_sub(1); - - self.current = Some(new_ix); - self.items.get(new_ix) - } - - pub fn next(&mut self) -> Option<&T> { - let current = self.current.as_mut()?; - *current += 1; - - self.items.get(*current).or_else(|| { - self.current.take(); - None - }) - } - - #[cfg(test)] - pub fn items(&self) -> &[T] { - &self.items - } -} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_prev_next() { - let mut history = MessageHistory::default(); - - // Test empty history - assert_eq!(history.prev(), None); - assert_eq!(history.next(), None); - - // Add some messages - history.push("first"); - history.push("second"); - history.push("third"); - - // Test prev navigation - assert_eq!(history.prev(), Some(&"third")); - assert_eq!(history.prev(), Some(&"second")); - assert_eq!(history.prev(), Some(&"first")); - assert_eq!(history.prev(), Some(&"first")); - - assert_eq!(history.next(), Some(&"second")); - - // Test mixed navigation - history.push("fourth"); - assert_eq!(history.prev(), Some(&"fourth")); - assert_eq!(history.prev(), Some(&"third")); - assert_eq!(history.next(), Some(&"fourth")); - assert_eq!(history.next(), None); - - // Test that push resets navigation - history.prev(); - history.prev(); - history.push("fifth"); - assert_eq!(history.prev(), Some(&"fifth")); - } -} diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs new file mode 100644 index 0000000000..563afee65f --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -0,0 +1,472 @@ +use std::{cmp::Reverse, rc::Rc, sync::Arc}; + +use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_client_protocol as acp; +use anyhow::Result; +use collections::IndexMap; +use futures::FutureExt; +use fuzzy::{StringMatchCandidate, match_strings}; +use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity}; +use ordered_float::OrderedFloat; +use picker::{Picker, PickerDelegate}; +use ui::{ + AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window, + prelude::*, rems, +}; +use util::ResultExt; + +pub type AcpModelSelector = Picker; + +pub fn acp_model_selector( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, +) -> AcpModelSelector { + let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx); + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems(20.)) + .max_height(Some(rems(20.).into())) +} + +enum AcpModelPickerEntry { + Separator(SharedString), + Model(AgentModelInfo), +} + +pub struct AcpModelPickerDelegate { + session_id: acp::SessionId, + selector: Rc, + filtered_entries: Vec, + models: Option, + selected_index: usize, + selected_model: Option, + _refresh_models_task: Task<()>, +} + +impl AcpModelPickerDelegate { + fn new( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let mut rx = selector.watch(cx); + let refresh_models_task = cx.spawn_in(window, { + let session_id = session_id.clone(); + async move |this, cx| { + async fn refresh( + this: &WeakEntity>, + session_id: &acp::SessionId, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let (models_task, selected_model_task) = this.update(cx, |this, cx| { + ( + this.delegate.selector.list_models(cx), + this.delegate.selector.selected_model(session_id, cx), + ) + })?; + + let (models, selected_model) = futures::join!(models_task, selected_model_task); + + this.update_in(cx, |this, window, cx| { + this.delegate.models = models.ok(); + this.delegate.selected_model = selected_model.ok(); + this.delegate.update_matches(this.query(cx), window, cx) + })? + .await; + + Ok(()) + } + + refresh(&this, &session_id, cx).await.log_err(); + while let Ok(()) = rx.recv().await { + refresh(&this, &session_id, cx).await.log_err(); + } + } + }); + + Self { + session_id, + selector, + filtered_entries: Vec::new(), + models: None, + selected_model: None, + selected_index: 0, + _refresh_models_task: refresh_models_task, + } + } + + pub fn active_model(&self) -> Option<&AgentModelInfo> { + self.selected_model.as_ref() + } +} + +impl PickerDelegate for AcpModelPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_entries.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { + self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); + cx.notify(); + } + + fn can_select( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + match self.filtered_entries.get(ix) { + Some(AcpModelPickerEntry::Model(_)) => true, + Some(AcpModelPickerEntry::Separator(_)) | None => false, + } + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Select a model…".into() + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + cx.spawn_in(window, async move |this, cx| { + let filtered_models = match this + .read_with(cx, |this, cx| { + this.delegate.models.clone().map(move |models| { + fuzzy_search(models, query, cx.background_executor().clone()) + }) + }) + .ok() + .flatten() + { + Some(task) => task.await, + None => AgentModelList::Flat(vec![]), + }; + + this.update_in(cx, |this, window, cx| { + this.delegate.filtered_entries = + info_list_to_picker_entries(filtered_models).collect(); + // Finds the currently selected model in the list + let new_index = this + .delegate + .selected_model + .as_ref() + .and_then(|selected| { + this.delegate.filtered_entries.iter().position(|entry| { + if let AcpModelPickerEntry::Model(model_info) = entry { + model_info.id == selected.id + } else { + false + } + }) + }) + .unwrap_or(0); + this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); + cx.notify(); + }) + .ok(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + if let Some(AcpModelPickerEntry::Model(model_info)) = + self.filtered_entries.get(self.selected_index) + { + self.selector + .select_model(self.session_id.clone(), model_info.id.clone(), cx) + .detach_and_log_err(cx); + self.selected_model = Some(model_info.clone()); + let current_index = self.selected_index; + self.set_selected_index(current_index, window, cx); + + cx.emit(DismissEvent); + } + } + + fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { + cx.emit(DismissEvent); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + match self.filtered_entries.get(ix)? { + AcpModelPickerEntry::Separator(title) => Some( + div() + .px_2() + .pb_1() + .when(ix > 1, |this| { + this.mt_1() + .pt_2() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + Label::new(title) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + ), + AcpModelPickerEntry::Model(model_info) => { + let is_selected = Some(model_info) == self.selected_model.as_ref(); + + let model_icon_color = if is_selected { + Color::Accent + } else { + Color::Muted + }; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot::(model_info.icon.map(|icon| { + Icon::new(icon) + .color(model_icon_color) + .size(IconSize::Small) + })) + .child( + h_flex() + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.name.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })) + .into_any_element(), + ) + } + } + } + + fn render_footer( + &self, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + Some( + h_flex() + .w_full() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .p_1() + .gap_4() + .justify_between() + .child( + Button::new("configure", "Configure") + .icon(IconName::Settings) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(|_, window, cx| { + window.dispatch_action( + zed_actions::agent::OpenSettings.boxed_clone(), + cx, + ); + }), + ) + .into_any(), + ) + } +} + +fn info_list_to_picker_entries( + model_list: AgentModelList, +) -> impl Iterator { + match model_list { + AgentModelList::Flat(list) => { + itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model)) + } + AgentModelList::Grouped(index_map) => { + itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| { + std::iter::once(AcpModelPickerEntry::Separator(group_name.0)) + .chain(models.into_iter().map(AcpModelPickerEntry::Model)) + })) + } + } +} + +async fn fuzzy_search( + model_list: AgentModelList, + query: String, + executor: BackgroundExecutor, +) -> AgentModelList { + async fn fuzzy_search_list( + model_list: Vec, + query: &str, + executor: BackgroundExecutor, + ) -> Vec { + let candidates = model_list + .iter() + .enumerate() + .map(|(ix, model)| { + StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name)) + }) + .collect::>(); + let mut matches = match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + executor, + ) + .await; + + matches.sort_unstable_by_key(|mat| { + let candidate = &candidates[mat.candidate_id]; + (Reverse(OrderedFloat(mat.score)), candidate.id) + }); + + matches + .into_iter() + .map(|mat| model_list[mat.candidate_id].clone()) + .collect() + } + + match model_list { + AgentModelList::Flat(model_list) => { + AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await) + } + AgentModelList::Grouped(index_map) => { + let groups = + futures::future::join_all(index_map.into_iter().map(|(group_name, models)| { + fuzzy_search_list(models, &query, executor.clone()) + .map(|results| (group_name, results)) + })) + .await; + AgentModelList::Grouped(IndexMap::from_iter( + groups + .into_iter() + .filter(|(_, results)| !results.is_empty()), + )) + } + } +} + +#[cfg(test)] +mod tests { + use gpui::TestAppContext; + + use super::*; + + fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList { + AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map( + |(group, models)| { + ( + acp_thread::AgentModelGroupName(group.to_string().into()), + models + .into_iter() + .map(|model| acp_thread::AgentModelInfo { + id: acp_thread::AgentModelId(model.to_string().into()), + name: model.to_string().into(), + icon: None, + }) + .collect::>(), + ) + }, + ))) + } + + fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) { + let AgentModelList::Grouped(groups) = result else { + panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result); + }; + + assert_eq!( + groups.len(), + expected.len(), + "Number of groups doesn't match" + ); + + for (i, (expected_group, expected_models)) in expected.iter().enumerate() { + let (actual_group, actual_models) = groups.get_index(i).unwrap(); + assert_eq!( + actual_group.0.as_ref(), + *expected_group, + "Group at position {} doesn't match expected group", + i + ); + assert_eq!( + actual_models.len(), + expected_models.len(), + "Number of models in group {} doesn't match", + expected_group + ); + + for (j, expected_model_name) in expected_models.iter().enumerate() { + assert_eq!( + actual_models[j].name, *expected_model_name, + "Model at position {} in group {} doesn't match expected model", + j, expected_group + ); + } + } + } + + #[gpui::test] + async fn test_fuzzy_match(cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ( + "zed", + vec![ + "Claude 3.7 Sonnet", + "Claude 3.7 Sonnet Thinking", + "gpt-4.1", + "gpt-4.1-nano", + ], + ), + ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]), + ("ollama", vec!["mistral", "deepseek"]), + ]); + + // Results should preserve models order whenever possible. + // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical + // similarity scores, but `zed/gpt-4.1` was higher in the models list, + // so it should appear first in the results. + let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]), + ], + ); + + // Fuzzy search + let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1-nano"]), + ], + ); + } +} diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs new file mode 100644 index 0000000000..e52101113a --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -0,0 +1,85 @@ +use std::rc::Rc; + +use acp_thread::AgentModelSelector; +use agent_client_protocol as acp; +use gpui::{Entity, FocusHandle}; +use picker::popover_menu::PickerPopoverMenu; +use ui::{ + ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*, +}; +use zed_actions::agent::ToggleModelSelector; + +use crate::acp::{AcpModelSelector, model_selector::acp_model_selector}; + +pub struct AcpModelSelectorPopover { + selector: Entity, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, +} + +impl AcpModelSelectorPopover { + pub(crate) fn new( + session_id: acp::SessionId, + selector: Rc, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, + window: &mut Window, + cx: &mut Context, + ) -> Self { + Self { + selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)), + menu_handle, + focus_handle, + } + } + + pub fn toggle(&self, window: &mut Window, cx: &mut Context) { + self.menu_handle.toggle(window, cx); + } +} + +impl Render for AcpModelSelectorPopover { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let model = self.selector.read(cx).delegate.active_model(); + let model_name = model + .as_ref() + .map(|model| model.name.clone()) + .unwrap_or_else(|| SharedString::from("Select a Model")); + + let model_icon = model.as_ref().and_then(|model| model.icon); + + let focus_handle = self.focus_handle.clone(); + + PickerPopoverMenu::new( + self.selector.clone(), + ButtonLike::new("active-model") + .when_some(model_icon, |this, icon| { + this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + }) + .child( + Label::new(model_name) + .color(Color::Muted) + .size(LabelSize::Small) + .ml_0p5(), + ) + .child( + Icon::new(IconName::ChevronDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + move |window, cx| { + Tooltip::for_action_in( + "Change Model", + &ToggleModelSelector, + &focus_handle, + window, + cx, + ) + }, + gpui::Corner::BottomRight, + cx, + ) + .with_handle(self.menu_handle.clone()) + .render(window, cx) + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index c811878c21..17341e4c8a 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,75 +1,115 @@ +use acp_thread::{ + AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, + LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, +}; use acp_thread::{AgentConnection, Plan}; +use action_log::ActionLog; +use agent::{TextThreadStore, ThreadStore}; +use agent_client_protocol::{self as acp}; use agent_servers::AgentServer; -use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting}; +use anyhow::bail; use audio::{Audio, Sound}; -use std::cell::RefCell; -use std::collections::BTreeMap; -use std::path::Path; -use std::process::ExitStatus; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; - -use agent_client_protocol as acp; -use assistant_tool::ActionLog; use buffer_diff::BufferDiff; +use client::zed_urls; use collections::{HashMap, HashSet}; -use editor::{ - AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MinimapVisibility, MultiBuffer, PathKey, -}; +use editor::scroll::Autoscroll; +use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; +use fs::Fs; use gpui::{ - Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, - SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, - Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, - linear_gradient, list, percentage, point, prelude::*, pulsating_between, + Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, + Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, + PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, + TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, + linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; -use language::language_settings::SoftWrap; -use language::{Buffer, Language}; +use language::Buffer; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; -use parking_lot::Mutex; -use project::{CompletionIntent, Project}; +use project::Project; +use prompt_store::PromptId; +use rope::Point; use settings::{Settings as _, SettingsStore}; -use text::{Anchor, BufferSnapshot}; +use std::sync::Arc; +use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration}; +use text::Anchor; use theme::ThemeSettings; use ui::{ - Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, + Callout, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle, + Scrollbar, ScrollbarState, Tooltip, prelude::*, }; -use util::ResultExt; +use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; -use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; +use zed_actions::agent::{Chat, ToggleModelSelector}; +use zed_actions::assistant::OpenRulesLibrary; -use ::acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, -}; - -use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; -use crate::acp::message_history::MessageHistory; +use super::entry_view_state::EntryViewState; +use crate::acp::AcpModelSelectorPopover; +use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent}; +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; -use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::profile_selector::{ProfileProvider, ProfileSelector}; +use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip}; use crate::{ - AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, + AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow, + KeepAll, OpenAgentDiff, RejectAll, ToggleBurnMode, ToggleProfileSelector, }; const RESPONSE_PADDING_X: Pixels = px(19.); +pub const MIN_EDITOR_LINES: usize = 4; +pub const MAX_EDITOR_LINES: usize = 8; + +enum ThreadError { + PaymentRequired, + ModelRequestLimitReached(cloud_llm_client::Plan), + ToolUseLimitReached, + Other(SharedString), +} + +impl ThreadError { + fn from_err(error: anyhow::Error) -> Self { + if error.is::() { + Self::PaymentRequired + } else if error.is::() { + Self::ToolUseLimitReached + } else if let Some(error) = + error.downcast_ref::() + { + Self::ModelRequestLimitReached(error.plan) + } else { + Self::Other(error.to_string().into()) + } + } +} + +impl ProfileProvider for Entity { + fn profile_id(&self, cx: &App) -> AgentProfileId { + self.read(cx).profile().clone() + } + + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { + self.update(cx, |thread, _cx| { + thread.set_profile(profile_id); + }); + } + + fn profiles_supported(&self, cx: &App) -> bool { + self.read(cx).model().supports_tools() + } +} pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, thread_state: ThreadState, - diff_editors: HashMap>, - message_editor: Entity, - message_set_from_history: Option, - _message_editor_subscription: Subscription, - mention_set: Arc>, + entry_view_state: Entity, + message_editor: Entity, + model_selector: Option>, + profile_selector: Option>, notifications: Vec>, notification_subscriptions: HashMap, Vec>, - last_error: Option>, + thread_error: Option, list_state: ListState, scrollbar_state: ScrollbarState, auth_task: Option>, @@ -78,9 +118,10 @@ pub struct AcpThreadView { edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, - message_history: Rc>>>, + terminal_expanded: bool, + editing_message: Option, _cancel_task: Option>, - _subscriptions: [Subscription; 1], + _subscriptions: [Subscription; 3], } enum ThreadState { @@ -105,105 +146,66 @@ impl AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, - message_history: Rc>>>, - min_lines: usize, - max_lines: Option, + thread_store: Entity, + text_thread_store: Entity, window: &mut Window, cx: &mut Context, ) -> Self { - let language = Language::new( - language::LanguageConfig { - completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), - ..Default::default() - }, - None, - ); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); - let message_editor = cx.new(|cx| { - let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - - let mut editor = Editor::new( + MessageEditor::new( + workspace.clone(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), editor::EditorMode::AutoHeight { - min_lines, - max_lines: max_lines, + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), }, - buffer, - None, window, cx, - ); - editor.set_placeholder_text("Message the agent - @ to include files", cx); - editor.set_show_indent_guides(false, cx); - editor.set_soft_wrap(); - editor.set_use_modal_editing(true); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.clone(), - cx.weak_entity(), - )))); - editor.set_context_menu_options(ContextMenuOptions { - min_entries_visible: 12, - max_entries_visible: 12, - placement: Some(ContextMenuPlacement::Above), - }); - editor + ) }); - let message_editor_subscription = - cx.subscribe(&message_editor, |this, editor, event, cx| { - if let editor::EditorEvent::BufferEdited = &event { - let buffer = editor - .read(cx) - .buffer() - .read(cx) - .as_singleton() - .unwrap() - .read(cx) - .snapshot(); - if let Some(message) = this.message_set_from_history.clone() - && message.version() != buffer.version() - { - this.message_set_from_history = None; - } - - if this.message_set_from_history.is_none() { - this.message_history.borrow_mut().reset_position(); - } - } - }); - - let mention_set = mention_set.clone(); - let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); - let subscription = cx.observe_global_in::(window, Self::settings_changed); + let entry_view_state = cx.new(|_| { + EntryViewState::new( + workspace.clone(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + ) + }); + + let subscriptions = [ + cx.observe_global_in::(window, Self::settings_changed), + cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event), + cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event), + ]; Self { agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), + entry_view_state, thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, - message_set_from_history: None, - _message_editor_subscription: message_editor_subscription, - mention_set, + model_selector: None, + profile_selector: None, notifications: Vec::new(), notification_subscriptions: HashMap::default(), - diff_editors: Default::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), - last_error: None, + thread_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), expanded_thinking_blocks: HashSet::default(), + editing_message: None, edits_expanded: false, plan_expanded: false, editor_expanded: false, - message_history, - _subscriptions: [subscription], + terminal_expanded: true, + _subscriptions: subscriptions, _cancel_task: None, } } @@ -250,11 +252,18 @@ impl AcpThreadView { // }) // .ok(); - let result = match connection - .clone() - .new_thread(project.clone(), &root_dir, cx) - .await - { + let Some(result) = cx + .update(|_, cx| { + connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + }) + .log_err() + else { + return; + }; + + let result = match result.await { Err(e) => { let mut cx = cx.clone(); if e.is::() { @@ -268,7 +277,7 @@ impl AcpThreadView { Err(e) } } - Ok(session_id) => Ok(session_id), + Ok(thread) => Ok(thread), }; this.update_in(cx, |this, window, cx| { @@ -286,11 +295,40 @@ impl AcpThreadView { AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + this.model_selector = + thread + .read(cx) + .connection() + .model_selector() + .map(|selector| { + cx.new(|cx| { + AcpModelSelectorPopover::new( + thread.read(cx).session_id().clone(), + selector, + PopoverMenuHandle::default(), + this.focus_handle(cx), + window, + cx, + ) + }) + }); + this.thread_state = ThreadState::Ready { thread, _subscription: [thread_subscription, action_log_subscription], }; + this.profile_selector = this.as_native_thread(cx).map(|thread| { + cx.new(|cx| { + ProfileSelector::new( + ::global(cx), + Arc::new(thread.clone()), + this.focus_handle(cx), + cx, + ) + }) + }); + cx.notify(); } Err(err) => { @@ -333,8 +371,8 @@ impl AcpThreadView { } } - pub fn cancel(&mut self, cx: &mut Context) { - self.last_error.take(); + pub fn cancel_generation(&mut self, cx: &mut Context) { + self.thread_error.take(); if let Some(thread) = self.thread() { self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); @@ -353,167 +391,186 @@ impl AcpThreadView { fn set_editor_is_expanded(&mut self, is_expanded: bool, cx: &mut Context) { self.editor_expanded = is_expanded; - self.message_editor.update(cx, |editor, _| { - if self.editor_expanded { - editor.set_mode(EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: false, - }) + self.message_editor.update(cx, |editor, cx| { + if is_expanded { + editor.set_mode( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: false, + }, + cx, + ) } else { - editor.set_mode(EditorMode::AutoHeight { - min_lines: MIN_EDITOR_LINES, - max_lines: Some(MAX_EDITOR_LINES), - }) + editor.set_mode( + EditorMode::AutoHeight { + min_lines: MIN_EDITOR_LINES, + max_lines: Some(MAX_EDITOR_LINES), + }, + cx, + ) } }); cx.notify(); } - fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context) { - self.last_error.take(); - - let mut ix = 0; - let mut chunks: Vec = Vec::new(); - let project = self.project.clone(); - self.message_editor.update(cx, |editor, cx| { - let text = editor.text(cx); - editor.display_map.update(cx, |map, cx| { - let snapshot = map.snapshot(cx); - for (crease_id, crease) in snapshot.crease_snapshot.creases() { - // Skip creases that have been edited out of the message buffer. - if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { - continue; - } - - if let Some(project_path) = - self.mention_set.lock().path_for_crease_id(crease_id) - { - let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); - if crease_range.start > ix { - chunks.push(text[ix..crease_range.start].into()); - } - if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - let path_str = abs_path.display().to_string(); - chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: path_str.clone(), - name: path_str, - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - })); - } - ix = crease_range.end; - } - } - - if ix < text.len() { - let last_chunk = text[ix..].trim_end(); - if !last_chunk.is_empty() { - chunks.push(last_chunk.into()); - } - } - }) - }); - - if chunks.is_empty() { - return; + pub fn handle_message_editor_event( + &mut self, + _: &Entity, + event: &MessageEditorEvent, + window: &mut Window, + cx: &mut Context, + ) { + match event { + MessageEditorEvent::Send => self.send(window, cx), + MessageEditorEvent::Cancel => self.cancel_generation(cx), + MessageEditorEvent::Focus => {} } + } + pub fn handle_entry_view_event( + &mut self, + _: &Entity, + event: &EntryViewEvent, + window: &mut Window, + cx: &mut Context, + ) { + match &event.view_event { + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => { + self.editing_message = Some(event.entry_index); + cx.notify(); + } + ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => { + self.regenerate(event.entry_index, editor, window, cx); + } + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => { + self.cancel_editing(&Default::default(), window, cx); + } + } + } + + fn resume_chat(&mut self, cx: &mut Context) { + self.thread_error.take(); let Some(thread) = self.thread() else { return; }; - let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); + let task = thread.update(cx, |thread, cx| thread.resume(cx)); cx.spawn(async move |this, cx| { let result = task.await; this.update(cx, |this, cx| { if let Err(err) = result { - this.last_error = - Some(cx.new(|cx| Markdown::new(err.to_string().into(), None, None, cx))) + this.handle_thread_error(err, cx); } }) }) .detach(); + } - let mention_set = self.mention_set.clone(); + fn send(&mut self, window: &mut Window, cx: &mut Context) { + let contents = self + .message_editor + .update(cx, |message_editor, cx| message_editor.contents(window, cx)); + self.send_impl(contents, window, cx) + } - self.set_editor_is_expanded(false, cx); + fn send_impl( + &mut self, + contents: Task>>, + window: &mut Window, + cx: &mut Context, + ) { + self.thread_error.take(); + self.editing_message.take(); - self.message_editor.update(cx, |editor, cx| { - editor.clear(window, cx); - editor.remove_creases(mention_set.lock().drain(), cx) + let Some(thread) = self.thread().cloned() else { + return; + }; + let task = cx.spawn_in(window, async move |this, cx| { + let contents = contents.await?; + + if contents.is_empty() { + return Ok(()); + } + + this.update_in(cx, |this, window, cx| { + this.set_editor_is_expanded(false, cx); + this.scroll_to_bottom(cx); + this.message_editor.update(cx, |message_editor, cx| { + message_editor.clear(window, cx); + }); + })?; + let send = thread.update(cx, |thread, cx| thread.send(contents, cx))?; + send.await }); - self.scroll_to_bottom(cx); - - self.message_history.borrow_mut().push(chunks); + cx.spawn(async move |this, cx| { + if let Err(err) = task.await { + this.update(cx, |this, cx| { + this.handle_thread_error(err, cx); + }) + .ok(); + } + }) + .detach(); } - fn previous_history_message( - &mut self, - _: &PreviousHistoryMessage, - window: &mut Window, - cx: &mut Context, - ) { - if self.message_set_from_history.is_none() && !self.message_editor.read(cx).is_empty(cx) { - self.message_editor.update(cx, |editor, cx| { - editor.move_up(&Default::default(), window, cx); - }); + fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let Some(thread) = self.thread().cloned() else { return; - } - - self.message_set_from_history = Self::set_draft_message( - self.message_editor.clone(), - self.mention_set.clone(), - self.project.clone(), - self.message_history - .borrow_mut() - .prev() - .map(|blocks| blocks.as_slice()), - window, - cx, - ); - } - - fn next_history_message( - &mut self, - _: &NextHistoryMessage, - window: &mut Window, - cx: &mut Context, - ) { - if self.message_set_from_history.is_none() { - self.message_editor.update(cx, |editor, cx| { - editor.move_down(&Default::default(), window, cx); - }); - return; - } - - let mut message_history = self.message_history.borrow_mut(); - let next_history = message_history.next(); - - let set_draft_message = Self::set_draft_message( - self.message_editor.clone(), - self.mention_set.clone(), - self.project.clone(), - Some( - next_history - .map(|blocks| blocks.as_slice()) - .unwrap_or_else(|| &[]), - ), - window, - cx, - ); - // If we reset the text to an empty string because we ran out of history, - // we don't want to mark it as coming from the history - self.message_set_from_history = if next_history.is_some() { - set_draft_message - } else { - None }; + + if let Some(index) = self.editing_message.take() { + if let Some(editor) = self + .entry_view_state + .read(cx) + .entry(index) + .and_then(|e| e.message_editor()) + .cloned() + { + editor.update(cx, |editor, cx| { + if let Some(user_message) = thread + .read(cx) + .entries() + .get(index) + .and_then(|e| e.user_message()) + { + editor.set_message(user_message.chunks.clone(), window, cx); + } + }) + } + }; + self.focus_handle(cx).focus(window); + cx.notify(); + } + + fn regenerate( + &mut self, + entry_ix: usize, + message_editor: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + let Some(rewind) = thread.update(cx, |thread, cx| { + let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?; + Some(thread.rewind(user_message_id, cx)) + }) else { + return; + }; + + let contents = + message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx)); + + let task = cx.foreground_executor().spawn(async move { + rewind.await?; + contents.await + }); + self.send_impl(task, window, cx); } fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { @@ -543,82 +600,14 @@ impl AcpThreadView { }) } - fn set_draft_message( - message_editor: Entity, - mention_set: Arc>, - project: Entity, - message: Option<&[acp::ContentBlock]>, - window: &mut Window, - cx: &mut Context, - ) -> Option { + fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context) { + self.thread_error = Some(ThreadError::from_err(error)); cx.notify(); + } - let message = message?; - - let mut text = String::new(); - let mut mentions = Vec::new(); - - for chunk in message { - match chunk { - acp::ContentBlock::Text(text_content) => { - text.push_str(&text_content.text); - } - acp::ContentBlock::ResourceLink(resource_link) => { - let path = Path::new(&resource_link.uri); - let start = text.len(); - let content = MentionPath::new(&path).to_string(); - text.push_str(&content); - let end = text.len(); - if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(&path, cx) - { - let filename: SharedString = path - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string() - .into(); - mentions.push((start..end, project_path, filename)); - } - } - acp::ContentBlock::Image(_) - | acp::ContentBlock::Audio(_) - | acp::ContentBlock::Resource(_) => {} - } - } - - let snapshot = message_editor.update(cx, |editor, cx| { - editor.set_text(text, window, cx); - editor.buffer().read(cx).snapshot(cx) - }); - - for (range, project_path, filename) in mentions { - let crease_icon_path = if project_path.path.is_dir() { - FileIcons::get_folder_icon(false, cx) - .unwrap_or_else(|| IconName::Folder.path().into()) - } else { - FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx) - .unwrap_or_else(|| IconName::File.path().into()) - }; - - let anchor = snapshot.anchor_before(range.start); - let crease_id = crate::context_picker::insert_crease_for_mention( - anchor.excerpt_id, - anchor.text_anchor, - range.end - range.start, - filename, - crease_icon_path, - message_editor.clone(), - window, - cx, - ); - if let Some(crease_id) = crease_id { - mention_set.lock().insert(crease_id, project_path); - } - } - - let snapshot = snapshot.as_singleton().unwrap().2.clone(); - Some(snapshot.text) + fn clear_thread_error(&mut self, cx: &mut Context) { + self.thread_error = None; + cx.notify(); } fn handle_thread_event( @@ -628,17 +617,25 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let count = self.list_state.item_count(); match event { AcpThreadEvent::NewEntry => { - let index = thread.read(cx).entries().len() - 1; - self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(count..count, 1); + let len = thread.read(cx).entries().len(); + let index = len - 1; + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(index, &thread, window, cx) + }); + self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - let index = *index; - self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(index..index + 1, 1); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(*index, &thread, window, cx) + }); + self.list_state.splice(*index..index + 1, 1); + } + AcpThreadEvent::EntriesRemoved(range) => { + self.entry_view_state + .update(cx, |view_state, _cx| view_state.remove(range.clone())); + self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); @@ -671,74 +668,6 @@ impl AcpThreadView { cx.notify(); } - fn sync_thread_entry_view( - &mut self, - entry_ix: usize, - window: &mut Window, - cx: &mut Context, - ) { - let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { - return; - }; - - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); - - self.diff_editors.insert(entity_id, editor); - } - } - - fn entry_diff_multibuffers( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { - let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some( - entry - .diffs() - .map(|diff| diff.read(cx).multibuffer().clone()), - ) - } - fn authenticate( &mut self, method: acp::AuthMethodId, @@ -749,7 +678,7 @@ impl AcpThreadView { return; }; - self.last_error.take(); + self.thread_error.take(); let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); @@ -759,9 +688,7 @@ impl AcpThreadView { this.update_in(cx, |this, window, cx| { if let Err(err) = result { - this.last_error = Some(cx.new(|cx| { - Markdown::new(format!("Error: {err}").into(), None, None, cx) - })) + this.handle_thread_error(err, cx); } else { this.thread_state = Self::initial_state( agent, @@ -794,9 +721,19 @@ impl AcpThreadView { cx.notify(); } + fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context) { + let Some(thread) = self.thread() else { + return; + }; + thread + .update(cx, |thread, cx| thread.rewind(message_id.clone(), cx)) + .detach_and_log_err(cx); + cx.notify(); + } + fn render_entry( &self, - index: usize, + entry_ix: usize, total_entries: usize, entry: &AgentThreadEntry, window: &mut Window, @@ -804,8 +741,21 @@ impl AcpThreadView { ) -> AnyElement { let primary = match &entry { AgentThreadEntry::UserMessage(message) => div() + .id(("user_message", entry_ix)) .py_4() .px_2() + .children(message.id.clone().and_then(|message_id| { + message.checkpoint.as_ref()?.show.then(|| { + Button::new("restore-checkpoint", "Restore Checkpoint") + .icon(IconName::Undo) + .icon_size(IconSize::XSmall) + .icon_position(IconPosition::Start) + .label_size(LabelSize::XSmall) + .on_click(cx.listener(move |this, _, _window, cx| { + this.rewind(&message_id, cx); + })) + }) + })) .child( v_flex() .p_3() @@ -816,12 +766,16 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .children(message.content.markdown().map(|md| { - self.render_markdown( - md.clone(), - user_message_markdown_style(window, cx), - ) - })), + .children( + self.entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.message_editor()) + .map(|editor| { + self.render_sent_message_editor(entry_ix, editor, cx) + .into_any_element() + }), + ), ) .into_any(), AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { @@ -840,7 +794,7 @@ impl AcpThreadView { AssistantMessageChunk::Thought { block } => { block.markdown().map(|md| { self.render_thinking_block( - index, + entry_ix, chunk_ix, md.clone(), window, @@ -856,25 +810,36 @@ impl AcpThreadView { v_flex() .px_5() .py_1() - .when(index + 1 == total_entries, |this| this.pb_4()) + .when(entry_ix + 1 == total_entries, |this| this.pb_4()) .w_full() .text_ui(cx) .child(message_body) .into_any() } - AgentThreadEntry::ToolCall(tool_call) => div() - .w_full() - .py_1p5() - .px_5() - .child(self.render_tool_call(index, tool_call, window, cx)) - .into_any(), + AgentThreadEntry::ToolCall(tool_call) => { + let has_terminals = tool_call.terminals().next().is_some(); + + div().w_full().py_1p5().px_5().map(|this| { + if has_terminals { + this.children(tool_call.terminals().map(|terminal| { + self.render_terminal_tool_call( + entry_ix, terminal, tool_call, window, cx, + ) + })) + } else { + this.child(self.render_tool_call(entry_ix, tool_call, window, cx)) + } + }) + } + .into_any(), }; let Some(thread) = self.thread() else { return primary; }; + let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); - if index == total_entries - 1 && !is_generating { + let primary = if entry_ix == total_entries - 1 && !is_generating { v_flex() .w_full() .child(primary) @@ -882,6 +847,28 @@ impl AcpThreadView { .into_any_element() } else { primary + }; + + if let Some(editing_index) = self.editing_message.as_ref() + && *editing_index < entry_ix + { + let backdrop = div() + .id(("backdrop", entry_ix)) + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll() + .on_click(cx.listener(Self::cancel_editing)); + + div() + .relative() + .child(primary) + .child(backdrop) + .into_any_element() + } else { + primary } } @@ -1017,10 +1004,10 @@ impl AcpThreadView { .size(IconSize::Small) .color(Color::Muted); + let base_container = h_flex().size_4().justify_center(); + if is_collapsible { - h_flex() - .size_4() - .justify_center() + base_container .child( div() .group_hover(&group_name, |s| s.invisible().w_0()) @@ -1051,7 +1038,7 @@ impl AcpThreadView { ), ) } else { - div().child(tool_icon) + base_container.child(tool_icon) } } @@ -1101,19 +1088,29 @@ impl AcpThreadView { ), }; - let needs_confirmation = match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { .. } => true, - _ => tool_call - .content - .iter() - .any(|content| matches!(content, ToolCallContent::Diff { .. })), - }; + let needs_confirmation = matches!( + tool_call.status, + ToolCallStatus::WaitingForConfirmation { .. } + ); + let is_edit = matches!(tool_call.kind, acp::ToolKind::Edit); + let has_diff = tool_call + .content + .iter() + .any(|content| matches!(content, ToolCallContent::Diff { .. })); + let has_nonempty_diff = tool_call.content.iter().any(|content| match content { + ToolCallContent::Diff(diff) => diff.read(cx).has_revealed_range(cx), + _ => false, + }); + let use_card_layout = needs_confirmation || is_edit || has_diff; - let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; - let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); + let is_collapsible = !tool_call.content.is_empty() && !use_card_layout; - let gradient_color = cx.theme().colors().panel_background; - let gradient_overlay = { + let is_open = tool_call.content.is_empty() + || needs_confirmation + || has_nonempty_diff + || self.expanded_tool_calls.contains(&tool_call.id); + + let gradient_overlay = |color: Hsla| { div() .absolute() .top_0() @@ -1122,13 +1119,47 @@ impl AcpThreadView { .h_full() .bg(linear_gradient( 90., - linear_color_stop(gradient_color, 1.), - linear_color_stop(gradient_color.opacity(0.2), 0.), + linear_color_stop(color, 1.), + linear_color_stop(color.opacity(0.2), 0.), )) }; + let gradient_color = if use_card_layout { + self.tool_card_header_bg(cx) + } else { + cx.theme().colors().panel_background + }; + + let tool_output_display = match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => v_flex() + .w_full() + .children(tool_call.content.iter().map(|content| { + div() + .child( + self.render_tool_call_content(entry_ix, content, tool_call, window, cx), + ) + .into_any_element() + })) + .child(self.render_permission_buttons( + options, + entry_ix, + tool_call.id.clone(), + tool_call.content.is_empty(), + cx, + )), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => v_flex() + .w_full() + .children(tool_call.content.iter().map(|content| { + div() + .child( + self.render_tool_call_content(entry_ix, content, tool_call, window, cx), + ) + .into_any_element() + })), + ToolCallStatus::Rejected => v_flex().size_0(), + }; v_flex() - .when(needs_confirmation, |this| { + .when(use_card_layout, |this| { this.rounded_lg() .border_1() .border_color(self.tool_card_border_color(cx)) @@ -1142,13 +1173,11 @@ impl AcpThreadView { .gap_1() .justify_between() .map(|this| { - if needs_confirmation { + if use_card_layout { this.pl_2() .pr_1() .py_1() .rounded_t_md() - .border_b_1() - .border_color(self.tool_card_border_color(cx)) .bg(self.tool_card_header_bg(cx)) } else { this.opacity(0.8).hover(|style| style.opacity(1.)) @@ -1159,13 +1188,6 @@ impl AcpThreadView { .group(&card_header_id) .relative() .w_full() - .map(|this| { - if tool_call.locations.len() == 1 { - this.gap_0() - } else { - this.gap_1p5() - } - }) .text_size(self.tool_name_font_size()) .child(self.render_tool_call_icon( card_header_id, @@ -1209,6 +1231,7 @@ impl AcpThreadView { .id("non-card-label-container") .w_full() .relative() + .ml_1p5() .overflow_hidden() .child( h_flex() @@ -1219,13 +1242,13 @@ impl AcpThreadView { .child(self.render_markdown( tool_call.label.clone(), default_markdown_style( - needs_confirmation, + needs_confirmation || is_edit || has_diff, window, cx, ), )), ) - .child(gradient_overlay) + .child(gradient_overlay(gradient_color)) .on_click(cx.listener({ let id = tool_call.id.clone(); move |this: &mut Self, _, _, cx: &mut Context| { @@ -1242,88 +1265,111 @@ impl AcpThreadView { ) .children(status_icon), ) - .when(is_open, |this| { - this.child( - v_flex() - .text_xs() - .when(is_collapsible, |this| { - this.mt_1() - .border_1() - .border_color(self.tool_card_border_color(cx)) - .bg(cx.theme().colors().editor_background) - .rounded_lg() - }) - .map(|this| { - if is_open { - match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { options, .. } => this - .children(tool_call.content.iter().map(|content| { - div() - .py_1p5() - .child( - self.render_tool_call_content( - content, window, cx, - ), - ) - .into_any_element() - })) - .child(self.render_permission_buttons( - options, - entry_ix, - tool_call.id.clone(), - tool_call.content.is_empty(), - cx, - )), - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { - this.children(tool_call.content.iter().map(|content| { - div() - .py_1p5() - .child( - self.render_tool_call_content( - content, window, cx, - ), - ) - .into_any_element() - })) - } - ToolCallStatus::Rejected => this, - } - } else { - this - } - }), - ) - }) + .when(is_open, |this| this.child(tool_output_display)) } fn render_tool_call_content( &self, + entry_ix: usize, content: &ToolCallContent, + tool_call: &ToolCall, window: &Window, cx: &Context, ) -> AnyElement { match content { - ToolCallContent::ContentBlock { content } => { - if let Some(md) = content.markdown() { - div() - .p_2() - .child( - self.render_markdown( - md.clone(), - default_markdown_style(false, window, cx), - ), - ) - .into_any_element() + ToolCallContent::ContentBlock(content) => { + if let Some(resource_link) = content.resource_link() { + self.render_resource_link(resource_link, cx) + } else if let Some(markdown) = content.markdown() { + self.render_markdown_output(markdown.clone(), tool_call.id.clone(), window, cx) } else { Empty.into_any_element() } } - ToolCallContent::Diff { diff, .. } => { - self.render_diff_editor(&diff.read(cx).multibuffer()) + ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx), + ToolCallContent::Terminal(terminal) => { + self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } } } + fn render_markdown_output( + &self, + markdown: Entity, + tool_call_id: acp::ToolCallId, + window: &Window, + cx: &Context, + ) -> AnyElement { + let button_id = SharedString::from(format!("tool_output-{:?}", tool_call_id.clone())); + + v_flex() + .mt_1p5() + .ml(px(7.)) + .px_3p5() + .gap_2() + .border_l_1() + .border_color(self.tool_card_border_color(cx)) + .text_sm() + .text_color(cx.theme().colors().text_muted) + .child(self.render_markdown(markdown, default_markdown_style(false, window, cx))) + .child( + Button::new(button_id, "Collapse Output") + .full_width() + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::ChevronUp) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(cx.listener({ + let id = tool_call_id.clone(); + move |this: &mut Self, _, _, cx: &mut Context| { + this.expanded_tool_calls.remove(&id); + cx.notify(); + } + })), + ) + .into_any_element() + } + + fn render_resource_link( + &self, + resource_link: &acp::ResourceLink, + cx: &Context, + ) -> AnyElement { + let uri: SharedString = resource_link.uri.clone().into(); + + let label: SharedString = if let Some(path) = resource_link.uri.strip_prefix("file://") { + path.to_string().into() + } else { + uri.clone() + }; + + let button_id = SharedString::from(format!("item-{}", uri.clone())); + + div() + .ml(px(7.)) + .pl_2p5() + .border_l_1() + .border_color(self.tool_card_border_color(cx)) + .overflow_hidden() + .child( + Button::new(button_id, label) + .label_size(LabelSize::Small) + .color(Color::Muted) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .truncate(true) + .on_click(cx.listener({ + let workspace = self.workspace.clone(); + move |_, _, window, cx: &mut Context| { + Self::open_link(uri.clone(), &workspace, window, cx); + } + })), + ) + .into_any_element() + } + fn render_permission_buttons( &self, options: &[acp::PermissionOption], @@ -1333,14 +1379,22 @@ impl AcpThreadView { cx: &Context, ) -> Div { h_flex() - .p_1p5() + .py_1() + .pl_2() + .pr_1() .gap_1() - .justify_end() + .justify_between() + .flex_wrap() .when(!empty_content, |this| { this.border_t_1() .border_color(self.tool_card_border_color(cx)) }) - .children(options.iter().map(|option| { + .child( + div() + .min_w(rems_from_px(145.)) + .child(LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)), + ) + .child(h_flex().gap_0p5().children(options.iter().map(|option| { let option_id = SharedString::from(option.id.0.clone()); Button::new((option_id, entry_ix), option.name.clone()) .map(|this| match option.kind { @@ -1373,14 +1427,23 @@ impl AcpThreadView { ); } })) - })) + }))) } - fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { + fn render_diff_editor( + &self, + entry_ix: usize, + diff: &Entity, + cx: &Context, + ) -> AnyElement { v_flex() .h_full() + .border_t_1() + .border_color(self.tool_card_border_color(cx)) .child( - if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) { + if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix) + && let Some(editor) = entry.editor_for_diff(&diff) + { editor.clone().into_any_element() } else { Empty.into_any() @@ -1389,6 +1452,250 @@ impl AcpThreadView { .into_any() } + fn render_terminal_tool_call( + &self, + entry_ix: usize, + terminal: &Entity, + tool_call: &ToolCall, + window: &Window, + cx: &Context, + ) -> AnyElement { + let terminal_data = terminal.read(cx); + let working_dir = terminal_data.working_dir(); + let command = terminal_data.command(); + let started_at = terminal_data.started_at(); + + let tool_failed = matches!( + &tool_call.status, + ToolCallStatus::Rejected + | ToolCallStatus::Canceled + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Failed, + .. + } + ); + + let output = terminal_data.output(); + let command_finished = output.is_some(); + let truncated_output = output.is_some_and(|output| output.was_content_truncated); + let output_line_count = output.map(|output| output.content_line_count).unwrap_or(0); + + let command_failed = command_finished + && output.is_some_and(|o| o.exit_status.is_none_or(|status| !status.success())); + + let time_elapsed = if let Some(output) = output { + output.ended_at.duration_since(started_at) + } else { + started_at.elapsed() + }; + + let header_bg = cx + .theme() + .colors() + .element_background + .blend(cx.theme().colors().editor_foreground.opacity(0.025)); + let border_color = cx.theme().colors().border.opacity(0.6); + + let working_dir = working_dir + .as_ref() + .map(|path| format!("{}", path.display())) + .unwrap_or_else(|| "current directory".to_string()); + + let header = h_flex() + .id(SharedString::from(format!( + "terminal-tool-header-{}", + terminal.entity_id() + ))) + .flex_none() + .gap_1() + .justify_between() + .rounded_t_md() + .child( + div() + .id(("command-target-path", terminal.entity_id())) + .w_full() + .max_w_full() + .overflow_x_scroll() + .child( + Label::new(working_dir) + .buffer_font(cx) + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + ) + .when(!command_finished, |header| { + header + .gap_1p5() + .child( + Button::new( + SharedString::from(format!("stop-terminal-{}", terminal.entity_id())), + "Stop", + ) + .icon(IconName::Stop) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .icon_color(Color::Error) + .label_size(LabelSize::Small) + .tooltip(move |window, cx| { + Tooltip::with_meta( + "Stop This Command", + None, + "Also possible by placing your cursor inside the terminal and using regular terminal bindings.", + window, + cx, + ) + }) + .on_click({ + let terminal = terminal.clone(); + cx.listener(move |_this, _event, _window, cx| { + let inner_terminal = terminal.read(cx).inner().clone(); + inner_terminal.update(cx, |inner_terminal, _cx| { + inner_terminal.kill_active_task(); + }); + }) + }), + ) + .child(Divider::vertical()) + .child( + Icon::new(IconName::ArrowCircle) + .size(IconSize::XSmall) + .color(Color::Info) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage(delta))) + }, + ), + ) + }) + .when(tool_failed || command_failed, |header| { + header.child( + div() + .id(("terminal-tool-error-code-indicator", terminal.entity_id())) + .child( + Icon::new(IconName::Close) + .size(IconSize::Small) + .color(Color::Error), + ) + .when_some(output.and_then(|o| o.exit_status), |this, status| { + this.tooltip(Tooltip::text(format!( + "Exited with code {}", + status.code().unwrap_or(-1), + ))) + }), + ) + }) + .when(truncated_output, |header| { + let tooltip = if let Some(output) = output { + if output_line_count + 10 > terminal::MAX_SCROLL_HISTORY_LINES { + "Output exceeded terminal max lines and was \ + truncated, the model received the first 16 KB." + .to_string() + } else { + format!( + "Output is {} long—to avoid unexpected token usage, \ + only 16 KB was sent back to the model.", + format_file_size(output.original_content_len as u64, true), + ) + } + } else { + "Output was truncated".to_string() + }; + + header.child( + h_flex() + .id(("terminal-tool-truncated-label", terminal.entity_id())) + .gap_1() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Ignored), + ) + .child( + Label::new("Truncated") + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .tooltip(Tooltip::text(tooltip)), + ) + }) + .when(time_elapsed > Duration::from_secs(10), |header| { + header.child( + Label::new(format!("({})", duration_alt_display(time_elapsed))) + .buffer_font(cx) + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + }) + .child( + Disclosure::new( + SharedString::from(format!( + "terminal-tool-disclosure-{}", + terminal.entity_id() + )), + self.terminal_expanded, + ) + .opened_icon(IconName::ChevronUp) + .closed_icon(IconName::ChevronDown) + .on_click(cx.listener(move |this, _event, _window, _cx| { + this.terminal_expanded = !this.terminal_expanded; + })), + ); + + let terminal_view = self + .entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.terminal(&terminal)); + let show_output = self.terminal_expanded && terminal_view.is_some(); + + v_flex() + .mb_2() + .border_1() + .when(tool_failed || command_failed, |card| card.border_dashed()) + .border_color(border_color) + .rounded_lg() + .overflow_hidden() + .child( + v_flex() + .py_1p5() + .pl_2() + .pr_1p5() + .gap_0p5() + .bg(header_bg) + .text_xs() + .child(header) + .child( + MarkdownElement::new( + command.clone(), + terminal_command_markdown_style(window, cx), + ) + .code_block_renderer( + markdown::CodeBlockRenderer::Default { + copy_button: false, + copy_button_on_hover: true, + border: false, + }, + ), + ), + ) + .when(show_output, |this| { + this.child( + div() + .pt_2() + .border_t_1() + .when(tool_failed || command_failed, |card| card.border_dashed()) + .border_color(border_color) + .bg(cx.theme().colors().editor_background) + .rounded_b_md() + .text_ui_sm(cx) + .children(terminal_view.clone()), + ) + }) + .into_any() + } + fn render_agent_logo(&self) -> AnyElement { Icon::new(self.agent.logo()) .color(Color::Muted) @@ -1608,24 +1915,26 @@ impl AcpThreadView { parent.child(self.render_plan_entries(plan, window, cx)) }) }) - .when(!changed_buffers.is_empty(), |this| { + .when(!plan.is_empty() && !changed_buffers.is_empty(), |this| { this.child(Divider::horizontal().color(DividerColor::Border)) - .child(self.render_edits_summary( + }) + .when(!changed_buffers.is_empty(), |this| { + this.child(self.render_edits_summary( + action_log, + &changed_buffers, + self.edits_expanded, + pending_edits, + window, + cx, + )) + .when(self.edits_expanded, |parent| { + parent.child(self.render_edited_files( action_log, &changed_buffers, - self.edits_expanded, pending_edits, - window, cx, )) - .when(self.edits_expanded, |parent| { - parent.child(self.render_edited_files( - action_log, - &changed_buffers, - pending_edits, - cx, - )) - }) + }) }) .into_any() .into() @@ -2067,6 +2376,17 @@ impl AcpThreadView { v_flex() .on_action(cx.listener(Self::expand_message_editor)) + .on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| { + if let Some(profile_selector) = this.profile_selector.as_ref() { + profile_selector.read(cx).menu_handle().toggle(window, cx); + } + })) + .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { + if let Some(model_selector) = this.model_selector.as_ref() { + model_selector + .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); + } + })) .p_2() .gap_2() .border_t_1() @@ -2081,34 +2401,7 @@ impl AcpThreadView { .size_full() .pt_1() .pr_2p5() - .child(div().flex_1().child({ - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small - .rems(cx) - .to_pixels(settings.agent_font_size(cx)); - let line_height = settings.buffer_line_height.value() * font_size; - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; - - EditorElement::new( - &self.message_editor, - EditorStyle { - background: editor_bg_color, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - ) - })) + .child(self.message_editor.clone()) .child( h_flex() .absolute() @@ -2142,12 +2435,169 @@ impl AcpThreadView { h_flex() .flex_none() .justify_between() - .child(self.render_follow_toggle(cx)) - .child(self.render_send_button(cx)), + .child( + h_flex() + .gap_1() + .child(self.render_follow_toggle(cx)) + .children(self.render_burn_mode_toggle(cx)), + ) + .child( + h_flex() + .gap_1() + .children(self.profile_selector.clone()) + .children(self.model_selector.clone()) + .child(self.render_send_button(cx)), + ), ) .into_any() } + fn as_native_connection(&self, cx: &App) -> Option> { + let acp_thread = self.thread()?.read(cx); + acp_thread.connection().clone().downcast() + } + + fn as_native_thread(&self, cx: &App) -> Option> { + let acp_thread = self.thread()?.read(cx); + self.as_native_connection(cx)? + .thread(acp_thread.session_id(), cx) + } + + fn toggle_burn_mode( + &mut self, + _: &ToggleBurnMode, + _window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.as_native_thread(cx) else { + return; + }; + + thread.update(cx, |thread, _cx| { + let current_mode = thread.completion_mode(); + thread.set_completion_mode(match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }); + }); + } + + fn render_burn_mode_toggle(&self, cx: &mut Context) -> Option { + let thread = self.as_native_thread(cx)?.read(cx); + + if !thread.model().supports_burn_mode() { + return None; + } + + let active_completion_mode = thread.completion_mode(); + let burn_mode_enabled = active_completion_mode == CompletionMode::Burn; + let icon = if burn_mode_enabled { + IconName::ZedBurnModeOn + } else { + IconName::ZedBurnMode + }; + + Some( + IconButton::new("burn-mode", icon) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(burn_mode_enabled) + .selected_icon_color(Color::Error) + .on_click(cx.listener(|this, _event, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + })) + .tooltip(move |_window, cx| { + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) + .into() + }) + .into_any_element(), + ) + } + + fn render_sent_message_editor( + &self, + entry_ix: usize, + editor: &Entity, + cx: &Context, + ) -> Div { + v_flex().w_full().gap_2().child(editor.clone()).when( + self.editing_message == Some(entry_ix), + |el| { + el.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall), + ) + .child( + Label::new("Editing will restart the thread from this point.") + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)), + ) + }, + ) + } + + fn render_sent_message_editor_buttons( + &self, + entry_ix: usize, + editor: &Entity, + cx: &Context, + ) -> Div { + h_flex() + .gap_0p5() + .flex_1() + .justify_end() + .child( + IconButton::new("cancel-edit-message", IconName::Close) + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Error) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = editor.focus_handle(cx); + move |window, cx| { + Tooltip::for_action_in( + "Cancel Edit", + &menu::Cancel, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(Self::cancel_editing)), + ) + .child( + IconButton::new("confirm-edit-message", IconName::Return) + .disabled(editor.read(cx).is_empty(cx)) + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = editor.focus_handle(cx); + move |window, cx| { + Tooltip::for_action_in( + "Regenerate", + &menu::Confirm, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener({ + let editor = editor.clone(); + move |this, _, window, cx| { + this.regenerate(entry_ix, &editor, window, cx); + } + })), + ) + } + fn render_send_button(&self, cx: &mut Context) -> AnyElement { if self.thread().map_or(true, |thread| { thread.read(cx).status() == ThreadStatus::Idle @@ -2164,7 +2614,7 @@ impl AcpThreadView { button.tooltip(Tooltip::text("Type a message to submit")) }) .on_click(cx.listener(|this, _, window, cx| { - this.chat(&Chat, window, cx); + this.send(window, cx); })) .into_any_element() } else { @@ -2174,7 +2624,7 @@ impl AcpThreadView { .tooltip(move |window, cx| { Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) }) - .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) + .on_click(cx.listener(|this, _event, _, cx| this.cancel_generation(cx))) .into_any_element() } } @@ -2236,26 +2686,95 @@ impl AcpThreadView { return; }; - if let Some(mention_path) = MentionPath::try_parse(&url) { - workspace.update(cx, |workspace, cx| { - let project = workspace.project(); - let Some((path, entry)) = project.update(cx, |project, cx| { - let path = project.find_project_path(mention_path.path(), cx)?; - let entry = project.entry_for_path(&path, cx)?; - Some((path, entry)) - }) else { - return; - }; + if let Some(mention) = MentionUri::parse(&url).log_err() { + workspace.update(cx, |workspace, cx| match mention { + MentionUri::File { abs_path, .. } => { + let project = workspace.project(); + let Some((path, entry)) = project.update(cx, |project, cx| { + let path = project.find_project_path(abs_path, cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; - if entry.is_dir() { - project.update(cx, |_, cx| { - cx.emit(project::Event::RevealInProjectPanel(entry.id)); - }); - } else { - workspace - .open_path(path, None, true, window, cx) + if entry.is_dir() { + project.update(cx, |_, cx| { + cx.emit(project::Event::RevealInProjectPanel(entry.id)); + }); + } else { + workspace + .open_path(path, None, true, window, cx) + .detach_and_log_err(cx); + } + } + MentionUri::Symbol { + path, line_range, .. + } + | MentionUri::Selection { path, line_range } => { + let project = workspace.project(); + let Some((path, _)) = project.update(cx, |project, cx| { + let path = project.find_project_path(path, cx)?; + let entry = project.entry_for_path(&path, cx)?; + Some((path, entry)) + }) else { + return; + }; + + let item = workspace.open_path(path, None, true, window, cx); + window + .spawn(cx, async move |cx| { + let Some(editor) = item.await?.downcast::() else { + return Ok(()); + }; + let range = + Point::new(line_range.start, 0)..Point::new(line_range.start, 0); + editor + .update_in(cx, |editor, window, cx| { + editor.change_selections( + SelectionEffects::scroll(Autoscroll::center()), + window, + cx, + |s| s.select_ranges(vec![range]), + ); + }) + .ok(); + anyhow::Ok(()) + }) .detach_and_log_err(cx); } + MentionUri::Thread { id, .. } => { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel + .open_thread_by_id(&id, window, cx) + .detach_and_log_err(cx) + }); + } + } + MentionUri::TextThread { path, .. } => { + if let Some(panel) = workspace.panel::(cx) { + panel.update(cx, |panel, cx| { + panel + .open_saved_prompt_editor(path.as_path().into(), window, cx) + .detach_and_log_err(cx); + }); + } + } + MentionUri::Rule { id, .. } => { + let PromptId::User { uuid } = id else { + return; + }; + window.dispatch_action( + Box::new(OpenRulesLibrary { + prompt_to_select: Some(uuid.0), + }), + cx, + ) + } + MentionUri::Fetch { url } => { + cx.open_url(url.as_str()); + } }) } else { cx.open_url(&url); @@ -2269,26 +2788,24 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) -> Option<()> { - let location = self + let (tool_call_location, agent_location) = self .thread()? .read(cx) .entries() .get(entry_ix)? - .locations()? - .get(location_ix)?; + .location(location_ix)?; let project_path = self .project .read(cx) - .find_project_path(&location.path, cx)?; + .find_project_path(&tool_call_location.path, cx)?; let open_task = self .workspace - .update(cx, |worskpace, cx| { - worskpace.open_path(project_path, None, true, window, cx) + .update(cx, |workspace, cx| { + workspace.open_path(project_path, None, true, window, cx) }) .log_err()?; - window .spawn(cx, async move |cx| { let item = open_task.await?; @@ -2298,17 +2815,22 @@ impl AcpThreadView { }; active_editor.update_in(cx, |editor, window, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let first_hunk = editor - .diff_hunks_in_ranges( - &[editor::Anchor::min()..editor::Anchor::max()], - &snapshot, - ) - .next(); - if let Some(first_hunk) = first_hunk { - let first_hunk_start = first_hunk.multi_buffer_range().start; + let multibuffer = editor.buffer().read(cx); + let buffer = multibuffer.as_singleton(); + if agent_location.buffer.upgrade() == buffer { + let excerpt_id = multibuffer.excerpt_ids().first().cloned(); + let anchor = editor::Anchor::in_buffer( + excerpt_id.unwrap(), + buffer.unwrap().read(cx).remote_id(), + agent_location.position, + ); editor.change_selections(Default::default(), window, cx, |selections| { - selections.select_anchor_ranges([first_hunk_start..first_hunk_start]); + selections.select_anchor_ranges([anchor..anchor]); + }) + } else { + let row = tool_call_location.line.unwrap_or_default(); + editor.change_selections(Default::default(), window, cx, |selections| { + selections.select_ranges([Point::new(row, 0)..Point::new(row, 0)]); }) } })?; @@ -2346,7 +2868,7 @@ impl AcpThreadView { let project = workspace.project().clone(); if !project.read(cx).is_local() { - anyhow::bail!("failed to open active thread as markdown in remote project"); + bail!("failed to open active thread as markdown in remote project"); } let buffer = project.update(cx, |project, cx| { @@ -2538,7 +3060,8 @@ impl AcpThreadView { fn render_thread_controls(&self, cx: &Context) -> impl IntoElement { let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileMarkdown) - .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Open Thread as Markdown")) .on_click(cx.listener(move |this, _, window, cx| { @@ -2549,7 +3072,8 @@ impl AcpThreadView { })); let scroll_to_top = IconButton::new("scroll_to_top", IconName::ArrowUp) - .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .icon_color(Color::Ignored) .tooltip(Tooltip::text("Scroll To Top")) .on_click(cx.listener(move |this, _, _, cx| { @@ -2560,7 +3084,6 @@ impl AcpThreadView { .w_full() .mr_1() .pb_2() - .gap_1() .px(RESPONSE_PADDING_X) .opacity(0.4) .hover(|style| style.opacity(1.)) @@ -2604,67 +3127,201 @@ impl AcpThreadView { } fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { - for diff_editor in self.diff_editors.values() { - diff_editor.update(cx, |diff_editor, cx| { - diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - cx.notify(); - }) - } + self.entry_view_state.update(cx, |entry_view_state, cx| { + entry_view_state.settings_changed(cx); + }); } pub(crate) fn insert_dragged_files( &self, paths: Vec, - _added_worktrees: Vec>, + added_worktrees: Vec>, window: &mut Window, - cx: &mut Context<'_, Self>, + cx: &mut Context, ) { - let buffer = self.message_editor.read(cx).buffer().clone(); - let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else { - return; - }; - let Some(buffer) = buffer.read(cx).as_singleton() else { - return; - }; - for path in paths { - let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { - continue; - }; - let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else { - continue; - }; + self.message_editor.update(cx, |message_editor, cx| { + message_editor.insert_dragged_files(paths, window, cx); + drop(added_worktrees); + }) + } - let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); - let path_prefix = abs_path - .file_name() - .unwrap_or(path.path.as_os_str()) - .display() - .to_string(); - let completion = ContextPickerCompletionProvider::completion_for_path( - path, - &path_prefix, - false, - entry.is_dir(), - excerpt_id, - anchor..anchor, - self.message_editor.clone(), - self.mention_set.clone(), - cx, - ); - - self.message_editor.update(cx, |message_editor, cx| { - message_editor.edit( - [( - multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), - completion.new_text, - )], - cx, - ); - }); - if let Some(confirm) = completion.confirm.clone() { - confirm(CompletionIntent::Complete, window, cx); + fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option
{ + let content = match self.thread_error.as_ref()? { + ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), + ThreadError::PaymentRequired => self.render_payment_required_error(cx), + ThreadError::ModelRequestLimitReached(plan) => { + self.render_model_request_limit_reached_error(*plan, cx) } - } + ThreadError::ToolUseLimitReached => { + self.render_tool_use_limit_reached_error(window, cx)? + } + }; + + Some( + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child(content), + ) + } + + fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout { + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Error") + .description(error.clone()) + .secondary_action(self.create_copy_button(error.to_string())) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_payment_required_error(&self, cx: &mut Context) -> Callout { + const ERROR_MESSAGE: &str = + "You reached your free usage limit. Upgrade to Zed Pro for more prompts."; + + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Free Usage Exceeded") + .description(ERROR_MESSAGE) + .tertiary_action(self.upgrade_button(cx)) + .secondary_action(self.create_copy_button(ERROR_MESSAGE)) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_model_request_limit_reached_error( + &self, + plan: cloud_llm_client::Plan, + cx: &mut Context, + ) -> Callout { + let error_message = match plan { + cloud_llm_client::Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", + cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => { + "Upgrade to Zed Pro for more prompts." + } + }; + + let icon = Icon::new(IconName::XCircle) + .size(IconSize::Small) + .color(Color::Error); + + Callout::new() + .icon(icon) + .title("Model Prompt Limit Reached") + .description(error_message) + .tertiary_action(self.upgrade_button(cx)) + .secondary_action(self.create_copy_button(error_message)) + .primary_action(self.dismiss_error_button(cx)) + .bg_color(self.error_callout_bg(cx)) + } + + fn render_tool_use_limit_reached_error( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Option { + let thread = self.as_native_thread(cx)?; + let supports_burn_mode = thread.read(cx).model().supports_burn_mode(); + + let focus_handle = self.focus_handle(cx); + + let icon = Icon::new(IconName::Info) + .size(IconSize::Small) + .color(Color::Info); + + Some( + Callout::new() + .icon(icon) + .title("Consecutive tool use limit reached.") + .when(supports_burn_mode, |this| { + this.secondary_action( + Button::new("continue-burn-mode", "Continue with Burn Mode") + .style(ButtonStyle::Filled) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &ContinueWithBurnMode, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) + .on_click({ + cx.listener(move |this, _, _window, cx| { + thread.update(cx, |thread, _cx| { + thread.set_completion_mode(CompletionMode::Burn); + }); + this.resume_chat(cx); + }) + }), + ) + }) + .primary_action( + Button::new("continue-conversation", "Continue") + .layer(ElevationIndex::ModalSurface) + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in(&ContinueThread, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click(cx.listener(|this, _, _window, cx| { + this.resume_chat(cx); + })), + ), + ) + } + + fn create_copy_button(&self, message: impl Into) -> impl IntoElement { + let message = message.into(); + + IconButton::new("copy", IconName::Copy) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Copy Error Message")) + .on_click(move |_, _, cx| { + cx.write_to_clipboard(ClipboardItem::new_string(message.clone())) + }) + } + + fn dismiss_error_button(&self, cx: &mut Context) -> impl IntoElement { + IconButton::new("dismiss", IconName::Close) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .tooltip(Tooltip::text("Dismiss Error")) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.notify(); + } + })) + } + + fn upgrade_button(&self, cx: &mut Context) -> impl IntoElement { + Button::new("upgrade", "Upgrade") + .label_size(LabelSize::Small) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .on_click(cx.listener({ + move |this, _, _, cx| { + this.clear_thread_error(cx); + cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)); + } + })) + } + + fn error_callout_bg(&self, cx: &Context) -> Hsla { + cx.theme().status().error.opacity(0.08) } } @@ -2676,13 +3333,13 @@ impl Focusable for AcpThreadView { impl Render for AcpThreadView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let has_messages = self.list_state.item_count() > 0; + v_flex() .size_full() .key_context("AcpThread") - .on_action(cx.listener(Self::chat)) - .on_action(cx.listener(Self::previous_history_message)) - .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) + .on_action(cx.listener(Self::toggle_burn_mode)) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { ThreadState::Unauthenticated { connection } => v_flex() @@ -2722,7 +3379,7 @@ impl Render for AcpThreadView { let thread_clone = thread.clone(); v_flex().flex_1().map(|this| { - if self.list_state.item_count() > 0 { + if has_messages { this.child( list( self.list_state.clone(), @@ -2741,69 +3398,37 @@ impl Render for AcpThreadView { .into_any(), ) .child(self.render_vertical_scrollbar(cx)) - .children(match thread_clone.read(cx).status() { - ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => { - None - } - ThreadStatus::Generating => div() - .px_5() - .py_2() - .child(LoadingLabel::new("").size(LabelSize::Small)) - .into(), - }) - .children(self.render_activity_bar(&thread_clone, window, cx)) + .children( + match thread_clone.read(cx).status() { + ThreadStatus::Idle + | ThreadStatus::WaitingForToolConfirmation => None, + ThreadStatus::Generating => div() + .px_5() + .py_2() + .child(LoadingLabel::new("").size(LabelSize::Small)) + .into(), + }, + ) } else { this.child(self.render_empty_state(cx)) } }) } }) - .when_some(self.last_error.clone(), |el, error| { - el.child( - div() - .p_2() - .text_xs() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().status().error_background) - .child( - self.render_markdown(error, default_markdown_style(false, window, cx)), - ), - ) + // The activity bar is intentionally rendered outside of the ThreadState::Ready match + // above so that the scrollbar doesn't render behind it. The current setup allows + // the scrollbar to stop exactly at the activity bar start. + .when(has_messages, |this| match &self.thread_state { + ThreadState::Ready { thread, .. } => { + this.children(self.render_activity_bar(thread, window, cx)) + } + _ => this, }) + .children(self.render_thread_error(window, cx)) .child(self.render_message_editor(window, cx)) } } -fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let mut style = default_markdown_style(false, window, cx); - let mut text_style = window.text_style(); - let theme_settings = ThemeSettings::get_global(cx); - - let buffer_font = theme_settings.buffer_font.family.clone(); - let buffer_font_size = TextSize::Small.rems(cx); - - text_style.refine(&TextStyleRefinement { - font_family: Some(buffer_font), - font_size: Some(buffer_font_size.into()), - ..Default::default() - }); - - style.base_text_style = text_style; - style.link_callback = Some(Rc::new(move |url, cx| { - if MentionPath::try_parse(url).is_some() { - let colors = cx.theme().colors(); - Some(TextStyleRefinement { - background_color: Some(colors.element_background), - ..Default::default() - }) - } else { - None - } - })); - style -} - fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); @@ -2943,31 +3568,32 @@ fn plan_label_markdown_style( } } -fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { - TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), +fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let default_md_style = default_markdown_style(true, window, cx); + + MarkdownStyle { + base_text_style: TextStyle { + ..default_md_style.base_text_style + }, + selection_background_color: cx.theme().colors().element_selection_background, ..Default::default() } } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use acp_thread::StubAgentConnection; + use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; use editor::EditorSettings; use fs::FakeFs; - use futures::future::try_join_all; - use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; - use lsp::{CompletionContext, CompletionTriggerKind}; - use project::CompletionIntent; - use rand::Rng; + use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext}; + use project::Project; use serde_json::json; use settings::SettingsStore; - use util::path; + use std::any::Any; + use std::path::Path; + use workspace::Item; use super::*; @@ -2995,7 +3621,7 @@ mod tests { cx.deactivate_window(); thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + thread_view.send(window, cx); }); cx.run_until_parked(); @@ -3022,7 +3648,7 @@ mod tests { cx.deactivate_window(); thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + thread_view.send(window, cx); }); cx.run_until_parked(); @@ -3049,8 +3675,8 @@ mod tests { raw_input: None, raw_output: None, }; - let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) - .with_permission_requests(HashMap::from_iter([( + let connection = + StubAgentConnection::new().with_permission_requests(HashMap::from_iter([( tool_call_id, vec![acp::PermissionOption { id: acp::PermissionOptionId("1".into()), @@ -3058,6 +3684,9 @@ mod tests { kind: acp::PermissionOptionKind::AllowOnce, }], )])); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); @@ -3068,7 +3697,7 @@ mod tests { cx.deactivate_window(); thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); + thread_view.send(window, cx); }); cx.run_until_parked(); @@ -3080,109 +3709,6 @@ mod tests { ); } - #[gpui::test] - async fn test_crease_removal(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree("/project", json!({"file": ""})).await; - let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; - let agent = StubAgentServer::default(); - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let thread_view = cx.update(|window, cx| { - cx.new(|cx| { - AcpThreadView::new( - Rc::new(agent), - workspace.downgrade(), - project, - Rc::new(RefCell::new(MessageHistory::default())), - 1, - None, - window, - cx, - ) - }) - }); - - cx.run_until_parked(); - - let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); - let excerpt_id = message_editor.update(cx, |editor, cx| { - editor - .buffer() - .read(cx) - .excerpt_ids() - .into_iter() - .next() - .unwrap() - }); - let completions = message_editor.update_in(cx, |editor, window, cx| { - editor.set_text("Hello @", window, cx); - let buffer = editor.buffer().read(cx).as_singleton().unwrap(); - let completion_provider = editor.completion_provider().unwrap(); - completion_provider.completions( - excerpt_id, - &buffer, - Anchor::MAX, - CompletionContext { - trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER, - trigger_character: Some("@".into()), - }, - window, - cx, - ) - }); - let [_, completion]: [_; 2] = completions - .await - .unwrap() - .into_iter() - .flat_map(|response| response.completions) - .collect::>() - .try_into() - .unwrap(); - - message_editor.update_in(cx, |editor, window, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let start = snapshot - .anchor_in_excerpt(excerpt_id, completion.replace_range.start) - .unwrap(); - let end = snapshot - .anchor_in_excerpt(excerpt_id, completion.replace_range.end) - .unwrap(); - editor.edit([(start..end, completion.new_text)], cx); - (completion.confirm.unwrap())(CompletionIntent::Complete, window, cx); - }); - - cx.run_until_parked(); - - // Backspace over the inserted crease (and the following space). - message_editor.update_in(cx, |editor, window, cx| { - editor.backspace(&Default::default(), window, cx); - editor.backspace(&Default::default(), window, cx); - }); - - thread_view.update_in(cx, |thread_view, window, cx| { - thread_view.chat(&Chat, window, cx); - }); - - cx.run_until_parked(); - - let content = thread_view.update_in(cx, |thread_view, _window, _cx| { - thread_view - .message_history - .borrow() - .items() - .iter() - .flatten() - .cloned() - .collect::>() - }); - - // We don't send a resource link for the deleted crease. - pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]); - } - async fn setup_thread_view( agent: impl AgentServer + 'static, cx: &mut TestAppContext, @@ -3192,15 +3718,19 @@ mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let thread_store = + cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); + let text_thread_store = + cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + let thread_view = cx.update(|window, cx| { cx.new(|cx| { AcpThreadView::new( Rc::new(agent), workspace.downgrade(), project, - Rc::new(RefCell::new(MessageHistory::default())), - 1, - None, + thread_store.clone(), + text_thread_store.clone(), window, cx, ) @@ -3210,6 +3740,50 @@ mod tests { (thread_view, cx) } + fn add_to_workspace(thread_view: Entity, cx: &mut VisualTestContext) { + let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone()); + + workspace + .update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane( + Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))), + None, + true, + window, + cx, + ); + }) + .unwrap(); + } + + struct ThreadViewItem(Entity); + + impl Item for ThreadViewItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for ThreadViewItem {} + + impl Focusable for ThreadViewItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx).clone() + } + } + + impl Render for ThreadViewItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + struct StubAgentServer { connection: C, } @@ -3231,19 +3805,19 @@ mod tests { C: 'static + AgentConnection + Send + Clone, { fn logo(&self) -> ui::IconName { - unimplemented!() + ui::IconName::Ai } fn name(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_headline(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_message(&self) -> &'static str { - unimplemented!() + "Test" } fn connect( @@ -3256,114 +3830,6 @@ mod tests { } } - #[derive(Clone, Default)] - struct StubAgentConnection { - sessions: Arc>>>, - permission_requests: HashMap>, - updates: Vec, - } - - impl StubAgentConnection { - fn new(updates: Vec) -> Self { - Self { - updates, - permission_requests: HashMap::default(), - sessions: Arc::default(), - } - } - - fn with_permission_requests( - mut self, - permission_requests: HashMap>, - ) -> Self { - self.permission_requests = permission_requests; - self - } - } - - impl AgentConnection for StubAgentConnection { - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - let session_id = SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .into(), - ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); - self.sessions.lock().insert(session_id, thread.downgrade()); - Task::ready(Ok(thread)) - } - - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } - - fn prompt( - &self, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let sessions = self.sessions.lock(); - let thread = sessions.get(¶ms.session_id).unwrap(); - let mut tasks = vec![]; - for update in &self.updates { - let thread = thread.clone(); - let update = update.clone(); - let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update - && let Some(options) = self.permission_requests.get(&tool_call.id) - { - Some((tool_call.clone(), options.clone())) - } else { - None - }; - let task = cx.spawn(async move |cx| { - if let Some((tool_call, options)) = permission_request { - let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization( - tool_call.clone(), - options.clone(), - cx, - ) - })?; - permission.await?; - } - thread.update(cx, |thread, cx| { - thread.handle_session_update(update.clone(), cx).unwrap(); - })?; - anyhow::Ok(()) - }); - tasks.push(task); - } - cx.spawn(async move |_| { - try_join_all(tasks).await?; - Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) - }) - } - - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } - #[derive(Clone)] struct SaboteurAgentConnection; @@ -3372,19 +3838,17 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::AsyncApp, + cx: &mut gpui::App, ) -> Task>> { - Task::ready(Ok(cx - .new(|cx| { - AcpThread::new( - "SaboteurAgentConnection", - self, - project, - SessionId("test".into()), - cx, - ) - }) - .unwrap())) + Task::ready(Ok(cx.new(|cx| { + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + SessionId("test".into()), + cx, + ) + }))) } fn auth_methods(&self) -> &[acp::AuthMethod] { @@ -3401,6 +3865,7 @@ mod tests { fn prompt( &self, + _id: Option, _params: acp::PromptRequest, _cx: &mut App, ) -> Task> { @@ -3410,9 +3875,13 @@ mod tests { fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { unimplemented!() } + + fn into_any(self: Rc) -> Rc { + self + } } - fn init_test(cx: &mut TestAppContext) { + pub(crate) fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -3425,4 +3894,319 @@ mod tests { EditorSettings::register(cx); }); } + + #[gpui::test] + async fn test_rewind_views(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "test1.txt": "old content 1", + "test2.txt": "old content 2" + }), + ) + .await; + let project = Project::test(fs, [Path::new("/project")], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = + cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); + let text_thread_store = + cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + + let connection = Rc::new(StubAgentConnection::new()); + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(StubAgentServer::new(connection.as_ref().clone())), + workspace.downgrade(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let thread = thread_view + .read_with(cx, |view, _| view.thread().cloned()) + .unwrap(); + + // First user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool1".into()), + title: "Edit file 1".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test1.txt".into(), + old_text: Some("old content 1".into()), + new_text: "new content 1".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Give me a diff", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + }); + }); + + // Second user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool2".into()), + title: "Edit file 2".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test2.txt".into(), + old_text: Some("old content 2".into()), + new_text: "new content 2".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Another one", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + let second_user_message_id = thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 4); + let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { + panic!(); + }; + user_message.id.clone().unwrap() + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + assert!( + entry_view_state + .entry(2) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(3).unwrap().has_content()); + }); + }); + + // Rewind to first message + thread + .update(cx, |thread, cx| thread.rewind(second_user_message_id, cx)) + .await + .unwrap(); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + + // Old views should be dropped + assert!(entry_view_state.entry(2).is_none()); + assert!(entry_view_state.entry(3).is_none()); + }); + }); + } + + #[gpui::test] + async fn test_message_editing_cancel(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); + }); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Cancel + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(editor::actions::Cancel), cx); + }); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, None); + }); + + user_message_editor.read_with(cx, |editor, cx| { + assert_eq!(editor.text(cx), "Original message to edit"); + }); + } + + #[gpui::test] + async fn test_message_editing_regenerate(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Send + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "New Response".into(), + annotations: None, + }), + }]); + + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(Chat), cx); + }); + + cx.run_until_parked(); + + thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + let entries = view.thread().unwrap().read(cx).entries(); + assert_eq!(entries.len(), 2); + assert_eq!( + entries[0].to_markdown(cx), + "## User\n\nEdited message content\n\n" + ); + assert_eq!( + entries[1].to_markdown(cx), + "## Assistant\n\nNew Response\n\n" + ); + + let new_editor = view.entry_view_state.read_with(cx, |state, _cx| { + assert!(!state.entry(1).unwrap().has_content()); + state.entry(0).unwrap().message_editor().unwrap().clone() + }); + + assert_eq!(new_editor.read(cx).text(cx), "Edited message content"); + }) + } } diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 5f72fa58c8..4a2dd88c33 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -465,7 +465,7 @@ impl AgentConfiguration { "modifier-send", "Use modifier to submit a message", Some( - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(), ), use_modifier_to_send, move |state, _window, cx| { @@ -1035,7 +1035,6 @@ fn extension_only_provides_context_server(manifest: &ExtensionManifest) -> bool && manifest.grammars.is_empty() && manifest.language_servers.is_empty() && manifest.slash_commands.is_empty() - && manifest.indexed_docs_providers.is_empty() && manifest.snippets.is_none() && manifest.debug_locators.is_empty() } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index e1ceaf761d..b9e1ea5d0a 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1,9 +1,9 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; use acp_thread::{AcpThread, AcpThreadEvent}; +use action_log::ActionLog; use agent::{Thread, ThreadEvent, ThreadSummary}; use agent_settings::AgentSettings; use anyhow::Result; -use assistant_tool::ActionLog; use buffer_diff::DiffHunkStatus; use collections::{HashMap, HashSet}; use editor::{ @@ -1521,7 +1521,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::Stopped + AcpThreadEvent::EntriesRemoved(_) + | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {} diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 87e4dd822c..44d605af57 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1,4 +1,3 @@ -use std::cell::RefCell; use std::ops::{Not, Range}; use std::path::Path; use std::rc::Rc; @@ -11,13 +10,12 @@ use serde::{Deserialize, Serialize}; use crate::NewExternalAgentThread; use crate::agent_diff::AgentDiffThread; -use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::ui::NewThreadButton; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, - ResetTrialUpsell, ToggleBurnMode, ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu, + ResetTrialUpsell, ToggleBurnMode, ToggleContextPicker, ToggleNavigationMenu, + ToggleNewThreadMenu, ToggleOptionsMenu, acp::AcpThreadView, active_thread::{self, ActiveThread, ActiveThreadEvent}, agent_configuration::{AgentConfiguration, AssistantConfigurationEvent}, @@ -86,6 +84,7 @@ const AGENT_PANEL_KEY: &str = "agent_panel"; #[derive(Serialize, Deserialize)] struct SerializedAgentPanel { width: Option, + selected_agent: Option, } pub fn init(cx: &mut App) { @@ -179,6 +178,14 @@ pub fn init(cx: &mut App) { }); } }) + .register_action(|workspace, _: &ToggleNewThreadMenu, window, cx| { + if let Some(panel) = workspace.panel::(cx) { + workspace.focus_panel::(window, cx); + panel.update(cx, |panel, cx| { + panel.toggle_new_thread_menu(&ToggleNewThreadMenu, window, cx); + }); + } + }) .register_action(|workspace, _: &OpenOnboardingModal, window, cx| { AgentOnboardingModal::toggle(workspace, window, cx) }) @@ -223,6 +230,36 @@ enum WhichFontSize { None, } +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AgentType { + #[default] + Zed, + TextThread, + Gemini, + ClaudeCode, + NativeAgent, +} + +impl AgentType { + fn label(self) -> impl Into { + match self { + Self::Zed | Self::TextThread => "Zed", + Self::NativeAgent => "Agent 2", + Self::Gemini => "Gemini", + Self::ClaudeCode => "Claude Code", + } + } + + fn icon(self) -> IconName { + match self { + Self::Zed | Self::TextThread => IconName::AiZed, + Self::NativeAgent => IconName::ZedAssistant, + Self::Gemini => IconName::AiGemini, + Self::ClaudeCode => IconName::AiClaude, + } + } +} + impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { @@ -438,8 +475,6 @@ pub struct AgentPanel { configuration_subscription: Option, local_timezone: UtcOffset, active_view: ActiveView, - acp_message_history: - Rc>>>, previous_view: Option, history_store: Entity, history: Entity, @@ -453,16 +488,21 @@ pub struct AgentPanel { zoomed: bool, pending_serialization: Option>>, onboarding: Entity, + selected_agent: AgentType, } impl AgentPanel { fn serialize(&mut self, cx: &mut Context) { let width = self.width; + let selected_agent = self.selected_agent; self.pending_serialization = Some(cx.background_spawn(async move { KEY_VALUE_STORE .write_kvp( AGENT_PANEL_KEY.into(), - serde_json::to_string(&SerializedAgentPanel { width })?, + serde_json::to_string(&SerializedAgentPanel { + width, + selected_agent: Some(selected_agent), + })?, ) .await?; anyhow::Ok(()) @@ -531,6 +571,9 @@ impl AgentPanel { if let Some(serialized_panel) = serialized_panel { panel.update(cx, |panel, cx| { panel.width = serialized_panel.width.map(|w| w.round()); + if let Some(selected_agent) = serialized_panel.selected_agent { + panel.selected_agent = selected_agent; + } cx.notify(); }); } @@ -719,7 +762,6 @@ impl AgentPanel { .unwrap(), inline_assist_context_store, previous_view: None, - acp_message_history: Default::default(), history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), hovered_recent_history_item: None, @@ -732,6 +774,7 @@ impl AgentPanel { zoomed: false, pending_serialization: None, onboarding, + selected_agent: AgentType::default(), } } @@ -775,10 +818,10 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::ExternalAgentThread { thread_view, .. } => { - thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} } } @@ -915,7 +958,7 @@ impl AgentPanel { ) { let workspace = self.workspace.clone(); let project = self.project.clone(); - let message_history = self.acp_message_history.clone(); + let fs = self.fs.clone(); const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; @@ -924,6 +967,9 @@ impl AgentPanel { agent: crate::ExternalAgent, } + let thread_store = self.thread_store.clone(); + let text_thread_store = self.context_store.clone(); + cx.spawn_in(window, async move |this, cx| { let server: Rc = match agent_choice { Some(agent) => { @@ -939,7 +985,7 @@ impl AgentPanel { }) .detach(); - agent.server() + agent.server(fs) } None => cx .background_spawn(async move { @@ -953,7 +999,7 @@ impl AgentPanel { }) .unwrap_or_default() .agent - .server(), + .server(fs), }; this.update_in(cx, |this, window, cx| { @@ -962,9 +1008,8 @@ impl AgentPanel { server, workspace.clone(), project, - message_history, - MIN_EDITOR_LINES, - Some(MAX_EDITOR_LINES), + thread_store.clone(), + text_thread_store.clone(), window, cx, ) @@ -1173,6 +1218,15 @@ impl AgentPanel { self.agent_panel_menu_handle.toggle(window, cx); } + pub fn toggle_new_thread_menu( + &mut self, + _: &ToggleNewThreadMenu, + window: &mut Window, + cx: &mut Context, + ) { + self.new_thread_menu_handle.toggle(window, cx); + } + pub fn increase_font_size( &mut self, action: &IncreaseBufferFontSize, @@ -1203,13 +1257,11 @@ impl AgentPanel { ThemeSettings::get_global(cx).agent_font_size(cx) + delta; let _ = settings .agent_font_size - .insert(theme::clamp_font_size(agent_font_size).0); + .insert(Some(theme::clamp_font_size(agent_font_size).into())); }, ); } else { - theme::adjust_agent_font_size(cx, |size| { - *size += delta; - }); + theme::adjust_agent_font_size(cx, |size| size + delta); } } WhichFontSize::BufferFont => { @@ -1512,8 +1564,6 @@ impl AgentPanel { self.active_view = new_view; } - self.acp_message_history.borrow_mut().reset_position(); - self.focus_handle(cx).focus(window); } @@ -1580,6 +1630,17 @@ impl AgentPanel { menu } + + pub fn set_selected_agent(&mut self, agent: AgentType, cx: &mut Context) { + if self.selected_agent != agent { + self.selected_agent = agent; + self.serialize(cx); + } + } + + pub fn selected_agent(&self) -> AgentType { + self.selected_agent + } } impl Focusable for AgentPanel { @@ -1810,200 +1871,24 @@ impl AgentPanel { .into_any() } - fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render_panel_options_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { let user_store = self.user_store.read(cx); let usage = user_store.model_request_usage(); - let account_url = zed_urls::account_url(cx); let focus_handle = self.focus_handle(cx); - let go_back_button = div().child( - IconButton::new("go-back", IconName::ArrowLeft) - .icon_size(IconSize::Small) - .on_click(cx.listener(|this, _, window, cx| { - this.go_back(&workspace::GoBack, window, cx); - })) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Go Back", - &workspace::GoBack, - &focus_handle, - window, - cx, - ) - } - }), - ); - - let recent_entries_menu = div().child( - PopoverMenu::new("agent-nav-menu") - .trigger_with_tooltip( - IconButton::new("agent-nav-menu", IconName::MenuAlt) - .icon_size(IconSize::Small) - .style(ui::ButtonStyle::Subtle), - { - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Toggle Panel Menu", - &ToggleNavigationMenu, - &focus_handle, - window, - cx, - ) - } - }, - ) - .anchor(Corner::TopLeft) - .with_handle(self.assistant_navigation_menu_handle.clone()) - .menu({ - let menu = self.assistant_navigation_menu.clone(); - move |window, cx| { - if let Some(menu) = menu.as_ref() { - menu.update(cx, |_, cx| { - cx.defer_in(window, |menu, window, cx| { - menu.rebuild(window, cx); - }); - }) - } - menu.clone() - } - }), - ); - let full_screen_label = if self.is_zoomed(window, cx) { "Disable Full Screen" } else { "Enable Full Screen" }; - let active_thread = match &self.active_view { - ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), - ActiveView::ExternalAgentThread { .. } - | ActiveView::TextThread { .. } - | ActiveView::History - | ActiveView::Configuration => None, - }; - - let new_thread_menu = PopoverMenu::new("new_thread_menu") - .trigger_with_tooltip( - IconButton::new("new_thread_menu_btn", IconName::Plus).icon_size(IconSize::Small), - Tooltip::text("New Thread…"), - ) - .anchor(Corner::TopRight) - .with_handle(self.new_thread_menu_handle.clone()) - .menu({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - let active_thread = active_thread.clone(); - Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { - menu = menu - .context(focus_handle.clone()) - .when(cx.has_flag::(), |this| { - this.header("Zed Agent") - }) - .when_some(active_thread, |this, active_thread| { - let thread = active_thread.read(cx); - - if !thread.is_empty() { - let thread_id = thread.id().clone(); - this.item( - ContextMenuEntry::new("New From Summary") - .icon(IconName::ThreadFromSummary) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - Box::new(NewThread { - from_thread_id: Some(thread_id.clone()), - }), - cx, - ); - }), - ) - } else { - this - } - }) - .item( - ContextMenuEntry::new("New Thread") - .icon(IconName::Thread) - .icon_color(Color::Muted) - .action(NewThread::default().boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Text Thread") - .icon(IconName::TextThread) - .icon_color(Color::Muted) - .action(NewTextThread.boxed_clone()) - .handler(move |window, cx| { - window.dispatch_action(NewTextThread.boxed_clone(), cx); - }), - ) - .when(cx.has_flag::(), |this| { - this.separator() - .header("External Agents") - .item( - ContextMenuEntry::new("New Gemini Thread") - .icon(IconName::AiGemini) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Gemini), - } - .boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Claude Code Thread") - .icon(IconName::AiClaude) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - } - .boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Native Agent Thread") - .icon(IconName::ZedAssistant) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::NativeAgent, - ), - } - .boxed_clone(), - cx, - ); - }), - ) - }); - menu - })) - } - }); - - let agent_panel_menu = PopoverMenu::new("agent-options-menu") + PopoverMenu::new("agent-options-menu") .trigger_with_tooltip( IconButton::new("agent-options-menu", IconName::Ellipsis) .icon_size(IconSize::Small), @@ -2086,6 +1971,137 @@ impl AgentPanel { menu })) } + }) + } + + fn render_recent_entries_menu( + &self, + icon: IconName, + cx: &mut Context, + ) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + PopoverMenu::new("agent-nav-menu") + .trigger_with_tooltip( + IconButton::new("agent-nav-menu", icon).icon_size(IconSize::Small), + { + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Toggle Panel Menu", + &ToggleNavigationMenu, + &focus_handle, + window, + cx, + ) + } + }, + ) + .anchor(Corner::TopLeft) + .with_handle(self.assistant_navigation_menu_handle.clone()) + .menu({ + let menu = self.assistant_navigation_menu.clone(); + move |window, cx| { + if let Some(menu) = menu.as_ref() { + menu.update(cx, |_, cx| { + cx.defer_in(window, |menu, window, cx| { + menu.rebuild(window, cx); + }); + }) + } + menu.clone() + } + }) + } + + fn render_toolbar_back_button(&self, cx: &mut Context) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + IconButton::new("go-back", IconName::ArrowLeft) + .icon_size(IconSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.go_back(&workspace::GoBack, window, cx); + })) + .tooltip({ + let focus_handle = focus_handle.clone(); + + move |window, cx| { + Tooltip::for_action_in("Go Back", &workspace::GoBack, &focus_handle, window, cx) + } + }) + } + + fn render_toolbar_old(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + let active_thread = match &self.active_view { + ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => None, + }; + + let new_thread_menu = PopoverMenu::new("new_thread_menu") + .trigger_with_tooltip( + IconButton::new("new_thread_menu_btn", IconName::Plus).icon_size(IconSize::Small), + Tooltip::text("New Thread…"), + ) + .anchor(Corner::TopRight) + .with_handle(self.new_thread_menu_handle.clone()) + .menu({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + let active_thread = active_thread.clone(); + Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { + menu = menu + .context(focus_handle.clone()) + .when_some(active_thread, |this, active_thread| { + let thread = active_thread.read(cx); + + if !thread.is_empty() { + let thread_id = thread.id().clone(); + this.item( + ContextMenuEntry::new("New From Summary") + .icon(IconName::ThreadFromSummary) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + cx, + ); + }), + ) + } else { + this + } + }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::Thread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::TextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler(move |window, cx| { + window.dispatch_action(NewTextThread.boxed_clone(), cx); + }), + ); + menu + })) + } }); h_flex() @@ -2104,8 +2120,13 @@ impl AgentPanel { .pl_1() .gap_1() .child(match &self.active_view { - ActiveView::History | ActiveView::Configuration => go_back_button, - _ => recent_entries_menu, + ActiveView::History | ActiveView::Configuration => div() + .pl(DynamicSpacing::Base04.rems(cx)) + .child(self.render_toolbar_back_button(cx)) + .into_any_element(), + _ => self + .render_recent_entries_menu(IconName::MenuAlt, cx) + .into_any_element(), }) .child(self.render_title_view(window, cx)), ) @@ -2122,11 +2143,292 @@ impl AgentPanel { .border_l_1() .border_color(cx.theme().colors().border) .child(new_thread_menu) - .child(agent_panel_menu), + .child(self.render_panel_options_menu(window, cx)), ), ) } + fn render_toolbar_new(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let focus_handle = self.focus_handle(cx); + + let active_thread = match &self.active_view { + ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => None, + }; + + let new_thread_menu = PopoverMenu::new("new_thread_menu") + .trigger_with_tooltip( + IconButton::new("new_thread_menu_btn", IconName::Plus).icon_size(IconSize::Small), + { + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "New…", + &ToggleNewThreadMenu, + &focus_handle, + window, + cx, + ) + } + }, + ) + .anchor(Corner::TopLeft) + .with_handle(self.new_thread_menu_handle.clone()) + .menu({ + let focus_handle = focus_handle.clone(); + let workspace = self.workspace.clone(); + + move |window, cx| { + let active_thread = active_thread.clone(); + Some(ContextMenu::build(window, cx, |mut menu, _window, cx| { + menu = menu + .context(focus_handle.clone()) + .header("Zed Agent") + .when_some(active_thread, |this, active_thread| { + let thread = active_thread.read(cx); + + if !thread.is_empty() { + let thread_id = thread.id().clone(); + this.item( + ContextMenuEntry::new("New From Summary") + .icon(IconName::ThreadFromSummary) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + cx, + ); + }), + ) + } else { + this + } + }) + .item( + ContextMenuEntry::new("New Thread") + .icon(IconName::Thread) + .icon_color(Color::Muted) + .action(NewThread::default().boxed_clone()) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.set_selected_agent( + AgentType::Zed, + cx, + ); + }); + } + }); + } + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); + } + }), + ) + .item( + ContextMenuEntry::new("New Text Thread") + .icon(IconName::TextThread) + .icon_color(Color::Muted) + .action(NewTextThread.boxed_clone()) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.set_selected_agent( + AgentType::TextThread, + cx, + ); + }); + } + }); + } + window.dispatch_action(NewTextThread.boxed_clone(), cx); + } + }), + ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.set_selected_agent( + AgentType::NativeAgent, + cx, + ); + }); + } + }); + } + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::NativeAgent), + } + .boxed_clone(), + cx, + ); + } + }), + ) + .separator() + .header("External Agents") + .item( + ContextMenuEntry::new("New Gemini Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.set_selected_agent( + AgentType::Gemini, + cx, + ); + }); + } + }); + } + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + cx, + ); + } + }), + ) + .item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .icon_color(Color::Muted) + .handler({ + let workspace = workspace.clone(); + move |window, cx| { + if let Some(workspace) = workspace.upgrade() { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = + workspace.panel::(cx) + { + panel.update(cx, |panel, cx| { + panel.set_selected_agent( + AgentType::ClaudeCode, + cx, + ); + }); + } + }); + } + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::ClaudeCode), + } + .boxed_clone(), + cx, + ); + } + }), + ); + menu + })) + } + }); + + h_flex() + .id("agent-panel-toolbar") + .h(Tab::container_height(cx)) + .max_w_full() + .flex_none() + .justify_between() + .gap_2() + .bg(cx.theme().colors().tab_bar_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .size_full() + .gap(DynamicSpacing::Base08.rems(cx)) + .child(match &self.active_view { + ActiveView::History | ActiveView::Configuration => div() + .pl(DynamicSpacing::Base04.rems(cx)) + .child(self.render_toolbar_back_button(cx)) + .into_any_element(), + _ => h_flex() + .h_full() + .px(DynamicSpacing::Base04.rems(cx)) + .border_r_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .px_0p5() + .gap_1p5() + .child( + Icon::new(self.selected_agent.icon()).color(Color::Muted), + ) + .child(Label::new(self.selected_agent.label())), + ) + .into_any_element(), + }) + .child(self.render_title_view(window, cx)), + ) + .child( + h_flex() + .h_full() + .gap_2() + .children(self.render_token_count(cx)) + .child( + h_flex() + .h_full() + .gap(DynamicSpacing::Base02.rems(cx)) + .pl(DynamicSpacing::Base04.rems(cx)) + .pr(DynamicSpacing::Base06.rems(cx)) + .border_l_1() + .border_color(cx.theme().colors().border) + .child(new_thread_menu) + .child(self.render_recent_entries_menu(IconName::HistoryRerun, cx)) + .child(self.render_panel_options_menu(window, cx)), + ), + ) + } + + fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + if cx.has_flag::() { + self.render_toolbar_new(window, cx).into_any_element() + } else { + self.render_toolbar_old(window, cx).into_any_element() + } + } + fn render_token_count(&self, cx: &App) -> Option { match &self.active_view { ActiveView::Thread { @@ -2575,138 +2877,6 @@ impl AgentPanel { }, )), ) - .child(self.render_empty_state_section_header("Start", None, cx)) - .child( - v_flex() - .p_1() - .gap_2() - .child( - h_flex() - .w_full() - .gap_2() - .child( - NewThreadButton::new( - "new-thread-btn", - "New Thread", - IconName::Thread, - ) - .keybinding(KeyBinding::for_action_in( - &NewThread::default(), - &self.focus_handle(cx), - window, - cx, - )) - .on_click( - |window, cx| { - window.dispatch_action( - NewThread::default().boxed_clone(), - cx, - ) - }, - ), - ) - .child( - NewThreadButton::new( - "new-text-thread-btn", - "New Text Thread", - IconName::TextThread, - ) - .keybinding(KeyBinding::for_action_in( - &NewTextThread, - &self.focus_handle(cx), - window, - cx, - )) - .on_click( - |window, cx| { - window.dispatch_action(Box::new(NewTextThread), cx) - }, - ), - ), - ) - .when(cx.has_flag::(), |this| { - this.child( - h_flex() - .w_full() - .gap_2() - .child( - NewThreadButton::new( - "new-gemini-thread-btn", - "New Gemini Thread", - IconName::AiGemini, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Gemini, - ), - }), - cx, - ) - }, - ), - ) - .child( - NewThreadButton::new( - "new-claude-thread-btn", - "New Claude Code Thread", - IconName::AiClaude, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - }), - cx, - ) - }, - ), - ) - .child( - NewThreadButton::new( - "new-native-agent-thread-btn", - "New Native Agent Thread", - IconName::ZedAssistant, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::NativeAgent, - ), - }), - cx, - ) - }, - ), - ), - ) - }), - ) .when_some(configuration_error.as_ref(), |this, err| { this.child(self.render_configuration_error(err, &focus_handle, window, cx)) }) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index fceb8f4c45..f25b576886 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -5,7 +5,6 @@ mod agent_diff; mod agent_model_selector; mod agent_panel; mod buffer_codegen; -mod burn_mode_tooltip; mod context_picker; mod context_server_configuration; mod context_strip; @@ -64,6 +63,8 @@ actions!( NewTextThread, /// Toggles the context picker interface for adding files, symbols, or other context. ToggleContextPicker, + /// Toggles the menu to create new agent threads. + ToggleNewThreadMenu, /// Toggles the navigation menu for switching between threads and views. ToggleNavigationMenu, /// Toggles the options menu for agent settings and preferences. @@ -155,11 +156,11 @@ enum ExternalAgent { } impl ExternalAgent { - pub fn server(&self) -> Rc { + pub fn server(&self, fs: Arc) -> Rc { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)), } } } @@ -241,7 +242,6 @@ pub fn init( client.telemetry().clone(), cx, ); - indexed_docs::init(cx); cx.observe_new(move |workspace, window, cx| { ConfigureContextServerModal::register(workspace, language_registry.clone(), window, cx) }) @@ -408,12 +408,6 @@ fn update_slash_commands_from_settings(cx: &mut App) { let slash_command_registry = SlashCommandRegistry::global(cx); let settings = SlashCommandSettings::get_global(cx); - if settings.docs.enabled { - slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true); - } else { - slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand); - } - if settings.cargo_workspace.enabled { slash_command_registry .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true); diff --git a/crates/agent_ui/src/burn_mode_tooltip.rs b/crates/agent_ui/src/burn_mode_tooltip.rs deleted file mode 100644 index 6354c07760..0000000000 --- a/crates/agent_ui/src/burn_mode_tooltip.rs +++ /dev/null @@ -1,61 +0,0 @@ -use gpui::{Context, FontWeight, IntoElement, Render, Window}; -use ui::{prelude::*, tooltip_container}; - -pub struct BurnModeTooltip { - selected: bool, -} - -impl BurnModeTooltip { - pub fn new() -> Self { - Self { selected: false } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } -} - -impl Render for BurnModeTooltip { - fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let (icon, color) = if self.selected { - (IconName::ZedBurnModeOn, Color::Error) - } else { - (IconName::ZedBurnMode, Color::Default) - }; - - let turned_on = h_flex() - .h_4() - .px_1() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().text_accent.opacity(0.1)) - .rounded_sm() - .child( - Label::new("ON") - .size(LabelSize::XSmall) - .weight(FontWeight::SEMIBOLD) - .color(Color::Accent), - ); - - let title = h_flex() - .gap_1p5() - .child(Icon::new(icon).size(IconSize::Small).color(color)) - .child(Label::new("Burn Mode")) - .when(self.selected, |title| title.child(turned_on)); - - tooltip_container(window, cx, |this, _, _| { - this - .child(title) - .child( - div() - .max_w_64() - .child( - Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.") - .size(LabelSize::Small) - .color(Color::Muted) - ) - ) - }) - } -} diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 58f11313e6..131023d249 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -1,18 +1,19 @@ mod completion_provider; -mod fetch_context_picker; +pub(crate) mod fetch_context_picker; pub(crate) mod file_context_picker; -mod rules_context_picker; -mod symbol_context_picker; -mod thread_context_picker; +pub(crate) mod rules_context_picker; +pub(crate) mod symbol_context_picker; +pub(crate) mod thread_context_picker; use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; use anyhow::{Result, anyhow}; +use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId}; -use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset}; use fetch_context_picker::FetchContextPicker; use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; @@ -45,7 +46,7 @@ use agent::{ }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerEntry { +pub(crate) enum ContextPickerEntry { Mode(ContextPickerMode), Action(ContextPickerAction), } @@ -74,7 +75,7 @@ impl ContextPickerEntry { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerMode { +pub(crate) enum ContextPickerMode { File, Symbol, Fetch, @@ -83,7 +84,7 @@ enum ContextPickerMode { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ContextPickerAction { +pub(crate) enum ContextPickerAction { AddSelections, } @@ -227,7 +228,7 @@ impl ContextPicker { } fn build_menu(&mut self, window: &mut Window, cx: &mut Context) -> Entity { - let context_picker = cx.entity().clone(); + let context_picker = cx.entity(); let menu = ContextMenu::build(window, cx, move |menu, _window, cx| { let recent = self.recent_entries(cx); @@ -531,7 +532,7 @@ impl ContextPicker { return vec![]; }; - recent_context_picker_entries( + recent_context_picker_entries_with_store( context_store, self.thread_store.clone(), self.text_thread_store.clone(), @@ -585,7 +586,8 @@ impl Render for ContextPicker { }) } } -enum RecentEntry { + +pub(crate) enum RecentEntry { File { project_path: ProjectPath, path_prefix: Arc, @@ -593,7 +595,7 @@ enum RecentEntry { Thread(ThreadContextEntry), } -fn available_context_picker_entries( +pub(crate) fn available_context_picker_entries( prompt_store: &Option>, thread_store: &Option>, workspace: &Entity, @@ -630,24 +632,56 @@ fn available_context_picker_entries( entries } -fn recent_context_picker_entries( +fn recent_context_picker_entries_with_store( context_store: Entity, thread_store: Option>, text_thread_store: Option>, workspace: Entity, exclude_path: Option, cx: &App, +) -> Vec { + let project = workspace.read(cx).project(); + + let mut exclude_paths = context_store.read(cx).file_paths(cx); + exclude_paths.extend(exclude_path); + + let exclude_paths = exclude_paths + .into_iter() + .filter_map(|project_path| project.read(cx).absolute_path(&project_path, cx)) + .collect(); + + let exclude_threads = context_store.read(cx).thread_ids(); + + recent_context_picker_entries( + thread_store, + text_thread_store, + workspace, + &exclude_paths, + exclude_threads, + cx, + ) +} + +pub(crate) fn recent_context_picker_entries( + thread_store: Option>, + text_thread_store: Option>, + workspace: Entity, + exclude_paths: &HashSet, + exclude_threads: &HashSet, + cx: &App, ) -> Vec { let mut recent = Vec::with_capacity(6); - let mut current_files = context_store.read(cx).file_paths(cx); - current_files.extend(exclude_path); let workspace = workspace.read(cx); let project = workspace.project().read(cx); recent.extend( workspace .recent_navigation_history_iter(cx) - .filter(|(path, _)| !current_files.contains(path)) + .filter(|(_, abs_path)| { + abs_path + .as_ref() + .map_or(true, |path| !exclude_paths.contains(path.as_path())) + }) .take(4) .filter_map(|(project_path, _)| { project @@ -659,8 +693,6 @@ fn recent_context_picker_entries( }), ); - let current_threads = context_store.read(cx).thread_ids(); - let active_thread_id = workspace .panel::(cx) .and_then(|panel| Some(panel.read(cx).active_thread(cx)?.read(cx).id())); @@ -672,7 +704,7 @@ fn recent_context_picker_entries( let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx) .filter(|(_, thread)| match thread { ThreadContextEntry::Thread { id, .. } => { - Some(id) != active_thread_id && !current_threads.contains(id) + Some(id) != active_thread_id && !exclude_threads.contains(id) } ThreadContextEntry::Context { .. } => true, }) @@ -710,7 +742,7 @@ fn add_selections_as_context( }) } -fn selection_ranges( +pub(crate) fn selection_ranges( workspace: &Entity, cx: &mut App, ) -> Vec<(Entity, Range)> { @@ -805,42 +837,9 @@ fn render_fold_icon_button( ) -> Arc, &mut App) -> AnyElement> { Arc::new({ move |fold_id, fold_range, cx| { - let is_in_text_selection = editor.upgrade().is_some_and(|editor| { - editor.update(cx, |editor, cx| { - let snapshot = editor - .buffer() - .update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx)); - - let is_in_pending_selection = || { - editor - .selections - .pending - .as_ref() - .is_some_and(|pending_selection| { - pending_selection - .selection - .range() - .includes(&fold_range, &snapshot) - }) - }; - - let mut is_in_complete_selection = || { - editor - .selections - .disjoint_in_range::(fold_range.clone(), cx) - .into_iter() - .any(|selection| { - // This is needed to cover a corner case, if we just check for an existing - // selection in the fold range, having a cursor at the start of the fold - // marks it as selected. Non-empty selections don't cause this. - let length = selection.end - selection.start; - length > 0 - }) - }; - - is_in_pending_selection() || is_in_complete_selection() - }) - }); + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); ButtonLike::new(fold_id) .style(ButtonStyle::Filled) diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index 8123b3437d..962c0df03d 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -35,7 +35,7 @@ use super::symbol_context_picker::search_symbols; use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; use super::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry, - available_context_picker_entries, recent_context_picker_entries, selection_ranges, + available_context_picker_entries, recent_context_picker_entries_with_store, selection_ranges, }; use crate::message_editor::ContextCreasesAddon; @@ -787,7 +787,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { .and_then(|b| b.read(cx).file()) .map(|file| ProjectPath::from_file(file.as_ref(), cx)); - let recent_entries = recent_context_picker_entries( + let recent_entries = recent_context_picker_entries_with_store( context_store.clone(), thread_store.clone(), text_thread_store.clone(), diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 4a4a747899..bbd3595805 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -72,7 +72,7 @@ pub fn init( let Some(window) = window else { return; }; - let workspace = cx.entity().clone(); + let workspace = cx.entity(); InlineAssistant::update_global(cx, |inline_assistant, cx| { inline_assistant.register_workspace(&workspace, window, cx) }); diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7121624c87..bb8514a224 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,5 +1,6 @@ use std::{cmp::Reverse, sync::Arc}; +use cloud_llm_client::Plan; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; @@ -10,7 +11,6 @@ use language_model::{ }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; @@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { ) -> Option { use feature_flags::FeatureFlagAppExt; - let plan = proto::Plan::ZedPro; + let plan = Plan::ZedPro; Some( h_flex() @@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { window .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx) }), - Plan::Free | Plan::ZedProTrial => Button::new( + Plan::ZedFree | Plan::ZedProTrial => Button::new( "try-pro", if plan == Plan::ZedProTrial { "Upgrade to Pro" diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 4b6d51c4c1..127e9256be 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::ui::{ - MaxModeTooltip, + BurnModeTooltip, preview::{AgentPreview, UsageCallout}, }; use agent::history_store::HistoryStore; @@ -14,7 +14,7 @@ use agent::{ context::{AgentContextKey, ContextLoadResult, load_context}, context_store::ContextStoreEvent, }; -use agent_settings::{AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; use cloud_llm_client::CompletionIntent; @@ -55,7 +55,7 @@ use zed_actions::agent::ToggleModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; -use crate::profile_selector::ProfileSelector; +use crate::profile_selector::{ProfileProvider, ProfileSelector}; use crate::{ ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll, ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode, @@ -152,6 +152,24 @@ pub(crate) fn create_editor( editor } +impl ProfileProvider for Entity { + fn profiles_supported(&self, cx: &App) -> bool { + self.read(cx) + .configured_model() + .map_or(false, |model| model.model.supports_tools()) + } + + fn profile_id(&self, cx: &App) -> AgentProfileId { + self.read(cx).profile().id().clone() + } + + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { + self.update(cx, |this, cx| { + this.set_profile(profile_id, cx); + }); + } +} + impl MessageEditor { pub fn new( fs: Arc, @@ -221,8 +239,9 @@ impl MessageEditor { ) }); - let profile_selector = - cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); + let profile_selector = cx.new(|cx| { + ProfileSelector::new(fs, Arc::new(thread.clone()), editor.focus_handle(cx), cx) + }); Self { editor: editor.clone(), @@ -605,7 +624,7 @@ impl MessageEditor { this.toggle_burn_mode(&ToggleBurnMode, window, cx); })) .tooltip(move |_window, cx| { - cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled)) + cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled)) .into() }) .into_any_element(), diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index ddcb44d46b..ce25f531e2 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -1,12 +1,8 @@ use crate::{ManageProfiles, ToggleProfileSelector}; -use agent::{ - Thread, - agent_profile::{AgentProfile, AvailableProfiles}, -}; +use agent::agent_profile::{AgentProfile, AvailableProfiles}; use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use fs::Fs; -use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*}; -use language_model::LanguageModelRegistry; +use gpui::{Action, Entity, FocusHandle, Subscription, prelude::*}; use settings::{Settings as _, SettingsStore, update_settings_file}; use std::sync::Arc; use ui::{ @@ -14,10 +10,22 @@ use ui::{ prelude::*, }; +/// Trait for types that can provide and manage agent profiles +pub trait ProfileProvider { + /// Get the current profile ID + fn profile_id(&self, cx: &App) -> AgentProfileId; + + /// Set the profile ID + fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App); + + /// Check if profiles are supported in the current context (e.g. if the model that is selected has tool support) + fn profiles_supported(&self, cx: &App) -> bool; +} + pub struct ProfileSelector { profiles: AvailableProfiles, fs: Arc, - thread: Entity, + provider: Arc, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -26,7 +34,7 @@ pub struct ProfileSelector { impl ProfileSelector { pub fn new( fs: Arc, - thread: Entity, + provider: Arc, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { @@ -37,7 +45,7 @@ impl ProfileSelector { Self { profiles: AgentProfile::available_profiles(cx), fs, - thread, + provider, menu_handle: PopoverMenuHandle::default(), focus_handle, _subscriptions: vec![settings_subscription], @@ -113,10 +121,10 @@ impl ProfileSelector { builtin_profiles::MINIMAL => Some("Chat about anything with no tools."), _ => None, }; - let thread_profile_id = self.thread.read(cx).profile().id(); + let thread_profile_id = self.provider.profile_id(cx); let entry = ContextMenuEntry::new(profile_name.clone()) - .toggleable(IconPosition::End, &profile_id == thread_profile_id); + .toggleable(IconPosition::End, profile_id == thread_profile_id); let entry = if let Some(doc_text) = documentation { entry.documentation_aside(documentation_side(settings.dock), move |_| { @@ -128,7 +136,7 @@ impl ProfileSelector { entry.handler({ let fs = self.fs.clone(); - let thread = self.thread.clone(); + let provider = self.provider.clone(); let profile_id = profile_id.clone(); move |_window, cx| { update_settings_file::(fs.clone(), cx, { @@ -138,9 +146,7 @@ impl ProfileSelector { } }); - thread.update(cx, |this, cx| { - this.set_profile(profile_id.clone(), cx); - }); + provider.set_profile(profile_id.clone(), cx); } }) } @@ -149,23 +155,15 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = self.thread.read(cx).profile().id(); - let profile = settings.profiles.get(profile_id); + let profile_id = self.provider.profile_id(cx); + let profile = settings.profiles.get(&profile_id); let selected_profile = profile .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let configured_model = self.thread.read(cx).configured_model().or_else(|| { - let model_registry = LanguageModelRegistry::read_global(cx); - model_registry.default_model() - }); - let Some(configured_model) = configured_model else { - return Empty.into_any_element(); - }; - - if configured_model.model.supports_tools() { - let this = cx.entity().clone(); + if self.provider.profiles_supported(cx) { + let this = cx.entity(); let focus_handle = self.focus_handle.clone(); let trigger_button = Button::new("profile-selector-model", selected_profile) .label_size(LabelSize::Small) diff --git a/crates/agent_ui/src/slash_command_settings.rs b/crates/agent_ui/src/slash_command_settings.rs index f254d00ec6..73e5622aa9 100644 --- a/crates/agent_ui/src/slash_command_settings.rs +++ b/crates/agent_ui/src/slash_command_settings.rs @@ -7,22 +7,11 @@ use settings::{Settings, SettingsSources}; /// Settings for slash commands. #[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] pub struct SlashCommandSettings { - /// Settings for the `/docs` slash command. - #[serde(default)] - pub docs: DocsCommandSettings, /// Settings for the `/cargo-workspace` slash command. #[serde(default)] pub cargo_workspace: CargoWorkspaceCommandSettings, } -/// Settings for the `/docs` slash command. -#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] -pub struct DocsCommandSettings { - /// Whether `/docs` is enabled. - #[serde(default)] - pub enabled: bool, -} - /// Settings for the `/cargo-workspace` slash command. #[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] pub struct CargoWorkspaceCommandSettings { diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 49a37002f7..8c1e163eca 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1,14 +1,11 @@ use crate::{ - burn_mode_tooltip::BurnModeTooltip, language_model_selector::{LanguageModelSelector, language_model_selector}, + ui::BurnModeTooltip, }; use agent_settings::{AgentSettings, CompletionMode}; use anyhow::Result; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet}; -use assistant_slash_commands::{ - DefaultSlashCommand, DocsSlashCommand, DocsSlashCommandArgs, FileSlashCommand, - selections_creases, -}; +use assistant_slash_commands::{DefaultSlashCommand, FileSlashCommand, selections_creases}; use client::{proto, zed_urls}; use collections::{BTreeSet, HashMap, HashSet, hash_map}; use editor::{ @@ -30,7 +27,6 @@ use gpui::{ StatefulInteractiveElement, Styled, Subscription, Task, Transformation, WeakEntity, actions, div, img, percentage, point, prelude::*, pulsating_between, size, }; -use indexed_docs::IndexedDocsStore; use language::{ BufferSnapshot, LspAdapterDelegate, ToOffset, language_settings::{SoftWrap, all_language_settings}, @@ -77,7 +73,7 @@ use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker} use assistant_context::{ AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus, - ParsedSlashCommand, PendingSlashCommandStatus, ThoughtProcessOutputSection, + PendingSlashCommandStatus, ThoughtProcessOutputSection, }; actions!( @@ -701,19 +697,7 @@ impl TextThreadEditor { } }; let render_trailer = { - let command = command.clone(); - move |row, _unfold, _window: &mut Window, cx: &mut App| { - // TODO: In the future we should investigate how we can expose - // this as a hook on the `SlashCommand` trait so that we don't - // need to special-case it here. - if command.name == DocsSlashCommand::NAME { - return render_docs_slash_command_trailer( - row, - command.clone(), - cx, - ); - } - + move |_row, _unfold, _window: &mut Window, _cx: &mut App| { Empty.into_any() } }; @@ -2398,70 +2382,6 @@ fn render_pending_slash_command_gutter_decoration( icon.into_any_element() } -fn render_docs_slash_command_trailer( - row: MultiBufferRow, - command: ParsedSlashCommand, - cx: &mut App, -) -> AnyElement { - if command.arguments.is_empty() { - return Empty.into_any(); - } - let args = DocsSlashCommandArgs::parse(&command.arguments); - - let Some(store) = args - .provider() - .and_then(|provider| IndexedDocsStore::try_global(provider, cx).ok()) - else { - return Empty.into_any(); - }; - - let Some(package) = args.package() else { - return Empty.into_any(); - }; - - let mut children = Vec::new(); - - if store.is_indexing(&package) { - children.push( - div() - .id(("crates-being-indexed", row.0)) - .child(Icon::new(IconName::ArrowCircle).with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(4)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - )) - .tooltip({ - let package = package.clone(); - Tooltip::text(format!("Indexing {package}…")) - }) - .into_any_element(), - ); - } - - if let Some(latest_error) = store.latest_error_for_package(&package) { - children.push( - div() - .id(("latest-error", row.0)) - .child( - Icon::new(IconName::Warning) - .size(IconSize::Small) - .color(Color::Warning), - ) - .tooltip(Tooltip::text(format!("Failed to index: {latest_error}"))) - .into_any_element(), - ) - } - - let is_indexing = store.is_indexing(&package); - let latest_error = store.latest_error_for_package(&package); - - if !is_indexing && latest_error.is_none() { - return Empty.into_any(); - } - - h_flex().gap_2().children(children).into_any_element() -} - #[derive(Debug, Clone, Serialize, Deserialize)] struct CopyMetadata { creases: Vec, diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index b477a8c385..beeaf0c43b 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -2,7 +2,7 @@ mod agent_notification; mod burn_mode_tooltip; mod context_pill; mod end_trial_upsell; -mod new_thread_button; +// mod new_thread_button; mod onboarding_modal; pub mod preview; @@ -10,5 +10,5 @@ pub use agent_notification::*; pub use burn_mode_tooltip::*; pub use context_pill::*; pub use end_trial_upsell::*; -pub use new_thread_button::*; +// pub use new_thread_button::*; pub use onboarding_modal::*; diff --git a/crates/agent_ui/src/ui/burn_mode_tooltip.rs b/crates/agent_ui/src/ui/burn_mode_tooltip.rs index 97f7853a61..72faaa614d 100644 --- a/crates/agent_ui/src/ui/burn_mode_tooltip.rs +++ b/crates/agent_ui/src/ui/burn_mode_tooltip.rs @@ -2,11 +2,11 @@ use crate::ToggleBurnMode; use gpui::{Context, FontWeight, IntoElement, Render, Window}; use ui::{KeyBinding, prelude::*, tooltip_container}; -pub struct MaxModeTooltip { +pub struct BurnModeTooltip { selected: bool, } -impl MaxModeTooltip { +impl BurnModeTooltip { pub fn new() -> Self { Self { selected: false } } @@ -17,7 +17,7 @@ impl MaxModeTooltip { } } -impl Render for MaxModeTooltip { +impl Render for BurnModeTooltip { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let (icon, color) = if self.selected { (IconName::ZedBurnModeOn, Color::Error) diff --git a/crates/agent_ui/src/ui/new_thread_button.rs b/crates/agent_ui/src/ui/new_thread_button.rs index 7764144150..347d6adcaf 100644 --- a/crates/agent_ui/src/ui/new_thread_button.rs +++ b/crates/agent_ui/src/ui/new_thread_button.rs @@ -11,7 +11,7 @@ pub struct NewThreadButton { } impl NewThreadButton { - pub fn new(id: impl Into, label: impl Into, icon: IconName) -> Self { + fn new(id: impl Into, label: impl Into, icon: IconName) -> Self { Self { id: id.into(), label: label.into(), @@ -21,12 +21,12 @@ impl NewThreadButton { } } - pub fn keybinding(mut self, keybinding: Option) -> Self { + fn keybinding(mut self, keybinding: Option) -> Self { self.keybinding = keybinding; self } - pub fn on_click(mut self, handler: F) -> Self + fn on_click(mut self, handler: F) -> Self where F: Fn(&mut Window, &mut App) + 'static, { diff --git a/crates/assets/src/assets.rs b/crates/assets/src/assets.rs index fad0c58b73..5c7e671159 100644 --- a/crates/assets/src/assets.rs +++ b/crates/assets/src/assets.rs @@ -58,9 +58,7 @@ impl Assets { pub fn load_test_fonts(&self, cx: &App) { cx.text_system() .add_fonts(vec![ - self.load("fonts/plex-mono/ZedPlexMono-Regular.ttf") - .unwrap() - .unwrap(), + self.load("fonts/lilex/Lilex-Regular.ttf").unwrap().unwrap(), ]) .unwrap() } diff --git a/crates/assistant_context/Cargo.toml b/crates/assistant_context/Cargo.toml index 8f5ff98790..45c0072418 100644 --- a/crates/assistant_context/Cargo.toml +++ b/crates/assistant_context/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/assistant_context.rs" +[features] +test-support = [] + [dependencies] agent_settings.workspace = true anyhow.workspace = true diff --git a/crates/assistant_context/src/context_store.rs b/crates/assistant_context/src/context_store.rs index 3090a7b234..622d8867a7 100644 --- a/crates/assistant_context/src/context_store.rs +++ b/crates/assistant_context/src/context_store.rs @@ -138,6 +138,27 @@ impl ContextStore { }) } + #[cfg(any(test, feature = "test-support"))] + pub fn fake(project: Entity, cx: &mut Context) -> Self { + Self { + contexts: Default::default(), + contexts_metadata: Default::default(), + context_server_slash_command_ids: Default::default(), + host_contexts: Default::default(), + fs: project.read(cx).fs().clone(), + languages: project.read(cx).languages().clone(), + slash_commands: Arc::default(), + telemetry: project.read(cx).client().telemetry().clone(), + _watch_updates: Task::ready(None), + client: project.read(cx).client(), + project, + project_is_shared: false, + client_subscription: None, + _project_subscriptions: Default::default(), + prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()), + } + } + async fn handle_advertise_contexts( this: Entity, envelope: TypedEnvelope, diff --git a/crates/assistant_slash_commands/Cargo.toml b/crates/assistant_slash_commands/Cargo.toml index f703a753f5..c054c3ced8 100644 --- a/crates/assistant_slash_commands/Cargo.toml +++ b/crates/assistant_slash_commands/Cargo.toml @@ -27,7 +27,6 @@ globset.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true -indexed_docs.workspace = true language.workspace = true project.workspace = true prompt_store.workspace = true diff --git a/crates/assistant_slash_commands/src/assistant_slash_commands.rs b/crates/assistant_slash_commands/src/assistant_slash_commands.rs index fa5dd8b683..fb00a91219 100644 --- a/crates/assistant_slash_commands/src/assistant_slash_commands.rs +++ b/crates/assistant_slash_commands/src/assistant_slash_commands.rs @@ -3,7 +3,6 @@ mod context_server_command; mod default_command; mod delta_command; mod diagnostics_command; -mod docs_command; mod fetch_command; mod file_command; mod now_command; @@ -18,7 +17,6 @@ pub use crate::context_server_command::*; pub use crate::default_command::*; pub use crate::delta_command::*; pub use crate::diagnostics_command::*; -pub use crate::docs_command::*; pub use crate::fetch_command::*; pub use crate::file_command::*; pub use crate::now_command::*; diff --git a/crates/assistant_slash_commands/src/docs_command.rs b/crates/assistant_slash_commands/src/docs_command.rs deleted file mode 100644 index bd87c72849..0000000000 --- a/crates/assistant_slash_commands/src/docs_command.rs +++ /dev/null @@ -1,543 +0,0 @@ -use std::path::Path; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; - -use anyhow::{Context as _, Result, anyhow, bail}; -use assistant_slash_command::{ - ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, - SlashCommandResult, -}; -use gpui::{App, BackgroundExecutor, Entity, Task, WeakEntity}; -use indexed_docs::{ - DocsDotRsProvider, IndexedDocsRegistry, IndexedDocsStore, LocalRustdocProvider, PackageName, - ProviderId, -}; -use language::{BufferSnapshot, LspAdapterDelegate}; -use project::{Project, ProjectPath}; -use ui::prelude::*; -use util::{ResultExt, maybe}; -use workspace::Workspace; - -pub struct DocsSlashCommand; - -impl DocsSlashCommand { - pub const NAME: &'static str = "docs"; - - fn path_to_cargo_toml(project: Entity, cx: &mut App) -> Option> { - let worktree = project.read(cx).worktrees(cx).next()?; - let worktree = worktree.read(cx); - let entry = worktree.entry_for_path("Cargo.toml")?; - let path = ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }; - Some(Arc::from( - project.read(cx).absolute_path(&path, cx)?.as_path(), - )) - } - - /// Ensures that the indexed doc providers for Rust are registered. - /// - /// Ideally we would do this sooner, but we need to wait until we're able to - /// access the workspace so we can read the project. - fn ensure_rust_doc_providers_are_registered( - &self, - workspace: Option>, - cx: &mut App, - ) { - let indexed_docs_registry = IndexedDocsRegistry::global(cx); - if indexed_docs_registry - .get_provider_store(LocalRustdocProvider::id()) - .is_none() - { - let index_provider_deps = maybe!({ - let workspace = workspace - .as_ref() - .context("no workspace")? - .upgrade() - .context("workspace dropped")?; - let project = workspace.read(cx).project().clone(); - let fs = project.read(cx).fs().clone(); - let cargo_workspace_root = Self::path_to_cargo_toml(project, cx) - .and_then(|path| path.parent().map(|path| path.to_path_buf())) - .context("no Cargo workspace root found")?; - - anyhow::Ok((fs, cargo_workspace_root)) - }); - - if let Some((fs, cargo_workspace_root)) = index_provider_deps.log_err() { - indexed_docs_registry.register_provider(Box::new(LocalRustdocProvider::new( - fs, - cargo_workspace_root, - ))); - } - } - - if indexed_docs_registry - .get_provider_store(DocsDotRsProvider::id()) - .is_none() - { - let http_client = maybe!({ - let workspace = workspace - .as_ref() - .context("no workspace")? - .upgrade() - .context("workspace was dropped")?; - let project = workspace.read(cx).project().clone(); - anyhow::Ok(project.read(cx).client().http_client()) - }); - - if let Some(http_client) = http_client.log_err() { - indexed_docs_registry - .register_provider(Box::new(DocsDotRsProvider::new(http_client))); - } - } - } - - /// Runs just-in-time indexing for a given package, in case the slash command - /// is run without any entries existing in the index. - fn run_just_in_time_indexing( - store: Arc, - key: String, - package: PackageName, - executor: BackgroundExecutor, - ) -> Task<()> { - executor.clone().spawn(async move { - let (prefix, needs_full_index) = if let Some((prefix, _)) = key.split_once('*') { - // If we have a wildcard in the search, we want to wait until - // we've completely finished indexing so we get a full set of - // results for the wildcard. - (prefix.to_string(), true) - } else { - (key, false) - }; - - // If we already have some entries, we assume that we've indexed the package before - // and don't need to do it again. - let has_any_entries = store - .any_with_prefix(prefix.clone()) - .await - .unwrap_or_default(); - if has_any_entries { - return (); - }; - - let index_task = store.clone().index(package.clone()); - - if needs_full_index { - _ = index_task.await; - } else { - loop { - executor.timer(Duration::from_millis(200)).await; - - if store - .any_with_prefix(prefix.clone()) - .await - .unwrap_or_default() - || !store.is_indexing(&package) - { - break; - } - } - } - }) - } -} - -impl SlashCommand for DocsSlashCommand { - fn name(&self) -> String { - Self::NAME.into() - } - - fn description(&self) -> String { - "insert docs".into() - } - - fn menu_text(&self) -> String { - "Insert Documentation".into() - } - - fn requires_argument(&self) -> bool { - true - } - - fn complete_argument( - self: Arc, - arguments: &[String], - _cancel: Arc, - workspace: Option>, - _: &mut Window, - cx: &mut App, - ) -> Task>> { - self.ensure_rust_doc_providers_are_registered(workspace, cx); - - let indexed_docs_registry = IndexedDocsRegistry::global(cx); - let args = DocsSlashCommandArgs::parse(arguments); - let store = args - .provider() - .context("no docs provider specified") - .and_then(|provider| IndexedDocsStore::try_global(provider, cx)); - cx.background_spawn(async move { - fn build_completions(items: Vec) -> Vec { - items - .into_iter() - .map(|item| ArgumentCompletion { - label: item.clone().into(), - new_text: item.to_string(), - after_completion: assistant_slash_command::AfterCompletion::Run, - replace_previous_arguments: false, - }) - .collect() - } - - match args { - DocsSlashCommandArgs::NoProvider => { - let providers = indexed_docs_registry.list_providers(); - if providers.is_empty() { - return Ok(vec![ArgumentCompletion { - label: "No available docs providers.".into(), - new_text: String::new(), - after_completion: false.into(), - replace_previous_arguments: false, - }]); - } - - Ok(providers - .into_iter() - .map(|provider| ArgumentCompletion { - label: provider.to_string().into(), - new_text: provider.to_string(), - after_completion: false.into(), - replace_previous_arguments: false, - }) - .collect()) - } - DocsSlashCommandArgs::SearchPackageDocs { - provider, - package, - index, - } => { - let store = store?; - - if index { - // We don't need to hold onto this task, as the `IndexedDocsStore` will hold it - // until it completes. - drop(store.clone().index(package.as_str().into())); - } - - let suggested_packages = store.clone().suggest_packages().await?; - let search_results = store.search(package).await; - - let mut items = build_completions(search_results); - let workspace_crate_completions = suggested_packages - .into_iter() - .filter(|package_name| { - !items - .iter() - .any(|item| item.label.text() == package_name.as_ref()) - }) - .map(|package_name| ArgumentCompletion { - label: format!("{package_name} (unindexed)").into(), - new_text: format!("{package_name}"), - after_completion: true.into(), - replace_previous_arguments: false, - }) - .collect::>(); - items.extend(workspace_crate_completions); - - if items.is_empty() { - return Ok(vec![ArgumentCompletion { - label: format!( - "Enter a {package_term} name.", - package_term = package_term(&provider) - ) - .into(), - new_text: provider.to_string(), - after_completion: false.into(), - replace_previous_arguments: false, - }]); - } - - Ok(items) - } - DocsSlashCommandArgs::SearchItemDocs { item_path, .. } => { - let store = store?; - let items = store.search(item_path).await; - Ok(build_completions(items)) - } - } - }) - } - - fn run( - self: Arc, - arguments: &[String], - _context_slash_command_output_sections: &[SlashCommandOutputSection], - _context_buffer: BufferSnapshot, - _workspace: WeakEntity, - _delegate: Option>, - _: &mut Window, - cx: &mut App, - ) -> Task { - if arguments.is_empty() { - return Task::ready(Err(anyhow!("missing an argument"))); - }; - - let args = DocsSlashCommandArgs::parse(arguments); - let executor = cx.background_executor().clone(); - let task = cx.background_spawn({ - let store = args - .provider() - .context("no docs provider specified") - .and_then(|provider| IndexedDocsStore::try_global(provider, cx)); - async move { - let (provider, key) = match args.clone() { - DocsSlashCommandArgs::NoProvider => bail!("no docs provider specified"), - DocsSlashCommandArgs::SearchPackageDocs { - provider, package, .. - } => (provider, package), - DocsSlashCommandArgs::SearchItemDocs { - provider, - item_path, - .. - } => (provider, item_path), - }; - - if key.trim().is_empty() { - bail!( - "no {package_term} name provided", - package_term = package_term(&provider) - ); - } - - let store = store?; - - if let Some(package) = args.package() { - Self::run_just_in_time_indexing(store.clone(), key.clone(), package, executor) - .await; - } - - let (text, ranges) = if let Some((prefix, _)) = key.split_once('*') { - let docs = store.load_many_by_prefix(prefix.to_string()).await?; - - let mut text = String::new(); - let mut ranges = Vec::new(); - - for (key, docs) in docs { - let prev_len = text.len(); - - text.push_str(&docs.0); - text.push_str("\n"); - ranges.push((key, prev_len..text.len())); - text.push_str("\n"); - } - - (text, ranges) - } else { - let item_docs = store.load(key.clone()).await?; - let text = item_docs.to_string(); - let range = 0..text.len(); - - (text, vec![(key, range)]) - }; - - anyhow::Ok((provider, text, ranges)) - } - }); - - cx.foreground_executor().spawn(async move { - let (provider, text, ranges) = task.await?; - Ok(SlashCommandOutput { - text, - sections: ranges - .into_iter() - .map(|(key, range)| SlashCommandOutputSection { - range, - icon: IconName::FileDoc, - label: format!("docs ({provider}): {key}",).into(), - metadata: None, - }) - .collect(), - run_commands_in_text: false, - } - .to_event_stream()) - }) - } -} - -fn is_item_path_delimiter(char: char) -> bool { - !char.is_alphanumeric() && char != '-' && char != '_' -} - -#[derive(Debug, PartialEq, Clone)] -pub enum DocsSlashCommandArgs { - NoProvider, - SearchPackageDocs { - provider: ProviderId, - package: String, - index: bool, - }, - SearchItemDocs { - provider: ProviderId, - package: String, - item_path: String, - }, -} - -impl DocsSlashCommandArgs { - pub fn parse(arguments: &[String]) -> Self { - let Some(provider) = arguments - .get(0) - .cloned() - .filter(|arg| !arg.trim().is_empty()) - else { - return Self::NoProvider; - }; - let provider = ProviderId(provider.into()); - let Some(argument) = arguments.get(1) else { - return Self::NoProvider; - }; - - if let Some((package, rest)) = argument.split_once(is_item_path_delimiter) { - if rest.trim().is_empty() { - Self::SearchPackageDocs { - provider, - package: package.to_owned(), - index: true, - } - } else { - Self::SearchItemDocs { - provider, - package: package.to_owned(), - item_path: argument.to_owned(), - } - } - } else { - Self::SearchPackageDocs { - provider, - package: argument.to_owned(), - index: false, - } - } - } - - pub fn provider(&self) -> Option { - match self { - Self::NoProvider => None, - Self::SearchPackageDocs { provider, .. } | Self::SearchItemDocs { provider, .. } => { - Some(provider.clone()) - } - } - } - - pub fn package(&self) -> Option { - match self { - Self::NoProvider => None, - Self::SearchPackageDocs { package, .. } | Self::SearchItemDocs { package, .. } => { - Some(package.as_str().into()) - } - } - } -} - -/// Returns the term used to refer to a package. -fn package_term(provider: &ProviderId) -> &'static str { - if provider == &DocsDotRsProvider::id() || provider == &LocalRustdocProvider::id() { - return "crate"; - } - - "package" -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_docs_slash_command_args() { - assert_eq!( - DocsSlashCommandArgs::parse(&["".to_string()]), - DocsSlashCommandArgs::NoProvider - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string()]), - DocsSlashCommandArgs::NoProvider - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "".into(), - index: false - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "".into(), - index: false - } - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - index: false, - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - index: false - } - ); - - // Adding an item path delimiter indicates we can start indexing. - assert_eq!( - DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui:".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - index: true, - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib/".to_string()]), - DocsSlashCommandArgs::SearchPackageDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - index: true - } - ); - - assert_eq!( - DocsSlashCommandArgs::parse(&[ - "rustdoc".to_string(), - "gpui::foo::bar::Baz".to_string() - ]), - DocsSlashCommandArgs::SearchItemDocs { - provider: ProviderId("rustdoc".into()), - package: "gpui".into(), - item_path: "gpui::foo::bar::Baz".into() - } - ); - assert_eq!( - DocsSlashCommandArgs::parse(&[ - "gleam".to_string(), - "gleam_stdlib/gleam/int".to_string() - ]), - DocsSlashCommandArgs::SearchItemDocs { - provider: ProviderId("gleam".into()), - package: "gleam_stdlib".into(), - item_path: "gleam_stdlib/gleam/int".into() - } - ); - } -} diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index acbe674b02..c95695052a 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -12,12 +12,10 @@ workspace = true path = "src/assistant_tool.rs" [dependencies] +action_log.workspace = true anyhow.workspace = true -buffer_diff.workspace = true -clock.workspace = true collections.workspace = true derive_more.workspace = true -futures.workspace = true gpui.workspace = true icons.workspace = true language.workspace = true @@ -30,7 +28,6 @@ serde.workspace = true serde_json.workspace = true text.workspace = true util.workspace = true -watch.workspace = true workspace.workspace = true workspace-hack.workspace = true diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 22cbaac3f8..9c5825d0f0 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -1,4 +1,3 @@ -mod action_log; pub mod outline; mod tool_registry; mod tool_schema; @@ -10,6 +9,7 @@ use std::fmt::Formatter; use std::ops::Deref; use std::sync::Arc; +use action_log::ActionLog; use anyhow::Result; use gpui::AnyElement; use gpui::AnyWindowHandle; @@ -25,7 +25,6 @@ use language_model::LanguageModelToolSchemaFormat; use project::Project; use workspace::Workspace; -pub use crate::action_log::*; pub use crate::tool_registry::*; pub use crate::tool_schema::*; pub use crate::tool_working_set::*; diff --git a/crates/assistant_tool/src/outline.rs b/crates/assistant_tool/src/outline.rs index 6af204d79a..4f8bde5456 100644 --- a/crates/assistant_tool/src/outline.rs +++ b/crates/assistant_tool/src/outline.rs @@ -1,4 +1,4 @@ -use crate::ActionLog; +use action_log::ActionLog; use anyhow::{Context as _, Result}; use gpui::{AsyncApp, Entity}; use language::{OutlineItem, ParseStatus}; diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index d4b8fa3afc..5a8ca8a5e9 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -15,6 +15,7 @@ path = "src/assistant_tools.rs" eval = [] [dependencies] +action_log.workspace = true agent_settings.workspace = true anyhow.workspace = true assistant_tool.workspace = true diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index e34ae9ff93..c56a864bd4 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::AnyWindowHandle; use gpui::{App, AppContext, Entity, Task}; use language_model::LanguageModel; diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index 11d969d234..85eea463dc 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::AnyWindowHandle; use gpui::{App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index 9e69c18b65..b181eeff5c 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use futures::{SinkExt, StreamExt, channel::mpsc}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 12ab97f820..4ec794e127 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language::{DiagnosticSeverity, OffsetRangeExt}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; @@ -85,7 +86,7 @@ impl Tool for DiagnosticsTool { input: serde_json::Value, _request: Arc, project: Entity, - action_log: Entity, + _action_log: Entity, _model: Arc, _window: Option, cx: &mut App, @@ -158,10 +159,6 @@ impl Tool for DiagnosticsTool { } } - action_log.update(cx, |action_log, _cx| { - action_log.checked_project_diagnostics(); - }); - if has_diagnostics { Task::ready(Ok(output.into())).into() } else { diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index dcb14a48f3..aa321aa8f3 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -5,8 +5,8 @@ mod evals; mod streaming_fuzzy_matcher; use crate::{Template, Templates}; +use action_log::ActionLog; use anyhow::Result; -use assistant_tool::ActionLog; use cloud_llm_client::CompletionIntent; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; pub use edit_parser::EditFormat; @@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent { ResolvingEditRange(Range), UnresolvedEditRange, AmbiguousEditRange(Vec>), - Edited, + Edited(Range), } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -178,7 +178,9 @@ impl EditAgent { ) }); output_events_tx - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited( + language::Anchor::MIN..language::Anchor::MAX, + )) .ok(); })?; @@ -200,7 +202,9 @@ impl EditAgent { }); })?; output_events_tx - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited( + language::Anchor::MIN..language::Anchor::MAX, + )) .ok(); } } @@ -336,8 +340,8 @@ impl EditAgent { // Edit the buffer and report edits to the action log as part of the // same effect cycle, otherwise the edit will be reported as if the // user made it. - cx.update(|cx| { - let max_edit_end = buffer.update(cx, |buffer, cx| { + let (min_edit_start, max_edit_end) = cx.update(|cx| { + let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| { buffer.edit(edits.iter().cloned(), None, cx); let max_edit_end = buffer .summaries_for_anchors::( @@ -345,7 +349,16 @@ impl EditAgent { ) .max() .unwrap(); - buffer.anchor_before(max_edit_end) + let min_edit_start = buffer + .summaries_for_anchors::( + edits.iter().map(|(range, _)| &range.start), + ) + .min() + .unwrap(); + ( + buffer.anchor_after(min_edit_start), + buffer.anchor_before(max_edit_end), + ) }); self.action_log .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); @@ -358,9 +371,10 @@ impl EditAgent { cx, ); }); + (min_edit_start, max_edit_end) })?; output_events - .unbounded_send(EditAgentOutputEvent::Edited) + .unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end)) .ok(); } @@ -755,6 +769,7 @@ mod tests { use gpui::{AppContext, TestAppContext}; use indoc::indoc; use language_model::fake_provider::FakeLanguageModel; + use pretty_assertions::assert_matches; use project::{AgentLocation, Project}; use rand::prelude::*; use rand::rngs::StdRng; @@ -992,7 +1007,10 @@ mod tests { model.send_last_completion_stream_text_chunk("abX"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)] + ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "abXc\ndef\nghi\njkl" @@ -1007,7 +1025,10 @@ mod tests { model.send_last_completion_stream_text_chunk("cY"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] + ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "abXcY\ndef\nghi\njkl" @@ -1118,9 +1139,9 @@ mod tests { model.send_last_completion_stream_text_chunk("GHI"); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1165,9 +1186,9 @@ mod tests { ); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1183,9 +1204,9 @@ mod tests { chunks_tx.unbounded_send("```\njkl\n").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1201,9 +1222,9 @@ mod tests { chunks_tx.unbounded_send("mno\n").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited { .. }] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1219,9 +1240,9 @@ mod tests { chunks_tx.unbounded_send("pqr\n```").unwrap(); cx.run_until_parked(); - assert_eq!( - drain_events(&mut events), - vec![EditAgentOutputEvent::Edited] + assert_matches!( + drain_events(&mut events).as_slice(), + [EditAgentOutputEvent::Edited(_)], ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), diff --git a/crates/assistant_tools/src/edit_agent/create_file_parser.rs b/crates/assistant_tools/src/edit_agent/create_file_parser.rs index 07c8fac7b9..0aad9ecb87 100644 --- a/crates/assistant_tools/src/edit_agent/create_file_parser.rs +++ b/crates/assistant_tools/src/edit_agent/create_file_parser.rs @@ -1,10 +1,11 @@ +use std::sync::OnceLock; + use regex::Regex; use smallvec::SmallVec; -use std::cell::LazyCell; use util::debug_panic; -const START_MARKER: LazyCell = LazyCell::new(|| Regex::new(r"\n?```\S*\n").unwrap()); -const END_MARKER: LazyCell = LazyCell::new(|| Regex::new(r"(^|\n)```\s*$").unwrap()); +static START_MARKER: OnceLock = OnceLock::new(); +static END_MARKER: OnceLock = OnceLock::new(); #[derive(Debug)] pub enum CreateFileParserEvent { @@ -43,10 +44,12 @@ impl CreateFileParser { self.buffer.push_str(chunk); let mut edit_events = SmallVec::new(); + let start_marker_regex = START_MARKER.get_or_init(|| Regex::new(r"\n?```\S*\n").unwrap()); + let end_marker_regex = END_MARKER.get_or_init(|| Regex::new(r"(^|\n)```\s*$").unwrap()); loop { match &mut self.state { ParserState::Pending => { - if let Some(m) = START_MARKER.find(&self.buffer) { + if let Some(m) = start_marker_regex.find(&self.buffer) { self.buffer.drain(..m.end()); self.state = ParserState::WithinText; } else { @@ -65,7 +68,7 @@ impl CreateFileParser { break; } ParserState::Finishing => { - if let Some(m) = END_MARKER.find(&self.buffer) { + if let Some(m) = end_marker_regex.find(&self.buffer) { self.buffer.drain(m.start()..); } if !self.buffer.is_empty() { diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 54431ee1d7..e819c51e1e 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -4,11 +4,11 @@ use crate::{ schema::json_schema_for, ui::{COLLAPSED_LINES, ToolOutputPreview}, }; +use action_log::ActionLog; use agent_settings; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ - ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, - ToolUseStatus, + AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey}; @@ -307,7 +307,7 @@ impl Tool for EditFileTool { let mut ambiguous_ranges = Vec::new(); while let Some(event) = events.next().await { match event { - EditAgentOutputEvent::Edited => { + EditAgentOutputEvent::Edited { .. } => { if let Some(card) = card_clone.as_ref() { card.update(cx, |card, cx| card.update_diff(cx))?; } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index a31ec39268..79e205f205 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -3,8 +3,9 @@ use std::sync::Arc; use std::{borrow::Cow, cell::RefCell}; use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow, bail}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use futures::AsyncReadExt as _; use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index 6cdf58eac8..6b62638a4c 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -1,7 +1,8 @@ use crate::{schema::json_schema_for, ui::ToolCallCardHeader}; +use action_log::ActionLog; use anyhow::{Result, anyhow}; use assistant_tool::{ - ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, + Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; use editor::Editor; use futures::channel::oneshot::{self, Receiver}; diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index 43c3d1d990..a5ce07823f 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use futures::StreamExt; use gpui::{AnyWindowHandle, App, Entity, Task}; use language::{OffsetRangeExt, ParseStatus, Point}; diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index b1980615d6..5471d8923b 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::{Project, WorktreeSettings}; diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index c1cbbf848d..2c065488ce 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::Project; diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index b51b91d3d5..f50ad065d1 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use chrono::{Local, Utc}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 8fddbb0431..6dbf66749b 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::Project; diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs index 03487e5419..c65cfd0ca7 100644 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::Result; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::Project; diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index ee38273cc0..68b870e40f 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -1,6 +1,7 @@ use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use assistant_tool::{ToolResultContent, outline}; use gpui::{AnyWindowHandle, App, Entity, Task}; use project::{ImageItem, image_store}; @@ -286,7 +287,7 @@ impl Tool for ReadFileTool { Using the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline. - + Alternatively, you can fall back to the `grep` tool (if available) to search the file for specific content." } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 8add60f09a..46227f130d 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -2,9 +2,10 @@ use crate::{ schema::json_schema_for, ui::{COLLAPSED_LINES, ToolOutputPreview}, }; +use action_log::ActionLog; use agent_settings; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; +use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus}; use futures::{FutureExt as _, future::Shared}; use gpui::{ Animation, AnimationExt, AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 76c6e6c0ba..17ce4afc2e 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use crate::schema::json_schema_for; +use action_log::ActionLog; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, Tool, ToolResult}; +use assistant_tool::{Tool, ToolResult}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::Project; diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index c6c37de472..47a6958b7a 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -2,9 +2,10 @@ use std::{sync::Arc, time::Duration}; use crate::schema::json_schema_for; use crate::ui::ToolCallCardHeader; +use action_log::ActionLog; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ - ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, + Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, }; use cloud_llm_client::{WebSearchResponse, WebSearchResult}; use futures::{Future, FutureExt, TryFutureExt}; diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index d857a3eb2f..5146396b92 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -18,6 +18,6 @@ collections.workspace = true derive_more.workspace = true gpui.workspace = true parking_lot.workspace = true -rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] } +rodio = { workspace = true, features = [ "wav", "playback", "tracing" ] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index 074aaa6fea..4d0d2d5984 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -59,16 +59,9 @@ pub enum VersionCheckType { pub enum AutoUpdateStatus { Idle, Checking, - Downloading { - version: VersionCheckType, - }, - Installing { - version: VersionCheckType, - }, - Updated { - binary_path: PathBuf, - version: VersionCheckType, - }, + Downloading { version: VersionCheckType }, + Installing { version: VersionCheckType }, + Updated { version: VersionCheckType }, Errored, } @@ -83,6 +76,7 @@ pub struct AutoUpdater { current_version: SemanticVersion, http_client: Arc, pending_poll: Option>>, + quit_subscription: Option, } #[derive(Deserialize, Clone, Debug)] @@ -164,7 +158,7 @@ pub fn init(http_client: Arc, cx: &mut App) { AutoUpdateSetting::register(cx); cx.observe_new(|workspace: &mut Workspace, _window, _cx| { - workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx)); + workspace.register_action(|_, action, window, cx| check(action, window, cx)); workspace.register_action(|_, action, _, cx| { view_release_notes(action, cx); @@ -174,7 +168,7 @@ pub fn init(http_client: Arc, cx: &mut App) { let version = release_channel::AppVersion::global(cx); let auto_updater = cx.new(|cx| { - let updater = AutoUpdater::new(version, http_client); + let updater = AutoUpdater::new(version, http_client, cx); let poll_for_updates = ReleaseChannel::try_global(cx) .map(|channel| channel.poll_for_updates()) @@ -321,12 +315,34 @@ impl AutoUpdater { cx.default_global::().0.clone() } - fn new(current_version: SemanticVersion, http_client: Arc) -> Self { + fn new( + current_version: SemanticVersion, + http_client: Arc, + cx: &mut Context, + ) -> Self { + // On windows, executable files cannot be overwritten while they are + // running, so we must wait to overwrite the application until quitting + // or restarting. When quitting the app, we spawn the auto update helper + // to finish the auto update process after Zed exits. When restarting + // the app after an update, we use `set_restart_path` to run the auto + // update helper instead of the app, so that it can overwrite the app + // and then spawn the new binary. + let quit_subscription = Some(cx.on_app_quit(|_, _| async move { + #[cfg(target_os = "windows")] + finalize_auto_update_on_quit(); + })); + + cx.on_app_restart(|this, _| { + this.quit_subscription.take(); + }) + .detach(); + Self { status: AutoUpdateStatus::Idle, current_version, http_client, pending_poll: None, + quit_subscription, } } @@ -536,6 +552,8 @@ impl AutoUpdater { ) })?; + Self::check_dependencies()?; + this.update(&mut cx, |this, cx| { this.status = AutoUpdateStatus::Checking; cx.notify(); @@ -582,13 +600,15 @@ impl AutoUpdater { cx.notify(); })?; - let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?; + let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?; + if let Some(new_binary_path) = new_binary_path { + cx.update(|cx| cx.set_restart_path(new_binary_path))?; + } this.update(&mut cx, |this, cx| { this.set_should_show_update_notification(true, cx) .detach_and_log_err(cx); this.status = AutoUpdateStatus::Updated { - binary_path, version: newer_version, }; cx.notify(); @@ -639,6 +659,15 @@ impl AutoUpdater { } } + fn check_dependencies() -> Result<()> { + #[cfg(not(target_os = "windows"))] + anyhow::ensure!( + which::which("rsync").is_ok(), + "Aborting. Could not find rsync which is required for auto-updates." + ); + Ok(()) + } + async fn target_path(installer_dir: &InstallerDir) -> Result { let filename = match OS { "macos" => anyhow::Ok("Zed.dmg"), @@ -647,20 +676,14 @@ impl AutoUpdater { unsupported_os => anyhow::bail!("not supported: {unsupported_os}"), }?; - #[cfg(not(target_os = "windows"))] - anyhow::ensure!( - which::which("rsync").is_ok(), - "Aborting. Could not find rsync which is required for auto-updates." - ); - Ok(installer_dir.path().join(filename)) } - async fn binary_path( + async fn install_release( installer_dir: InstallerDir, target_path: PathBuf, cx: &AsyncApp, - ) -> Result { + ) -> Result> { match OS { "macos" => install_release_macos(&installer_dir, target_path, cx).await, "linux" => install_release_linux(&installer_dir, target_path, cx).await, @@ -801,7 +824,7 @@ async fn install_release_linux( temp_dir: &InstallerDir, downloaded_tar_gz: PathBuf, cx: &AsyncApp, -) -> Result { +) -> Result> { let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?; let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?); let running_app_path = cx.update(|cx| cx.app_path())??; @@ -861,14 +884,14 @@ async fn install_release_linux( String::from_utf8_lossy(&output.stderr) ); - Ok(to.join(expected_suffix)) + Ok(Some(to.join(expected_suffix))) } async fn install_release_macos( temp_dir: &InstallerDir, downloaded_dmg: PathBuf, cx: &AsyncApp, -) -> Result { +) -> Result> { let running_app_path = cx.update(|cx| cx.app_path())??; let running_app_filename = running_app_path .file_name() @@ -910,10 +933,10 @@ async fn install_release_macos( String::from_utf8_lossy(&output.stderr) ); - Ok(running_app_path) + Ok(None) } -async fn install_release_windows(downloaded_installer: PathBuf) -> Result { +async fn install_release_windows(downloaded_installer: PathBuf) -> Result> { let output = Command::new(downloaded_installer) .arg("/verysilent") .arg("/update=true") @@ -926,29 +949,36 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result bool { +pub fn finalize_auto_update_on_quit() { let Some(installer_path) = std::env::current_exe() .ok() .and_then(|p| p.parent().map(|p| p.join("updates"))) else { - return false; + return; }; // The installer will create a flag file after it finishes updating let flag_file = installer_path.join("versions.txt"); - if flag_file.exists() { - if let Some(helper) = installer_path + if flag_file.exists() + && let Some(helper) = installer_path .parent() .map(|p| p.join("tools\\auto_update_helper.exe")) - { - let _ = std::process::Command::new(helper).spawn(); - return true; - } + { + let mut command = std::process::Command::new(helper); + command.arg("--launch"); + command.arg("false"); + let _ = command.spawn(); } - false } #[cfg(test)] @@ -1002,7 +1032,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)), }; let fetched_version = SemanticVersion::new(1, 0, 1); @@ -1024,7 +1053,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)), }; let fetched_version = SemanticVersion::new(1, 0, 2); @@ -1090,7 +1118,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "b".to_string(); @@ -1112,7 +1139,6 @@ mod tests { let app_commit_sha = Ok(Some("a".to_string())); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "c".to_string(); @@ -1160,7 +1186,6 @@ mod tests { let app_commit_sha = Ok(None); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "b".to_string(); @@ -1183,7 +1208,6 @@ mod tests { let app_commit_sha = Ok(None); let installed_version = SemanticVersion::new(1, 0, 0); let status = AutoUpdateStatus::Updated { - binary_path: PathBuf::new(), version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())), }; let fetched_sha = "c".to_string(); diff --git a/crates/auto_update_helper/src/auto_update_helper.rs b/crates/auto_update_helper/src/auto_update_helper.rs index 7c810d8724..3aa57094d3 100644 --- a/crates/auto_update_helper/src/auto_update_helper.rs +++ b/crates/auto_update_helper/src/auto_update_helper.rs @@ -18,7 +18,7 @@ fn main() {} #[cfg(target_os = "windows")] mod windows_impl { - use std::path::Path; + use std::{borrow::Cow, path::Path}; use super::dialog::create_dialog_window; use super::updater::perform_update; @@ -37,6 +37,11 @@ mod windows_impl { pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1; pub(crate) const WM_TERMINATE: u32 = WM_USER + 2; + #[derive(Debug, Default)] + struct Args { + launch: bool, + } + pub(crate) fn run() -> Result<()> { let helper_dir = std::env::current_exe()? .parent() @@ -51,8 +56,9 @@ mod windows_impl { log::info!("======= Starting Zed update ======="); let (tx, rx) = std::sync::mpsc::channel(); let hwnd = create_dialog_window(rx)?.0 as isize; + let args = parse_args(std::env::args().skip(1)); std::thread::spawn(move || { - let result = perform_update(app_dir.as_path(), Some(hwnd)); + let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch); tx.send(result).ok(); unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok(); }); @@ -77,6 +83,29 @@ mod windows_impl { Ok(()) } + fn parse_args(input: impl IntoIterator) -> Args { + let mut args: Args = Args { launch: true }; + + let mut input = input.into_iter(); + if let Some(arg) = input.next() { + let launch_arg; + + if arg == "--launch" { + launch_arg = input.next().map(Cow::Owned); + } else if let Some(rest) = arg.strip_prefix("--launch=") { + launch_arg = Some(Cow::Borrowed(rest)); + } else { + launch_arg = None; + } + + if launch_arg.as_deref() == Some("false") { + args.launch = false; + } + } + + args + } + pub(crate) fn show_error(mut content: String) { if content.len() > 600 { content.truncate(600); @@ -91,4 +120,31 @@ mod windows_impl { ) }; } + + #[cfg(test)] + mod tests { + use crate::windows_impl::parse_args; + + #[test] + fn test_parse_args() { + // launch can be specified via two separate arguments + assert_eq!(parse_args(["--launch".into(), "true".into()]).launch, true); + assert_eq!( + parse_args(["--launch".into(), "false".into()]).launch, + false + ); + + // launch can be specified via one single argument + assert_eq!(parse_args(["--launch=true".into()]).launch, true); + assert_eq!(parse_args(["--launch=false".into()]).launch, false); + + // launch defaults to true on no arguments + assert_eq!(parse_args([]).launch, true); + + // launch defaults to true on invalid arguments + assert_eq!(parse_args(["--launch".into()]).launch, true); + assert_eq!(parse_args(["--launch=".into()]).launch, true); + assert_eq!(parse_args(["--launch=invalid".into()]).launch, true); + } + } } diff --git a/crates/auto_update_helper/src/dialog.rs b/crates/auto_update_helper/src/dialog.rs index 010ebb4875..757819df51 100644 --- a/crates/auto_update_helper/src/dialog.rs +++ b/crates/auto_update_helper/src/dialog.rs @@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver>) -> Result) -> Result<()> { +pub(crate) fn perform_update(app_dir: &Path, hwnd: Option, launch: bool) -> Result<()> { let hwnd = hwnd.map(|ptr| HWND(ptr as _)); for job in JOBS.iter() { @@ -145,9 +145,11 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option) -> Result<()> } } } - let _ = std::process::Command::new(app_dir.join("Zed.exe")) - .creation_flags(CREATE_NEW_PROCESS_GROUP.0) - .spawn(); + if launch { + let _ = std::process::Command::new(app_dir.join("Zed.exe")) + .creation_flags(CREATE_NEW_PROCESS_GROUP.0) + .spawn(); + } log::info!("Update completed successfully"); Ok(()) } @@ -159,11 +161,11 @@ mod test { #[test] fn test_perform_update() { let app_dir = std::path::Path::new("C:/"); - assert!(perform_update(app_dir, None).is_ok()); + assert!(perform_update(app_dir, None, false).is_ok()); // Simulate a timeout unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") }; - let ret = perform_update(app_dir, None); + let ret = perform_update(app_dir, None, false); assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out")); } } diff --git a/crates/call/src/call_impl/room.rs b/crates/call/src/call_impl/room.rs index afeee4c924..73cb8518a6 100644 --- a/crates/call/src/call_impl/room.rs +++ b/crates/call/src/call_impl/room.rs @@ -10,10 +10,10 @@ use client::{ }; use collections::{BTreeMap, HashMap, HashSet}; use fs::Fs; -use futures::{FutureExt, StreamExt}; +use futures::StreamExt; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource, - ScreenCaptureStream, Task, WeakEntity, + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FutureExt as _, + ScreenCaptureSource, ScreenCaptureStream, Task, Timeout, WeakEntity, }; use gpui_tokio::Tokio; use language::LanguageRegistry; @@ -370,57 +370,53 @@ impl Room { })?; // Wait for client to re-establish a connection to the server. - { - let mut reconnection_timeout = - cx.background_executor().timer(RECONNECT_TIMEOUT).fuse(); - let client_reconnection = async { - let mut remaining_attempts = 3; - while remaining_attempts > 0 { - if client_status.borrow().is_connected() { - log::info!("client reconnected, attempting to rejoin room"); + let executor = cx.background_executor().clone(); + let client_reconnection = async { + let mut remaining_attempts = 3; + while remaining_attempts > 0 { + if client_status.borrow().is_connected() { + log::info!("client reconnected, attempting to rejoin room"); - let Some(this) = this.upgrade() else { break }; - match this.update(cx, |this, cx| this.rejoin(cx)) { - Ok(task) => { - if task.await.log_err().is_some() { - return true; - } else { - remaining_attempts -= 1; - } + let Some(this) = this.upgrade() else { break }; + match this.update(cx, |this, cx| this.rejoin(cx)) { + Ok(task) => { + if task.await.log_err().is_some() { + return true; + } else { + remaining_attempts -= 1; } - Err(_app_dropped) => return false, } - } else if client_status.borrow().is_signed_out() { - return false; + Err(_app_dropped) => return false, } - - log::info!( - "waiting for client status change, remaining attempts {}", - remaining_attempts - ); - client_status.next().await; + } else if client_status.borrow().is_signed_out() { + return false; } - false + + log::info!( + "waiting for client status change, remaining attempts {}", + remaining_attempts + ); + client_status.next().await; } - .fuse(); - futures::pin_mut!(client_reconnection); + false + }; - futures::select_biased! { - reconnected = client_reconnection => { - if reconnected { - log::info!("successfully reconnected to room"); - // If we successfully joined the room, go back around the loop - // waiting for future connection status changes. - continue; - } - } - _ = reconnection_timeout => { - log::info!("room reconnection timeout expired"); - } + match client_reconnection + .with_timeout(RECONNECT_TIMEOUT, &executor) + .await + { + Ok(true) => { + log::info!("successfully reconnected to room"); + // If we successfully joined the room, go back around the loop + // waiting for future connection status changes. + continue; + } + Ok(false) => break, + Err(Timeout) => { + log::info!("room reconnection timeout expired"); + break; } } - - break; } } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 8d6cd2544a..67591167df 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -957,17 +957,14 @@ mod mac_os { ) -> Result<()> { use anyhow::bail; - let app_id_prompt = format!("id of app \"{}\"", channel.display_name()); - let app_id_output = Command::new("osascript") + let app_path_prompt = format!( + "POSIX path of (path to application \"{}\")", + channel.display_name() + ); + let app_path_output = Command::new("osascript") .arg("-e") - .arg(&app_id_prompt) + .arg(&app_path_prompt) .output()?; - if !app_id_output.status.success() { - bail!("Could not determine app id for {}", channel.display_name()); - } - let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned(); - let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'"); - let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?; if !app_path_output.status.success() { bail!( "Could not determine app path for {}", diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 43a1a0b7a4..54b3d3f801 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -340,22 +340,35 @@ impl Telemetry { } pub fn log_edit_event(self: &Arc, environment: &'static str, is_via_ssh: bool) { + static LAST_EVENT_TIME: Mutex> = Mutex::new(None); + let mut state = self.state.lock(); let period_data = state.event_coalescer.log_event(environment); drop(state); - if let Some((start, end, environment)) = period_data { - let duration = end - .saturating_duration_since(start) - .min(Duration::from_secs(60 * 60 * 24)) - .as_millis() as i64; + if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() { + let current_time = std::time::Instant::now(); + let last_time = last_event.get_or_insert(current_time); - telemetry::event!( - "Editor Edited", - duration = duration, - environment = environment, - is_via_ssh = is_via_ssh - ); + if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) { + *last_time = current_time; + } else { + return; + } + + if let Some((start, end, environment)) = period_data { + let duration = end + .saturating_duration_since(start) + .min(Duration::from_secs(60 * 60 * 24)) + .as_millis() as i64; + + telemetry::event!( + "Editor Edited", + duration = duration, + environment = environment, + is_via_ssh = is_via_ssh + ); + } } } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 439fb100d2..3c451fcb01 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,16 +1,12 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; -use chrono::Duration; use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; -use rpc::{ - ConnectionId, Peer, Receipt, TypedEnvelope, - proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, -}; +use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto}; use std::sync::Arc; pub struct FakeServer { @@ -187,50 +183,27 @@ impl FakeServer { pub async fn receive(&self) -> Result> { self.executor.start_waiting(); - loop { - let message = self - .state - .lock() - .incoming - .as_mut() - .expect("not connected") - .next() - .await - .context("other half hung up")?; - self.executor.finish_waiting(); - let type_name = message.payload_type_name(); - let message = message.into_any(); + let message = self + .state + .lock() + .incoming + .as_mut() + .expect("not connected") + .next() + .await + .context("other half hung up")?; + self.executor.finish_waiting(); + let type_name = message.payload_type_name(); + let message = message.into_any(); - if message.is::>() { - return Ok(*message.downcast().unwrap()); - } - - let accepted_tos_at = chrono::Utc::now() - .checked_sub_signed(Duration::hours(5)) - .expect("failed to build accepted_tos_at") - .timestamp() as u64; - - if message.is::>() { - self.respond( - message - .downcast::>() - .unwrap() - .receipt(), - GetPrivateUserInfoResponse { - metrics_id: "the-metrics-id".into(), - staff: false, - flags: Default::default(), - accepted_tos_at: Some(accepted_tos_at), - }, - ); - continue; - } - - panic!( - "fake server received unexpected message type: {:?}", - type_name - ); + if message.is::>() { + return Ok(*message.downcast().unwrap()); } + + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); } pub fn respond(&self, receipt: Receipt, response: T::Response) { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index faf46945d8..da7f50076b 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -177,7 +177,6 @@ impl UserStore { let (mut current_user_tx, current_user_rx) = watch::channel(); let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscriptions = vec![ - client.add_message_handler(cx.weak_entity(), Self::handle_update_plan), client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts), client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), @@ -343,26 +342,6 @@ impl UserStore { Ok(()) } - async fn handle_update_plan( - this: Entity, - _message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let client = this - .read_with(&cx, |this, _| this.client.upgrade())? - .context("client was dropped")?; - - let response = client - .cloud_client() - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - - this.update(&mut cx, |this, cx| { - this.update_authenticated_user(response, cx); - }) - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -1019,19 +998,6 @@ impl RequestUsage { } } - pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { - let limit = match limit.variant? { - proto::usage_limit::Variant::Limited(limited) => { - UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, - }; - Some(RequestUsage { - limit, - amount: amount as i32, - }) - } - fn from_headers( limit_name: &str, amount_name: &str, diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index e78957ec49..741945af10 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -263,12 +263,12 @@ pub struct WebSearchBody { pub query: String, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct WebSearchResponse { pub results: Vec, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct WebSearchResult { pub title: String, pub url: String, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e6..4fccd3be7f 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,7 +19,6 @@ test-support = ["sqlite"] [dependencies] anyhow.workspace = true -async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } @@ -30,16 +29,13 @@ axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true chrono.workspace = true clock.workspace = true -cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true -derive_more.workspace = true envy = "0.4.2" futures.workspace = true gpui.workspace = true hex.workspace = true http_client.workspace = true -jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true @@ -65,7 +61,6 @@ subtle.workspace = true supermaven_api.workspace = true telemetry_events.workspace = true text.workspace = true -thiserror.workspace = true time.workspace = true tokio = { workspace = true, features = ["full"] } toml.workspace = true @@ -136,6 +131,3 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } zlog.workspace = true - -[package.metadata.cargo-machete] -ignored = ["async-stripe"] diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 45fc018a4a..214b550ac2 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -219,12 +219,6 @@ spec: secretKeyRef: name: slack key: panics_webhook - - name: STRIPE_API_KEY - valueFrom: - secretKeyRef: - name: stripe - key: api_key - optional: true - name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR value: "1000" - name: SUPERMAVEN_ADMIN_API_KEY diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 73d473ab76..170ac7b0a2 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -474,67 +474,6 @@ CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count"); -CREATE TABLE rate_buckets ( - user_id INT NOT NULL, - rate_limit_name VARCHAR(255) NOT NULL, - token_count INT NOT NULL, - last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL, - PRIMARY KEY (user_id, rate_limit_name), - FOREIGN KEY (user_id) REFERENCES users (id) -); - -CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name); - -CREATE TABLE IF NOT EXISTS billing_preferences ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - user_id INTEGER NOT NULL REFERENCES users (id), - max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL, - model_request_overages_enabled bool NOT NULL DEFAULT FALSE, - model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0 -); - -CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id); - -CREATE TABLE IF NOT EXISTS billing_customers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - user_id INTEGER NOT NULL REFERENCES users (id), - has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE, - stripe_customer_id TEXT NOT NULL, - trial_started_at TIMESTAMP -); - -CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id); - -CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id); - -CREATE TABLE IF NOT EXISTS billing_subscriptions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id), - stripe_subscription_id TEXT NOT NULL, - stripe_subscription_status TEXT NOT NULL, - stripe_cancel_at TIMESTAMP, - stripe_cancellation_reason TEXT, - kind TEXT, - stripe_current_period_start BIGINT, - stripe_current_period_end BIGINT -); - -CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id); - -CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id); - -CREATE TABLE IF NOT EXISTS processed_stripe_events ( - stripe_event_id TEXT PRIMARY KEY, - stripe_event_type TEXT NOT NULL, - stripe_event_created_timestamp INTEGER NOT NULL, - processed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE INDEX "ix_processed_stripe_events_on_stripe_event_created_timestamp" ON processed_stripe_events (stripe_event_created_timestamp); - CREATE TABLE IF NOT EXISTS "breakpoints" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, diff --git a/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql b/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql new file mode 100644 index 0000000000..e372723d6d --- /dev/null +++ b/crates/collab/migrations/20250816124707_make_admin_required_on_users.sql @@ -0,0 +1,2 @@ +alter table users +alter column admin set not null; diff --git a/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql b/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql new file mode 100644 index 0000000000..ea5e4de52a --- /dev/null +++ b/crates/collab/migrations/20250816133027_add_orb_customer_id_to_billing_customers.sql @@ -0,0 +1,2 @@ +alter table billing_customers + add column orb_customer_id text; diff --git a/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql b/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql new file mode 100644 index 0000000000..f51a33ed30 --- /dev/null +++ b/crates/collab/migrations/20250816135346_drop_rate_buckets_table.sql @@ -0,0 +1 @@ +drop table rate_buckets; diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 6cf3f68f54..0cc7e2b2e9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,19 +1,11 @@ -pub mod billing; pub mod contributors; pub mod events; pub mod extensions; pub mod ips_file; pub mod slack; -use crate::db::Database; -use crate::{ - AppState, Error, Result, auth, - db::{User, UserId}, - rpc, -}; -use ::rpc::proto; +use crate::{AppState, Error, Result, auth, db::UserId, rpc}; use anyhow::Context as _; -use axum::extract; use axum::{ Extension, Json, Router, body::Body, @@ -25,7 +17,6 @@ use axum::{ routing::{get, post}, }; use axum_extra::response::ErasedJson; -use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; @@ -100,10 +91,7 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) - .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) - .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(contributors::router()) .layer( @@ -144,99 +132,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct LookUpUserParams { - identifier: String, -} - -#[derive(Debug, Serialize)] -struct LookUpUserResponse { - user: Option, -} - -async fn look_up_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?; - let user = if let Some(user) = user { - match user { - UserOrId::User(user) => Some(user), - UserOrId::Id(id) => app.db.get_user_by_id(id).await?, - } - } else { - None - }; - - Ok(Json(LookUpUserResponse { user })) -} - -enum UserOrId { - User(User), - Id(UserId), -} - -async fn resolve_identifier_to_user( - db: &Arc, - identifier: &str, -) -> Result> { - if let Some(identifier) = identifier.parse::().ok() { - let user = db.get_user_by_id(UserId(identifier)).await?; - - return Ok(user.map(UserOrId::User)); - } - - if identifier.starts_with("cus_") { - let billing_customer = db - .get_billing_customer_by_stripe_customer_id(&identifier) - .await?; - - return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))); - } - - if identifier.starts_with("sub_") { - let billing_subscription = db - .get_billing_subscription_by_stripe_subscription_id(&identifier) - .await?; - - if let Some(billing_subscription) = billing_subscription { - let billing_customer = db - .get_billing_customer_by_id(billing_subscription.billing_customer_id) - .await?; - - return Ok( - billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)) - ); - } else { - return Ok(None); - } - } - - if identifier.contains('@') { - let user = db.get_user_by_email(identifier).await?; - - return Ok(user.map(UserOrId::User)); - } - - if let Some(user) = db.get_user_by_github_login(identifier).await? { - return Ok(Some(UserOrId::User(user))); - } - - Ok(None) -} - -#[derive(Deserialize, Debug)] -struct CreateUserParams { - github_user_id: i32, - github_login: String, - email_address: String, - email_confirmation_code: Option, - #[serde(default)] - admin: bool, - #[serde(default)] - invite_count: i32, -} - async fn get_rpc_server_snapshot( Extension(rpc_server): Extension>, ) -> Result { @@ -295,90 +190,3 @@ async fn create_access_token( encrypted_access_token, })) } - -#[derive(Serialize)] -struct RefreshLlmTokensResponse {} - -async fn refresh_llm_tokens( - Path(user_id): Path, - Extension(rpc_server): Extension>, -) -> Result> { - rpc_server.refresh_llm_tokens_for_user(user_id).await; - - Ok(Json(RefreshLlmTokensResponse {})) -} - -#[derive(Debug, Serialize, Deserialize)] -struct UpdatePlanBody { - pub plan: cloud_llm_client::Plan, - pub subscription_period: SubscriptionPeriod, - pub usage: cloud_llm_client::CurrentUsage, - pub trial_started_at: Option>, - pub is_usage_based_billing_enabled: bool, - pub is_account_too_young: bool, - pub has_overdue_invoices: bool, -} - -#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] -struct SubscriptionPeriod { - pub started_at: DateTime, - pub ended_at: DateTime, -} - -#[derive(Serialize)] -struct UpdatePlanResponse {} - -async fn update_plan( - Path(user_id): Path, - Extension(rpc_server): Extension>, - extract::Json(body): extract::Json, -) -> Result> { - let plan = match body.plan { - cloud_llm_client::Plan::ZedFree => proto::Plan::Free, - cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, - }; - - let update_user_plan = proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: body - .trial_started_at - .map(|trial_started_at| trial_started_at.timestamp() as u64), - is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled), - usage: Some(proto::SubscriptionUsage { - model_requests_usage_amount: body.usage.model_requests.used, - model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)), - edit_predictions_usage_amount: body.usage.edit_predictions.used, - edit_predictions_usage_limit: Some(usage_limit_to_proto( - body.usage.edit_predictions.limit, - )), - }), - subscription_period: Some(proto::SubscriptionPeriod { - started_at: body.subscription_period.started_at.timestamp() as u64, - ended_at: body.subscription_period.ended_at.timestamp() as u64, - }), - account_too_young: Some(body.is_account_too_young), - has_overdue_invoices: Some(body.has_overdue_invoices), - }; - - rpc_server - .update_plan_for_user(user_id, update_user_plan) - .await?; - - Ok(Json(UpdatePlanResponse {})) -} - -fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit { - proto::UsageLimit { - variant: Some(match limit { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - } -} diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs deleted file mode 100644 index a0325d14c4..0000000000 --- a/crates/collab/src/api/billing.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::sync::Arc; -use stripe::SubscriptionStatus; - -use crate::AppState; -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::{CreateBillingCustomerParams, billing_customer}; -use crate::stripe_client::{StripeClient, StripeCustomerId}; - -impl From for StripeSubscriptionStatus { - fn from(value: SubscriptionStatus) -> Self { - match value { - SubscriptionStatus::Incomplete => Self::Incomplete, - SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, - SubscriptionStatus::Trialing => Self::Trialing, - SubscriptionStatus::Active => Self::Active, - SubscriptionStatus::PastDue => Self::PastDue, - SubscriptionStatus::Canceled => Self::Canceled, - SubscriptionStatus::Unpaid => Self::Unpaid, - SubscriptionStatus::Paused => Self::Paused, - } - } -} - -/// Finds or creates a billing customer using the provided customer. -pub async fn find_or_create_billing_customer( - app: &Arc, - stripe_client: &dyn StripeClient, - customer_id: &StripeCustomerId, -) -> anyhow::Result> { - // If we already have a billing customer record associated with the Stripe customer, - // there's nothing more we need to do. - if let Some(billing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref()) - .await? - { - return Ok(Some(billing_customer)); - } - - let customer = stripe_client.get_customer(customer_id).await?; - - let Some(email) = customer.email else { - return Ok(None); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - return Ok(None); - }; - - let billing_customer = app - .db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - - Ok(Some(billing_customer)) -} diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 2f34a843a8..cd1dc42e64 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -564,170 +564,10 @@ fn for_snowflake( country_code: Option, checksum_matched: bool, ) -> impl Iterator { - body.events.into_iter().filter_map(move |event| { + body.events.into_iter().map(move |event| { let timestamp = first_event_at + Duration::milliseconds(event.milliseconds_since_first_event); - // We will need to double check, but I believe all of the events that - // are being transformed here are now migrated over to use the - // telemetry::event! macro, as of this commit so this code can go away - // when we feel enough users have upgraded past this point. let (event_type, mut event_properties) = match &event.event { - Event::Editor(e) => ( - match e.operation.as_str() { - "open" => "Editor Opened".to_string(), - "save" => "Editor Saved".to_string(), - _ => format!("Unknown Editor Event: {}", e.operation), - }, - serde_json::to_value(e).unwrap(), - ), - Event::EditPrediction(e) => ( - format!( - "Edit Prediction {}", - if e.suggestion_accepted { - "Accepted" - } else { - "Discarded" - } - ), - serde_json::to_value(e).unwrap(), - ), - Event::EditPredictionRating(e) => ( - "Edit Prediction Rated".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Call(e) => { - let event_type = match e.operation.trim() { - "unshare project" => "Project Unshared".to_string(), - "open channel notes" => "Channel Notes Opened".to_string(), - "share project" => "Project Shared".to_string(), - "join channel" => "Channel Joined".to_string(), - "hang up" => "Call Ended".to_string(), - "accept incoming" => "Incoming Call Accepted".to_string(), - "invite" => "Participant Invited".to_string(), - "disable microphone" => "Microphone Disabled".to_string(), - "enable microphone" => "Microphone Enabled".to_string(), - "enable screen share" => "Screen Share Enabled".to_string(), - "disable screen share" => "Screen Share Disabled".to_string(), - "decline incoming" => "Incoming Call Declined".to_string(), - _ => format!("Unknown Call Event: {}", e.operation), - }; - - (event_type, serde_json::to_value(e).unwrap()) - } - Event::Assistant(e) => ( - match e.phase { - telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(), - telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(), - telemetry_events::AssistantPhase::Accepted => { - "Assistant Response Accepted".to_string() - } - telemetry_events::AssistantPhase::Rejected => { - "Assistant Response Rejected".to_string() - } - }, - serde_json::to_value(e).unwrap(), - ), - Event::Cpu(_) | Event::Memory(_) => return None, - Event::App(e) => { - let mut properties = json!({}); - let event_type = match e.operation.trim() { - // App - "open" => "App Opened".to_string(), - "first open" => "App First Opened".to_string(), - "first open for release channel" => { - "App First Opened For Release Channel".to_string() - } - "close" => "App Closed".to_string(), - - // Project - "open project" => "Project Opened".to_string(), - "open node project" => { - properties["project_type"] = json!("node"); - "Project Opened".to_string() - } - "open pnpm project" => { - properties["project_type"] = json!("pnpm"); - "Project Opened".to_string() - } - "open yarn project" => { - properties["project_type"] = json!("yarn"); - "Project Opened".to_string() - } - - // SSH - "create ssh server" => "SSH Server Created".to_string(), - "create ssh project" => "SSH Project Created".to_string(), - "open ssh project" => "SSH Project Opened".to_string(), - - // Welcome Page - "welcome page: change keymap" => "Welcome Keymap Changed".to_string(), - "welcome page: change theme" => "Welcome Theme Changed".to_string(), - "welcome page: close" => "Welcome Page Closed".to_string(), - "welcome page: edit settings" => "Welcome Settings Edited".to_string(), - "welcome page: install cli" => "Welcome CLI Installed".to_string(), - "welcome page: open" => "Welcome Page Opened".to_string(), - "welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(), - "welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(), - "welcome page: toggle diagnostic telemetry" => { - "Welcome Diagnostic Telemetry Toggled".to_string() - } - "welcome page: toggle metric telemetry" => { - "Welcome Metric Telemetry Toggled".to_string() - } - "welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(), - "welcome page: view docs" => "Welcome Documentation Viewed".to_string(), - - // Extensions - "extensions page: open" => "Extensions Page Opened".to_string(), - "extensions: install extension" => "Extension Installed".to_string(), - "extensions: uninstall extension" => "Extension Uninstalled".to_string(), - - // Misc - "markdown preview: open" => "Markdown Preview Opened".to_string(), - "project diagnostics: open" => "Project Diagnostics Opened".to_string(), - "project search: open" => "Project Search Opened".to_string(), - "repl sessions: open" => "REPL Session Started".to_string(), - - // Feature Upsell - "feature upsell: toggle vim" => { - properties["source"] = json!("Feature Upsell"); - "Vim Mode Toggled".to_string() - } - _ => e - .operation - .strip_prefix("feature upsell: viewed docs (") - .and_then(|s| s.strip_suffix(')')) - .map_or_else( - || format!("Unknown App Event: {}", e.operation), - |docs_url| { - properties["url"] = json!(docs_url); - properties["source"] = json!("Feature Upsell"); - "Documentation Viewed".to_string() - }, - ), - }; - (event_type, properties) - } - Event::Setting(e) => ( - "Settings Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Extension(e) => ( - "Extension Loaded".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Edit(e) => ( - "Editor Edited".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Action(e) => ( - "Action Invoked".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Repl(e) => ( - "Kernel Status Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), Event::Flexible(e) => ( e.event_type.clone(), serde_json::to_value(&e.event_properties).unwrap(), @@ -759,7 +599,7 @@ fn for_snowflake( }) }); - Some(SnowflakeRow { + SnowflakeRow { time: timestamp, user_id: body.metrics_id.clone(), device_id: body.system_id.clone(), @@ -767,7 +607,7 @@ fn for_snowflake( event_properties, user_properties, insert_id: Some(Uuid::new_v4().to_string()), - }) + } }) } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 2c22ca2069..774eec5d2c 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind; pub use tests::TestDb; pub use ids::*; -pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; -pub use queries::billing_subscriptions::{ - CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, -}; pub use queries::contributors::ContributorSelector; -pub use queries::processed_stripe_events::CreateProcessedStripeEventParams; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; pub use tables::*; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 2ba7ec1051..8f116cfd63 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -70,9 +70,6 @@ macro_rules! id_type { } id_type!(AccessTokenId); -id_type!(BillingCustomerId); -id_type!(BillingSubscriptionId); -id_type!(BillingPreferencesId); id_type!(BufferId); id_type!(ChannelBufferCollaboratorId); id_type!(ChannelChatParticipantId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 64b627e475..95e45dc004 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -1,9 +1,6 @@ use super::*; pub mod access_tokens; -pub mod billing_customers; -pub mod billing_preferences; -pub mod billing_subscriptions; pub mod buffers; pub mod channels; pub mod contacts; @@ -12,7 +9,6 @@ pub mod embeddings; pub mod extensions; pub mod messages; pub mod notifications; -pub mod processed_stripe_events; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs deleted file mode 100644 index ead9e6cd32..0000000000 --- a/crates/collab/src/db/queries/billing_customers.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::*; - -#[derive(Debug)] -pub struct CreateBillingCustomerParams { - pub user_id: UserId, - pub stripe_customer_id: String, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingCustomerParams { - pub user_id: ActiveValue, - pub stripe_customer_id: ActiveValue, - pub has_overdue_invoices: ActiveValue, - pub trial_started_at: ActiveValue>, -} - -impl Database { - /// Creates a new billing customer. - pub async fn create_billing_customer( - &self, - params: &CreateBillingCustomerParams, - ) -> Result { - self.transaction(|tx| async move { - let customer = billing_customer::Entity::insert(billing_customer::ActiveModel { - user_id: ActiveValue::set(params.user_id), - stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()), - ..Default::default() - }) - .exec_with_returning(&*tx) - .await?; - - Ok(customer) - }) - .await - } - - /// Updates the specified billing customer. - pub async fn update_billing_customer( - &self, - id: BillingCustomerId, - params: &UpdateBillingCustomerParams, - ) -> Result<()> { - self.transaction(|tx| async move { - billing_customer::Entity::update(billing_customer::ActiveModel { - id: ActiveValue::set(id), - user_id: params.user_id.clone(), - stripe_customer_id: params.stripe_customer_id.clone(), - has_overdue_invoices: params.has_overdue_invoices.clone(), - trial_started_at: params.trial_started_at.clone(), - created_at: ActiveValue::not_set(), - }) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - pub async fn get_billing_customer_by_id( - &self, - id: BillingCustomerId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::Id.eq(id)) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the billing customer for the user with the specified ID. - pub async fn get_billing_customer_by_user_id( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::UserId.eq(user_id)) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the billing customer for the user with the specified Stripe customer ID. - pub async fn get_billing_customer_by_stripe_customer_id( - &self, - stripe_customer_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_customer::Entity::find() - .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id)) - .one(&*tx) - .await?) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs deleted file mode 100644 index f370964ecd..0000000000 --- a/crates/collab/src/db/queries/billing_preferences.rs +++ /dev/null @@ -1,17 +0,0 @@ -use super::*; - -impl Database { - /// Returns the billing preferences for the given user, if they exist. - pub async fn get_billing_preferences( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_preference::Entity::find() - .filter(billing_preference::Column::UserId.eq(user_id)) - .one(&*tx) - .await?) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs deleted file mode 100644 index 8361d6b4d0..0000000000 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ /dev/null @@ -1,158 +0,0 @@ -use anyhow::Context as _; - -use crate::db::billing_subscription::{ - StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, -}; - -use super::*; - -#[derive(Debug)] -pub struct CreateBillingSubscriptionParams { - pub billing_customer_id: BillingCustomerId, - pub kind: Option, - pub stripe_subscription_id: String, - pub stripe_subscription_status: StripeSubscriptionStatus, - pub stripe_cancellation_reason: Option, - pub stripe_current_period_start: Option, - pub stripe_current_period_end: Option, -} - -#[derive(Debug, Default)] -pub struct UpdateBillingSubscriptionParams { - pub billing_customer_id: ActiveValue, - pub kind: ActiveValue>, - pub stripe_subscription_id: ActiveValue, - pub stripe_subscription_status: ActiveValue, - pub stripe_cancel_at: ActiveValue>, - pub stripe_cancellation_reason: ActiveValue>, - pub stripe_current_period_start: ActiveValue>, - pub stripe_current_period_end: ActiveValue>, -} - -impl Database { - /// Creates a new billing subscription. - pub async fn create_billing_subscription( - &self, - params: &CreateBillingSubscriptionParams, - ) -> Result { - self.transaction(|tx| async move { - let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel { - billing_customer_id: ActiveValue::set(params.billing_customer_id), - kind: ActiveValue::set(params.kind), - stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), - stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), - stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason), - stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start), - stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end), - ..Default::default() - }) - .exec(&*tx) - .await? - .last_insert_id; - - Ok(billing_subscription::Entity::find_by_id(id) - .one(&*tx) - .await? - .context("failed to retrieve inserted billing subscription")?) - }) - .await - } - - /// Updates the specified billing subscription. - pub async fn update_billing_subscription( - &self, - id: BillingSubscriptionId, - params: &UpdateBillingSubscriptionParams, - ) -> Result<()> { - self.transaction(|tx| async move { - billing_subscription::Entity::update(billing_subscription::ActiveModel { - id: ActiveValue::set(id), - billing_customer_id: params.billing_customer_id.clone(), - kind: params.kind.clone(), - stripe_subscription_id: params.stripe_subscription_id.clone(), - stripe_subscription_status: params.stripe_subscription_status.clone(), - stripe_cancel_at: params.stripe_cancel_at.clone(), - stripe_cancellation_reason: params.stripe_cancellation_reason.clone(), - stripe_current_period_start: params.stripe_current_period_start.clone(), - stripe_current_period_end: params.stripe_current_period_end.clone(), - created_at: ActiveValue::not_set(), - }) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Returns the billing subscription with the specified Stripe subscription ID. - pub async fn get_billing_subscription_by_stripe_subscription_id( - &self, - stripe_subscription_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find() - .filter( - billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id), - ) - .one(&*tx) - .await?) - }) - .await - } - - pub async fn get_active_billing_subscription( - &self, - user_id: UserId, - ) -> Result> { - self.transaction(|tx| async move { - Ok(billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter(billing_customer::Column::UserId.eq(user_id)) - .filter( - Condition::all() - .add( - Condition::any() - .add( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active), - ) - .add( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Trialing), - ), - ) - .add(billing_subscription::Column::Kind.is_not_null()), - ) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns whether the user has an active billing subscription. - pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { - Ok(self.count_active_billing_subscriptions(user_id).await? > 0) - } - - /// Returns the count of the active billing subscriptions for the user with the specified ID. - pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result { - self.transaction(|tx| async move { - let count = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter( - billing_customer::Column::UserId.eq(user_id).and( - billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Active) - .or(billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Trialing)), - ), - ) - .count(&*tx) - .await?; - - Ok(count as usize) - }) - .await - } -} diff --git a/crates/collab/src/db/queries/processed_stripe_events.rs b/crates/collab/src/db/queries/processed_stripe_events.rs deleted file mode 100644 index f14ad480e0..0000000000 --- a/crates/collab/src/db/queries/processed_stripe_events.rs +++ /dev/null @@ -1,69 +0,0 @@ -use super::*; - -#[derive(Debug)] -pub struct CreateProcessedStripeEventParams { - pub stripe_event_id: String, - pub stripe_event_type: String, - pub stripe_event_created_timestamp: i64, -} - -impl Database { - /// Creates a new processed Stripe event. - pub async fn create_processed_stripe_event( - &self, - params: &CreateProcessedStripeEventParams, - ) -> Result<()> { - self.transaction(|tx| async move { - processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel { - stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()), - stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()), - stripe_event_created_timestamp: ActiveValue::set( - params.stripe_event_created_timestamp, - ), - ..Default::default() - }) - .exec_without_returning(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Returns the processed Stripe event with the specified event ID. - pub async fn get_processed_stripe_event_by_event_id( - &self, - event_id: &str, - ) -> Result> { - self.transaction(|tx| async move { - Ok(processed_stripe_event::Entity::find_by_id(event_id) - .one(&*tx) - .await?) - }) - .await - } - - /// Returns the processed Stripe events with the specified event IDs. - pub async fn get_processed_stripe_events_by_event_ids( - &self, - event_ids: &[&str], - ) -> Result> { - self.transaction(|tx| async move { - Ok(processed_stripe_event::Entity::find() - .filter( - processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()), - ) - .all(&*tx) - .await?) - }) - .await - } - - /// Returns whether the Stripe event with the specified ID has already been processed. - pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result { - Ok(self - .get_processed_stripe_event_by_event_id(event_id) - .await? - .is_some()) - } -} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index d87ab174bd..0082a9fb03 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -1,7 +1,4 @@ pub mod access_token; -pub mod billing_customer; -pub mod billing_preference; -pub mod billing_subscription; pub mod buffer; pub mod buffer_operation; pub mod buffer_snapshot; @@ -23,7 +20,6 @@ pub mod notification; pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; -pub mod processed_stripe_event; pub mod project; pub mod project_collaborator; pub mod project_repository; diff --git a/crates/collab/src/db/tables/billing_customer.rs b/crates/collab/src/db/tables/billing_customer.rs deleted file mode 100644 index e7d4a216e3..0000000000 --- a/crates/collab/src/db/tables/billing_customer.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::db::{BillingCustomerId, UserId}; -use sea_orm::entity::prelude::*; - -/// A billing customer. -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_customers")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingCustomerId, - pub user_id: UserId, - pub stripe_customer_id: String, - pub has_overdue_invoices: bool, - pub trial_started_at: Option, - pub created_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, - #[sea_orm(has_many = "super::billing_subscription::Entity")] - BillingSubscription, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingSubscription.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/billing_preference.rs b/crates/collab/src/db/tables/billing_preference.rs deleted file mode 100644 index c1888d3b2f..0000000000 --- a/crates/collab/src/db/tables/billing_preference.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::db::{BillingPreferencesId, UserId}; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_preferences")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingPreferencesId, - pub created_at: DateTime, - pub user_id: UserId, - pub max_monthly_llm_usage_spending_in_cents: i32, - pub model_request_overages_enabled: bool, - pub model_request_overages_spend_limit_in_cents: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs deleted file mode 100644 index 522973dbc9..0000000000 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ /dev/null @@ -1,176 +0,0 @@ -use crate::db::{BillingCustomerId, BillingSubscriptionId}; -use crate::stripe_client; -use chrono::{Datelike as _, NaiveDate, Utc}; -use sea_orm::entity::prelude::*; -use serde::Serialize; - -/// A billing subscription. -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_subscriptions")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingSubscriptionId, - pub billing_customer_id: BillingCustomerId, - pub kind: Option, - pub stripe_subscription_id: String, - pub stripe_subscription_status: StripeSubscriptionStatus, - pub stripe_cancel_at: Option, - pub stripe_cancellation_reason: Option, - pub stripe_current_period_start: Option, - pub stripe_current_period_end: Option, - pub created_at: DateTime, -} - -impl Model { - pub fn current_period_start_at(&self) -> Option { - let period_start = self.stripe_current_period_start?; - chrono::DateTime::from_timestamp(period_start, 0) - } - - pub fn current_period_end_at(&self) -> Option { - let period_end = self.stripe_current_period_end?; - chrono::DateTime::from_timestamp(period_end, 0) - } - - pub fn current_period( - subscription: Option, - is_staff: bool, - ) -> Option<(DateTimeUtc, DateTimeUtc)> { - if is_staff { - let now = Utc::now(); - let year = now.year(); - let month = now.month(); - - let first_day_of_this_month = - NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?; - - let next_month = if month == 12 { 1 } else { month + 1 }; - let next_month_year = if month == 12 { year + 1 } else { year }; - let first_day_of_next_month = - NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?; - - let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1); - - Some(( - first_day_of_this_month.and_utc(), - last_day_of_this_month.and_utc(), - )) - } else { - let subscription = subscription?; - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at, period_end_at)) - } - } -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::billing_customer::Entity", - from = "Column::BillingCustomerId", - to = "super::billing_customer::Column::Id" - )] - BillingCustomer, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingCustomer.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} - -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum SubscriptionKind { - #[sea_orm(string_value = "zed_pro")] - ZedPro, - #[sea_orm(string_value = "zed_pro_trial")] - ZedProTrial, - #[sea_orm(string_value = "zed_free")] - ZedFree, -} - -impl From for cloud_llm_client::Plan { - fn from(value: SubscriptionKind) -> Self { - match value { - SubscriptionKind::ZedPro => Self::ZedPro, - SubscriptionKind::ZedProTrial => Self::ZedProTrial, - SubscriptionKind::ZedFree => Self::ZedFree, - } - } -} - -/// The status of a Stripe subscription. -/// -/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status) -#[derive( - Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize, -)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum StripeSubscriptionStatus { - #[default] - #[sea_orm(string_value = "incomplete")] - Incomplete, - #[sea_orm(string_value = "incomplete_expired")] - IncompleteExpired, - #[sea_orm(string_value = "trialing")] - Trialing, - #[sea_orm(string_value = "active")] - Active, - #[sea_orm(string_value = "past_due")] - PastDue, - #[sea_orm(string_value = "canceled")] - Canceled, - #[sea_orm(string_value = "unpaid")] - Unpaid, - #[sea_orm(string_value = "paused")] - Paused, -} - -impl StripeSubscriptionStatus { - pub fn is_cancelable(&self) -> bool { - match self { - Self::Trialing | Self::Active | Self::PastDue => true, - Self::Incomplete - | Self::IncompleteExpired - | Self::Canceled - | Self::Unpaid - | Self::Paused => false, - } - } -} - -/// The cancellation reason for a Stripe subscription. -/// -/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason) -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum StripeCancellationReason { - #[sea_orm(string_value = "cancellation_requested")] - CancellationRequested, - #[sea_orm(string_value = "payment_disputed")] - PaymentDisputed, - #[sea_orm(string_value = "payment_failed")] - PaymentFailed, -} - -impl From for StripeCancellationReason { - fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self { - match value { - stripe_client::StripeCancellationDetailsReason::CancellationRequested => { - Self::CancellationRequested - } - stripe_client::StripeCancellationDetailsReason::PaymentDisputed => { - Self::PaymentDisputed - } - stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} diff --git a/crates/collab/src/db/tables/processed_stripe_event.rs b/crates/collab/src/db/tables/processed_stripe_event.rs deleted file mode 100644 index 7b6f0cdc31..0000000000 --- a/crates/collab/src/db/tables/processed_stripe_event.rs +++ /dev/null @@ -1,16 +0,0 @@ -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "processed_stripe_events")] -pub struct Model { - #[sea_orm(primary_key)] - pub stripe_event_id: String, - pub stripe_event_type: String, - pub stripe_event_created_timestamp: i64, - pub processed_at: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index 49fe3eb58f..af43fe300a 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -29,8 +29,6 @@ pub struct Model { pub enum Relation { #[sea_orm(has_many = "super::access_token::Entity")] AccessToken, - #[sea_orm(has_one = "super::billing_customer::Entity")] - BillingCustomer, #[sea_orm(has_one = "super::room_participant::Entity")] RoomParticipant, #[sea_orm(has_many = "super::project::Entity")] @@ -68,12 +66,6 @@ impl Related for Entity { } } -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingCustomer.def() - } -} - impl Related for Entity { fn to() -> RelationDef { Relation::RoomParticipant.def() diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 6c2f9dc82a..2eb8d377ac 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -8,7 +8,6 @@ mod embedding_tests; mod extension_tests; mod feature_flag_tests; mod message_tests; -mod processed_stripe_event_tests; mod user_tests; use crate::migrations::run_database_migrations; diff --git a/crates/collab/src/db/tests/processed_stripe_event_tests.rs b/crates/collab/src/db/tests/processed_stripe_event_tests.rs deleted file mode 100644 index ad93b5a658..0000000000 --- a/crates/collab/src/db/tests/processed_stripe_event_tests.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::sync::Arc; - -use crate::test_both_dbs; - -use super::{CreateProcessedStripeEventParams, Database}; - -test_both_dbs!( - test_already_processed_stripe_event, - test_already_processed_stripe_event_postgres, - test_already_processed_stripe_event_sqlite -); - -async fn test_already_processed_stripe_event(db: &Arc) { - let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string(); - let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string(); - - db.create_processed_stripe_event(&CreateProcessedStripeEventParams { - stripe_event_id: processed_event_id.clone(), - stripe_event_type: "customer.created".into(), - stripe_event_created_timestamp: 1722355968, - }) - .await - .unwrap(); - - assert!( - db.already_processed_stripe_event(&processed_event_id) - .await - .unwrap(), - "Expected {processed_event_id} to already be processed" - ); - - assert!( - !db.already_processed_stripe_event(&unprocessed_event_id) - .await - .unwrap(), - "Expected {unprocessed_event_id} to be unprocessed" - ); -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 905859ca69..191025df37 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -7,8 +7,6 @@ pub mod llm; pub mod migrations; pub mod rpc; pub mod seed; -pub mod stripe_billing; -pub mod stripe_client; pub mod user_backfiller; #[cfg(test)] @@ -22,21 +20,16 @@ use axum::{ }; use db::{ChannelId, Database}; use executor::Executor; -use llm::db::LlmDatabase; use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{RealStripeClient, StripeClient}; - pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String, HeaderMap), Database(sea_orm::error::DbErr), Internal(anyhow::Error), - Stripe(stripe::StripeError), } impl From for Error { @@ -51,12 +44,6 @@ impl From for Error { } } -impl From for Error { - fn from(error: stripe::StripeError) -> Self { - Self::Stripe(error) - } -} - impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -104,14 +91,6 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } - Error::Stripe(error) => { - log::error!( - "HTTP error {}: {:?}", - StatusCode::INTERNAL_SERVER_ERROR, - &error - ); - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() - } } } } @@ -122,7 +101,6 @@ impl std::fmt::Debug for Error { Error::Http(code, message, _headers) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -133,7 +111,6 @@ impl std::fmt::Display for Error { Error::Http(code, message, _) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -179,7 +156,6 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, - pub stripe_api_key: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -234,7 +210,6 @@ impl Config { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -266,14 +241,8 @@ impl ServiceMode { pub struct AppState { pub db: Arc, - pub llm_db: Option>, pub livekit_client: Option>, pub blob_store_client: Option, - /// This is a real instance of the Stripe client; we're working to replace references to this with the - /// [`StripeClient`] trait. - pub real_stripe_client: Option>, - pub stripe_client: Option>, - pub stripe_billing: Option>, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, pub config: Config, @@ -286,20 +255,6 @@ impl AppState { let mut db = Database::new(db_options).await?; db.initialize_notification_kinds().await?; - let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config - .llm_database_url - .clone() - .zip(config.llm_database_max_connections) - { - let mut llm_db_options = db::ConnectOptions::new(llm_database_url); - llm_db_options.max_connections(llm_database_max_connections); - let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?; - llm_db.initialize().await?; - Some(Arc::new(llm_db)) - } else { - None - }; - let livekit_client = if let Some(((server, key), secret)) = config .livekit_server .as_ref() @@ -316,18 +271,10 @@ impl AppState { }; let db = Arc::new(db); - let stripe_client = build_stripe_client(&config).map(Arc::new).log_err(); let this = Self { db: db.clone(), - llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client - .clone() - .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), - real_stripe_client: stripe_client.clone(), - stripe_client: stripe_client - .map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _), executor, kinesis_client: if config.kinesis_access_key.is_some() { build_kinesis_client(&config).await.log_err() @@ -340,14 +287,6 @@ impl AppState { } } -fn build_stripe_client(config: &Config) -> anyhow::Result { - let api_key = config - .stripe_api_key - .as_ref() - .context("missing stripe_api_key")?; - Ok(stripe::Client::new(api_key)) -} - async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index de74858168..dec10232bd 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,12 +1 @@ pub mod db; -mod token; - -pub use token::*; - -pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; - -/// The name of the feature flag that bypasses the account age check. -pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check"; - -/// The minimum account age an account must have in order to use the LLM service. -pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 18ad624dab..b15d5a42b5 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -1,30 +1,9 @@ -mod ids; -mod queries; -mod seed; -mod tables; - -#[cfg(test)] -mod tests; - -use cloud_llm_client::LanguageModelProvider; -use collections::HashMap; -pub use ids::*; -pub use seed::*; -pub use tables::*; - -#[cfg(test)] -pub use tests::TestLlmDb; -use usage_measure::UsageMeasure; - use std::future::Future; use std::sync::Arc; use anyhow::Context; pub use sea_orm::ConnectOptions; -use sea_orm::prelude::*; -use sea_orm::{ - ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait, -}; +use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait}; use crate::Result; use crate::db::TransactionHandle; @@ -36,9 +15,6 @@ pub struct LlmDatabase { pool: DatabaseConnection, #[allow(unused)] executor: Executor, - provider_ids: HashMap, - models: HashMap<(LanguageModelProvider, String), model::Model>, - usage_measure_ids: HashMap, #[cfg(test)] runtime: Option, } @@ -51,59 +27,11 @@ impl LlmDatabase { options: options.clone(), pool: sea_orm::Database::connect(options).await?, executor, - provider_ids: HashMap::default(), - models: HashMap::default(), - usage_measure_ids: HashMap::default(), #[cfg(test)] runtime: None, }) } - pub async fn initialize(&mut self) -> Result<()> { - self.initialize_providers().await?; - self.initialize_models().await?; - self.initialize_usage_measures().await?; - Ok(()) - } - - /// Returns the list of all known models, with their [`LanguageModelProvider`]. - pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> { - self.models - .iter() - .map(|((model_provider, _model_name), model)| (*model_provider, model.clone())) - .collect::>() - } - - /// Returns the names of the known models for the given [`LanguageModelProvider`]. - pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec { - self.models - .keys() - .filter_map(|(model_provider, model_name)| { - if model_provider == &provider { - Some(model_name) - } else { - None - } - }) - .cloned() - .collect::>() - } - - pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> { - Ok(self - .models - .get(&(provider, name.to_string())) - .with_context(|| format!("unknown model {provider:?}:{name}"))?) - } - - pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> { - Ok(self - .models - .values() - .find(|model| model.id == id) - .with_context(|| format!("no model for ID {id:?}"))?) - } - pub fn options(&self) -> &ConnectOptions { &self.options } diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs deleted file mode 100644 index 03cab6cee0..0000000000 --- a/crates/collab/src/llm/db/ids.rs +++ /dev/null @@ -1,11 +0,0 @@ -use sea_orm::{DbErr, entity::prelude::*}; -use serde::{Deserialize, Serialize}; - -use crate::id_type; - -id_type!(BillingEventId); -id_type!(ModelId); -id_type!(ProviderId); -id_type!(RevokedAccessTokenId); -id_type!(UsageId); -id_type!(UsageMeasureId); diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs deleted file mode 100644 index 0087218b3f..0000000000 --- a/crates/collab/src/llm/db/queries.rs +++ /dev/null @@ -1,5 +0,0 @@ -use super::*; - -pub mod providers; -pub mod subscription_usages; -pub mod usages; diff --git a/crates/collab/src/llm/db/queries/providers.rs b/crates/collab/src/llm/db/queries/providers.rs deleted file mode 100644 index 9c7dbdd184..0000000000 --- a/crates/collab/src/llm/db/queries/providers.rs +++ /dev/null @@ -1,134 +0,0 @@ -use super::*; -use sea_orm::{QueryOrder, sea_query::OnConflict}; -use std::str::FromStr; -use strum::IntoEnumIterator as _; - -pub struct ModelParams { - pub provider: LanguageModelProvider, - pub name: String, - pub max_requests_per_minute: i64, - pub max_tokens_per_minute: i64, - pub max_tokens_per_day: i64, - pub price_per_million_input_tokens: i32, - pub price_per_million_output_tokens: i32, -} - -impl LlmDatabase { - pub async fn initialize_providers(&mut self) -> Result<()> { - self.provider_ids = self - .transaction(|tx| async move { - let existing_providers = provider::Entity::find().all(&*tx).await?; - - let mut new_providers = LanguageModelProvider::iter() - .filter(|provider| { - !existing_providers - .iter() - .any(|p| p.name == provider.to_string()) - }) - .map(|provider| provider::ActiveModel { - name: ActiveValue::set(provider.to_string()), - ..Default::default() - }) - .peekable(); - - if new_providers.peek().is_some() { - provider::Entity::insert_many(new_providers) - .exec(&*tx) - .await?; - } - - let all_providers: HashMap<_, _> = provider::Entity::find() - .all(&*tx) - .await? - .iter() - .filter_map(|provider| { - LanguageModelProvider::from_str(&provider.name) - .ok() - .map(|p| (p, provider.id)) - }) - .collect(); - - Ok(all_providers) - }) - .await?; - Ok(()) - } - - pub async fn initialize_models(&mut self) -> Result<()> { - let all_provider_ids = &self.provider_ids; - self.models = self - .transaction(|tx| async move { - let all_models: HashMap<_, _> = model::Entity::find() - .all(&*tx) - .await? - .into_iter() - .filter_map(|model| { - let provider = all_provider_ids.iter().find_map(|(provider, id)| { - if *id == model.provider_id { - Some(provider) - } else { - None - } - })?; - Some(((*provider, model.name.clone()), model)) - }) - .collect(); - Ok(all_models) - }) - .await?; - Ok(()) - } - - pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> { - let all_provider_ids = &self.provider_ids; - self.transaction(|tx| async move { - model::Entity::insert_many(models.iter().map(|model_params| { - let provider_id = all_provider_ids[&model_params.provider]; - model::ActiveModel { - provider_id: ActiveValue::set(provider_id), - name: ActiveValue::set(model_params.name.clone()), - max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute), - max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute), - max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day), - price_per_million_input_tokens: ActiveValue::set( - model_params.price_per_million_input_tokens, - ), - price_per_million_output_tokens: ActiveValue::set( - model_params.price_per_million_output_tokens, - ), - ..Default::default() - } - })) - .on_conflict( - OnConflict::columns([model::Column::ProviderId, model::Column::Name]) - .update_columns([ - model::Column::MaxRequestsPerMinute, - model::Column::MaxTokensPerMinute, - model::Column::MaxTokensPerDay, - model::Column::PricePerMillionInputTokens, - model::Column::PricePerMillionOutputTokens, - ]) - .to_owned(), - ) - .exec_without_returning(&*tx) - .await?; - Ok(()) - }) - .await?; - self.initialize_models().await - } - - /// Returns the list of LLM providers. - pub async fn list_providers(&self) -> Result> { - self.transaction(|tx| async move { - Ok(provider::Entity::find() - .order_by_asc(provider::Column::Name) - .all(&*tx) - .await? - .into_iter() - .filter_map(|p| LanguageModelProvider::from_str(&p.name).ok()) - .collect()) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs deleted file mode 100644 index 8a51979075..0000000000 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::db::UserId; - -use super::*; - -impl LlmDatabase { - pub async fn get_subscription_usage_for_period( - &self, - user_id: UserId, - period_start_at: DateTimeUtc, - period_end_at: DateTimeUtc, - ) -> Result> { - self.transaction(|tx| async move { - self.get_subscription_usage_for_period_in_tx( - user_id, - period_start_at, - period_end_at, - &tx, - ) - .await - }) - .await - } - - async fn get_subscription_usage_for_period_in_tx( - &self, - user_id: UserId, - period_start_at: DateTimeUtc, - period_end_at: DateTimeUtc, - tx: &DatabaseTransaction, - ) -> Result> { - Ok(subscription_usage::Entity::find() - .filter(subscription_usage::Column::UserId.eq(user_id)) - .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at)) - .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at)) - .one(tx) - .await?) - } -} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs deleted file mode 100644 index a917703f96..0000000000 --- a/crates/collab/src/llm/db/queries/usages.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::str::FromStr; -use strum::IntoEnumIterator as _; - -use super::*; - -impl LlmDatabase { - pub async fn initialize_usage_measures(&mut self) -> Result<()> { - let all_measures = self - .transaction(|tx| async move { - let existing_measures = usage_measure::Entity::find().all(&*tx).await?; - - let new_measures = UsageMeasure::iter() - .filter(|measure| { - !existing_measures - .iter() - .any(|m| m.name == measure.to_string()) - }) - .map(|measure| usage_measure::ActiveModel { - name: ActiveValue::set(measure.to_string()), - ..Default::default() - }) - .collect::>(); - - if !new_measures.is_empty() { - usage_measure::Entity::insert_many(new_measures) - .exec(&*tx) - .await?; - } - - Ok(usage_measure::Entity::find().all(&*tx).await?) - }) - .await?; - - self.usage_measure_ids = all_measures - .into_iter() - .filter_map(|measure| { - UsageMeasure::from_str(&measure.name) - .ok() - .map(|um| (um, measure.id)) - }) - .collect(); - Ok(()) - } -} diff --git a/crates/collab/src/llm/db/seed.rs b/crates/collab/src/llm/db/seed.rs deleted file mode 100644 index 55c6c30cd5..0000000000 --- a/crates/collab/src/llm/db/seed.rs +++ /dev/null @@ -1,45 +0,0 @@ -use super::*; -use crate::{Config, Result}; -use queries::providers::ModelParams; - -pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> { - db.insert_models(&[ - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-5-sonnet".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 20_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 300, // $3.00/MTok - price_per_million_output_tokens: 1500, // $15.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-opus".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 1500, // $15.00/MTok - price_per_million_output_tokens: 7500, // $75.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-sonnet".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 20_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 1500, // $15.00/MTok - price_per_million_output_tokens: 7500, // $75.00/MTok - }, - ModelParams { - provider: LanguageModelProvider::Anthropic, - name: "claude-3-haiku".into(), - max_requests_per_minute: 5, - max_tokens_per_minute: 25_000, - max_tokens_per_day: 300_000, - price_per_million_input_tokens: 25, // $0.25/MTok - price_per_million_output_tokens: 125, // $1.25/MTok - }, - ]) - .await -} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs deleted file mode 100644 index 75ea8f5140..0000000000 --- a/crates/collab/src/llm/db/tables.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod model; -pub mod provider; -pub mod subscription_usage; -pub mod subscription_usage_meter; -pub mod usage; -pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs deleted file mode 100644 index f0a858b4a6..0000000000 --- a/crates/collab/src/llm/db/tables/model.rs +++ /dev/null @@ -1,48 +0,0 @@ -use sea_orm::entity::prelude::*; - -use crate::llm::db::{ModelId, ProviderId}; - -/// An LLM model. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "models")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: ModelId, - pub provider_id: ProviderId, - pub name: String, - pub max_requests_per_minute: i64, - pub max_tokens_per_minute: i64, - pub max_input_tokens_per_minute: i64, - pub max_output_tokens_per_minute: i64, - pub max_tokens_per_day: i64, - pub price_per_million_input_tokens: i32, - pub price_per_million_cache_creation_input_tokens: i32, - pub price_per_million_cache_read_input_tokens: i32, - pub price_per_million_output_tokens: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::provider::Entity", - from = "Column::ProviderId", - to = "super::provider::Column::Id" - )] - Provider, - #[sea_orm(has_many = "super::usage::Entity")] - Usages, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Provider.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Usages.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/provider.rs b/crates/collab/src/llm/db/tables/provider.rs deleted file mode 100644 index 90838f7c65..0000000000 --- a/crates/collab/src/llm/db/tables/provider.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::llm::db::ProviderId; -use sea_orm::entity::prelude::*; - -/// An LLM provider. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "providers")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: ProviderId, - pub name: String, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm(has_many = "super::model::Entity")] - Models, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Models.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/subscription_usage.rs b/crates/collab/src/llm/db/tables/subscription_usage.rs deleted file mode 100644 index dd93b03d05..0000000000 --- a/crates/collab/src/llm/db/tables/subscription_usage.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::db::UserId; -use crate::db::billing_subscription::SubscriptionKind; -use sea_orm::entity::prelude::*; -use time::PrimitiveDateTime; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "subscription_usages_v2")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: Uuid, - pub user_id: UserId, - pub period_start_at: PrimitiveDateTime, - pub period_end_at: PrimitiveDateTime, - pub plan: SubscriptionKind, - pub model_requests: i32, - pub edit_predictions: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs deleted file mode 100644 index c082cf3bc1..0000000000 --- a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs +++ /dev/null @@ -1,55 +0,0 @@ -use sea_orm::entity::prelude::*; -use serde::Serialize; - -use crate::llm::db::ModelId; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "subscription_usage_meters_v2")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: Uuid, - pub subscription_usage_id: Uuid, - pub model_id: ModelId, - pub mode: CompletionMode, - pub requests: i32, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::subscription_usage::Entity", - from = "Column::SubscriptionUsageId", - to = "super::subscription_usage::Column::Id" - )] - SubscriptionUsage, - #[sea_orm( - belongs_to = "super::model::Entity", - from = "Column::ModelId", - to = "super::model::Column::Id" - )] - Model, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::SubscriptionUsage.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Model.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} - -#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)] -#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")] -#[serde(rename_all = "snake_case")] -pub enum CompletionMode { - #[sea_orm(string_value = "normal")] - Normal, - #[sea_orm(string_value = "max")] - Max, -} diff --git a/crates/collab/src/llm/db/tables/usage.rs b/crates/collab/src/llm/db/tables/usage.rs deleted file mode 100644 index 331c94a8a9..0000000000 --- a/crates/collab/src/llm/db/tables/usage.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::{ - db::UserId, - llm::db::{ModelId, UsageId, UsageMeasureId}, -}; -use sea_orm::entity::prelude::*; - -/// An LLM usage record. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "usages")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: UsageId, - /// The ID of the Zed user. - /// - /// Corresponds to the `users` table in the primary collab database. - pub user_id: UserId, - pub model_id: ModelId, - pub measure_id: UsageMeasureId, - pub timestamp: DateTime, - pub buckets: Vec, - pub is_staff: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::model::Entity", - from = "Column::ModelId", - to = "super::model::Column::Id" - )] - Model, - #[sea_orm( - belongs_to = "super::usage_measure::Entity", - from = "Column::MeasureId", - to = "super::usage_measure::Column::Id" - )] - UsageMeasure, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Model.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::UsageMeasure.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/usage_measure.rs b/crates/collab/src/llm/db/tables/usage_measure.rs deleted file mode 100644 index 4f75577ed4..0000000000 --- a/crates/collab/src/llm/db/tables/usage_measure.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::llm::db::UsageMeasureId; -use sea_orm::entity::prelude::*; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter, -)] -#[strum(serialize_all = "snake_case")] -pub enum UsageMeasure { - RequestsPerMinute, - TokensPerMinute, - InputTokensPerMinute, - OutputTokensPerMinute, - TokensPerDay, -} - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "usage_measures")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: UsageMeasureId, - pub name: String, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm(has_many = "super::usage::Entity")] - Usages, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Usages.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs deleted file mode 100644 index 43a1b8b0d4..0000000000 --- a/crates/collab/src/llm/db/tests.rs +++ /dev/null @@ -1,107 +0,0 @@ -mod provider_tests; - -use gpui::BackgroundExecutor; -use parking_lot::Mutex; -use rand::prelude::*; -use sea_orm::ConnectionTrait; -use sqlx::migrate::MigrateDatabase; -use std::time::Duration; - -use crate::migrations::run_database_migrations; - -use super::*; - -pub struct TestLlmDb { - pub db: Option, - pub connection: Option, -} - -impl TestLlmDb { - pub fn postgres(background: BackgroundExecutor) -> Self { - static LOCK: Mutex<()> = Mutex::new(()); - - let _guard = LOCK.lock(); - let mut rng = StdRng::from_entropy(); - let url = format!( - "postgres://postgres@localhost/zed-llm-test-{}", - rng.r#gen::() - ); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - sqlx::Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut options = ConnectOptions::new(url); - options - .max_connections(5) - .idle_timeout(Duration::from_secs(0)); - let db = LlmDatabase::new(options, Executor::Deterministic(background)) - .await - .unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); - run_database_migrations(db.options(), migrations_path) - .await - .unwrap(); - db - }); - - db.runtime = Some(runtime); - - Self { - db: Some(db), - connection: None, - } - } - - pub fn db(&mut self) -> &mut LlmDatabase { - self.db.as_mut().unwrap() - } -} - -#[macro_export] -macro_rules! test_llm_db { - ($test_name:ident, $postgres_test_name:ident) => { - #[gpui::test] - async fn $postgres_test_name(cx: &mut gpui::TestAppContext) { - if !cfg!(target_os = "macos") { - return; - } - - let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone()); - $test_name(test_db.db()).await; - } - }; -} - -impl Drop for TestLlmDb { - fn drop(&mut self) { - let db = self.db.take().unwrap(); - if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { - db.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE - pg_stat_activity.datname = current_database() AND - pid <> pg_backend_pid(); - "; - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - query, - )) - .await - .log_err(); - sqlx::Postgres::drop_database(db.options.get_url()) - .await - .log_err(); - }) - } - } -} diff --git a/crates/collab/src/llm/db/tests/provider_tests.rs b/crates/collab/src/llm/db/tests/provider_tests.rs deleted file mode 100644 index f4e1de40ec..0000000000 --- a/crates/collab/src/llm/db/tests/provider_tests.rs +++ /dev/null @@ -1,31 +0,0 @@ -use cloud_llm_client::LanguageModelProvider; -use pretty_assertions::assert_eq; - -use crate::llm::db::LlmDatabase; -use crate::test_llm_db; - -test_llm_db!( - test_initialize_providers, - test_initialize_providers_postgres -); - -async fn test_initialize_providers(db: &mut LlmDatabase) { - let initial_providers = db.list_providers().await.unwrap(); - assert_eq!(initial_providers, vec![]); - - db.initialize_providers().await.unwrap(); - - // Do it twice, to make sure the operation is idempotent. - db.initialize_providers().await.unwrap(); - - let providers = db.list_providers().await.unwrap(); - - assert_eq!( - providers, - &[ - LanguageModelProvider::Anthropic, - LanguageModelProvider::Google, - LanguageModelProvider::OpenAi, - ] - ) -} diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs deleted file mode 100644 index da01c7f3be..0000000000 --- a/crates/collab/src/llm/token.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::db::billing_subscription::SubscriptionKind; -use crate::db::{billing_customer, billing_subscription, user}; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG}; -use crate::{Config, db::billing_preference}; -use anyhow::{Context as _, Result}; -use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use thiserror::Error; -use uuid::Uuid; - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LlmTokenClaims { - pub iat: u64, - pub exp: u64, - pub jti: String, - pub user_id: u64, - pub system_id: Option, - pub metrics_id: Uuid, - pub github_user_login: String, - pub account_created_at: NaiveDateTime, - pub is_staff: bool, - pub has_llm_closed_beta_feature_flag: bool, - pub bypass_account_age_check: bool, - pub use_llm_request_queue: bool, - pub plan: Plan, - pub has_extended_trial: bool, - pub subscription_period: (NaiveDateTime, NaiveDateTime), - pub enable_model_request_overages: bool, - pub model_request_overages_spend_limit_in_cents: u32, - pub can_use_web_search_tool: bool, - #[serde(default)] - pub has_overdue_invoices: bool, -} - -const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); - -impl LlmTokenClaims { - pub fn create( - user: &user::Model, - is_staff: bool, - billing_customer: billing_customer::Model, - billing_preferences: Option, - feature_flags: &Vec, - subscription: billing_subscription::Model, - system_id: Option, - config: &Config, - ) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - let plan = if is_staff { - Plan::ZedPro - } else { - subscription.kind.map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }; - let subscription_period = - billing_subscription::Model::current_period(Some(subscription), is_staff) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())) - .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?; - - let now = Utc::now(); - let claims = Self { - iat: now.timestamp() as u64, - exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, - jti: uuid::Uuid::new_v4().to_string(), - user_id: user.id.to_proto(), - system_id, - metrics_id: user.metrics_id, - github_user_login: user.github_login.clone(), - account_created_at: user.account_created_at(), - is_staff, - has_llm_closed_beta_feature_flag: feature_flags - .iter() - .any(|flag| flag == "llm-closed-beta"), - bypass_account_age_check: feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG), - can_use_web_search_tool: true, - use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan, - has_extended_trial: feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period, - enable_model_request_overages: billing_preferences - .as_ref() - .map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: billing_preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents as u32 - }), - has_overdue_invoices: billing_customer.has_overdue_invoices, - }; - - Ok(jsonwebtoken::encode( - &Header::default(), - &claims, - &EncodingKey::from_secret(secret.as_ref()), - )?) - } - - pub fn validate(token: &str, config: &Config) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - match jsonwebtoken::decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::default(), - ) { - Ok(token) => Ok(token.claims), - Err(e) => { - if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { - Err(ValidateLlmTokenError::Expired) - } else { - Err(ValidateLlmTokenError::JwtError(e)) - } - } - } - } -} - -#[derive(Error, Debug)] -pub enum ValidateLlmTokenError { - #[error("access token is expired")] - Expired, - #[error("access token validation error: {0}")] - JwtError(#[from] jsonwebtoken::errors::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 20641cb232..cb6f6cad1d 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -62,13 +62,6 @@ async fn main() -> Result<()> { db.initialize_notification_kinds().await?; collab::seed::seed(&config, &db, false).await?; - - if let Some(llm_database_url) = config.llm_database_url.clone() { - let db_options = db::ConnectOptions::new(llm_database_url); - let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?; - db.initialize().await?; - collab::llm::db::seed_database(&config, &mut db, true).await?; - } } Some("serve") => { let mode = match args.next().as_deref() { @@ -102,13 +95,6 @@ async fn main() -> Result<()> { let state = AppState::new(config, Executor::Production).await?; - if let Some(stripe_billing) = state.stripe_billing.clone() { - let executor = state.executor.clone(); - executor.spawn_detached(async move { - stripe_billing.initialize().await.trace_err(); - }); - } - if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); @@ -270,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> { .llm_database_migrations_path .as_deref() .unwrap_or_else(|| { - #[cfg(feature = "sqlite")] - let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite"); - #[cfg(not(feature = "sqlite"))] let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm"); Path::new(default_migrations) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 18eb1457dc..ef749ac9b7 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,14 +1,6 @@ mod connection_pool; -use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::db::LlmDatabase; -use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims, - MIN_ACCOUNT_AGE_FOR_LLM_USE, -}; -use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ @@ -37,7 +29,6 @@ use axum::{ response::IntoResponse, routing::get, }; -use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; @@ -148,13 +139,6 @@ pub enum Principal { } impl Principal { - fn user(&self) -> &User { - match self { - Principal::User(user) => user, - Principal::Impersonated { user, .. } => user, - } - } - fn update_span(&self, span: &tracing::Span) { match &self { Principal::User(user) => { @@ -218,6 +202,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + #[allow(unused)] system_id: Option, _executor: Executor, } @@ -463,9 +448,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) - .add_request_handler(get_llm_api_token) - .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) .add_request_handler(get_supermaven_api_key) @@ -1000,8 +982,6 @@ impl Server { .await?; } - update_user_plan(session).await?; - let contacts = self.app_state.db.get_contacts(user.id).await?; { @@ -1081,53 +1061,6 @@ impl Server { Ok(()) } - pub async fn update_plan_for_user( - self: &Arc, - user_id: UserId, - update_user_plan: proto::UpdateUserPlan, - ) -> Result<()> { - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, update_user_plan.clone()) - .trace_err(); - } - - Ok(()) - } - - /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan` - /// message on the Collab server. - /// - /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint. - pub async fn update_plan_for_user_legacy(self: &Arc, user_id: UserId) -> Result<()> { - let user = self - .app_state - .db - .get_user_by_id(user_id) - .await? - .context("user not found")?; - - let update_user_plan = make_update_user_plan_message( - &user, - user.admin, - &self.app_state.db, - self.app_state.llm_db.clone(), - ) - .await?; - - self.update_plan_for_user(user_id, update_user_plan).await - } - - pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, proto::RefreshLlmToken {}) - .trace_err(); - } - } - pub async fn snapshot(self: &Arc) -> ServerSnapshot<'_> { ServerSnapshot { connection_pool: ConnectionPoolGuard { @@ -2882,214 +2815,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn current_plan(db: &Arc, user_id: UserId, is_staff: bool) -> Result { - if is_staff { - return Ok(proto::Plan::ZedPro); - } - - let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_kind = subscription.and_then(|subscription| subscription.kind); - - let plan = if let Some(subscription_kind) = subscription_kind { - match subscription_kind { - SubscriptionKind::ZedPro => proto::Plan::ZedPro, - SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, - SubscriptionKind::ZedFree => proto::Plan::Free, - } - } else { - proto::Plan::Free - }; - - Ok(plan) -} - -async fn make_update_user_plan_message( - user: &User, - is_staff: bool, - db: &Arc, - llm_db: Option>, -) -> Result { - let feature_flags = db.get_user_flags(user.id).await?; - let plan = current_plan(db, user.id, is_staff).await?; - let billing_customer = db.get_billing_customer_by_user_id(user.id).await?; - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let (subscription_period, usage) = if let Some(llm_db) = llm_db { - let subscription = db.get_active_billing_subscription(user.id).await?; - - let subscription_period = - crate::db::billing_subscription::Model::current_period(subscription, is_staff); - - let usage = if let Some((period_start_at, period_end_at)) = subscription_period { - llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await? - } else { - None - }; - - (subscription_period, usage) - } else { - (None, None) - }; - - let bypass_account_age_check = feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG); - let account_too_young = !matches!(plan, proto::Plan::ZedPro) - && !bypass_account_age_check - && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE; - - Ok(proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: billing_customer - .as_ref() - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), - is_usage_based_billing_enabled: if is_staff { - Some(true) - } else { - billing_preferences.map(|preferences| preferences.model_request_overages_enabled) - }, - subscription_period: subscription_period.map(|(started_at, ended_at)| { - proto::SubscriptionPeriod { - started_at: started_at.timestamp() as u64, - ended_at: ended_at.timestamp() as u64, - } - }), - account_too_young: Some(account_too_young), - has_overdue_invoices: billing_customer - .map(|billing_customer| billing_customer.has_overdue_invoices), - usage: Some( - usage - .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags)) - .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)), - ), - }) -} - -fn model_requests_limit( - plan: cloud_llm_client::Plan, - feature_flags: &Vec, -) -> cloud_llm_client::UsageLimit { - match plan.model_requests_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == cloud_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; - - cloud_llm_client::UsageLimit::Limited(limit) - } - cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, - } -} - -fn subscription_usage_to_proto( - plan: proto::Plan, - usage: crate::llm::db::subscription_usage::Model, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -fn make_default_subscription_usage( - plan: proto::Plan, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: 0, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: 0, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -async fn update_user_plan(session: &Session) -> Result<()> { - let db = session.db().await; - - let update_user_plan = make_update_user_plan_message( - session.principal.user(), - session.is_staff(), - &db.0, - session.app_state.llm_db.clone(), - ) - .await?; - - session - .peer - .send(session.connection_id, update_user_plan) - .trace_err(); - - Ok(()) -} - async fn subscribe_to_channels( _: proto::SubscribeToChannels, session: MessageContext, @@ -4258,139 +3983,6 @@ async fn mark_notification_as_read( Ok(()) } -/// Get the current users information -async fn get_private_user_info( - _request: proto::GetPrivateUserInfo, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let metrics_id = db.get_user_metrics_id(session.user_id()).await?; - let user = db - .get_user_by_id(session.user_id()) - .await? - .context("user not found")?; - let flags = db.get_user_flags(session.user_id()).await?; - - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - flags, - accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64), - })?; - Ok(()) -} - -/// Accept the terms of service (tos) on behalf of the current user -async fn accept_terms_of_service( - _request: proto::AcceptTermsOfService, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let accepted_tos_at = Utc::now(); - db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc())) - .await?; - - response.send(proto::AcceptTermsOfServiceResponse { - accepted_tos_at: accepted_tos_at.timestamp() as u64, - })?; - - // When the user accepts the terms of service, we want to refresh their LLM - // token to grant access. - session - .peer - .send(session.connection_id, proto::RefreshLlmToken {})?; - - Ok(()) -} - -async fn get_llm_api_token( - _request: proto::GetLlmToken, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let flags = db.get_user_flags(session.user_id()).await?; - - let user_id = session.user_id(); - let user = db - .get_user_by_id(user_id) - .await? - .with_context(|| format!("user {user_id} not found"))?; - - if user.accepted_tos_at.is_none() { - Err(anyhow!("terms of service not accepted"))? - } - - let stripe_client = session - .app_state - .stripe_client - .as_ref() - .context("failed to retrieve Stripe client")?; - - let stripe_billing = session - .app_state - .stripe_billing - .as_ref() - .context("failed to retrieve Stripe billing object")?; - - let billing_customer = if let Some(billing_customer) = - db.get_billing_customer_by_user_id(user.id).await? - { - billing_customer - } else { - let customer_id = stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; - - find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id) - .await? - .context("billing customer not found")? - }; - - let billing_subscription = - if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { - billing_subscription - } else { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let stripe_subscription = stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - - db.create_billing_subscription(&db::CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: Some(SubscriptionKind::ZedFree), - stripe_subscription_id: stripe_subscription.id.to_string(), - stripe_subscription_status: stripe_subscription.status.into(), - stripe_cancellation_reason: None, - stripe_current_period_start: Some(stripe_subscription.current_period_start), - stripe_current_period_end: Some(stripe_subscription.current_period_end), - }) - .await? - }; - - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let token = LlmTokenClaims::create( - &user, - session.is_staff(), - billing_customer, - billing_preferences, - &flags, - billing_subscription, - session.system_id.clone(), - &session.app_state.config, - )?; - response.send(proto::GetLlmTokenResponse { token })?; - Ok(()) -} - fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()), diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs index 35290fa697..729e7c8533 100644 --- a/crates/collab/src/rpc/connection_pool.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -30,7 +30,19 @@ impl fmt::Display for ZedVersion { impl ZedVersion { pub fn can_collaborate(&self) -> bool { - self.0 >= SemanticVersion::new(0, 157, 0) + // v0.198.4 is the first version where we no longer connect to Collab automatically. + // We reject any clients older than that to prevent them from connecting to Collab just for authentication. + if self.0 < SemanticVersion::new(0, 198, 4) { + return false; + } + + // Since we hotfixed the changes to no longer connect to Collab automatically to Preview, we also need to reject + // versions in the range [v0.199.0, v0.199.1]. + if self.0 >= SemanticVersion::new(0, 199, 0) && self.0 < SemanticVersion::new(0, 199, 2) { + return false; + } + + true } } diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs deleted file mode 100644 index ef5bef3e7e..0000000000 --- a/crates/collab/src/stripe_billing.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::sync::Arc; - -use anyhow::anyhow; -use collections::HashMap; -use stripe::SubscriptionStatus; -use tokio::sync::RwLock; - -use crate::Result; -use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, - StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, - StripeSubscription, -}; - -pub struct StripeBilling { - state: RwLock, - client: Arc, -} - -#[derive(Default)] -struct StripeBillingState { - prices_by_lookup_key: HashMap, -} - -impl StripeBilling { - pub fn new(client: Arc) -> Self { - Self { - client: Arc::new(RealStripeClient::new(client.clone())), - state: RwLock::default(), - } - } - - #[cfg(test)] - pub fn test(client: Arc) -> Self { - Self { - client, - state: RwLock::default(), - } - } - - pub fn client(&self) -> &Arc { - &self.client - } - - pub async fn initialize(&self) -> Result<()> { - log::info!("StripeBilling: initializing"); - - let mut state = self.state.write().await; - - let prices = self.client.list_prices().await?; - - for price in prices { - if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price); - } - } - - log::info!("StripeBilling: initialized"); - - Ok(()) - } - - pub async fn zed_pro_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-pro").await - } - - pub async fn zed_free_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-free").await - } - - pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .map(|price| price.id.clone()) - .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) - } - - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .cloned() - .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) - } - - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does - /// not already exist. - /// - /// Always returns a new Stripe customer if the email address is `None`. - pub async fn find_or_create_customer_by_email( - &self, - email_address: Option<&str>, - ) -> Result { - let existing_customer = if let Some(email) = email_address { - let customers = self.client.list_customers_by_email(email).await?; - - customers.first().cloned() - } else { - None - }; - - let customer_id = if let Some(existing_customer) = existing_customer { - existing_customer.id - } else { - let customer = self - .client - .create_customer(crate::stripe_client::CreateCustomerParams { - email: email_address, - }) - .await?; - - customer.id - }; - - Ok(customer_id) - } - - pub async fn subscribe_to_zed_free( - &self, - customer_id: StripeCustomerId, - ) -> Result { - let zed_free_price_id = self.zed_free_price_id().await?; - - let existing_subscriptions = self - .client - .list_subscriptions_for_customer(&customer_id) - .await?; - - let existing_active_subscription = - existing_subscriptions.into_iter().find(|subscription| { - subscription.status == SubscriptionStatus::Active - || subscription.status == SubscriptionStatus::Trialing - }); - if let Some(subscription) = existing_active_subscription { - return Ok(subscription); - } - - let params = StripeCreateSubscriptionParams { - customer: customer_id, - items: vec![StripeCreateSubscriptionItems { - price: Some(zed_free_price_id), - quantity: Some(1), - }], - automatic_tax: Some(StripeAutomaticTax { enabled: true }), - }; - - let subscription = self.client.create_subscription(params).await?; - - Ok(subscription) - } -} diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs deleted file mode 100644 index 6e75a4d874..0000000000 --- a/crates/collab/src/stripe_client.rs +++ /dev/null @@ -1,285 +0,0 @@ -#[cfg(test)] -mod fake_stripe_client; -mod real_stripe_client; - -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; - -#[cfg(test)] -pub use fake_stripe_client::*; -pub use real_stripe_client::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)] -pub struct StripeCustomerId(pub Arc); - -#[derive(Debug, Clone)] -pub struct StripeCustomer { - pub id: StripeCustomerId, - pub email: Option, -} - -#[derive(Debug)] -pub struct CreateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug)] -pub struct UpdateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscription { - pub id: StripeSubscriptionId, - pub customer: StripeCustomerId, - // TODO: Create our own version of this enum. - pub status: stripe::SubscriptionStatus, - pub current_period_end: i64, - pub current_period_start: i64, - pub items: Vec, - pub cancel_at: Option, - pub cancellation_details: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionItemId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionItem { - pub id: StripeSubscriptionItemId, - pub price: Option, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StripeCancellationDetails { - pub reason: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCancellationDetailsReason { - CancellationRequested, - PaymentDisputed, - PaymentFailed, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionParams { - pub customer: StripeCustomerId, - pub items: Vec, - pub automatic_tax: Option, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionParams { - pub items: Option>, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct UpdateSubscriptionItems { - pub price: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettings { - pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettingsEndBehavior { - pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { - Cancel, - CreateInvoice, - Pause, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripePriceId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePrice { - pub id: StripePriceId, - pub unit_amount: Option, - pub lookup_key: Option, - pub recurring: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePriceRecurring { - pub meter: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)] -pub struct StripeMeterId(pub Arc); - -#[derive(Debug, Clone, Deserialize)] -pub struct StripeMeter { - pub id: StripeMeterId, - pub event_name: String, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventParams<'a> { - pub identifier: &'a str, - pub event_name: &'a str, - pub payload: StripeCreateMeterEventPayload<'a>, - pub timestamp: Option, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventPayload<'a> { - pub value: u64, - pub stripe_customer_id: &'a StripeCustomerId, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeBillingAddressCollection { - Auto, - Required, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCustomerUpdate { - pub address: Option, - pub name: Option, - pub shipping: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateAddress { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateName { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateShipping { - Auto, - Never, -} - -#[derive(Debug, Default)] -pub struct StripeCreateCheckoutSessionParams<'a> { - pub customer: Option<&'a StripeCustomerId>, - pub client_reference_id: Option<&'a str>, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option<&'a str>, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionMode { - Payment, - Setup, - Subscription, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionLineItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionPaymentMethodCollection { - Always, - IfRequired, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionSubscriptionData { - pub metadata: Option>, - pub trial_period_days: Option, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeTaxIdCollection { - pub enabled: bool, -} - -#[derive(Debug, Clone)] -pub struct StripeAutomaticTax { - pub enabled: bool, -} - -#[derive(Debug)] -pub struct StripeCheckoutSession { - pub url: Option, -} - -#[async_trait] -pub trait StripeClient: Send + Sync { - async fn list_customers_by_email(&self, email: &str) -> Result>; - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result; - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result; - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result>; - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result; - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result; - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()>; - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>; - - async fn list_prices(&self) -> Result>; - - async fn list_meters(&self) -> Result>; - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result; -} diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs deleted file mode 100644 index 9bb08443ec..0000000000 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Result, anyhow}; -use async_trait::async_trait; -use chrono::{Duration, Utc}; -use collections::HashMap; -use parking_lot::Mutex; -use uuid::Uuid; - -use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -#[derive(Debug, Clone)] -pub struct StripeCreateMeterEventCall { - pub identifier: Arc, - pub event_name: Arc, - pub value: u64, - pub stripe_customer_id: StripeCustomerId, - pub timestamp: Option, -} - -#[derive(Debug, Clone)] -pub struct StripeCreateCheckoutSessionCall { - pub customer: Option, - pub client_reference_id: Option, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -pub struct FakeStripeClient { - pub customers: Arc>>, - pub subscriptions: Arc>>, - pub update_subscription_calls: - Arc>>, - pub prices: Arc>>, - pub meters: Arc>>, - pub create_meter_event_calls: Arc>>, - pub create_checkout_session_calls: Arc>>, -} - -impl FakeStripeClient { - pub fn new() -> Self { - Self { - customers: Arc::new(Mutex::new(HashMap::default())), - subscriptions: Arc::new(Mutex::new(HashMap::default())), - update_subscription_calls: Arc::new(Mutex::new(Vec::new())), - prices: Arc::new(Mutex::new(HashMap::default())), - meters: Arc::new(Mutex::new(HashMap::default())), - create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), - create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())), - } - } -} - -#[async_trait] -impl StripeClient for FakeStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - Ok(self - .customers - .lock() - .values() - .filter(|customer| customer.email.as_deref() == Some(email)) - .cloned() - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - self.customers - .lock() - .get(customer_id) - .cloned() - .ok_or_else(|| anyhow!("no customer found for {customer_id:?}")) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = StripeCustomer { - id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()), - email: params.email.map(|email| email.to_string()), - }; - - self.customers - .lock() - .insert(customer.id.clone(), customer.clone()); - - Ok(customer) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let mut customers = self.customers.lock(); - if let Some(customer) = customers.get_mut(customer_id) { - if let Some(email) = params.email { - customer.email = Some(email.to_string()); - } - Ok(customer.clone()) - } else { - Err(anyhow!("no customer found for {customer_id:?}")) - } - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let subscriptions = self - .subscriptions - .lock() - .values() - .filter(|subscription| subscription.customer == *customer_id) - .cloned() - .collect(); - - Ok(subscriptions) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - self.subscriptions - .lock() - .get(subscription_id) - .cloned() - .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let now = Utc::now(); - - let subscription = StripeSubscription { - id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()), - customer: params.customer, - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: params - .items - .into_iter() - .map(|item| StripeSubscriptionItem { - id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()), - price: item - .price - .and_then(|price_id| self.prices.lock().get(&price_id).cloned()), - }) - .collect(), - cancel_at: None, - cancellation_details: None, - }; - - self.subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - Ok(subscription) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription = self.get_subscription(subscription_id).await?; - - self.update_subscription_calls - .lock() - .push((subscription.id, params)); - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - // TODO: Implement fake subscription cancellation. - let _ = subscription_id; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let prices = self.prices.lock().values().cloned().collect(); - - Ok(prices) - } - - async fn list_meters(&self) -> Result> { - let meters = self.meters.lock().values().cloned().collect(); - - Ok(meters) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - self.create_meter_event_calls - .lock() - .push(StripeCreateMeterEventCall { - identifier: params.identifier.into(), - event_name: params.event_name.into(), - value: params.payload.value, - stripe_customer_id: params.payload.stripe_customer_id.clone(), - timestamp: params.timestamp, - }); - - Ok(()) - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - self.create_checkout_session_calls - .lock() - .push(StripeCreateCheckoutSessionCall { - customer: params.customer.cloned(), - client_reference_id: params.client_reference_id.map(|id| id.to_string()), - mode: params.mode, - line_items: params.line_items, - payment_method_collection: params.payment_method_collection, - subscription_data: params.subscription_data, - success_url: params.success_url.map(|url| url.to_string()), - billing_address_collection: params.billing_address_collection, - customer_update: params.customer_update, - tax_id_collection: params.tax_id_collection, - }); - - Ok(StripeCheckoutSession { - url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()), - }) - } -} diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs deleted file mode 100644 index 07c191ff30..0000000000 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::str::FromStr as _; -use std::sync::Arc; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use stripe::{ - CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode, - CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, - PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, - UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, - UpdateSubscriptionTrialSettingsEndBehavior, - UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -}; - -use crate::stripe_client::{ - CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, - StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping, - StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -pub struct RealStripeClient { - client: Arc, -} - -impl RealStripeClient { - pub fn new(client: Arc) -> Self { - Self { client } - } -} - -#[async_trait] -impl StripeClient for RealStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - let response = Customer::list( - &self.client, - &ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; - - Ok(response - .data - .into_iter() - .map(StripeCustomer::from) - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - let customer_id = customer_id.try_into()?; - - let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = Customer::create( - &self.client, - CreateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let customer = Customer::update( - &self.client, - &customer_id.try_into()?, - UpdateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let customer_id = customer_id.try_into()?; - - let subscriptions = stripe::Subscription::list( - &self.client, - &stripe::ListSubscriptions { - customer: Some(customer_id), - status: None, - ..Default::default() - }, - ) - .await?; - - Ok(subscriptions - .data - .into_iter() - .map(StripeSubscription::from) - .collect()) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - let subscription_id = subscription_id.try_into()?; - - let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let customer_id = params.customer.try_into()?; - - let mut create_subscription = stripe::CreateSubscription::new(customer_id); - create_subscription.items = Some( - params - .items - .into_iter() - .map(|item| stripe::CreateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - quantity: item.quantity, - ..Default::default() - }) - .collect(), - ); - create_subscription.automatic_tax = params.automatic_tax.map(Into::into); - - let subscription = Subscription::create(&self.client, create_subscription).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - stripe::Subscription::update( - &self.client, - &subscription_id, - stripe::UpdateSubscription { - items: params.items.map(|items| { - items - .into_iter() - .map(|item| UpdateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - ..Default::default() - }) - .collect() - }), - trial_settings: params.trial_settings.map(Into::into), - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - Subscription::cancel( - &self.client, - &subscription_id, - stripe::CancelSubscription { - invoice_now: None, - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let response = stripe::Price::list( - &self.client, - &stripe::ListPrices { - limit: Some(100), - ..Default::default() - }, - ) - .await?; - - Ok(response.data.into_iter().map(StripePrice::from).collect()) - } - - async fn list_meters(&self) -> Result> { - #[derive(Serialize)] - struct Params { - #[serde(skip_serializing_if = "Option::is_none")] - limit: Option, - } - - let response = self - .client - .get_query::, _>( - "/billing/meters", - Params { limit: Some(100) }, - ) - .await?; - - Ok(response.data) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - #[derive(Deserialize)] - struct StripeMeterEvent { - pub identifier: String, - } - - let identifier = params.identifier; - match self - .client - .post_form::("/billing/meter_events", params) - .await - { - Ok(_event) => Ok(()), - Err(stripe::StripeError::Stripe(error)) => { - if error.http_status == 400 - && error - .message - .as_ref() - .map_or(false, |message| message.contains(identifier)) - { - Ok(()) - } else { - Err(anyhow!(stripe::StripeError::Stripe(error))) - } - } - Err(error) => Err(anyhow!("failed to create meter event: {error:?}")), - } - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - let params = params.try_into()?; - let session = CheckoutSession::create(&self.client, params).await?; - - Ok(session.into()) - } -} - -impl From for StripeCustomerId { - fn from(value: CustomerId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl TryFrom<&StripeCustomerId> for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: &StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl From for StripeCustomer { - fn from(value: Customer) -> Self { - StripeCustomer { - id: value.id.into(), - email: value.email, - } - } -} - -impl From for StripeSubscriptionId { - fn from(value: SubscriptionId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom<&StripeSubscriptionId> for SubscriptionId { - type Error = anyhow::Error; - - fn try_from(value: &StripeSubscriptionId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID") - } -} - -impl From for StripeSubscription { - fn from(value: Subscription) -> Self { - Self { - id: value.id.into(), - customer: value.customer.id().into(), - status: value.status, - current_period_start: value.current_period_start, - current_period_end: value.current_period_end, - items: value.items.data.into_iter().map(Into::into).collect(), - cancel_at: value.cancel_at, - cancellation_details: value.cancellation_details.map(Into::into), - } - } -} - -impl From for StripeCancellationDetails { - fn from(value: CancellationDetails) -> Self { - Self { - reason: value.reason.map(Into::into), - } - } -} - -impl From for StripeCancellationDetailsReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - -impl From for StripeSubscriptionItemId { - fn from(value: SubscriptionItemId) -> Self { - Self(value.as_str().into()) - } -} - -impl From for StripeSubscriptionItem { - fn from(value: SubscriptionItem) -> Self { - Self { - id: value.id.into(), - price: value.price.map(Into::into), - } - } -} - -impl From for CreateSubscriptionAutomaticTax { - fn from(value: StripeAutomaticTax) -> Self { - Self { - enabled: value.enabled, - liability: None, - } - } -} - -impl From for UpdateSubscriptionTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripePriceId { - fn from(value: PriceId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for PriceId { - type Error = anyhow::Error; - - fn try_from(value: StripePriceId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID") - } -} - -impl From for StripePrice { - fn from(value: Price) -> Self { - Self { - id: value.id.into(), - unit_amount: value.unit_amount, - lookup_key: value.lookup_key, - recurring: value.recurring.map(StripePriceRecurring::from), - } - } -} - -impl From for StripePriceRecurring { - fn from(value: Recurring) -> Self { - Self { meter: value.meter } - } -} - -impl<'a> TryFrom> for CreateCheckoutSession<'a> { - type Error = anyhow::Error; - - fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result { - Ok(Self { - customer: value - .customer - .map(|customer_id| customer_id.try_into()) - .transpose()?, - client_reference_id: value.client_reference_id, - mode: value.mode.map(Into::into), - line_items: value - .line_items - .map(|line_items| line_items.into_iter().map(Into::into).collect()), - payment_method_collection: value.payment_method_collection.map(Into::into), - subscription_data: value.subscription_data.map(Into::into), - success_url: value.success_url, - billing_address_collection: value.billing_address_collection.map(Into::into), - customer_update: value.customer_update.map(Into::into), - tax_id_collection: value.tax_id_collection.map(Into::into), - ..Default::default() - }) - } -} - -impl From for CheckoutSessionMode { - fn from(value: StripeCheckoutSessionMode) -> Self { - match value { - StripeCheckoutSessionMode::Payment => Self::Payment, - StripeCheckoutSessionMode::Setup => Self::Setup, - StripeCheckoutSessionMode::Subscription => Self::Subscription, - } - } -} - -impl From for CreateCheckoutSessionLineItems { - fn from(value: StripeCreateCheckoutSessionLineItems) -> Self { - Self { - price: value.price, - quantity: value.quantity, - ..Default::default() - } - } -} - -impl From for CheckoutSessionPaymentMethodCollection { - fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self { - match value { - StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always, - StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired, - } - } -} - -impl From for CreateCheckoutSessionSubscriptionData { - fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self { - Self { - trial_period_days: value.trial_period_days, - trial_settings: value.trial_settings.map(Into::into), - metadata: value.metadata, - ..Default::default() - } - } -} - -impl From for CreateCheckoutSessionSubscriptionDataTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripeCheckoutSession { - fn from(value: CheckoutSession) -> Self { - Self { url: value.url } - } -} - -impl From for stripe::CheckoutSessionBillingAddressCollection { - fn from(value: StripeBillingAddressCollection) -> Self { - match value { - StripeBillingAddressCollection::Auto => { - stripe::CheckoutSessionBillingAddressCollection::Auto - } - StripeBillingAddressCollection::Required => { - stripe::CheckoutSessionBillingAddressCollection::Required - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateAddress { - fn from(value: StripeCustomerUpdateAddress) -> Self { - match value { - StripeCustomerUpdateAddress::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto - } - StripeCustomerUpdateAddress::Never => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateName { - fn from(value: StripeCustomerUpdateName) -> Self { - match value { - StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto, - StripeCustomerUpdateName::Never => { - stripe::CreateCheckoutSessionCustomerUpdateName::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateShipping { - fn from(value: StripeCustomerUpdateShipping) -> Self { - match value { - StripeCustomerUpdateShipping::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto - } - StripeCustomerUpdateShipping::Never => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdate { - fn from(value: StripeCustomerUpdate) -> Self { - stripe::CreateCheckoutSessionCustomerUpdate { - address: value.address.map(Into::into), - name: value.name.map(Into::into), - shipping: value.shipping.map(Into::into), - } - } -} - -impl From for stripe::CreateCheckoutSessionTaxIdCollection { - fn from(value: StripeTaxIdCollection) -> Self { - stripe::CreateCheckoutSessionTaxIdCollection { - enabled: value.enabled, - } - } -} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 8d5d076780..ddf245b06f 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -8,7 +8,6 @@ mod channel_buffer_tests; mod channel_guest_tests; mod channel_message_tests; mod channel_tests; -// mod debug_panel_tests; mod editor_tests; mod following_tests; mod git_tests; @@ -18,7 +17,6 @@ mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; mod remote_editing_collaboration_tests; -mod stripe_billing_tests; mod test_server; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs deleted file mode 100644 index bb84bedfcf..0000000000 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::sync::Arc; - -use pretty_assertions::assert_eq; - -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; - -fn make_stripe_billing() -> (StripeBilling, Arc) { - let stripe_client = Arc::new(FakeStripeClient::new()); - let stripe_billing = StripeBilling::test(stripe_client.clone()); - - (stripe_billing, stripe_client) -} - -#[gpui::test] -async fn test_initialize() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Add test prices - let price1 = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(1_000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - let price2 = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - let price3 = StripePrice { - id: StripePriceId("price_3".into()), - unit_amount: Some(500), - lookup_key: None, - recurring: Some(StripePriceRecurring { - meter: Some("meter_1".to_string()), - }), - }; - stripe_client - .prices - .lock() - .insert(price1.id.clone(), price1); - stripe_client - .prices - .lock() - .insert(price2.id.clone(), price2); - stripe_client - .prices - .lock() - .insert(price3.id.clone(), price3); - - // Initialize the billing system - stripe_billing.initialize().await.unwrap(); - - // Verify that prices can be found by lookup key - let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap(); - assert_eq!(zed_pro_price_id.to_string(), "price_1"); - - let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap(); - assert_eq!(zed_free_price_id.to_string(), "price_2"); - - // Verify that a price can be found by lookup key - let zed_pro_price = stripe_billing - .find_price_by_lookup_key("zed-pro") - .await - .unwrap(); - assert_eq!(zed_pro_price.id.to_string(), "price_1"); - assert_eq!(zed_pro_price.unit_amount, Some(1_000)); - - // Verify that finding a non-existent lookup key returns an error - let result = stripe_billing - .find_price_by_lookup_key("non-existent") - .await; - assert!(result.is_err()); -} - -#[gpui::test] -async fn test_find_or_create_customer_by_email() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Create a customer with an email that doesn't yet correspond to a customer. - { - let email = "user@example.com"; - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } - - // Create a customer with an email that corresponds to an existing customer. - { - let email = "user2@example.com"; - - let existing_customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - assert_eq!(customer_id, existing_customer_id); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index f5a0e8ea81..07ea1efc9d 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,4 +1,3 @@ -use crate::stripe_client::FakeStripeClient; use crate::{ AppState, Config, db::{NewUserParams, UserId, tests::TestDb}, @@ -566,12 +565,8 @@ impl TestServer { ) -> Arc { Arc::new(AppState { db: test_db.db().clone(), - llm_db: None, livekit_client: Some(Arc::new(livekit_test_server.create_api_client())), blob_store_client: None, - real_stripe_client: None, - stripe_client: Some(Arc::new(FakeStripeClient::new())), - stripe_billing: None, executor, kinesis_client: None, config: Config { @@ -608,7 +603,6 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 51d9f003f8..2bbaa8446c 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -674,7 +674,7 @@ impl ChatPanel { }) }) .when_some(message_id, |el, message_id| { - let this = cx.entity().clone(); + let this = cx.entity(); el.child( self.render_popover_button( diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 430b447580..c2cc6a7ad5 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -95,7 +95,7 @@ pub fn init(cx: &mut App) { .and_then(|room| room.read(cx).channel_id()); if let Some(channel_id) = channel_id { - let workspace = cx.entity().clone(); + let workspace = cx.entity(); window.defer(cx, move |window, cx| { ChannelView::open(channel_id, None, workspace, window, cx) .detach_and_log_err(cx) @@ -1142,7 +1142,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); if !(role == proto::ChannelRole::Guest || role == proto::ChannelRole::Talker || role == proto::ChannelRole::Member) @@ -1272,7 +1272,7 @@ impl CollabPanel { .channel_for_id(clipboard.channel_id) .map(|channel| channel.name.clone()) }); - let this = cx.entity().clone(); + let this = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| { if self.has_subchannels(ix) { @@ -1439,7 +1439,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); let in_room = ActiveCall::global(cx).read(cx).room().is_some(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| { diff --git a/crates/collab_ui/src/collab_panel/channel_modal.rs b/crates/collab_ui/src/collab_panel/channel_modal.rs index c0d3130ee9..e558835dba 100644 --- a/crates/collab_ui/src/collab_panel/channel_modal.rs +++ b/crates/collab_ui/src/collab_panel/channel_modal.rs @@ -586,7 +586,7 @@ impl ChannelModalDelegate { return; }; let user_id = membership.user.id; - let picker = cx.entity().clone(); + let picker = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| { let role = membership.role; diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index 3a280ff667..a3420d603b 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -321,7 +321,7 @@ impl NotificationPanel { .justify_end() .child(Button::new("decline", "Decline").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( @@ -334,7 +334,7 @@ impl NotificationPanel { })) .child(Button::new("accept", "Accept").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 49ae2b9d9c..dcebeae721 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -21,7 +21,7 @@ use language::{ point_from_lsp, point_to_lsp, }; use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use parking_lot::Mutex; use project::DisableAiSettings; use request::StatusNotification; @@ -349,7 +349,11 @@ impl Copilot { this.start_copilot(true, false, cx); cx.observe_global::(move |this, cx| { this.start_copilot(true, false, cx); - this.send_configuration_update(cx); + if let Ok(server) = this.server.as_running() { + notify_did_change_config_to_server(&server.lsp, cx) + .context("copilot setting change: did change configuration") + .log_err(); + } }) .detach(); this @@ -438,43 +442,6 @@ impl Copilot { if env.is_empty() { None } else { Some(env) } } - fn send_configuration_update(&mut self, cx: &mut Context) { - let copilot_settings = all_language_settings(None, cx) - .edit_predictions - .copilot - .clone(); - - let settings = json!({ - "http": { - "proxy": copilot_settings.proxy, - "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false) - }, - "github-enterprise": { - "uri": copilot_settings.enterprise_uri - } - }); - - if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) { - copilot_chat.update(cx, |chat, cx| { - chat.set_configuration( - copilot_chat::CopilotChatConfiguration { - enterprise_uri: copilot_settings.enterprise_uri.clone(), - }, - cx, - ); - }); - } - - if let Ok(server) = self.server.as_running() { - server - .lsp - .notify::( - &lsp::DidChangeConfigurationParams { settings }, - ) - .log_err(); - } - } - #[cfg(any(test, feature = "test-support"))] pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity, lsp::FakeLanguageServer) { use fs::FakeFs; @@ -573,6 +540,9 @@ impl Copilot { })? .await?; + this.update(cx, |_, cx| notify_did_change_config_to_server(&server, cx))? + .context("copilot: did change configuration")?; + let status = server .request::(request::CheckStatusParams { local_checks_only: false, @@ -598,8 +568,6 @@ impl Copilot { }); cx.emit(Event::CopilotLanguageServerStarted); this.update_sign_in_status(status, cx); - // Send configuration now that the LSP is fully started - this.send_configuration_update(cx); } Err(error) => { this.server = CopilotServer::Error(error.to_string().into()); @@ -1156,6 +1124,41 @@ fn uri_for_buffer(buffer: &Entity, cx: &App) -> Result { } } +fn notify_did_change_config_to_server( + server: &Arc, + cx: &mut Context, +) -> std::result::Result<(), anyhow::Error> { + let copilot_settings = all_language_settings(None, cx) + .edit_predictions + .copilot + .clone(); + + if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) { + copilot_chat.update(cx, |chat, cx| { + chat.set_configuration( + copilot_chat::CopilotChatConfiguration { + enterprise_uri: copilot_settings.enterprise_uri.clone(), + }, + cx, + ); + }); + } + + let settings = json!({ + "http": { + "proxy": copilot_settings.proxy, + "proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false) + }, + "github-enterprise": { + "uri": copilot_settings.enterprise_uri + } + }); + + server.notify::(&lsp::DidChangeConfigurationParams { + settings, + }) +} + async fn clear_copilot_dir() { remove_matching(paths::copilot_dir(), |_| true).await } @@ -1181,7 +1184,7 @@ async fn get_copilot_lsp(fs: Arc, node_runtime: NodeRuntime) -> anyhow:: PACKAGE_NAME, &server_path, paths::copilot_dir(), - &latest_version, + VersionStrategy::Latest(&latest_version), ) .await; if should_install { diff --git a/crates/crashes/Cargo.toml b/crates/crashes/Cargo.toml index 641a97765a..2420b499f8 100644 --- a/crates/crashes/Cargo.toml +++ b/crates/crashes/Cargo.toml @@ -10,7 +10,10 @@ crash-handler.workspace = true log.workspace = true minidumper.workspace = true paths.workspace = true +release_channel.workspace = true smol.workspace = true +serde.workspace = true +serde_json.workspace = true workspace-hack.workspace = true [lints] diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs index cfb4b57d5d..ddf6468be8 100644 --- a/crates/crashes/src/crashes.rs +++ b/crates/crashes/src/crashes.rs @@ -1,15 +1,18 @@ use crash_handler::CrashHandler; use log::info; use minidumper::{Client, LoopAction, MinidumpBinary}; +use release_channel::{RELEASE_CHANNEL, ReleaseChannel}; +use serde::{Deserialize, Serialize}; use std::{ env, - fs::File, + fs::{self, File}, io, + panic::Location, path::{Path, PathBuf}, process::{self, Command}, sync::{ - OnceLock, + Arc, OnceLock, atomic::{AtomicBool, Ordering}, }, thread, @@ -17,12 +20,17 @@ use std::{ }; // set once the crash handler has initialized and the client has connected to it -pub static CRASH_HANDLER: AtomicBool = AtomicBool::new(false); +pub static CRASH_HANDLER: OnceLock> = OnceLock::new(); // set when the first minidump request is made to avoid generating duplicate crash reports pub static REQUESTED_MINIDUMP: AtomicBool = AtomicBool::new(false); -const CRASH_HANDLER_TIMEOUT: Duration = Duration::from_secs(60); +const CRASH_HANDLER_PING_TIMEOUT: Duration = Duration::from_secs(60); +const CRASH_HANDLER_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); + +pub async fn init(crash_init: InitCrashHandler) { + if *RELEASE_CHANNEL == ReleaseChannel::Dev && env::var("ZED_GENERATE_MINIDUMPS").is_err() { + return; + } -pub async fn init(id: String) { let exe = env::current_exe().expect("unable to find ourselves"); let zed_pid = process::id(); // TODO: we should be able to get away with using 1 crash-handler process per machine, @@ -53,9 +61,11 @@ pub async fn init(id: String) { smol::Timer::after(retry_frequency).await; } let client = maybe_client.unwrap(); - client.send_message(1, id).unwrap(); // set session id on the server + client + .send_message(1, serde_json::to_vec(&crash_init).unwrap()) + .unwrap(); - let client = std::sync::Arc::new(client); + let client = Arc::new(client); let handler = crash_handler::CrashHandler::attach(unsafe { let client = client.clone(); crash_handler::make_crash_event(move |crash_context: &crash_handler::CrashContext| { @@ -64,7 +74,6 @@ pub async fn init(id: String) { .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) .is_ok() { - client.send_message(2, "mistakes were made").unwrap(); client.ping().unwrap(); client.request_dump(crash_context).is_ok() } else { @@ -79,7 +88,7 @@ pub async fn init(id: String) { { handler.set_ptracer(Some(server_pid)); } - CRASH_HANDLER.store(true, Ordering::Release); + CRASH_HANDLER.set(client.clone()).ok(); std::mem::forget(handler); info!("crash handler registered"); @@ -90,14 +99,43 @@ pub async fn init(id: String) { } pub struct CrashServer { - session_id: OnceLock, + initialization_params: OnceLock, + panic_info: OnceLock, + has_connection: Arc, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct CrashInfo { + pub init: InitCrashHandler, + pub panic: Option, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct InitCrashHandler { + pub session_id: String, + pub zed_version: String, + pub release_channel: String, + pub commit_sha: String, + // pub gpu: String, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct CrashPanic { + pub message: String, + pub span: String, } impl minidumper::ServerHandler for CrashServer { fn create_minidump_file(&self) -> Result<(File, PathBuf), io::Error> { - let err_message = "Need to send a message with the ID upon starting the crash handler"; + let err_message = "Missing initialization data"; let dump_path = paths::logs_dir() - .join(self.session_id.get().expect(err_message)) + .join( + &self + .initialization_params + .get() + .expect(err_message) + .session_id, + ) .with_extension("dmp"); let file = File::create(&dump_path)?; Ok((file, dump_path)) @@ -114,35 +152,71 @@ impl minidumper::ServerHandler for CrashServer { info!("failed to write minidump: {:#}", e); } } + + let crash_info = CrashInfo { + init: self + .initialization_params + .get() + .expect("not initialized") + .clone(), + panic: self.panic_info.get().cloned(), + }; + + let crash_data_path = paths::logs_dir() + .join(&crash_info.init.session_id) + .with_extension("json"); + + fs::write(crash_data_path, serde_json::to_vec(&crash_info).unwrap()).ok(); + LoopAction::Exit } fn on_message(&self, kind: u32, buffer: Vec) { - let message = String::from_utf8(buffer).expect("invalid utf-8"); - info!("kind: {kind}, message: {message}",); - if kind == 1 { - self.session_id - .set(message) - .expect("session id already initialized"); + match kind { + 1 => { + let init_data = + serde_json::from_slice::(&buffer).expect("invalid init data"); + self.initialization_params + .set(init_data) + .expect("already initialized"); + } + 2 => { + let panic_data = + serde_json::from_slice::(&buffer).expect("invalid panic data"); + self.panic_info.set(panic_data).expect("already panicked"); + } + _ => { + panic!("invalid message kind"); + } } } - fn on_client_disconnected(&self, clients: usize) -> LoopAction { - info!("client disconnected, {clients} remaining"); - if clients == 0 { - LoopAction::Exit - } else { - LoopAction::Continue - } + fn on_client_disconnected(&self, _clients: usize) -> LoopAction { + LoopAction::Exit + } + + fn on_client_connected(&self, _clients: usize) -> LoopAction { + self.has_connection.store(true, Ordering::SeqCst); + LoopAction::Continue } } -pub fn handle_panic() { +pub fn handle_panic(message: String, span: Option<&Location>) { + let span = span + .map(|loc| format!("{}:{}", loc.file(), loc.line())) + .unwrap_or_default(); + // wait 500ms for the crash handler process to start up // if it's still not there just write panic info and no minidump let retry_frequency = Duration::from_millis(100); for _ in 0..5 { - if CRASH_HANDLER.load(Ordering::Acquire) { + if let Some(client) = CRASH_HANDLER.get() { + client + .send_message( + 2, + serde_json::to_vec(&CrashPanic { message, span }).unwrap(), + ) + .ok(); log::error!("triggering a crash to generate a minidump..."); #[cfg(target_os = "linux")] CrashHandler.simulate_signal(crash_handler::Signal::Trap as u32); @@ -159,14 +233,30 @@ pub fn crash_server(socket: &Path) { log::info!("Couldn't create socket, there may already be a running crash server"); return; }; - let ab = AtomicBool::new(false); + + let shutdown = Arc::new(AtomicBool::new(false)); + let has_connection = Arc::new(AtomicBool::new(false)); + + std::thread::spawn({ + let shutdown = shutdown.clone(); + let has_connection = has_connection.clone(); + move || { + std::thread::sleep(CRASH_HANDLER_CONNECT_TIMEOUT); + if !has_connection.load(Ordering::SeqCst) { + shutdown.store(true, Ordering::SeqCst); + } + } + }); + server .run( Box::new(CrashServer { - session_id: OnceLock::new(), + initialization_params: OnceLock::new(), + panic_info: OnceLock::new(), + has_connection, }), - &ab, - Some(CRASH_HANDLER_TIMEOUT), + &shutdown, + Some(CRASH_HANDLER_PING_TIMEOUT), ) .expect("failed to run server"); } diff --git a/crates/dap_adapters/src/codelldb.rs b/crates/dap_adapters/src/codelldb.rs index 5b88db4432..842bb264a8 100644 --- a/crates/dap_adapters/src/codelldb.rs +++ b/crates/dap_adapters/src/codelldb.rs @@ -338,8 +338,8 @@ impl DebugAdapter for CodeLldbDebugAdapter { if command.is_none() { delegate.output_to_console(format!("Checking latest version of {}...", self.name())); let adapter_path = paths::debug_adapters_dir().join(&Self::ADAPTER_NAME); - let version_path = - if let Ok(version) = self.fetch_latest_adapter_version(delegate).await { + let version_path = match self.fetch_latest_adapter_version(delegate).await { + Ok(version) => { adapters::download_adapter_from_github( self.name(), version.clone(), @@ -351,10 +351,26 @@ impl DebugAdapter for CodeLldbDebugAdapter { adapter_path.join(format!("{}_{}", Self::ADAPTER_NAME, version.tag_name)); remove_matching(&adapter_path, |entry| entry != version_path).await; version_path - } else { - let mut paths = delegate.fs().read_dir(&adapter_path).await?; - paths.next().await.context("No adapter found")?? - }; + } + Err(e) => { + delegate.output_to_console("Unable to fetch latest version".to_string()); + log::error!("Error fetching latest version of {}: {}", self.name(), e); + delegate.output_to_console(format!( + "Searching for adapters in: {}", + adapter_path.display() + )); + let mut paths = delegate + .fs() + .read_dir(&adapter_path) + .await + .context("No cached adapter directory")?; + paths + .next() + .await + .context("No cached adapter found")? + .context("No cached adapter found")? + } + }; let adapter_dir = version_path.join("extension").join("adapter"); let path = adapter_dir.join("codelldb").to_string_lossy().to_string(); self.path_to_codelldb.set(path.clone()).ok(); diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index 461ce6fbb3..a2bd934311 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -152,6 +152,9 @@ impl PythonDebugAdapter { maybe!(async move { let response = latest_release.filter(|response| response.status().is_success())?; + let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME); + std::fs::create_dir_all(&download_dir).ok()?; + let mut output = String::new(); response .into_body() diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index 91382c74ae..1d44c5c244 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -36,7 +36,7 @@ use settings::Settings; use std::sync::{Arc, LazyLock}; use task::{DebugScenario, TaskContext}; use tree_sitter::{Query, StreamingIterator as _}; -use ui::{ContextMenu, Divider, PopoverMenuHandle, Tooltip, prelude::*}; +use ui::{ContextMenu, Divider, PopoverMenuHandle, Tab, Tooltip, prelude::*}; use util::{ResultExt, debug_panic, maybe}; use workspace::SplitDirection; use workspace::item::SaveOptions; @@ -642,12 +642,14 @@ impl DebugPanel { } }) }; + let documentation_button = || { IconButton::new("debug-open-documentation", IconName::CircleHelp) .icon_size(IconSize::Small) .on_click(move |_, _, cx| cx.open_url("https://zed.dev/docs/debugger")) .tooltip(Tooltip::text("Open Documentation")) }; + let logs_button = || { IconButton::new("debug-open-logs", IconName::Notepad) .icon_size(IconSize::Small) @@ -658,16 +660,18 @@ impl DebugPanel { }; Some( - div.border_b_1() - .border_color(cx.theme().colors().border) - .p_1() + div.w_full() + .py_1() + .px_1p5() .justify_between() - .w_full() + .border_b_1() + .border_color(cx.theme().colors().border) .when(is_side, |this| this.gap_1()) .child( h_flex() + .justify_between() .child( - h_flex().gap_2().w_full().when_some( + h_flex().gap_1().w_full().when_some( active_session .as_ref() .map(|session| session.read(cx).running_state()), @@ -679,6 +683,7 @@ impl DebugPanel { let capabilities = running_state.read(cx).capabilities(cx); let supports_detach = running_state.read(cx).session().read(cx).is_attached(); + this.map(|this| { if thread_status == ThreadStatus::Running { this.child( @@ -686,8 +691,7 @@ impl DebugPanel { "debug-pause", IconName::DebugPause, ) - .icon_size(IconSize::XSmall) - .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _window, cx| { @@ -698,7 +702,7 @@ impl DebugPanel { let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Pause program", + "Pause Program", &Pause, &focus_handle, window, @@ -713,8 +717,7 @@ impl DebugPanel { "debug-continue", IconName::DebugContinue, ) - .icon_size(IconSize::XSmall) - .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _window, cx| this.continue_thread(cx), @@ -724,7 +727,7 @@ impl DebugPanel { let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Continue program", + "Continue Program", &Continue, &focus_handle, window, @@ -737,8 +740,7 @@ impl DebugPanel { }) .child( IconButton::new("debug-step-over", IconName::ArrowRight) - .icon_size(IconSize::XSmall) - .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _window, cx| { @@ -750,7 +752,7 @@ impl DebugPanel { let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Step over", + "Step Over", &StepOver, &focus_handle, window, @@ -764,8 +766,7 @@ impl DebugPanel { "debug-step-into", IconName::ArrowDownRight, ) - .icon_size(IconSize::XSmall) - .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _window, cx| { @@ -777,7 +778,7 @@ impl DebugPanel { let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Step in", + "Step In", &StepInto, &focus_handle, window, @@ -789,7 +790,6 @@ impl DebugPanel { .child( IconButton::new("debug-step-out", IconName::ArrowUpRight) .icon_size(IconSize::Small) - .shape(ui::IconButtonShape::Square) .on_click(window.listener_for( &running_state, |this, _, _window, cx| { @@ -801,7 +801,7 @@ impl DebugPanel { let focus_handle = focus_handle.clone(); move |window, cx| { Tooltip::for_action_in( - "Step out", + "Step Out", &StepOut, &focus_handle, window, @@ -813,7 +813,7 @@ impl DebugPanel { .child(Divider::vertical()) .child( IconButton::new("debug-restart", IconName::RotateCcw) - .icon_size(IconSize::XSmall) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, window, cx| { @@ -835,7 +835,7 @@ impl DebugPanel { ) .child( IconButton::new("debug-stop", IconName::Power) - .icon_size(IconSize::XSmall) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _window, cx| { @@ -890,7 +890,7 @@ impl DebugPanel { thread_status != ThreadStatus::Stopped && thread_status != ThreadStatus::Running, ) - .icon_size(IconSize::XSmall) + .icon_size(IconSize::Small) .on_click(window.listener_for( &running_state, |this, _, _, cx| { @@ -915,7 +915,6 @@ impl DebugPanel { }, ), ) - .justify_around() .when(is_side, |this| { this.child(new_session_button()) .child(logs_button()) @@ -924,7 +923,7 @@ impl DebugPanel { ) .child( h_flex() - .gap_2() + .gap_0p5() .when(is_side, |this| this.justify_between()) .child( h_flex().when_some( @@ -954,12 +953,15 @@ impl DebugPanel { ) }) }) - .when(!is_side, |this| this.gap_2().child(Divider::vertical())) + .when(!is_side, |this| { + this.gap_0p5().child(Divider::vertical()) + }) }, ), ) .child( h_flex() + .gap_0p5() .children(self.render_session_menu( self.active_session(), self.running_state(cx), @@ -1702,6 +1704,7 @@ impl Render for DebugPanel { this.child(active_session) } else { let docked_to_bottom = self.position(window, cx) == DockPosition::Bottom; + let welcome_experience = v_flex() .when_else( docked_to_bottom, @@ -1767,54 +1770,58 @@ impl Render for DebugPanel { ); }), ); - let breakpoint_list = - v_flex() - .group("base-breakpoint-list") - .items_start() - .when_else( - docked_to_bottom, - |this| this.min_w_1_3().h_full(), - |this| this.w_full().h_2_3(), - ) - .p_1() - .child( - h_flex() - .pl_1() - .w_full() - .justify_between() - .child(Label::new("Breakpoints").size(LabelSize::Small)) - .child(h_flex().visible_on_hover("base-breakpoint-list").child( + + let breakpoint_list = v_flex() + .group("base-breakpoint-list") + .when_else( + docked_to_bottom, + |this| this.min_w_1_3().h_full(), + |this| this.size_full().h_2_3(), + ) + .child( + h_flex() + .track_focus(&self.breakpoint_list.focus_handle(cx)) + .h(Tab::container_height(cx)) + .p_1p5() + .w_full() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border_variant) + .child(Label::new("Breakpoints").size(LabelSize::Small)) + .child( + h_flex().visible_on_hover("base-breakpoint-list").child( self.breakpoint_list.read(cx).render_control_strip(), - )) - .track_focus(&self.breakpoint_list.focus_handle(cx)), - ) - .child(Divider::horizontal()) - .child(self.breakpoint_list.clone()); + ), + ), + ) + .child(self.breakpoint_list.clone()); + this.child( v_flex() - .h_full() + .size_full() .gap_1() .items_center() .justify_center() - .child( - div() - .when_else(docked_to_bottom, Div::h_flex, Div::v_flex) - .size_full() - .map(|this| { - if docked_to_bottom { - this.items_start() - .child(breakpoint_list) - .child(Divider::vertical()) - .child(welcome_experience) - .child(Divider::vertical()) - } else { - this.items_end() - .child(welcome_experience) - .child(Divider::horizontal()) - .child(breakpoint_list) - } - }), - ), + .map(|this| { + if docked_to_bottom { + this.child( + h_flex() + .size_full() + .child(breakpoint_list) + .child(Divider::vertical()) + .child(welcome_experience) + .child(Divider::vertical()), + ) + } else { + this.child( + v_flex() + .size_full() + .child(welcome_experience) + .child(Divider::horizontal()) + .child(breakpoint_list), + ) + } + }), ) } }) diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index a3e2805e2b..f3117aee07 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -48,10 +48,8 @@ use task::{ }; use terminal_view::TerminalView; use ui::{ - ActiveTheme, AnyElement, App, ButtonCommon as _, Clickable as _, Context, FluentBuilder, - IconButton, IconName, IconSize, InteractiveElement, IntoElement, Label, LabelCommon as _, - ParentElement, Render, SharedString, StatefulInteractiveElement, Styled, Tab, Tooltip, - VisibleOnHover, VisualContext, Window, div, h_flex, v_flex, + FluentBuilder, IntoElement, Render, StatefulInteractiveElement, Tab, Tooltip, VisibleOnHover, + VisualContext, prelude::*, }; use util::ResultExt; use variable_list::VariableList; @@ -293,7 +291,7 @@ pub(crate) fn new_debugger_pane( let Some(project) = project.upgrade() else { return ControlFlow::Break(()); }; - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { @@ -419,13 +417,14 @@ pub(crate) fn new_debugger_pane( .map_or(false, |item| item.read(cx).hovered); h_flex() - .group(pane_group_id.clone()) - .justify_between() - .bg(cx.theme().colors().tab_bar_background) - .border_b_1() - .px_2() - .border_color(cx.theme().colors().border) .track_focus(&focus_handle) + .group(pane_group_id.clone()) + .pl_1p5() + .pr_1() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().tab_bar_background) .on_action(|_: &menu::Cancel, window, cx| { if cx.stop_active_drag(window) { return; @@ -503,7 +502,7 @@ pub(crate) fn new_debugger_pane( .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail: 0, is_active: selected, ix, @@ -514,6 +513,7 @@ pub(crate) fn new_debugger_pane( ) .child({ let zoomed = pane.is_zoomed(); + h_flex() .visible_on_hover(pane_group_id) .when(is_hovered, |this| this.visible()) @@ -537,7 +537,7 @@ pub(crate) fn new_debugger_pane( IconName::Maximize }, ) - .icon_size(IconSize::XSmall) + .icon_size(IconSize::Small) .on_click(cx.listener(move |pane, _, _, cx| { let is_zoomed = pane.is_zoomed(); pane.set_zoomed(!is_zoomed, cx); @@ -592,10 +592,11 @@ impl DebugTerminal { } impl gpui::Render for DebugTerminal { - fn render(&mut self, _window: &mut Window, _: &mut Context) -> impl IntoElement { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div() - .size_full() .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) .children(self.terminal.clone()) } } diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index 326fb84e20..38108dbfbc 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -23,11 +23,8 @@ use project::{ worktree_store::WorktreeStore, }; use ui::{ - ActiveTheme, AnyElement, App, ButtonCommon, Clickable, Color, Context, Disableable, Div, - Divider, FluentBuilder as _, Icon, IconButton, IconName, IconSize, InteractiveElement, - IntoElement, Label, LabelCommon, LabelSize, ListItem, ParentElement, Render, RenderOnce, - Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable, - Tooltip, Window, div, h_flex, px, v_flex, + Divider, DividerColor, FluentBuilder as _, Indicator, IntoElement, ListItem, Render, Scrollbar, + ScrollbarState, StatefulInteractiveElement, Tooltip, prelude::*, }; use workspace::Workspace; use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint}; @@ -569,6 +566,7 @@ impl BreakpointList { .map(|session| SupportedBreakpointProperties::from(session.read(cx).capabilities())) .unwrap_or_else(SupportedBreakpointProperties::empty); let strip_mode = self.strip_mode; + uniform_list( "breakpoint-list", self.breakpoints.len(), @@ -591,7 +589,7 @@ impl BreakpointList { }), ) .track_scroll(self.scroll_handle.clone()) - .flex_grow() + .flex_1() } fn render_vertical_scrollbar(&self, cx: &mut Context) -> Stateful
{ @@ -630,6 +628,7 @@ impl BreakpointList { pub(crate) fn render_control_strip(&self) -> AnyElement { let selection_kind = self.selection_kind(); let focus_handle = self.focus_handle.clone(); + let remove_breakpoint_tooltip = selection_kind.map(|(kind, _)| match kind { SelectedBreakpointKind::Source => "Remove breakpoint from a breakpoint list", SelectedBreakpointKind::Exception => { @@ -637,6 +636,7 @@ impl BreakpointList { } SelectedBreakpointKind::Data => "Remove data breakpoint from a breakpoint list", }); + let toggle_label = selection_kind.map(|(_, is_enabled)| { if is_enabled { ( @@ -649,13 +649,12 @@ impl BreakpointList { }); h_flex() - .gap_2() .child( IconButton::new( "disable-breakpoint-breakpoint-list", IconName::DebugDisabledBreakpoint, ) - .icon_size(IconSize::XSmall) + .icon_size(IconSize::Small) .when_some(toggle_label, |this, (label, meta)| { this.tooltip({ let focus_handle = focus_handle.clone(); @@ -681,9 +680,8 @@ impl BreakpointList { }), ) .child( - IconButton::new("remove-breakpoint-breakpoint-list", IconName::Close) - .icon_size(IconSize::XSmall) - .icon_color(ui::Color::Error) + IconButton::new("remove-breakpoint-breakpoint-list", IconName::Trash) + .icon_size(IconSize::Small) .when_some(remove_breakpoint_tooltip, |this, tooltip| { this.tooltip({ let focus_handle = focus_handle.clone(); @@ -710,7 +708,6 @@ impl BreakpointList { } }), ) - .mr_2() .into_any_element() } } @@ -791,6 +788,7 @@ impl Render for BreakpointList { .chain(data_breakpoints) .chain(exception_breakpoints), ); + v_flex() .id("breakpoint-list") .key_context("BreakpointList") @@ -806,35 +804,33 @@ impl Render for BreakpointList { .on_action(cx.listener(Self::next_breakpoint_property)) .on_action(cx.listener(Self::previous_breakpoint_property)) .size_full() - .m_0p5() - .child( - v_flex() - .size_full() - .child(self.render_list(cx)) - .child(self.render_vertical_scrollbar(cx)), - ) + .pt_1() + .child(self.render_list(cx)) + .child(self.render_vertical_scrollbar(cx)) .when_some(self.strip_mode, |this, _| { - this.child(Divider::horizontal()).child( - h_flex() - // .w_full() - .m_0p5() - .p_0p5() - .border_1() - .rounded_sm() - .when( - self.input.focus_handle(cx).contains_focused(window, cx), - |this| { - let colors = cx.theme().colors(); - let border = if self.input.read(cx).read_only(cx) { - colors.border_disabled - } else { - colors.border_focused - }; - this.border_color(border) - }, - ) - .child(self.input.clone()), - ) + this.child(Divider::horizontal().color(DividerColor::Border)) + .child( + h_flex() + .p_1() + .rounded_sm() + .bg(cx.theme().colors().editor_background) + .border_1() + .when( + self.input.focus_handle(cx).contains_focused(window, cx), + |this| { + let colors = cx.theme().colors(); + + let border_color = if self.input.read(cx).read_only(cx) { + colors.border_disabled + } else { + colors.border_transparent + }; + + this.border_color(border_color) + }, + ) + .child(self.input.clone()), + ) }) } } @@ -865,12 +861,17 @@ impl LineBreakpoint { let path = self.breakpoint.path.clone(); let row = self.breakpoint.row; let is_enabled = self.breakpoint.state.is_enabled(); + let indicator = div() .id(SharedString::from(format!( "breakpoint-ui-toggle-{:?}/{}:{}", self.dir, self.name, self.line ))) - .cursor_pointer() + .child( + Icon::new(icon_name) + .color(Color::Debugger) + .size(IconSize::XSmall), + ) .tooltip({ let focus_handle = focus_handle.clone(); move |window, cx| { @@ -902,17 +903,14 @@ impl LineBreakpoint { .ok(); } }) - .child( - Icon::new(icon_name) - .color(Color::Debugger) - .size(IconSize::XSmall), - ) .on_mouse_down(MouseButton::Left, move |_, _, _| {}); ListItem::new(SharedString::from(format!( "breakpoint-ui-item-{:?}/{}:{}", self.dir, self.name, self.line ))) + .toggle_state(is_selected) + .inset(true) .on_click({ let weak = weak.clone(); move |_, window, cx| { @@ -922,23 +920,20 @@ impl LineBreakpoint { .ok(); } }) - .start_slot(indicator) - .rounded() .on_secondary_mouse_down(|_, _, cx| { cx.stop_propagation(); }) + .start_slot(indicator) .child( h_flex() - .w_full() - .mr_4() - .py_0p5() - .gap_1() - .min_h(px(26.)) - .justify_between() .id(SharedString::from(format!( "breakpoint-ui-on-click-go-to-line-{:?}/{}:{}", self.dir, self.name, self.line ))) + .w_full() + .gap_1() + .min_h(rems_from_px(26.)) + .justify_between() .on_click({ let weak = weak.clone(); move |_, window, cx| { @@ -949,9 +944,9 @@ impl LineBreakpoint { .ok(); } }) - .cursor_pointer() .child( h_flex() + .id("label-container") .gap_0p5() .child( Label::new(format!("{}:{}", self.name, self.line)) @@ -971,11 +966,13 @@ impl LineBreakpoint { .line_height_style(ui::LineHeightStyle::UiLabel) .truncate(), ) - })), + })) + .when_some(self.dir.as_ref(), |this, parent_dir| { + this.tooltip(Tooltip::text(format!( + "Worktree parent path: {parent_dir}" + ))) + }), ) - .when_some(self.dir.as_ref(), |this, parent_dir| { - this.tooltip(Tooltip::text(format!("Worktree parent path: {parent_dir}"))) - }) .child(BreakpointOptionsStrip { props, breakpoint: BreakpointEntry { @@ -988,15 +985,16 @@ impl LineBreakpoint { index: ix, }), ) - .toggle_state(is_selected) } } + #[derive(Clone, Debug)] struct ExceptionBreakpoint { id: String, data: ExceptionBreakpointsFilter, is_enabled: bool, } + #[derive(Clone, Debug)] struct DataBreakpoint(project::debugger::session::DataBreakpointState); @@ -1017,17 +1015,24 @@ impl DataBreakpoint { }; let is_enabled = self.0.is_enabled; let id = self.0.dap.data_id.clone(); + ListItem::new(SharedString::from(format!( "data-breakpoint-ui-item-{}", self.0.dap.data_id ))) - .rounded() + .toggle_state(is_selected) + .inset(true) .start_slot( div() .id(SharedString::from(format!( "data-breakpoint-ui-item-{}-click-handler", self.0.dap.data_id ))) + .child( + Icon::new(IconName::Binary) + .color(color) + .size(IconSize::Small), + ) .tooltip({ let focus_handle = focus_handle.clone(); move |window, cx| { @@ -1052,25 +1057,18 @@ impl DataBreakpoint { }) .ok(); } - }) - .cursor_pointer() - .child( - Icon::new(IconName::Binary) - .color(color) - .size(IconSize::Small), - ), + }), ) .child( h_flex() .w_full() - .mr_4() - .py_0p5() + .gap_1() + .min_h(rems_from_px(26.)) .justify_between() .child( v_flex() .py_1() .gap_1() - .min_h(px(26.)) .justify_center() .id(("data-breakpoint-label", ix)) .child( @@ -1091,7 +1089,6 @@ impl DataBreakpoint { index: ix, }), ) - .toggle_state(is_selected) } } @@ -1113,10 +1110,13 @@ impl ExceptionBreakpoint { let id = SharedString::from(&self.id); let is_enabled = self.is_enabled; let weak = list.clone(); + ListItem::new(SharedString::from(format!( "exception-breakpoint-ui-item-{}", self.id ))) + .toggle_state(is_selected) + .inset(true) .on_click({ let list = list.clone(); move |_, window, cx| { @@ -1124,7 +1124,6 @@ impl ExceptionBreakpoint { .ok(); } }) - .rounded() .on_secondary_mouse_down(|_, _, cx| { cx.stop_propagation(); }) @@ -1134,6 +1133,11 @@ impl ExceptionBreakpoint { "exception-breakpoint-ui-item-{}-click-handler", self.id ))) + .child( + Icon::new(IconName::Flame) + .color(color) + .size(IconSize::Small), + ) .tooltip({ let focus_handle = focus_handle.clone(); move |window, cx| { @@ -1158,25 +1162,18 @@ impl ExceptionBreakpoint { }) .ok(); } - }) - .cursor_pointer() - .child( - Icon::new(IconName::Flame) - .color(color) - .size(IconSize::Small), - ), + }), ) .child( h_flex() .w_full() - .mr_4() - .py_0p5() + .gap_1() + .min_h(rems_from_px(26.)) .justify_between() .child( v_flex() .py_1() .gap_1() - .min_h(px(26.)) .justify_center() .id(("exception-breakpoint-label", ix)) .child( @@ -1200,7 +1197,6 @@ impl ExceptionBreakpoint { index: ix, }), ) - .toggle_state(is_selected) } } #[derive(Clone, Debug)] @@ -1302,6 +1298,7 @@ impl BreakpointEntry { } } } + bitflags::bitflags! { #[derive(Clone, Copy)] pub struct SupportedBreakpointProperties: u32 { @@ -1360,6 +1357,7 @@ impl BreakpointOptionsStrip { fn is_toggled(&self, expected_mode: ActiveBreakpointStripMode) -> bool { self.is_selected && self.strip_mode == Some(expected_mode) } + fn on_click_callback( &self, mode: ActiveBreakpointStripMode, @@ -1379,7 +1377,8 @@ impl BreakpointOptionsStrip { .ok(); } } - fn add_border( + + fn add_focus_styles( &self, kind: ActiveBreakpointStripMode, available: bool, @@ -1388,22 +1387,25 @@ impl BreakpointOptionsStrip { ) -> impl Fn(Div) -> Div { move |this: Div| { // Avoid layout shifts in case there's no colored border - let this = this.border_2().rounded_sm(); + let this = this.border_1().rounded_sm(); + let color = cx.theme().colors(); + if self.is_selected && self.strip_mode == Some(kind) { - let theme = cx.theme().colors(); if self.focus_handle.is_focused(window) { - this.border_color(theme.border_selected) + this.bg(color.editor_background) + .border_color(color.border_focused) } else { - this.border_color(theme.border_disabled) + this.border_color(color.border) } } else if !available { - this.border_color(cx.theme().colors().border_disabled) + this.border_color(color.border_transparent) } else { this } } } } + impl RenderOnce for BreakpointOptionsStrip { fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement { let id = self.breakpoint.id(); @@ -1426,73 +1428,117 @@ impl RenderOnce for BreakpointOptionsStrip { }; let color_for_toggle = |is_enabled| { if is_enabled { - ui::Color::Default + Color::Default } else { - ui::Color::Muted + Color::Muted } }; h_flex() - .gap_1() + .gap_px() + .mr_3() // Space to avoid overlapping with the scrollbar .child( - div().map(self.add_border(ActiveBreakpointStripMode::Log, supports_logs, window, cx)) + div() + .map(self.add_focus_styles( + ActiveBreakpointStripMode::Log, + supports_logs, + window, + cx, + )) .child( IconButton::new( SharedString::from(format!("{id}-log-toggle")), IconName::Notepad, ) - .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) .style(style_for_toggle(ActiveBreakpointStripMode::Log, has_logs)) + .icon_size(IconSize::Small) .icon_color(color_for_toggle(has_logs)) + .when(has_logs, |this| this.indicator(Indicator::dot().color(Color::Info))) .disabled(!supports_logs) .toggle_state(self.is_toggled(ActiveBreakpointStripMode::Log)) - .on_click(self.on_click_callback(ActiveBreakpointStripMode::Log)).tooltip(|window, cx| Tooltip::with_meta("Set Log Message", None, "Set log message to display (instead of stopping) when a breakpoint is hit", window, cx)) + .on_click(self.on_click_callback(ActiveBreakpointStripMode::Log)) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Set Log Message", + None, + "Set log message to display (instead of stopping) when a breakpoint is hit.", + window, + cx, + ) + }), ) .when(!has_logs && !self.is_selected, |this| this.invisible()), ) .child( - div().map(self.add_border( - ActiveBreakpointStripMode::Condition, - supports_condition, - window, cx - )) + div() + .map(self.add_focus_styles( + ActiveBreakpointStripMode::Condition, + supports_condition, + window, + cx, + )) .child( IconButton::new( SharedString::from(format!("{id}-condition-toggle")), IconName::SplitAlt, ) - .icon_size(IconSize::XSmall) + .shape(ui::IconButtonShape::Square) .style(style_for_toggle( ActiveBreakpointStripMode::Condition, - has_condition + has_condition, )) + .icon_size(IconSize::Small) .icon_color(color_for_toggle(has_condition)) + .when(has_condition, |this| this.indicator(Indicator::dot().color(Color::Info))) .disabled(!supports_condition) .toggle_state(self.is_toggled(ActiveBreakpointStripMode::Condition)) .on_click(self.on_click_callback(ActiveBreakpointStripMode::Condition)) - .tooltip(|window, cx| Tooltip::with_meta("Set Condition", None, "Set condition to evaluate when a breakpoint is hit. Program execution will stop only when the condition is met", window, cx)) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Set Condition", + None, + "Set condition to evaluate when a breakpoint is hit. Program execution will stop only when the condition is met.", + window, + cx, + ) + }), ) .when(!has_condition && !self.is_selected, |this| this.invisible()), ) .child( - div().map(self.add_border( - ActiveBreakpointStripMode::HitCondition, - supports_hit_condition,window, cx - )) + div() + .map(self.add_focus_styles( + ActiveBreakpointStripMode::HitCondition, + supports_hit_condition, + window, + cx, + )) .child( IconButton::new( SharedString::from(format!("{id}-hit-condition-toggle")), IconName::ArrowDown10, ) - .icon_size(IconSize::XSmall) .style(style_for_toggle( ActiveBreakpointStripMode::HitCondition, has_hit_condition, )) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) .icon_color(color_for_toggle(has_hit_condition)) + .when(has_hit_condition, |this| this.indicator(Indicator::dot().color(Color::Info))) .disabled(!supports_hit_condition) .toggle_state(self.is_toggled(ActiveBreakpointStripMode::HitCondition)) - .on_click(self.on_click_callback(ActiveBreakpointStripMode::HitCondition)).tooltip(|window, cx| Tooltip::with_meta("Set Hit Condition", None, "Set expression that controls how many hits of the breakpoint are ignored.", window, cx)) + .on_click(self.on_click_callback(ActiveBreakpointStripMode::HitCondition)) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Set Hit Condition", + None, + "Set expression that controls how many hits of the breakpoint are ignored.", + window, + cx, + ) + }), ) .when(!has_hit_condition && !self.is_selected, |this| { this.invisible() diff --git a/crates/debugger_ui/src/session/running/console.rs b/crates/debugger_ui/src/session/running/console.rs index daf4486f81..e6308518e4 100644 --- a/crates/debugger_ui/src/session/running/console.rs +++ b/crates/debugger_ui/src/session/running/console.rs @@ -367,7 +367,7 @@ impl Console { .when_some(keybinding_target.clone(), |el, keybinding_target| { el.context(keybinding_target.clone()) }) - .action("Watch expression", WatchExpression.boxed_clone()) + .action("Watch Expression", WatchExpression.boxed_clone()) })) }) }, @@ -452,18 +452,22 @@ impl Render for Console { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let query_focus_handle = self.query_bar.focus_handle(cx); self.update_output(window, cx); + v_flex() .track_focus(&self.focus_handle) .key_context("DebugConsole") .on_action(cx.listener(Self::evaluate)) .on_action(cx.listener(Self::watch_expression)) .size_full() + .border_2() + .bg(cx.theme().colors().editor_background) .child(self.render_console(cx)) .when(self.is_running(cx), |this| { this.child(Divider::horizontal()).child( h_flex() .on_action(cx.listener(Self::previous_query)) .on_action(cx.listener(Self::next_query)) + .p_1() .gap_1() .bg(cx.theme().colors().editor_background) .child(self.render_query_bar(cx)) @@ -474,6 +478,9 @@ impl Render for Console { .on_click(move |_, window, cx| { window.dispatch_action(Box::new(Confirm), cx) }) + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::Compact) + .child(Label::new("Evaluate")) .tooltip({ let query_focus_handle = query_focus_handle.clone(); @@ -486,10 +493,7 @@ impl Render for Console { cx, ) } - }) - .layer(ui::ElevationIndex::ModalSurface) - .size(ui::ButtonSize::Compact) - .child(Label::new("Evaluate")), + }), self.render_submit_menu( ElementId::Name("split-button-right-confirm-button".into()), Some(query_focus_handle.clone()), @@ -499,7 +503,6 @@ impl Render for Console { )), ) }) - .border_2() } } diff --git a/crates/debugger_ui/src/session/running/memory_view.rs b/crates/debugger_ui/src/session/running/memory_view.rs index 75b8938371..f936d908b1 100644 --- a/crates/debugger_ui/src/session/running/memory_view.rs +++ b/crates/debugger_ui/src/session/running/memory_view.rs @@ -18,10 +18,8 @@ use project::debugger::{MemoryCell, dap_command::DataBreakpointContext, session: use settings::Settings; use theme::ThemeSettings; use ui::{ - ActiveTheme, AnyElement, App, Color, Context, ContextMenu, Div, Divider, DropdownMenu, Element, - FluentBuilder, Icon, IconName, InteractiveElement, IntoElement, Label, LabelCommon, - ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString, - StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex, + ContextMenu, Divider, DropdownMenu, FluentBuilder, IntoElement, PopoverMenuHandle, Render, + Scrollbar, ScrollbarState, StatefulInteractiveElement, Tooltip, prelude::*, }; use workspace::Workspace; diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 8fb223b2cb..5df1b13897 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -971,7 +971,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1065,7 +1065,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1239,7 +1239,7 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) { } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { @@ -1293,7 +1293,7 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext) fn «test»() { println!(); } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store.update_diagnostics( @@ -1450,7 +1450,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {"error warning info hiˇnt"}); diff --git a/crates/docs_preprocessor/src/main.rs b/crates/docs_preprocessor/src/main.rs index 1448f4cb52..17804b4281 100644 --- a/crates/docs_preprocessor/src/main.rs +++ b/crates/docs_preprocessor/src/main.rs @@ -8,7 +8,7 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::io::{self, Read}; use std::process; -use std::sync::LazyLock; +use std::sync::{LazyLock, OnceLock}; use util::paths::PathExt; static KEYMAP_MACOS: LazyLock = LazyLock::new(|| { @@ -388,7 +388,7 @@ fn handle_postprocessing() -> Result<()> { let meta_title = format!("{} | {}", page_title, meta_title); zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir)); let contents = contents.replace("#description#", meta_description); - let contents = TITLE_REGEX + let contents = title_regex() .replace(&contents, |_: ®ex::Captures| { format!("{}", meta_title) }) @@ -404,10 +404,8 @@ fn handle_postprocessing() -> Result<()> { ) -> &'a std::path::Path { &path.strip_prefix(&root).unwrap_or(&path) } - const TITLE_REGEX: std::cell::LazyCell = - std::cell::LazyCell::new(|| Regex::new(r"\s*(.*?)\s*").unwrap()); fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String { - let title_tag_contents = &TITLE_REGEX + let title_tag_contents = &title_regex() .captures(&contents) .with_context(|| format!("Failed to find title in {:?}", pretty_path)) .expect("Page has element")[1]; @@ -420,3 +418,8 @@ fn handle_postprocessing() -> Result<()> { title } } + +fn title_regex() -> &'static Regex { + static TITLE_REGEX: OnceLock<Regex> = OnceLock::new(); + TITLE_REGEX.get_or_init(|| Regex::new(r"<title>\s*(.*?)\s*").unwrap()) +} diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 3d3b43d71b..4632a03daf 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -127,7 +127,7 @@ impl Render for EditPredictionButton { }), ); } - let this = cx.entity().clone(); + let this = cx.entity(); div().child( PopoverMenu::new("copilot") @@ -182,7 +182,7 @@ impl Render for EditPredictionButton { let icon = status.to_icon(); let tooltip_text = status.to_tooltip(); let has_menu = status.has_menu(); - let this = cx.entity().clone(); + let this = cx.entity(); let fs = self.fs.clone(); return div().child( @@ -331,7 +331,7 @@ impl Render for EditPredictionButton { }) }); - let this = cx.entity().clone(); + let this = cx.entity(); let mut popover_menu = PopoverMenu::new("zeta") .menu(move |window, cx| { diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index 39433b3c27..ce02c4d2bf 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -273,6 +273,16 @@ pub enum UuidVersion { V7, } +/// Splits selection into individual lines. +#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct SplitSelectionIntoLines { + /// Keep the text selected after splitting instead of collapsing to cursors. + #[serde(default)] + pub keep_selections: bool, +} + /// Goes to the next diagnostic in the file. #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] #[action(namespace = editor)] @@ -672,8 +682,6 @@ actions!( SortLinesCaseInsensitive, /// Sorts selected lines case-sensitively. SortLinesCaseSensitive, - /// Splits selection into individual lines. - SplitSelectionIntoLines, /// Stops the language server for the current file. StopLanguageServer, /// Switches between source and header files. diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index e25c02432d..c4c9f2004a 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -2290,8 +2290,6 @@ mod tests { fn test_blocks_on_wrapped_lines(cx: &mut gpui::TestAppContext) { cx.update(init_test); - let _font_id = cx.text_system().font_id(&font("Helvetica")).unwrap(); - let text = "one two three\nfour five six\nseven eight"; let buffer = cx.update(|cx| MultiBuffer::build_simple(text, cx)); diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index 269f8f0c40..caa4882a6e 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -1223,7 +1223,7 @@ mod tests { let tab_size = NonZeroU32::new(rng.gen_range(1..=4)).unwrap(); let font = test_font(); - let _font_id = text_system.font_id(&font); + let _font_id = text_system.resolve_font(&font); let font_size = px(14.0); log::info!("Tab size: {}", tab_size); diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d1bf95c794..0111e91347 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -250,6 +250,24 @@ pub type RenderDiffHunkControlsFn = Arc< ) -> AnyElement, >; +enum ReportEditorEvent { + Saved { auto_saved: bool }, + EditorOpened, + ZetaTosClicked, + Closed, +} + +impl ReportEditorEvent { + pub fn event_type(&self) -> &'static str { + match self { + Self::Saved { .. } => "Editor Saved", + Self::EditorOpened => "Editor Opened", + Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked", + Self::Closed => "Editor Closed", + } + } +} + struct InlineValueCache { enabled: bool, inlays: Vec, @@ -1021,9 +1039,7 @@ pub struct Editor { inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>, soft_wrap_mode_override: Option, hard_wrap: Option, - - // TODO: make this a access method - pub project: Option>, + project: Option>, semantics_provider: Option>, completion_provider: Option>, collaboration_hub: Option>, @@ -2308,7 +2324,7 @@ impl Editor { editor.go_to_active_debug_line(window, cx); if let Some(buffer) = buffer.read(cx).as_singleton() { - if let Some(project) = editor.project.as_ref() { + if let Some(project) = editor.project() { let handle = project.update(cx, |project, cx| { project.register_buffer_with_language_servers(&buffer, cx) }); @@ -2325,7 +2341,7 @@ impl Editor { } if editor.mode.is_full() { - editor.report_editor_event("Editor Opened", None, cx); + editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx); } editor @@ -2353,6 +2369,34 @@ impl Editor { .is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window)) } + pub fn is_range_selected(&mut self, range: &Range, cx: &mut Context) -> bool { + if self + .selections + .pending + .as_ref() + .is_some_and(|pending_selection| { + let snapshot = self.buffer().read(cx).snapshot(cx); + pending_selection + .selection + .range() + .includes(&range, &snapshot) + }) + { + return true; + } + + self.selections + .disjoint_in_range::(range.clone(), cx) + .into_iter() + .any(|selection| { + // This is needed to cover a corner case, if we just check for an existing + // selection in the fold range, having a cursor at the start of the fold + // marks it as selected. Non-empty selections don't cause this. + let length = selection.end - selection.start; + length > 0 + }) + } + pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { self.key_context_internal(self.has_active_edit_prediction(), window, cx) } @@ -2608,6 +2652,10 @@ impl Editor { &self.buffer } + pub fn project(&self) -> Option<&Entity> { + self.project.as_ref() + } + pub fn workspace(&self) -> Option> { self.workspace.as_ref()?.0.upgrade() } @@ -5194,7 +5242,7 @@ impl Editor { restrict_to_languages: Option<&HashSet>>, cx: &mut Context, ) -> HashMap, clock::Global, Range)> { - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return HashMap::default(); }; let project = project.read(cx); @@ -5276,7 +5324,7 @@ impl Editor { return None; } - let project = self.project.as_ref()?; + let project = self.project()?; let position = self.selections.newest_anchor().head(); let (buffer, buffer_position) = self .buffer @@ -6123,7 +6171,7 @@ impl Editor { cx: &mut App, ) -> Task> { maybe!({ - let project = self.project.as_ref()?; + let project = self.project()?; let dap_store = project.read(cx).dap_store(); let mut scenarios = vec![]; let resolved_tasks = resolved_tasks.as_ref()?; @@ -7889,7 +7937,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return breakpoint_display_points; }; @@ -9124,7 +9172,7 @@ impl Editor { .on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default()) .on_click(cx.listener(|this, _event, window, cx| { cx.stop_propagation(); - this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx); + this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx); window.dispatch_action( zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx, @@ -10483,7 +10531,7 @@ impl Editor { ) { if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let parent = match &entry.canonical_path { Some(canonical_path) => canonical_path.to_path_buf(), @@ -12158,6 +12206,8 @@ impl Editor { let clipboard_text = Cow::Borrowed(text); self.transact(window, cx, |this, window, cx| { + let had_active_edit_prediction = this.has_active_edit_prediction(); + if let Some(mut clipboard_selections) = clipboard_selections { let old_selections = this.selections.all::(cx); let all_selections_were_entire_line = @@ -12230,6 +12280,11 @@ impl Editor { } else { this.insert(&clipboard_text, window, cx); } + + let trigger_in_words = + this.show_edit_predictions_in_menu() || !had_active_edit_prediction; + + this.trigger_completion_on_input(&text, trigger_in_words, window, cx); }); } @@ -13587,7 +13642,7 @@ impl Editor { pub fn split_selection_into_lines( &mut self, - _: &SplitSelectionIntoLines, + action: &SplitSelectionIntoLines, window: &mut Window, cx: &mut Context, ) { @@ -13604,8 +13659,21 @@ impl Editor { let buffer = self.buffer.read(cx).read(cx); for selection in selections { for row in selection.start.row..selection.end.row { - let cursor = Point::new(row, buffer.line_len(MultiBufferRow(row))); - new_selection_ranges.push(cursor..cursor); + let line_start = Point::new(row, 0); + let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row))); + + if action.keep_selections { + // Keep the selection range for each line + let selection_start = if row == selection.start.row { + selection.start + } else { + line_start + }; + new_selection_ranges.push(selection_start..line_end); + } else { + // Collapse to cursor at end of line + new_selection_ranges.push(line_end..line_end); + } } let is_multiline_selection = selection.start.row != selection.end.row; @@ -13613,7 +13681,16 @@ impl Editor { // so this action feels more ergonomic when paired with other selection operations let should_skip_last = is_multiline_selection && selection.end.column == 0; if !should_skip_last { - new_selection_ranges.push(selection.end..selection.end); + if action.keep_selections { + if is_multiline_selection { + let line_start = Point::new(selection.end.row, 0); + new_selection_ranges.push(line_start..selection.end); + } else { + new_selection_ranges.push(selection.start..selection.end); + } + } else { + new_selection_ranges.push(selection.end..selection.end); + } } } } @@ -14828,7 +14905,7 @@ impl Editor { self.clear_tasks(); return Task::ready(()); } - let project = self.project.as_ref().map(Entity::downgrade); + let project = self.project().map(Entity::downgrade); let task_sources = self.lsp_task_sources(cx); let multi_buffer = self.buffer.downgrade(); cx.spawn_in(window, async move |editor, cx| { @@ -15817,19 +15894,30 @@ impl Editor { let tab_kind = match kind { Some(GotoDefinitionKind::Implementation) => "Implementations", - _ => "Definitions", + Some(GotoDefinitionKind::Symbol) | None => "Definitions", + Some(GotoDefinitionKind::Declaration) => "Declarations", + Some(GotoDefinitionKind::Type) => "Types", }; let title = editor .update_in(acx, |_, _, cx| { - let origin = locations.first().unwrap(); - let buffer = origin.buffer.read(cx); - format!( - "{} for {}", - tab_kind, - buffer - .text_for_range(origin.range.clone()) - .collect::() - ) + let target = locations + .iter() + .map(|location| { + location + .buffer + .read(cx) + .text_for_range(location.range.clone()) + .collect::() + }) + .filter(|text| !text.contains('\n')) + .unique() + .take(3) + .join(", "); + if target.is_empty() { + tab_kind.to_owned() + } else { + format!("{tab_kind} for {target}") + } }) .context("buffer title")?; @@ -16025,19 +16113,24 @@ impl Editor { } workspace.update_in(cx, |workspace, window, cx| { - let title = locations - .first() - .as_ref() + let target = locations + .iter() .map(|location| { - let buffer = location.buffer.read(cx); - format!( - "References to `{}`", - buffer - .text_for_range(location.range.clone()) - .collect::() - ) + location + .buffer + .read(cx) + .text_for_range(location.range.clone()) + .collect::() }) - .unwrap(); + .filter(|text| !text.contains('\n')) + .unique() + .take(3) + .join(", "); + let title = if target.is_empty() { + "References".to_owned() + } else { + format!("References to {target}") + }; Self::open_locations_in_multibuffer( workspace, locations, @@ -17001,7 +17094,7 @@ impl Editor { if !pull_diagnostics_settings.enabled { return None; } - let project = self.project.as_ref()?.downgrade(); + let project = self.project()?.downgrade(); let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms); let mut buffers = self.buffer.read(cx).all_buffers(); if let Some(buffer_id) = buffer_id { @@ -17965,7 +18058,7 @@ impl Editor { hunks: impl Iterator, cx: &mut App, ) -> Option<()> { - let project = self.project.as_ref()?; + let project = self.project()?; let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?; let diff = self.buffer.read(cx).diff_for(buffer_id)?; let buffer_snapshot = buffer.read(cx).snapshot(); @@ -18625,7 +18718,7 @@ impl Editor { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let buffer = buffer.read(cx); if let Some(project_path) = buffer.project_path(cx) { - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); project.absolute_path(&project_path, cx) } else { buffer @@ -18638,7 +18731,7 @@ impl Editor { fn target_file_path(&self, cx: &mut Context) -> Option { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let path = entry.path.to_path_buf(); Some(path) @@ -18859,7 +18952,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if let Some(project) = self.project.as_ref() { + if let Some(project) = self.project() { let Some(buffer) = self.buffer().read(cx).as_singleton() else { return; }; @@ -18975,7 +19068,7 @@ impl Editor { return Task::ready(Err(anyhow!("failed to determine buffer and selection"))); }; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return Task::ready(Err(anyhow!("editor does not have project"))); }; @@ -20182,6 +20275,7 @@ impl Editor { ); let old_cursor_shape = self.cursor_shape; + let old_show_breadcrumbs = self.show_breadcrumbs; { let editor_settings = EditorSettings::get_global(cx); @@ -20195,6 +20289,10 @@ impl Editor { cx.emit(EditorEvent::CursorShapeChanged); } + if old_show_breadcrumbs != self.show_breadcrumbs { + cx.emit(EditorEvent::BreadcrumbsChanged); + } + let project_settings = ProjectSettings::get_global(cx); self.serialize_dirty_buffers = !self.mode.is_minimap() && project_settings.session.restore_unsaved_buffers; @@ -20547,7 +20645,7 @@ impl Editor { fn report_editor_event( &self, - event_type: &'static str, + reported_event: ReportEditorEvent, file_extension: Option, cx: &App, ) { @@ -20581,15 +20679,30 @@ impl Editor { .show_edit_predictions; let project = project.read(cx); - telemetry::event!( - event_type, - file_extension, - vim_mode, - copilot_enabled, - copilot_enabled_for_language, - edit_predictions_provider, - is_via_ssh = project.is_via_ssh(), - ); + let event_type = reported_event.event_type(); + + if let ReportEditorEvent::Saved { auto_saved } = reported_event { + telemetry::event!( + event_type, + type = if auto_saved {"autosave"} else {"manual"}, + file_extension, + vim_mode, + copilot_enabled, + copilot_enabled_for_language, + edit_predictions_provider, + is_via_ssh = project.is_via_ssh(), + ); + } else { + telemetry::event!( + event_type, + file_extension, + vim_mode, + copilot_enabled, + copilot_enabled_for_language, + edit_predictions_provider, + is_via_ssh = project.is_via_ssh(), + ); + }; } /// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines, @@ -20942,7 +21055,7 @@ impl Editor { cx: &mut Context, ) { let workspace = self.workspace(); - let project = self.project.as_ref(); + let project = self.project(); let save_tasks = self.buffer().update(cx, |multi_buffer, cx| { let mut tasks = Vec::new(); for (buffer_id, changes) in revert_changes { @@ -22801,6 +22914,7 @@ pub enum EditorEvent { }, Reloaded, CursorShapeChanged, + BreadcrumbsChanged, PushedToNavHistory { anchor: Anchor, is_deactivate: bool, diff --git a/crates/editor/src/editor_settings.rs b/crates/editor/src/editor_settings.rs index 3d132651b8..d3a21c7642 100644 --- a/crates/editor/src/editor_settings.rs +++ b/crates/editor/src/editor_settings.rs @@ -132,6 +132,10 @@ pub struct StatusBar { /// /// Default: true pub active_language_button: bool, + /// Whether to show the cursor position button in the status bar. + /// + /// Default: true + pub cursor_position_button: bool, } #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -585,6 +589,10 @@ pub struct StatusBarContent { /// /// Default: true pub active_language_button: Option, + /// Whether to show the cursor position button in the status bar. + /// + /// Default: true + pub cursor_position_button: Option, } // Toolbar related settings diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index b31963c9c8..ef2bdc5da3 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -74,7 +74,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let editor1 = cx.add_window({ let events = events.clone(); |window, cx| { - let entity = cx.entity().clone(); + let entity = cx.entity(); cx.subscribe_in( &entity, window, @@ -95,7 +95,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let events = events.clone(); |window, cx| { cx.subscribe_in( - &cx.entity().clone(), + &cx.entity(), window, move |_, _, event: &EditorEvent, _, _| match event { EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")), @@ -1901,6 +1901,51 @@ fn test_beginning_of_line_stop_at_indent(cx: &mut TestAppContext) { }); } +#[gpui::test] +fn test_beginning_of_line_with_cursor_between_line_start_and_indent(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let move_to_beg = MoveToBeginningOfLine { + stop_at_soft_wraps: true, + stop_at_indent: true, + }; + + let editor = cx.add_window(|window, cx| { + let buffer = MultiBuffer::build_simple(" hello\nworld", cx); + build_editor(buffer, window, cx) + }); + + _ = editor.update(cx, |editor, window, cx| { + // test cursor between line_start and indent_start + editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { + s.select_display_ranges([ + DisplayPoint::new(DisplayRow(0), 3)..DisplayPoint::new(DisplayRow(0), 3) + ]); + }); + + // cursor should move to line_start + editor.move_to_beginning_of_line(&move_to_beg, window, cx); + assert_eq!( + editor.selections.display_ranges(cx), + &[DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0)] + ); + + // cursor should move to indent_start + editor.move_to_beginning_of_line(&move_to_beg, window, cx); + assert_eq!( + editor.selections.display_ranges(cx), + &[DisplayPoint::new(DisplayRow(0), 4)..DisplayPoint::new(DisplayRow(0), 4)] + ); + + // cursor should move to back to line_start + editor.move_to_beginning_of_line(&move_to_beg, window, cx); + assert_eq!( + editor.selections.display_ranges(cx), + &[DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0)] + ); + }); +} + #[gpui::test] fn test_prev_next_word_boundary(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -6356,7 +6401,7 @@ async fn test_split_selection_into_lines(cx: &mut TestAppContext) { fn test(cx: &mut EditorTestContext, initial_state: &'static str, expected_state: &'static str) { cx.set_state(initial_state); cx.update_editor(|e, window, cx| { - e.split_selection_into_lines(&SplitSelectionIntoLines, window, cx) + e.split_selection_into_lines(&Default::default(), window, cx) }); cx.assert_editor_state(expected_state); } @@ -6444,7 +6489,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA DisplayPoint::new(DisplayRow(4), 4)..DisplayPoint::new(DisplayRow(4), 4), ]) }); - editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx); + editor.split_selection_into_lines(&Default::default(), window, cx); assert_eq!( editor.display_text(cx), "aaaaa\nbbbbb\nccc⋯eeee\nfffff\nggggg\n⋯i" @@ -6460,7 +6505,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA DisplayPoint::new(DisplayRow(5), 0)..DisplayPoint::new(DisplayRow(0), 1) ]) }); - editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx); + editor.split_selection_into_lines(&Default::default(), window, cx); assert_eq!( editor.display_text(cx), "aaaaa\nbbbbb\nccccc\nddddd\neeeee\nfffff\nggggg\nhhhhh\niiiii" @@ -15037,7 +15082,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -19589,13 +19634,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) { editor.insert_creases(Some(crease), cx); let snapshot = editor.snapshot(window, cx); - let _div = snapshot.render_crease_toggle( - MultiBufferRow(1), - false, - cx.entity().clone(), - window, - cx, - ); + let _div = + snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx); snapshot }) .unwrap(); @@ -22456,7 +22496,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) { ); cx.update(|_, cx| { - workspace::reload(&workspace::Reload::default(), cx); + workspace::reload(cx); }); assert_language_servers_count( 1, diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index a7fd0abf88..5edfd7df30 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -3011,7 +3011,7 @@ impl EditorElement { .icon_color(Color::Custom(cx.theme().colors().editor_line_number)) .selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground)) .icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size()))) - .width(width.into()) + .width(width) .on_click(move |_, window, cx| { editor.update(cx, |editor, cx| { editor.expand_excerpt(excerpt_id, direction, window, cx); @@ -3627,7 +3627,7 @@ impl EditorElement { ButtonLike::new("toggle-buffer-fold") .style(ui::ButtonStyle::Transparent) .height(px(28.).into()) - .width(px(28.).into()) + .width(px(28.)) .children(toggle_chevron_icon) .tooltip({ let focus_handle = focus_handle.clone(); @@ -7815,7 +7815,7 @@ impl Element for EditorElement { min_lines, max_lines, } => { - let editor_handle = cx.entity().clone(); + let editor_handle = cx.entity(); let max_line_number_width = self.max_line_number_width(&editor.snapshot(window, cx), window); window.request_measured_layout( diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index bda229e346..3fc673bad9 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -251,7 +251,7 @@ fn show_hover( let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?; - let language_registry = editor.project.as_ref()?.read(cx).languages().clone(); + let language_registry = editor.project()?.read(cx).languages().clone(); let provider = editor.semantics_provider.clone()?; if !ignore_timeout { diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 231aaa1d00..34533002ff 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1,7 +1,7 @@ use crate::{ Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget, - MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects, - ToPoint as _, + MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange, + SelectionEffects, ToPoint as _, display_map::HighlightKey, editor_settings::SeedQuerySetting, persistence::{DB, SerializedEditor}, @@ -654,6 +654,10 @@ impl Item for Editor { } } + fn suggested_filename(&self, cx: &App) -> SharedString { + self.buffer.read(cx).title(cx).to_string().into() + } + fn tab_icon(&self, _: &Window, cx: &App) -> Option { ItemSettings::get_global(cx) .file_icons @@ -674,7 +678,7 @@ impl Item for Editor { let buffer = buffer.read(cx); let path = buffer.project_path(cx)?; let buffer_id = buffer.remote_id(); - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&path, cx)?; let (repo, repo_path) = project .git_store() @@ -776,6 +780,10 @@ impl Item for Editor { } } + fn on_removed(&self, cx: &App) { + self.report_editor_event(ReportEditorEvent::Closed, None, cx); + } + fn deactivated(&mut self, _: &mut Window, cx: &mut Context) { let selection = self.selections.newest_anchor(); self.push_to_nav_history(selection.head(), None, true, false, cx); @@ -815,9 +823,9 @@ impl Item for Editor { ) -> Task> { // Add meta data tracking # of auto saves if options.autosave { - self.report_editor_event("Editor Autosaved", None, cx); + self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx); } else { - self.report_editor_event("Editor Saved", None, cx); + self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx); } let buffers = self.buffer().clone().read(cx).all_buffers(); @@ -896,7 +904,11 @@ impl Item for Editor { .path .extension() .map(|a| a.to_string_lossy().to_string()); - self.report_editor_event("Editor Saved", file_extension, cx); + self.report_editor_event( + ReportEditorEvent::Saved { auto_saved: false }, + file_extension, + cx, + ); project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx)) } @@ -997,12 +1009,16 @@ impl Item for Editor { ) { self.workspace = Some((workspace.weak_handle(), workspace.database_id())); if let Some(workspace) = &workspace.weak_handle().upgrade() { - cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| { - if matches!(event, workspace::Event::ModalOpened) { - editor.mouse_context_menu.take(); - editor.inline_blame_popover.take(); - } - }) + cx.subscribe( + &workspace, + |editor, _, event: &workspace::Event, _cx| match event { + workspace::Event::ModalOpened => { + editor.mouse_context_menu.take(); + editor.inline_blame_popover.take(); + } + _ => {} + }, + ) .detach(); } } @@ -1024,6 +1040,10 @@ impl Item for Editor { f(ItemEvent::UpdateBreadcrumbs); } + EditorEvent::BreadcrumbsChanged => { + f(ItemEvent::UpdateBreadcrumbs); + } + EditorEvent::DirtyChanged => { f(ItemEvent::UpdateTab); } diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index a185de33ca..aaf9032b04 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -51,7 +51,7 @@ pub(super) fn refresh_linked_ranges( if editor.pending_rename.is_some() { return None; } - let project = editor.project.as_ref()?.downgrade(); + let project = editor.project()?.downgrade(); editor.linked_editing_range_task = Some(cx.spawn_in(window, async move |editor, cx| { cx.background_executor().timer(UPDATE_DEBOUNCE).await; diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index a8850984a1..fdda0e82bc 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -230,7 +230,7 @@ pub fn indented_line_beginning( if stop_at_soft_boundaries && soft_line_start > indent_start && display_point != soft_line_start { soft_line_start - } else if stop_at_indent && display_point != indent_start { + } else if stop_at_indent && (display_point > indent_start || display_point == line_start) { indent_start } else { line_start diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index e9f8d2dbd3..e0736a6e9f 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -169,7 +169,7 @@ impl Editor { else { return; }; - let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else { + let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else { return; }; let task = lsp_store.update(cx, |lsp_store, cx| { diff --git a/crates/editor/src/test.rs b/crates/editor/src/test.rs index 0a9d5e9535..f328945dbe 100644 --- a/crates/editor/src/test.rs +++ b/crates/editor/src/test.rs @@ -53,7 +53,7 @@ pub fn marked_display_snapshot( let (unmarked_text, markers) = marked_text_offsets(text); let font = Font { - family: "Zed Plex Mono".into(), + family: ".ZedMono".into(), features: FontFeatures::default(), fallbacks: None, weight: FontWeight::default(), diff --git a/crates/editor/src/test/editor_test_context.rs b/crates/editor/src/test/editor_test_context.rs index bdf73da5fb..dbb519c40e 100644 --- a/crates/editor/src/test/editor_test_context.rs +++ b/crates/editor/src/test/editor_test_context.rs @@ -297,9 +297,8 @@ impl EditorTestContext { pub fn set_head_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_head_for_repo( &Self::root_path().join(".git"), @@ -311,18 +310,16 @@ impl EditorTestContext { pub fn clear_index_text(&mut self) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); fs.set_index_for_repo(&Self::root_path().join(".git"), &[]); self.cx.run_until_parked(); } pub fn set_index_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_index_for_repo( &Self::root_path().join(".git"), @@ -333,9 +330,8 @@ impl EditorTestContext { #[track_caller] pub fn assert_index_text(&mut self, expected: Option<&str>) { - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); let mut found = None; fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| { diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index 917739759f..6a24e3ba3f 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -28,7 +28,6 @@ pub struct ExtensionHostProxy { snippet_proxy: RwLock>>, slash_command_proxy: RwLock>>, context_server_proxy: RwLock>>, - indexed_docs_provider_proxy: RwLock>>, debug_adapter_provider_proxy: RwLock>>, } @@ -54,7 +53,6 @@ impl ExtensionHostProxy { snippet_proxy: RwLock::default(), slash_command_proxy: RwLock::default(), context_server_proxy: RwLock::default(), - indexed_docs_provider_proxy: RwLock::default(), debug_adapter_provider_proxy: RwLock::default(), } } @@ -87,14 +85,6 @@ impl ExtensionHostProxy { self.context_server_proxy.write().replace(Arc::new(proxy)); } - pub fn register_indexed_docs_provider_proxy( - &self, - proxy: impl ExtensionIndexedDocsProviderProxy, - ) { - self.indexed_docs_provider_proxy - .write() - .replace(Arc::new(proxy)); - } pub fn register_debug_adapter_proxy(&self, proxy: impl ExtensionDebugAdapterProviderProxy) { self.debug_adapter_provider_proxy .write() @@ -408,30 +398,6 @@ impl ExtensionContextServerProxy for ExtensionHostProxy { } } -pub trait ExtensionIndexedDocsProviderProxy: Send + Sync + 'static { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc); - - fn unregister_indexed_docs_provider(&self, provider_id: Arc); -} - -impl ExtensionIndexedDocsProviderProxy for ExtensionHostProxy { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc) { - let Some(proxy) = self.indexed_docs_provider_proxy.read().clone() else { - return; - }; - - proxy.register_indexed_docs_provider(extension, provider_id) - } - - fn unregister_indexed_docs_provider(&self, provider_id: Arc) { - let Some(proxy) = self.indexed_docs_provider_proxy.read().clone() else { - return; - }; - - proxy.unregister_indexed_docs_provider(provider_id) - } -} - pub trait ExtensionDebugAdapterProviderProxy: Send + Sync + 'static { fn register_debug_adapter( &self, diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 5852b3e3fc..f5296198b0 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -84,8 +84,6 @@ pub struct ExtensionManifest { #[serde(default)] pub slash_commands: BTreeMap, SlashCommandManifestEntry>, #[serde(default)] - pub indexed_docs_providers: BTreeMap, IndexedDocsProviderEntry>, - #[serde(default)] pub snippets: Option, #[serde(default)] pub capabilities: Vec, @@ -195,9 +193,6 @@ pub struct SlashCommandManifestEntry { pub requires_argument: bool, } -#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] -pub struct IndexedDocsProviderEntry {} - #[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct DebugAdapterManifestEntry { pub schema_path: Option, @@ -271,7 +266,6 @@ fn manifest_from_old_manifest( language_servers: Default::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -304,7 +298,6 @@ mod tests { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![], debug_adapters: Default::default(), diff --git a/crates/extension_cli/src/main.rs b/crates/extension_cli/src/main.rs index ab4a9cddb0..d6c0501efd 100644 --- a/crates/extension_cli/src/main.rs +++ b/crates/extension_cli/src/main.rs @@ -144,10 +144,6 @@ fn extension_provides(manifest: &ExtensionManifest) -> BTreeSet ExtensionManifest { .collect(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![ExtensionCapability::ProcessExec( extension::ProcessExecCapability { diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs index c77e5ecba1..5a2093c1dd 100644 --- a/crates/extension_host/src/capability_granter.rs +++ b/crates/extension_host/src/capability_granter.rs @@ -108,7 +108,6 @@ mod tests { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: vec![], debug_adapters: Default::default(), diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index dc38c244f1..46deacfe69 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -16,9 +16,9 @@ pub use extension::ExtensionManifest; use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder}; use extension::{ ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents, - ExtensionGrammarProxy, ExtensionHostProxy, ExtensionIndexedDocsProviderProxy, - ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, - ExtensionSnippetProxy, ExtensionThemeProxy, + ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy, + ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy, + ExtensionThemeProxy, }; use fs::{Fs, RemoveOptions}; use futures::future::join_all; @@ -1118,15 +1118,17 @@ impl ExtensionStore { extensions_to_unload.len() - reload_count ); - for extension_id in &extensions_to_load { - if let Some(extension) = new_index.extensions.get(extension_id) { - telemetry::event!( - "Extension Loaded", - extension_id, - version = extension.manifest.version - ); - } - } + let extension_ids = extensions_to_load + .iter() + .filter_map(|id| { + Some(( + id.clone(), + new_index.extensions.get(id)?.manifest.version.clone(), + )) + }) + .collect::>(); + + telemetry::event!("Extensions Loaded", id_and_versions = extension_ids); let themes_to_remove = old_index .themes @@ -1190,10 +1192,6 @@ impl ExtensionStore { for (command_name, _) in &extension.manifest.slash_commands { self.proxy.unregister_slash_command(command_name.clone()); } - for (provider_id, _) in &extension.manifest.indexed_docs_providers { - self.proxy - .unregister_indexed_docs_provider(provider_id.clone()); - } } self.wasm_extensions @@ -1397,11 +1395,6 @@ impl ExtensionStore { .register_context_server(extension.clone(), id.clone(), cx); } - for (provider_id, _provider) in &manifest.indexed_docs_providers { - this.proxy - .register_indexed_docs_provider(extension.clone(), provider_id.clone()); - } - for (debug_adapter, meta) in &manifest.debug_adapters { let mut path = root_dir.clone(); path.push(Path::new(manifest.id.as_ref())); diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index c31774c20d..347a610439 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -160,7 +160,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -191,7 +190,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), @@ -371,7 +369,6 @@ async fn test_extension_store(cx: &mut TestAppContext) { language_servers: BTreeMap::default(), context_servers: BTreeMap::default(), slash_commands: BTreeMap::default(), - indexed_docs_providers: BTreeMap::default(), snippets: None, capabilities: Vec::new(), debug_adapters: Default::default(), diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index fe3e94f5c2..4915933920 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -703,7 +703,7 @@ impl ExtensionsPage { extension: &ExtensionMetadata, cx: &mut Context, ) -> ExtensionCard { - let this = cx.entity().clone(); + let this = cx.entity(); let status = Self::extension_status(&extension.id, cx); let has_dev_extension = Self::dev_extension_exists(&extension.id, cx); diff --git a/crates/file_icons/src/file_icons.rs b/crates/file_icons/src/file_icons.rs index 2eebb40385..d6fc248f6a 100644 --- a/crates/file_icons/src/file_icons.rs +++ b/crates/file_icons/src/file_icons.rs @@ -33,13 +33,23 @@ impl FileIcons { // TODO: Associate a type with the languages and have the file's language // override these associations - // check if file name is in suffixes - // e.g. catch file named `eslint.config.js` instead of `.eslint.config.js` - if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) { + if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) { + // check if file name is in suffixes + // e.g. catch file named `eslint.config.js` instead of `.eslint.config.js` let maybe_path = get_icon_from_suffix(typ); if maybe_path.is_some() { return maybe_path; } + + // check if suffix based on first dot is in suffixes + // e.g. consider `module.js` as suffix to angular's module file named `auth.module.js` + while let Some((_, suffix)) = typ.split_once('.') { + let maybe_path = get_icon_from_suffix(suffix); + if maybe_path.is_some() { + return maybe_path; + } + typ = suffix; + } } // handle cases where the file extension is made up of multiple important diff --git a/crates/fs/Cargo.toml b/crates/fs/Cargo.toml index 633fc1fc99..1d4161134e 100644 --- a/crates/fs/Cargo.toml +++ b/crates/fs/Cargo.toml @@ -51,6 +51,7 @@ ashpd.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } +git = { workspace = true, features = ["test-support"] } [features] test-support = ["gpui/test-support", "git/test-support"] diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 04ba656232..f0936d400a 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -1,8 +1,9 @@ -use crate::{FakeFs, Fs}; +use crate::{FakeFs, FakeFsEntry, Fs}; use anyhow::{Context as _, Result}; use collections::{HashMap, HashSet}; use futures::future::{self, BoxFuture, join_all}; use git::{ + Oid, blame::Blame, repository::{ AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository, @@ -10,8 +11,9 @@ use git::{ }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; -use gpui::{AsyncApp, BackgroundExecutor, SharedString}; +use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task}; use ignore::gitignore::GitignoreBuilder; +use parking_lot::Mutex; use rope::Rope; use smol::future::FutureExt as _; use std::{path::PathBuf, sync::Arc}; @@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc}; #[derive(Clone)] pub struct FakeGitRepository { pub(crate) fs: Arc, + pub(crate) checkpoints: Arc>>, pub(crate) executor: BackgroundExecutor, pub(crate) dot_git_path: PathBuf, pub(crate) repository_dir_path: PathBuf, @@ -183,7 +186,7 @@ impl GitRepository for FakeGitRepository { async move { None }.boxed() } - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result> { + fn status(&self, path_prefixes: &[RepoPath]) -> Task> { let workdir_path = self.dot_git_path.parent().unwrap(); // Load gitignores @@ -311,7 +314,10 @@ impl GitRepository for FakeGitRepository { entries: entries.into(), }) }); - async move { result? }.boxed() + Task::ready(match result { + Ok(result) => result, + Err(e) => Err(e), + }) } fn branches(&self) -> BoxFuture<'_, Result>> { @@ -402,11 +408,11 @@ impl GitRepository for FakeGitRepository { &self, _paths: Vec, _env: Arc>, - ) -> BoxFuture> { + ) -> BoxFuture<'_, Result<()>> { unimplemented!() } - fn stash_pop(&self, _env: Arc>) -> BoxFuture> { + fn stash_pop(&self, _env: Arc>) -> BoxFuture<'_, Result<()>> { unimplemented!() } @@ -466,22 +472,57 @@ impl GitRepository for FakeGitRepository { } fn checkpoint(&self) -> BoxFuture<'static, Result> { - unimplemented!() + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let oid = Oid::random(&mut executor.rng()); + let entry = fs.entry(&repository_dir_path)?; + checkpoints.lock().insert(oid, entry); + Ok(GitRepositoryCheckpoint { commit_sha: oid }) + } + .boxed() } - fn restore_checkpoint( - &self, - _checkpoint: GitRepositoryCheckpoint, - ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> { + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let entry = checkpoints + .get(&checkpoint.commit_sha) + .context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?; + fs.insert_entry(&repository_dir_path, entry.clone())?; + Ok(()) + } + .boxed() } fn compare_checkpoints( &self, - _left: GitRepositoryCheckpoint, - _right: GitRepositoryCheckpoint, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result> { - unimplemented!() + let executor = self.executor.clone(); + let checkpoints = self.checkpoints.clone(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let left = checkpoints + .get(&left.commit_sha) + .context(format!("invalid left checkpoint: {}", left.commit_sha))?; + let right = checkpoints + .get(&right.commit_sha) + .context(format!("invalid right checkpoint: {}", right.commit_sha))?; + + Ok(left == right) + } + .boxed() } fn diff_checkpoints( @@ -496,3 +537,63 @@ impl GitRepository for FakeGitRepository { unimplemented!() } } + +#[cfg(test)] +mod tests { + use crate::{FakeFs, Fs}; + use gpui::BackgroundExecutor; + use serde_json::json; + use std::path::Path; + use util::path; + + #[gpui::test] + async fn test_checkpoints(executor: BackgroundExecutor) { + let fs = FakeFs::new(executor); + fs.insert_tree( + path!("/"), + json!({ + "bar": { + "baz": "qux" + }, + "foo": { + ".git": {}, + "a": "lorem", + "b": "ipsum", + }, + }), + ) + .await; + fs.with_git_state(Path::new("/foo/.git"), true, |_git| {}) + .unwrap(); + let repository = fs.open_repo(Path::new("/foo/.git")).unwrap(); + + let checkpoint_1 = repository.checkpoint().await.unwrap(); + fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap(); + fs.write(Path::new("/foo/c"), b"dolor").await.unwrap(); + let checkpoint_2 = repository.checkpoint().await.unwrap(); + let checkpoint_3 = repository.checkpoint().await.unwrap(); + + assert!( + repository + .compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone()) + .await + .unwrap() + ); + assert!( + !repository + .compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone()) + .await + .unwrap() + ); + + repository.restore_checkpoint(checkpoint_1).await.unwrap(); + assert_eq!( + fs.files_with_contents(Path::new("")), + [ + (Path::new("/bar/baz").into(), b"qux".into()), + (Path::new("/foo/a").into(), b"lorem".into()), + (Path::new("/foo/b").into(), b"ipsum".into()) + ] + ); + } +} diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index a76ccee2bf..22bfdbcd66 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -12,7 +12,7 @@ use gpui::BackgroundExecutor; use gpui::Global; use gpui::ReadGlobal as _; use std::borrow::Cow; -use util::command::new_std_command; +use util::command::{new_smol_command, new_std_command}; #[cfg(unix)] use std::os::fd::{AsFd, AsRawFd}; @@ -134,6 +134,7 @@ pub trait Fs: Send + Sync { fn home_dir(&self) -> Option; fn open_repo(&self, abs_dot_git: &Path) -> Option>; fn git_init(&self, abs_work_directory: &Path, fallback_branch_name: String) -> Result<()>; + async fn git_clone(&self, repo_url: &str, abs_work_directory: &Path) -> Result<()>; fn is_fake(&self) -> bool; async fn is_case_sensitive(&self) -> Result; @@ -839,6 +840,23 @@ impl Fs for RealFs { Ok(()) } + async fn git_clone(&self, repo_url: &str, abs_work_directory: &Path) -> Result<()> { + let output = new_smol_command("git") + .current_dir(abs_work_directory) + .args(&["clone", repo_url]) + .output() + .await?; + + if !output.status.success() { + anyhow::bail!( + "git clone failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + } + + Ok(()) + } + fn is_fake(&self) -> bool { false } @@ -906,7 +924,7 @@ pub struct FakeFs { #[cfg(any(test, feature = "test-support"))] struct FakeFsState { - root: Arc>, + root: FakeFsEntry, next_inode: u64, next_mtime: SystemTime, git_event_tx: smol::channel::Sender, @@ -921,7 +939,7 @@ struct FakeFsState { } #[cfg(any(test, feature = "test-support"))] -#[derive(Debug)] +#[derive(Clone, Debug)] enum FakeFsEntry { File { inode: u64, @@ -935,7 +953,7 @@ enum FakeFsEntry { inode: u64, mtime: MTime, len: u64, - entries: BTreeMap>>, + entries: BTreeMap, git_repo_state: Option>>, }, Symlink { @@ -943,6 +961,67 @@ enum FakeFsEntry { }, } +#[cfg(any(test, feature = "test-support"))] +impl PartialEq for FakeFsEntry { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::File { + inode: l_inode, + mtime: l_mtime, + len: l_len, + content: l_content, + git_dir_path: l_git_dir_path, + }, + Self::File { + inode: r_inode, + mtime: r_mtime, + len: r_len, + content: r_content, + git_dir_path: r_git_dir_path, + }, + ) => { + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_content == r_content + && l_git_dir_path == r_git_dir_path + } + ( + Self::Dir { + inode: l_inode, + mtime: l_mtime, + len: l_len, + entries: l_entries, + git_repo_state: l_git_repo_state, + }, + Self::Dir { + inode: r_inode, + mtime: r_mtime, + len: r_len, + entries: r_entries, + git_repo_state: r_git_repo_state, + }, + ) => { + let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) { + (Some(l), Some(r)) => Arc::ptr_eq(l, r), + (None, None) => true, + _ => false, + }; + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_entries == r_entries + && same_repo_state + } + (Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => { + l_target == r_target + } + _ => false, + } + } +} + #[cfg(any(test, feature = "test-support"))] impl FakeFsState { fn get_and_increment_mtime(&mut self) -> MTime { @@ -957,25 +1036,9 @@ impl FakeFsState { inode } - fn read_path(&self, target: &Path) -> Result>> { - Ok(self - .try_read_path(target, true) - .ok_or_else(|| { - anyhow!(io::Error::new( - io::ErrorKind::NotFound, - format!("not found: {target:?}") - )) - })? - .0) - } - - fn try_read_path( - &self, - target: &Path, - follow_symlink: bool, - ) -> Option<(Arc>, PathBuf)> { - let mut path = target.to_path_buf(); + fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option { let mut canonical_path = PathBuf::new(); + let mut path = target.to_path_buf(); let mut entry_stack = Vec::new(); 'outer: loop { let mut path_components = path.components().peekable(); @@ -985,7 +1048,7 @@ impl FakeFsState { Component::Prefix(prefix_component) => prefix = Some(prefix_component), Component::RootDir => { entry_stack.clear(); - entry_stack.push(self.root.clone()); + entry_stack.push(&self.root); canonical_path.clear(); match prefix { Some(prefix_component) => { @@ -1002,20 +1065,18 @@ impl FakeFsState { canonical_path.pop(); } Component::Normal(name) => { - let current_entry = entry_stack.last().cloned()?; - let current_entry = current_entry.lock(); - if let FakeFsEntry::Dir { entries, .. } = &*current_entry { - let entry = entries.get(name.to_str().unwrap()).cloned()?; + let current_entry = *entry_stack.last()?; + if let FakeFsEntry::Dir { entries, .. } = current_entry { + let entry = entries.get(name.to_str().unwrap())?; if path_components.peek().is_some() || follow_symlink { - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target, .. } = &*entry { + if let FakeFsEntry::Symlink { target, .. } = entry { let mut target = target.clone(); target.extend(path_components); path = target; continue 'outer; } } - entry_stack.push(entry.clone()); + entry_stack.push(entry); canonical_path = canonical_path.join(name); } else { return None; @@ -1025,19 +1086,72 @@ impl FakeFsState { } break; } - Some((entry_stack.pop()?, canonical_path)) + + if entry_stack.is_empty() { + None + } else { + Some(canonical_path) + } } - fn write_path(&self, path: &Path, callback: Fn) -> Result + fn try_entry( + &mut self, + target: &Path, + follow_symlink: bool, + ) -> Option<(&mut FakeFsEntry, PathBuf)> { + let canonical_path = self.canonicalize(target, follow_symlink)?; + + let mut components = canonical_path.components(); + let Some(Component::RootDir) = components.next() else { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + }; + + let mut entry = &mut self.root; + for component in components { + match component { + Component::Normal(name) => { + if let FakeFsEntry::Dir { entries, .. } = entry { + entry = entries.get_mut(name.to_str().unwrap())?; + } else { + return None; + } + } + _ => { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + } + } + } + + Some((entry, canonical_path)) + } + + fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> { + Ok(self + .try_entry(target, true) + .ok_or_else(|| { + anyhow!(io::Error::new( + io::ErrorKind::NotFound, + format!("not found: {target:?}") + )) + })? + .0) + } + + fn write_path(&mut self, path: &Path, callback: Fn) -> Result where - Fn: FnOnce(btree_map::Entry>>) -> Result, + Fn: FnOnce(btree_map::Entry) -> Result, { let path = normalize_path(path); let filename = path.file_name().context("cannot overwrite the root")?; let parent_path = path.parent().unwrap(); - let parent = self.read_path(parent_path)?; - let mut parent = parent.lock(); + let parent = self.entry(parent_path)?; let new_entry = parent .dir_entries(parent_path)? .entry(filename.to_str().unwrap().into()); @@ -1087,13 +1201,13 @@ impl FakeFs { this: this.clone(), executor: executor.clone(), state: Arc::new(Mutex::new(FakeFsState { - root: Arc::new(Mutex::new(FakeFsEntry::Dir { + root: FakeFsEntry::Dir { inode: 0, mtime: MTime(UNIX_EPOCH), len: 0, entries: Default::default(), git_repo_state: None, - })), + }, git_event_tx: tx, next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL, next_inode: 1, @@ -1143,15 +1257,15 @@ impl FakeFs { .write_path(path, move |entry| { match entry { btree_map::Entry::Vacant(e) => { - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, content: Vec::new(), len: 0, git_dir_path: None, - }))); + }); } - btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() { + btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() { FakeFsEntry::File { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Symlink { .. } => {} @@ -1170,7 +1284,7 @@ impl FakeFs { pub async fn insert_symlink(&self, path: impl AsRef, target: PathBuf) { let mut state = self.state.lock(); let path = path.as_ref(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -1203,13 +1317,13 @@ impl FakeFs { match entry { btree_map::Entry::Vacant(e) => { kind = Some(PathEventKind::Created); - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, len: new_len, content: new_content, git_dir_path: None, - }))); + }); } btree_map::Entry::Occupied(mut e) => { kind = Some(PathEventKind::Changed); @@ -1219,7 +1333,7 @@ impl FakeFs { len, content, .. - } = &mut *e.get_mut().lock() + } = e.get_mut() { *mtime = new_mtime; *content = new_content; @@ -1241,9 +1355,8 @@ impl FakeFs { pub fn read_file_sync(&self, path: impl AsRef) -> Result> { let path = path.as_ref(); let path = normalize_path(path); - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1251,9 +1364,8 @@ impl FakeFs { let path = path.as_ref(); let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1274,6 +1386,25 @@ impl FakeFs { self.state.lock().flush_events(count); } + pub(crate) fn entry(&self, target: &Path) -> Result { + self.state.lock().entry(target).cloned() + } + + pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> { + let mut state = self.state.lock(); + state.write_path(target, |entry| { + match entry { + btree_map::Entry::Vacant(vacant_entry) => { + vacant_entry.insert(new_entry); + } + btree_map::Entry::Occupied(mut occupied_entry) => { + occupied_entry.insert(new_entry); + } + } + Ok(()) + }) + } + #[must_use] pub fn insert_tree<'a>( &'a self, @@ -1343,20 +1474,19 @@ impl FakeFs { F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T, { let mut state = self.state.lock(); - let entry = state.read_path(dot_git).context("open .git")?; - let mut entry = entry.lock(); + let git_event_tx = state.git_event_tx.clone(); + let entry = state.entry(dot_git).context("open .git")?; - if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { + if let FakeFsEntry::Dir { git_repo_state, .. } = entry { let repo_state = git_repo_state.get_or_insert_with(|| { log::debug!("insert git state for {dot_git:?}"); - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, dot_git, dot_git); + drop(repo_state); if emit_git_event { state.emit_event([(dot_git, None)]); } @@ -1380,21 +1510,20 @@ impl FakeFs { } } .clone(); - drop(entry); - let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else { + let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else { anyhow::bail!("pointed-to git dir {path:?} not found") }; let FakeFsEntry::Dir { git_repo_state, entries, .. - } = &mut *git_dir_entry.lock() + } = git_dir_entry else { anyhow::bail!("gitfile points to a non-directory") }; let common_dir = if let Some(child) = entries.get("commondir") { Path::new( - std::str::from_utf8(child.lock().file_content("commondir".as_ref())?) + std::str::from_utf8(child.file_content("commondir".as_ref())?) .context("commondir content")?, ) .to_owned() @@ -1402,15 +1531,14 @@ impl FakeFs { canonical_path.clone() }; let repo_state = git_repo_state.get_or_insert_with(|| { - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, &canonical_path, &common_dir); if emit_git_event { + drop(repo_state); state.emit_event([(canonical_path, None)]); } @@ -1637,14 +1765,12 @@ impl FakeFs { pub fn paths(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } if include_dot_git @@ -1661,14 +1787,12 @@ impl FakeFs { pub fn directories(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } if include_dot_git || !path @@ -1685,17 +1809,14 @@ impl FakeFs { pub fn files(&self) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { .. } => result.push(path), FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1707,13 +1828,10 @@ impl FakeFs { pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec)> { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { content, .. } => { if path.starts_with(prefix) { result.push((path, content.clone())); @@ -1721,7 +1839,7 @@ impl FakeFs { } FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1787,10 +1905,7 @@ impl FakeFsEntry { } } - fn dir_entries( - &mut self, - path: &Path, - ) -> Result<&mut BTreeMap>>> { + fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap> { if let Self::Dir { entries, .. } = self { Ok(entries) } else { @@ -1837,12 +1952,12 @@ struct FakeHandle { impl FileHandle for FakeHandle { fn current_path(&self, fs: &Arc) -> Result { let fs = fs.as_fake(); - let state = fs.state.lock(); - let Some(target) = state.moves.get(&self.inode) else { + let mut state = fs.state.lock(); + let Some(target) = state.moves.get(&self.inode).cloned() else { anyhow::bail!("fake fd not moved") }; - if state.try_read_path(&target, false).is_some() { + if state.try_entry(&target, false).is_some() { return Ok(target.clone()); } anyhow::bail!("fake fd target not found") @@ -1870,13 +1985,13 @@ impl Fs for FakeFs { state.write_path(&cur_path, |entry| { entry.or_insert_with(|| { created_dirs.push((cur_path.clone(), Some(PathEventKind::Created))); - Arc::new(Mutex::new(FakeFsEntry::Dir { + FakeFsEntry::Dir { inode, mtime, len: 0, entries: Default::default(), git_repo_state: None, - })) + } }); Ok(()) })? @@ -1891,13 +2006,13 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let inode = state.get_and_increment_inode(); let mtime = state.get_and_increment_mtime(); - let file = Arc::new(Mutex::new(FakeFsEntry::File { + let file = FakeFsEntry::File { inode, mtime, len: 0, content: Vec::new(), git_dir_path: None, - })); + }; let mut kind = Some(PathEventKind::Created); state.write_path(path, |entry| { match entry { @@ -1921,7 +2036,7 @@ impl Fs for FakeFs { async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> { let mut state = self.state.lock(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -1984,7 +2099,7 @@ impl Fs for FakeFs { } })?; - let inode = match *moved_entry.lock() { + let inode = match moved_entry { FakeFsEntry::File { inode, .. } => inode, FakeFsEntry::Dir { inode, .. } => inode, _ => 0, @@ -2033,8 +2148,8 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let mtime = state.get_and_increment_mtime(); let inode = state.get_and_increment_inode(); - let source_entry = state.read_path(&source)?; - let content = source_entry.lock().file_content(&source)?.clone(); + let source_entry = state.entry(&source)?; + let content = source_entry.file_content(&source)?.clone(); let mut kind = Some(PathEventKind::Created); state.write_path(&target, |e| match e { btree_map::Entry::Occupied(e) => { @@ -2048,13 +2163,13 @@ impl Fs for FakeFs { } } btree_map::Entry::Vacant(e) => Ok(Some( - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode, mtime, len: content.len() as u64, content, git_dir_path: None, - }))) + }) .clone(), )), })?; @@ -2070,8 +2185,7 @@ impl Fs for FakeFs { let base_name = path.file_name().context("cannot remove the root")?; let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2082,15 +2196,14 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { + btree_map::Entry::Occupied(mut entry) => { { - let mut entry = e.get().lock(); - let children = entry.dir_entries(&path)?; + let children = entry.get_mut().dir_entries(&path)?; if !options.recursive && !children.is_empty() { anyhow::bail!("{path:?} is not empty"); } } - e.remove(); + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2104,8 +2217,7 @@ impl Fs for FakeFs { let parent_path = path.parent().context("cannot remove the root")?; let base_name = path.file_name().unwrap(); let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2115,9 +2227,9 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { - e.get().lock().file_content(&path)?; - e.remove(); + btree_map::Entry::Occupied(mut entry) => { + entry.get_mut().file_content(&path)?; + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2131,12 +2243,10 @@ impl Fs for FakeFs { async fn open_handle(&self, path: &Path) -> Result> { self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); - let inode = match *entry { - FakeFsEntry::File { inode, .. } => inode, - FakeFsEntry::Dir { inode, .. } => inode, + let mut state = self.state.lock(); + let inode = match state.entry(&path)? { + FakeFsEntry::File { inode, .. } => *inode, + FakeFsEntry::Dir { inode, .. } => *inode, _ => unreachable!(), }; Ok(Arc::new(FakeHandle { inode })) @@ -2154,6 +2264,9 @@ impl Fs for FakeFs { async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> { self.simulate_random_delay().await; let path = normalize_path(path.as_path()); + if let Some(path) = path.parent() { + self.create_dir(path).await?; + } self.write_file_internal(path, data.into_bytes(), true)?; Ok(()) } @@ -2183,8 +2296,8 @@ impl Fs for FakeFs { let path = normalize_path(path); self.simulate_random_delay().await; let state = self.state.lock(); - let (_, canonical_path) = state - .try_read_path(&path, true) + let canonical_path = state + .canonicalize(&path, true) .with_context(|| format!("path does not exist: {path:?}"))?; Ok(canonical_path) } @@ -2192,9 +2305,9 @@ impl Fs for FakeFs { async fn is_file(&self, path: &Path) -> bool { let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - if let Some((entry, _)) = state.try_read_path(&path, true) { - entry.lock().is_file() + let mut state = self.state.lock(); + if let Some((entry, _)) = state.try_entry(&path, true) { + entry.is_file() } else { false } @@ -2211,17 +2324,16 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.metadata_call_count += 1; - if let Some((mut entry, _)) = state.try_read_path(&path, false) { - let is_symlink = entry.lock().is_symlink(); + if let Some((mut entry, _)) = state.try_entry(&path, false) { + let is_symlink = entry.is_symlink(); if is_symlink { - if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) { + if let Some(e) = state.try_entry(&path, true).map(|e| e.0) { entry = e; } else { return Ok(None); } } - let entry = entry.lock(); Ok(Some(match &*entry { FakeFsEntry::File { inode, mtime, len, .. @@ -2253,12 +2365,11 @@ impl Fs for FakeFs { async fn read_link(&self, path: &Path) -> Result { self.simulate_random_delay().await; let path = normalize_path(path); - let state = self.state.lock(); + let mut state = self.state.lock(); let (entry, _) = state - .try_read_path(&path, false) + .try_entry(&path, false) .with_context(|| format!("path does not exist: {path:?}"))?; - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target } = &*entry { + if let FakeFsEntry::Symlink { target } = entry { Ok(target.clone()) } else { anyhow::bail!("not a symlink: {path:?}") @@ -2273,8 +2384,7 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.read_dir_call_count += 1; - let entry = state.read_path(&path)?; - let mut entry = entry.lock(); + let entry = state.entry(&path)?; let children = entry.dir_entries(&path)?; let paths = children .keys() @@ -2338,6 +2448,7 @@ impl Fs for FakeFs { dot_git_path: abs_dot_git.to_path_buf(), repository_dir_path: repository_dir_path.to_owned(), common_dir_path: common_dir_path.to_owned(), + checkpoints: Arc::default(), }) as _ }, ) @@ -2352,6 +2463,10 @@ impl Fs for FakeFs { smol::block_on(self.create_dir(&abs_work_directory_path.join(".git"))) } + async fn git_clone(&self, _repo_url: &str, _abs_work_directory: &Path) -> Result<()> { + anyhow::bail!("Git clone is not supported in fake Fs") + } + fn is_fake(&self) -> bool { true } diff --git a/crates/git/Cargo.toml b/crates/git/Cargo.toml index ab2210094d..74656f1d4c 100644 --- a/crates/git/Cargo.toml +++ b/crates/git/Cargo.toml @@ -12,7 +12,7 @@ workspace = true path = "src/git.rs" [features] -test-support = [] +test-support = ["rand"] [dependencies] anyhow.workspace = true @@ -26,6 +26,7 @@ http_client.workspace = true log.workspace = true parking_lot.workspace = true regex.workspace = true +rand = { workspace = true, optional = true } rope.workspace = true schemars.workspace = true serde.workspace = true @@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] } unindent.workspace = true gpui = { workspace = true, features = ["test-support"] } tempfile.workspace = true +rand.workspace = true diff --git a/crates/git/src/blame.rs b/crates/git/src/blame.rs index 2128fa55c3..6f12681ea0 100644 --- a/crates/git/src/blame.rs +++ b/crates/git/src/blame.rs @@ -73,6 +73,7 @@ async fn run_git_blame( .current_dir(working_directory) .arg("blame") .arg("--incremental") + .arg("-w") .arg("--contents") .arg("-") .arg(path.as_os_str()) diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index 553361e673..e84014129c 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -93,6 +93,8 @@ actions!( Init, /// Opens all modified files in the editor. OpenModifiedFiles, + /// Clones a repository. + Clone, ] ); @@ -117,6 +119,13 @@ impl Oid { Ok(Self(oid)) } + #[cfg(any(test, feature = "test-support"))] + pub fn random(rng: &mut impl rand::Rng) -> Self { + let mut bytes = [0; 20]; + rng.fill(&mut bytes); + Self::from_bytes(&bytes).unwrap() + } + pub fn as_bytes(&self) -> &[u8] { self.0.as_bytes() } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index dc7ab0af65..49eee84840 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -6,7 +6,7 @@ use collections::HashMap; use futures::future::BoxFuture; use futures::{AsyncWriteExt, FutureExt as _, select_biased}; use git2::BranchType; -use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString}; +use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString, Task}; use parking_lot::Mutex; use rope::Rope; use schemars::JsonSchema; @@ -338,7 +338,7 @@ pub trait GitRepository: Send + Sync { fn merge_message(&self) -> BoxFuture<'_, Option>; - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result>; + fn status(&self, path_prefixes: &[RepoPath]) -> Task>; fn branches(&self) -> BoxFuture<'_, Result>>; @@ -399,9 +399,9 @@ pub trait GitRepository: Send + Sync { &self, paths: Vec, env: Arc>, - ) -> BoxFuture>; + ) -> BoxFuture<'_, Result<()>>; - fn stash_pop(&self, env: Arc>) -> BoxFuture>; + fn stash_pop(&self, env: Arc>) -> BoxFuture<'_, Result<()>>; fn push( &self, @@ -953,25 +953,27 @@ impl GitRepository for RealGitRepository { .boxed() } - fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result> { + fn status(&self, path_prefixes: &[RepoPath]) -> Task> { let git_binary_path = self.git_binary_path.clone(); - let working_directory = self.working_directory(); - let path_prefixes = path_prefixes.to_owned(); - self.executor - .spawn(async move { - let output = new_std_command(&git_binary_path) - .current_dir(working_directory?) - .args(git_status_args(&path_prefixes)) - .output()?; - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - stdout.parse() - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - anyhow::bail!("git status failed: {stderr}"); - } - }) - .boxed() + let working_directory = match self.working_directory() { + Ok(working_directory) => working_directory, + Err(e) => return Task::ready(Err(e)), + }; + let args = git_status_args(&path_prefixes); + log::debug!("Checking for git status in {path_prefixes:?}"); + self.executor.spawn(async move { + let output = new_std_command(&git_binary_path) + .current_dir(working_directory) + .args(args) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + stdout.parse() + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("git status failed: {stderr}"); + } + }) } fn branches(&self) -> BoxFuture<'_, Result>> { @@ -1203,7 +1205,7 @@ impl GitRepository for RealGitRepository { &self, paths: Vec, env: Arc>, - ) -> BoxFuture> { + ) -> BoxFuture<'_, Result<()>> { let working_directory = self.working_directory(); self.executor .spawn(async move { @@ -1227,7 +1229,7 @@ impl GitRepository for RealGitRepository { .boxed() } - fn stash_pop(&self, env: Arc>) -> BoxFuture> { + fn stash_pop(&self, env: Arc>) -> BoxFuture<'_, Result<()>> { let working_directory = self.working_directory(); self.executor .spawn(async move { diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 0bbb9411be..6482ebb9f8 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated( } fn buffer_added(editor: &mut Editor, buffer: Entity, cx: &mut Context) { - let Some(project) = &editor.project else { + let Some(project) = editor.project() else { return; }; let git_store = project.read(cx).git_store().clone(); @@ -469,7 +469,7 @@ pub(crate) fn resolve_conflict( let Some((workspace, project, multibuffer, buffer)) = editor .update(cx, |editor, cx| { let workspace = editor.workspace()?; - let project = editor.project.clone()?; + let project = editor.project()?.clone(); let multibuffer = editor.buffer().clone(); let buffer_id = resolved_conflict.ours.end.buffer_id?; let buffer = multibuffer.read(cx).buffer(buffer_id)?; diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index e4f445858d..70987dd212 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2081,6 +2081,99 @@ impl GitPanel { .detach_and_log_err(cx); } + pub(crate) fn git_clone(&mut self, repo: String, window: &mut Window, cx: &mut Context) { + let path = cx.prompt_for_paths(gpui::PathPromptOptions { + files: false, + directories: true, + multiple: false, + }); + + let workspace = self.workspace.clone(); + + cx.spawn_in(window, async move |this, cx| { + let mut paths = path.await.ok()?.ok()??; + let mut path = paths.pop()?; + let repo_name = repo + .split(std::path::MAIN_SEPARATOR_STR) + .last()? + .strip_suffix(".git")? + .to_owned(); + + let fs = this.read_with(cx, |this, _| this.fs.clone()).ok()?; + + let prompt_answer = match fs.git_clone(&repo, path.as_path()).await { + Ok(_) => cx.update(|window, cx| { + window.prompt( + PromptLevel::Info, + &format!("Git Clone: {}", repo_name), + None, + &["Add repo to project", "Open repo in new project"], + cx, + ) + }), + Err(e) => { + this.update(cx, |this: &mut GitPanel, cx| { + let toast = StatusToast::new(e.to_string(), cx, |this, _| { + this.icon(ToastIcon::new(IconName::XCircle).color(Color::Error)) + .dismiss_button(true) + }); + + this.workspace + .update(cx, |workspace, cx| { + workspace.toggle_status_toast(toast, cx); + }) + .ok(); + }) + .ok()?; + + return None; + } + } + .ok()?; + + path.push(repo_name); + match prompt_answer.await.ok()? { + 0 => { + workspace + .update(cx, |workspace, cx| { + workspace + .project() + .update(cx, |project, cx| { + project.create_worktree(path.as_path(), true, cx) + }) + .detach(); + }) + .ok(); + } + 1 => { + workspace + .update(cx, move |workspace, cx| { + workspace::open_new( + Default::default(), + workspace.app_state().clone(), + cx, + move |workspace, _, cx| { + cx.activate(true); + workspace + .project() + .update(cx, |project, cx| { + project.create_worktree(&path, true, cx) + }) + .detach(); + }, + ) + .detach(); + }) + .ok(); + } + _ => {} + } + + Some(()) + }) + .detach(); + } + pub(crate) fn git_init(&mut self, window: &mut Window, cx: &mut Context) { let worktrees = self .project @@ -3317,7 +3410,7 @@ impl GitPanel { * MAX_PANEL_EDITOR_LINES + gap; - let git_panel = cx.entity().clone(); + let git_panel = cx.entity(); let display_name = SharedString::from(Arc::from( active_repository .read(cx) diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index bde867bcd2..79aa4a6bd0 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -3,21 +3,25 @@ use std::any::Any; use ::settings::Settings; use command_palette_hooks::CommandPaletteFilter; use commit_modal::CommitModal; -use editor::{Editor, actions::DiffClipboardWithSelectionData}; +use editor::{Editor, EditorElement, EditorStyle, actions::DiffClipboardWithSelectionData}; mod blame_ui; use git::{ repository::{Branch, Upstream, UpstreamTracking, UpstreamTrackingStatus}, status::{FileStatus, StatusCode, UnmergedStatus, UnmergedStatusCode}, }; use git_panel_settings::GitPanelSettings; -use gpui::{Action, App, Context, FocusHandle, Window, actions}; +use gpui::{ + Action, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, TextStyle, + Window, actions, +}; use onboarding::GitOnboardingModal; use project_diff::ProjectDiff; +use theme::ThemeSettings; use ui::prelude::*; -use workspace::Workspace; +use workspace::{ModalView, Workspace}; use zed_actions; -use crate::text_diff_view::TextDiffView; +use crate::{git_panel::GitPanel, text_diff_view::TextDiffView}; mod askpass_modal; pub mod branch_picker; @@ -169,6 +173,15 @@ pub fn init(cx: &mut App) { panel.git_init(window, cx); }); }); + workspace.register_action(|workspace, _action: &git::Clone, window, cx| { + let Some(panel) = workspace.panel::(cx) else { + return; + }; + + workspace.toggle_modal(window, cx, |window, cx| { + GitCloneModal::show(panel, window, cx) + }); + }); workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| { open_modified_files(workspace, window, cx); }); @@ -613,3 +626,98 @@ impl Component for GitStatusIcon { ) } } + +struct GitCloneModal { + panel: Entity, + repo_input: Entity, + focus_handle: FocusHandle, +} + +impl GitCloneModal { + pub fn show(panel: Entity, window: &mut Window, cx: &mut Context) -> Self { + let repo_input = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Enter repository", cx); + editor + }); + let focus_handle = repo_input.focus_handle(cx); + + window.focus(&focus_handle); + + Self { + panel, + repo_input, + focus_handle, + } + } + + fn render_editor(&self, window: &Window, cx: &App) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let theme = cx.theme(); + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: settings.buffer_font_size(cx).into(), + font_weight: settings.buffer_font.weight, + line_height: relative(settings.buffer_line_height.value()), + background_color: Some(theme.colors().editor_background), + ..Default::default() + }; + + let element = EditorElement::new( + &self.repo_input, + EditorStyle { + background: theme.colors().editor_background, + local_player: theme.players().local(), + text: text_style, + ..Default::default() + }, + ); + + div() + .rounded_md() + .p_1() + .border_1() + .border_color(theme.colors().border_variant) + .when( + self.repo_input + .focus_handle(cx) + .contains_focused(window, cx), + |this| this.border_color(theme.colors().border_focused), + ) + .child(element) + .bg(theme.colors().editor_background) + } +} + +impl Focusable for GitCloneModal { + fn focus_handle(&self, _: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for GitCloneModal { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .size_full() + .w(rems(34.)) + .elevation_3(cx) + .child(self.render_editor(window, cx)) + .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| { + cx.emit(DismissEvent); + })) + .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| { + let repo = this.repo_input.read(cx).text(cx); + this.panel.update(cx, |panel, cx| { + panel.git_clone(repo, window, cx); + }); + cx.emit(DismissEvent); + })) + } +} + +impl EventEmitter for GitCloneModal {} + +impl ModalView for GitCloneModal {} diff --git a/crates/go_to_line/src/cursor_position.rs b/crates/go_to_line/src/cursor_position.rs index 29064eb29c..af92621378 100644 --- a/crates/go_to_line/src/cursor_position.rs +++ b/crates/go_to_line/src/cursor_position.rs @@ -1,4 +1,4 @@ -use editor::{Editor, MultiBufferSnapshot}; +use editor::{Editor, EditorSettings, MultiBufferSnapshot}; use gpui::{App, Entity, FocusHandle, Focusable, Subscription, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -209,6 +209,13 @@ impl CursorPosition { impl Render for CursorPosition { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if !EditorSettings::get_global(cx) + .status_bar + .cursor_position_button + { + return div(); + } + div().when_some(self.position, |el, position| { let mut text = format!( "{}{FILE_ROW_COLUMN_DELIMITER}{}", diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 6e5a76d441..6be8c5fd1f 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -209,7 +209,7 @@ xkbcommon = { version = "0.8.0", features = [ "wayland", "x11", ], optional = true } -xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf65a0ea94c70d3c4fd", features = [ +xim = { git = "https://github.com/zed-industries/xim-rs", rev = "c0a70c1bd2ce197364216e5e818a2cb3adb99a8d" , features = [ "x11rb-xcb", "x11rb-client", ], optional = true } @@ -305,3 +305,7 @@ path = "examples/uniform_list.rs" [[example]] name = "window_shadow" path = "examples/window_shadow.rs" + +[[example]] +name = "grid_layout" +path = "examples/grid_layout.rs" diff --git a/crates/gpui/examples/grid_layout.rs b/crates/gpui/examples/grid_layout.rs new file mode 100644 index 0000000000..f285497578 --- /dev/null +++ b/crates/gpui/examples/grid_layout.rs @@ -0,0 +1,80 @@ +use gpui::{ + App, Application, Bounds, Context, Hsla, Window, WindowBounds, WindowOptions, div, prelude::*, + px, rgb, size, +}; + +// https://en.wikipedia.org/wiki/Holy_grail_(web_design) +struct HolyGrailExample {} + +impl Render for HolyGrailExample { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + let block = |color: Hsla| { + div() + .size_full() + .bg(color) + .border_1() + .border_dashed() + .rounded_md() + .border_color(gpui::white()) + .items_center() + }; + + div() + .gap_1() + .grid() + .bg(rgb(0x505050)) + .size(px(500.0)) + .shadow_lg() + .border_1() + .size_full() + .grid_cols(5) + .grid_rows(5) + .child( + block(gpui::white()) + .row_span(1) + .col_span_full() + .child("Header"), + ) + .child( + block(gpui::red()) + .col_span(1) + .h_56() + .child("Table of contents"), + ) + .child( + block(gpui::green()) + .col_span(3) + .row_span(3) + .child("Content"), + ) + .child( + block(gpui::blue()) + .col_span(1) + .row_span(3) + .child("AD :(") + .text_color(gpui::white()), + ) + .child( + block(gpui::black()) + .row_span(1) + .col_span_full() + .text_color(gpui::white()) + .child("Footer"), + ) + } +} + +fn main() { + Application::new().run(|cx: &mut App| { + let bounds = Bounds::centered(None, size(px(500.), px(500.0)), cx); + cx.open_window( + WindowOptions { + window_bounds: Some(WindowBounds::Windowed(bounds)), + ..Default::default() + }, + |_, cx| cx.new(|_| HolyGrailExample {}), + ) + .unwrap(); + cx.activate(true); + }); +} diff --git a/crates/gpui/examples/input.rs b/crates/gpui/examples/input.rs index 52a5b08b96..b0f560e38d 100644 --- a/crates/gpui/examples/input.rs +++ b/crates/gpui/examples/input.rs @@ -595,9 +595,7 @@ impl Render for TextInput { .w_full() .p(px(4.)) .bg(white()) - .child(TextElement { - input: cx.entity().clone(), - }), + .child(TextElement { input: cx.entity() }), ) } } diff --git a/crates/gpui/examples/set_menus.rs b/crates/gpui/examples/set_menus.rs index f53fff7c7f..8a97a8d8a2 100644 --- a/crates/gpui/examples/set_menus.rs +++ b/crates/gpui/examples/set_menus.rs @@ -1,5 +1,6 @@ use gpui::{ - App, Application, Context, Menu, MenuItem, Window, WindowOptions, actions, div, prelude::*, rgb, + App, Application, Context, Menu, MenuItem, SystemMenuType, Window, WindowOptions, actions, div, + prelude::*, rgb, }; struct SetMenus; @@ -27,7 +28,11 @@ fn main() { // Add menu items cx.set_menus(vec![Menu { name: "set_menus".into(), - items: vec![MenuItem::action("Quit", Quit)], + items: vec![ + MenuItem::os_submenu("Services", SystemMenuType::Services), + MenuItem::separator(), + MenuItem::action("Quit", Quit), + ], }]); cx.open_window(WindowOptions::default(), |_, cx| cx.new(|_| SetMenus {})) .unwrap(); diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index ded7bae316..e1df6d0be4 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -277,6 +277,8 @@ pub struct App { pub(crate) release_listeners: SubscriberSet, pub(crate) global_observers: SubscriberSet, pub(crate) quit_observers: SubscriberSet<(), QuitHandler>, + pub(crate) restart_observers: SubscriberSet<(), Handler>, + pub(crate) restart_path: Option, pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>, pub(crate) layout_id_buffer: Vec, // We recycle this memory across layout requests. pub(crate) propagate_event: bool, @@ -349,6 +351,8 @@ impl App { keyboard_layout_observers: SubscriberSet::new(), global_observers: SubscriberSet::new(), quit_observers: SubscriberSet::new(), + restart_observers: SubscriberSet::new(), + restart_path: None, window_closed_observers: SubscriberSet::new(), layout_id_buffer: Default::default(), propagate_event: true, @@ -812,8 +816,9 @@ impl App { pub fn prompt_for_new_path( &self, directory: &Path, + suggested_name: Option<&str>, ) -> oneshot::Receiver>> { - self.platform.prompt_for_new_path(directory) + self.platform.prompt_for_new_path(directory, suggested_name) } /// Reveals the specified path at the platform level, such as in Finder on macOS. @@ -832,8 +837,16 @@ impl App { } /// Restarts the application. - pub fn restart(&self, binary_path: Option) { - self.platform.restart(binary_path) + pub fn restart(&mut self) { + self.restart_observers + .clone() + .retain(&(), |observer| observer(self)); + self.platform.restart(self.restart_path.take()) + } + + /// Sets the path to use when restarting the application. + pub fn set_restart_path(&mut self, path: PathBuf) { + self.restart_path = Some(path); } /// Returns the HTTP client for the application. @@ -1466,6 +1479,21 @@ impl App { subscription } + /// Register a callback to be invoked when the application is about to restart. + /// + /// These callbacks are called before any `on_app_quit` callbacks. + pub fn on_app_restart(&self, mut on_restart: impl 'static + FnMut(&mut App)) -> Subscription { + let (subscription, activate) = self.restart_observers.insert( + (), + Box::new(move |cx| { + on_restart(cx); + true + }), + ); + activate(); + subscription + } + /// Register a callback to be invoked when a window is closed /// The window is no longer accessible at the point this callback is invoked. pub fn on_window_closed(&self, mut on_closed: impl FnMut(&mut App) + 'static) -> Subscription { diff --git a/crates/gpui/src/app/context.rs b/crates/gpui/src/app/context.rs index 392be2ffe9..68c41592b3 100644 --- a/crates/gpui/src/app/context.rs +++ b/crates/gpui/src/app/context.rs @@ -164,6 +164,20 @@ impl<'a, T: 'static> Context<'a, T> { subscription } + /// Register a callback to be invoked when the application is about to restart. + pub fn on_app_restart( + &self, + mut on_restart: impl FnMut(&mut T, &mut App) + 'static, + ) -> Subscription + where + T: 'static, + { + let handle = self.weak_entity(); + self.app.on_app_restart(move |cx| { + handle.update(cx, |entity, cx| on_restart(entity, cx)).ok(); + }) + } + /// Arrange for the given function to be invoked whenever the application is quit. /// The future returned from this callback will be polled for up to [crate::SHUTDOWN_TIMEOUT] until the app fully quits. pub fn on_app_quit( @@ -175,20 +189,15 @@ impl<'a, T: 'static> Context<'a, T> { T: 'static, { let handle = self.weak_entity(); - let (subscription, activate) = self.app.quit_observers.insert( - (), - Box::new(move |cx| { - let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok(); - async move { - if let Some(future) = future { - future.await; - } + self.app.on_app_quit(move |cx| { + let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok(); + async move { + if let Some(future) = future { + future.await; } - .boxed_local() - }), - ); - activate(); - subscription + } + .boxed_local() + }) } /// Tell GPUI that this entity has changed and observers of it should be notified. diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index 35e6032671..a96c24432a 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -585,7 +585,7 @@ impl Entity { cx.executor().advance_clock(advance_clock_by); async move { - let notification = crate::util::timeout(duration, rx.recv()) + let notification = crate::util::smol_timeout(duration, rx.recv()) .await .expect("next notification timed out"); drop(subscription); @@ -629,7 +629,7 @@ impl Entity { let handle = self.downgrade(); async move { - crate::util::timeout(Duration::from_secs(1), async move { + crate::util::smol_timeout(Duration::from_secs(1), async move { loop { { let cx = cx.borrow(); diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 3d2d9cd9db..2de3e23ff7 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -9,12 +9,14 @@ use refineable::Refineable; use schemars::{JsonSchema, json_schema}; use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; use std::borrow::Cow; +use std::ops::Range; use std::{ cmp::{self, PartialOrd}, fmt::{self, Display}, hash::Hash, ops::{Add, Div, Mul, MulAssign, Neg, Sub}, }; +use taffy::prelude::{TaffyGridLine, TaffyGridSpan}; use crate::{App, DisplayId}; @@ -3608,6 +3610,37 @@ impl From<()> for Length { } } +/// A location in a grid layout. +#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, JsonSchema, Default)] +pub struct GridLocation { + /// The rows this item uses within the grid. + pub row: Range, + /// The columns this item uses within the grid. + pub column: Range, +} + +/// The placement of an item within a grid layout's column or row. +#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize, JsonSchema, Default)] +pub enum GridPlacement { + /// The grid line index to place this item. + Line(i16), + /// The number of grid lines to span. + Span(u16), + /// Automatically determine the placement, equivalent to Span(1) + #[default] + Auto, +} + +impl From for taffy::GridPlacement { + fn from(placement: GridPlacement) -> Self { + match placement { + GridPlacement::Line(index) => taffy::GridPlacement::from_line_index(index), + GridPlacement::Span(span) => taffy::GridPlacement::from_span(span), + GridPlacement::Auto => taffy::GridPlacement::Auto, + } + } +} + /// Provides a trait for types that can calculate half of their value. /// /// The `Half` trait is used for types that can be evenly divided, returning a new instance of the same type diff --git a/crates/gpui/src/gpui.rs b/crates/gpui/src/gpui.rs index 09799eb910..f0ce04a915 100644 --- a/crates/gpui/src/gpui.rs +++ b/crates/gpui/src/gpui.rs @@ -157,7 +157,7 @@ pub use taffy::{AvailableSpace, LayoutId}; #[cfg(any(test, feature = "test-support"))] pub use test::*; pub use text_system::*; -pub use util::arc_cow::ArcCow; +pub use util::{FutureExt, Timeout, arc_cow::ArcCow}; pub use view::*; pub use window::*; diff --git a/crates/gpui/src/key_dispatch.rs b/crates/gpui/src/key_dispatch.rs index cc6ebb9b08..c3f5d18603 100644 --- a/crates/gpui/src/key_dispatch.rs +++ b/crates/gpui/src/key_dispatch.rs @@ -611,9 +611,17 @@ impl DispatchTree { #[cfg(test)] mod tests { - use std::{cell::RefCell, rc::Rc}; + use crate::{ + self as gpui, Element, ElementId, GlobalElementId, InspectorElementId, LayoutId, Style, + }; + use core::panic; + use std::{cell::RefCell, ops::Range, rc::Rc}; - use crate::{Action, ActionRegistry, DispatchTree, KeyBinding, KeyContext, Keymap}; + use crate::{ + Action, ActionRegistry, App, Bounds, Context, DispatchTree, FocusHandle, InputHandler, + IntoElement, KeyBinding, KeyContext, Keymap, Pixels, Point, Render, TestAppContext, + UTF16Selection, Window, + }; #[derive(PartialEq, Eq)] struct TestAction; @@ -674,4 +682,165 @@ mod tests { assert!(keybinding[0].action.partial_eq(&TestAction)) } + + #[crate::test] + fn test_input_handler_pending(cx: &mut TestAppContext) { + #[derive(Clone)] + struct CustomElement { + focus_handle: FocusHandle, + text: Rc>, + } + impl CustomElement { + fn new(cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + text: Rc::default(), + } + } + } + impl Element for CustomElement { + type RequestLayoutState = (); + + type PrepaintState = (); + + fn id(&self) -> Option { + Some("custom".into()) + } + fn source_location(&self) -> Option<&'static panic::Location<'static>> { + None + } + fn request_layout( + &mut self, + _: Option<&GlobalElementId>, + _: Option<&InspectorElementId>, + window: &mut Window, + cx: &mut App, + ) -> (LayoutId, Self::RequestLayoutState) { + (window.request_layout(Style::default(), [], cx), ()) + } + fn prepaint( + &mut self, + _: Option<&GlobalElementId>, + _: Option<&InspectorElementId>, + _: Bounds, + _: &mut Self::RequestLayoutState, + window: &mut Window, + cx: &mut App, + ) -> Self::PrepaintState { + window.set_focus_handle(&self.focus_handle, cx); + } + fn paint( + &mut self, + _: Option<&GlobalElementId>, + _: Option<&InspectorElementId>, + _: Bounds, + _: &mut Self::RequestLayoutState, + _: &mut Self::PrepaintState, + window: &mut Window, + cx: &mut App, + ) { + let mut key_context = KeyContext::default(); + key_context.add("Terminal"); + window.set_key_context(key_context); + window.handle_input(&self.focus_handle, self.clone(), cx); + window.on_action(std::any::TypeId::of::(), |_, _, _, _| {}); + } + } + impl IntoElement for CustomElement { + type Element = Self; + + fn into_element(self) -> Self::Element { + self + } + } + + impl InputHandler for CustomElement { + fn selected_text_range( + &mut self, + _: bool, + _: &mut Window, + _: &mut App, + ) -> Option { + None + } + + fn marked_text_range(&mut self, _: &mut Window, _: &mut App) -> Option> { + None + } + + fn text_for_range( + &mut self, + _: Range, + _: &mut Option>, + _: &mut Window, + _: &mut App, + ) -> Option { + None + } + + fn replace_text_in_range( + &mut self, + replacement_range: Option>, + text: &str, + _: &mut Window, + _: &mut App, + ) { + if replacement_range.is_some() { + unimplemented!() + } + self.text.borrow_mut().push_str(text) + } + + fn replace_and_mark_text_in_range( + &mut self, + replacement_range: Option>, + new_text: &str, + _: Option>, + _: &mut Window, + _: &mut App, + ) { + if replacement_range.is_some() { + unimplemented!() + } + self.text.borrow_mut().push_str(new_text) + } + + fn unmark_text(&mut self, _: &mut Window, _: &mut App) {} + + fn bounds_for_range( + &mut self, + _: Range, + _: &mut Window, + _: &mut App, + ) -> Option> { + None + } + + fn character_index_for_point( + &mut self, + _: Point, + _: &mut Window, + _: &mut App, + ) -> Option { + None + } + } + impl Render for CustomElement { + fn render(&mut self, _: &mut Window, _: &mut Context) -> impl IntoElement { + self.clone() + } + } + + cx.update(|cx| { + cx.bind_keys([KeyBinding::new("ctrl-b", TestAction, Some("Terminal"))]); + cx.bind_keys([KeyBinding::new("ctrl-b h", TestAction, Some("Terminal"))]); + }); + let (test, cx) = cx.add_window_view(|_, cx| CustomElement::new(cx)); + cx.update(|window, cx| { + window.focus(&test.read(cx).focus_handle); + window.activate_window(); + }); + cx.simulate_keystrokes("ctrl-b ["); + test.update(cx, |test, _| assert_eq!(test.text.borrow().as_str(), "[")) + } } diff --git a/crates/gpui/src/keymap/context.rs b/crates/gpui/src/keymap/context.rs index f4b878ae77..281035fe97 100644 --- a/crates/gpui/src/keymap/context.rs +++ b/crates/gpui/src/keymap/context.rs @@ -461,6 +461,8 @@ fn skip_whitespace(source: &str) -> &str { #[cfg(test)] mod tests { + use core::slice; + use super::*; use crate as gpui; use KeyBindingContextPredicate::*; @@ -674,11 +676,11 @@ mod tests { assert!(predicate.eval(&contexts)); assert!(!predicate.eval(&[])); - assert!(!predicate.eval(&[child_context.clone()])); + assert!(!predicate.eval(slice::from_ref(&child_context))); assert!(!predicate.eval(&[parent_context])); let zany_predicate = KeyBindingContextPredicate::parse("child > child").unwrap(); - assert!(!zany_predicate.eval(&[child_context.clone()])); + assert!(!zany_predicate.eval(slice::from_ref(&child_context))); assert!(zany_predicate.eval(&[child_context.clone(), child_context.clone()])); } @@ -690,13 +692,13 @@ mod tests { let parent_context = KeyContext::try_from("parent").unwrap(); let child_context = KeyContext::try_from("child").unwrap(); - assert!(not_predicate.eval(&[workspace_context.clone()])); - assert!(!not_predicate.eval(&[editor_context.clone()])); + assert!(not_predicate.eval(slice::from_ref(&workspace_context))); + assert!(!not_predicate.eval(slice::from_ref(&editor_context))); assert!(!not_predicate.eval(&[editor_context.clone(), workspace_context.clone()])); assert!(!not_predicate.eval(&[workspace_context.clone(), editor_context.clone()])); let complex_not = KeyBindingContextPredicate::parse("!editor && workspace").unwrap(); - assert!(complex_not.eval(&[workspace_context.clone()])); + assert!(complex_not.eval(slice::from_ref(&workspace_context))); assert!(!complex_not.eval(&[editor_context.clone(), workspace_context.clone()])); let not_mode_predicate = KeyBindingContextPredicate::parse("!(mode == full)").unwrap(); @@ -709,18 +711,18 @@ mod tests { assert!(not_mode_predicate.eval(&[other_mode_context])); let not_descendant = KeyBindingContextPredicate::parse("!(parent > child)").unwrap(); - assert!(not_descendant.eval(&[parent_context.clone()])); - assert!(not_descendant.eval(&[child_context.clone()])); + assert!(not_descendant.eval(slice::from_ref(&parent_context))); + assert!(not_descendant.eval(slice::from_ref(&child_context))); assert!(!not_descendant.eval(&[parent_context.clone(), child_context.clone()])); let not_descendant = KeyBindingContextPredicate::parse("parent > !child").unwrap(); - assert!(!not_descendant.eval(&[parent_context.clone()])); - assert!(!not_descendant.eval(&[child_context.clone()])); + assert!(!not_descendant.eval(slice::from_ref(&parent_context))); + assert!(!not_descendant.eval(slice::from_ref(&child_context))); assert!(!not_descendant.eval(&[parent_context.clone(), child_context.clone()])); let double_not = KeyBindingContextPredicate::parse("!!editor").unwrap(); - assert!(double_not.eval(&[editor_context.clone()])); - assert!(!double_not.eval(&[workspace_context.clone()])); + assert!(double_not.eval(slice::from_ref(&editor_context))); + assert!(!double_not.eval(slice::from_ref(&workspace_context))); // Test complex descendant cases let workspace_context = KeyContext::try_from("Workspace").unwrap(); @@ -754,9 +756,9 @@ mod tests { // !Workspace - shouldn't match when Workspace is in the context let not_workspace = KeyBindingContextPredicate::parse("!Workspace").unwrap(); - assert!(!not_workspace.eval(&[workspace_context.clone()])); - assert!(not_workspace.eval(&[pane_context.clone()])); - assert!(not_workspace.eval(&[editor_context.clone()])); + assert!(!not_workspace.eval(slice::from_ref(&workspace_context))); + assert!(not_workspace.eval(slice::from_ref(&pane_context))); + assert!(not_workspace.eval(slice::from_ref(&editor_context))); assert!(!not_workspace.eval(&workspace_pane_editor)); } } diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index b495d70dfd..bf6ce68703 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -220,7 +220,11 @@ pub(crate) trait Platform: 'static { &self, options: PathPromptOptions, ) -> oneshot::Receiver>>>; - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>>; + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>>; fn can_select_mixed_files_and_dirs(&self) -> bool; fn reveal_path(&self, path: &Path); fn open_with_system(&self, path: &Path); diff --git a/crates/gpui/src/platform/app_menu.rs b/crates/gpui/src/platform/app_menu.rs index 2815cbdd7f..4069fee726 100644 --- a/crates/gpui/src/platform/app_menu.rs +++ b/crates/gpui/src/platform/app_menu.rs @@ -20,6 +20,34 @@ impl Menu { } } +/// OS menus are menus that are recognized by the operating system +/// This allows the operating system to provide specialized items for +/// these menus +pub struct OsMenu { + /// The name of the menu + pub name: SharedString, + + /// The type of menu + pub menu_type: SystemMenuType, +} + +impl OsMenu { + /// Create an OwnedOsMenu from this OsMenu + pub fn owned(self) -> OwnedOsMenu { + OwnedOsMenu { + name: self.name.to_string().into(), + menu_type: self.menu_type, + } + } +} + +/// The type of system menu +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum SystemMenuType { + /// The 'Services' menu in the Application menu on macOS + Services, +} + /// The different kinds of items that can be in a menu pub enum MenuItem { /// A separator between items @@ -28,6 +56,9 @@ pub enum MenuItem { /// A submenu Submenu(Menu), + /// A menu, managed by the system (for example, the Services menu on macOS) + SystemMenu(OsMenu), + /// An action that can be performed Action { /// The name of this menu item @@ -53,6 +84,14 @@ impl MenuItem { Self::Submenu(menu) } + /// Creates a new submenu that is populated by the OS + pub fn os_submenu(name: impl Into, menu_type: SystemMenuType) -> Self { + Self::SystemMenu(OsMenu { + name: name.into(), + menu_type, + }) + } + /// Creates a new menu item that invokes an action pub fn action(name: impl Into, action: impl Action) -> Self { Self::Action { @@ -89,10 +128,23 @@ impl MenuItem { action, os_action, }, + MenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.owned()), } } } +/// OS menus are menus that are recognized by the operating system +/// This allows the operating system to provide specialized items for +/// these menus +#[derive(Clone)] +pub struct OwnedOsMenu { + /// The name of the menu + pub name: SharedString, + + /// The type of menu + pub menu_type: SystemMenuType, +} + /// A menu of the application, either a main menu or a submenu #[derive(Clone)] pub struct OwnedMenu { @@ -111,6 +163,9 @@ pub enum OwnedMenuItem { /// A submenu Submenu(OwnedMenu), + /// A menu, managed by the system (for example, the Services menu on macOS) + SystemMenu(OwnedOsMenu), + /// An action that can be performed Action { /// The name of this menu item @@ -139,6 +194,7 @@ impl Clone for OwnedMenuItem { action: action.boxed_clone(), os_action: *os_action, }, + OwnedMenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.clone()), } } } diff --git a/crates/gpui/src/platform/blade/shaders.wgsl b/crates/gpui/src/platform/blade/shaders.wgsl index b1ffb1812e..95980b54fe 100644 --- a/crates/gpui/src/platform/blade/shaders.wgsl +++ b/crates/gpui/src/platform/blade/shaders.wgsl @@ -1057,6 +1057,9 @@ fn vs_underline(@builtin(vertex_index) vertex_id: u32, @builtin(instance_index) @fragment fn fs_underline(input: UnderlineVarying) -> @location(0) vec4 { + const WAVE_FREQUENCY: f32 = 2.0; + const WAVE_HEIGHT_RATIO: f32 = 0.8; + // Alpha clip first, since we don't have `clip_distance`. if (any(input.clip_distances < vec4(0.0))) { return vec4(0.0); @@ -1069,9 +1072,11 @@ fn fs_underline(input: UnderlineVarying) -> @location(0) vec4 { } let half_thickness = underline.thickness * 0.5; + let st = (input.position.xy - underline.bounds.origin) / underline.bounds.size.y - vec2(0.0, 0.5); - let frequency = M_PI_F * 3.0 * underline.thickness / 3.0; - let amplitude = 1.0 / (4.0 * underline.thickness); + let frequency = M_PI_F * WAVE_FREQUENCY * underline.thickness / underline.bounds.size.y; + let amplitude = (underline.thickness * WAVE_HEIGHT_RATIO) / underline.bounds.size.y; + let sine = sin(st.x * frequency) * amplitude; let dSine = cos(st.x * frequency) * amplitude * frequency; let distance = (st.y - sine) / sqrt(1.0 + dSine * dSine); diff --git a/crates/gpui/src/platform/linux/platform.rs b/crates/gpui/src/platform/linux/platform.rs index fe6a36baa8..31d445be52 100644 --- a/crates/gpui/src/platform/linux/platform.rs +++ b/crates/gpui/src/platform/linux/platform.rs @@ -327,26 +327,35 @@ impl Platform for P { done_rx } - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>> { let (done_tx, done_rx) = oneshot::channel(); #[cfg(not(any(feature = "wayland", feature = "x11")))] - let _ = (done_tx.send(Ok(None)), directory); + let _ = (done_tx.send(Ok(None)), directory, suggested_name); #[cfg(any(feature = "wayland", feature = "x11"))] self.foreground_executor() .spawn({ let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); async move { - let request = match ashpd::desktop::file_chooser::SaveFileRequest::default() - .modal(true) - .title("Save File") - .current_folder(directory) - .expect("pathbuf should not be nul terminated") - .send() - .await - { + let mut request_builder = + ashpd::desktop::file_chooser::SaveFileRequest::default() + .modal(true) + .title("Save File") + .current_folder(directory) + .expect("pathbuf should not be nul terminated"); + + if let Some(suggested_name) = suggested_name { + request_builder = request_builder.current_name(suggested_name.as_str()); + } + + let request = match request_builder.send().await { Ok(request) => request, Err(err) => { let result = match err { diff --git a/crates/gpui/src/platform/linux/text_system.rs b/crates/gpui/src/platform/linux/text_system.rs index e6f6e9a680..f66a2e71d4 100644 --- a/crates/gpui/src/platform/linux/text_system.rs +++ b/crates/gpui/src/platform/linux/text_system.rs @@ -213,11 +213,7 @@ impl CosmicTextSystemState { features: &FontFeatures, ) -> Result> { // TODO: Determine the proper system UI font. - let name = if name == ".SystemUIFont" { - "Zed Plex Sans" - } else { - name - }; + let name = crate::text_system::font_name_with_fallbacks(name, "IBM Plex Sans"); let families = self .font_system diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index 573e4addf7..053cd0387b 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -642,13 +642,7 @@ impl X11Client { let xim_connected = xim_handler.connected; drop(state); - let xim_filtered = match ximc.filter_event(&event, &mut xim_handler) { - Ok(handled) => handled, - Err(err) => { - log::error!("XIMClientError: {}", err); - false - } - }; + let xim_filtered = ximc.filter_event(&event, &mut xim_handler); let xim_callback_event = xim_handler.last_callback_event.take(); let mut state = self.0.borrow_mut(); @@ -659,14 +653,28 @@ impl X11Client { self.handle_xim_callback_event(event); } - if xim_filtered { - continue; - } - - if xim_connected { - self.xim_handle_event(event); - } else { - self.handle_event(event); + match xim_filtered { + Ok(handled) => { + if handled { + continue; + } + if xim_connected { + self.xim_handle_event(event); + } else { + self.handle_event(event); + } + } + Err(err) => { + // this might happen when xim server crashes on one of the events + // we do lose 1-2 keys when crash happens since there is no reliable way to get that info + // luckily, x11 sends us window not found error when xim server crashes upon further key press + // hence we fall back to handle_event + log::error!("XIMClientError: {}", err); + let mut state = self.0.borrow_mut(); + state.take_xim(); + drop(state); + self.handle_event(event); + } } } } diff --git a/crates/gpui/src/platform/mac/platform.rs b/crates/gpui/src/platform/mac/platform.rs index 1d2146cf73..533423229c 100644 --- a/crates/gpui/src/platform/mac/platform.rs +++ b/crates/gpui/src/platform/mac/platform.rs @@ -7,9 +7,9 @@ use super::{ use crate::{ Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString, CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher, - MacDisplay, MacWindow, Menu, MenuItem, OwnedMenu, PathPromptOptions, Platform, PlatformDisplay, - PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, SemanticVersion, Task, - WindowAppearance, WindowParams, hash, + MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform, + PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, + SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash, }; use anyhow::{Context as _, anyhow}; use block::ConcreteBlock; @@ -47,7 +47,7 @@ use objc::{ use parking_lot::Mutex; use ptr::null_mut; use std::{ - cell::{Cell, LazyCell}, + cell::Cell, convert::TryInto, ffi::{CStr, OsStr, c_void}, os::{raw::c_char, unix::ffi::OsStrExt}, @@ -56,7 +56,7 @@ use std::{ ptr, rc::Rc, slice, str, - sync::Arc, + sync::{Arc, OnceLock}, }; use strum::IntoEnumIterator; use util::ResultExt; @@ -296,18 +296,7 @@ impl MacPlatform { actions: &mut Vec>, keymap: &Keymap, ) -> id { - const DEFAULT_CONTEXT: LazyCell> = LazyCell::new(|| { - let mut workspace_context = KeyContext::new_with_defaults(); - workspace_context.add("Workspace"); - let mut pane_context = KeyContext::new_with_defaults(); - pane_context.add("Pane"); - let mut editor_context = KeyContext::new_with_defaults(); - editor_context.add("Editor"); - - pane_context.extend(&editor_context); - workspace_context.extend(&pane_context); - vec![workspace_context] - }); + static DEFAULT_CONTEXT: OnceLock> = OnceLock::new(); unsafe { match item { @@ -323,9 +312,20 @@ impl MacPlatform { let keystrokes = keymap .bindings_for_action(action.as_ref()) .find_or_first(|binding| { - binding - .predicate() - .is_none_or(|predicate| predicate.eval(&DEFAULT_CONTEXT)) + binding.predicate().is_none_or(|predicate| { + predicate.eval(DEFAULT_CONTEXT.get_or_init(|| { + let mut workspace_context = KeyContext::new_with_defaults(); + workspace_context.add("Workspace"); + let mut pane_context = KeyContext::new_with_defaults(); + pane_context.add("Pane"); + let mut editor_context = KeyContext::new_with_defaults(); + editor_context.add("Editor"); + + pane_context.extend(&editor_context); + workspace_context.extend(&pane_context); + vec![workspace_context] + })) + }) }) .map(|binding| binding.keystrokes()); @@ -413,9 +413,20 @@ impl MacPlatform { } item.setSubmenu_(submenu); item.setTitle_(ns_string(&name)); - if name == "Services" { - let app: id = msg_send![APP_CLASS, sharedApplication]; - app.setServicesMenu_(item); + item + } + MenuItem::SystemMenu(OsMenu { name, menu_type }) => { + let item = NSMenuItem::new(nil).autorelease(); + let submenu = NSMenu::new(nil).autorelease(); + submenu.setDelegate_(delegate); + item.setSubmenu_(submenu); + item.setTitle_(ns_string(&name)); + + match menu_type { + SystemMenuType::Services => { + let app: id = msg_send![APP_CLASS, sharedApplication]; + app.setServicesMenu_(item); + } } item @@ -726,8 +737,13 @@ impl Platform for MacPlatform { done_rx } - fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> oneshot::Receiver>> { let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); let (done_tx, done_rx) = oneshot::channel(); self.foreground_executor() .spawn(async move { @@ -737,6 +753,11 @@ impl Platform for MacPlatform { let url = NSURL::fileURLWithPath_isDirectory_(nil, path, true.to_objc()); panel.setDirectoryURL(url); + if let Some(suggested_name) = suggested_name { + let name_string = ns_string(&suggested_name); + let _: () = msg_send![panel, setNameFieldStringValue: name_string]; + } + let done_tx = Cell::new(Some(done_tx)); let block = ConcreteBlock::new(move |response: NSModalResponse| { let mut result = None; diff --git a/crates/gpui/src/platform/mac/shaders.metal b/crates/gpui/src/platform/mac/shaders.metal index f9d5bdbf4c..83c978b853 100644 --- a/crates/gpui/src/platform/mac/shaders.metal +++ b/crates/gpui/src/platform/mac/shaders.metal @@ -567,15 +567,20 @@ vertex UnderlineVertexOutput underline_vertex( fragment float4 underline_fragment(UnderlineFragmentInput input [[stage_in]], constant Underline *underlines [[buffer(UnderlineInputIndex_Underlines)]]) { + const float WAVE_FREQUENCY = 2.0; + const float WAVE_HEIGHT_RATIO = 0.8; + Underline underline = underlines[input.underline_id]; if (underline.wavy) { float half_thickness = underline.thickness * 0.5; float2 origin = float2(underline.bounds.origin.x, underline.bounds.origin.y); + float2 st = ((input.position.xy - origin) / underline.bounds.size.height) - float2(0., 0.5); - float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; - float amplitude = 1. / (2. * underline.thickness); + float frequency = (M_PI_F * WAVE_FREQUENCY * underline.thickness) / underline.bounds.size.height; + float amplitude = (underline.thickness * WAVE_HEIGHT_RATIO) / underline.bounds.size.height; + float sine = sin(st.x * frequency) * amplitude; float dSine = cos(st.x * frequency) * amplitude * frequency; float distance = (st.y - sine) / sqrt(1. + dSine * dSine); diff --git a/crates/gpui/src/platform/mac/text_system.rs b/crates/gpui/src/platform/mac/text_system.rs index c45888bce7..849925c727 100644 --- a/crates/gpui/src/platform/mac/text_system.rs +++ b/crates/gpui/src/platform/mac/text_system.rs @@ -211,11 +211,7 @@ impl MacTextSystemState { features: &FontFeatures, fallbacks: Option<&FontFallbacks>, ) -> Result> { - let name = if name == ".SystemUIFont" { - ".AppleSystemUIFont" - } else { - name - }; + let name = crate::text_system::font_name_with_fallbacks(name, ".AppleSystemUIFont"); let mut font_ids = SmallVec::new(); let family = self diff --git a/crates/gpui/src/platform/test/platform.rs b/crates/gpui/src/platform/test/platform.rs index a26b65576c..69371bc8c4 100644 --- a/crates/gpui/src/platform/test/platform.rs +++ b/crates/gpui/src/platform/test/platform.rs @@ -336,6 +336,7 @@ impl Platform for TestPlatform { fn prompt_for_new_path( &self, directory: &std::path::Path, + _suggested_name: Option<&str>, ) -> oneshot::Receiver>> { let (tx, rx) = oneshot::channel(); self.background_executor() diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 5268d3ccba..77e0ca41bf 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -10,6 +10,7 @@ mod keyboard; mod platform; mod system_settings; mod util; +mod vsync; mod window; mod wrapper; @@ -25,6 +26,7 @@ pub(crate) use keyboard::*; pub(crate) use platform::*; pub(crate) use system_settings::*; pub(crate) use util::*; +pub(crate) use vsync::*; pub(crate) use window::*; pub(crate) use wrapper::*; diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index 587cb7b4a6..75cb50243b 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -498,8 +498,9 @@ impl DirectWriteState { ) .unwrap() } else { + let family = self.system_ui_font_name.clone(); self.find_font_id( - target_font.family.as_ref(), + font_name_with_fallbacks(target_font.family.as_ref(), family.as_ref()), target_font.weight, target_font.style, &target_font.features, @@ -512,7 +513,6 @@ impl DirectWriteState { } #[cfg(not(any(test, feature = "test-support")))] { - let family = self.system_ui_font_name.clone(); log::error!("{} not found, use {} instead.", target_font.family, family); self.get_font_id_from_font_collection( family.as_ref(), diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs index 585b1dab1c..4e72ded534 100644 --- a/crates/gpui/src/platform/windows/directx_renderer.rs +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -4,16 +4,15 @@ use ::util::ResultExt; use anyhow::{Context, Result}; use windows::{ Win32::{ - Foundation::{FreeLibrary, HMODULE, HWND}, + Foundation::{HMODULE, HWND}, Graphics::{ Direct3D::*, Direct3D11::*, DirectComposition::*, Dxgi::{Common::*, *}, }, - System::LibraryLoader::LoadLibraryA, }, - core::{Interface, PCSTR}, + core::Interface, }; use crate::{ @@ -208,7 +207,7 @@ impl DirectXRenderer { fn present(&mut self) -> Result<()> { unsafe { - let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); + let result = self.resources.swap_chain.Present(0, DXGI_PRESENT(0)); // Presenting the swap chain can fail if the DirectX device was removed or reset. if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { let reason = self.devices.device.GetDeviceRemovedReason(); @@ -1619,22 +1618,6 @@ pub(crate) mod shader_resources { } } -fn with_dll_library(dll_name: PCSTR, f: F) -> Result -where - F: FnOnce(HMODULE) -> Result, -{ - let library = unsafe { - LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))? - }; - let result = f(library); - unsafe { - FreeLibrary(library) - .with_context(|| format!("Freeing dll: {}", dll_name.display())) - .log_err(); - } - result -} - mod nvidia { use std::{ ffi::CStr, @@ -1644,7 +1627,7 @@ mod nvidia { use anyhow::Result; use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; - use crate::platform::windows::directx_renderer::with_dll_library; + use crate::with_dll_library; // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 const NVAPI_SHORT_STRING_MAX: usize = 64; @@ -1711,7 +1694,7 @@ mod amd { use anyhow::Result; use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s}; - use crate::platform::windows::directx_renderer::with_dll_library; + use crate::with_dll_library; // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 01b043a755..c1fb0cabc4 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -32,7 +32,7 @@ use crate::*; pub(crate) struct WindowsPlatform { state: RefCell, - raw_window_handles: RwLock>, + raw_window_handles: Arc>>, // The below members will never change throughout the entire lifecycle of the app. icon: HICON, main_receiver: flume::Receiver, @@ -114,7 +114,7 @@ impl WindowsPlatform { }; let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); - let raw_window_handles = RwLock::new(SmallVec::new()); + let raw_window_handles = Arc::new(RwLock::new(SmallVec::new())); let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { @@ -134,22 +134,12 @@ impl WindowsPlatform { }) } - fn redraw_all(&self) { - for handle in self.raw_window_handles.read().iter() { - unsafe { - RedrawWindow(Some(*handle), None, None, RDW_INVALIDATE | RDW_UPDATENOW) - .ok() - .log_err(); - } - } - } - pub fn window_from_hwnd(&self, hwnd: HWND) -> Option> { self.raw_window_handles .read() .iter() - .find(|entry| *entry == &hwnd) - .and_then(|hwnd| window_from_hwnd(*hwnd)) + .find(|entry| entry.as_raw() == hwnd) + .and_then(|hwnd| window_from_hwnd(hwnd.as_raw())) } #[inline] @@ -158,7 +148,7 @@ impl WindowsPlatform { .read() .iter() .for_each(|handle| unsafe { - PostMessageW(Some(*handle), message, wparam, lparam).log_err(); + PostMessageW(Some(handle.as_raw()), message, wparam, lparam).log_err(); }); } @@ -166,7 +156,7 @@ impl WindowsPlatform { let mut lock = self.raw_window_handles.write(); let index = lock .iter() - .position(|handle| *handle == target_window) + .position(|handle| handle.as_raw() == target_window) .unwrap(); lock.remove(index); @@ -226,19 +216,19 @@ impl WindowsPlatform { } } - // Returns true if the app should quit. - fn handle_events(&self) -> bool { + // Returns if the app should quit. + fn handle_events(&self) { let mut msg = MSG::default(); unsafe { - while PeekMessageW(&mut msg, None, 0, 0, PM_REMOVE).as_bool() { + while GetMessageW(&mut msg, None, 0, 0).as_bool() { match msg.message { - WM_QUIT => return true, + WM_QUIT => return, WM_INPUTLANGCHANGE | WM_GPUI_CLOSE_ONE_WINDOW | WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD | WM_GPUI_DOCK_MENU_ACTION => { if self.handle_gpui_evnets(msg.message, msg.wParam, msg.lParam, &msg) { - return true; + return; } } _ => { @@ -247,7 +237,6 @@ impl WindowsPlatform { } } } - false } // Returns true if the app should quit. @@ -315,8 +304,28 @@ impl WindowsPlatform { self.raw_window_handles .read() .iter() - .find(|&&hwnd| hwnd == active_window_hwnd) - .copied() + .find(|hwnd| hwnd.as_raw() == active_window_hwnd) + .map(|hwnd| hwnd.as_raw()) + } + + fn begin_vsync_thread(&self) { + let all_windows = Arc::downgrade(&self.raw_window_handles); + std::thread::spawn(move || { + let vsync_provider = VSyncProvider::new(); + loop { + vsync_provider.wait_for_vsync(); + let Some(all_windows) = all_windows.upgrade() else { + break; + }; + for hwnd in all_windows.read().iter() { + unsafe { + RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE) + .ok() + .log_err(); + } + } + } + }); } } @@ -347,12 +356,8 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box) { on_finish_launching(); - loop { - if self.handle_events() { - break; - } - self.redraw_all(); - } + self.begin_vsync_thread(); + self.handle_events(); if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { callback(); @@ -365,9 +370,9 @@ impl Platform for WindowsPlatform { .detach(); } - fn restart(&self, _: Option) { + fn restart(&self, binary_path: Option) { let pid = std::process::id(); - let Some(app_path) = self.app_path().log_err() else { + let Some(app_path) = binary_path.or(self.app_path().log_err()) else { return; }; let script = format!( @@ -445,7 +450,7 @@ impl Platform for WindowsPlatform { ) -> Result> { let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; let handle = window.get_raw_handle(); - self.raw_window_handles.write().push(handle); + self.raw_window_handles.write().push(handle.into()); Ok(Box::new(window)) } @@ -485,13 +490,18 @@ impl Platform for WindowsPlatform { rx } - fn prompt_for_new_path(&self, directory: &Path) -> Receiver>> { + fn prompt_for_new_path( + &self, + directory: &Path, + suggested_name: Option<&str>, + ) -> Receiver>> { let directory = directory.to_owned(); + let suggested_name = suggested_name.map(|s| s.to_owned()); let (tx, rx) = oneshot::channel(); let window = self.find_current_active_window(); self.foreground_executor() .spawn(async move { - let _ = tx.send(file_save_dialog(directory, window)); + let _ = tx.send(file_save_dialog(directory, suggested_name, window)); }) .detach(); @@ -799,7 +809,11 @@ fn file_open_dialog( Ok(Some(paths)) } -fn file_save_dialog(directory: PathBuf, window: Option) -> Result> { +fn file_save_dialog( + directory: PathBuf, + suggested_name: Option, + window: Option, +) -> Result> { let dialog: IFileSaveDialog = unsafe { CoCreateInstance(&FileSaveDialog, None, CLSCTX_ALL)? }; if !directory.to_string_lossy().is_empty() { if let Some(full_path) = directory.canonicalize().log_err() { @@ -810,6 +824,11 @@ fn file_save_dialog(directory: PathBuf, window: Option) -> Result(dll_name: PCSTR, f: F) -> Result +where + F: FnOnce(HMODULE) -> Result, +{ + let library = unsafe { + LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))? + }; + let result = f(library); + unsafe { + FreeLibrary(library) + .with_context(|| format!("Freeing dll: {}", dll_name.display())) + .log_err(); + } + result +} diff --git a/crates/gpui/src/platform/windows/vsync.rs b/crates/gpui/src/platform/windows/vsync.rs new file mode 100644 index 0000000000..6d09b0960f --- /dev/null +++ b/crates/gpui/src/platform/windows/vsync.rs @@ -0,0 +1,174 @@ +use std::{ + sync::LazyLock, + time::{Duration, Instant}, +}; + +use anyhow::{Context, Result}; +use util::ResultExt; +use windows::{ + Win32::{ + Foundation::{HANDLE, HWND}, + Graphics::{ + DirectComposition::{ + COMPOSITION_FRAME_ID_COMPLETED, COMPOSITION_FRAME_ID_TYPE, COMPOSITION_FRAME_STATS, + COMPOSITION_TARGET_ID, + }, + Dwm::{DWM_TIMING_INFO, DwmFlush, DwmGetCompositionTimingInfo}, + }, + System::{ + LibraryLoader::{GetModuleHandleA, GetProcAddress}, + Performance::QueryPerformanceFrequency, + Threading::INFINITE, + }, + }, + core::{HRESULT, s}, +}; + +static QPC_TICKS_PER_SECOND: LazyLock = LazyLock::new(|| { + let mut frequency = 0; + // On systems that run Windows XP or later, the function will always succeed and + // will thus never return zero. + unsafe { QueryPerformanceFrequency(&mut frequency).unwrap() }; + frequency as u64 +}); + +const VSYNC_INTERVAL_THRESHOLD: Duration = Duration::from_millis(1); +const DEFAULT_VSYNC_INTERVAL: Duration = Duration::from_micros(16_666); // ~60Hz + +// Here we are using dynamic loading of DirectComposition functions, +// or the app will refuse to start on windows systems that do not support DirectComposition. +type DCompositionGetFrameId = + unsafe extern "system" fn(frameidtype: COMPOSITION_FRAME_ID_TYPE, frameid: *mut u64) -> HRESULT; +type DCompositionGetStatistics = unsafe extern "system" fn( + frameid: u64, + framestats: *mut COMPOSITION_FRAME_STATS, + targetidcount: u32, + targetids: *mut COMPOSITION_TARGET_ID, + actualtargetidcount: *mut u32, +) -> HRESULT; +type DCompositionWaitForCompositorClock = + unsafe extern "system" fn(count: u32, handles: *const HANDLE, timeoutinms: u32) -> u32; + +pub(crate) struct VSyncProvider { + interval: Duration, + f: Box bool>, +} + +impl VSyncProvider { + pub(crate) fn new() -> Self { + if let Some((get_frame_id, get_statistics, wait_for_comp_clock)) = + initialize_direct_composition() + .context("Retrieving DirectComposition functions") + .log_with_level(log::Level::Warn) + { + let interval = get_dwm_interval_from_direct_composition(get_frame_id, get_statistics) + .context("Failed to get DWM interval from DirectComposition") + .log_err() + .unwrap_or(DEFAULT_VSYNC_INTERVAL); + log::info!( + "DirectComposition is supported for VSync, interval: {:?}", + interval + ); + let f = Box::new(move || unsafe { + wait_for_comp_clock(0, std::ptr::null(), INFINITE) == 0 + }); + Self { interval, f } + } else { + let interval = get_dwm_interval() + .context("Failed to get DWM interval") + .log_err() + .unwrap_or(DEFAULT_VSYNC_INTERVAL); + log::info!( + "DirectComposition is not supported for VSync, falling back to DWM, interval: {:?}", + interval + ); + let f = Box::new(|| unsafe { DwmFlush().is_ok() }); + Self { interval, f } + } + } + + pub(crate) fn wait_for_vsync(&self) { + let vsync_start = Instant::now(); + let wait_succeeded = (self.f)(); + let elapsed = vsync_start.elapsed(); + // DwmFlush and DCompositionWaitForCompositorClock returns very early + // instead of waiting until vblank when the monitor goes to sleep or is + // unplugged (nothing to present due to desktop occlusion). We use 1ms as + // a threshhold for the duration of the wait functions and fallback to + // Sleep() if it returns before that. This could happen during normal + // operation for the first call after the vsync thread becomes non-idle, + // but it shouldn't happen often. + if !wait_succeeded || elapsed < VSYNC_INTERVAL_THRESHOLD { + log::trace!("VSyncProvider::wait_for_vsync() took less time than expected"); + std::thread::sleep(self.interval); + } + } +} + +fn initialize_direct_composition() -> Result<( + DCompositionGetFrameId, + DCompositionGetStatistics, + DCompositionWaitForCompositorClock, +)> { + unsafe { + // Load DLL at runtime since older Windows versions don't have dcomp. + let hmodule = GetModuleHandleA(s!("dcomp.dll")).context("Loading dcomp.dll")?; + let get_frame_id_addr = GetProcAddress(hmodule, s!("DCompositionGetFrameId")) + .context("Function DCompositionGetFrameId not found")?; + let get_statistics_addr = GetProcAddress(hmodule, s!("DCompositionGetStatistics")) + .context("Function DCompositionGetStatistics not found")?; + let wait_for_compositor_clock_addr = + GetProcAddress(hmodule, s!("DCompositionWaitForCompositorClock")) + .context("Function DCompositionWaitForCompositorClock not found")?; + let get_frame_id: DCompositionGetFrameId = std::mem::transmute(get_frame_id_addr); + let get_statistics: DCompositionGetStatistics = std::mem::transmute(get_statistics_addr); + let wait_for_compositor_clock: DCompositionWaitForCompositorClock = + std::mem::transmute(wait_for_compositor_clock_addr); + Ok((get_frame_id, get_statistics, wait_for_compositor_clock)) + } +} + +fn get_dwm_interval_from_direct_composition( + get_frame_id: DCompositionGetFrameId, + get_statistics: DCompositionGetStatistics, +) -> Result { + let mut frame_id = 0; + unsafe { get_frame_id(COMPOSITION_FRAME_ID_COMPLETED, &mut frame_id) }.ok()?; + let mut stats = COMPOSITION_FRAME_STATS::default(); + unsafe { + get_statistics( + frame_id, + &mut stats, + 0, + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + } + .ok()?; + Ok(retrieve_duration(stats.framePeriod, *QPC_TICKS_PER_SECOND)) +} + +fn get_dwm_interval() -> Result { + let mut timing_info = DWM_TIMING_INFO { + cbSize: std::mem::size_of::() as u32, + ..Default::default() + }; + unsafe { DwmGetCompositionTimingInfo(HWND::default(), &mut timing_info) }?; + let interval = retrieve_duration(timing_info.qpcRefreshPeriod, *QPC_TICKS_PER_SECOND); + // Check for interval values that are impossibly low. A 29 microsecond + // interval was seen (from a qpcRefreshPeriod of 60). + if interval < VSYNC_INTERVAL_THRESHOLD { + Ok(retrieve_duration( + timing_info.rateRefresh.uiDenominator as u64, + timing_info.rateRefresh.uiNumerator as u64, + )) + } else { + Ok(interval) + } +} + +#[inline] +fn retrieve_duration(counts: u64, ticks_per_second: u64) -> Duration { + let ticks_per_microsecond = ticks_per_second / 1_000_000; + Duration::from_micros(counts / ticks_per_microsecond) +} diff --git a/crates/gpui/src/platform/windows/wrapper.rs b/crates/gpui/src/platform/windows/wrapper.rs index 6015dffdab..60bbc433ca 100644 --- a/crates/gpui/src/platform/windows/wrapper.rs +++ b/crates/gpui/src/platform/windows/wrapper.rs @@ -1,28 +1,6 @@ use std::ops::Deref; -use windows::Win32::{Foundation::HANDLE, UI::WindowsAndMessaging::HCURSOR}; - -#[derive(Debug, Clone, Copy)] -pub(crate) struct SafeHandle { - raw: HANDLE, -} - -unsafe impl Send for SafeHandle {} -unsafe impl Sync for SafeHandle {} - -impl From for SafeHandle { - fn from(value: HANDLE) -> Self { - SafeHandle { raw: value } - } -} - -impl Deref for SafeHandle { - type Target = HANDLE; - - fn deref(&self) -> &Self::Target { - &self.raw - } -} +use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::HCURSOR}; #[derive(Debug, Clone, Copy)] pub(crate) struct SafeCursor { @@ -45,3 +23,31 @@ impl Deref for SafeCursor { &self.raw } } + +#[derive(Debug, Clone, Copy)] +pub(crate) struct SafeHwnd { + raw: HWND, +} + +impl SafeHwnd { + pub(crate) fn as_raw(&self) -> HWND { + self.raw + } +} + +unsafe impl Send for SafeHwnd {} +unsafe impl Sync for SafeHwnd {} + +impl From for SafeHwnd { + fn from(value: HWND) -> Self { + SafeHwnd { raw: value } + } +} + +impl Deref for SafeHwnd { + type Target = HWND; + + fn deref(&self) -> &Self::Target { + &self.raw + } +} diff --git a/crates/gpui/src/scene.rs b/crates/gpui/src/scene.rs index c527dfe750..758d06e597 100644 --- a/crates/gpui/src/scene.rs +++ b/crates/gpui/src/scene.rs @@ -476,7 +476,7 @@ pub(crate) struct Underline { pub content_mask: ContentMask, pub color: Hsla, pub thickness: ScaledPixels, - pub wavy: bool, + pub wavy: u32, } impl From for Primitive { diff --git a/crates/gpui/src/style.rs b/crates/gpui/src/style.rs index 560de7b924..09985722ef 100644 --- a/crates/gpui/src/style.rs +++ b/crates/gpui/src/style.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ AbsoluteLength, App, Background, BackgroundTag, BorderStyle, Bounds, ContentMask, Corners, CornersRefinement, CursorStyle, DefiniteLength, DevicePixels, Edges, EdgesRefinement, Font, - FontFallbacks, FontFeatures, FontStyle, FontWeight, Hsla, Length, Pixels, Point, + FontFallbacks, FontFeatures, FontStyle, FontWeight, GridLocation, Hsla, Length, Pixels, Point, PointRefinement, Rgba, SharedString, Size, SizeRefinement, Styled, TextRun, Window, black, phi, point, quad, rems, size, }; @@ -260,6 +260,17 @@ pub struct Style { /// The opacity of this element pub opacity: Option, + /// The grid columns of this element + /// Equivalent to the Tailwind `grid-cols-` + pub grid_cols: Option, + + /// The row span of this element + /// Equivalent to the Tailwind `grid-rows-` + pub grid_rows: Option, + + /// The grid location of this element + pub grid_location: Option, + /// Whether to draw a red debugging outline around this element #[cfg(debug_assertions)] pub debug: bool, @@ -275,6 +286,13 @@ impl Styled for StyleRefinement { } } +impl StyleRefinement { + /// The grid location of this element + pub fn grid_location_mut(&mut self) -> &mut GridLocation { + self.grid_location.get_or_insert_default() + } +} + /// The value of the visibility property, similar to the CSS property `visibility` #[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] pub enum Visibility { @@ -757,6 +775,9 @@ impl Default for Style { text: TextStyleRefinement::default(), mouse_cursor: None, opacity: None, + grid_rows: None, + grid_cols: None, + grid_location: None, #[cfg(debug_assertions)] debug: false, diff --git a/crates/gpui/src/styled.rs b/crates/gpui/src/styled.rs index b689f32687..c714cac14f 100644 --- a/crates/gpui/src/styled.rs +++ b/crates/gpui/src/styled.rs @@ -1,8 +1,8 @@ use crate::{ self as gpui, AbsoluteLength, AlignContent, AlignItems, BorderStyle, CursorStyle, - DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight, Hsla, - JustifyContent, Length, SharedString, StrikethroughStyle, StyleRefinement, TextAlign, - TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px, relative, rems, + DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontStyle, FontWeight, + GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle, StyleRefinement, + TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px, relative, rems, }; pub use gpui_macros::{ border_style_methods, box_shadow_style_methods, cursor_style_methods, margin_style_methods, @@ -46,6 +46,13 @@ pub trait Styled: Sized { self } + /// Sets the display type of the element to `grid`. + /// [Docs](https://tailwindcss.com/docs/display) + fn grid(mut self) -> Self { + self.style().display = Some(Display::Grid); + self + } + /// Sets the whitespace of the element to `normal`. /// [Docs](https://tailwindcss.com/docs/whitespace#normal) fn whitespace_normal(mut self) -> Self { @@ -640,6 +647,102 @@ pub trait Styled: Sized { self } + /// Sets the grid columns of this element. + fn grid_cols(mut self, cols: u16) -> Self { + self.style().grid_cols = Some(cols); + self + } + + /// Sets the grid rows of this element. + fn grid_rows(mut self, rows: u16) -> Self { + self.style().grid_rows = Some(rows); + self + } + + /// Sets the column start of this element. + fn col_start(mut self, start: i16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column.start = GridPlacement::Line(start); + self + } + + /// Sets the column start of this element to auto. + fn col_start_auto(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column.start = GridPlacement::Auto; + self + } + + /// Sets the column end of this element. + fn col_end(mut self, end: i16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column.end = GridPlacement::Line(end); + self + } + + /// Sets the column end of this element to auto. + fn col_end_auto(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column.end = GridPlacement::Auto; + self + } + + /// Sets the column span of this element. + fn col_span(mut self, span: u16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column = GridPlacement::Span(span)..GridPlacement::Span(span); + self + } + + /// Sets the row span of this element. + fn col_span_full(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.column = GridPlacement::Line(1)..GridPlacement::Line(-1); + self + } + + /// Sets the row start of this element. + fn row_start(mut self, start: i16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row.start = GridPlacement::Line(start); + self + } + + /// Sets the row start of this element to "auto" + fn row_start_auto(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row.start = GridPlacement::Auto; + self + } + + /// Sets the row end of this element. + fn row_end(mut self, end: i16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row.end = GridPlacement::Line(end); + self + } + + /// Sets the row end of this element to "auto" + fn row_end_auto(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row.end = GridPlacement::Auto; + self + } + + /// Sets the row span of this element. + fn row_span(mut self, span: u16) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row = GridPlacement::Span(span)..GridPlacement::Span(span); + self + } + + /// Sets the row span of this element. + fn row_span_full(mut self) -> Self { + let grid_location = self.style().grid_location_mut(); + grid_location.row = GridPlacement::Line(1)..GridPlacement::Line(-1); + self + } + /// Draws a debug border around this element. #[cfg(debug_assertions)] fn debug(mut self) -> Self { diff --git a/crates/gpui/src/taffy.rs b/crates/gpui/src/taffy.rs index f7fa54256d..ee21ecd8c4 100644 --- a/crates/gpui/src/taffy.rs +++ b/crates/gpui/src/taffy.rs @@ -3,7 +3,7 @@ use crate::{ }; use collections::{FxHashMap, FxHashSet}; use smallvec::SmallVec; -use std::fmt::Debug; +use std::{fmt::Debug, ops::Range}; use taffy::{ TaffyTree, TraversePartialTree as _, geometry::{Point as TaffyPoint, Rect as TaffyRect, Size as TaffySize}, @@ -251,6 +251,25 @@ trait ToTaffy { impl ToTaffy for Style { fn to_taffy(&self, rem_size: Pixels) -> taffy::style::Style { + use taffy::style_helpers::{fr, length, minmax, repeat}; + + fn to_grid_line( + placement: &Range, + ) -> taffy::Line { + taffy::Line { + start: placement.start.into(), + end: placement.end.into(), + } + } + + fn to_grid_repeat( + unit: &Option, + ) -> Vec> { + // grid-template-columns: repeat(, minmax(0, 1fr)); + unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), fr(1.0))])]) + .unwrap_or_default() + } + taffy::style::Style { display: self.display.into(), overflow: self.overflow.into(), @@ -274,7 +293,19 @@ impl ToTaffy for Style { flex_basis: self.flex_basis.to_taffy(rem_size), flex_grow: self.flex_grow, flex_shrink: self.flex_shrink, - ..Default::default() // Ignore grid properties for now + grid_template_rows: to_grid_repeat(&self.grid_rows), + grid_template_columns: to_grid_repeat(&self.grid_cols), + grid_row: self + .grid_location + .as_ref() + .map(|location| to_grid_line(&location.row)) + .unwrap_or_default(), + grid_column: self + .grid_location + .as_ref() + .map(|location| to_grid_line(&location.column)) + .unwrap_or_default(), + ..Default::default() } } } diff --git a/crates/gpui/src/text_system.rs b/crates/gpui/src/text_system.rs index ed1307c6cd..b48c3a2935 100644 --- a/crates/gpui/src/text_system.rs +++ b/crates/gpui/src/text_system.rs @@ -65,7 +65,7 @@ impl TextSystem { font_runs_pool: Mutex::default(), fallback_font_stack: smallvec![ // TODO: Remove this when Linux have implemented setting fallbacks. - font("Zed Plex Mono"), + font(".ZedMono"), font("Helvetica"), font("Segoe UI"), // Windows font("Cantarell"), // Gnome @@ -96,7 +96,7 @@ impl TextSystem { } /// Get the FontId for the configure font family and style. - pub fn font_id(&self, font: &Font) -> Result { + fn font_id(&self, font: &Font) -> Result { fn clone_font_id_result(font_id: &Result) -> Result { match font_id { Ok(font_id) => Ok(*font_id), @@ -844,3 +844,16 @@ impl FontMetrics { (self.bounding_box / self.units_per_em as f32 * font_size.0).map(px) } } + +#[allow(unused)] +pub(crate) fn font_name_with_fallbacks<'a>(name: &'a str, system: &'a str) -> &'a str { + // Note: the "Zed Plex" fonts were deprecated as we are not allowed to use "Plex" + // in a derived font name. They are essentially indistinguishable from IBM Plex/Lilex, + // and so retained here for backward compatibility. + match name { + ".SystemUIFont" => system, + ".ZedSans" | "Zed Plex Sans" => "IBM Plex Sans", + ".ZedMono" | "Zed Plex Mono" => "Lilex", + _ => name, + } +} diff --git a/crates/gpui/src/text_system/line_wrapper.rs b/crates/gpui/src/text_system/line_wrapper.rs index 5de26511d3..648d714c89 100644 --- a/crates/gpui/src/text_system/line_wrapper.rs +++ b/crates/gpui/src/text_system/line_wrapper.rs @@ -327,7 +327,7 @@ mod tests { fn build_wrapper() -> LineWrapper { let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); let cx = TestAppContext::build(dispatcher, None); - let id = cx.text_system().font_id(&font("Zed Plex Mono")).unwrap(); + let id = cx.text_system().resolve_font(&font(".ZedMono")); LineWrapper::new(id, px(16.), cx.text_system().platform_text_system.clone()) } diff --git a/crates/gpui/src/util.rs b/crates/gpui/src/util.rs index 5e92335fdc..f357034fbf 100644 --- a/crates/gpui/src/util.rs +++ b/crates/gpui/src/util.rs @@ -1,13 +1,11 @@ -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -#[cfg(any(test, feature = "test-support"))] -use std::time::Duration; - -#[cfg(any(test, feature = "test-support"))] -use futures::Future; - -#[cfg(any(test, feature = "test-support"))] -use smol::future::FutureExt; +use crate::{BackgroundExecutor, Task}; +use std::{ + future::Future, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering::SeqCst}, + task, + time::Duration, +}; pub use util::*; @@ -70,8 +68,59 @@ pub trait FluentBuilder { } } +/// Extensions for Future types that provide additional combinators and utilities. +pub trait FutureExt { + /// Requires a Future to complete before the specified duration has elapsed. + /// Similar to tokio::timeout. + fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout + where + Self: Sized; +} + +impl FutureExt for T { + fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout + where + Self: Sized, + { + WithTimeout { + future: self, + timer: executor.timer(timeout), + } + } +} + +pub struct WithTimeout { + future: T, + timer: Task<()>, +} + +#[derive(Debug, thiserror::Error)] +#[error("Timed out before future resolved")] +/// Error returned by with_timeout when the timeout duration elapsed before the future resolved +pub struct Timeout; + +impl Future for WithTimeout { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { + // SAFETY: the fields of Timeout are private and we never move the future ourselves + // And its already pinned since we are being polled (all futures need to be pinned to be polled) + let this = unsafe { self.get_unchecked_mut() }; + let future = unsafe { Pin::new_unchecked(&mut this.future) }; + let timer = unsafe { Pin::new_unchecked(&mut this.timer) }; + + if let task::Poll::Ready(output) = future.poll(cx) { + task::Poll::Ready(Ok(output)) + } else if timer.poll(cx).is_ready() { + task::Poll::Ready(Err(Timeout)) + } else { + task::Poll::Pending + } + } +} + #[cfg(any(test, feature = "test-support"))] -pub async fn timeout(timeout: Duration, f: F) -> Result +pub async fn smol_timeout(timeout: Duration, f: F) -> Result where F: Future, { @@ -80,7 +129,7 @@ where Err(()) }; let future = async move { Ok(f.await) }; - timer.race(future).await + smol::future::FutureExt::race(timer, future).await } /// Increment the given atomic counter if it is not zero. diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 40d3845ff9..c0ffd34a0d 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -2814,7 +2814,7 @@ impl Window { content_mask: content_mask.scale(scale_factor), color: style.color.unwrap_or_default().opacity(element_opacity), thickness: style.thickness.scale(scale_factor), - wavy: style.wavy, + wavy: if style.wavy { 1 } else { 0 }, }); } @@ -2845,7 +2845,7 @@ impl Window { content_mask: content_mask.scale(scale_factor), thickness: style.thickness.scale(scale_factor), color: style.color.unwrap_or_default().opacity(opacity), - wavy: false, + wavy: 0, }); } @@ -3688,7 +3688,8 @@ impl Window { ); if !match_result.to_replay.is_empty() { - self.replay_pending_input(match_result.to_replay, cx) + self.replay_pending_input(match_result.to_replay, cx); + cx.propagate_event = true; } if !match_result.pending.is_empty() { diff --git a/crates/gpui_macros/src/test.rs b/crates/gpui_macros/src/test.rs index 2c52149897..adb27f42ea 100644 --- a/crates/gpui_macros/src/test.rs +++ b/crates/gpui_macros/src/test.rs @@ -167,6 +167,7 @@ fn generate_test_function( )); cx_teardowns.extend(quote!( dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); #cx_varname.quit(); dispatcher.run_until_parked(); )); @@ -232,7 +233,7 @@ fn generate_test_function( cx_teardowns.extend(quote!( drop(#cx_varname_lock); dispatcher.run_until_parked(); - #cx_varname.update(|cx| { cx.quit() }); + #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); }); dispatcher.run_until_parked(); )); continue; @@ -247,6 +248,7 @@ fn generate_test_function( )); cx_teardowns.extend(quote!( dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); #cx_varname.quit(); dispatcher.run_until_parked(); )); diff --git a/crates/http_client/src/github.rs b/crates/http_client/src/github.rs index a19c13b0ff..89309ff344 100644 --- a/crates/http_client/src/github.rs +++ b/crates/http_client/src/github.rs @@ -71,11 +71,19 @@ pub async fn latest_github_release( } }; - releases + let mut release = releases .into_iter() .filter(|release| !require_assets || !release.assets.is_empty()) .find(|release| release.pre_release == pre_release) - .context("finding a prerelease") + .context("finding a prerelease")?; + release.assets.iter_mut().for_each(|asset| { + if let Some(digest) = &mut asset.digest { + if let Some(stripped) = digest.strip_prefix("sha256:") { + *digest = stripped.to_owned(); + } + } + }); + Ok(release) } pub async fn get_release_by_tag_name( diff --git a/crates/icons/README.md b/crates/icons/README.md index 5fbd6d4948..71bc5c8545 100644 --- a/crates/icons/README.md +++ b/crates/icons/README.md @@ -3,27 +3,27 @@ ## Guidelines Icons are a big part of Zed, and they're how we convey hundreds of actions without relying on labeled buttons. -When introducing a new icon to the set, it's important to ensure it is consistent with the whole set, which follows a few guidelines: +When introducing a new icon, it's important to ensure consistency with the existing set, which follows these guidelines: 1. The SVG view box should be 16x16. 2. For outlined icons, use a 1.5px stroke width. -3. Not all icons are mathematically aligned; there's quite a bit of optical adjustment. But try to keep the icon within an internal 12x12 bounding box as much as possible while ensuring proper visibility. -4. Use the `filled` and `outlined` terminology when introducing icons that will have the two variants. +3. Not all icons are mathematically aligned; there's quite a bit of optical adjustment. However, try to keep the icon within an internal 12x12 bounding box as much as possible while ensuring proper visibility. +4. Use the `filled` and `outlined` terminology when introducing icons that will have these two variants. 5. Icons that are deeply contextual may have the feature context as their name prefix. For example, `ToolWeb`, `ReplPlay`, `DebugStepInto`, etc. -6. Avoid complex layer structure in the icon SVG, like clipping masks and whatnot. When the shape ends up too complex, we recommend running the SVG in [SVGOMG](https://jakearchibald.github.io/svgomg/) to clean it up a bit. +6. Avoid complex layer structures in the icon SVG, like clipping masks and similar elements. When the shape becomes too complex, we recommend running the SVG through [SVGOMG](https://jakearchibald.github.io/svgomg/) to clean it up. ## Sourcing Most icons are created by sourcing them from [Lucide](https://lucide.dev/). Then, they're modified, adjusted, cleaned up, and simplified depending on their use and overall fit with Zed. -Sometimes, we may use other sources like [Phosphor](https://phosphoricons.com/), but we also design many of them completely from scratch. +Sometimes, we may use other sources like [Phosphor](https://phosphoricons.com/), but we also design many icons completely from scratch. ## Contributing -To introduce a new icon, add the `.svg` file in the `assets/icon` directory and then add its corresponding item in the `icons.rs` file within the `crates` directory. +To introduce a new icon, add the `.svg` file to the `assets/icon` directory and then add its corresponding item to the `icons.rs` file within the `crates` directory. -- SVG files in the assets folder follow a snake case name format. -- Icons in the `icons.rs` file follow the pascal case name format. +- SVG files in the assets folder follow a snake_case name format. +- Icons in the `icons.rs` file follow the PascalCase name format. -Ensure you tag a member of Zed's design team so we can adjust and double-check any newly introduced icon. +Make sure to tag a member of Zed's design team so we can review and adjust any newly introduced icon. diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index f5c2a83fec..8bd76cbecf 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -140,6 +140,7 @@ pub enum IconName { Image, Indicator, Info, + Json, Keyboard, Library, LineHeight, diff --git a/crates/indexed_docs/Cargo.toml b/crates/indexed_docs/Cargo.toml deleted file mode 100644 index eb269ad939..0000000000 --- a/crates/indexed_docs/Cargo.toml +++ /dev/null @@ -1,38 +0,0 @@ -[package] -name = "indexed_docs" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/indexed_docs.rs" - -[dependencies] -anyhow.workspace = true -async-trait.workspace = true -cargo_metadata.workspace = true -collections.workspace = true -derive_more.workspace = true -extension.workspace = true -fs.workspace = true -futures.workspace = true -fuzzy.workspace = true -gpui.workspace = true -heed.workspace = true -html_to_markdown.workspace = true -http_client.workspace = true -indexmap.workspace = true -parking_lot.workspace = true -paths.workspace = true -serde.workspace = true -strum.workspace = true -util.workspace = true -workspace-hack.workspace = true - -[dev-dependencies] -indoc.workspace = true -pretty_assertions.workspace = true diff --git a/crates/indexed_docs/src/extension_indexed_docs_provider.rs b/crates/indexed_docs/src/extension_indexed_docs_provider.rs deleted file mode 100644 index c77ea4066d..0000000000 --- a/crates/indexed_docs/src/extension_indexed_docs_provider.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; -use extension::{Extension, ExtensionHostProxy, ExtensionIndexedDocsProviderProxy}; -use gpui::App; - -use crate::{ - IndexedDocsDatabase, IndexedDocsProvider, IndexedDocsRegistry, PackageName, ProviderId, -}; - -pub fn init(cx: &mut App) { - let proxy = ExtensionHostProxy::default_global(cx); - proxy.register_indexed_docs_provider_proxy(IndexedDocsRegistryProxy { - indexed_docs_registry: IndexedDocsRegistry::global(cx), - }); -} - -struct IndexedDocsRegistryProxy { - indexed_docs_registry: Arc, -} - -impl ExtensionIndexedDocsProviderProxy for IndexedDocsRegistryProxy { - fn register_indexed_docs_provider(&self, extension: Arc, provider_id: Arc) { - self.indexed_docs_registry - .register_provider(Box::new(ExtensionIndexedDocsProvider::new( - extension, - ProviderId(provider_id), - ))); - } - - fn unregister_indexed_docs_provider(&self, provider_id: Arc) { - self.indexed_docs_registry - .unregister_provider(&ProviderId(provider_id)); - } -} - -pub struct ExtensionIndexedDocsProvider { - extension: Arc, - id: ProviderId, -} - -impl ExtensionIndexedDocsProvider { - pub fn new(extension: Arc, id: ProviderId) -> Self { - Self { extension, id } - } -} - -#[async_trait] -impl IndexedDocsProvider for ExtensionIndexedDocsProvider { - fn id(&self) -> ProviderId { - self.id.clone() - } - - fn database_path(&self) -> PathBuf { - let mut database_path = PathBuf::from(self.extension.work_dir().as_ref()); - database_path.push("docs"); - database_path.push(format!("{}.0.mdb", self.id)); - - database_path - } - - async fn suggest_packages(&self) -> Result> { - let packages = self - .extension - .suggest_docs_packages(self.id.0.clone()) - .await?; - - Ok(packages - .into_iter() - .map(|package| PackageName::from(package.as_str())) - .collect()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - self.extension - .index_docs(self.id.0.clone(), package.as_ref().into(), database) - .await - } -} diff --git a/crates/indexed_docs/src/indexed_docs.rs b/crates/indexed_docs/src/indexed_docs.rs deleted file mode 100644 index 97538329d4..0000000000 --- a/crates/indexed_docs/src/indexed_docs.rs +++ /dev/null @@ -1,16 +0,0 @@ -mod extension_indexed_docs_provider; -mod providers; -mod registry; -mod store; - -use gpui::App; - -pub use crate::extension_indexed_docs_provider::ExtensionIndexedDocsProvider; -pub use crate::providers::rustdoc::*; -pub use crate::registry::*; -pub use crate::store::*; - -pub fn init(cx: &mut App) { - IndexedDocsRegistry::init_global(cx); - extension_indexed_docs_provider::init(cx); -} diff --git a/crates/indexed_docs/src/providers.rs b/crates/indexed_docs/src/providers.rs deleted file mode 100644 index c6505a2ab6..0000000000 --- a/crates/indexed_docs/src/providers.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod rustdoc; diff --git a/crates/indexed_docs/src/providers/rustdoc.rs b/crates/indexed_docs/src/providers/rustdoc.rs deleted file mode 100644 index ac6dc3a10b..0000000000 --- a/crates/indexed_docs/src/providers/rustdoc.rs +++ /dev/null @@ -1,291 +0,0 @@ -mod item; -mod to_markdown; - -use cargo_metadata::MetadataCommand; -use futures::future::BoxFuture; -pub use item::*; -use parking_lot::RwLock; -pub use to_markdown::convert_rustdoc_to_markdown; - -use std::collections::BTreeSet; -use std::path::PathBuf; -use std::sync::{Arc, LazyLock}; -use std::time::{Duration, Instant}; - -use anyhow::{Context as _, Result, bail}; -use async_trait::async_trait; -use collections::{HashSet, VecDeque}; -use fs::Fs; -use futures::{AsyncReadExt, FutureExt}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; - -use crate::{IndexedDocsDatabase, IndexedDocsProvider, PackageName, ProviderId}; - -#[derive(Debug)] -struct RustdocItemWithHistory { - pub item: RustdocItem, - #[cfg(debug_assertions)] - pub history: Vec, -} - -pub struct LocalRustdocProvider { - fs: Arc, - cargo_workspace_root: PathBuf, -} - -impl LocalRustdocProvider { - pub fn id() -> ProviderId { - ProviderId("rustdoc".into()) - } - - pub fn new(fs: Arc, cargo_workspace_root: PathBuf) -> Self { - Self { - fs, - cargo_workspace_root, - } - } -} - -#[async_trait] -impl IndexedDocsProvider for LocalRustdocProvider { - fn id(&self) -> ProviderId { - Self::id() - } - - fn database_path(&self) -> PathBuf { - paths::data_dir().join("docs/rust/rustdoc-db.1.mdb") - } - - async fn suggest_packages(&self) -> Result> { - static WORKSPACE_CRATES: LazyLock, Instant)>>> = - LazyLock::new(|| RwLock::new(None)); - - if let Some((crates, fetched_at)) = &*WORKSPACE_CRATES.read() { - if fetched_at.elapsed() < Duration::from_secs(300) { - return Ok(crates.iter().cloned().collect()); - } - } - - let workspace = MetadataCommand::new() - .manifest_path(self.cargo_workspace_root.join("Cargo.toml")) - .exec() - .context("failed to load cargo metadata")?; - - let workspace_crates = workspace - .packages - .into_iter() - .map(|package| PackageName::from(package.name.as_str())) - .collect::>(); - - *WORKSPACE_CRATES.write() = Some((workspace_crates.clone(), Instant::now())); - - Ok(workspace_crates.into_iter().collect()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - index_rustdoc(package, database, { - move |crate_name, item| { - let fs = self.fs.clone(); - let cargo_workspace_root = self.cargo_workspace_root.clone(); - let crate_name = crate_name.clone(); - let item = item.cloned(); - async move { - let target_doc_path = cargo_workspace_root.join("target/doc"); - let mut local_cargo_doc_path = target_doc_path.join(crate_name.as_ref().replace('-', "_")); - - if !fs.is_dir(&local_cargo_doc_path).await { - let cargo_doc_exists_at_all = fs.is_dir(&target_doc_path).await; - if cargo_doc_exists_at_all { - bail!( - "no docs directory for '{crate_name}'. if this is a valid crate name, try running `cargo doc`" - ); - } else { - bail!("no cargo doc directory. run `cargo doc`"); - } - } - - if let Some(item) = item { - local_cargo_doc_path.push(item.url_path()); - } else { - local_cargo_doc_path.push("index.html"); - } - - let Ok(contents) = fs.load(&local_cargo_doc_path).await else { - return Ok(None); - }; - - Ok(Some(contents)) - } - .boxed() - } - }) - .await - } -} - -pub struct DocsDotRsProvider { - http_client: Arc, -} - -impl DocsDotRsProvider { - pub fn id() -> ProviderId { - ProviderId("docs-rs".into()) - } - - pub fn new(http_client: Arc) -> Self { - Self { http_client } - } -} - -#[async_trait] -impl IndexedDocsProvider for DocsDotRsProvider { - fn id(&self) -> ProviderId { - Self::id() - } - - fn database_path(&self) -> PathBuf { - paths::data_dir().join("docs/rust/docs-rs-db.1.mdb") - } - - async fn suggest_packages(&self) -> Result> { - static POPULAR_CRATES: LazyLock> = LazyLock::new(|| { - include_str!("./rustdoc/popular_crates.txt") - .lines() - .filter(|line| !line.starts_with('#')) - .map(|line| PackageName::from(line.trim())) - .collect() - }); - - Ok(POPULAR_CRATES.clone()) - } - - async fn index(&self, package: PackageName, database: Arc) -> Result<()> { - index_rustdoc(package, database, { - move |crate_name, item| { - let http_client = self.http_client.clone(); - let crate_name = crate_name.clone(); - let item = item.cloned(); - async move { - let version = "latest"; - let path = format!( - "{crate_name}/{version}/{crate_name}{item_path}", - item_path = item - .map(|item| format!("/{}", item.url_path())) - .unwrap_or_default() - ); - - let mut response = http_client - .get( - &format!("https://docs.rs/{path}"), - AsyncBody::default(), - true, - ) - .await?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading docs.rs response body")?; - - if response.status().is_client_error() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - Ok(Some(String::from_utf8(body)?)) - } - .boxed() - } - }) - .await - } -} - -async fn index_rustdoc( - package: PackageName, - database: Arc, - fetch_page: impl Fn( - &PackageName, - Option<&RustdocItem>, - ) -> BoxFuture<'static, Result>> - + Send - + Sync, -) -> Result<()> { - let Some(package_root_content) = fetch_page(&package, None).await? else { - return Ok(()); - }; - - let (crate_root_markdown, items) = - convert_rustdoc_to_markdown(package_root_content.as_bytes())?; - - database - .insert(package.to_string(), crate_root_markdown) - .await?; - - let mut seen_items = HashSet::from_iter(items.clone()); - let mut items_to_visit: VecDeque = - VecDeque::from_iter(items.into_iter().map(|item| RustdocItemWithHistory { - item, - #[cfg(debug_assertions)] - history: Vec::new(), - })); - - while let Some(item_with_history) = items_to_visit.pop_front() { - let item = &item_with_history.item; - - let Some(result) = fetch_page(&package, Some(item)).await.with_context(|| { - #[cfg(debug_assertions)] - { - format!( - "failed to fetch {item:?}: {history:?}", - history = item_with_history.history - ) - } - - #[cfg(not(debug_assertions))] - { - format!("failed to fetch {item:?}") - } - })? - else { - continue; - }; - - let (markdown, referenced_items) = convert_rustdoc_to_markdown(result.as_bytes())?; - - database - .insert(format!("{package}::{}", item.display()), markdown) - .await?; - - let parent_item = item; - for mut item in referenced_items { - if seen_items.contains(&item) { - continue; - } - - seen_items.insert(item.clone()); - - item.path.extend(parent_item.path.clone()); - if parent_item.kind == RustdocItemKind::Mod { - item.path.push(parent_item.name.clone()); - } - - items_to_visit.push_back(RustdocItemWithHistory { - #[cfg(debug_assertions)] - history: { - let mut history = item_with_history.history.clone(); - history.push(item.url_path()); - history - }, - item, - }); - } - } - - Ok(()) -} diff --git a/crates/indexed_docs/src/providers/rustdoc/item.rs b/crates/indexed_docs/src/providers/rustdoc/item.rs deleted file mode 100644 index 7d9023ef3e..0000000000 --- a/crates/indexed_docs/src/providers/rustdoc/item.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::sync::Arc; - -use serde::{Deserialize, Serialize}; -use strum::EnumIter; - -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumIter, -)] -#[serde(rename_all = "snake_case")] -pub enum RustdocItemKind { - Mod, - Macro, - Struct, - Enum, - Constant, - Trait, - Function, - TypeAlias, - AttributeMacro, - DeriveMacro, -} - -impl RustdocItemKind { - pub(crate) const fn class(&self) -> &'static str { - match self { - Self::Mod => "mod", - Self::Macro => "macro", - Self::Struct => "struct", - Self::Enum => "enum", - Self::Constant => "constant", - Self::Trait => "trait", - Self::Function => "fn", - Self::TypeAlias => "type", - Self::AttributeMacro => "attr", - Self::DeriveMacro => "derive", - } - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] -pub struct RustdocItem { - pub kind: RustdocItemKind, - /// The item path, up until the name of the item. - pub path: Vec>, - /// The name of the item. - pub name: Arc, -} - -impl RustdocItem { - pub fn display(&self) -> String { - let mut path_segments = self.path.clone(); - path_segments.push(self.name.clone()); - - path_segments.join("::") - } - - pub fn url_path(&self) -> String { - let name = &self.name; - let mut path_components = self.path.clone(); - - match self.kind { - RustdocItemKind::Mod => { - path_components.push(name.clone()); - path_components.push("index.html".into()); - } - RustdocItemKind::Macro - | RustdocItemKind::Struct - | RustdocItemKind::Enum - | RustdocItemKind::Constant - | RustdocItemKind::Trait - | RustdocItemKind::Function - | RustdocItemKind::TypeAlias - | RustdocItemKind::AttributeMacro - | RustdocItemKind::DeriveMacro => { - path_components - .push(format!("{kind}.{name}.html", kind = self.kind.class()).into()); - } - } - - path_components.join("/") - } -} diff --git a/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt b/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt deleted file mode 100644 index ce2c3d51d8..0000000000 --- a/crates/indexed_docs/src/providers/rustdoc/popular_crates.txt +++ /dev/null @@ -1,252 +0,0 @@ -# A list of the most popular Rust crates. -# Sourced from https://lib.rs/std. -serde -serde_json -syn -clap -thiserror -rand -log -tokio -anyhow -regex -quote -proc-macro2 -base64 -itertools -chrono -lazy_static -once_cell -libc -reqwest -futures -bitflags -tracing -url -bytes -toml -tempfile -uuid -indexmap -env_logger -num-traits -async-trait -sha2 -hex -tracing-subscriber -http -parking_lot -cfg-if -futures-util -cc -hashbrown -rayon -hyper -getrandom -semver -strum -flate2 -tokio-util -smallvec -criterion -paste -heck -rand_core -nom -rustls -nix -glob -time -byteorder -strum_macros -serde_yaml -wasm-bindgen -ahash -either -num_cpus -rand_chacha -prost -percent-encoding -pin-project-lite -tokio-stream -bincode -walkdir -bindgen -axum -windows-sys -futures-core -ring -digest -num-bigint -rustls-pemfile -serde_with -crossbeam-channel -tokio-rustls -hmac -fastrand -dirs -zeroize -socket2 -pin-project -tower -derive_more -memchr -toml_edit -static_assertions -pretty_assertions -js-sys -convert_case -unicode-width -pkg-config -itoa -colored -rustc-hash -darling -mime -web-sys -image -bytemuck -which -sha1 -dashmap -arrayvec -fnv -tonic -humantime -libloading -winapi -rustc_version -http-body -indoc -num -home -serde_urlencoded -http-body-util -unicode-segmentation -num-integer -webpki-roots -phf -futures-channel -indicatif -petgraph -ordered-float -strsim -zstd -console -encoding_rs -wasm-bindgen-futures -urlencoding -subtle -crc32fast -slab -rustix -predicates -spin -hyper-rustls -backtrace -rustversion -mio -scopeguard -proc-macro-error -hyper-util -ryu -prost-types -textwrap -memmap2 -zip -zerocopy -generic-array -tar -pyo3 -async-stream -quick-xml -memoffset -csv -crossterm -windows -num_enum -tokio-tungstenite -crossbeam-utils -async-channel -lru -aes -futures-lite -tracing-core -prettyplease -httparse -serde_bytes -tracing-log -tower-service -cargo_metadata -pest -mime_guess -tower-http -data-encoding -native-tls -prost-build -proptest -derivative -serial_test -libm -half -futures-io -bitvec -rustls-native-certs -ureq -object -anstyle -tonic-build -form_urlencoded -num-derive -pest_derive -schemars -proc-macro-crate -rstest -futures-executor -assert_cmd -termcolor -serde_repr -ctrlc -sha3 -clap_complete -flume -mockall -ipnet -aho-corasick -atty -signal-hook -async-std -filetime -num-complex -opentelemetry -cmake -arc-swap -derive_builder -async-recursion -dyn-clone -bumpalo -fs_extra -git2 -sysinfo -shlex -instant -approx -rmp-serde -rand_distr -rustls-pki-types -maplit -sqlx -blake3 -hyper-tls -dotenvy -jsonwebtoken -openssl-sys -crossbeam -camino -winreg -config -rsa -bit-vec -chrono-tz -async-lock -bstr diff --git a/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs b/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs deleted file mode 100644 index 87e3863728..0000000000 --- a/crates/indexed_docs/src/providers/rustdoc/to_markdown.rs +++ /dev/null @@ -1,618 +0,0 @@ -use std::cell::RefCell; -use std::io::Read; -use std::rc::Rc; - -use anyhow::Result; -use html_to_markdown::markdown::{ - HeadingHandler, ListHandler, ParagraphHandler, StyledTextHandler, TableHandler, -}; -use html_to_markdown::{ - HandleTag, HandlerOutcome, HtmlElement, MarkdownWriter, StartTagOutcome, TagHandler, - convert_html_to_markdown, -}; -use indexmap::IndexSet; -use strum::IntoEnumIterator; - -use crate::{RustdocItem, RustdocItemKind}; - -/// Converts the provided rustdoc HTML to Markdown. -pub fn convert_rustdoc_to_markdown(html: impl Read) -> Result<(String, Vec)> { - let item_collector = Rc::new(RefCell::new(RustdocItemCollector::new())); - - let mut handlers: Vec = vec![ - Rc::new(RefCell::new(ParagraphHandler)), - Rc::new(RefCell::new(HeadingHandler)), - Rc::new(RefCell::new(ListHandler)), - Rc::new(RefCell::new(TableHandler::new())), - Rc::new(RefCell::new(StyledTextHandler)), - Rc::new(RefCell::new(RustdocChromeRemover)), - Rc::new(RefCell::new(RustdocHeadingHandler)), - Rc::new(RefCell::new(RustdocCodeHandler)), - Rc::new(RefCell::new(RustdocItemHandler)), - item_collector.clone(), - ]; - - let markdown = convert_html_to_markdown(html, &mut handlers)?; - - let items = item_collector - .borrow() - .items - .iter() - .cloned() - .collect::>(); - - Ok((markdown, items)) -} - -pub struct RustdocHeadingHandler; - -impl HandleTag for RustdocHeadingHandler { - fn should_handle(&self, _tag: &str) -> bool { - // We're only handling text, so we don't need to visit any tags. - false - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if writer.is_inside("h1") - || writer.is_inside("h2") - || writer.is_inside("h3") - || writer.is_inside("h4") - || writer.is_inside("h5") - || writer.is_inside("h6") - { - let text = text - .trim_matches(|char| char == '\n' || char == '\r') - .replace('\n', " "); - writer.push_str(&text); - - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -pub struct RustdocCodeHandler; - -impl HandleTag for RustdocCodeHandler { - fn should_handle(&self, tag: &str) -> bool { - matches!(tag, "pre" | "code") - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "code" => { - if !writer.is_inside("pre") { - writer.push_str("`"); - } - } - "pre" => { - let classes = tag.classes(); - let is_rust = classes.iter().any(|class| class == "rust"); - let language = is_rust - .then_some("rs") - .or_else(|| { - classes.iter().find_map(|class| { - if let Some((_, language)) = class.split_once("language-") { - Some(language.trim()) - } else { - None - } - }) - }) - .unwrap_or(""); - - writer.push_str(&format!("\n\n```{language}\n")); - } - _ => {} - } - - StartTagOutcome::Continue - } - - fn handle_tag_end(&mut self, tag: &HtmlElement, writer: &mut MarkdownWriter) { - match tag.tag() { - "code" => { - if !writer.is_inside("pre") { - writer.push_str("`"); - } - } - "pre" => writer.push_str("\n```\n"), - _ => {} - } - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if writer.is_inside("pre") { - writer.push_str(text); - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -const RUSTDOC_ITEM_NAME_CLASS: &str = "item-name"; - -pub struct RustdocItemHandler; - -impl RustdocItemHandler { - /// Returns whether we're currently inside of an `.item-name` element, which - /// rustdoc uses to display Rust items in a list. - fn is_inside_item_name(writer: &MarkdownWriter) -> bool { - writer - .current_element_stack() - .iter() - .any(|element| element.has_class(RUSTDOC_ITEM_NAME_CLASS)) - } -} - -impl HandleTag for RustdocItemHandler { - fn should_handle(&self, tag: &str) -> bool { - matches!(tag, "div" | "span") - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "div" | "span" => { - if Self::is_inside_item_name(writer) && tag.has_class("stab") { - writer.push_str(" ["); - } - } - _ => {} - } - - StartTagOutcome::Continue - } - - fn handle_tag_end(&mut self, tag: &HtmlElement, writer: &mut MarkdownWriter) { - match tag.tag() { - "div" | "span" => { - if tag.has_class(RUSTDOC_ITEM_NAME_CLASS) { - writer.push_str(": "); - } - - if Self::is_inside_item_name(writer) && tag.has_class("stab") { - writer.push_str("]"); - } - } - _ => {} - } - } - - fn handle_text(&mut self, text: &str, writer: &mut MarkdownWriter) -> HandlerOutcome { - if Self::is_inside_item_name(writer) - && !writer.is_inside("span") - && !writer.is_inside("code") - { - writer.push_str(&format!("`{text}`")); - return HandlerOutcome::Handled; - } - - HandlerOutcome::NoOp - } -} - -pub struct RustdocChromeRemover; - -impl HandleTag for RustdocChromeRemover { - fn should_handle(&self, tag: &str) -> bool { - matches!( - tag, - "head" | "script" | "nav" | "summary" | "button" | "a" | "div" | "span" - ) - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - _writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - match tag.tag() { - "head" | "script" | "nav" => return StartTagOutcome::Skip, - "summary" => { - if tag.has_class("hideme") { - return StartTagOutcome::Skip; - } - } - "button" => { - if tag.attr("id").as_deref() == Some("copy-path") { - return StartTagOutcome::Skip; - } - } - "a" => { - if tag.has_any_classes(&["anchor", "doc-anchor", "src"]) { - return StartTagOutcome::Skip; - } - } - "div" | "span" => { - if tag.has_any_classes(&["nav-container", "sidebar-elems", "out-of-band"]) { - return StartTagOutcome::Skip; - } - } - - _ => {} - } - - StartTagOutcome::Continue - } -} - -pub struct RustdocItemCollector { - pub items: IndexSet, -} - -impl RustdocItemCollector { - pub fn new() -> Self { - Self { - items: IndexSet::new(), - } - } - - fn parse_item(tag: &HtmlElement) -> Option { - if tag.tag() != "a" { - return None; - } - - let href = tag.attr("href")?; - if href.starts_with('#') || href.starts_with("https://") || href.starts_with("../") { - return None; - } - - for kind in RustdocItemKind::iter() { - if tag.has_class(kind.class()) { - let mut parts = href.trim_end_matches("/index.html").split('/'); - - if let Some(last_component) = parts.next_back() { - let last_component = match last_component.split_once('#') { - Some((component, _fragment)) => component, - None => last_component, - }; - - let name = last_component - .trim_start_matches(&format!("{}.", kind.class())) - .trim_end_matches(".html"); - - return Some(RustdocItem { - kind, - name: name.into(), - path: parts.map(Into::into).collect(), - }); - } - } - } - - None - } -} - -impl HandleTag for RustdocItemCollector { - fn should_handle(&self, tag: &str) -> bool { - tag == "a" - } - - fn handle_tag_start( - &mut self, - tag: &HtmlElement, - writer: &mut MarkdownWriter, - ) -> StartTagOutcome { - if tag.tag() == "a" { - let is_reexport = writer.current_element_stack().iter().any(|element| { - if let Some(id) = element.attr("id") { - id.starts_with("reexport.") || id.starts_with("method.") - } else { - false - } - }); - - if !is_reexport { - if let Some(item) = Self::parse_item(tag) { - self.items.insert(item); - } - } - } - - StartTagOutcome::Continue - } -} - -#[cfg(test)] -mod tests { - use html_to_markdown::{TagHandler, convert_html_to_markdown}; - use indoc::indoc; - use pretty_assertions::assert_eq; - - use super::*; - - fn rustdoc_handlers() -> Vec { - vec![ - Rc::new(RefCell::new(ParagraphHandler)), - Rc::new(RefCell::new(HeadingHandler)), - Rc::new(RefCell::new(ListHandler)), - Rc::new(RefCell::new(TableHandler::new())), - Rc::new(RefCell::new(StyledTextHandler)), - Rc::new(RefCell::new(RustdocChromeRemover)), - Rc::new(RefCell::new(RustdocHeadingHandler)), - Rc::new(RefCell::new(RustdocCodeHandler)), - Rc::new(RefCell::new(RustdocItemHandler)), - ] - } - - #[test] - fn test_main_heading_buttons_get_removed() { - let html = indoc! {r##" -
-

Crate serde

- - source · - -
- "##}; - let expected = indoc! {" - # Crate serde - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_single_paragraph() { - let html = indoc! {r#" -

In particular, the last point is what sets axum apart from other frameworks. - axum doesn’t have its own middleware system but instead uses - tower::Service. This means axum gets timeouts, tracing, compression, - authorization, and more, for free. It also enables you to share middleware with - applications written using hyper or tonic.

- "#}; - let expected = indoc! {" - In particular, the last point is what sets `axum` apart from other frameworks. `axum` doesn’t have its own middleware system but instead uses `tower::Service`. This means `axum` gets timeouts, tracing, compression, authorization, and more, for free. It also enables you to share middleware with applications written using `hyper` or `tonic`. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_multiple_paragraphs() { - let html = indoc! {r##" -

§Serde

-

Serde is a framework for serializing and deserializing Rust data - structures efficiently and generically.

-

The Serde ecosystem consists of data structures that know how to serialize - and deserialize themselves along with data formats that know how to - serialize and deserialize other things. Serde provides the layer by which - these two groups interact with each other, allowing any supported data - structure to be serialized and deserialized using any supported data format.

-

See the Serde website https://serde.rs/ for additional documentation and - usage examples.

-

§Design

-

Where many other languages rely on runtime reflection for serializing data, - Serde is instead built on Rust’s powerful trait system. A data structure - that knows how to serialize and deserialize itself is one that implements - Serde’s Serialize and Deserialize traits (or uses Serde’s derive - attribute to automatically generate implementations at compile time). This - avoids any overhead of reflection or runtime type information. In fact in - many situations the interaction between data structure and data format can - be completely optimized away by the Rust compiler, leaving Serde - serialization to perform the same speed as a handwritten serializer for the - specific selection of data structure and data format.

- "##}; - let expected = indoc! {" - ## Serde - - Serde is a framework for _**ser**_ializing and _**de**_serializing Rust data structures efficiently and generically. - - The Serde ecosystem consists of data structures that know how to serialize and deserialize themselves along with data formats that know how to serialize and deserialize other things. Serde provides the layer by which these two groups interact with each other, allowing any supported data structure to be serialized and deserialized using any supported data format. - - See the Serde website https://serde.rs/ for additional documentation and usage examples. - - ### Design - - Where many other languages rely on runtime reflection for serializing data, Serde is instead built on Rust’s powerful trait system. A data structure that knows how to serialize and deserialize itself is one that implements Serde’s `Serialize` and `Deserialize` traits (or uses Serde’s derive attribute to automatically generate implementations at compile time). This avoids any overhead of reflection or runtime type information. In fact in many situations the interaction between data structure and data format can be completely optimized away by the Rust compiler, leaving Serde serialization to perform the same speed as a handwritten serializer for the specific selection of data structure and data format. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_styled_text() { - let html = indoc! {r#" -

This text is bolded.

-

This text is italicized.

- "#}; - let expected = indoc! {" - This text is **bolded**. - - This text is _italicized_. - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_rust_code_block() { - let html = indoc! {r#" -
use axum::extract::{Path, Query, Json};
-            use std::collections::HashMap;
-
-            // `Path` gives you the path parameters and deserializes them.
-            async fn path(Path(user_id): Path<u32>) {}
-
-            // `Query` gives you the query parameters and deserializes them.
-            async fn query(Query(params): Query<HashMap<String, String>>) {}
-
-            // Buffer the request body and deserialize it as JSON into a
-            // `serde_json::Value`. `Json` supports any type that implements
-            // `serde::Deserialize`.
-            async fn json(Json(payload): Json<serde_json::Value>) {}
- "#}; - let expected = indoc! {" - ```rs - use axum::extract::{Path, Query, Json}; - use std::collections::HashMap; - - // `Path` gives you the path parameters and deserializes them. - async fn path(Path(user_id): Path) {} - - // `Query` gives you the query parameters and deserializes them. - async fn query(Query(params): Query>) {} - - // Buffer the request body and deserialize it as JSON into a - // `serde_json::Value`. `Json` supports any type that implements - // `serde::Deserialize`. - async fn json(Json(payload): Json) {} - ``` - "} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_toml_code_block() { - let html = indoc! {r##" -

§Required dependencies

-

To use axum there are a few dependencies you have to pull in as well:

-
[dependencies]
-            axum = "<latest-version>"
-            tokio = { version = "<latest-version>", features = ["full"] }
-            tower = "<latest-version>"
-            
- "##}; - let expected = indoc! {r#" - ## Required dependencies - - To use axum there are a few dependencies you have to pull in as well: - - ```toml - [dependencies] - axum = "" - tokio = { version = "", features = ["full"] } - tower = "" - - ``` - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_item_table() { - let html = indoc! {r##" -

Structs§

-
    -
  • Errors that can happen when using axum.
  • -
  • Extractor and response for extensions.
  • -
  • Formform
    URL encoded extractor and response.
  • -
  • Jsonjson
    JSON Extractor / Response.
  • -
  • The router type for composing handlers and services.
-

Functions§

-
    -
  • servetokio and (http1 or http2)
    Serve the service with the supplied listener.
  • -
- "##}; - let expected = indoc! {r#" - ## Structs - - - `Error`: Errors that can happen when using axum. - - `Extension`: Extractor and response for extensions. - - `Form` [`form`]: URL encoded extractor and response. - - `Json` [`json`]: JSON Extractor / Response. - - `Router`: The router type for composing handlers and services. - - ## Functions - - - `serve` [`tokio` and (`http1` or `http2`)]: Serve the service with the supplied listener. - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } - - #[test] - fn test_table() { - let html = indoc! {r##" -

§Feature flags

-

axum uses a set of feature flags to reduce the amount of compiled and - optional dependencies.

-

The following optional features are available:

-
- - - - - - - - - - - - - -
NameDescriptionDefault?
http1Enables hyper’s http1 featureYes
http2Enables hyper’s http2 featureNo
jsonEnables the Json type and some similar convenience functionalityYes
macrosEnables optional utility macrosNo
matched-pathEnables capturing of every request’s router path and the MatchedPath extractorYes
multipartEnables parsing multipart/form-data requests with MultipartNo
original-uriEnables capturing of every request’s original URI and the OriginalUri extractorYes
tokioEnables tokio as a dependency and axum::serve, SSE and extract::connect_info types.Yes
tower-logEnables tower’s log featureYes
tracingLog rejections from built-in extractorsYes
wsEnables WebSockets support via extract::wsNo
formEnables the Form extractorYes
queryEnables the Query extractorYes
- "##}; - let expected = indoc! {r#" - ## Feature flags - - axum uses a set of feature flags to reduce the amount of compiled and optional dependencies. - - The following optional features are available: - - | Name | Description | Default? | - | --- | --- | --- | - | `http1` | Enables hyper’s `http1` feature | Yes | - | `http2` | Enables hyper’s `http2` feature | No | - | `json` | Enables the `Json` type and some similar convenience functionality | Yes | - | `macros` | Enables optional utility macros | No | - | `matched-path` | Enables capturing of every request’s router path and the `MatchedPath` extractor | Yes | - | `multipart` | Enables parsing `multipart/form-data` requests with `Multipart` | No | - | `original-uri` | Enables capturing of every request’s original URI and the `OriginalUri` extractor | Yes | - | `tokio` | Enables `tokio` as a dependency and `axum::serve`, `SSE` and `extract::connect_info` types. | Yes | - | `tower-log` | Enables `tower`’s `log` feature | Yes | - | `tracing` | Log rejections from built-in extractors | Yes | - | `ws` | Enables WebSockets support via `extract::ws` | No | - | `form` | Enables the `Form` extractor | Yes | - | `query` | Enables the `Query` extractor | Yes | - "#} - .trim(); - - assert_eq!( - convert_html_to_markdown(html.as_bytes(), &mut rustdoc_handlers()).unwrap(), - expected - ) - } -} diff --git a/crates/indexed_docs/src/registry.rs b/crates/indexed_docs/src/registry.rs deleted file mode 100644 index 6757cd9c1a..0000000000 --- a/crates/indexed_docs/src/registry.rs +++ /dev/null @@ -1,62 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use gpui::{App, BackgroundExecutor, Global, ReadGlobal, UpdateGlobal}; -use parking_lot::RwLock; - -use crate::{IndexedDocsProvider, IndexedDocsStore, ProviderId}; - -struct GlobalIndexedDocsRegistry(Arc); - -impl Global for GlobalIndexedDocsRegistry {} - -pub struct IndexedDocsRegistry { - executor: BackgroundExecutor, - stores_by_provider: RwLock>>, -} - -impl IndexedDocsRegistry { - pub fn global(cx: &App) -> Arc { - GlobalIndexedDocsRegistry::global(cx).0.clone() - } - - pub(crate) fn init_global(cx: &mut App) { - GlobalIndexedDocsRegistry::set_global( - cx, - GlobalIndexedDocsRegistry(Arc::new(Self::new(cx.background_executor().clone()))), - ); - } - - pub fn new(executor: BackgroundExecutor) -> Self { - Self { - executor, - stores_by_provider: RwLock::new(HashMap::default()), - } - } - - pub fn list_providers(&self) -> Vec { - self.stores_by_provider - .read() - .keys() - .cloned() - .collect::>() - } - - pub fn register_provider( - &self, - provider: Box, - ) { - self.stores_by_provider.write().insert( - provider.id(), - Arc::new(IndexedDocsStore::new(provider, self.executor.clone())), - ); - } - - pub fn unregister_provider(&self, provider_id: &ProviderId) { - self.stores_by_provider.write().remove(provider_id); - } - - pub fn get_provider_store(&self, provider_id: ProviderId) -> Option> { - self.stores_by_provider.read().get(&provider_id).cloned() - } -} diff --git a/crates/indexed_docs/src/store.rs b/crates/indexed_docs/src/store.rs deleted file mode 100644 index 1407078efa..0000000000 --- a/crates/indexed_docs/src/store.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use collections::HashMap; -use derive_more::{Deref, Display}; -use futures::FutureExt; -use futures::future::{self, BoxFuture, Shared}; -use fuzzy::StringMatchCandidate; -use gpui::{App, BackgroundExecutor, Task}; -use heed::Database; -use heed::types::SerdeBincode; -use parking_lot::RwLock; -use serde::{Deserialize, Serialize}; -use util::ResultExt; - -use crate::IndexedDocsRegistry; - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Deref, Display)] -pub struct ProviderId(pub Arc); - -/// The name of a package. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Deref, Display)] -pub struct PackageName(Arc); - -impl From<&str> for PackageName { - fn from(value: &str) -> Self { - Self(value.into()) - } -} - -#[async_trait] -pub trait IndexedDocsProvider { - /// Returns the ID of this provider. - fn id(&self) -> ProviderId; - - /// Returns the path to the database for this provider. - fn database_path(&self) -> PathBuf; - - /// Returns a list of packages as suggestions to be included in the search - /// results. - /// - /// This can be used to provide completions for known packages (e.g., from the - /// local project or a registry) before a package has been indexed. - async fn suggest_packages(&self) -> Result>; - - /// Indexes the package with the given name. - async fn index(&self, package: PackageName, database: Arc) -> Result<()>; -} - -/// A store for indexed docs. -pub struct IndexedDocsStore { - executor: BackgroundExecutor, - database_future: - Shared, Arc>>>, - provider: Box, - indexing_tasks_by_package: - RwLock>>>>>, - latest_errors_by_package: RwLock>>, -} - -impl IndexedDocsStore { - pub fn try_global(provider: ProviderId, cx: &App) -> Result> { - let registry = IndexedDocsRegistry::global(cx); - registry - .get_provider_store(provider.clone()) - .with_context(|| format!("no indexed docs store found for {provider}")) - } - - pub fn new( - provider: Box, - executor: BackgroundExecutor, - ) -> Self { - let database_future = executor - .spawn({ - let executor = executor.clone(); - let database_path = provider.database_path(); - async move { IndexedDocsDatabase::new(database_path, executor) } - }) - .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) - .boxed() - .shared(); - - Self { - executor, - database_future, - provider, - indexing_tasks_by_package: RwLock::new(HashMap::default()), - latest_errors_by_package: RwLock::new(HashMap::default()), - } - } - - pub fn latest_error_for_package(&self, package: &PackageName) -> Option> { - self.latest_errors_by_package.read().get(package).cloned() - } - - /// Returns whether the package with the given name is currently being indexed. - pub fn is_indexing(&self, package: &PackageName) -> bool { - self.indexing_tasks_by_package.read().contains_key(package) - } - - pub async fn load(&self, key: String) -> Result { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .load(key) - .await - } - - pub async fn load_many_by_prefix(&self, prefix: String) -> Result> { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .load_many_by_prefix(prefix) - .await - } - - /// Returns whether any entries exist with the given prefix. - pub async fn any_with_prefix(&self, prefix: String) -> Result { - self.database_future - .clone() - .await - .map_err(|err| anyhow!(err))? - .any_with_prefix(prefix) - .await - } - - pub fn suggest_packages(self: Arc) -> Task>> { - let this = self.clone(); - self.executor - .spawn(async move { this.provider.suggest_packages().await }) - } - - pub fn index( - self: Arc, - package: PackageName, - ) -> Shared>>> { - if let Some(existing_task) = self.indexing_tasks_by_package.read().get(&package) { - return existing_task.clone(); - } - - let indexing_task = self - .executor - .spawn({ - let this = self.clone(); - let package = package.clone(); - async move { - let _finally = util::defer({ - let this = this.clone(); - let package = package.clone(); - move || { - this.indexing_tasks_by_package.write().remove(&package); - } - }); - - let index_task = { - let package = package.clone(); - async { - let database = this - .database_future - .clone() - .await - .map_err(|err| anyhow!(err))?; - this.provider.index(package, database).await - } - }; - - let result = index_task.await.map_err(Arc::new); - match &result { - Ok(_) => { - this.latest_errors_by_package.write().remove(&package); - } - Err(err) => { - this.latest_errors_by_package - .write() - .insert(package, err.to_string().into()); - } - } - - result - } - }) - .shared(); - - self.indexing_tasks_by_package - .write() - .insert(package, indexing_task.clone()); - - indexing_task - } - - pub fn search(&self, query: String) -> Task> { - let executor = self.executor.clone(); - let database_future = self.database_future.clone(); - self.executor.spawn(async move { - let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { - return Vec::new(); - }; - - let Some(items) = database.keys().await.log_err() else { - return Vec::new(); - }; - - let candidates = items - .iter() - .enumerate() - .map(|(ix, item_path)| StringMatchCandidate::new(ix, &item_path)) - .collect::>(); - - let matches = fuzzy::match_strings( - &candidates, - &query, - false, - true, - 100, - &AtomicBool::default(), - executor, - ) - .await; - - matches - .into_iter() - .map(|mat| items[mat.candidate_id].clone()) - .collect() - }) - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Display, Serialize, Deserialize)] -pub struct MarkdownDocs(pub String); - -pub struct IndexedDocsDatabase { - executor: BackgroundExecutor, - env: heed::Env, - entries: Database, SerdeBincode>, -} - -impl IndexedDocsDatabase { - pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result { - std::fs::create_dir_all(&path)?; - - const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; - let env = unsafe { - heed::EnvOpenOptions::new() - .map_size(ONE_GB_IN_BYTES) - .max_dbs(1) - .open(path)? - }; - - let mut txn = env.write_txn()?; - let entries = env.create_database(&mut txn, Some("rustdoc_entries"))?; - txn.commit()?; - - Ok(Self { - executor, - env, - entries, - }) - } - - pub fn keys(&self) -> Task>> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let mut iter = entries.iter(&txn)?; - let mut keys = Vec::new(); - while let Some((key, _value)) = iter.next().transpose()? { - keys.push(key); - } - - Ok(keys) - }) - } - - pub fn load(&self, key: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - entries - .get(&txn, &key)? - .with_context(|| format!("no docs found for {key}")) - }) - } - - pub fn load_many_by_prefix(&self, prefix: String) -> Task>> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let results = entries - .iter(&txn)? - .filter_map(|entry| { - let (key, value) = entry.ok()?; - if key.starts_with(&prefix) { - Some((key, value)) - } else { - None - } - }) - .collect::>(); - - Ok(results) - }) - } - - /// Returns whether any entries exist with the given prefix. - pub fn any_with_prefix(&self, prefix: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let txn = env.read_txn()?; - let any = entries - .iter(&txn)? - .any(|entry| entry.map_or(false, |(key, _value)| key.starts_with(&prefix))); - Ok(any) - }) - } - - pub fn insert(&self, key: String, docs: String) -> Task> { - let env = self.env.clone(); - let entries = self.entries; - - self.executor.spawn(async move { - let mut txn = env.write_txn()?; - entries.put(&mut txn, &key, &MarkdownDocs(docs))?; - txn.commit()?; - Ok(()) - }) - } -} - -impl extension::KeyValueStoreDelegate for IndexedDocsDatabase { - fn insert(&self, key: String, docs: String) -> Task> { - IndexedDocsDatabase::insert(&self, key, docs) - } -} diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 9b0abb1537..1aae0b2f7e 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -987,7 +987,7 @@ pub struct InlayHintSettings { /// Default: false #[serde(default)] pub enabled: bool, - /// Global switch to toggle inline values on and off. + /// Global switch to toggle inline values on and off when debugging. /// /// Default: true #[serde(default = "default_true")] diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 3b4c1fa269..0e10050dae 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -42,6 +42,18 @@ impl fmt::Display for ModelRequestLimitReachedError { } } +#[derive(Error, Debug)] +pub struct ToolUseLimitReachedError; + +impl fmt::Display for ToolUseLimitReachedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use." + ) + } +} + #[derive(Clone, Default)] pub struct LlmApiToken(Arc>>); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index ba110be9c5..c1337399f9 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -941,6 +941,8 @@ impl LanguageModel for CloudLanguageModel { request, model.id(), model.supports_parallel_tool_calls(), + model.supports_prompt_cache_key(), + None, None, ); let llm_api_token = self.llm_api_token.clone(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 02e53cb99a..4a0d740334 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -47,6 +47,7 @@ pub struct AvailableModel { pub max_completion_tokens: Option, pub supports_tools: Option, pub supports_images: Option, + pub supports_thinking: Option, } pub struct MistralLanguageModelProvider { @@ -215,6 +216,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { max_completion_tokens: model.max_completion_tokens, supports_tools: model.supports_tools, supports_images: model.supports_images, + supports_thinking: model.supports_thinking, }, ); } @@ -366,11 +368,7 @@ impl LanguageModel for MistralLanguageModel { LanguageModelCompletionError, >, > { - let request = into_mistral( - request, - self.model.id().to_string(), - self.max_output_tokens(), - ); + let request = into_mistral(request, self.model.clone(), self.max_output_tokens()); let stream = self.stream_completion(request, cx); async move { @@ -384,7 +382,7 @@ impl LanguageModel for MistralLanguageModel { pub fn into_mistral( request: LanguageModelRequest, - model: String, + model: mistral::Model, max_output_tokens: Option, ) -> mistral::Request { let stream = true; @@ -401,13 +399,20 @@ pub fn into_mistral( .push_part(mistral::MessagePart::Text { text: text.clone() }); } MessageContent::Image(image_content) => { - message_content.push_part(mistral::MessagePart::ImageUrl { - image_url: image_content.to_base64_url(), - }); + if model.supports_images() { + message_content.push_part(mistral::MessagePart::ImageUrl { + image_url: image_content.to_base64_url(), + }); + } } MessageContent::Thinking { text, .. } => { - message_content - .push_part(mistral::MessagePart::Text { text: text.clone() }); + if model.supports_thinking() { + message_content.push_part(mistral::MessagePart::Thinking { + thinking: vec![mistral::ThinkingPart::Text { + text: text.clone(), + }], + }); + } } MessageContent::RedactedThinking(_) => {} MessageContent::ToolUse(_) => { @@ -437,12 +442,28 @@ pub fn into_mistral( Role::Assistant => { for content in &message.content { match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + MessageContent::Text(text) => { messages.push(mistral::RequestMessage::Assistant { - content: Some(text.clone()), + content: Some(mistral::MessageContent::Plain { + content: text.clone(), + }), tool_calls: Vec::new(), }); } + MessageContent::Thinking { text, .. } => { + if model.supports_thinking() { + messages.push(mistral::RequestMessage::Assistant { + content: Some(mistral::MessageContent::Multipart { + content: vec![mistral::MessagePart::Thinking { + thinking: vec![mistral::ThinkingPart::Text { + text: text.clone(), + }], + }], + }), + tool_calls: Vec::new(), + }); + } + } MessageContent::RedactedThinking(_) => {} MessageContent::Image(_) => {} MessageContent::ToolUse(tool_use) => { @@ -477,11 +498,26 @@ pub fn into_mistral( Role::System => { for content in &message.content { match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + MessageContent::Text(text) => { messages.push(mistral::RequestMessage::System { - content: text.clone(), + content: mistral::MessageContent::Plain { + content: text.clone(), + }, }); } + MessageContent::Thinking { text, .. } => { + if model.supports_thinking() { + messages.push(mistral::RequestMessage::System { + content: mistral::MessageContent::Multipart { + content: vec![mistral::MessagePart::Thinking { + thinking: vec![mistral::ThinkingPart::Text { + text: text.clone(), + }], + }], + }, + }); + } + } MessageContent::RedactedThinking(_) => {} MessageContent::Image(_) | MessageContent::ToolUse(_) @@ -494,37 +530,8 @@ pub fn into_mistral( } } - // The Mistral API requires that tool messages be followed by assistant messages, - // not user messages. When we have a tool->user sequence in the conversation, - // we need to insert a placeholder assistant message to maintain proper conversation - // flow and prevent API errors. This is a Mistral-specific requirement that differs - // from other language model APIs. - let messages = { - let mut fixed_messages = Vec::with_capacity(messages.len()); - let mut messages_iter = messages.into_iter().peekable(); - - while let Some(message) = messages_iter.next() { - let is_tool_message = matches!(message, mistral::RequestMessage::Tool { .. }); - fixed_messages.push(message); - - // Insert assistant message between tool and user messages - if is_tool_message { - if let Some(next_msg) = messages_iter.peek() { - if matches!(next_msg, mistral::RequestMessage::User { .. }) { - fixed_messages.push(mistral::RequestMessage::Assistant { - content: Some(" ".to_string()), - tool_calls: Vec::new(), - }); - } - } - } - } - - fixed_messages - }; - mistral::Request { - model, + model: model.id().to_string(), messages, stream, max_tokens: max_output_tokens, @@ -595,8 +602,38 @@ impl MistralEventMapper { }; let mut events = Vec::new(); - if let Some(content) = choice.delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); + if let Some(content) = choice.delta.content.as_ref() { + match content { + mistral::MessageContentDelta::Text(text) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text.clone()))); + } + mistral::MessageContentDelta::Parts(parts) => { + for part in parts { + match part { + mistral::MessagePart::Text { text } => { + events.push(Ok(LanguageModelCompletionEvent::Text(text.clone()))); + } + mistral::MessagePart::Thinking { thinking } => { + for tp in thinking.iter().cloned() { + match tp { + mistral::ThinkingPart::Text { text } => { + events.push(Ok( + LanguageModelCompletionEvent::Thinking { + text, + signature: None, + }, + )); + } + } + } + } + mistral::MessagePart::ImageUrl { .. } => { + // We currently don't emit a separate event for images in responses. + } + } + } + } + } } if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { @@ -908,7 +945,7 @@ mod tests { thinking_allowed: true, }; - let mistral_request = into_mistral(request, "mistral-small-latest".into(), None); + let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None); assert_eq!(mistral_request.model, "mistral-small-latest"); assert_eq!(mistral_request.temperature, Some(0.5)); @@ -941,7 +978,7 @@ mod tests { thinking_allowed: true, }; - let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None); + let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None); assert_eq!(mistral_request.messages.len(), 1); assert!(matches!( diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 7a6c8e09ed..eaf8d885b3 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -14,7 +14,7 @@ use language_model::{ RateLimiter, Role, StopReason, TokenUsage, }; use menu; -use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; +use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -45,6 +45,7 @@ pub struct AvailableModel { pub max_tokens: u64, pub max_output_tokens: Option, pub max_completion_tokens: Option, + pub reasoning_effort: Option, } pub struct OpenAiLanguageModelProvider { @@ -213,6 +214,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, + reasoning_effort: model.reasoning_effort.clone(), }, ); } @@ -301,7 +303,25 @@ impl LanguageModel for OpenAiLanguageModel { } fn supports_images(&self) -> bool { - false + use open_ai::Model; + match &self.model { + Model::FourOmni + | Model::FourOmniMini + | Model::FourPointOne + | Model::FourPointOneMini + | Model::FourPointOneNano + | Model::Five + | Model::FiveMini + | Model::FiveNano + | Model::O1 + | Model::O3 + | Model::O4Mini => true, + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::O3Mini + | Model::Custom { .. } => false, + } } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { @@ -350,7 +370,9 @@ impl LanguageModel for OpenAiLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), + self.model.reasoning_effort(), ); let completions = self.stream_completion(request, cx); async move { @@ -365,7 +387,9 @@ pub fn into_open_ai( request: LanguageModelRequest, model_id: &str, supports_parallel_tool_calls: bool, + supports_prompt_cache_key: bool, max_output_tokens: Option, + reasoning_effort: Option, ) -> open_ai::Request { let stream = !model_id.starts_with("o1-"); @@ -455,6 +479,11 @@ pub fn into_open_ai( } else { None }, + prompt_cache_key: if supports_prompt_cache_key { + request.thread_id + } else { + None + }, tools: request .tools .into_iter() @@ -471,6 +500,7 @@ pub fn into_open_ai( LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, LanguageModelToolChoice::None => open_ai::ToolChoice::None, }), + reasoning_effort, } } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 38bd7cee06..5f546f5219 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -355,7 +355,16 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { LanguageModelCompletionError, >, > { - let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens()); + let supports_parallel_tool_call = true; + let supports_prompt_cache_key = false; + let request = into_open_ai( + request, + &self.model.name, + supports_parallel_tool_call, + supports_prompt_cache_key, + self.max_output_tokens(), + None, + ); let completions = self.stream_completion(request, cx); async move { let mapper = OpenAiEventMapper::new(); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 037ce467d0..9f447cb68b 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -355,7 +355,9 @@ impl LanguageModel for VercelLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), + None, ); let completions = self.stream_completion(request, cx); async move { diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 5f6034571b..fed6fe92bf 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -359,7 +359,9 @@ impl LanguageModel for XAiLanguageModel { request, self.model.id(), self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), self.max_output_tokens(), + None, ); let completions = self.stream_completion(request, cx); async move { diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 606f3a3f0e..823d59ce12 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1358,7 +1358,7 @@ impl Render for LspLogToolbarItemView { }) .collect(); - let log_toolbar_view = cx.entity().clone(); + let log_toolbar_view = cx.entity(); let lsp_menu = PopoverMenu::new("LspLogView") .anchor(Corner::TopLeft) diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 50547253a9..3244350a34 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1007,7 +1007,7 @@ impl Render for LspTool { (None, "All Servers Operational") }; - let lsp_tool = cx.entity().clone(); + let lsp_tool = cx.entity(); div().child( PopoverMenu::new("lsp-tool") diff --git a/crates/language_tools/src/syntax_tree_view.rs b/crates/language_tools/src/syntax_tree_view.rs index eadba2c1d2..9946442ec8 100644 --- a/crates/language_tools/src/syntax_tree_view.rs +++ b/crates/language_tools/src/syntax_tree_view.rs @@ -456,7 +456,7 @@ impl SyntaxTreeToolbarItemView { let active_layer = buffer_state.active_layer.clone()?; let active_buffer = buffer_state.buffer.read(cx).snapshot(); - let view = cx.entity().clone(); + let view = cx.entity(); Some( PopoverMenu::new("Syntax Tree") .trigger(Self::render_header(&active_layer)) diff --git a/crates/languages/src/c.rs b/crates/languages/src/c.rs index df93e51760..aee1abee95 100644 --- a/crates/languages/src/c.rs +++ b/crates/languages/src/c.rs @@ -71,13 +71,13 @@ impl super::LspAdapter for CLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result { - let GitHubLspBinaryVersion { name, url, digest } = - &*version.downcast::().unwrap(); + let GitHubLspBinaryVersion { + name, + url, + digest: expected_digest, + } = *version.downcast::().unwrap(); let version_dir = container_dir.join(format!("clangd_{name}")); let binary_path = version_dir.join("bin/clangd"); - let expected_digest = digest - .as_ref() - .and_then(|digest| digest.strip_prefix("sha256:")); let binary = LanguageServerBinary { path: binary_path.clone(), @@ -103,7 +103,7 @@ impl super::LspAdapter for CLspAdapter { }) }; if let (Some(actual_digest), Some(expected_digest)) = - (&metadata.digest, expected_digest) + (&metadata.digest, &expected_digest) { if actual_digest == expected_digest { if validity_check().await.is_ok() { @@ -120,8 +120,8 @@ impl super::LspAdapter for CLspAdapter { } download_server_binary( delegate, - url, - digest.as_deref(), + &url, + expected_digest.as_deref(), &container_dir, AssetKind::Zip, ) @@ -130,7 +130,7 @@ impl super::LspAdapter for CLspAdapter { GithubBinaryMetadata::write_to_file( &GithubBinaryMetadata { metadata_version: 1, - digest: digest.clone(), + digest: expected_digest, }, &metadata_path, ) diff --git a/crates/languages/src/cpp/outline.scm b/crates/languages/src/cpp/outline.scm index 448fe35220..c897366558 100644 --- a/crates/languages/src/cpp/outline.scm +++ b/crates/languages/src/cpp/outline.scm @@ -149,7 +149,9 @@ parameters: (parameter_list "(" @context ")" @context))) - ] - (type_qualifier)? @context) @item + ; Fields declarations may define multiple fields, and so @item is on the + ; declarator so they each get distinct ranges. + ] @item + (type_qualifier)? @context) (comment) @annotation diff --git a/crates/languages/src/css.rs b/crates/languages/src/css.rs index 7725e079be..ffd9006c76 100644 --- a/crates/languages/src/css.rs +++ b/crates/languages/src/css.rs @@ -4,7 +4,7 @@ use futures::StreamExt; use gpui::AsyncApp; use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::json; use smol::fs; @@ -103,7 +103,12 @@ impl LspAdapter for CssLspAdapter { let should_install_language_server = self .node - .should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version) + .should_install_npm_package( + Self::PACKAGE_NAME, + &server_path, + &container_dir, + VersionStrategy::Latest(version), + ) .await; if should_install_language_server { diff --git a/crates/languages/src/github_download.rs b/crates/languages/src/github_download.rs index a3cd0a964b..5b0f1d0729 100644 --- a/crates/languages/src/github_download.rs +++ b/crates/languages/src/github_download.rs @@ -18,9 +18,8 @@ impl GithubBinaryMetadata { let metadata_content = async_fs::read_to_string(metadata_path) .await .with_context(|| format!("reading metadata file at {metadata_path:?}"))?; - let metadata: GithubBinaryMetadata = serde_json::from_str(&metadata_content) - .with_context(|| format!("parsing metadata file at {metadata_path:?}"))?; - Ok(metadata) + serde_json::from_str(&metadata_content) + .with_context(|| format!("parsing metadata file at {metadata_path:?}")) } pub(crate) async fn write_to_file(&self, metadata_path: &Path) -> Result<()> { @@ -62,6 +61,7 @@ pub(crate) async fn download_server_binary( format!("saving archive contents into the temporary file for {url}",) })?; let asset_sha_256 = format!("{:x}", writer.hasher.finalize()); + anyhow::ensure!( asset_sha_256 == expected_sha_256, "{url} asset got SHA-256 mismatch. Expected: {expected_sha_256}, Got: {asset_sha_256}", diff --git a/crates/languages/src/go.rs b/crates/languages/src/go.rs index 16c1b67203..14f646133b 100644 --- a/crates/languages/src/go.rs +++ b/crates/languages/src/go.rs @@ -487,6 +487,8 @@ const GO_MODULE_ROOT_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("GO_MODULE_ROOT")); const GO_SUBTEST_NAME_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("GO_SUBTEST_NAME")); +const GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE: VariableName = + VariableName::Custom(Cow::Borrowed("GO_TABLE_TEST_CASE_NAME")); impl ContextProvider for GoContextProvider { fn build_context( @@ -545,10 +547,19 @@ impl ContextProvider for GoContextProvider { let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or("")) .map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name)); + let table_test_case_name = variables.get(&VariableName::Custom(Cow::Borrowed( + "_table_test_case_name", + ))); + + let go_table_test_case_variable = table_test_case_name + .and_then(extract_subtest_name) + .map(|case_name| (GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.clone(), case_name)); + Task::ready(Ok(TaskVariables::from_iter( [ go_package_variable, go_subtest_variable, + go_table_test_case_variable, go_module_root_variable, ] .into_iter() @@ -570,6 +581,28 @@ impl ContextProvider for GoContextProvider { let module_cwd = Some(GO_MODULE_ROOT_TASK_VARIABLE.template_value()); Task::ready(Some(TaskTemplates(vec![ + TaskTemplate { + label: format!( + "go test {} -v -run {}/{}", + GO_PACKAGE_TASK_VARIABLE.template_value(), + VariableName::Symbol.template_value(), + GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(), + ), + command: "go".into(), + args: vec![ + "test".into(), + "-v".into(), + "-run".into(), + format!( + "\\^{}\\$/\\^{}\\$", + VariableName::Symbol.template_value(), + GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(), + ), + ], + cwd: package_cwd.clone(), + tags: vec!["go-table-test-case".to_owned()], + ..TaskTemplate::default() + }, TaskTemplate { label: format!( "go test {} -run {}", @@ -842,10 +875,21 @@ mod tests { .collect() }); + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + assert!( - runnables.len() == 2, - "Should find test function and subtest with double quotes, found: {}", - runnables.len() + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-subtest".to_string()), + "Should find go-subtest tag, found: {:?}", + tag_strings ); let buffer = cx.new(|cx| { @@ -860,10 +904,299 @@ mod tests { .collect() }); + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + assert!( - runnables.len() == 2, - "Should find test function and subtest with backticks, found: {}", - runnables.len() + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-subtest".to_string()), + "Should find go-subtest tag, found: {:?}", + tag_strings + ); + } + + #[gpui::test] + fn test_go_table_test_slice_detection(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + _ = "some random string" + + testCases := []struct{ + name string + anotherStr string + }{ + { + name: "test case 1", + anotherStr: "foo", + }, + { + name: "test case 2", + anotherStr: "bar", + }, + } + + notATableTest := []struct{ + name string + }{ + { + name: "some string", + }, + { + name: "some other string", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // test code here + }) + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + + let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count(); + let go_table_test_count = tag_strings + .iter() + .filter(|&tag| tag == "go-table-test-case") + .count(); + + assert!( + go_test_count == 1, + "Should find exactly 1 go-test, found: {}", + go_test_count + ); + assert!( + go_table_test_count == 2, + "Should find exactly 2 go-table-test-case, found: {}", + go_table_test_count + ); + } + + #[gpui::test] + fn test_go_table_test_slice_ignored(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + func Example() { + _ = "some random string" + + notATableTest := []struct{ + name string + }{ + { + name: "some string", + }, + { + name: "some other string", + }, + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + !tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + !tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + } + + #[gpui::test] + fn test_go_table_test_map_detection(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + _ = "some random string" + + testCases := map[string]struct { + someStr string + fail bool + }{ + "test failure": { + someStr: "foo", + fail: true, + }, + "test success": { + someStr: "bar", + fail: false, + }, + } + + notATableTest := map[string]struct { + someStr string + }{ + "some string": { + someStr: "foo", + }, + "some other string": { + someStr: "bar", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // test code here + }) + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + + let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count(); + let go_table_test_count = tag_strings + .iter() + .filter(|&tag| tag == "go-table-test-case") + .count(); + + assert!( + go_test_count == 1, + "Should find exactly 1 go-test, found: {}", + go_test_count + ); + assert!( + go_table_test_count == 2, + "Should find exactly 2 go-table-test-case, found: {}", + go_table_test_count + ); + } + + #[gpui::test] + fn test_go_table_test_map_ignored(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + func Example() { + _ = "some random string" + + notATableTest := map[string]struct { + someStr string + }{ + "some string": { + someStr: "foo", + }, + "some other string": { + someStr: "bar", + }, + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + !tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + !tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings ); } diff --git a/crates/languages/src/go/outline.scm b/crates/languages/src/go/outline.scm index e37ae7e572..c745f55aff 100644 --- a/crates/languages/src/go/outline.scm +++ b/crates/languages/src/go/outline.scm @@ -1,4 +1,5 @@ (comment) @annotation + (type_declaration "type" @context [ @@ -42,13 +43,13 @@ (var_declaration "var" @context [ + ; The declaration may define multiple variables, and so @item is on + ; the identifier so they get distinct ranges. (var_spec - name: (identifier) @name) @item + name: (identifier) @name @item) (var_spec_list - "(" (var_spec - name: (identifier) @name) @item - ")" + name: (identifier) @name @item) ) ] ) @@ -60,5 +61,7 @@ "(" @context ")" @context)) @item +; Fields declarations may define multiple fields, and so @item is on the +; declarator so they each get distinct ranges. (field_declaration - name: (_) @name) @item + name: (_) @name @item) diff --git a/crates/languages/src/go/runnables.scm b/crates/languages/src/go/runnables.scm index 6418cd04d8..f56262f799 100644 --- a/crates/languages/src/go/runnables.scm +++ b/crates/languages/src/go/runnables.scm @@ -91,3 +91,103 @@ ) @_ (#set! tag go-main) ) + +; Table test cases - slice and map +( + (short_var_declaration + left: (expression_list (identifier) @_collection_var) + right: (expression_list + (composite_literal + type: [ + (slice_type) + (map_type + key: (type_identifier) @_key_type + (#eq? @_key_type "string") + ) + ] + body: (literal_value + [ + (literal_element + (literal_value + (keyed_element + (literal_element + (identifier) @_field_name + ) + (literal_element + [ + (interpreted_string_literal) @run @_table_test_case_name + (raw_string_literal) @run @_table_test_case_name + ] + ) + ) + ) + ) + (keyed_element + (literal_element + [ + (interpreted_string_literal) @run @_table_test_case_name + (raw_string_literal) @run @_table_test_case_name + ] + ) + ) + ] + ) + ) + ) + ) + (for_statement + (range_clause + left: (expression_list + [ + ( + (identifier) + (identifier) @_loop_var + ) + (identifier) @_loop_var + ] + ) + right: (identifier) @_range_var + (#eq? @_range_var @_collection_var) + ) + body: (block + (expression_statement + (call_expression + function: (selector_expression + operand: (identifier) @_t_var + field: (field_identifier) @_run_method + (#eq? @_run_method "Run") + ) + arguments: (argument_list + . + [ + (selector_expression + operand: (identifier) @_tc_var + (#eq? @_tc_var @_loop_var) + field: (field_identifier) @_field_check + (#eq? @_field_check @_field_name) + ) + (identifier) @_arg_var + (#eq? @_arg_var @_loop_var) + ] + . + (func_literal + parameters: (parameter_list + (parameter_declaration + type: (pointer_type + (qualified_type + package: (package_identifier) @_pkg + name: (type_identifier) @_type + (#eq? @_pkg "testing") + (#eq? @_type "T") + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) @_ + (#set! tag go-table-test-case) +) diff --git a/crates/languages/src/javascript/outline.scm b/crates/languages/src/javascript/outline.scm index 026c71e1f9..ca16c27a27 100644 --- a/crates/languages/src/javascript/outline.scm +++ b/crates/languages/src/javascript/outline.scm @@ -31,12 +31,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item))) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/languages/src/json.rs b/crates/languages/src/json.rs index ca82bb2431..484631d01f 100644 --- a/crates/languages/src/json.rs +++ b/crates/languages/src/json.rs @@ -12,7 +12,7 @@ use language::{ LspAdapter, LspAdapterDelegate, }; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use settings::{KeymapFile, SettingsJsonSchemaParams, SettingsStore}; @@ -340,7 +340,12 @@ impl LspAdapter for JsonLspAdapter { let should_install_language_server = self .node - .should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version) + .should_install_npm_package( + Self::PACKAGE_NAME, + &server_path, + &container_dir, + VersionStrategy::Latest(version), + ) .await; if should_install_language_server { diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index 0524c02fd5..40131089d1 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -13,7 +13,7 @@ use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery}; use language::{Toolchain, WorkspaceFoldersContent}; use lsp::LanguageServerBinary; use lsp::LanguageServerName; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use pet_core::Configuration; use pet_core::os_environment::Environment; use pet_core::python_environment::PythonEnvironmentKind; @@ -205,7 +205,7 @@ impl LspAdapter for PythonLspAdapter { Self::SERVER_NAME.as_ref(), &server_path, &container_dir, - &version, + VersionStrategy::Latest(version), ) .await; diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index b52b1e7d55..3baaec1842 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -23,7 +23,7 @@ use std::{ sync::{Arc, LazyLock}, }; use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName}; -use util::fs::make_file_executable; +use util::fs::{make_file_executable, remove_matching}; use util::merge_json_value_into; use util::{ResultExt, maybe}; @@ -162,13 +162,13 @@ impl LspAdapter for RustLspAdapter { let asset_name = Self::build_asset_name(); let asset = release .assets - .iter() + .into_iter() .find(|asset| asset.name == asset_name) .with_context(|| format!("no asset found matching `{asset_name:?}`"))?; Ok(Box::new(GitHubLspBinaryVersion { name: release.tag_name, - url: asset.browser_download_url.clone(), - digest: asset.digest.clone(), + url: asset.browser_download_url, + digest: asset.digest, })) } @@ -178,11 +178,11 @@ impl LspAdapter for RustLspAdapter { container_dir: PathBuf, delegate: &dyn LspAdapterDelegate, ) -> Result { - let GitHubLspBinaryVersion { name, url, digest } = - &*version.downcast::().unwrap(); - let expected_digest = digest - .as_ref() - .and_then(|digest| digest.strip_prefix("sha256:")); + let GitHubLspBinaryVersion { + name, + url, + digest: expected_digest, + } = *version.downcast::().unwrap(); let destination_path = container_dir.join(format!("rust-analyzer-{name}")); let server_path = match Self::GITHUB_ASSET_KIND { AssetKind::TarGz | AssetKind::Gz => destination_path.clone(), // Tar and gzip extract in place. @@ -213,7 +213,7 @@ impl LspAdapter for RustLspAdapter { }) }; if let (Some(actual_digest), Some(expected_digest)) = - (&metadata.digest, expected_digest) + (&metadata.digest, &expected_digest) { if actual_digest == expected_digest { if validity_check().await.is_ok() { @@ -229,20 +229,20 @@ impl LspAdapter for RustLspAdapter { } } - _ = fs::remove_dir_all(&destination_path).await; download_server_binary( delegate, - url, - expected_digest, + &url, + expected_digest.as_deref(), &destination_path, Self::GITHUB_ASSET_KIND, ) .await?; make_file_executable(&server_path).await?; + remove_matching(&container_dir, |path| path != destination_path).await; GithubBinaryMetadata::write_to_file( &GithubBinaryMetadata { metadata_version: 1, - digest: expected_digest.map(ToString::to_string), + digest: expected_digest, }, &metadata_path, ) @@ -1023,8 +1023,14 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option path.clone(), // Tar and gzip extract in place. + AssetKind::Zip => path.clone().join("rust-analyzer.exe"), // zip contains a .exe + }; + anyhow::Ok(LanguageServerBinary { - path: last.context("no cached binary")?, + path, env: None, arguments: Default::default(), }) diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index a7edbb148c..0d647f07cf 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -5,7 +5,7 @@ use futures::StreamExt; use gpui::AsyncApp; use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use smol::fs; @@ -108,7 +108,12 @@ impl LspAdapter for TailwindLspAdapter { let should_install_language_server = self .node - .should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version) + .should_install_npm_package( + Self::PACKAGE_NAME, + &server_path, + &container_dir, + VersionStrategy::Latest(version), + ) .await; if should_install_language_server { diff --git a/crates/languages/src/tsx/outline.scm b/crates/languages/src/tsx/outline.scm index 5dafe791e4..f4261b9697 100644 --- a/crates/languages/src/tsx/outline.scm +++ b/crates/languages/src/tsx/outline.scm @@ -34,12 +34,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/languages/src/typescript.rs b/crates/languages/src/typescript.rs index f976b62614..1877c86dc5 100644 --- a/crates/languages/src/typescript.rs +++ b/crates/languages/src/typescript.rs @@ -10,7 +10,7 @@ use language::{ LspAdapterDelegate, }; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::{Value, json}; use smol::{fs, lock::RwLock, stream::StreamExt}; @@ -588,7 +588,7 @@ impl LspAdapter for TypeScriptLspAdapter { Self::PACKAGE_NAME, &server_path, &container_dir, - version.typescript_version.as_str(), + VersionStrategy::Latest(version.typescript_version.as_str()), ) .await; diff --git a/crates/languages/src/typescript/outline.scm b/crates/languages/src/typescript/outline.scm index 5dafe791e4..f4261b9697 100644 --- a/crates/languages/src/typescript/outline.scm +++ b/crates/languages/src/typescript/outline.scm @@ -34,12 +34,16 @@ (export_statement (lexical_declaration ["let" "const"] @context + ; Multiple names may be exported - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) (program (lexical_declaration ["let" "const"] @context + ; Multiple names may be defined - @item is on the declarator to keep + ; ranges distinct. (variable_declarator name: (_) @name) @item)) diff --git a/crates/languages/src/vtsls.rs b/crates/languages/src/vtsls.rs index 33751f733e..90faf883ba 100644 --- a/crates/languages/src/vtsls.rs +++ b/crates/languages/src/vtsls.rs @@ -4,7 +4,7 @@ use collections::HashMap; use gpui::AsyncApp; use language::{LanguageName, LanguageToolchainStore, LspAdapter, LspAdapterDelegate}; use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::Value; use std::{ @@ -115,7 +115,7 @@ impl LspAdapter for VtslsLspAdapter { Self::PACKAGE_NAME, &server_path, &container_dir, - &latest_version.server_version, + VersionStrategy::Latest(&latest_version.server_version), ) .await { @@ -128,7 +128,7 @@ impl LspAdapter for VtslsLspAdapter { Self::TYPESCRIPT_PACKAGE_NAME, &container_dir.join(Self::TYPESCRIPT_TSDK_PATH), &container_dir, - &latest_version.typescript_version, + VersionStrategy::Latest(&latest_version.typescript_version), ) .await { diff --git a/crates/languages/src/yaml.rs b/crates/languages/src/yaml.rs index 815605d524..15a4d590bc 100644 --- a/crates/languages/src/yaml.rs +++ b/crates/languages/src/yaml.rs @@ -6,7 +6,7 @@ use language::{ LanguageToolchainStore, LspAdapter, LspAdapterDelegate, language_settings::AllLanguageSettings, }; use lsp::{LanguageServerBinary, LanguageServerName}; -use node_runtime::NodeRuntime; +use node_runtime::{NodeRuntime, VersionStrategy}; use project::{Fs, lsp_store::language_server_settings}; use serde_json::Value; use settings::{Settings, SettingsLocation}; @@ -104,7 +104,12 @@ impl LspAdapter for YamlLspAdapter { let should_install_language_server = self .node - .should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version) + .should_install_npm_package( + Self::PACKAGE_NAME, + &server_path, + &container_dir, + VersionStrategy::Latest(version), + ) .await; if should_install_language_server { diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index 821fd5d390..58059967b7 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -39,6 +39,8 @@ tokio-tungstenite.workspace = true util.workspace = true workspace-hack.workspace = true +rodio = { workspace = true, features = ["wav_output"] } + [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ diff --git a/crates/livekit_client/src/lib.rs b/crates/livekit_client/src/lib.rs index 149859fdc8..e3934410e1 100644 --- a/crates/livekit_client/src/lib.rs +++ b/crates/livekit_client/src/lib.rs @@ -1,7 +1,13 @@ +use anyhow::Context as _; use collections::HashMap; mod remote_video_track_view; +use cpal::traits::HostTrait as _; pub use remote_video_track_view::{RemoteVideoTrackView, RemoteVideoTrackViewEvent}; +use rodio::DeviceTrait as _; + +mod record; +pub use record::CaptureInput; #[cfg(not(any( test, @@ -18,6 +24,8 @@ mod livekit_client; )))] pub use livekit_client::*; +// If you need proper LSP in livekit_client you've got to comment out +// the mocks and test #[cfg(any( test, feature = "test-support", @@ -168,3 +176,59 @@ pub enum RoomEvent { Reconnecting, Reconnected, } + +pub(crate) fn default_device( + input: bool, +) -> anyhow::Result<(cpal::Device, cpal::SupportedStreamConfig)> { + let device; + let config; + if input { + device = cpal::default_host() + .default_input_device() + .context("no audio input device available")?; + config = device + .default_input_config() + .context("failed to get default input config")?; + } else { + device = cpal::default_host() + .default_output_device() + .context("no audio output device available")?; + config = device + .default_output_config() + .context("failed to get default output config")?; + } + Ok((device, config)) +} + +pub(crate) fn get_sample_data( + sample_format: cpal::SampleFormat, + data: &cpal::Data, +) -> anyhow::Result> { + match sample_format { + cpal::SampleFormat::I8 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I16 => Ok(data.as_slice::().unwrap().to_vec()), + cpal::SampleFormat::I24 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::I64 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U8 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U16 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::U64 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::F32 => Ok(convert_sample_data::(data)), + cpal::SampleFormat::F64 => Ok(convert_sample_data::(data)), + _ => anyhow::bail!("Unsupported sample format"), + } +} + +pub(crate) fn convert_sample_data< + TSource: cpal::SizedSample, + TDest: cpal::SizedSample + cpal::FromSample, +>( + data: &cpal::Data, +) -> Vec { + data.as_slice::() + .unwrap() + .iter() + .map(|e| e.to_sample::()) + .collect() +} diff --git a/crates/livekit_client/src/livekit_client.rs b/crates/livekit_client/src/livekit_client.rs index 8f0ac1a456..adeea4f512 100644 --- a/crates/livekit_client/src/livekit_client.rs +++ b/crates/livekit_client/src/livekit_client.rs @@ -8,6 +8,8 @@ use gpui_tokio::Tokio; use playback::capture_local_video_track; mod playback; +#[cfg(feature = "record-microphone")] +mod record; use crate::{LocalTrack, Participant, RemoteTrack, RoomEvent, TrackPublication}; pub use playback::AudioStream; diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index f14e156125..d1eec42f8f 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,7 +1,6 @@ use anyhow::{Context as _, Result}; -use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _}; -use cpal::{Data, FromSample, I24, SampleFormat, SizedSample}; +use cpal::traits::{DeviceTrait, StreamTrait as _}; use futures::channel::mpsc::UnboundedSender; use futures::{Stream, StreamExt as _}; use gpui::{ @@ -166,7 +165,7 @@ impl AudioStack { ) -> Result<()> { loop { let mut device_change_listener = DeviceChangeListener::new(false)?; - let (output_device, output_config) = default_device(false)?; + let (output_device, output_config) = crate::default_device(false)?; let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>(); let mixer = mixer.clone(); let apm = apm.clone(); @@ -238,7 +237,7 @@ impl AudioStack { ) -> Result<()> { loop { let mut device_change_listener = DeviceChangeListener::new(true)?; - let (device, config) = default_device(true)?; + let (device, config) = crate::default_device(true)?; let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>(); let apm = apm.clone(); let frame_tx = frame_tx.clone(); @@ -262,7 +261,7 @@ impl AudioStack { config.sample_format(), move |data, _: &_| { let data = - Self::get_sample_data(config.sample_format(), data).log_err(); + crate::get_sample_data(config.sample_format(), data).log_err(); let Some(data) = data else { return; }; @@ -320,33 +319,6 @@ impl AudioStack { drop(end_on_drop_tx) } } - - fn get_sample_data(sample_format: SampleFormat, data: &Data) -> Result> { - match sample_format { - SampleFormat::I8 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I16 => Ok(data.as_slice::().unwrap().to_vec()), - SampleFormat::I24 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::I64 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U8 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U16 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::U64 => Ok(Self::convert_sample_data::(data)), - SampleFormat::F32 => Ok(Self::convert_sample_data::(data)), - SampleFormat::F64 => Ok(Self::convert_sample_data::(data)), - _ => anyhow::bail!("Unsupported sample format"), - } - } - - fn convert_sample_data>( - data: &Data, - ) -> Vec { - data.as_slice::() - .unwrap() - .iter() - .map(|e| e.to_sample::()) - .collect() - } } use super::LocalVideoTrack; @@ -393,27 +365,6 @@ pub(crate) async fn capture_local_video_track( )) } -fn default_device(input: bool) -> Result<(cpal::Device, cpal::SupportedStreamConfig)> { - let device; - let config; - if input { - device = cpal::default_host() - .default_input_device() - .context("no audio input device available")?; - config = device - .default_input_config() - .context("failed to get default input config")?; - } else { - device = cpal::default_host() - .default_output_device() - .context("no audio output device available")?; - config = device - .default_output_config() - .context("failed to get default output config")?; - } - Ok((device, config)) -} - #[derive(Clone)] struct AudioMixerSource { ssrc: i32, diff --git a/crates/livekit_client/src/record.rs b/crates/livekit_client/src/record.rs new file mode 100644 index 0000000000..925c0d4c67 --- /dev/null +++ b/crates/livekit_client/src/record.rs @@ -0,0 +1,91 @@ +use std::{ + env, + path::{Path, PathBuf}, + sync::{Arc, Mutex}, + time::Duration, +}; + +use anyhow::{Context, Result}; +use cpal::traits::{DeviceTrait, StreamTrait}; +use rodio::{buffer::SamplesBuffer, conversions::SampleTypeConverter}; +use util::ResultExt; + +pub struct CaptureInput { + pub name: String, + config: cpal::SupportedStreamConfig, + samples: Arc>>, + _stream: cpal::Stream, +} + +impl CaptureInput { + pub fn start() -> anyhow::Result { + let (device, config) = crate::default_device(true)?; + let name = device.name().unwrap_or("".to_string()); + log::info!("Using microphone: {}", name); + + let samples = Arc::new(Mutex::new(Vec::new())); + let stream = start_capture(device, config.clone(), samples.clone())?; + + Ok(Self { + name, + _stream: stream, + config, + samples, + }) + } + + pub fn finish(self) -> Result { + let name = self.name; + let mut path = env::current_dir().context("Could not get current dir")?; + path.push(&format!("test_recording_{name}.wav")); + log::info!("Test recording written to: {}", path.display()); + write_out(self.samples, self.config, &path)?; + Ok(path) + } +} + +fn start_capture( + device: cpal::Device, + config: cpal::SupportedStreamConfig, + samples: Arc>>, +) -> Result { + let stream = device + .build_input_stream_raw( + &config.config(), + config.sample_format(), + move |data, _: &_| { + let data = crate::get_sample_data(config.sample_format(), data).log_err(); + let Some(data) = data else { + return; + }; + samples + .try_lock() + .expect("Only locked after stream ends") + .extend_from_slice(&data); + }, + |err| log::error!("error capturing audio track: {:?}", err), + Some(Duration::from_millis(100)), + ) + .context("failed to build input stream")?; + + stream.play()?; + Ok(stream) +} + +fn write_out( + samples: Arc>>, + config: cpal::SupportedStreamConfig, + path: &Path, +) -> Result<()> { + let samples = std::mem::take( + &mut *samples + .try_lock() + .expect("Stream has ended, callback cant hold the lock"), + ); + let samples: Vec = SampleTypeConverter::<_, f32>::new(samples.into_iter()).collect(); + let mut samples = SamplesBuffer::new(config.channels(), config.sample_rate().0, samples); + match rodio::output_to_wav(&mut samples, path) { + Ok(_) => Ok(()), + Err(e) => Err(anyhow::anyhow!("Failed to write wav file: {}", e)), + } +} diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index a92787cd3e..ce9e2fe229 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -318,6 +318,8 @@ impl LanguageServer { } else { root_path.parent().unwrap_or_else(|| Path::new("/")) }; + let root_uri = Url::from_file_path(&working_dir) + .map_err(|()| anyhow!("{working_dir:?} is not a valid URI"))?; log::info!( "starting language server process. binary path: {:?}, working directory: {:?}, args: {:?}", @@ -345,8 +347,6 @@ impl LanguageServer { let stdin = server.stdin.take().unwrap(); let stdout = server.stdout.take().unwrap(); let stderr = server.stderr.take().unwrap(); - let root_uri = Url::from_file_path(&working_dir) - .map_err(|()| anyhow!("{working_dir:?} is not a valid URI"))?; let server = Self::new_internal( server_id, server_name, @@ -651,7 +651,7 @@ impl LanguageServer { capabilities: ClientCapabilities { general: Some(GeneralClientCapabilities { position_encodings: Some(vec![PositionEncodingKind::UTF16]), - ..Default::default() + ..GeneralClientCapabilities::default() }), workspace: Some(WorkspaceClientCapabilities { configuration: Some(true), @@ -665,6 +665,7 @@ impl LanguageServer { workspace_folders: Some(true), symbol: Some(WorkspaceSymbolClientCapabilities { resolve_support: None, + dynamic_registration: Some(true), ..WorkspaceSymbolClientCapabilities::default() }), inlay_hint: Some(InlayHintWorkspaceClientCapabilities { @@ -688,21 +689,21 @@ impl LanguageServer { ..WorkspaceEditClientCapabilities::default() }), file_operations: Some(WorkspaceFileOperationsClientCapabilities { - dynamic_registration: Some(false), + dynamic_registration: Some(true), did_rename: Some(true), will_rename: Some(true), - ..Default::default() + ..WorkspaceFileOperationsClientCapabilities::default() }), apply_edit: Some(true), execute_command: Some(ExecuteCommandClientCapabilities { - dynamic_registration: Some(false), + dynamic_registration: Some(true), }), - ..Default::default() + ..WorkspaceClientCapabilities::default() }), text_document: Some(TextDocumentClientCapabilities { definition: Some(GotoCapability { link_support: Some(true), - dynamic_registration: None, + dynamic_registration: Some(true), }), code_action: Some(CodeActionClientCapabilities { code_action_literal_support: Some(CodeActionLiteralSupport { @@ -725,7 +726,8 @@ impl LanguageServer { "command".to_string(), ], }), - ..Default::default() + dynamic_registration: Some(true), + ..CodeActionClientCapabilities::default() }), completion: Some(CompletionClientCapabilities { completion_item: Some(CompletionItemCapability { @@ -751,7 +753,7 @@ impl LanguageServer { MarkupKind::Markdown, MarkupKind::PlainText, ]), - ..Default::default() + ..CompletionItemCapability::default() }), insert_text_mode: Some(InsertTextMode::ADJUST_INDENTATION), completion_list: Some(CompletionListCapability { @@ -764,18 +766,20 @@ impl LanguageServer { ]), }), context_support: Some(true), - ..Default::default() + dynamic_registration: Some(true), + ..CompletionClientCapabilities::default() }), rename: Some(RenameClientCapabilities { prepare_support: Some(true), prepare_support_default_behavior: Some( PrepareSupportDefaultBehavior::IDENTIFIER, ), - ..Default::default() + dynamic_registration: Some(true), + ..RenameClientCapabilities::default() }), hover: Some(HoverClientCapabilities { content_format: Some(vec![MarkupKind::Markdown]), - dynamic_registration: None, + dynamic_registration: Some(true), }), inlay_hint: Some(InlayHintClientCapabilities { resolve_support: Some(InlayHintResolveClientCapabilities { @@ -787,7 +791,7 @@ impl LanguageServer { "label.command".to_string(), ], }), - dynamic_registration: Some(false), + dynamic_registration: Some(true), }), publish_diagnostics: Some(PublishDiagnosticsClientCapabilities { related_information: Some(true), @@ -818,26 +822,29 @@ impl LanguageServer { }), active_parameter_support: Some(true), }), + dynamic_registration: Some(true), ..SignatureHelpClientCapabilities::default() }), synchronization: Some(TextDocumentSyncClientCapabilities { did_save: Some(true), + dynamic_registration: Some(true), ..TextDocumentSyncClientCapabilities::default() }), code_lens: Some(CodeLensClientCapabilities { - dynamic_registration: Some(false), + dynamic_registration: Some(true), }), document_symbol: Some(DocumentSymbolClientCapabilities { hierarchical_document_symbol_support: Some(true), + dynamic_registration: Some(true), ..DocumentSymbolClientCapabilities::default() }), diagnostic: Some(DiagnosticClientCapabilities { - dynamic_registration: Some(false), + dynamic_registration: Some(true), related_document_support: Some(true), }) .filter(|_| pull_diagnostics), color_provider: Some(DocumentColorClientCapabilities { - dynamic_registration: Some(false), + dynamic_registration: Some(true), }), ..TextDocumentClientCapabilities::default() }), @@ -850,7 +857,7 @@ impl LanguageServer { show_message: Some(ShowMessageRequestClientCapabilities { message_action_item: None, }), - ..Default::default() + ..WindowClientCapabilities::default() }), }, trace: None, @@ -862,8 +869,7 @@ impl LanguageServer { } }), locale: None, - - ..Default::default() + ..InitializeParams::default() } } @@ -1672,7 +1678,7 @@ impl LanguageServer { workspace_symbol_provider: Some(OneOf::Left(true)), implementation_provider: Some(ImplementationProviderCapability::Simple(true)), type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)), - ..Default::default() + ..ServerCapabilities::default() } } } diff --git a/crates/markdown/examples/markdown.rs b/crates/markdown/examples/markdown.rs index bf685bd9ac..c651c7921d 100644 --- a/crates/markdown/examples/markdown.rs +++ b/crates/markdown/examples/markdown.rs @@ -77,16 +77,16 @@ impl Render for MarkdownExample { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let markdown_style = MarkdownStyle { base_text_style: gpui::TextStyle { - font_family: "Zed Plex Sans".into(), + font_family: ".ZedSans".into(), color: cx.theme().colors().terminal_ansi_black, ..Default::default() }, code_block: StyleRefinement::default() - .font_family("Zed Plex Mono") + .font_family(".ZedMono") .m(rems(1.)) .bg(rgb(0xAAAAAAA)), inline_code: gpui::TextStyleRefinement { - font_family: Some("Zed Mono".into()), + font_family: Some(".ZedMono".into()), color: Some(cx.theme().colors().editor_foreground), background_color: Some(cx.theme().colors().editor_background), ..Default::default() diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index dba4bc64b1..a3235a9773 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -1084,7 +1084,9 @@ impl Element for MarkdownElement { self.markdown.clone(), cx, ); - el.child(div().absolute().top_1().right_1().w_5().child(codeblock)) + el.child( + div().absolute().top_1().right_0p5().w_5().child(codeblock), + ) }); } @@ -1312,6 +1314,7 @@ fn render_copy_code_block_button( }, ) .icon_color(Color::Muted) + .icon_size(IconSize::Small) .shape(ui::IconButtonShape::Square) .tooltip(Tooltip::text("Copy Code")) .on_click({ diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index c466a598a0..5b4d05377c 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -86,6 +86,7 @@ pub enum Model { max_completion_tokens: Option, supports_tools: Option, supports_images: Option, + supports_thinking: Option, }, } @@ -214,6 +215,16 @@ impl Model { } => supports_images.unwrap_or(false), } } + + pub fn supports_thinking(&self) -> bool { + match self { + Self::MagistralMediumLatest | Self::MagistralSmallLatest => true, + Self::Custom { + supports_thinking, .. + } => supports_thinking.unwrap_or(false), + _ => false, + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -288,7 +299,9 @@ pub enum ToolChoice { #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { - content: Option, + #[serde(flatten)] + #[serde(default, skip_serializing_if = "Option::is_none")] + content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, @@ -297,7 +310,8 @@ pub enum RequestMessage { content: MessageContent, }, System { - content: String, + #[serde(flatten)] + content: MessageContent, }, Tool { content: String, @@ -305,7 +319,7 @@ pub enum RequestMessage { }, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] #[serde(untagged)] pub enum MessageContent { #[serde(rename = "content")] @@ -346,11 +360,21 @@ impl MessageContent { } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] #[serde(tag = "type", rename_all = "snake_case")] pub enum MessagePart { Text { text: String }, ImageUrl { image_url: String }, + Thinking { thinking: Vec }, +} + +// Backwards-compatibility alias for provider code that refers to ContentPart +pub type ContentPart = MessagePart; + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ThinkingPart { + Text { text: String }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -418,24 +442,30 @@ pub struct StreamChoice { pub finish_reason: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct StreamDelta { pub role: Option, - pub content: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] +#[serde(untagged)] +pub enum MessageContentDelta { + Text(String), + Parts(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct ToolCallChunk { pub index: usize, pub id: Option, pub function: Option, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct FunctionChunk { pub name: Option, pub arguments: Option, diff --git a/crates/node_runtime/src/node_runtime.rs b/crates/node_runtime/src/node_runtime.rs index 08698a1d6c..f92c122e71 100644 --- a/crates/node_runtime/src/node_runtime.rs +++ b/crates/node_runtime/src/node_runtime.rs @@ -29,6 +29,13 @@ pub struct NodeBinaryOptions { pub use_paths: Option<(PathBuf, PathBuf)>, } +pub enum VersionStrategy<'a> { + /// Install if current version doesn't match pinned version + Pin(&'a str), + /// Install if current version is older than latest version + Latest(&'a str), +} + #[derive(Clone)] pub struct NodeRuntime(Arc>); @@ -286,7 +293,7 @@ impl NodeRuntime { package_name: &str, local_executable_path: &Path, local_package_directory: &Path, - latest_version: &str, + version_strategy: VersionStrategy<'_>, ) -> bool { // In the case of the local system not having the package installed, // or in the instances where we fail to parse package.json data, @@ -307,11 +314,21 @@ impl NodeRuntime { let Some(installed_version) = Version::parse(&installed_version).log_err() else { return true; }; - let Some(latest_version) = Version::parse(latest_version).log_err() else { - return true; - }; - installed_version < latest_version + match version_strategy { + VersionStrategy::Pin(pinned_version) => { + let Some(pinned_version) = Version::parse(pinned_version).log_err() else { + return true; + }; + installed_version != pinned_version + } + VersionStrategy::Latest(latest_version) => { + let Some(latest_version) = Version::parse(latest_version).log_err() else { + return true; + }; + installed_version < latest_version + } + } } } diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index 436c714cf3..4157be3172 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -18,14 +18,13 @@ default = [] ai_onboarding.workspace = true anyhow.workspace = true client.workspace = true -command_palette_hooks.workspace = true component.workspace = true db.workspace = true documented.workspace = true editor.workspace = true -feature_flags.workspace = true fs.workspace = true fuzzy.workspace = true +git.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true @@ -37,6 +36,7 @@ project.workspace = true schemars.workspace = true serde.workspace = true settings.workspace = true +telemetry.workspace = true theme.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index 8203f96479..bb1932bdf2 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -188,6 +188,11 @@ fn render_llm_provider_card( workspace .update(cx, |workspace, cx| { workspace.toggle_modal(window, cx, |window, cx| { + telemetry::event!( + "Welcome AI Modal Opened", + provider = provider.name().0, + ); + let modal = AiConfigurationModal::new( provider.clone(), window, @@ -245,16 +250,25 @@ pub(crate) fn render_ai_setup_page( ToggleState::Selected }, |&toggle_state, _, cx| { + let enabled = match toggle_state { + ToggleState::Indeterminate => { + return; + } + ToggleState::Unselected => true, + ToggleState::Selected => false, + }; + + telemetry::event!( + "Welcome AI Enabled", + toggle = if enabled { "on" } else { "off" }, + ); + let fs = ::global(cx); update_settings_file::( fs, cx, move |ai_settings: &mut Option, _| { - *ai_settings = match toggle_state { - ToggleState::Indeterminate => None, - ToggleState::Unselected => Some(true), - ToggleState::Selected => Some(false), - }; + *ai_settings = Some(enabled); }, ); }, diff --git a/crates/welcome/src/base_keymap_picker.rs b/crates/onboarding/src/base_keymap_picker.rs similarity index 99% rename from crates/welcome/src/base_keymap_picker.rs rename to crates/onboarding/src/base_keymap_picker.rs index 92317ca711..0ac07d9a9d 100644 --- a/crates/welcome/src/base_keymap_picker.rs +++ b/crates/onboarding/src/base_keymap_picker.rs @@ -12,7 +12,7 @@ use util::ResultExt; use workspace::{ModalView, Workspace, ui::HighlightedLabel}; actions!( - welcome, + zed, [ /// Toggles the base keymap selector modal. ToggleBaseKeymapSelector diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs index a19a21fddf..86ddc22a86 100644 --- a/crates/onboarding/src/basics_page.rs +++ b/crates/onboarding/src/basics_page.rs @@ -58,7 +58,7 @@ fn render_theme_section(tab_index: &mut isize, cx: &mut App) -> impl IntoElement .tab_index(tab_index) .selected_index(theme_mode as usize) .style(ui::ToggleButtonGroupStyle::Outlined) - .button_width(rems_from_px(64.)), + .width(rems_from_px(3. * 64.)), ), ) .child( @@ -305,8 +305,8 @@ fn render_base_keymap_section(tab_index: &mut isize, cx: &mut App) -> impl IntoE .when_some(base_keymap, |this, base_keymap| { this.selected_index(base_keymap) }) + .full_width() .tab_index(tab_index) - .button_width(rems_from_px(216.)) .size(ui::ToggleButtonGroupSize::Medium) .style(ui::ToggleButtonGroupStyle::Outlined), ); diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index 13b4f6a5c1..d941a0315a 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -35,6 +35,11 @@ fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { EditorSettings::override_global(curr_settings, cx); update_settings_file::(fs, cx, move |editor_settings, _| { + telemetry::event!( + "Welcome Minimap Clicked", + from = editor_settings.minimap.unwrap_or_default(), + to = show + ); editor_settings.minimap.get_or_insert_default().show = Some(show); }); } @@ -71,7 +76,7 @@ fn read_git_blame(cx: &App) -> bool { ProjectSettings::get_global(cx).git.inline_blame_enabled() } -fn set_git_blame(enabled: bool, cx: &mut App) { +fn write_git_blame(enabled: bool, cx: &mut App) { let fs = ::global(cx); let mut curr_settings = ProjectSettings::get_global(cx).clone(); @@ -95,6 +100,12 @@ fn write_ui_font_family(font: SharedString, cx: &mut App) { let fs = ::global(cx); update_settings_file::(fs, cx, move |theme_settings, _| { + telemetry::event!( + "Welcome Font Changed", + type = "ui font", + old = theme_settings.ui_font_family, + new = font.clone() + ); theme_settings.ui_font_family = Some(FontFamilyName(font.into())); }); } @@ -119,6 +130,13 @@ fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { let fs = ::global(cx); update_settings_file::(fs, cx, move |theme_settings, _| { + telemetry::event!( + "Welcome Font Changed", + type = "editor font", + old = theme_settings.buffer_font_family, + new = font_family.clone() + ); + theme_settings.buffer_font_family = Some(FontFamilyName(font_family.into())); }); } @@ -197,7 +215,7 @@ fn render_setting_import_button( .color(Color::Muted) .size(IconSize::XSmall), ) - .child(Label::new(label)), + .child(Label::new(label.clone())), ) .when(imported, |this| { this.child( @@ -212,7 +230,10 @@ fn render_setting_import_button( ) }), ) - .on_click(move |_, window, cx| window.dispatch_action(action.boxed_clone(), cx)), + .on_click(move |_, window, cx| { + telemetry::event!("Welcome Import Settings", import_source = label,); + window.dispatch_action(action.boxed_clone(), cx); + }), ) } @@ -573,7 +594,7 @@ fn font_picker( ) -> FontPicker { let delegate = FontPickerDelegate::new(current_font, on_font_changed, cx); - Picker::list(delegate, window, cx) + Picker::uniform_list(delegate, window, cx) .show_scrollbar(true) .width(rems_from_px(210.)) .max_height(Some(rems(20.).into())) @@ -605,7 +626,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_font_ligatures(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Font Ligature", + options = if enabled { "on" } else { "off" }, + ); + + write_font_ligatures(enabled, cx); }, ) .tab_index({ @@ -625,7 +652,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_format_on_save(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Format On Save Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_format_on_save(enabled, cx); }, ) .tab_index({ @@ -644,7 +677,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Inlay Hints Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_inlay_hints(enabled, cx); }, ) .tab_index({ @@ -663,7 +702,13 @@ fn render_popular_settings_section( ui::ToggleState::Unselected }, |toggle_state, _, cx| { - set_git_blame(toggle_state == &ToggleState::Selected, cx); + let enabled = toggle_state == &ToggleState::Selected; + telemetry::event!( + "Welcome Git Blame Changed", + options = if enabled { "on" } else { "off" }, + ); + + write_git_blame(enabled, cx); }, ) .tab_index({ @@ -676,7 +721,7 @@ fn render_popular_settings_section( .items_start() .justify_between() .child( - v_flex().child(Label::new("Mini Map")).child( + v_flex().child(Label::new("Minimap")).child( Label::new("See a high-level overview of your source code.") .color(Color::Muted), ), @@ -706,7 +751,7 @@ fn render_popular_settings_section( }) .tab_index(tab_index) .style(ToggleButtonGroupStyle::Outlined) - .button_width(ui::rems_from_px(64.)), + .width(ui::rems_from_px(3. * 64.)), ), ) } diff --git a/crates/welcome/src/multibuffer_hint.rs b/crates/onboarding/src/multibuffer_hint.rs similarity index 100% rename from crates/welcome/src/multibuffer_hint.rs rename to crates/onboarding/src/multibuffer_hint.rs diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index 7ba7ba60cb..e07a8dc9fb 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,8 +1,7 @@ -use crate::welcome::{ShowWelcome, WelcomePage}; +pub use crate::welcome::ShowWelcome; +use crate::{multibuffer_hint::MultibufferHint, welcome::WelcomePage}; use client::{Client, UserStore, zed_urls}; -use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; -use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, @@ -27,17 +26,13 @@ use workspace::{ }; mod ai_setup_page; +mod base_keymap_picker; mod basics_page; mod editing_page; +pub mod multibuffer_hint; mod theme_preview; mod welcome; -pub struct OnBoardingFeatureFlag {} - -impl FeatureFlag for OnBoardingFeatureFlag { - const NAME: &'static str = "onboarding"; -} - /// Imports settings from Visual Studio Code. #[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] #[action(namespace = zed)] @@ -57,6 +52,7 @@ pub struct ImportCursorSettings { } pub const FIRST_OPEN: &str = "first_open"; +pub const DOCS_URL: &str = "https://zed.dev/docs/"; actions!( zed, @@ -80,11 +76,19 @@ actions!( /// Sign in while in the onboarding flow. SignIn, /// Open the user account in zed.dev while in the onboarding flow. - OpenAccount + OpenAccount, + /// Resets the welcome screen hints to their initial state. + ResetHints ] ); pub fn init(cx: &mut App) { + cx.observe_new(|workspace: &mut Workspace, _, _cx| { + workspace + .register_action(|_workspace, _: &ResetHints, _, cx| MultibufferHint::set_count(0, cx)); + }) + .detach(); + cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { workspace @@ -182,38 +186,14 @@ pub fn init(cx: &mut App) { }) .detach(); - cx.observe_new::(|_, window, cx| { - let Some(window) = window else { - return; - }; + base_keymap_picker::init(cx); - let onboarding_actions = [ - std::any::TypeId::of::(), - std::any::TypeId::of::(), - ]; - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&onboarding_actions); - }); - - cx.observe_flag::(window, move |is_enabled, _, _, cx| { - if is_enabled { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_action_types(onboarding_actions.iter()); - }); - } else { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&onboarding_actions); - }); - } - }) - .detach(); - }) - .detach(); register_serializable_item::(cx); + register_serializable_item::(cx); } pub fn show_onboarding_view(app_state: Arc, cx: &mut App) -> Task> { + telemetry::event!("Onboarding Page Opened"); open_new( Default::default(), app_state, @@ -242,6 +222,16 @@ enum SelectedPage { AiSetup, } +impl SelectedPage { + fn name(&self) -> &'static str { + match self { + SelectedPage::Basics => "Basics", + SelectedPage::Editing => "Editing", + SelectedPage::AiSetup => "AI Setup", + } + } +} + struct Onboarding { workspace: WeakEntity, focus_handle: FocusHandle, @@ -261,7 +251,21 @@ impl Onboarding { }) } - fn set_page(&mut self, page: SelectedPage, cx: &mut Context) { + fn set_page( + &mut self, + page: SelectedPage, + clicked: Option<&'static str>, + cx: &mut Context, + ) { + if let Some(click) = clicked { + telemetry::event!( + "Welcome Tab Clicked", + from = self.selected_page.name(), + to = page.name(), + clicked = click, + ); + } + self.selected_page = page; cx.notify(); cx.emit(ItemEvent::UpdateTab); @@ -325,8 +329,13 @@ impl Onboarding { gpui::Empty.into_any_element(), IntoElement::into_any_element, )) - .on_click(cx.listener(move |this, _, _, cx| { - this.set_page(page, cx); + .on_click(cx.listener(move |this, click_event, _, cx| { + let click = match click_event { + gpui::ClickEvent::Mouse(_) => "mouse", + gpui::ClickEvent::Keyboard(_) => "keyboard", + }; + + this.set_page(page, Some(click), cx); })) }) } @@ -475,6 +484,7 @@ impl Onboarding { } fn on_finish(_: &Finish, _: &mut Window, cx: &mut App) { + telemetry::event!("Welcome Skip Clicked"); go_to_welcome_page(cx); } @@ -532,13 +542,13 @@ impl Render for Onboarding { .on_action(Self::handle_sign_in) .on_action(Self::handle_open_account) .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { - this.set_page(SelectedPage::Basics, cx); + this.set_page(SelectedPage::Basics, Some("action"), cx); })) .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { - this.set_page(SelectedPage::Editing, cx); + this.set_page(SelectedPage::Editing, Some("action"), cx); })) .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { - this.set_page(SelectedPage::AiSetup, cx); + this.set_page(SelectedPage::AiSetup, Some("action"), cx); })) .on_action(cx.listener(|_, _: &menu::SelectNext, window, cx| { window.focus_next(); @@ -551,6 +561,7 @@ impl Render for Onboarding { .child( h_flex() .max_w(rems_from_px(1100.)) + .max_h(rems_from_px(850.)) .size_full() .m_auto() .py_20() @@ -560,12 +571,14 @@ impl Render for Onboarding { .child(self.render_nav(window, cx)) .child( v_flex() + .id("page-content") + .size_full() .max_w_full() .min_w_0() .pl_12() .border_l_1() .border_color(cx.theme().colors().border_variant.opacity(0.5)) - .size_full() + .overflow_y_scroll() .child(self.render_page(window, cx)), ), ) @@ -803,7 +816,7 @@ impl workspace::SerializableItem for Onboarding { if let Some(page) = page { zlog::info!("Onboarding page {page:?} loaded"); onboarding_page.update(cx, |onboarding_page, cx| { - onboarding_page.set_page(page, cx); + onboarding_page.set_page(page, None, cx); }) } onboarding_page diff --git a/crates/onboarding/src/theme_preview.rs b/crates/onboarding/src/theme_preview.rs index 81eb14ec4b..9f299eb6ea 100644 --- a/crates/onboarding/src/theme_preview.rs +++ b/crates/onboarding/src/theme_preview.rs @@ -1,6 +1,9 @@ #![allow(unused, dead_code)] use gpui::{Hsla, Length}; -use std::sync::Arc; +use std::{ + cell::LazyCell, + sync::{Arc, LazyLock, OnceLock}, +}; use theme::{Theme, ThemeColors, ThemeRegistry}; use ui::{ IntoElement, RenderOnce, component_prelude::Documented, prelude::*, utils::inner_corner_radius, @@ -22,6 +25,15 @@ pub struct ThemePreviewTile { style: ThemePreviewStyle, } +static CHILD_RADIUS: LazyLock = LazyLock::new(|| { + inner_corner_radius( + ThemePreviewTile::ROOT_RADIUS, + ThemePreviewTile::ROOT_BORDER, + ThemePreviewTile::ROOT_PADDING, + ThemePreviewTile::CHILD_BORDER, + ) +}); + impl ThemePreviewTile { pub const SKELETON_HEIGHT_DEFAULT: Pixels = px(2.); pub const SIDEBAR_SKELETON_ITEM_COUNT: usize = 8; @@ -30,14 +42,6 @@ impl ThemePreviewTile { pub const ROOT_BORDER: Pixels = px(2.0); pub const ROOT_PADDING: Pixels = px(2.0); pub const CHILD_BORDER: Pixels = px(1.0); - pub const CHILD_RADIUS: std::cell::LazyCell = std::cell::LazyCell::new(|| { - inner_corner_radius( - Self::ROOT_RADIUS, - Self::ROOT_BORDER, - Self::ROOT_PADDING, - Self::CHILD_BORDER, - ) - }); pub fn new(theme: Arc, seed: f32) -> Self { Self { @@ -222,7 +226,7 @@ impl ThemePreviewTile { .child( div() .size_full() - .rounded(*Self::CHILD_RADIUS) + .rounded(*CHILD_RADIUS) .border(Self::CHILD_BORDER) .border_color(theme.colors().border) .child(Self::render_editor( @@ -250,7 +254,7 @@ impl ThemePreviewTile { h_flex() .size_full() .relative() - .rounded(*Self::CHILD_RADIUS) + .rounded(*CHILD_RADIUS) .border(Self::CHILD_BORDER) .border_color(border_color) .overflow_hidden() diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs index d4d6c3f701..ba0053a3b6 100644 --- a/crates/onboarding/src/welcome.rs +++ b/crates/onboarding/src/welcome.rs @@ -1,6 +1,6 @@ use gpui::{ Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, - NoAction, ParentElement, Render, Styled, Window, actions, + ParentElement, Render, Styled, Task, Window, actions, }; use menu::{SelectNext, SelectPrevious}; use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; @@ -38,8 +38,7 @@ const CONTENT: (Section<4>, Section<3>) = ( SectionEntry { icon: IconName::CloudDownload, title: "Clone a Repo", - // TODO: use proper action - action: &NoAction, + action: &git::Clone, }, SectionEntry { icon: IconName::ListCollapse, @@ -353,3 +352,109 @@ impl Item for WelcomePage { f(*event) } } + +impl workspace::SerializableItem for WelcomePage { + fn serialized_item_kind() -> &'static str { + "WelcomePage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec, + _window: &mut Window, + cx: &mut App, + ) -> Task> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "welcome_pages", + &persistence::WELCOME_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity, + _workspace: gpui::WeakEntity, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> Task>> { + if persistence::WELCOME_PAGES + .get_welcome_page(item_id, workspace_id) + .ok() + .is_some_and(|is_open| is_open) + { + window.spawn(cx, async move |cx| cx.update(WelcomePage::new)) + } else { + Task::ready(Err(anyhow::anyhow!("No welcome page to deserialize"))) + } + } + + fn serialize( + &mut self, + workspace: &mut workspace::Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut Context, + ) -> Option>> { + let workspace_id = workspace.database_id()?; + Some(cx.background_spawn(async move { + persistence::WELCOME_PAGES + .save_welcome_page(item_id, workspace_id, true) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref WELCOME_PAGES: WelcomePagesDb = + &[ + sql!( + CREATE TABLE welcome_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + is_open INTEGER DEFAULT FALSE, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl WelcomePagesDb { + query! { + pub async fn save_welcome_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + is_open: bool + ) -> Result<()> { + INSERT OR REPLACE INTO welcome_pages(item_id, workspace_id, is_open) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_welcome_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result { + SELECT is_open + FROM welcome_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 2d40cd2735..bae00f0a8e 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -20,6 +20,7 @@ anyhow.workspace = true futures.workspace = true http_client.workspace = true schemars = { workspace = true, optional = true } +log.workspace = true serde.workspace = true serde_json.workspace = true strum.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 4697d71ed3..604e8fe622 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -89,11 +89,13 @@ pub enum Model { max_tokens: u64, max_output_tokens: Option, max_completion_tokens: Option, + reasoning_effort: Option, }, } impl Model { pub fn default_fast() -> Self { + // TODO: Replace with FiveMini since all other models are deprecated Self::FourPointOneMini } @@ -206,6 +208,15 @@ impl Model { } } + pub fn reasoning_effort(&self) -> Option { + match self { + Self::Custom { + reasoning_effort, .. + } => reasoning_effort.to_owned(), + _ => None, + } + } + /// Returns whether the given model supports the `parallel_tool_calls` parameter. /// /// If the model does not support the parameter, do not pass it up, or the API will return an error. @@ -225,6 +236,13 @@ impl Model { Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } + + /// Returns whether the given model supports the `prompt_cache_key` parameter. + /// + /// If the model does not support the parameter, do not pass it up. + pub fn supports_prompt_cache_key(&self) -> bool { + return true; + } } #[derive(Debug, Serialize, Deserialize)] @@ -244,6 +262,10 @@ pub struct Request { pub parallel_tool_calls: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -255,6 +277,16 @@ pub enum ToolChoice { Other(ToolDefinition), } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + Minimal, + Low, + Medium, + High, +} + #[derive(Clone, Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { @@ -445,7 +477,15 @@ pub async fn stream_completion( Ok(ResponseStreamResult::Err { error }) => { Some(Err(anyhow!(error))) } - Err(error) => Some(Err(anyhow!(error))), + Err(error) => { + log::error!( + "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\ + Response: `{}`", + error, + line, + ); + Some(Err(anyhow!(error))) + } } } } diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 1cda3897ec..004a27b0cf 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -4815,51 +4815,45 @@ impl OutlinePanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |outline_panel, range, _, _| { - let entries = outline_panel.cached_entries.get(range); - if let Some(entries) = entries { - entries.into_iter().map(|item| item.depth).collect() - } else { - smallvec::SmallVec::new() - } - }, - ) - .with_render_fn( - cx.entity().clone(), - move |outline_panel, params, _, _| { - const LEFT_OFFSET: Pixels = px(14.); + .with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| { + let entries = outline_panel.cached_entries.get(range); + if let Some(entries) = entries { + entries.into_iter().map(|item| item.depth).collect() + } else { + smallvec::SmallVec::new() + } + }) + .with_render_fn(cx.entity(), move |outline_panel, params, _, _| { + const LEFT_OFFSET: Pixels = px(14.); - let indent_size = params.indent_size; - let item_height = params.item_height; - let active_indent_guide_ix = find_active_indent_guide_ix( - outline_panel, - ¶ms.indent_guides, - ); + let indent_size = params.indent_size; + let item_height = params.item_height; + let active_indent_guide_ix = find_active_indent_guide_ix( + outline_panel, + ¶ms.indent_guides, + ); - params - .indent_guides - .into_iter() - .enumerate() - .map(|(ix, layout)| { - let bounds = Bounds::new( - point( - layout.offset.x * indent_size + LEFT_OFFSET, - layout.offset.y * item_height, - ), - size(px(1.), layout.length * item_height), - ); - ui::RenderedIndentGuide { - bounds, - layout, - is_active: active_indent_guide_ix == Some(ix), - hitbox: None, - } - }) - .collect() - }, - ), + params + .indent_guides + .into_iter() + .enumerate() + .map(|(ix, layout)| { + let bounds = Bounds::new( + point( + layout.offset.x * indent_size + LEFT_OFFSET, + layout.offset.y * item_height, + ), + size(px(1.), layout.length * item_height), + ); + ui::RenderedIndentGuide { + bounds, + layout, + is_active: active_indent_guide_ix == Some(ix), + hitbox: None, + } + }) + .collect() + }), ) }) }; diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 01fc987816..3163a10239 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -414,6 +414,7 @@ impl GitStore { pub fn init(client: &AnyProtoClient) { client.add_entity_request_handler(Self::handle_get_remotes); client.add_entity_request_handler(Self::handle_get_branches); + client.add_entity_request_handler(Self::handle_get_default_branch); client.add_entity_request_handler(Self::handle_change_branch); client.add_entity_request_handler(Self::handle_create_branch); client.add_entity_request_handler(Self::handle_git_init); @@ -441,6 +442,7 @@ impl GitStore { client.add_entity_request_handler(Self::handle_blame_buffer); client.add_entity_message_handler(Self::handle_update_repository); client.add_entity_message_handler(Self::handle_remove_repository); + client.add_entity_request_handler(Self::handle_git_clone); } pub fn is_local(&self) -> bool { @@ -1464,6 +1466,45 @@ impl GitStore { } } + pub fn git_clone( + &self, + repo: String, + path: impl Into>, + cx: &App, + ) -> Task> { + let path = path.into(); + match &self.state { + GitStoreState::Local { fs, .. } => { + let fs = fs.clone(); + cx.background_executor() + .spawn(async move { fs.git_clone(&repo, &path).await }) + } + GitStoreState::Ssh { + upstream_client, + upstream_project_id, + .. + } => { + let request = upstream_client.request(proto::GitClone { + project_id: upstream_project_id.0, + abs_path: path.to_string_lossy().to_string(), + remote_repo: repo, + }); + + cx.background_spawn(async move { + let result = request.await?; + + match result.success { + true => Ok(()), + false => Err(anyhow!("Git Clone failed")), + } + }) + } + GitStoreState::Remote { .. } => { + Task::ready(Err(anyhow!("Git Clone isn't supported for remote users"))) + } + } + } + async fn handle_update_repository( this: Entity, envelope: TypedEnvelope, @@ -1550,6 +1591,22 @@ impl GitStore { Ok(proto::Ack {}) } + async fn handle_git_clone( + this: Entity, + envelope: TypedEnvelope, + cx: AsyncApp, + ) -> Result { + let path: Arc = PathBuf::from(envelope.payload.abs_path).into(); + let repo_name = envelope.payload.remote_repo; + let result = cx + .update(|cx| this.read(cx).git_clone(repo_name, path, cx))? + .await; + + Ok(proto::GitCloneResponse { + success: result.is_ok(), + }) + } + async fn handle_fetch( this: Entity, envelope: TypedEnvelope, @@ -1838,6 +1895,23 @@ impl GitStore { .collect::>(), }) } + async fn handle_get_default_branch( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + let branch = repository_handle + .update(&mut cx, |repository_handle, _| { + repository_handle.default_branch() + })? + .await?? + .map(Into::into); + + Ok(proto::GetDefaultBranchResponse { branch }) + } async fn handle_create_branch( this: Entity, envelope: TypedEnvelope, @@ -2275,7 +2349,7 @@ impl GitStore { return None; }; - let mut paths = vec![]; + let mut paths = Vec::new(); // All paths prefixed by a given repo will constitute a continuous range. while let Some(path) = entries.get(ix) && let Some(repo_path) = @@ -2284,7 +2358,11 @@ impl GitStore { paths.push((repo_path, ix)); ix += 1; } - Some((repo, paths)) + if paths.is_empty() { + None + } else { + Some((repo, paths)) + } }); tasks.push_back(task); } @@ -4264,7 +4342,8 @@ impl Repository { bail!("not a local repository") }; let (snapshot, events) = this - .read_with(&mut cx, |this, _| { + .update(&mut cx, |this, _| { + this.paths_needing_status_update.clear(); compute_snapshot( this.id, this.work_directory_abs_path.clone(), @@ -4494,6 +4573,9 @@ impl Repository { }; let paths = changed_paths.iter().cloned().collect::>(); + if paths.is_empty() { + return Ok(()); + } let statuses = backend.status(&paths).await?; let changed_path_statuses = cx diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index d3843bc4ea..196f55171a 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -390,13 +390,17 @@ impl LocalLspStore { delegate.update_status( adapter.name(), BinaryStatus::Failed { - error: format!("{err}\n-- stderr--\n{log}"), + error: if log.is_empty() { + format!("{err:#}") + } else { + format!("{err:#}\n-- stderr --\n{log}") + }, }, ); - let message = - format!("Failed to start language server {server_name:?}: {err:#?}"); - log::error!("{message}"); - log::error!("server stderr: {log}"); + log::error!("Failed to start language server {server_name:?}: {err:?}"); + if !log.is_empty() { + log::error!("server stderr: {log}"); + } None } } @@ -638,139 +642,27 @@ impl LocalLspStore { language_server .on_request::({ - let this = this.clone(); + let lsp_store = this.clone(); move |params, cx| { - let lsp_store = this.clone(); + let lsp_store = lsp_store.clone(); let mut cx = cx.clone(); async move { - for reg in params.registrations { - match reg.method.as_str() { - "workspace/didChangeWatchedFiles" => { - if let Some(options) = reg.register_options { - let options = serde_json::from_value(options)?; - lsp_store.update(&mut cx, |this, cx| { - this.as_local_mut()?.on_lsp_did_change_watched_files( - server_id, ®.id, options, cx, + lsp_store + .update(&mut cx, |lsp_store, cx| { + if lsp_store.as_local().is_some() { + match lsp_store + .register_server_capabilities(server_id, params, cx) + { + Ok(()) => {} + Err(e) => { + log::error!( + "Failed to register server capabilities: {e:#}" ); - Some(()) - })?; - } - } - "textDocument/rangeFormatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - let options = reg - .register_options - .map(|options| { - serde_json::from_value::< - lsp::DocumentRangeFormattingOptions, - >( - options - ) - }) - .transpose()?; - let provider = match options { - None => OneOf::Left(true), - Some(options) => OneOf::Right(options), - }; - server.update_capabilities(|capabilities| { - capabilities.document_range_formatting_provider = - Some(provider); - }); - notify_server_capabilities_updated(&server, cx); } - anyhow::Ok(()) - })??; + }; } - "textDocument/onTypeFormatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - let options = reg - .register_options - .map(|options| { - serde_json::from_value::< - lsp::DocumentOnTypeFormattingOptions, - >( - options - ) - }) - .transpose()?; - if let Some(options) = options { - server.update_capabilities(|capabilities| { - capabilities - .document_on_type_formatting_provider = - Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } - } - anyhow::Ok(()) - })??; - } - "textDocument/formatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - let options = reg - .register_options - .map(|options| { - serde_json::from_value::< - lsp::DocumentFormattingOptions, - >( - options - ) - }) - .transpose()?; - let provider = match options { - None => OneOf::Left(true), - Some(options) => OneOf::Right(options), - }; - server.update_capabilities(|capabilities| { - capabilities.document_formatting_provider = - Some(provider); - }); - notify_server_capabilities_updated(&server, cx); - } - anyhow::Ok(()) - })??; - } - "workspace/didChangeConfiguration" => { - // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. - } - "textDocument/rename" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - let options = reg - .register_options - .map(|options| { - serde_json::from_value::( - options, - ) - }) - .transpose()?; - let options = match options { - None => OneOf::Left(true), - Some(options) => OneOf::Right(options), - }; - - server.update_capabilities(|capabilities| { - capabilities.rename_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); - } - anyhow::Ok(()) - })??; - } - _ => log::warn!("unhandled capability registration: {reg:?}"), - } - } + }) + .ok(); Ok(()) } } @@ -779,79 +671,27 @@ impl LocalLspStore { language_server .on_request::({ - let this = this.clone(); + let lsp_store = this.clone(); move |params, cx| { - let lsp_store = this.clone(); + let lsp_store = lsp_store.clone(); let mut cx = cx.clone(); async move { - for unreg in params.unregisterations.iter() { - match unreg.method.as_str() { - "workspace/didChangeWatchedFiles" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - lsp_store - .as_local_mut()? - .on_lsp_unregister_did_change_watched_files( - server_id, &unreg.id, cx, + lsp_store + .update(&mut cx, |lsp_store, cx| { + if lsp_store.as_local().is_some() { + match lsp_store + .unregister_server_capabilities(server_id, params, cx) + { + Ok(()) => {} + Err(e) => { + log::error!( + "Failed to unregister server capabilities: {e:#}" ); - Some(()) - })?; - } - "workspace/didChangeConfiguration" => { - // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. - } - "textDocument/rename" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - server.update_capabilities(|capabilities| { - capabilities.rename_provider = None - }); - notify_server_capabilities_updated(&server, cx); } - })?; + } } - "textDocument/rangeFormatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - server.update_capabilities(|capabilities| { - capabilities.document_range_formatting_provider = - None - }); - notify_server_capabilities_updated(&server, cx); - } - })?; - } - "textDocument/onTypeFormatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - server.update_capabilities(|capabilities| { - capabilities.document_on_type_formatting_provider = - None; - }); - notify_server_capabilities_updated(&server, cx); - } - })?; - } - "textDocument/formatting" => { - lsp_store.update(&mut cx, |lsp_store, cx| { - if let Some(server) = - lsp_store.language_server_for_id(server_id) - { - server.update_capabilities(|capabilities| { - capabilities.document_formatting_provider = None; - }); - notify_server_capabilities_updated(&server, cx); - } - })?; - } - _ => log::warn!("unhandled capability unregistration: {unreg:?}"), - } - } + }) + .ok(); Ok(()) } } @@ -3519,6 +3359,16 @@ impl LocalLspStore { Ok(workspace_config) } + + fn language_server_for_id(&self, id: LanguageServerId) -> Option> { + if let Some(LanguageServerState::Running { server, .. }) = self.language_servers.get(&id) { + Some(server.clone()) + } else if let Some((_, server)) = self.supplementary_language_servers.get(&id) { + Some(Arc::clone(server)) + } else { + None + } + } } fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context) { @@ -9434,16 +9284,7 @@ impl LspStore { } pub fn language_server_for_id(&self, id: LanguageServerId) -> Option> { - let local_lsp_store = self.as_local()?; - if let Some(LanguageServerState::Running { server, .. }) = - local_lsp_store.language_servers.get(&id) - { - Some(server.clone()) - } else if let Some((_, server)) = local_lsp_store.supplementary_language_servers.get(&id) { - Some(Arc::clone(server)) - } else { - None - } + self.as_local()?.language_server_for_id(id) } fn on_lsp_progress( @@ -11808,6 +11649,375 @@ impl LspStore { .log_err(); } } + + fn register_server_capabilities( + &mut self, + server_id: LanguageServerId, + params: lsp::RegistrationParams, + cx: &mut Context, + ) -> anyhow::Result<()> { + let server = self + .language_server_for_id(server_id) + .with_context(|| format!("no server {server_id} found"))?; + for reg in params.registrations { + match reg.method.as_str() { + "workspace/didChangeWatchedFiles" => { + if let Some(options) = reg.register_options { + let notify = if let Some(local_lsp_store) = self.as_local_mut() { + let caps = serde_json::from_value(options)?; + local_lsp_store + .on_lsp_did_change_watched_files(server_id, ®.id, caps, cx); + true + } else { + false + }; + if notify { + notify_server_capabilities_updated(&server, cx); + } + } + } + "workspace/didChangeConfiguration" => { + // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. + } + "workspace/symbol" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.workspace_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "workspace/fileOperations" => { + if let Some(options) = reg.register_options { + let caps = serde_json::from_value(options)?; + server.update_capabilities(|capabilities| { + capabilities + .workspace + .get_or_insert_default() + .file_operations = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "workspace/executeCommand" => { + if let Some(options) = reg.register_options { + let options = serde_json::from_value(options)?; + server.update_capabilities(|capabilities| { + capabilities.execute_command_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/rangeFormatting" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_range_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/onTypeFormatting" => { + if let Some(options) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.document_on_type_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/formatting" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/rename" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.rename_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/inlayHint" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.inlay_hint_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/documentSymbol" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/codeAction" => { + if let Some(options) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.code_action_provider = + Some(lsp::CodeActionProviderCapability::Options(options)); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/definition" => { + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.definition_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/completion" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.completion_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/hover" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.hover_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/signatureHelp" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.signature_help_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/didChange" => { + if let Some(sync_kind) = reg + .register_options + .and_then(|opts| opts.get("syncKind").cloned()) + .map(serde_json::from_value::) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.text_document_sync = + Some(lsp::TextDocumentSyncCapability::Kind(sync_kind)); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/codeLens" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.code_lens_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/diagnostic" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.diagnostic_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + "textDocument/colorProvider" => { + if let Some(caps) = reg + .register_options + .map(serde_json::from_value) + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.color_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } + } + _ => log::warn!("unhandled capability registration: {reg:?}"), + } + } + + Ok(()) + } + + fn unregister_server_capabilities( + &mut self, + server_id: LanguageServerId, + params: lsp::UnregistrationParams, + cx: &mut Context, + ) -> anyhow::Result<()> { + let server = self + .language_server_for_id(server_id) + .with_context(|| format!("no server {server_id} found"))?; + for unreg in params.unregisterations.iter() { + match unreg.method.as_str() { + "workspace/didChangeWatchedFiles" => { + let notify = if let Some(local_lsp_store) = self.as_local_mut() { + local_lsp_store + .on_lsp_unregister_did_change_watched_files(server_id, &unreg.id, cx); + true + } else { + false + }; + if notify { + notify_server_capabilities_updated(&server, cx); + } + } + "workspace/didChangeConfiguration" => { + // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. + } + "workspace/symbol" => { + server.update_capabilities(|capabilities| { + capabilities.workspace_symbol_provider = None + }); + notify_server_capabilities_updated(&server, cx); + } + "workspace/fileOperations" => { + server.update_capabilities(|capabilities| { + capabilities + .workspace + .get_or_insert_with(|| lsp::WorkspaceServerCapabilities { + workspace_folders: None, + file_operations: None, + }) + .file_operations = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "workspace/executeCommand" => { + server.update_capabilities(|capabilities| { + capabilities.execute_command_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/rangeFormatting" => { + server.update_capabilities(|capabilities| { + capabilities.document_range_formatting_provider = None + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/onTypeFormatting" => { + server.update_capabilities(|capabilities| { + capabilities.document_on_type_formatting_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/formatting" => { + server.update_capabilities(|capabilities| { + capabilities.document_formatting_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/rename" => { + server.update_capabilities(|capabilities| capabilities.rename_provider = None); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/codeAction" => { + server.update_capabilities(|capabilities| { + capabilities.code_action_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/definition" => { + server.update_capabilities(|capabilities| { + capabilities.definition_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/completion" => { + server.update_capabilities(|capabilities| { + capabilities.completion_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/hover" => { + server.update_capabilities(|capabilities| { + capabilities.hover_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/signatureHelp" => { + server.update_capabilities(|capabilities| { + capabilities.signature_help_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/didChange" => { + server.update_capabilities(|capabilities| { + capabilities.text_document_sync = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/codeLens" => { + server.update_capabilities(|capabilities| { + capabilities.code_lens_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/diagnostic" => { + server.update_capabilities(|capabilities| { + capabilities.diagnostic_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + "textDocument/colorProvider" => { + server.update_capabilities(|capabilities| { + capabilities.color_provider = None; + }); + notify_server_capabilities_updated(&server, cx); + } + _ => log::warn!("unhandled capability unregistration: {unreg:?}"), + } + } + + Ok(()) + } +} + +// Registration with empty capabilities should be ignored. +// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/formatting.ts#L67-L70 +fn parse_register_capabilities( + reg: lsp::Registration, +) -> anyhow::Result>> { + Ok(reg + .register_options + .map(|options| serde_json::from_value::(options)) + .transpose()? + .map(OneOf::Right)) } fn subscribe_to_binary_statuses( diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index d543e6bf25..27ab55d53e 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -962,14 +962,19 @@ impl settings::Settings for DisableAiSettings { type FileContent = Option; fn load(sources: SettingsSources, _: &mut App) -> Result { - Ok(Self { - disable_ai: sources - .user - .or(sources.server) - .copied() - .flatten() - .unwrap_or(sources.default.ok_or_else(Self::missing_default)?), - }) + // For security reasons, settings can only make AI restrictions MORE strict, not less. + // (For example, if someone is working on a project that contractually + // requires no AI use, that should override the user's setting which + // permits AI use.) + // This also prevents an attacker from using project or server settings to enable AI when it should be disabled. + let disable_ai = sources + .project + .iter() + .chain(sources.user.iter()) + .chain(sources.server.iter()) + .any(|disabled| **disabled == Some(true)); + + Ok(Self { disable_ai }) } fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} @@ -5508,3 +5513,153 @@ fn provide_inline_values( variables } + +#[cfg(test)] +mod disable_ai_settings_tests { + use super::*; + use gpui::TestAppContext; + use settings::{Settings, SettingsSources}; + + #[gpui::test] + async fn test_disable_ai_settings_security(cx: &mut TestAppContext) { + cx.update(|cx| { + // Test 1: Default is false (AI enabled) + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: None, + release_channel: None, + operating_system: None, + profile: None, + server: None, + project: &[], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!(settings.disable_ai, false, "Default should allow AI"); + + // Test 2: Global true, local false -> still disabled (local cannot re-enable) + let global_true = Some(true); + let local_false = Some(false); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&global_true), + release_channel: None, + operating_system: None, + profile: None, + server: None, + project: &[&local_false], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Local false cannot override global true" + ); + + // Test 3: Global false, local true -> disabled (local can make more restrictive) + let global_false = Some(false); + let local_true = Some(true); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&global_false), + release_channel: None, + operating_system: None, + profile: None, + server: None, + project: &[&local_true], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Local true can override global false" + ); + + // Test 4: Server can only make more restrictive (set to true) + let user_false = Some(false); + let server_true = Some(true); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&user_false), + release_channel: None, + operating_system: None, + profile: None, + server: Some(&server_true), + project: &[], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Server can set to true even if user is false" + ); + + // Test 5: Server false cannot override user true + let user_true = Some(true); + let server_false = Some(false); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&user_true), + release_channel: None, + operating_system: None, + profile: None, + server: Some(&server_false), + project: &[], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Server false cannot override user true" + ); + + // Test 6: Multiple local settings, any true disables AI + let global_false = Some(false); + let local_false3 = Some(false); + let local_true2 = Some(true); + let local_false4 = Some(false); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&global_false), + release_channel: None, + operating_system: None, + profile: None, + server: None, + project: &[&local_false3, &local_true2, &local_false4], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Any local true should disable AI" + ); + + // Test 7: All three sources can independently disable AI + let user_false2 = Some(false); + let server_false2 = Some(false); + let local_true3 = Some(true); + let sources = SettingsSources { + default: &Some(false), + global: None, + extensions: None, + user: Some(&user_false2), + release_channel: None, + operating_system: None, + profile: None, + server: Some(&server_false2), + project: &[&local_true3], + }; + let settings = DisableAiSettings::load(sources, cx).unwrap(); + assert_eq!( + settings.disable_ai, true, + "Local can disable even if user and server are false" + ); + }); + } +} diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 41d8c4b2fd..5ea7b87fbe 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -256,7 +256,7 @@ impl Project { let local_path = if is_ssh_terminal { None } else { path.clone() }; - let mut python_venv_activate_command = None; + let mut python_venv_activate_command = Task::ready(None); let (spawn_task, shell) = match kind { TerminalKind::Shell(_) => { @@ -265,6 +265,7 @@ impl Project { python_venv_directory, &settings.detect_venv, &settings.shell, + cx, ); } @@ -419,9 +420,12 @@ impl Project { }) .detach(); - if let Some(activate_command) = python_venv_activate_command { - this.activate_python_virtual_environment(activate_command, &terminal_handle, cx); - } + this.activate_python_virtual_environment( + python_venv_activate_command, + &terminal_handle, + cx, + ); + terminal_handle }) } @@ -539,12 +543,15 @@ impl Project { venv_base_directory: &Path, venv_settings: &VenvSettings, shell: &Shell, - ) -> Option { - let venv_settings = venv_settings.as_option()?; + cx: &mut App, + ) -> Task> { + let Some(venv_settings) = venv_settings.as_option() else { + return Task::ready(None); + }; let activate_keyword = match venv_settings.activate_script { terminal_settings::ActivateScript::Default => match std::env::consts::OS { "windows" => ".", - _ => "source", + _ => ".", }, terminal_settings::ActivateScript::Nushell => "overlay use", terminal_settings::ActivateScript::PowerShell => ".", @@ -589,30 +596,44 @@ impl Project { .join(activate_script_name) .to_string_lossy() .to_string(); - let quoted = shlex::try_quote(&path).ok()?; - smol::block_on(self.fs.metadata(path.as_ref())) - .ok() - .flatten()?; - Some(format!( - "{} {} ; clear{}", - activate_keyword, quoted, line_ending - )) + let is_valid_path = self.resolve_abs_path(path.as_ref(), cx); + cx.background_spawn(async move { + let quoted = shlex::try_quote(&path).ok()?; + if is_valid_path.await.is_some_and(|meta| meta.is_file()) { + Some(format!( + "{} {} ; clear{}", + activate_keyword, quoted, line_ending + )) + } else { + None + } + }) } else { - Some(format!( + Task::ready(Some(format!( "{activate_keyword} {activate_script_name} {name}; clear{line_ending}", name = venv_settings.venv_name - )) + ))) } } fn activate_python_virtual_environment( &self, - command: String, + command: Task>, terminal_handle: &Entity, cx: &mut App, ) { - terminal_handle.update(cx, |terminal, _| terminal.input(command.into_bytes())); + terminal_handle.update(cx, |_, cx| { + cx.spawn(async move |this, cx| { + if let Some(command) = command.await { + this.update(cx, |this, _| { + this.input(command.into_bytes()); + }) + .ok(); + } + }) + .detach() + }); } pub fn local_terminal_handles(&self) -> &Vec> { diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 967df41e23..4d7f2faf62 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -5351,26 +5351,22 @@ impl Render for ProjectPanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |this, range, window, cx| { - let mut items = - SmallVec::with_capacity(range.end - range.start); - this.iter_visible_entries( - range, - window, - cx, - |entry, _, entries, _, _| { - let (depth, _) = - Self::calculate_depth_and_difference( - entry, entries, - ); - items.push(depth); - }, - ); - items - }, - ) + .with_compute_indents_fn(cx.entity(), |this, range, window, cx| { + let mut items = + SmallVec::with_capacity(range.end - range.start); + this.iter_visible_entries( + range, + window, + cx, + |entry, _, entries, _, _| { + let (depth, _) = Self::calculate_depth_and_difference( + entry, entries, + ); + items.push(depth); + }, + ); + items + }) .on_click(cx.listener( |this, active_indent_guide: &IndentGuideLayout, window, cx| { if window.modifiers().secondary() { @@ -5394,7 +5390,7 @@ impl Render for ProjectPanel { } }, )) - .with_render_fn(cx.entity().clone(), move |this, params, _, cx| { + .with_render_fn(cx.entity(), move |this, params, _, cx| { const LEFT_OFFSET: Pixels = px(14.); const PADDING_Y: Pixels = px(4.); const HITBOX_OVERDRAW: Pixels = px(3.); @@ -5447,7 +5443,7 @@ impl Render for ProjectPanel { }) .when(show_sticky_entries, |list| { let sticky_items = ui::sticky_items( - cx.entity().clone(), + cx.entity(), |this, range, window, cx| { let mut items = SmallVec::with_capacity(range.end - range.start); this.iter_visible_entries( @@ -5474,7 +5470,7 @@ impl Render for ProjectPanel { list.with_decoration(if show_indent_guides { sticky_items.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_render_fn(cx.entity().clone(), move |_, params, _, _| { + .with_render_fn(cx.entity(), move |_, params, _, _| { const LEFT_OFFSET: Pixels = px(14.); let indent_size = params.indent_size; diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 6c62c8db93..de3316e357 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -740,9 +740,9 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", + " > [EDITOR: ''] <== selected", " > 3", " > 4", - " > [EDITOR: ''] <== selected", " a-different-filename.tar.gz", " > C", " .dockerignore", @@ -765,10 +765,10 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", - " > 4", " > [PROCESSING: 'new-dir']", - " a-different-filename.tar.gz <== selected", + " > 3 <== selected", + " > 4", + " a-different-filename.tar.gz", " > C", " .dockerignore", ] @@ -782,10 +782,10 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", + " > 3 <== selected", " > 4", " > new-dir", - " a-different-filename.tar.gz <== selected", + " a-different-filename.tar.gz", " > C", " .dockerignore", ] @@ -801,10 +801,10 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", + " > [EDITOR: '3'] <== selected", " > 4", " > new-dir", - " [EDITOR: 'a-different-filename.tar.gz'] <== selected", + " a-different-filename.tar.gz", " > C", " .dockerignore", ] @@ -819,10 +819,10 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", + " > 3 <== selected", " > 4", " > new-dir", - " a-different-filename.tar.gz <== selected", + " a-different-filename.tar.gz", " > C", " .dockerignore", ] @@ -837,12 +837,12 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", + " v 3", + " [EDITOR: ''] <== selected", + " Q", " > 4", " > new-dir", - " [EDITOR: ''] <== selected", " a-different-filename.tar.gz", - " > C", ] ); panel.update_in(cx, |panel, window, cx| { @@ -863,12 +863,12 @@ async fn test_editing_files(cx: &mut gpui::TestAppContext) { " > .git", " > a", " v b", - " > 3", + " v 3 <== selected", + " Q", " > 4", " > new-dir", - " a-different-filename.tar.gz <== selected", + " a-different-filename.tar.gz", " > C", - " .dockerignore", ] ); } diff --git a/crates/prompt_store/src/prompt_store.rs b/crates/prompt_store/src/prompt_store.rs index f9cb26ed9a..06a65b97cd 100644 --- a/crates/prompt_store/src/prompt_store.rs +++ b/crates/prompt_store/src/prompt_store.rs @@ -90,6 +90,15 @@ impl From for UserPromptId { } } +impl std::fmt::Display for PromptId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PromptId::User { uuid } => write!(f, "{}", uuid.0), + PromptId::EditWorkflow => write!(f, "Edit workflow"), + } + } +} + pub struct PromptStore { env: heed::Env, metadata_cache: RwLock, diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index 67c2224387..1064ed2f8d 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -158,14 +158,6 @@ message SynchronizeContextsResponse { repeated ContextVersion contexts = 1; } -message GetLlmToken {} - -message GetLlmTokenResponse { - string token = 1; -} - -message RefreshLlmToken {} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 353f19adb2..66f8da44f2 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -6,62 +6,6 @@ message UpdateInviteInfo { uint32 count = 2; } -message GetPrivateUserInfo {} - -message GetPrivateUserInfoResponse { - string metrics_id = 1; - bool staff = 2; - repeated string flags = 3; - optional uint64 accepted_tos_at = 4; -} - -enum Plan { - Free = 0; - ZedPro = 1; - ZedProTrial = 2; -} - -message UpdateUserPlan { - Plan plan = 1; - optional uint64 trial_started_at = 2; - optional bool is_usage_based_billing_enabled = 3; - optional SubscriptionUsage usage = 4; - optional SubscriptionPeriod subscription_period = 5; - optional bool account_too_young = 6; - optional bool has_overdue_invoices = 7; -} - -message SubscriptionPeriod { - uint64 started_at = 1; - uint64 ended_at = 2; -} - -message SubscriptionUsage { - uint32 model_requests_usage_amount = 1; - UsageLimit model_requests_usage_limit = 2; - uint32 edit_predictions_usage_amount = 3; - UsageLimit edit_predictions_usage_limit = 4; -} - -message UsageLimit { - oneof variant { - Limited limited = 1; - Unlimited unlimited = 2; - } - - message Limited { - uint32 limit = 1; - } - - message Unlimited {} -} - -message AcceptTermsOfService {} - -message AcceptTermsOfServiceResponse { - uint64 accepted_tos_at = 1; -} - message ShutdownRemoteServer {} message Toast { @@ -84,11 +28,13 @@ message GetCrashFiles { message GetCrashFilesResponse { repeated CrashReport crashes = 1; + repeated string legacy_panics = 2; } message CrashReport { - optional string panic_contents = 1; - optional bytes minidump_contents = 2; + reserved 1, 2; + string metadata = 3; + bytes minidump_contents = 4; } message Extension { diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index c32da9b110..f2c388a3a3 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -202,6 +202,16 @@ message GitInit { string fallback_branch_name = 3; } +message GitClone { + uint64 project_id = 1; + string abs_path = 2; + string remote_repo = 3; +} + +message GitCloneResponse { + bool success = 1; +} + message CheckForPushedCommits { uint64 project_id = 1; reserved 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index bb97bd500a..310fcf584e 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -135,12 +135,7 @@ message Envelope { FollowResponse follow_response = 99; UpdateFollowers update_followers = 100; Unfollow unfollow = 101; - GetPrivateUserInfo get_private_user_info = 102; - GetPrivateUserInfoResponse get_private_user_info_response = 103; - UpdateUserPlan update_user_plan = 234; UpdateDiffBases update_diff_bases = 104; - AcceptTermsOfService accept_terms_of_service = 239; - AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; OnTypeFormatting on_type_formatting = 105; OnTypeFormattingResponse on_type_formatting_response = 106; @@ -250,10 +245,6 @@ message Envelope { AddWorktree add_worktree = 222; AddWorktreeResponse add_worktree_response = 223; - GetLlmToken get_llm_token = 235; - GetLlmTokenResponse get_llm_token_response = 236; - RefreshLlmToken refresh_llm_token = 259; - LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241; LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242; @@ -399,10 +390,14 @@ message Envelope { GetDefaultBranchResponse get_default_branch_response = 360; GetCrashFiles get_crash_files = 361; - GetCrashFilesResponse get_crash_files_response = 362; // current max + GetCrashFilesResponse get_crash_files_response = 362; + + GitClone git_clone = 363; + GitCloneResponse git_clone_response = 364; // current max } reserved 87 to 88; + reserved 102 to 103; reserved 158 to 161; reserved 164; reserved 166 to 169; @@ -416,10 +411,13 @@ message Envelope { reserved 221; reserved 224 to 229; reserved 230 to 231; + reserved 234 to 236; + reserved 239 to 240; reserved 246; - reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 259; + reserved 270; reserved 280 to 281; reserved 332 to 333; } diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 9edb041b4b..802db09590 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -20,8 +20,6 @@ pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 }; pub const SSH_PROJECT_ID: u64 = 0; messages!( - (AcceptTermsOfService, Foreground), - (AcceptTermsOfServiceResponse, Foreground), (Ack, Foreground), (AckBufferOperation, Background), (AckChannelMessage, Background), @@ -105,8 +103,6 @@ messages!( (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), (GetPermalinkToLineResponse, Foreground), - (GetPrivateUserInfo, Foreground), - (GetPrivateUserInfoResponse, Foreground), (GetProjectSymbols, Background), (GetProjectSymbolsResponse, Background), (GetReferences, Background), @@ -119,8 +115,6 @@ messages!( (GetTypeDefinitionResponse, Background), (GetImplementation, Background), (GetImplementationResponse, Background), - (GetLlmToken, Background), - (GetLlmTokenResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -196,7 +190,6 @@ messages!( (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), (RefreshInlayHints, Foreground), - (RefreshLlmToken, Background), (RegisterBufferWithLanguageServers, Background), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -280,7 +273,6 @@ messages!( (UpdateProject, Foreground), (UpdateProjectCollaborator, Foreground), (UpdateUserChannels, Foreground), - (UpdateUserPlan, Foreground), (UpdateWorktree, Foreground), (UpdateWorktreeSettings, Foreground), (UpdateRepository, Foreground), @@ -316,10 +308,11 @@ messages!( (PullWorkspaceDiagnostics, Background), (GetDefaultBranch, Background), (GetDefaultBranchResponse, Background), + (GitClone, Background), + (GitCloneResponse, Background) ); request_messages!( - (AcceptTermsOfService, AcceptTermsOfServiceResponse), (ApplyCodeAction, ApplyCodeActionResponse), ( ApplyCompletionAdditionalEdits, @@ -352,9 +345,7 @@ request_messages!( (GetDocumentHighlights, GetDocumentHighlightsResponse), (GetDocumentSymbols, GetDocumentSymbolsResponse), (GetHover, GetHoverResponse), - (GetLlmToken, GetLlmTokenResponse), (GetNotifications, GetNotificationsResponse), - (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), (GetSignatureHelp, GetSignatureHelpResponse), @@ -484,6 +475,7 @@ request_messages!( (GetDocumentDiagnostics, GetDocumentDiagnosticsResponse), (PullWorkspaceDiagnostics, Ack), (GetDefaultBranch, GetDefaultBranchResponse), + (GitClone, GitCloneResponse) ); entity_messages!( @@ -615,7 +607,8 @@ entity_messages!( LogToDebugConsole, GetDocumentDiagnostics, PullWorkspaceDiagnostics, - GetDefaultBranch + GetDefaultBranch, + GitClone ); entity_messages!( diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 354434a7fc..e5e166cb4c 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -1292,7 +1292,7 @@ impl RemoteServerProjects { let connection_string = connection_string.clone(); move |_, _: &menu::Confirm, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, @@ -1312,7 +1312,7 @@ impl RemoteServerProjects { .child(Label::new("Remove Server").color(Color::Error)) .on_click(cx.listener(move |_, _, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 4306251e44..ea383ac264 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -400,6 +400,7 @@ impl SshSocket { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) + .args(self.connection_options.additional_args()) .args(["-o", "ControlMaster=no", "-o"]) .arg(format!("ControlPath={}", self.socket_path.display())) } @@ -410,6 +411,7 @@ impl SshSocket { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) + .args(self.connection_options.additional_args()) .envs(self.envs.clone()) } @@ -417,22 +419,26 @@ impl SshSocket { // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to #[cfg(not(target_os = "windows"))] fn ssh_args(&self) -> SshArgs { + let mut arguments = self.connection_options.additional_args(); + arguments.extend(vec![ + "-o".to_string(), + "ControlMaster=no".to_string(), + "-o".to_string(), + format!("ControlPath={}", self.socket_path.display()), + self.connection_options.ssh_url(), + ]); SshArgs { - arguments: vec![ - "-o".to_string(), - "ControlMaster=no".to_string(), - "-o".to_string(), - format!("ControlPath={}", self.socket_path.display()), - self.connection_options.ssh_url(), - ], + arguments, envs: None, } } #[cfg(target_os = "windows")] fn ssh_args(&self) -> SshArgs { + let mut arguments = self.connection_options.additional_args(); + arguments.push(self.connection_options.ssh_url()); SshArgs { - arguments: vec![self.connection_options.ssh_url()], + arguments, envs: Some(self.envs.clone()), } } @@ -1484,20 +1490,17 @@ impl RemoteConnection for SshRemoteConnection { identifier = &unique_identifier, ); - if let Some(rust_log) = std::env::var("RUST_LOG").ok() { - start_proxy_command = format!( - "RUST_LOG={} {}", - shlex::try_quote(&rust_log).unwrap(), - start_proxy_command - ) - } - if let Some(rust_backtrace) = std::env::var("RUST_BACKTRACE").ok() { - start_proxy_command = format!( - "RUST_BACKTRACE={} {}", - shlex::try_quote(&rust_backtrace).unwrap(), - start_proxy_command - ) + for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] { + if let Some(value) = std::env::var(env_var).ok() { + start_proxy_command = format!( + "{}={} {} ", + env_var, + shlex::try_quote(&value).unwrap(), + start_proxy_command, + ); + } } + if reconnect { start_proxy_command.push_str(" --reconnect"); } @@ -2069,11 +2072,17 @@ impl SshRemoteConnection { Ok(()) } + let use_musl = !build_remote_server.contains("nomusl"); let triple = format!( "{}-{}", self.ssh_platform.arch, match self.ssh_platform.os { - "linux" => "unknown-linux-musl", + "linux" => + if use_musl { + "unknown-linux-musl" + } else { + "unknown-linux-gnu" + }, "macos" => "apple-darwin", _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform), } @@ -2086,7 +2095,7 @@ impl SshRemoteConnection { String::new() } }; - if self.ssh_platform.os == "linux" { + if self.ssh_platform.os == "linux" && use_musl { rust_flags.push_str(" -C target-feature=+crt-static"); } if build_remote_server.contains("mold") { @@ -2229,8 +2238,7 @@ impl SshRemoteConnection { #[cfg(not(target_os = "windows"))] { - run_cmd(Command::new("gzip").args(["-9", "-f", &bin_path.to_string_lossy()])) - .await?; + run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?; } #[cfg(target_os = "windows")] { @@ -2462,7 +2470,7 @@ impl ChannelClient { }, async { smol::Timer::after(timeout).await; - anyhow::bail!("Timeout detected") + anyhow::bail!("Timed out resyncing remote client") }, ) .await @@ -2476,7 +2484,7 @@ impl ChannelClient { }, async { smol::Timer::after(timeout).await; - anyhow::bail!("Timeout detected") + anyhow::bail!("Timed out pinging remote client") }, ) .await diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index c6a546f345..dcec9f6fe0 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -74,6 +74,7 @@ libc.workspace = true minidumper.workspace = true [dev-dependencies] +action_log.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true client = { workspace = true, features = ["test-support"] } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 9730984f26..514e5ce4c0 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -1724,7 +1724,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu .await .unwrap(); - let action_log = cx.new(|_| assistant_tool::ActionLog::new(project.clone())); + let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); let request = Arc::new(LanguageModelRequest::default()); diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 9bb5645dc7..dc7fab8c3c 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -34,10 +34,10 @@ use smol::io::AsyncReadExt; use smol::Async; use smol::{net::unix::UnixListener, stream::StreamExt as _}; -use std::collections::HashMap; use std::ffi::OsStr; use std::ops::ControlFlow; use std::str::FromStr; +use std::sync::LazyLock; use std::{env, thread}; use std::{ io::Write, @@ -48,6 +48,13 @@ use std::{ use telemetry_events::LocationData; use util::ResultExt; +pub static VERSION: LazyLock<&str> = LazyLock::new(|| match *RELEASE_CHANNEL { + ReleaseChannel::Stable | ReleaseChannel::Preview => env!("ZED_PKG_VERSION"), + ReleaseChannel::Nightly | ReleaseChannel::Dev => { + option_env!("ZED_COMMIT_SHA").unwrap_or("missing-zed-commit-sha") + } +}); + fn init_logging_proxy() { env_logger::builder() .format(|buf, record| { @@ -113,7 +120,6 @@ fn init_logging_server(log_file_path: PathBuf) -> Result>> { fn init_panic_hook(session_id: String) { std::panic::set_hook(Box::new(move |info| { - crashes::handle_panic(); let payload = info .payload() .downcast_ref::<&str>() @@ -121,6 +127,8 @@ fn init_panic_hook(session_id: String) { .or_else(|| info.payload().downcast_ref::().cloned()) .unwrap_or_else(|| "Box".to_string()); + crashes::handle_panic(payload.clone(), info.location()); + let backtrace = backtrace::Backtrace::new(); let mut backtrace = backtrace .frames() @@ -150,14 +158,6 @@ fn init_panic_hook(session_id: String) { (&backtrace).join("\n") ); - let release_channel = *RELEASE_CHANNEL; - let version = match release_channel { - ReleaseChannel::Stable | ReleaseChannel::Preview => env!("ZED_PKG_VERSION"), - ReleaseChannel::Nightly | ReleaseChannel::Dev => { - option_env!("ZED_COMMIT_SHA").unwrap_or("missing-zed-commit-sha") - } - }; - let panic_data = telemetry_events::Panic { thread: thread_name.into(), payload: payload.clone(), @@ -165,9 +165,9 @@ fn init_panic_hook(session_id: String) { file: location.file().into(), line: location.line(), }), - app_version: format!("remote-server-{version}"), + app_version: format!("remote-server-{}", *VERSION), app_commit_sha: option_env!("ZED_COMMIT_SHA").map(|sha| sha.into()), - release_channel: release_channel.dev_name().into(), + release_channel: RELEASE_CHANNEL.dev_name().into(), target: env!("TARGET").to_owned().into(), os_name: telemetry::os_name(), os_version: Some(telemetry::os_version()), @@ -204,8 +204,8 @@ fn handle_crash_files_requests(project: &Entity, client: &Arc, _cx| async move { + let mut legacy_panics = Vec::new(); let mut crashes = Vec::new(); - let mut minidumps_by_session_id = HashMap::new(); let mut children = smol::fs::read_dir(paths::logs_dir()).await?; while let Some(child) = children.next().await { let child = child?; @@ -227,41 +227,31 @@ fn handle_crash_files_requests(project: &Entity, client: &Arc Result<()> { let server_paths = ServerPaths::new(&identifier)?; let id = std::process::id().to_string(); - smol::spawn(crashes::init(id.clone())).detach(); + smol::spawn(crashes::init(crashes::InitCrashHandler { + session_id: id.clone(), + zed_version: VERSION.to_owned(), + release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(), + commit_sha: option_env!("ZED_COMMIT_SHA").unwrap_or("no_sha").to_owned(), + })) + .detach(); init_panic_hook(id); log::info!("starting proxy process. PID: {}", std::process::id()); diff --git a/crates/repl/src/notebook/notebook_ui.rs b/crates/repl/src/notebook/notebook_ui.rs index b53809dff0..36a0af30d0 100644 --- a/crates/repl/src/notebook/notebook_ui.rs +++ b/crates/repl/src/notebook/notebook_ui.rs @@ -295,7 +295,7 @@ impl NotebookEditor { _cx: &mut Context, ) -> IconButton { let id: ElementId = ElementId::Name(id.into()); - IconButton::new(id, icon).width(px(CONTROL_SIZE).into()) + IconButton::new(id, icon).width(px(CONTROL_SIZE)) } fn render_notebook_controls( diff --git a/crates/repl/src/session.rs b/crates/repl/src/session.rs index 729a616135..f945e5ed9f 100644 --- a/crates/repl/src/session.rs +++ b/crates/repl/src/session.rs @@ -244,7 +244,7 @@ impl Session { repl_session_id = cx.entity_id().to_string(), ); - let session_view = cx.entity().clone(); + let session_view = cx.entity(); let kernel = match self.kernel_specification.clone() { KernelSpecification::Jupyter(kernel_specification) diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index 14703be7a2..189f48e6b6 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -1,22 +1,23 @@ mod registrar; use crate::{ - FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions, - SelectAllMatches, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, ToggleRegex, - ToggleReplace, ToggleSelection, ToggleWholeWord, search_bar::render_nav_button, + FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOption, + SearchOptions, SearchSource, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use any_vec::AnyVec; use anyhow::Context as _; use collections::HashMap; use editor::{ - DisplayPoint, Editor, EditorElement, EditorSettings, EditorStyle, + DisplayPoint, Editor, EditorSettings, actions::{Backtab, Tab}, }; use futures::channel::oneshot; use gpui::{ - Action, App, ClickEvent, Context, Entity, EventEmitter, FocusHandle, Focusable, - InteractiveElement as _, IntoElement, KeyContext, ParentElement as _, Render, ScrollHandle, - Styled, Subscription, Task, TextStyle, Window, actions, div, + Action, App, ClickEvent, Context, Entity, EventEmitter, Focusable, InteractiveElement as _, + IntoElement, KeyContext, ParentElement as _, Render, ScrollHandle, Styled, Subscription, Task, + Window, actions, div, }; use language::{Language, LanguageRegistry}; use project::{ @@ -27,7 +28,6 @@ use schemars::JsonSchema; use serde::Deserialize; use settings::Settings; use std::sync::Arc; -use theme::ThemeSettings; use zed_actions::outline::ToggleOutline; use ui::{ @@ -125,46 +125,6 @@ pub struct BufferSearchBar { } impl BufferSearchBar { - fn render_text_input( - &self, - editor: &Entity, - color_override: Option, - cx: &mut Context, - ) -> impl IntoElement { - let (color, use_syntax) = if editor.read(cx).read_only(cx) { - (cx.theme().colors().text_disabled, false) - } else { - match color_override { - Some(color_override) => (color_override.color(cx), false), - None => (cx.theme().colors().text, true), - } - }; - - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color, - font_family: settings.buffer_font.family.clone(), - font_features: settings.buffer_font.features.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_size: rems(0.875).into(), - font_weight: settings.buffer_font.weight, - line_height: relative(1.3), - ..TextStyle::default() - }; - - let mut editor_style = EditorStyle { - background: cx.theme().colors().toolbar_background, - local_player: cx.theme().players().local(), - text: text_style, - ..EditorStyle::default() - }; - if use_syntax { - editor_style.syntax = cx.theme().syntax().clone(); - } - - EditorElement::new(editor, editor_style) - } - pub fn query_editor_focused(&self) -> bool { self.query_editor_focused } @@ -185,7 +145,14 @@ impl Render for BufferSearchBar { let hide_inline_icons = self.editor_needed_width > self.editor_scroll_handle.bounds().size.width - window.rem_size() * 6.; - let supported_options = self.supported_options(cx); + let workspace::searchable::SearchOptions { + case, + word, + regex, + replacement, + selection, + find_in_results, + } = self.supported_options(cx); if self.query_editor.update(cx, |query_editor, _cx| { query_editor.placeholder_text().is_none() @@ -220,268 +187,203 @@ impl Render for BufferSearchBar { } }) .unwrap_or_else(|| "0/0".to_string()); - let should_show_replace_input = self.replace_enabled && supported_options.replacement; + let should_show_replace_input = self.replace_enabled && replacement; let in_replace = self.replacement_editor.focus_handle(cx).is_focused(window); + let theme_colors = cx.theme().colors(); + let query_border = if self.query_error.is_some() { + Color::Error.color(cx) + } else { + theme_colors.border + }; + let replacement_border = theme_colors.border; + + let container_width = window.viewport_size().width; + let input_width = SearchInputWidth::calc_width(container_width); + + let input_base_styles = + |border_color| input_base_styles(border_color, |div| div.w(input_width)); + + let query_column = input_base_styles(query_border) + .id("editor-scroll") + .track_scroll(&self.editor_scroll_handle) + .child(render_text_input(&self.query_editor, color_override, cx)) + .when(!hide_inline_icons, |div| { + div.child( + h_flex() + .gap_1() + .when(case, |div| { + div.child(SearchOption::CaseSensitive.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) + }) + .when(word, |div| { + div.child(SearchOption::WholeWord.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) + }) + .when(regex, |div| { + div.child(SearchOption::Regex.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) + }), + ) + }); + + let mode_column = h_flex() + .gap_1() + .min_w_64() + .when(replacement, |this| { + this.child(render_action_button( + "buffer-search-bar-toggle", + IconName::Replace, + self.replace_enabled.then_some(ActionButtonState::Toggled), + "Toggle Replace", + &ToggleReplace, + focus_handle.clone(), + )) + }) + .when(selection, |this| { + this.child( + IconButton::new( + "buffer-search-bar-toggle-search-selection-button", + IconName::Quote, + ) + .style(ButtonStyle::Subtle) + .shape(IconButtonShape::Square) + .when(self.selection_search_enabled, |button| { + button.style(ButtonStyle::Filled) + }) + .on_click(cx.listener(|this, _: &ClickEvent, window, cx| { + this.toggle_selection(&ToggleSelection, window, cx); + })) + .toggle_state(self.selection_search_enabled) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Toggle Search Selection", + &ToggleSelection, + &focus_handle, + window, + cx, + ) + } + }), + ) + }) + .when(!find_in_results, |el| { + let query_focus = self.query_editor.focus_handle(cx); + let matches_column = h_flex() + .pl_2() + .ml_2() + .border_l_1() + .border_color(theme_colors.border_variant) + .child(render_action_button( + "buffer-search-nav-button", + ui::IconName::ChevronLeft, + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), + "Select Previous Match", + &SelectPreviousMatch, + query_focus.clone(), + )) + .child(render_action_button( + "buffer-search-nav-button", + ui::IconName::ChevronRight, + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), + "Select Next Match", + &SelectNextMatch, + query_focus.clone(), + )) + .when(!narrow_mode, |this| { + this.child(div().ml_2().min_w(rems_from_px(40.)).child( + Label::new(match_text).size(LabelSize::Small).color( + if self.active_match_index.is_some() { + Color::Default + } else { + Color::Disabled + }, + ), + )) + }); + + el.child(render_action_button( + "buffer-search-nav-button", + IconName::SelectAll, + Default::default(), + "Select All Matches", + &SelectAllMatches, + query_focus, + )) + .child(matches_column) + }) + .when(find_in_results, |el| { + el.child(render_action_button( + "buffer-search", + IconName::Close, + Default::default(), + "Close Search Bar", + &Dismiss, + focus_handle.clone(), + )) + }); + + let search_line = h_flex() + .w_full() + .gap_2() + .when(find_in_results, |el| { + el.child(Label::new("Find in results").color(Color::Hint)) + }) + .child(query_column) + .child(mode_column); + + let replace_line = + should_show_replace_input.then(|| { + let replace_column = input_base_styles(replacement_border) + .child(render_text_input(&self.replacement_editor, None, cx)); + let focus_handle = self.replacement_editor.read(cx).focus_handle(cx); + + let replace_actions = h_flex() + .min_w_64() + .gap_1() + .child(render_action_button( + "buffer-search-replace-button", + IconName::ReplaceNext, + Default::default(), + "Replace Next Match", + &ReplaceNext, + focus_handle.clone(), + )) + .child(render_action_button( + "buffer-search-replace-button", + IconName::ReplaceAll, + Default::default(), + "Replace All Matches", + &ReplaceAll, + focus_handle, + )); + h_flex() + .w_full() + .gap_2() + .child(replace_column) + .child(replace_actions) + }); + let mut key_context = KeyContext::new_with_defaults(); key_context.add("BufferSearchBar"); if in_replace { key_context.add("in_replace"); } - let query_border = if self.query_error.is_some() { - Color::Error.color(cx) - } else { - cx.theme().colors().border - }; - let replacement_border = cx.theme().colors().border; - - let container_width = window.viewport_size().width; - let input_width = SearchInputWidth::calc_width(container_width); - - let input_base_styles = |border_color| { - h_flex() - .min_w_32() - .w(input_width) - .h_8() - .pl_2() - .pr_1() - .py_1() - .border_1() - .border_color(border_color) - .rounded_lg() - }; - - let search_line = h_flex() - .gap_2() - .when(supported_options.find_in_results, |el| { - el.child(Label::new("Find in results").color(Color::Hint)) - }) - .child( - input_base_styles(query_border) - .id("editor-scroll") - .track_scroll(&self.editor_scroll_handle) - .child(self.render_text_input(&self.query_editor, color_override, cx)) - .when(!hide_inline_icons, |div| { - div.child( - h_flex() - .gap_1() - .children(supported_options.case.then(|| { - self.render_search_option_button( - SearchOptions::CASE_SENSITIVE, - focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_case_sensitive( - &ToggleCaseSensitive, - window, - cx, - ) - }), - ) - })) - .children(supported_options.word.then(|| { - self.render_search_option_button( - SearchOptions::WHOLE_WORD, - focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_whole_word(&ToggleWholeWord, window, cx) - }), - ) - })) - .children(supported_options.regex.then(|| { - self.render_search_option_button( - SearchOptions::REGEX, - focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_regex(&ToggleRegex, window, cx) - }), - ) - })), - ) - }), - ) - .child( - h_flex() - .gap_1() - .min_w_64() - .when(supported_options.replacement, |this| { - this.child( - IconButton::new( - "buffer-search-bar-toggle-replace-button", - IconName::Replace, - ) - .style(ButtonStyle::Subtle) - .shape(IconButtonShape::Square) - .when(self.replace_enabled, |button| { - button.style(ButtonStyle::Filled) - }) - .on_click(cx.listener(|this, _: &ClickEvent, window, cx| { - this.toggle_replace(&ToggleReplace, window, cx); - })) - .toggle_state(self.replace_enabled) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Toggle Replace", - &ToggleReplace, - &focus_handle, - window, - cx, - ) - } - }), - ) - }) - .when(supported_options.selection, |this| { - this.child( - IconButton::new( - "buffer-search-bar-toggle-search-selection-button", - IconName::Quote, - ) - .style(ButtonStyle::Subtle) - .shape(IconButtonShape::Square) - .when(self.selection_search_enabled, |button| { - button.style(ButtonStyle::Filled) - }) - .on_click(cx.listener(|this, _: &ClickEvent, window, cx| { - this.toggle_selection(&ToggleSelection, window, cx); - })) - .toggle_state(self.selection_search_enabled) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Toggle Search Selection", - &ToggleSelection, - &focus_handle, - window, - cx, - ) - } - }), - ) - }) - .when(!supported_options.find_in_results, |el| { - el.child( - IconButton::new("select-all", ui::IconName::SelectAll) - .on_click(|_, window, cx| { - window.dispatch_action(SelectAllMatches.boxed_clone(), cx) - }) - .shape(IconButtonShape::Square) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Select All Matches", - &SelectAllMatches, - &focus_handle, - window, - cx, - ) - } - }), - ) - .child( - h_flex() - .pl_2() - .ml_1() - .border_l_1() - .border_color(cx.theme().colors().border_variant) - .child(render_nav_button( - ui::IconName::ChevronLeft, - self.active_match_index.is_some(), - "Select Previous Match", - &SelectPreviousMatch, - focus_handle.clone(), - )) - .child(render_nav_button( - ui::IconName::ChevronRight, - self.active_match_index.is_some(), - "Select Next Match", - &SelectNextMatch, - focus_handle.clone(), - )), - ) - .when(!narrow_mode, |this| { - this.child(h_flex().ml_2().min_w(rems_from_px(40.)).child( - Label::new(match_text).size(LabelSize::Small).color( - if self.active_match_index.is_some() { - Color::Default - } else { - Color::Disabled - }, - ), - )) - }) - }) - .when(supported_options.find_in_results, |el| { - el.child( - IconButton::new(SharedString::from("Close"), IconName::Close) - .shape(IconButtonShape::Square) - .tooltip(move |window, cx| { - Tooltip::for_action("Close Search Bar", &Dismiss, window, cx) - }) - .on_click(cx.listener(|this, _: &ClickEvent, window, cx| { - this.dismiss(&Dismiss, window, cx) - })), - ) - }), - ); - - let replace_line = should_show_replace_input.then(|| { - h_flex() - .gap_2() - .child( - input_base_styles(replacement_border).child(self.render_text_input( - &self.replacement_editor, - None, - cx, - )), - ) - .child( - h_flex() - .min_w_64() - .gap_1() - .child( - IconButton::new("search-replace-next", ui::IconName::ReplaceNext) - .shape(IconButtonShape::Square) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Replace Next Match", - &ReplaceNext, - &focus_handle, - window, - cx, - ) - } - }) - .on_click(cx.listener(|this, _, window, cx| { - this.replace_next(&ReplaceNext, window, cx) - })), - ) - .child( - IconButton::new("search-replace-all", ui::IconName::ReplaceAll) - .shape(IconButtonShape::Square) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Replace All Matches", - &ReplaceAll, - &focus_handle, - window, - cx, - ) - } - }) - .on_click(cx.listener(|this, _, window, cx| { - this.replace_all(&ReplaceAll, window, cx) - })), - ), - ) - }); let query_error_line = self.query_error.as_ref().map(|error| { Label::new(error) @@ -491,10 +393,26 @@ impl Render for BufferSearchBar { .ml_2() }); + let search_line = + h_flex() + .relative() + .child(search_line) + .when(!narrow_mode && !find_in_results, |div| { + div.child(h_flex().absolute().right_0().child(render_action_button( + "buffer-search", + IconName::Close, + Default::default(), + "Close Search Bar", + &Dismiss, + focus_handle.clone(), + ))) + .w_full() + }); v_flex() .id("buffer_search") .gap_2() .py(px(1.0)) + .w_full() .track_scroll(&self.scroll_handle) .key_context(key_context) .capture_action(cx.listener(Self::tab)) @@ -509,43 +427,26 @@ impl Render for BufferSearchBar { active_searchable_item.relay_action(Box::new(ToggleOutline), window, cx); } })) - .when(self.supported_options(cx).replacement, |this| { + .when(replacement, |this| { this.on_action(cx.listener(Self::toggle_replace)) .when(in_replace, |this| { this.on_action(cx.listener(Self::replace_next)) .on_action(cx.listener(Self::replace_all)) }) }) - .when(self.supported_options(cx).case, |this| { + .when(case, |this| { this.on_action(cx.listener(Self::toggle_case_sensitive)) }) - .when(self.supported_options(cx).word, |this| { + .when(word, |this| { this.on_action(cx.listener(Self::toggle_whole_word)) }) - .when(self.supported_options(cx).regex, |this| { + .when(regex, |this| { this.on_action(cx.listener(Self::toggle_regex)) }) - .when(self.supported_options(cx).selection, |this| { + .when(selection, |this| { this.on_action(cx.listener(Self::toggle_selection)) }) - .child(h_flex().relative().child(search_line.w_full()).when( - !narrow_mode && !supported_options.find_in_results, - |div| { - div.child( - h_flex().absolute().right_0().child( - IconButton::new(SharedString::from("Close"), IconName::Close) - .shape(IconButtonShape::Square) - .tooltip(move |window, cx| { - Tooltip::for_action("Close Search Bar", &Dismiss, window, cx) - }) - .on_click(cx.listener(|this, _: &ClickEvent, window, cx| { - this.dismiss(&Dismiss, window, cx) - })), - ), - ) - .w_full() - }, - )) + .child(search_line) .children(query_error_line) .children(replace_line) } @@ -792,7 +693,7 @@ impl BufferSearchBar { active_editor.search_bar_visibility_changed(false, window, cx); active_editor.toggle_filtered_search_ranges(false, window, cx); let handle = active_editor.item_focus_handle(cx); - self.focus(&handle, window, cx); + self.focus(&handle, window); } cx.emit(Event::UpdateLocation); cx.emit(ToolbarItemEvent::ChangeLocation( @@ -948,7 +849,7 @@ impl BufferSearchBar { } pub fn focus_replace(&mut self, window: &mut Window, cx: &mut Context) { - self.focus(&self.replacement_editor.focus_handle(cx), window, cx); + self.focus(&self.replacement_editor.focus_handle(cx), window); cx.notify(); } @@ -975,16 +876,6 @@ impl BufferSearchBar { self.update_matches(!updated, window, cx) } - fn render_search_option_button( - &self, - option: SearchOptions, - focus_handle: FocusHandle, - action: Action, - ) -> impl IntoElement + use { - let is_active = self.search_options.contains(option); - option.as_button(is_active, focus_handle, action) - } - pub fn focus_editor(&mut self, _: &FocusEditor, window: &mut Window, cx: &mut Context) { if let Some(active_editor) = self.active_searchable_item.as_ref() { let handle = active_editor.item_focus_handle(cx); @@ -1400,28 +1291,32 @@ impl BufferSearchBar { } fn tab(&mut self, _: &Tab, window: &mut Window, cx: &mut Context) { - // Search -> Replace -> Editor - let focus_handle = if self.replace_enabled && self.query_editor_focused { - self.replacement_editor.focus_handle(cx) - } else if let Some(item) = self.active_searchable_item.as_ref() { - item.item_focus_handle(cx) - } else { - return; - }; - self.focus(&focus_handle, window, cx); - cx.stop_propagation(); + self.cycle_field(Direction::Next, window, cx); } fn backtab(&mut self, _: &Backtab, window: &mut Window, cx: &mut Context) { - // Search -> Replace -> Search - let focus_handle = if self.replace_enabled && self.query_editor_focused { - self.replacement_editor.focus_handle(cx) - } else if self.replacement_editor_focused { - self.query_editor.focus_handle(cx) - } else { - return; + self.cycle_field(Direction::Prev, window, cx); + } + fn cycle_field(&mut self, direction: Direction, window: &mut Window, cx: &mut Context) { + let mut handles = vec![self.query_editor.focus_handle(cx)]; + if self.replace_enabled { + handles.push(self.replacement_editor.focus_handle(cx)); + } + if let Some(item) = self.active_searchable_item.as_ref() { + handles.push(item.item_focus_handle(cx)); + } + let current_index = match handles.iter().position(|focus| focus.is_focused(window)) { + Some(index) => index, + None => return, }; - self.focus(&focus_handle, window, cx); + + let new_index = match direction { + Direction::Next => (current_index + 1) % handles.len(), + Direction::Prev if current_index == 0 => handles.len() - 1, + Direction::Prev => (current_index - 1) % handles.len(), + }; + let next_focus_handle = &handles[new_index]; + self.focus(next_focus_handle, window); cx.stop_propagation(); } @@ -1469,10 +1364,8 @@ impl BufferSearchBar { } } - fn focus(&self, handle: &gpui::FocusHandle, window: &mut Window, cx: &mut Context) { - cx.on_next_frame(window, |_, window, _| { - window.invalidate_character_coordinates(); - }); + fn focus(&self, handle: &gpui::FocusHandle, window: &mut Window) { + window.invalidate_character_coordinates(); window.focus(handle); } @@ -1484,7 +1377,7 @@ impl BufferSearchBar { } else { self.query_editor.focus_handle(cx) }; - self.focus(&handle, window, cx); + self.focus(&handle, window); cx.notify(); } } diff --git a/crates/search/src/mode.rs b/crates/search/src/mode.rs deleted file mode 100644 index 957eb707a5..0000000000 --- a/crates/search/src/mode.rs +++ /dev/null @@ -1,36 +0,0 @@ -use gpui::{Action, SharedString}; - -use crate::{ActivateRegexMode, ActivateTextMode}; - -// TODO: Update the default search mode to get from config -#[derive(Copy, Clone, Debug, Default, PartialEq)] -pub enum SearchMode { - #[default] - Text, - Regex, -} - -impl SearchMode { - pub(crate) fn label(&self) -> &'static str { - match self { - SearchMode::Text => "Text", - SearchMode::Regex => "Regex", - } - } - pub(crate) fn tooltip(&self) -> SharedString { - format!("Activate {} Mode", self.label()).into() - } - pub(crate) fn action(&self) -> Box { - match self { - SearchMode::Text => ActivateTextMode.boxed_clone(), - SearchMode::Regex => ActivateRegexMode.boxed_clone(), - } - } -} - -pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode { - match mode { - SearchMode::Text => SearchMode::Regex, - SearchMode::Regex => SearchMode::Text, - } -} diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 96194cdad2..056c3556ba 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -1,20 +1,23 @@ use crate::{ BufferSearchBar, FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, - SearchOptions, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, ToggleIncludeIgnored, - ToggleRegex, ToggleReplace, ToggleWholeWord, buffer_search::Deploy, + SearchOption, SearchOptions, SearchSource, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord, + buffer_search::Deploy, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use anyhow::Context as _; -use collections::{HashMap, HashSet}; +use collections::HashMap; use editor::{ - Anchor, Editor, EditorElement, EditorEvent, EditorSettings, EditorStyle, MAX_TAB_TITLE_LEN, - MultiBuffer, SelectionEffects, actions::SelectAll, items::active_match_index, + Anchor, Editor, EditorEvent, EditorSettings, MAX_TAB_TITLE_LEN, MultiBuffer, SelectionEffects, + actions::{Backtab, SelectAll, Tab}, + items::active_match_index, }; use futures::{StreamExt, stream::FuturesOrdered}; use gpui::{ Action, AnyElement, AnyView, App, Axis, Context, Entity, EntityId, EventEmitter, FocusHandle, Focusable, Global, Hsla, InteractiveElement, IntoElement, KeyContext, ParentElement, Point, - Render, SharedString, Styled, Subscription, Task, TextStyle, UpdateGlobal, WeakEntity, Window, - actions, div, + Render, SharedString, Styled, Subscription, Task, UpdateGlobal, WeakEntity, Window, actions, + div, }; use language::{Buffer, Language}; use menu::Confirm; @@ -32,7 +35,6 @@ use std::{ pin::pin, sync::Arc, }; -use theme::ThemeSettings; use ui::{ Icon, IconButton, IconButtonShape, IconName, KeyBinding, Label, LabelCommon, LabelSize, Toggleable, Tooltip, h_flex, prelude::*, utils::SearchInputWidth, v_flex, @@ -208,7 +210,7 @@ pub struct ProjectSearchView { replacement_editor: Entity, results_editor: Entity, search_options: SearchOptions, - panels_with_errors: HashSet, + panels_with_errors: HashMap, active_match_index: Option, search_id: usize, included_files_editor: Entity, @@ -218,7 +220,6 @@ pub struct ProjectSearchView { included_opened_only: bool, regex_language: Option>, _subscriptions: Vec, - query_error: Option, } #[derive(Debug, Clone)] @@ -879,7 +880,7 @@ impl ProjectSearchView { query_editor, results_editor, search_options: options, - panels_with_errors: HashSet::default(), + panels_with_errors: HashMap::default(), active_match_index: None, included_files_editor, excluded_files_editor, @@ -888,7 +889,6 @@ impl ProjectSearchView { included_opened_only: false, regex_language: None, _subscriptions: subscriptions, - query_error: None, }; this.entity_changed(window, cx); this @@ -1152,14 +1152,16 @@ impl ProjectSearchView { Ok(included_files) => { let should_unmark_error = self.panels_with_errors.remove(&InputPanel::Include); - if should_unmark_error { + if should_unmark_error.is_some() { cx.notify(); } included_files } - Err(_e) => { - let should_mark_error = self.panels_with_errors.insert(InputPanel::Include); - if should_mark_error { + Err(e) => { + let should_mark_error = self + .panels_with_errors + .insert(InputPanel::Include, e.to_string()); + if should_mark_error.is_none() { cx.notify(); } PathMatcher::default() @@ -1174,15 +1176,17 @@ impl ProjectSearchView { Ok(excluded_files) => { let should_unmark_error = self.panels_with_errors.remove(&InputPanel::Exclude); - if should_unmark_error { + if should_unmark_error.is_some() { cx.notify(); } excluded_files } - Err(_e) => { - let should_mark_error = self.panels_with_errors.insert(InputPanel::Exclude); - if should_mark_error { + Err(e) => { + let should_mark_error = self + .panels_with_errors + .insert(InputPanel::Exclude, e.to_string()); + if should_mark_error.is_none() { cx.notify(); } PathMatcher::default() @@ -1219,19 +1223,19 @@ impl ProjectSearchView { ) { Ok(query) => { let should_unmark_error = self.panels_with_errors.remove(&InputPanel::Query); - if should_unmark_error { + if should_unmark_error.is_some() { cx.notify(); } - self.query_error = None; Some(query) } Err(e) => { - let should_mark_error = self.panels_with_errors.insert(InputPanel::Query); - if should_mark_error { + let should_mark_error = self + .panels_with_errors + .insert(InputPanel::Query, e.to_string()); + if should_mark_error.is_none() { cx.notify(); } - self.query_error = Some(e.to_string()); None } @@ -1249,15 +1253,17 @@ impl ProjectSearchView { ) { Ok(query) => { let should_unmark_error = self.panels_with_errors.remove(&InputPanel::Query); - if should_unmark_error { + if should_unmark_error.is_some() { cx.notify(); } Some(query) } - Err(_e) => { - let should_mark_error = self.panels_with_errors.insert(InputPanel::Query); - if should_mark_error { + Err(e) => { + let should_mark_error = self + .panels_with_errors + .insert(InputPanel::Query, e.to_string()); + if should_mark_error.is_none() { cx.notify(); } @@ -1512,7 +1518,7 @@ impl ProjectSearchView { } fn border_color_for(&self, panel: InputPanel, cx: &App) -> Hsla { - if self.panels_with_errors.contains(&panel) { + if self.panels_with_errors.contains_key(&panel) { Color::Error.color(cx) } else { cx.theme().colors().border @@ -1610,16 +1616,11 @@ impl ProjectSearchBar { } } - fn tab(&mut self, _: &editor::actions::Tab, window: &mut Window, cx: &mut Context) { + fn tab(&mut self, _: &Tab, window: &mut Window, cx: &mut Context) { self.cycle_field(Direction::Next, window, cx); } - fn backtab( - &mut self, - _: &editor::actions::Backtab, - window: &mut Window, - cx: &mut Context, - ) { + fn backtab(&mut self, _: &Backtab, window: &mut Window, cx: &mut Context) { self.cycle_field(Direction::Prev, window, cx); } @@ -1634,29 +1635,22 @@ impl ProjectSearchBar { fn cycle_field(&mut self, direction: Direction, window: &mut Window, cx: &mut Context) { let active_project_search = match &self.active_project_search { Some(active_project_search) => active_project_search, - - None => { - return; - } + None => return, }; active_project_search.update(cx, |project_view, cx| { - let mut views = vec![&project_view.query_editor]; + let mut views = vec![project_view.query_editor.focus_handle(cx)]; if project_view.replace_enabled { - views.push(&project_view.replacement_editor); + views.push(project_view.replacement_editor.focus_handle(cx)); } if project_view.filters_enabled { views.extend([ - &project_view.included_files_editor, - &project_view.excluded_files_editor, + project_view.included_files_editor.focus_handle(cx), + project_view.excluded_files_editor.focus_handle(cx), ]); } - let current_index = match views - .iter() - .enumerate() - .find(|(_, editor)| editor.focus_handle(cx).is_focused(window)) - { - Some((index, _)) => index, + let current_index = match views.iter().position(|focus| focus.is_focused(window)) { + Some(index) => index, None => return, }; @@ -1665,13 +1659,13 @@ impl ProjectSearchBar { Direction::Prev if current_index == 0 => views.len() - 1, Direction::Prev => (current_index - 1) % views.len(), }; - let next_focus_handle = views[new_index].focus_handle(cx); - window.focus(&next_focus_handle); + let next_focus_handle = &views[new_index]; + window.focus(next_focus_handle); cx.stop_propagation(); }); } - fn toggle_search_option( + pub(crate) fn toggle_search_option( &mut self, option: SearchOptions, window: &mut Window, @@ -1788,14 +1782,6 @@ impl ProjectSearchBar { } } - fn is_option_enabled(&self, option: SearchOptions, cx: &App) -> bool { - if let Some(search) = self.active_project_search.as_ref() { - search.read(cx).search_options.contains(option) - } else { - false - } - } - fn next_history_query( &mut self, _: &NextHistoryQuery, @@ -1915,37 +1901,6 @@ impl ProjectSearchBar { }) } } - - fn render_text_input(&self, editor: &Entity, cx: &Context) -> impl IntoElement { - let (color, use_syntax) = if editor.read(cx).read_only(cx) { - (cx.theme().colors().text_disabled, false) - } else { - (cx.theme().colors().text, true) - }; - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color, - font_family: settings.buffer_font.family.clone(), - font_features: settings.buffer_font.features.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_size: rems(0.875).into(), - font_weight: settings.buffer_font.weight, - line_height: relative(1.3), - ..TextStyle::default() - }; - - let mut editor_style = EditorStyle { - background: cx.theme().colors().toolbar_background, - local_player: cx.theme().players().local(), - text: text_style, - ..EditorStyle::default() - }; - if use_syntax { - editor_style.syntax = cx.theme().syntax().clone(); - } - - EditorElement::new(editor, editor_style) - } } impl Render for ProjectSearchBar { @@ -1959,28 +1914,43 @@ impl Render for ProjectSearchBar { let container_width = window.viewport_size().width; let input_width = SearchInputWidth::calc_width(container_width); - enum BaseStyle { - SingleInput, - MultipleInputs, - } - - let input_base_styles = |base_style: BaseStyle, panel: InputPanel| { - h_flex() - .min_w_32() - .map(|div| match base_style { - BaseStyle::SingleInput => div.w(input_width), - BaseStyle::MultipleInputs => div.flex_grow(), - }) - .h_8() - .pl_2() - .pr_1() - .py_1() - .border_1() - .border_color(search.border_color_for(panel, cx)) - .rounded_lg() + let input_base_styles = |panel: InputPanel| { + input_base_styles(search.border_color_for(panel, cx), |div| match panel { + InputPanel::Query | InputPanel::Replacement => div.w(input_width), + InputPanel::Include | InputPanel::Exclude => div.flex_grow(), + }) }; + let theme_colors = cx.theme().colors(); + let project_search = search.entity.read(cx); + let limit_reached = project_search.limit_reached; - let query_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Query) + let color_override = match ( + project_search.no_results, + &project_search.active_query, + &project_search.last_search_query_text, + ) { + (Some(true), Some(q), Some(p)) if q.as_str() == p => Some(Color::Error), + _ => None, + }; + let match_text = search + .active_match_index + .and_then(|index| { + let index = index + 1; + let match_quantity = project_search.match_ranges.len(); + if match_quantity > 0 { + debug_assert!(match_quantity >= index); + if limit_reached { + Some(format!("{index}/{match_quantity}+")) + } else { + Some(format!("{index}/{match_quantity}")) + } + } else { + None + } + }) + .unwrap_or_else(|| "0/0".to_string()); + + let query_column = input_base_styles(InputPanel::Query) .on_action(cx.listener(|this, action, window, cx| this.confirm(action, window, cx))) .on_action(cx.listener(|this, action, window, cx| { this.previous_history_query(action, window, cx) @@ -1988,35 +1958,78 @@ impl Render for ProjectSearchBar { .on_action( cx.listener(|this, action, window, cx| this.next_history_query(action, window, cx)), ) - .child(self.render_text_input(&search.query_editor, cx)) + .child(render_text_input(&search.query_editor, color_override, cx)) .child( h_flex() .gap_1() - .child(SearchOptions::CASE_SENSITIVE.as_button( - self.is_option_enabled(SearchOptions::CASE_SENSITIVE, cx), + .child(SearchOption::CaseSensitive.as_button( + search.search_options, + SearchSource::Project(cx), focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_search_option(SearchOptions::CASE_SENSITIVE, window, cx); - }), )) - .child(SearchOptions::WHOLE_WORD.as_button( - self.is_option_enabled(SearchOptions::WHOLE_WORD, cx), + .child(SearchOption::WholeWord.as_button( + search.search_options, + SearchSource::Project(cx), focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_search_option(SearchOptions::WHOLE_WORD, window, cx); - }), )) - .child(SearchOptions::REGEX.as_button( - self.is_option_enabled(SearchOptions::REGEX, cx), + .child(SearchOption::Regex.as_button( + search.search_options, + SearchSource::Project(cx), focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_search_option(SearchOptions::REGEX, window, cx); - }), )), ); + let query_focus = search.query_editor.focus_handle(cx); + + let matches_column = h_flex() + .pl_2() + .ml_2() + .border_l_1() + .border_color(theme_colors.border_variant) + .child(render_action_button( + "project-search-nav-button", + IconName::ChevronLeft, + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), + "Select Previous Match", + &SelectPreviousMatch, + query_focus.clone(), + )) + .child(render_action_button( + "project-search-nav-button", + IconName::ChevronRight, + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), + "Select Next Match", + &SelectNextMatch, + query_focus, + )) + .child( + div() + .id("matches") + .ml_2() + .min_w(rems_from_px(40.)) + .child(Label::new(match_text).size(LabelSize::Small).color( + if search.active_match_index.is_some() { + Color::Default + } else { + Color::Disabled + }, + )) + .when(limit_reached, |el| { + el.tooltip(Tooltip::text( + "Search limits reached.\nTry narrowing your search.", + )) + }), + ); + let mode_column = h_flex() .gap_1() + .min_w_64() .child( IconButton::new("project-search-filter-button", IconName::Filter) .shape(IconButtonShape::Square) @@ -2045,187 +2058,50 @@ impl Render for ProjectSearchBar { } }), ) - .child( - IconButton::new("project-search-toggle-replace", IconName::Replace) - .shape(IconButtonShape::Square) - .on_click(cx.listener(|this, _, window, cx| { - this.toggle_replace(&ToggleReplace, window, cx); - })) - .toggle_state( - self.active_project_search - .as_ref() - .map(|search| search.read(cx).replace_enabled) - .unwrap_or_default(), - ) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Toggle Replace", - &ToggleReplace, - &focus_handle, - window, - cx, - ) - } - }), - ); - - let limit_reached = search.entity.read(cx).limit_reached; - - let match_text = search - .active_match_index - .and_then(|index| { - let index = index + 1; - let match_quantity = search.entity.read(cx).match_ranges.len(); - if match_quantity > 0 { - debug_assert!(match_quantity >= index); - if limit_reached { - Some(format!("{index}/{match_quantity}+")) - } else { - Some(format!("{index}/{match_quantity}")) - } - } else { - None - } - }) - .unwrap_or_else(|| "0/0".to_string()); - - let matches_column = h_flex() - .pl_2() - .ml_2() - .border_l_1() - .border_color(cx.theme().colors().border_variant) - .child( - IconButton::new("project-search-prev-match", IconName::ChevronLeft) - .shape(IconButtonShape::Square) - .disabled(search.active_match_index.is_none()) - .on_click(cx.listener(|this, _, window, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.select_match(Direction::Prev, window, cx); - }) - } - })) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Go To Previous Match", - &SelectPreviousMatch, - &focus_handle, - window, - cx, - ) - } - }), - ) - .child( - IconButton::new("project-search-next-match", IconName::ChevronRight) - .shape(IconButtonShape::Square) - .disabled(search.active_match_index.is_none()) - .on_click(cx.listener(|this, _, window, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.select_match(Direction::Next, window, cx); - }) - } - })) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Go To Next Match", - &SelectNextMatch, - &focus_handle, - window, - cx, - ) - } - }), - ) - .child( - div() - .id("matches") - .ml_1() - .child(Label::new(match_text).size(LabelSize::Small).color( - if search.active_match_index.is_some() { - Color::Default - } else { - Color::Disabled - }, - )) - .when(limit_reached, |el| { - el.tooltip(Tooltip::text( - "Search limits reached.\nTry narrowing your search.", - )) - }), - ); + .child(render_action_button( + "project-search", + IconName::Replace, + self.active_project_search + .as_ref() + .map(|search| search.read(cx).replace_enabled) + .and_then(|enabled| enabled.then_some(ActionButtonState::Toggled)), + "Toggle Replace", + &ToggleReplace, + focus_handle.clone(), + )) + .child(matches_column); let search_line = h_flex() .w_full() .gap_2() .child(query_column) - .child(h_flex().min_w_64().child(mode_column).child(matches_column)); + .child(mode_column); let replace_line = search.replace_enabled.then(|| { - let replace_column = input_base_styles(BaseStyle::SingleInput, InputPanel::Replacement) - .child(self.render_text_input(&search.replacement_editor, cx)); + let replace_column = input_base_styles(InputPanel::Replacement) + .child(render_text_input(&search.replacement_editor, None, cx)); let focus_handle = search.replacement_editor.read(cx).focus_handle(cx); - let replace_actions = - h_flex() - .min_w_64() - .gap_1() - .when(search.replace_enabled, |this| { - this.child( - IconButton::new("project-search-replace-next", IconName::ReplaceNext) - .shape(IconButtonShape::Square) - .on_click(cx.listener(|this, _, window, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.replace_next(&ReplaceNext, window, cx); - }) - } - })) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Replace Next Match", - &ReplaceNext, - &focus_handle, - window, - cx, - ) - } - }), - ) - .child( - IconButton::new("project-search-replace-all", IconName::ReplaceAll) - .shape(IconButtonShape::Square) - .on_click(cx.listener(|this, _, window, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.replace_all(&ReplaceAll, window, cx); - }) - } - })) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Replace All Matches", - &ReplaceAll, - &focus_handle, - window, - cx, - ) - } - }), - ) - }); + let replace_actions = h_flex() + .min_w_64() + .gap_1() + .child(render_action_button( + "project-search-replace-button", + IconName::ReplaceNext, + Default::default(), + "Replace Next Match", + &ReplaceNext, + focus_handle.clone(), + )) + .child(render_action_button( + "project-search-replace-button", + IconName::ReplaceAll, + Default::default(), + "Replace All Matches", + &ReplaceAll, + focus_handle, + )); h_flex() .w_full() @@ -2235,6 +2111,39 @@ impl Render for ProjectSearchBar { }); let filter_line = search.filters_enabled.then(|| { + let include = input_base_styles(InputPanel::Include) + .on_action(cx.listener(|this, action, window, cx| { + this.previous_history_query(action, window, cx) + })) + .on_action(cx.listener(|this, action, window, cx| { + this.next_history_query(action, window, cx) + })) + .child(render_text_input(&search.included_files_editor, None, cx)); + let exclude = input_base_styles(InputPanel::Exclude) + .on_action(cx.listener(|this, action, window, cx| { + this.previous_history_query(action, window, cx) + })) + .on_action(cx.listener(|this, action, window, cx| { + this.next_history_query(action, window, cx) + })) + .child(render_text_input(&search.excluded_files_editor, None, cx)); + let mode_column = h_flex() + .gap_1() + .min_w_64() + .child( + IconButton::new("project-search-opened-only", IconName::FolderSearch) + .shape(IconButtonShape::Square) + .toggle_state(self.is_opened_only_enabled(cx)) + .tooltip(Tooltip::text("Only Search Open Files")) + .on_click(cx.listener(|this, _, window, cx| { + this.toggle_opened_only(window, cx); + })), + ) + .child(SearchOption::IncludeIgnored.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )); h_flex() .w_full() .gap_2() @@ -2242,62 +2151,14 @@ impl Render for ProjectSearchBar { h_flex() .gap_2() .w(input_width) - .child( - input_base_styles(BaseStyle::MultipleInputs, InputPanel::Include) - .on_action(cx.listener(|this, action, window, cx| { - this.previous_history_query(action, window, cx) - })) - .on_action(cx.listener(|this, action, window, cx| { - this.next_history_query(action, window, cx) - })) - .child(self.render_text_input(&search.included_files_editor, cx)), - ) - .child( - input_base_styles(BaseStyle::MultipleInputs, InputPanel::Exclude) - .on_action(cx.listener(|this, action, window, cx| { - this.previous_history_query(action, window, cx) - })) - .on_action(cx.listener(|this, action, window, cx| { - this.next_history_query(action, window, cx) - })) - .child(self.render_text_input(&search.excluded_files_editor, cx)), - ), - ) - .child( - h_flex() - .min_w_64() - .gap_1() - .child( - IconButton::new("project-search-opened-only", IconName::FolderSearch) - .shape(IconButtonShape::Square) - .toggle_state(self.is_opened_only_enabled(cx)) - .tooltip(Tooltip::text("Only Search Open Files")) - .on_click(cx.listener(|this, _, window, cx| { - this.toggle_opened_only(window, cx); - })), - ) - .child( - SearchOptions::INCLUDE_IGNORED.as_button( - search - .search_options - .contains(SearchOptions::INCLUDE_IGNORED), - focus_handle.clone(), - cx.listener(|this, _, window, cx| { - this.toggle_search_option( - SearchOptions::INCLUDE_IGNORED, - window, - cx, - ); - }), - ), - ), + .child(include) + .child(exclude), ) + .child(mode_column) }); let mut key_context = KeyContext::default(); - key_context.add("ProjectSearchBar"); - if search .replacement_editor .focus_handle(cx) @@ -2306,16 +2167,33 @@ impl Render for ProjectSearchBar { key_context.add("in_replace"); } - let query_error_line = search.query_error.as_ref().map(|error| { - Label::new(error) - .size(LabelSize::Small) - .color(Color::Error) - .mt_neg_1() - .ml_2() - }); + let query_error_line = search + .panels_with_errors + .get(&InputPanel::Query) + .map(|error| { + Label::new(error) + .size(LabelSize::Small) + .color(Color::Error) + .mt_neg_1() + .ml_2() + }); + + let filter_error_line = search + .panels_with_errors + .get(&InputPanel::Include) + .or_else(|| search.panels_with_errors.get(&InputPanel::Exclude)) + .map(|error| { + Label::new(error) + .size(LabelSize::Small) + .color(Color::Error) + .mt_neg_1() + .ml_2() + }); v_flex() + .gap_2() .py(px(1.0)) + .w_full() .key_context(key_context) .on_action(cx.listener(|this, _: &ToggleFocus, window, cx| { this.move_focus_to_results(window, cx) @@ -2323,14 +2201,8 @@ impl Render for ProjectSearchBar { .on_action(cx.listener(|this, _: &ToggleFilters, window, cx| { this.toggle_filters(window, cx); })) - .capture_action(cx.listener(|this, action, window, cx| { - this.tab(action, window, cx); - cx.stop_propagation(); - })) - .capture_action(cx.listener(|this, action, window, cx| { - this.backtab(action, window, cx); - cx.stop_propagation(); - })) + .capture_action(cx.listener(Self::tab)) + .capture_action(cx.listener(Self::backtab)) .on_action(cx.listener(|this, action, window, cx| this.confirm(action, window, cx))) .on_action(cx.listener(|this, action, window, cx| { this.toggle_replace(action, window, cx); @@ -2362,12 +2234,11 @@ impl Render for ProjectSearchBar { }) .on_action(cx.listener(Self::select_next_match)) .on_action(cx.listener(Self::select_prev_match)) - .gap_2() - .w_full() .child(search_line) .children(query_error_line) .children(replace_line) .children(filter_line) + .children(filter_error_line) } } diff --git a/crates/search/src/search.rs b/crates/search/src/search.rs index 5f57bfb4b1..904c74d03c 100644 --- a/crates/search/src/search.rs +++ b/crates/search/src/search.rs @@ -1,7 +1,7 @@ use bitflags::bitflags; pub use buffer_search::BufferSearchBar; use editor::SearchSettings; -use gpui::{Action, App, FocusHandle, IntoElement, actions}; +use gpui::{Action, App, ClickEvent, FocusHandle, IntoElement, actions}; use project::search::SearchQuery; pub use project_search::ProjectSearchView; use ui::{ButtonStyle, IconButton, IconButtonShape}; @@ -9,6 +9,10 @@ use ui::{Tooltip, prelude::*}; use workspace::notifications::NotificationId; use workspace::{Toast, Workspace}; +pub use search_status_button::SEARCH_ICON; + +use crate::project_search::ProjectSearchBar; + pub mod buffer_search; pub mod project_search; pub(crate) mod search_bar; @@ -59,48 +63,108 @@ actions!( bitflags! { #[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] pub struct SearchOptions: u8 { - const NONE = 0b000; - const WHOLE_WORD = 0b001; - const CASE_SENSITIVE = 0b010; - const INCLUDE_IGNORED = 0b100; - const REGEX = 0b1000; - const ONE_MATCH_PER_LINE = 0b100000; + const NONE = 0; + const WHOLE_WORD = 1 << SearchOption::WholeWord as u8; + const CASE_SENSITIVE = 1 << SearchOption::CaseSensitive as u8; + const INCLUDE_IGNORED = 1 << SearchOption::IncludeIgnored as u8; + const REGEX = 1 << SearchOption::Regex as u8; + const ONE_MATCH_PER_LINE = 1 << SearchOption::OneMatchPerLine as u8; /// If set, reverse direction when finding the active match - const BACKWARDS = 0b10000; + const BACKWARDS = 1 << SearchOption::Backwards as u8; } } -impl SearchOptions { +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum SearchOption { + WholeWord = 0, + CaseSensitive, + IncludeIgnored, + Regex, + OneMatchPerLine, + Backwards, +} + +pub(crate) enum SearchSource<'a, 'b> { + Buffer, + Project(&'a Context<'b, ProjectSearchBar>), +} + +impl SearchOption { + pub fn as_options(&self) -> SearchOptions { + SearchOptions::from_bits(1 << *self as u8).unwrap() + } + pub fn label(&self) -> &'static str { - match *self { - SearchOptions::WHOLE_WORD => "Match Whole Words", - SearchOptions::CASE_SENSITIVE => "Match Case Sensitively", - SearchOptions::INCLUDE_IGNORED => "Also search files ignored by configuration", - SearchOptions::REGEX => "Use Regular Expressions", - _ => panic!("{:?} is not a named SearchOption", self), + match self { + SearchOption::WholeWord => "Match Whole Words", + SearchOption::CaseSensitive => "Match Case Sensitively", + SearchOption::IncludeIgnored => "Also search files ignored by configuration", + SearchOption::Regex => "Use Regular Expressions", + SearchOption::OneMatchPerLine => "One Match Per Line", + SearchOption::Backwards => "Search Backwards", } } pub fn icon(&self) -> ui::IconName { - match *self { - SearchOptions::WHOLE_WORD => ui::IconName::WholeWord, - SearchOptions::CASE_SENSITIVE => ui::IconName::CaseSensitive, - SearchOptions::INCLUDE_IGNORED => ui::IconName::Sliders, - SearchOptions::REGEX => ui::IconName::Regex, - _ => panic!("{:?} is not a named SearchOption", self), + match self { + SearchOption::WholeWord => ui::IconName::WholeWord, + SearchOption::CaseSensitive => ui::IconName::CaseSensitive, + SearchOption::IncludeIgnored => ui::IconName::Sliders, + SearchOption::Regex => ui::IconName::Regex, + _ => panic!("{self:?} is not a named SearchOption"), } } - pub fn to_toggle_action(&self) -> Box { + pub fn to_toggle_action(&self) -> &'static dyn Action { match *self { - SearchOptions::WHOLE_WORD => Box::new(ToggleWholeWord), - SearchOptions::CASE_SENSITIVE => Box::new(ToggleCaseSensitive), - SearchOptions::INCLUDE_IGNORED => Box::new(ToggleIncludeIgnored), - SearchOptions::REGEX => Box::new(ToggleRegex), - _ => panic!("{:?} is not a named SearchOption", self), + SearchOption::WholeWord => &ToggleWholeWord, + SearchOption::CaseSensitive => &ToggleCaseSensitive, + SearchOption::IncludeIgnored => &ToggleIncludeIgnored, + SearchOption::Regex => &ToggleRegex, + _ => panic!("{self:?} is not a toggle action"), } } + pub(crate) fn as_button( + &self, + active: SearchOptions, + search_source: SearchSource, + focus_handle: FocusHandle, + ) -> impl IntoElement { + let action = self.to_toggle_action(); + let label = self.label(); + IconButton::new( + (label, matches!(search_source, SearchSource::Buffer) as u32), + self.icon(), + ) + .map(|button| match search_source { + SearchSource::Buffer => { + let focus_handle = focus_handle.clone(); + button.on_click(move |_: &ClickEvent, window, cx| { + if !focus_handle.is_focused(&window) { + window.focus(&focus_handle); + } + window.dispatch_action(action.boxed_clone(), cx); + }) + } + SearchSource::Project(cx) => { + let options = self.as_options(); + button.on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { + this.toggle_search_option(options, window, cx); + })) + } + }) + .style(ButtonStyle::Subtle) + .shape(IconButtonShape::Square) + .toggle_state(active.contains(self.as_options())) + .tooltip({ + move |window, cx| Tooltip::for_action_in(label, action, &focus_handle, window, cx) + }) + } +} + +impl SearchOptions { pub fn none() -> SearchOptions { SearchOptions::NONE } @@ -122,24 +186,6 @@ impl SearchOptions { options.set(SearchOptions::REGEX, settings.regex); options } - - pub fn as_button( - &self, - active: bool, - focus_handle: FocusHandle, - action: Action, - ) -> impl IntoElement + use { - IconButton::new(self.label(), self.icon()) - .on_click(action) - .style(ButtonStyle::Subtle) - .shape(IconButtonShape::Square) - .toggle_state(active) - .tooltip({ - let action = self.to_toggle_action(); - let label = self.label(); - move |window, cx| Tooltip::for_action_in(label, &*action, &focus_handle, window, cx) - }) - } } pub(crate) fn show_no_more_matches(window: &mut Window, cx: &mut App) { diff --git a/crates/search/src/search_bar.rs b/crates/search/src/search_bar.rs index 805664c794..8cc838a8a6 100644 --- a/crates/search/src/search_bar.rs +++ b/crates/search/src/search_bar.rs @@ -1,16 +1,25 @@ -use gpui::{Action, FocusHandle, IntoElement}; +use editor::{Editor, EditorElement, EditorStyle}; +use gpui::{Action, Entity, FocusHandle, Hsla, IntoElement, TextStyle}; +use settings::Settings; +use theme::ThemeSettings; use ui::{IconButton, IconButtonShape}; use ui::{Tooltip, prelude::*}; -pub(super) fn render_nav_button( +pub(super) enum ActionButtonState { + Disabled, + Toggled, +} + +pub(super) fn render_action_button( + id_prefix: &'static str, icon: ui::IconName, - active: bool, + button_state: Option, tooltip: &'static str, action: &'static dyn Action, focus_handle: FocusHandle, ) -> impl IntoElement { IconButton::new( - SharedString::from(format!("search-nav-button-{}", action.name())), + SharedString::from(format!("{id_prefix}-{}", action.name())), icon, ) .shape(IconButtonShape::Square) @@ -24,5 +33,60 @@ pub(super) fn render_nav_button( } }) .tooltip(move |window, cx| Tooltip::for_action_in(tooltip, action, &focus_handle, window, cx)) - .disabled(!active) + .when_some(button_state, |this, state| match state { + ActionButtonState::Toggled => this.toggle_state(true), + ActionButtonState::Disabled => this.disabled(true), + }) +} + +pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div) -> Div { + h_flex() + .min_w_32() + .map(map) + .h_8() + .pl_2() + .pr_1() + .py_1() + .border_1() + .border_color(border_color) + .rounded_lg() +} + +pub(crate) fn render_text_input( + editor: &Entity, + color_override: Option, + app: &App, +) -> impl IntoElement { + let (color, use_syntax) = if editor.read(app).read_only(app) { + (app.theme().colors().text_disabled, false) + } else { + match color_override { + Some(color_override) => (color_override.color(app), false), + None => (app.theme().colors().text, true), + } + }; + + let settings = ThemeSettings::get_global(app); + let text_style = TextStyle { + color, + font_family: settings.buffer_font.family.clone(), + font_features: settings.buffer_font.features.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.buffer_font.weight, + line_height: relative(1.3), + ..TextStyle::default() + }; + + let mut editor_style = EditorStyle { + background: app.theme().colors().toolbar_background, + local_player: app.theme().players().local(), + text: text_style, + ..EditorStyle::default() + }; + if use_syntax { + editor_style.syntax = app.theme().syntax().clone(); + } + + EditorElement::new(editor, editor_style) } diff --git a/crates/search/src/search_status_button.rs b/crates/search/src/search_status_button.rs index ff2ee1641d..fcf36e86fa 100644 --- a/crates/search/src/search_status_button.rs +++ b/crates/search/src/search_status_button.rs @@ -3,6 +3,8 @@ use settings::Settings as _; use ui::{ButtonCommon, Clickable, Context, Render, Tooltip, Window, prelude::*}; use workspace::{ItemHandle, StatusItemView}; +pub const SEARCH_ICON: IconName = IconName::MagnifyingGlass; + pub struct SearchButton; impl SearchButton { @@ -20,7 +22,7 @@ impl Render for SearchButton { } button.child( - IconButton::new("project-search-indicator", IconName::MagnifyingGlass) + IconButton::new("project-search-indicator", SEARCH_ICON) .icon_size(IconSize::Small) .tooltip(|window, cx| { Tooltip::for_action( diff --git a/crates/settings/src/base_keymap_setting.rs b/crates/settings/src/base_keymap_setting.rs index 6916d98ae3..91dda03d00 100644 --- a/crates/settings/src/base_keymap_setting.rs +++ b/crates/settings/src/base_keymap_setting.rs @@ -44,7 +44,7 @@ impl BaseKeymap { ("Sublime Text", Self::SublimeText), ("Emacs (beta)", Self::Emacs), ("TextMate", Self::TextMate), - ("Cursor (beta)", Self::Cursor), + ("Cursor", Self::Cursor), ]; #[cfg(not(target_os = "macos"))] @@ -54,7 +54,7 @@ impl BaseKeymap { ("JetBrains", Self::JetBrains), ("Sublime Text", Self::SublimeText), ("Emacs (beta)", Self::Emacs), - ("Cursor (beta)", Self::Cursor), + ("Cursor", Self::Cursor), ]; pub fn asset_path(&self) -> Option<&'static str> { diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index a4c47081c6..8a151359ec 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -42,8 +42,10 @@ tree-sitter-rust.workspace = true ui.workspace = true ui_input.workspace = true util.workspace = true +vim.workspace = true workspace-hack.workspace = true workspace.workspace = true +zed_actions.workspace = true [dev-dependencies] db = {"workspace"= true, "features" = ["test-support"]} diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index a62c669488..b4e871c617 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -23,7 +23,7 @@ use settings::{BaseKeymap, KeybindSource, KeymapFile, Settings as _, SettingsAss use ui::{ ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, IconButtonShape, Indicator, Modal, ModalFooter, ModalHeader, ParentElement as _, Render, Section, SharedString, - Styled as _, Tooltip, Window, prelude::*, + Styled as _, Tooltip, Window, prelude::*, right_click_menu, }; use ui_input::SingleLineInput; use util::ResultExt; @@ -1536,6 +1536,33 @@ impl Render for KeymapEditor { .child( h_flex() .gap_2() + .child( + right_click_menu("open-keymap-menu") + .menu(|window, cx| { + ContextMenu::build(window, cx, |menu, _, _| { + menu.header("Open Keymap JSON") + .action("User", zed_actions::OpenKeymap.boxed_clone()) + .action("Zed Default", zed_actions::OpenDefaultKeymap.boxed_clone()) + .action("Vim Default", vim::OpenDefaultKeymap.boxed_clone()) + }) + }) + .anchor(gpui::Corner::TopLeft) + .trigger(|open, _, _| + IconButton::new( + "OpenKeymapJsonButton", + IconName::Json + ) + .shape(ui::IconButtonShape::Square) + .when(!open, |this| + this.tooltip(move |window, cx| { + Tooltip::with_meta("Open Keymap JSON", Some(&zed_actions::OpenKeymap),"Right click to view more options", window, cx) + }) + ) + .on_click(|_, window, cx| { + window.dispatch_action(zed_actions::OpenKeymap.boxed_clone(), cx); + }) + ) + ) .child( div() .key_context({ @@ -2154,6 +2181,7 @@ impl KeybindingEditorModal { let value = action_arguments .as_ref() + .filter(|args| !args.is_empty()) .map(|args| { serde_json::from_str(args).context("Failed to parse action arguments as JSON") }) diff --git a/crates/storybook/src/stories/indent_guides.rs b/crates/storybook/src/stories/indent_guides.rs index e4f9669b1f..db23ea79bd 100644 --- a/crates/storybook/src/stories/indent_guides.rs +++ b/crates/storybook/src/stories/indent_guides.rs @@ -65,7 +65,7 @@ impl Render for IndentGuidesStory { }, ) .with_compute_indents_fn( - cx.entity().clone(), + cx.entity(), |this, range, _cx, _context| { this.depths .iter() diff --git a/crates/storybook/src/storybook.rs b/crates/storybook/src/storybook.rs index 4c5b6272ef..ac01e6c5c8 100644 --- a/crates/storybook/src/storybook.rs +++ b/crates/storybook/src/storybook.rs @@ -128,7 +128,7 @@ impl Render for StoryWrapper { .flex() .flex_col() .size_full() - .font_family("Zed Plex Mono") + .font_family(".ZedMono") .child(self.story.clone()) } } diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index 735a1310ae..12d8d4c04b 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -2,7 +2,7 @@ use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt::Display, sync::Arc, time::Duration}; +use std::{collections::HashMap, fmt::Display, time::Duration}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventRequestBody { @@ -93,19 +93,6 @@ impl Display for AssistantPhase { #[serde(tag = "type")] pub enum Event { Flexible(FlexibleEvent), - Editor(EditorEvent), - EditPrediction(EditPredictionEvent), - EditPredictionRating(EditPredictionRatingEvent), - Call(CallEvent), - Assistant(AssistantEventData), - Cpu(CpuEvent), - Memory(MemoryEvent), - App(AppEvent), - Setting(SettingEvent), - Extension(ExtensionEvent), - Edit(EditEvent), - Action(ActionEvent), - Repl(ReplEvent), } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -114,54 +101,12 @@ pub struct FlexibleEvent { pub event_properties: HashMap, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditorEvent { - /// The editor operation performed (open, save) - pub operation: String, - /// The extension of the file that was opened or saved - pub file_extension: Option, - /// Whether the user is in vim mode or not - pub vim_mode: bool, - /// Whether the user has copilot enabled or not - pub copilot_enabled: bool, - /// Whether the user has copilot enabled for the language of the file opened or saved - pub copilot_enabled_for_language: bool, - /// Whether the client is opening/saving a local file or a remote file via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionEvent { - /// Provider of the completion suggestion (e.g. copilot, supermaven) - pub provider: String, - pub suggestion_accepted: bool, - pub file_extension: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum EditPredictionRating { Positive, Negative, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionRatingEvent { - pub rating: EditPredictionRating, - pub input_events: Arc, - pub input_excerpt: Arc, - pub output_excerpt: Arc, - pub feedback: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CallEvent { - /// Operation performed: invite/join call; begin/end screenshare; share/unshare project; etc - pub operation: String, - pub room_id: Option, - pub channel_id: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AssistantEventData { /// Unique random identifier for each assistant tab (None for inline assist) @@ -180,57 +125,6 @@ pub struct AssistantEventData { pub language_name: Option, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CpuEvent { - pub usage_as_percentage: f32, - pub core_count: u32, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct MemoryEvent { - pub memory_in_bytes: u64, - pub virtual_memory_in_bytes: u64, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ActionEvent { - pub source: String, - pub action: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditEvent { - pub duration: i64, - pub environment: String, - /// Whether the edits occurred locally or remotely via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SettingEvent { - pub setting: String, - pub value: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ExtensionEvent { - pub extension_id: Arc, - pub version: Arc, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct AppEvent { - pub operation: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ReplEvent { - pub kernel_language: String, - pub kernel_status: String, - pub repl_session_id: String, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct BacktraceFrame { pub ip: usize, diff --git a/crates/terminal/Cargo.toml b/crates/terminal/Cargo.toml index 93f61622c8..b1c0dd693f 100644 --- a/crates/terminal/Cargo.toml +++ b/crates/terminal/Cargo.toml @@ -5,6 +5,13 @@ edition.workspace = true publish.workspace = true license = "GPL-3.0-or-later" +[features] +test-support = [ + "collections/test-support", + "gpui/test-support", + "settings/test-support", +] + [lints] workspace = true @@ -39,5 +46,6 @@ workspace-hack.workspace = true windows.workspace = true [dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } rand.workspace = true url.workspace = true diff --git a/crates/terminal/src/terminal.rs b/crates/terminal/src/terminal.rs index d6a09a590f..86728cc11c 100644 --- a/crates/terminal/src/terminal.rs +++ b/crates/terminal/src/terminal.rs @@ -58,7 +58,7 @@ use std::{ path::PathBuf, process::ExitStatus, sync::Arc, - time::{Duration, Instant}, + time::Instant, }; use thiserror::Error; @@ -167,6 +167,7 @@ enum InternalEvent { // Vi mode events ToggleViMode, ViMotion(ViMotion), + MoveViCursorToAlacPoint(AlacPoint), } ///A translation struct for Alacritty to communicate with us from their event loop @@ -408,7 +409,13 @@ impl TerminalBuilder { let terminal_title_override = shell_params.as_ref().and_then(|e| e.title_override.clone()); #[cfg(windows)] - let shell_program = shell_params.as_ref().map(|params| params.program.clone()); + let shell_program = shell_params.as_ref().map(|params| { + use util::ResultExt; + + Self::resolve_path(¶ms.program) + .log_err() + .unwrap_or(params.program.clone()) + }); let pty_options = { let alac_shell = shell_params.map(|params| { @@ -534,10 +541,15 @@ impl TerminalBuilder { 'outer: loop { let mut events = Vec::new(); + + #[cfg(any(test, feature = "test-support"))] + let mut timer = cx.background_executor().simulate_random_delay().fuse(); + #[cfg(not(any(test, feature = "test-support")))] let mut timer = cx .background_executor() - .timer(Duration::from_millis(4)) + .timer(std::time::Duration::from_millis(4)) .fuse(); + let mut wakeup = false; loop { futures::select_biased! { @@ -584,6 +596,24 @@ impl TerminalBuilder { self.terminal } + + #[cfg(windows)] + fn resolve_path(path: &str) -> Result { + use windows::Win32::Storage::FileSystem::SearchPathW; + use windows::core::HSTRING; + + let path = if path.starts_with(r"\\?\") || !path.contains(&['/', '\\']) { + path.to_string() + } else { + r"\\?\".to_string() + path + }; + + let required_length = unsafe { SearchPathW(None, &HSTRING::from(&path), None, None, None) }; + let mut buf = vec![0u16; required_length as usize]; + let size = unsafe { SearchPathW(None, &HSTRING::from(&path), None, Some(&mut buf), None) }; + + Ok(String::from_utf16(&buf[..size as usize])?) + } } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -943,6 +973,10 @@ impl Terminal { term.scroll_to_point(*point); self.refresh_hovered_word(window); } + InternalEvent::MoveViCursorToAlacPoint(point) => { + term.vi_goto_point(*point); + self.refresh_hovered_word(window); + } InternalEvent::ToggleViMode => { self.vi_mode_enabled = !self.vi_mode_enabled; term.toggle_vi_mode(); @@ -1071,12 +1105,21 @@ impl Terminal { pub fn activate_match(&mut self, index: usize) { if let Some(search_match) = self.matches.get(index).cloned() { self.set_selection(Some((make_selection(&search_match), *search_match.end()))); - - self.events - .push_back(InternalEvent::ScrollToAlacPoint(*search_match.start())); + if self.vi_mode_enabled { + self.events + .push_back(InternalEvent::MoveViCursorToAlacPoint(*search_match.end())); + } else { + self.events + .push_back(InternalEvent::ScrollToAlacPoint(*search_match.start())); + } } } + pub fn clear_matches(&mut self) { + self.matches.clear(); + self.set_selection(None); + } + pub fn select_matches(&mut self, matches: &[RangeInclusive]) { let matches_to_select = self .matches @@ -2104,16 +2147,56 @@ pub fn rgba_color(r: u8, g: u8, b: u8) -> Hsla { #[cfg(test)] mod tests { + use super::*; + use crate::{ + IndexedCell, TerminalBounds, TerminalBuilder, TerminalContent, content_index_for_mouse, + rgb_for_index, + }; use alacritty_terminal::{ index::{Column, Line, Point as AlacPoint}, term::cell::Cell, }; - use gpui::{Pixels, Point, bounds, point, size}; + use collections::HashMap; + use gpui::{Pixels, Point, TestAppContext, bounds, point, size}; use rand::{Rng, distributions::Alphanumeric, rngs::ThreadRng, thread_rng}; - use crate::{ - IndexedCell, TerminalBounds, TerminalContent, content_index_for_mouse, rgb_for_index, - }; + #[cfg_attr(windows, ignore = "TODO: fix on windows")] + #[gpui::test] + async fn test_basic_terminal(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let (completion_tx, completion_rx) = smol::channel::unbounded(); + let terminal = cx.new(|cx| { + TerminalBuilder::new( + None, + None, + None, + task::Shell::WithArguments { + program: "echo".into(), + args: vec!["hello".into()], + title_override: None, + }, + HashMap::default(), + CursorShape::default(), + AlternateScroll::On, + None, + false, + 0, + completion_tx, + cx, + ) + .unwrap() + .subscribe(cx) + }); + assert_eq!( + completion_rx.recv().await.unwrap(), + Some(ExitStatus::default()) + ); + assert_eq!( + terminal.update(cx, |term, _| term.get_content()).trim(), + "hello" + ); + } #[test] fn test_rgb_for_index() { diff --git a/crates/terminal_view/src/terminal_element.rs b/crates/terminal_view/src/terminal_element.rs index 083c07de9c..6c1be9d5e7 100644 --- a/crates/terminal_view/src/terminal_element.rs +++ b/crates/terminal_view/src/terminal_element.rs @@ -136,7 +136,7 @@ impl BatchedTextRun { .shape_line( self.text.clone().into(), self.font_size.to_pixels(window.rem_size()), - &[self.style.clone()], + std::slice::from_ref(&self.style), Some(dimensions.cell_width), ) .paint(pos, dimensions.line_height, window, cx); diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index c9528c39b9..568dc1db2e 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -947,7 +947,7 @@ pub fn new_terminal_pane( cx: &mut Context, ) -> Entity { let is_local = project.read(cx).is_local(); - let terminal_panel = cx.entity().clone(); + let terminal_panel = cx.entity(); let pane = cx.new(|cx| { let mut pane = Pane::new( workspace.clone(), @@ -1009,7 +1009,7 @@ pub fn new_terminal_pane( return ControlFlow::Break(()); }; if let Some(tab) = dropped_item.downcast_ref::() { - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 0ec5f816d5..534c0a8051 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1491,7 +1491,7 @@ impl TerminalView { impl Render for TerminalView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let terminal_handle = self.terminal.clone(); - let terminal_view_handle = cx.entity().clone(); + let terminal_view_handle = cx.entity(); let focused = self.focus_handle.is_focused(window); @@ -1869,7 +1869,7 @@ impl SearchableItem for TerminalView { /// Clear stored matches fn clear_matches(&mut self, _window: &mut Window, cx: &mut Context) { - self.terminal().update(cx, |term, _| term.matches.clear()) + self.terminal().update(cx, |term, _| term.clear_matches()) } /// Store matches returned from find_matches somewhere for rendering diff --git a/crates/text/src/selection.rs b/crates/text/src/selection.rs index 18b82dbb6a..d3c280bde8 100644 --- a/crates/text/src/selection.rs +++ b/crates/text/src/selection.rs @@ -104,6 +104,19 @@ impl Selection { self.goal = new_goal; } + pub fn set_head_tail(&mut self, head: T, tail: T, new_goal: SelectionGoal) { + if head < tail { + self.reversed = true; + self.start = head; + self.end = tail; + } else { + self.reversed = false; + self.start = tail; + self.end = head; + } + self.goal = new_goal; + } + pub fn swap_head_tail(&mut self) { if self.reversed { self.reversed = false; diff --git a/crates/theme/src/default_colors.rs b/crates/theme/src/default_colors.rs index 1c3f48b548..051b7acf10 100644 --- a/crates/theme/src/default_colors.rs +++ b/crates/theme/src/default_colors.rs @@ -54,6 +54,7 @@ impl ThemeColors { element_disabled: neutral().light_alpha().step_3(), element_selection_background: blue().light().step_3().alpha(0.25), drop_target_background: blue().light_alpha().step_2(), + drop_target_border: neutral().light().step_12(), ghost_element_background: system.transparent, ghost_element_hover: neutral().light_alpha().step_3(), ghost_element_active: neutral().light_alpha().step_4(), @@ -179,6 +180,7 @@ impl ThemeColors { element_disabled: neutral().dark_alpha().step_3(), element_selection_background: blue().dark().step_3().alpha(0.25), drop_target_background: blue().dark_alpha().step_2(), + drop_target_border: neutral().dark().step_12(), ghost_element_background: system.transparent, ghost_element_hover: neutral().dark_alpha().step_4(), ghost_element_active: neutral().dark_alpha().step_5(), diff --git a/crates/theme/src/fallback_themes.rs b/crates/theme/src/fallback_themes.rs index 4d77dd5d81..e9e8e2d0db 100644 --- a/crates/theme/src/fallback_themes.rs +++ b/crates/theme/src/fallback_themes.rs @@ -115,6 +115,7 @@ pub(crate) fn zed_default_dark() -> Theme { element_disabled: SystemColors::default().transparent, element_selection_background: player.local().selection.alpha(0.25), drop_target_background: hsla(220.0 / 360., 8.3 / 100., 21.4 / 100., 1.0), + drop_target_border: hsla(221. / 360., 11. / 100., 86. / 100., 1.0), ghost_element_background: SystemColors::default().transparent, ghost_element_hover: hover, ghost_element_active: hsla(220.0 / 360., 11.8 / 100., 20.0 / 100., 1.0), diff --git a/crates/theme/src/schema.rs b/crates/theme/src/schema.rs index bfa2adcedf..425fedbc71 100644 --- a/crates/theme/src/schema.rs +++ b/crates/theme/src/schema.rs @@ -225,6 +225,10 @@ pub struct ThemeColorsContent { #[serde(rename = "drop_target.background")] pub drop_target_background: Option, + /// Border Color. Used for the border that shows where a dragged element will be dropped. + #[serde(rename = "drop_target.border")] + pub drop_target_border: Option, + /// Used for the background of a ghost element that should have the same background as the surface it's on. /// /// Elements might include: Buttons, Inputs, Checkboxes, Radio Buttons... @@ -747,6 +751,10 @@ impl ThemeColorsContent { .drop_target_background .as_ref() .and_then(|color| try_parse_color(color).ok()), + drop_target_border: self + .drop_target_border + .as_ref() + .and_then(|color| try_parse_color(color).ok()), ghost_element_background: self .ghost_element_background .as_ref() diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index f5f1fd5547..df147cfe92 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -19,6 +19,7 @@ use util::ResultExt as _; use util::schemars::replace_subschema; const MIN_FONT_SIZE: Pixels = px(6.0); +const MAX_FONT_SIZE: Pixels = px(100.0); const MIN_LINE_HEIGHT: f32 = 1.0; #[derive( @@ -103,8 +104,8 @@ pub struct ThemeSettings { /// /// The terminal font family can be overridden using it's own setting. pub buffer_font: Font, - /// The agent font size. Determines the size of text in the agent panel. - agent_font_size: Pixels, + /// The agent font size. Determines the size of text in the agent panel. Falls back to the UI font size if unset. + agent_font_size: Option, /// The line height for buffers, and the terminal. /// /// Changing this may affect the spacing of some UI elements. @@ -404,9 +405,9 @@ pub struct ThemeSettingsContent { #[serde(default)] #[schemars(default = "default_font_features")] pub buffer_font_features: Option, - /// The font size for the agent panel. + /// The font size for the agent panel. Falls back to the UI font size if unset. #[serde(default)] - pub agent_font_size: Option, + pub agent_font_size: Option>, /// The name of the Zed theme to use. #[serde(default)] pub theme: Option, @@ -599,13 +600,13 @@ impl ThemeSettings { clamp_font_size(font_size) } - /// Returns the UI font size. + /// Returns the agent panel font size. Falls back to the UI font size if unset. pub fn agent_font_size(&self, cx: &App) -> Pixels { - let font_size = cx - .try_global::() + cx.try_global::() .map(|size| size.0) - .unwrap_or(self.agent_font_size); - clamp_font_size(font_size) + .or(self.agent_font_size) + .map(clamp_font_size) + .unwrap_or_else(|| self.ui_font_size(cx)) } /// Returns the buffer font size, read from the settings. @@ -624,6 +625,14 @@ impl ThemeSettings { self.ui_font_size } + /// Returns the agent font size, read from the settings. + /// + /// The real agent font size is stored in-memory, to support temporary font size changes. + /// Use [`Self::agent_font_size`] to get the real font size. + pub fn agent_font_size_settings(&self) -> Option { + self.agent_font_size + } + // TODO: Rename: `line_height` -> `buffer_line_height` /// Returns the buffer's line height. pub fn line_height(&self) -> f32 { @@ -732,14 +741,12 @@ pub fn adjusted_font_size(size: Pixels, cx: &App) -> Pixels { } /// Adjusts the buffer font size. -pub fn adjust_buffer_font_size(cx: &mut App, mut f: impl FnMut(&mut Pixels)) { +pub fn adjust_buffer_font_size(cx: &mut App, f: impl FnOnce(Pixels) -> Pixels) { let buffer_font_size = ThemeSettings::get_global(cx).buffer_font_size; - let mut adjusted_size = cx + let adjusted_size = cx .try_global::() .map_or(buffer_font_size, |adjusted_size| adjusted_size.0); - - f(&mut adjusted_size); - cx.set_global(BufferFontSize(clamp_font_size(adjusted_size))); + cx.set_global(BufferFontSize(clamp_font_size(f(adjusted_size)))); cx.refresh_windows(); } @@ -765,14 +772,12 @@ pub fn setup_ui_font(window: &mut Window, cx: &mut App) -> gpui::Font { } /// Sets the adjusted UI font size. -pub fn adjust_ui_font_size(cx: &mut App, mut f: impl FnMut(&mut Pixels)) { +pub fn adjust_ui_font_size(cx: &mut App, f: impl FnOnce(Pixels) -> Pixels) { let ui_font_size = ThemeSettings::get_global(cx).ui_font_size(cx); - let mut adjusted_size = cx + let adjusted_size = cx .try_global::() .map_or(ui_font_size, |adjusted_size| adjusted_size.0); - - f(&mut adjusted_size); - cx.set_global(UiFontSize(clamp_font_size(adjusted_size))); + cx.set_global(UiFontSize(clamp_font_size(f(adjusted_size)))); cx.refresh_windows(); } @@ -784,19 +789,17 @@ pub fn reset_ui_font_size(cx: &mut App) { } } -/// Sets the adjusted UI font size. -pub fn adjust_agent_font_size(cx: &mut App, mut f: impl FnMut(&mut Pixels)) { +/// Sets the adjusted agent panel font size. +pub fn adjust_agent_font_size(cx: &mut App, f: impl FnOnce(Pixels) -> Pixels) { let agent_font_size = ThemeSettings::get_global(cx).agent_font_size(cx); - let mut adjusted_size = cx + let adjusted_size = cx .try_global::() .map_or(agent_font_size, |adjusted_size| adjusted_size.0); - - f(&mut adjusted_size); - cx.set_global(AgentFontSize(clamp_font_size(adjusted_size))); + cx.set_global(AgentFontSize(clamp_font_size(f(adjusted_size)))); cx.refresh_windows(); } -/// Resets the UI font size to the default value. +/// Resets the agent panel font size to the default value. pub fn reset_agent_font_size(cx: &mut App) { if cx.has_global::() { cx.remove_global::(); @@ -806,7 +809,7 @@ pub fn reset_agent_font_size(cx: &mut App) { /// Ensures font size is within the valid range. pub fn clamp_font_size(size: Pixels) -> Pixels { - size.max(MIN_FONT_SIZE) + size.clamp(MIN_FONT_SIZE, MAX_FONT_SIZE) } fn clamp_font_weight(weight: f32) -> FontWeight { @@ -860,7 +863,7 @@ impl settings::Settings for ThemeSettings { }, buffer_font_size: defaults.buffer_font_size.unwrap().into(), buffer_line_height: defaults.buffer_line_height.unwrap(), - agent_font_size: defaults.agent_font_size.unwrap().into(), + agent_font_size: defaults.agent_font_size.flatten().map(Into::into), theme_selection: defaults.theme.clone(), active_theme: themes .get(defaults.theme.as_ref().unwrap().theme(*system_appearance)) @@ -959,20 +962,20 @@ impl settings::Settings for ThemeSettings { } } - merge(&mut this.ui_font_size, value.ui_font_size.map(Into::into)); - this.ui_font_size = this.ui_font_size.clamp(px(6.), px(100.)); - + merge( + &mut this.ui_font_size, + value.ui_font_size.map(Into::into).map(clamp_font_size), + ); merge( &mut this.buffer_font_size, - value.buffer_font_size.map(Into::into), + value.buffer_font_size.map(Into::into).map(clamp_font_size), ); - this.buffer_font_size = this.buffer_font_size.clamp(px(6.), px(100.)); - merge( &mut this.agent_font_size, - value.agent_font_size.map(Into::into), + value + .agent_font_size + .map(|value| value.map(Into::into).map(clamp_font_size)), ); - this.agent_font_size = this.agent_font_size.clamp(px(6.), px(100.)); merge(&mut this.buffer_line_height, value.buffer_line_height); diff --git a/crates/theme/src/styles/colors.rs b/crates/theme/src/styles/colors.rs index aab11803f4..198ad97adb 100644 --- a/crates/theme/src/styles/colors.rs +++ b/crates/theme/src/styles/colors.rs @@ -59,6 +59,8 @@ pub struct ThemeColors { pub element_disabled: Hsla, /// Background Color. Used for the area that shows where a dragged element will be dropped. pub drop_target_background: Hsla, + /// Border Color. Used for the border that shows where a dragged element will be dropped. + pub drop_target_border: Hsla, /// Used for the background of a ghost element that should have the same background as the surface it's on. /// /// Elements might include: Buttons, Inputs, Checkboxes, Radio Buttons... @@ -304,6 +306,7 @@ pub enum ThemeColorField { ElementSelected, ElementDisabled, DropTargetBackground, + DropTargetBorder, GhostElementBackground, GhostElementHover, GhostElementActive, @@ -418,6 +421,7 @@ impl ThemeColors { ThemeColorField::ElementSelected => self.element_selected, ThemeColorField::ElementDisabled => self.element_disabled, ThemeColorField::DropTargetBackground => self.drop_target_background, + ThemeColorField::DropTargetBorder => self.drop_target_border, ThemeColorField::GhostElementBackground => self.ghost_element_background, ThemeColorField::GhostElementHover => self.ghost_element_hover, ThemeColorField::GhostElementActive => self.ghost_element_active, diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index f04eeade73..e02324a142 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -107,6 +107,8 @@ pub fn init(themes_to_load: LoadThemes, cx: &mut App) { let mut prev_buffer_font_size_settings = ThemeSettings::get_global(cx).buffer_font_size_settings(); let mut prev_ui_font_size_settings = ThemeSettings::get_global(cx).ui_font_size_settings(); + let mut prev_agent_font_size_settings = + ThemeSettings::get_global(cx).agent_font_size_settings(); cx.observe_global::(move |cx| { let buffer_font_size_settings = ThemeSettings::get_global(cx).buffer_font_size_settings(); if buffer_font_size_settings != prev_buffer_font_size_settings { @@ -119,6 +121,12 @@ pub fn init(themes_to_load: LoadThemes, cx: &mut App) { prev_ui_font_size_settings = ui_font_size_settings; reset_ui_font_size(cx); } + + let agent_font_size_settings = ThemeSettings::get_global(cx).agent_font_size_settings(); + if agent_font_size_settings != prev_agent_font_size_settings { + prev_agent_font_size_settings = agent_font_size_settings; + reset_agent_font_size(cx); + } }) .detach(); } diff --git a/crates/title_bar/src/application_menu.rs b/crates/title_bar/src/application_menu.rs index a5d5f154c9..98f0eeb6cc 100644 --- a/crates/title_bar/src/application_menu.rs +++ b/crates/title_bar/src/application_menu.rs @@ -121,8 +121,16 @@ impl ApplicationMenu { menu.action(name, action) } OwnedMenuItem::Submenu(_) => menu, + OwnedMenuItem::SystemMenu(_) => { + // A system menu doesn't make sense in this context, so ignore it + menu + } }) } + OwnedMenuItem::SystemMenu(_) => { + // A system menu doesn't make sense in this context, so ignore it + menu + } }) }) } diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index d11d3b7081..eb317a5616 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -595,7 +595,7 @@ impl TitleBar { .on_click(|_, window, cx| { if let Some(auto_updater) = auto_update::AutoUpdater::get(cx) { if auto_updater.read(cx).status().is_updated() { - workspace::reload(&Default::default(), cx); + workspace::reload(cx); return; } } diff --git a/crates/toolchain_selector/src/active_toolchain.rs b/crates/toolchain_selector/src/active_toolchain.rs index 631f66a83c..01bd7b0a9c 100644 --- a/crates/toolchain_selector/src/active_toolchain.rs +++ b/crates/toolchain_selector/src/active_toolchain.rs @@ -8,6 +8,7 @@ use gpui::{ use language::{Buffer, BufferEvent, LanguageName, Toolchain}; use project::{Project, ProjectPath, WorktreeId, toolchain_store::ToolchainStoreEvent}; use ui::{Button, ButtonCommon, Clickable, FluentBuilder, LabelSize, SharedString, Tooltip}; +use util::maybe; use workspace::{StatusItemView, Workspace, item::ItemHandle}; use crate::ToolchainSelector; @@ -55,49 +56,61 @@ impl ActiveToolchain { } fn spawn_tracker_task(window: &mut Window, cx: &mut Context) -> Task> { cx.spawn_in(window, async move |this, cx| { - let active_file = this - .read_with(cx, |this, _| { - this.active_buffer - .as_ref() - .map(|(_, buffer, _)| buffer.clone()) - }) - .ok() - .flatten()?; - let workspace = this.read_with(cx, |this, _| this.workspace.clone()).ok()?; - let language_name = active_file - .read_with(cx, |this, _| Some(this.language()?.name())) - .ok() - .flatten()?; - let term = workspace - .update(cx, |workspace, cx| { - let languages = workspace.project().read(cx).languages(); - Project::toolchain_term(languages.clone(), language_name.clone()) - }) - .ok()? - .await?; - let _ = this.update(cx, |this, cx| { - this.term = term; - cx.notify(); - }); - let (worktree_id, path) = active_file - .update(cx, |this, cx| { - this.file().and_then(|file| { - Some(( - file.worktree_id(cx), - Arc::::from(file.path().parent()?), - )) + let did_set_toolchain = maybe!(async { + let active_file = this + .read_with(cx, |this, _| { + this.active_buffer + .as_ref() + .map(|(_, buffer, _)| buffer.clone()) }) + .ok() + .flatten()?; + let workspace = this.read_with(cx, |this, _| this.workspace.clone()).ok()?; + let language_name = active_file + .read_with(cx, |this, _| Some(this.language()?.name())) + .ok() + .flatten()?; + let term = workspace + .update(cx, |workspace, cx| { + let languages = workspace.project().read(cx).languages(); + Project::toolchain_term(languages.clone(), language_name.clone()) + }) + .ok()? + .await?; + let _ = this.update(cx, |this, cx| { + this.term = term; + cx.notify(); + }); + let (worktree_id, path) = active_file + .update(cx, |this, cx| { + this.file().and_then(|file| { + Some(( + file.worktree_id(cx), + Arc::::from(file.path().parent()?), + )) + }) + }) + .ok() + .flatten()?; + let toolchain = + Self::active_toolchain(workspace, worktree_id, path, language_name, cx).await?; + this.update(cx, |this, cx| { + this.active_toolchain = Some(toolchain); + + cx.notify(); }) .ok() - .flatten()?; - let toolchain = - Self::active_toolchain(workspace, worktree_id, path, language_name, cx).await?; - let _ = this.update(cx, |this, cx| { - this.active_toolchain = Some(toolchain); - - cx.notify(); - }); - Some(()) + }) + .await + .is_some(); + if !did_set_toolchain { + this.update(cx, |this, cx| { + this.active_toolchain = None; + cx.notify(); + }) + .ok(); + } + did_set_toolchain.then_some(()) }) } @@ -110,6 +123,17 @@ impl ActiveToolchain { let editor = editor.read(cx); if let Some((_, buffer, _)) = editor.active_excerpt(cx) { if let Some(worktree_id) = buffer.read(cx).file().map(|file| file.worktree_id(cx)) { + if self + .active_buffer + .as_ref() + .is_some_and(|(old_worktree_id, old_buffer, _)| { + (old_worktree_id, old_buffer.entity_id()) + == (&worktree_id, buffer.entity_id()) + }) + { + return; + } + let subscription = cx.subscribe_in( &buffer, window, @@ -231,7 +255,6 @@ impl StatusItemView for ActiveToolchain { cx: &mut Context, ) { if let Some(editor) = active_pane_item.and_then(|item| item.downcast::()) { - self.active_toolchain.take(); self.update_lister(editor, window, cx); } cx.notify(); diff --git a/crates/ui/src/components/button/button.rs b/crates/ui/src/components/button/button.rs index 19f782fb98..cee39ac23b 100644 --- a/crates/ui/src/components/button/button.rs +++ b/crates/ui/src/components/button/button.rs @@ -324,7 +324,7 @@ impl FixedWidth for Button { /// ``` /// /// This sets the button's width to be exactly 100 pixels. - fn width(mut self, width: DefiniteLength) -> Self { + fn width(mut self, width: impl Into) -> Self { self.base = self.base.width(width); self } diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 35c78fbb5d..0b30007e44 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -499,8 +499,8 @@ impl Clickable for ButtonLike { } impl FixedWidth for ButtonLike { - fn width(mut self, width: DefiniteLength) -> Self { - self.width = Some(width); + fn width(mut self, width: impl Into) -> Self { + self.width = Some(width.into()); self } diff --git a/crates/ui/src/components/button/icon_button.rs b/crates/ui/src/components/button/icon_button.rs index 8d8718a634..74fc4851fe 100644 --- a/crates/ui/src/components/button/icon_button.rs +++ b/crates/ui/src/components/button/icon_button.rs @@ -133,7 +133,7 @@ impl Clickable for IconButton { } impl FixedWidth for IconButton { - fn width(mut self, width: DefiniteLength) -> Self { + fn width(mut self, width: impl Into) -> Self { self.base = self.base.width(width); self } @@ -194,7 +194,7 @@ impl RenderOnce for IconButton { .map(|this| match self.shape { IconButtonShape::Square => { let size = self.icon_size.square(window, cx); - this.width(size.into()).height(size.into()) + this.width(size).height(size.into()) } IconButtonShape::Wide => this, }) diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index 91defa730b..2a862f4876 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use gpui::{AnyView, ClickEvent}; +use gpui::{AnyView, ClickEvent, relative}; use crate::{ButtonLike, ButtonLikeRounding, ElevationIndex, TintColor, Tooltip, prelude::*}; @@ -73,8 +73,8 @@ impl SelectableButton for ToggleButton { } impl FixedWidth for ToggleButton { - fn width(mut self, width: DefiniteLength) -> Self { - self.base.width = Some(width); + fn width(mut self, width: impl Into) -> Self { + self.base.width = Some(width.into()); self } @@ -429,7 +429,7 @@ where rows: [[T; COLS]; ROWS], style: ToggleButtonGroupStyle, size: ToggleButtonGroupSize, - button_width: Rems, + group_width: Option, selected_index: usize, tab_index: Option, } @@ -441,7 +441,7 @@ impl ToggleButtonGroup { rows: [buttons], style: ToggleButtonGroupStyle::Transparent, size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), + group_width: None, selected_index: 0, tab_index: None, } @@ -455,7 +455,7 @@ impl ToggleButtonGroup { rows: [first_row, second_row], style: ToggleButtonGroupStyle::Transparent, size: ToggleButtonGroupSize::Default, - button_width: rems_from_px(100.), + group_width: None, selected_index: 0, tab_index: None, } @@ -473,11 +473,6 @@ impl ToggleButtonGroup Self { - self.button_width = button_width; - self - } - pub fn selected_index(mut self, index: usize) -> Self { self.selected_index = index; self @@ -491,6 +486,24 @@ impl ToggleButtonGroup DefiniteLength { + relative(1. / COLS as f32) + } +} + +impl FixedWidth + for ToggleButtonGroup +{ + fn width(mut self, width: impl Into) -> Self { + self.group_width = Some(width.into()); + self + } + + fn full_width(mut self) -> Self { + self.group_width = Some(relative(1.)); + self + } } impl RenderOnce @@ -511,6 +524,7 @@ impl RenderOnce let entry_index = row_index * COLS + col_index; ButtonLike::new((self.group_name, entry_index)) + .full_width() .rounding(None) .when_some(self.tab_index, |this, tab_index| { this.tab_index(tab_index + entry_index as isize) @@ -527,7 +541,7 @@ impl RenderOnce }) .child( h_flex() - .min_w(self.button_width) + .w_full() .gap_1p5() .px_3() .py_1() @@ -561,6 +575,13 @@ impl RenderOnce let is_transparent = self.style == ToggleButtonGroupStyle::Transparent; v_flex() + .map(|this| { + if let Some(width) = self.group_width { + this.w(width) + } else { + this.w_full() + } + }) .rounded_md() .overflow_hidden() .map(|this| { @@ -583,6 +604,8 @@ impl RenderOnce .when(is_outlined_or_filled && !last_item, |this| { this.border_r_1().border_color(border_color) }) + .w(Self::button_width()) + .overflow_hidden() .child(item) })) })) @@ -630,7 +653,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -656,7 +678,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -675,7 +696,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .into_any_element(), ), single_example( @@ -718,7 +738,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .into_any_element(), ), ], @@ -763,7 +782,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -783,7 +801,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -827,7 +844,6 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Outlined) .into_any_element(), ), @@ -873,7 +889,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -893,7 +908,7 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) + .width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -937,7 +952,7 @@ impl Component ], ) .selected_index(3) - .button_width(rems_from_px(100.)) + .width(rems_from_px(100.)) .style(ToggleButtonGroupStyle::Filled) .into_any_element(), ), @@ -957,7 +972,6 @@ impl Component ], ) .selected_index(1) - .button_width(rems_from_px(100.)) .into_any_element(), )]) .into_any_element(), diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index 59c056859d..e5f28e3b25 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -420,7 +420,7 @@ pub struct Switch { id: ElementId, toggle_state: ToggleState, disabled: bool, - on_click: Option>, + on_click: Option>, label: Option, key_binding: Option, color: SwitchColor, @@ -459,7 +459,7 @@ impl Switch { mut self, handler: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { - self.on_click = Some(Box::new(handler)); + self.on_click = Some(Rc::new(handler)); self } @@ -513,10 +513,16 @@ impl RenderOnce for Switch { .when_some( self.tab_index.filter(|_| !self.disabled), |this, tab_index| { - this.tab_index(tab_index).focus(|mut style| { - style.border_color = Some(cx.theme().colors().border_focused); - style - }) + this.tab_index(tab_index) + .focus(|mut style| { + style.border_color = Some(cx.theme().colors().border_focused); + style + }) + .when_some(self.on_click.clone(), |this, on_click| { + this.on_click(move |_, window, cx| { + on_click(&self.toggle_state.inverse(), window, cx) + }) + }) }, ) .child( diff --git a/crates/ui/src/traits/fixed.rs b/crates/ui/src/traits/fixed.rs index 9ba64da090..6ca9c8617f 100644 --- a/crates/ui/src/traits/fixed.rs +++ b/crates/ui/src/traits/fixed.rs @@ -3,7 +3,7 @@ use gpui::DefiniteLength; /// A trait for elements that can have a fixed with. Enables the use of the `width` and `full_width` methods. pub trait FixedWidth { /// Sets the width of the element. - fn width(self, width: DefiniteLength) -> Self; + fn width(self, width: impl Into) -> Self; /// Sets the element's width to the full width of its container. fn full_width(self) -> Self; diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 0a2f7ffff6..7fa578f411 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -1,4 +1,8 @@ -use std::cmp; +use globset::{Glob, GlobSet, GlobSetBuilder}; +use itertools::Itertools; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; use std::path::StripPrefixError; use std::sync::{Arc, OnceLock}; use std::{ @@ -7,13 +11,6 @@ use std::{ sync::LazyLock, }; -use globset::{Glob, GlobSet, GlobSetBuilder}; -use itertools::Itertools; -use regex::Regex; -use serde::{Deserialize, Serialize}; - -use crate::NumericPrefixWithSuffix; - /// Returns the path to the user's home directory. pub fn home_dir() -> &'static PathBuf { static HOME_DIR: OnceLock = OnceLock::new(); @@ -568,17 +565,172 @@ impl PathMatcher { } } +/// Custom character comparison that prioritizes lowercase for same letters +fn compare_chars(a: char, b: char) -> Ordering { + // First compare case-insensitive + match a.to_ascii_lowercase().cmp(&b.to_ascii_lowercase()) { + Ordering::Equal => { + // If same letter, prioritize lowercase (lowercase < uppercase) + match (a.is_ascii_lowercase(), b.is_ascii_lowercase()) { + (true, false) => Ordering::Less, // lowercase comes first + (false, true) => Ordering::Greater, // uppercase comes after + _ => Ordering::Equal, // both same case or both non-ascii + } + } + other => other, + } +} + +/// Compares two sequences of consecutive digits for natural sorting. +/// +/// This function is a core component of natural sorting that handles numeric comparison +/// in a way that feels natural to humans. It extracts and compares consecutive digit +/// sequences from two iterators, handling various cases like leading zeros and very large numbers. +/// +/// # Behavior +/// +/// The function implements the following comparison rules: +/// 1. Different numeric values: Compares by actual numeric value (e.g., "2" < "10") +/// 2. Leading zeros: When values are equal, longer sequence wins (e.g., "002" > "2") +/// 3. Large numbers: Falls back to string comparison for numbers that would overflow u128 +/// +/// # Examples +/// +/// ```text +/// "1" vs "2" -> Less (different values) +/// "2" vs "10" -> Less (numeric comparison) +/// "002" vs "2" -> Greater (leading zeros) +/// "10" vs "010" -> Less (leading zeros) +/// "999..." vs "1000..." -> Less (large number comparison) +/// ``` +/// +/// # Implementation Details +/// +/// 1. Extracts consecutive digits into strings +/// 2. Compares sequence lengths for leading zero handling +/// 3. For equal lengths, compares digit by digit +/// 4. For different lengths: +/// - Attempts numeric comparison first (for numbers up to 2^128 - 1) +/// - Falls back to string comparison if numbers would overflow +/// +/// The function advances both iterators past their respective numeric sequences, +/// regardless of the comparison result. +fn compare_numeric_segments( + a_iter: &mut std::iter::Peekable, + b_iter: &mut std::iter::Peekable, +) -> Ordering +where + I: Iterator, +{ + // Collect all consecutive digits into strings + let mut a_num_str = String::new(); + let mut b_num_str = String::new(); + + while let Some(&c) = a_iter.peek() { + if !c.is_ascii_digit() { + break; + } + + a_num_str.push(c); + a_iter.next(); + } + + while let Some(&c) = b_iter.peek() { + if !c.is_ascii_digit() { + break; + } + + b_num_str.push(c); + b_iter.next(); + } + + // First compare lengths (handle leading zeros) + match a_num_str.len().cmp(&b_num_str.len()) { + Ordering::Equal => { + // Same length, compare digit by digit + match a_num_str.cmp(&b_num_str) { + Ordering::Equal => Ordering::Equal, + ordering => ordering, + } + } + + // Different lengths but same value means leading zeros + ordering => { + // Try parsing as numbers first + if let (Ok(a_val), Ok(b_val)) = (a_num_str.parse::(), b_num_str.parse::()) { + match a_val.cmp(&b_val) { + Ordering::Equal => ordering, // Same value, longer one is greater (leading zeros) + ord => ord, + } + } else { + // If parsing fails (overflow), compare as strings + a_num_str.cmp(&b_num_str) + } + } + } +} + +/// Performs natural sorting comparison between two strings. +/// +/// Natural sorting is an ordering that handles numeric sequences in a way that matches human expectations. +/// For example, "file2" comes before "file10" (unlike standard lexicographic sorting). +/// +/// # Characteristics +/// +/// * Case-sensitive with lowercase priority: When comparing same letters, lowercase comes before uppercase +/// * Numbers are compared by numeric value, not character by character +/// * Leading zeros affect ordering when numeric values are equal +/// * Can handle numbers larger than u128::MAX (falls back to string comparison) +/// +/// # Algorithm +/// +/// The function works by: +/// 1. Processing strings character by character +/// 2. When encountering digits, treating consecutive digits as a single number +/// 3. Comparing numbers by their numeric value rather than lexicographically +/// 4. For non-numeric characters, using case-sensitive comparison with lowercase priority +fn natural_sort(a: &str, b: &str) -> Ordering { + let mut a_iter = a.chars().peekable(); + let mut b_iter = b.chars().peekable(); + + loop { + match (a_iter.peek(), b_iter.peek()) { + (None, None) => return Ordering::Equal, + (None, _) => return Ordering::Less, + (_, None) => return Ordering::Greater, + (Some(&a_char), Some(&b_char)) => { + if a_char.is_ascii_digit() && b_char.is_ascii_digit() { + match compare_numeric_segments(&mut a_iter, &mut b_iter) { + Ordering::Equal => continue, + ordering => return ordering, + } + } else { + match compare_chars(a_char, b_char) { + Ordering::Equal => { + a_iter.next(); + b_iter.next(); + } + ordering => return ordering, + } + } + } + } + } +} + pub fn compare_paths( (path_a, a_is_file): (&Path, bool), (path_b, b_is_file): (&Path, bool), -) -> cmp::Ordering { +) -> Ordering { let mut components_a = path_a.components().peekable(); let mut components_b = path_b.components().peekable(); + loop { match (components_a.next(), components_b.next()) { (Some(component_a), Some(component_b)) => { let a_is_file = components_a.peek().is_none() && a_is_file; let b_is_file = components_b.peek().is_none() && b_is_file; + let ordering = a_is_file.cmp(&b_is_file).then_with(|| { let path_a = Path::new(component_a.as_os_str()); let path_string_a = if a_is_file { @@ -587,9 +739,6 @@ pub fn compare_paths( path_a.file_name() } .map(|s| s.to_string_lossy()); - let num_and_remainder_a = path_string_a - .as_deref() - .map(NumericPrefixWithSuffix::from_numeric_prefixed_str); let path_b = Path::new(component_b.as_os_str()); let path_string_b = if b_is_file { @@ -598,27 +747,32 @@ pub fn compare_paths( path_b.file_name() } .map(|s| s.to_string_lossy()); - let num_and_remainder_b = path_string_b - .as_deref() - .map(NumericPrefixWithSuffix::from_numeric_prefixed_str); - num_and_remainder_a.cmp(&num_and_remainder_b).then_with(|| { + let compare_components = match (path_string_a, path_string_b) { + (Some(a), Some(b)) => natural_sort(&a, &b), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + }; + + compare_components.then_with(|| { if a_is_file && b_is_file { let ext_a = path_a.extension().unwrap_or_default(); let ext_b = path_b.extension().unwrap_or_default(); ext_a.cmp(ext_b) } else { - cmp::Ordering::Equal + Ordering::Equal } }) }); + if !ordering.is_eq() { return ordering; } } - (Some(_), None) => break cmp::Ordering::Greater, - (None, Some(_)) => break cmp::Ordering::Less, - (None, None) => break cmp::Ordering::Equal, + (Some(_), None) => break Ordering::Greater, + (None, Some(_)) => break Ordering::Less, + (None, None) => break Ordering::Equal, } } } @@ -1091,4 +1245,335 @@ mod tests { "C:\\Users\\someone\\test_file.rs" ); } + + #[test] + fn test_compare_numeric_segments() { + // Helper function to create peekable iterators and test + fn compare(a: &str, b: &str) -> Ordering { + let mut a_iter = a.chars().peekable(); + let mut b_iter = b.chars().peekable(); + + let result = compare_numeric_segments(&mut a_iter, &mut b_iter); + + // Verify iterators advanced correctly + assert!( + !a_iter.next().map_or(false, |c| c.is_ascii_digit()), + "Iterator a should have consumed all digits" + ); + assert!( + !b_iter.next().map_or(false, |c| c.is_ascii_digit()), + "Iterator b should have consumed all digits" + ); + + result + } + + // Basic numeric comparisons + assert_eq!(compare("0", "0"), Ordering::Equal); + assert_eq!(compare("1", "2"), Ordering::Less); + assert_eq!(compare("9", "10"), Ordering::Less); + assert_eq!(compare("10", "9"), Ordering::Greater); + assert_eq!(compare("99", "100"), Ordering::Less); + + // Leading zeros + assert_eq!(compare("0", "00"), Ordering::Less); + assert_eq!(compare("00", "0"), Ordering::Greater); + assert_eq!(compare("01", "1"), Ordering::Greater); + assert_eq!(compare("001", "1"), Ordering::Greater); + assert_eq!(compare("001", "01"), Ordering::Greater); + + // Same value different representation + assert_eq!(compare("000100", "100"), Ordering::Greater); + assert_eq!(compare("100", "0100"), Ordering::Less); + assert_eq!(compare("0100", "00100"), Ordering::Less); + + // Large numbers + assert_eq!(compare("9999999999", "10000000000"), Ordering::Less); + assert_eq!( + compare( + "340282366920938463463374607431768211455", // u128::MAX + "340282366920938463463374607431768211456" + ), + Ordering::Less + ); + assert_eq!( + compare( + "340282366920938463463374607431768211456", // > u128::MAX + "340282366920938463463374607431768211455" + ), + Ordering::Greater + ); + + // Iterator advancement verification + let mut a_iter = "123abc".chars().peekable(); + let mut b_iter = "456def".chars().peekable(); + + compare_numeric_segments(&mut a_iter, &mut b_iter); + + assert_eq!(a_iter.collect::(), "abc"); + assert_eq!(b_iter.collect::(), "def"); + } + + #[test] + fn test_natural_sort() { + // Basic alphanumeric + assert_eq!(natural_sort("a", "b"), Ordering::Less); + assert_eq!(natural_sort("b", "a"), Ordering::Greater); + assert_eq!(natural_sort("a", "a"), Ordering::Equal); + + // Case sensitivity + assert_eq!(natural_sort("a", "A"), Ordering::Less); + assert_eq!(natural_sort("A", "a"), Ordering::Greater); + assert_eq!(natural_sort("aA", "aa"), Ordering::Greater); + assert_eq!(natural_sort("aa", "aA"), Ordering::Less); + + // Numbers + assert_eq!(natural_sort("1", "2"), Ordering::Less); + assert_eq!(natural_sort("2", "10"), Ordering::Less); + assert_eq!(natural_sort("02", "10"), Ordering::Less); + assert_eq!(natural_sort("02", "2"), Ordering::Greater); + + // Mixed alphanumeric + assert_eq!(natural_sort("a1", "a2"), Ordering::Less); + assert_eq!(natural_sort("a2", "a10"), Ordering::Less); + assert_eq!(natural_sort("a02", "a2"), Ordering::Greater); + assert_eq!(natural_sort("a1b", "a1c"), Ordering::Less); + + // Multiple numeric segments + assert_eq!(natural_sort("1a2", "1a10"), Ordering::Less); + assert_eq!(natural_sort("1a10", "1a2"), Ordering::Greater); + assert_eq!(natural_sort("2a1", "10a1"), Ordering::Less); + + // Special characters + assert_eq!(natural_sort("a-1", "a-2"), Ordering::Less); + assert_eq!(natural_sort("a_1", "a_2"), Ordering::Less); + assert_eq!(natural_sort("a.1", "a.2"), Ordering::Less); + + // Unicode + assert_eq!(natural_sort("文1", "文2"), Ordering::Less); + assert_eq!(natural_sort("文2", "文10"), Ordering::Less); + assert_eq!(natural_sort("🔤1", "🔤2"), Ordering::Less); + + // Empty and special cases + assert_eq!(natural_sort("", ""), Ordering::Equal); + assert_eq!(natural_sort("", "a"), Ordering::Less); + assert_eq!(natural_sort("a", ""), Ordering::Greater); + assert_eq!(natural_sort(" ", " "), Ordering::Less); + + // Mixed everything + assert_eq!(natural_sort("File-1.txt", "File-2.txt"), Ordering::Less); + assert_eq!(natural_sort("File-02.txt", "File-2.txt"), Ordering::Greater); + assert_eq!(natural_sort("File-2.txt", "File-10.txt"), Ordering::Less); + assert_eq!(natural_sort("File_A1", "File_A2"), Ordering::Less); + assert_eq!(natural_sort("File_a1", "File_A1"), Ordering::Less); + } + + #[test] + fn test_compare_paths() { + // Helper function for cleaner tests + fn compare(a: &str, is_a_file: bool, b: &str, is_b_file: bool) -> Ordering { + compare_paths((Path::new(a), is_a_file), (Path::new(b), is_b_file)) + } + + // Basic path comparison + assert_eq!(compare("a", true, "b", true), Ordering::Less); + assert_eq!(compare("b", true, "a", true), Ordering::Greater); + assert_eq!(compare("a", true, "a", true), Ordering::Equal); + + // Files vs Directories + assert_eq!(compare("a", true, "a", false), Ordering::Greater); + assert_eq!(compare("a", false, "a", true), Ordering::Less); + assert_eq!(compare("b", false, "a", true), Ordering::Less); + + // Extensions + assert_eq!(compare("a.txt", true, "a.md", true), Ordering::Greater); + assert_eq!(compare("a.md", true, "a.txt", true), Ordering::Less); + assert_eq!(compare("a", true, "a.txt", true), Ordering::Less); + + // Nested paths + assert_eq!(compare("dir/a", true, "dir/b", true), Ordering::Less); + assert_eq!(compare("dir1/a", true, "dir2/a", true), Ordering::Less); + assert_eq!(compare("dir/sub/a", true, "dir/a", true), Ordering::Less); + + // Case sensitivity in paths + assert_eq!( + compare("Dir/file", true, "dir/file", true), + Ordering::Greater + ); + assert_eq!( + compare("dir/File", true, "dir/file", true), + Ordering::Greater + ); + assert_eq!(compare("dir/file", true, "Dir/File", true), Ordering::Less); + + // Hidden files and special names + assert_eq!(compare(".hidden", true, "visible", true), Ordering::Less); + assert_eq!(compare("_special", true, "normal", true), Ordering::Less); + assert_eq!(compare(".config", false, ".data", false), Ordering::Less); + + // Mixed numeric paths + assert_eq!( + compare("dir1/file", true, "dir2/file", true), + Ordering::Less + ); + assert_eq!( + compare("dir2/file", true, "dir10/file", true), + Ordering::Less + ); + assert_eq!( + compare("dir02/file", true, "dir2/file", true), + Ordering::Greater + ); + + // Root paths + assert_eq!(compare("/a", true, "/b", true), Ordering::Less); + assert_eq!(compare("/", false, "/a", true), Ordering::Less); + + // Complex real-world examples + assert_eq!( + compare("project/src/main.rs", true, "project/src/lib.rs", true), + Ordering::Greater + ); + assert_eq!( + compare( + "project/tests/test_1.rs", + true, + "project/tests/test_2.rs", + true + ), + Ordering::Less + ); + assert_eq!( + compare( + "project/v1.0.0/README.md", + true, + "project/v1.10.0/README.md", + true + ), + Ordering::Less + ); + } + + #[test] + fn test_natural_sort_case_sensitivity() { + // Same letter different case - lowercase should come first + assert_eq!(natural_sort("a", "A"), Ordering::Less); + assert_eq!(natural_sort("A", "a"), Ordering::Greater); + assert_eq!(natural_sort("a", "a"), Ordering::Equal); + assert_eq!(natural_sort("A", "A"), Ordering::Equal); + + // Mixed case strings + assert_eq!(natural_sort("aaa", "AAA"), Ordering::Less); + assert_eq!(natural_sort("AAA", "aaa"), Ordering::Greater); + assert_eq!(natural_sort("aAa", "AaA"), Ordering::Less); + + // Different letters + assert_eq!(natural_sort("a", "b"), Ordering::Less); + assert_eq!(natural_sort("A", "b"), Ordering::Less); + assert_eq!(natural_sort("a", "B"), Ordering::Less); + } + + #[test] + fn test_natural_sort_with_numbers() { + // Basic number ordering + assert_eq!(natural_sort("file1", "file2"), Ordering::Less); + assert_eq!(natural_sort("file2", "file10"), Ordering::Less); + assert_eq!(natural_sort("file10", "file2"), Ordering::Greater); + + // Numbers in different positions + assert_eq!(natural_sort("1file", "2file"), Ordering::Less); + assert_eq!(natural_sort("file1text", "file2text"), Ordering::Less); + assert_eq!(natural_sort("text1file", "text2file"), Ordering::Less); + + // Multiple numbers in string + assert_eq!(natural_sort("file1-2", "file1-10"), Ordering::Less); + assert_eq!(natural_sort("2-1file", "10-1file"), Ordering::Less); + + // Leading zeros + assert_eq!(natural_sort("file002", "file2"), Ordering::Greater); + assert_eq!(natural_sort("file002", "file10"), Ordering::Less); + + // Very large numbers + assert_eq!( + natural_sort("file999999999999999999999", "file999999999999999999998"), + Ordering::Greater + ); + + // u128 edge cases + + // Numbers near u128::MAX (340,282,366,920,938,463,463,374,607,431,768,211,455) + assert_eq!( + natural_sort( + "file340282366920938463463374607431768211454", + "file340282366920938463463374607431768211455" + ), + Ordering::Less + ); + + // Equal length numbers that overflow u128 + assert_eq!( + natural_sort( + "file340282366920938463463374607431768211456", + "file340282366920938463463374607431768211455" + ), + Ordering::Greater + ); + + // Different length numbers that overflow u128 + assert_eq!( + natural_sort( + "file3402823669209384634633746074317682114560", + "file340282366920938463463374607431768211455" + ), + Ordering::Greater + ); + + // Leading zeros with numbers near u128::MAX + assert_eq!( + natural_sort( + "file0340282366920938463463374607431768211455", + "file340282366920938463463374607431768211455" + ), + Ordering::Greater + ); + + // Very large numbers with different lengths (both overflow u128) + assert_eq!( + natural_sort( + "file999999999999999999999999999999999999999999999999", + "file9999999999999999999999999999999999999999999999999" + ), + Ordering::Less + ); + + // Mixed case with numbers + assert_eq!(natural_sort("File1", "file2"), Ordering::Greater); + assert_eq!(natural_sort("file1", "File2"), Ordering::Less); + } + + #[test] + fn test_natural_sort_edge_cases() { + // Empty strings + assert_eq!(natural_sort("", ""), Ordering::Equal); + assert_eq!(natural_sort("", "a"), Ordering::Less); + assert_eq!(natural_sort("a", ""), Ordering::Greater); + + // Special characters + assert_eq!(natural_sort("file-1", "file_1"), Ordering::Less); + assert_eq!(natural_sort("file.1", "file_1"), Ordering::Less); + assert_eq!(natural_sort("file 1", "file_1"), Ordering::Less); + + // Unicode characters + // 9312 vs 9313 + assert_eq!(natural_sort("file①", "file②"), Ordering::Less); + // 9321 vs 9313 + assert_eq!(natural_sort("file⑩", "file②"), Ordering::Greater); + // 28450 vs 23383 + assert_eq!(natural_sort("file漢", "file字"), Ordering::Greater); + + // Mixed alphanumeric with special chars + assert_eq!(natural_sort("file-1a", "file-1b"), Ordering::Less); + assert_eq!(natural_sort("file-1.2", "file-1.10"), Ordering::Less); + assert_eq!(natural_sort("file-1.10", "file-1.2"), Ordering::Greater); + } } diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index 932b519b18..e1b25f4dba 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -669,9 +669,12 @@ where let file = caller.file(); #[cfg(target_os = "windows")] let file = caller.file().replace('\\', "/"); - // In this codebase, the first segment of the file path is - // the 'crates' folder, followed by the crate name. - let target = file.split('/').nth(1); + // In this codebase all crates reside in a `crates` directory, + // so discard the prefix up to that segment to find the crate name + let target = file + .split_once("crates/") + .and_then(|(_, s)| s.split_once('/')) + .map(|(p, _)| p); log::logger().log( &log::Record::builder() @@ -884,10 +887,10 @@ macro_rules! maybe { (|| $block)() }; (async $block:block) => { - (|| async $block)() + (async || $block)() }; (async move $block:block) => { - (|| async move $block)() + (async move || $block)() }; } diff --git a/crates/vercel/src/vercel.rs b/crates/vercel/src/vercel.rs index 1ae22c5fef..8686fda53f 100644 --- a/crates/vercel/src/vercel.rs +++ b/crates/vercel/src/vercel.rs @@ -71,4 +71,8 @@ impl Model { Model::Custom { .. } => false, } } + + pub fn supports_prompt_cache_key(&self) -> bool { + false + } } diff --git a/crates/vim/Cargo.toml b/crates/vim/Cargo.toml index 9fb5c46564..434b14b07c 100644 --- a/crates/vim/Cargo.toml +++ b/crates/vim/Cargo.toml @@ -24,6 +24,7 @@ command_palette.workspace = true command_palette_hooks.workspace = true db.workspace = true editor.workspace = true +env_logger.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true diff --git a/crates/vim/src/change_list.rs b/crates/vim/src/change_list.rs index a59083f7ab..c92ce4720e 100644 --- a/crates/vim/src/change_list.rs +++ b/crates/vim/src/change_list.rs @@ -31,7 +31,7 @@ impl Vim { ) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { if let Some(selections) = editor .change_list .next_change(count, direction) @@ -49,7 +49,7 @@ impl Vim { } pub(crate) fn push_to_change_list(&mut self, window: &mut Window, cx: &mut Context) { - let Some((new_positions, buffer)) = self.update_editor(window, cx, |vim, editor, _, cx| { + let Some((new_positions, buffer)) = self.update_editor(cx, |vim, editor, cx| { let (map, selections) = editor.selections.all_adjusted_display(cx); let buffer = editor.buffer().clone(); diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 7963db3571..ce5e5a0300 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -241,9 +241,9 @@ impl Deref for WrappedAction { pub fn register(editor: &mut Editor, cx: &mut Context) { // Vim::action(editor, cx, |vim, action: &StartOfLine, window, cx| { - Vim::action(editor, cx, |vim, action: &VimSet, window, cx| { + Vim::action(editor, cx, |vim, action: &VimSet, _, cx| { for option in action.options.iter() { - vim.update_editor(window, cx, |_, editor, _, cx| match option { + vim.update_editor(cx, |_, editor, cx| match option { VimOption::Wrap(true) => { editor .set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx); @@ -298,8 +298,8 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, action: &VimSave, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { - let Some(project) = editor.project.clone() else { + vim.update_editor(cx, |_, editor, cx| { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { @@ -375,7 +375,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { cx, ); } - vim.update_editor(window, cx, |vim, editor, window, cx| match action { + vim.update_editor(cx, |vim, editor, cx| match action { DeleteMarks::Marks(s) => { if s.starts_with('-') || s.ends_with('-') || s.contains(['\'', '`']) { err(s.clone(), window, cx); @@ -432,11 +432,11 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, action: &VimEdit, window, cx| { - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { let Some(workspace) = vim.workspace(window) else { return; }; - let Some(project) = editor.project.clone() else { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { @@ -462,11 +462,10 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { .map(|c| Keystroke::parse(&c.to_string()).unwrap()) .collect(); vim.switch_mode(Mode::Normal, true, window, cx); - let initial_selections = vim.update_editor(window, cx, |_, editor, _, _| { - editor.selections.disjoint_anchors() - }); + let initial_selections = + vim.update_editor(cx, |_, editor, _| editor.selections.disjoint_anchors()); if let Some(range) = &action.range { - let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let result = vim.update_editor(cx, |vim, editor, cx| { let range = range.buffer_range(vim, editor, window, cx)?; editor.change_selections( SelectionEffects::no_scroll().nav_history(false), @@ -498,7 +497,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { cx.spawn_in(window, async move |vim, cx| { task.await; vim.update_in(cx, |vim, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { if had_range { editor.change_selections(SelectionEffects::default(), window, cx, |s| { s.select_anchor_ranges([s.newest_anchor().range()]); @@ -510,7 +509,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { } else { vim.switch_mode(Mode::Normal, true, window, cx); } - vim.update_editor(window, cx, |_, editor, _, cx| { + vim.update_editor(cx, |_, editor, cx| { if let Some(first_sel) = initial_selections { if let Some(tx_id) = editor .buffer() @@ -548,7 +547,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, action: &GoToLine, window, cx| { vim.switch_mode(Mode::Normal, false, window, cx); - let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let result = vim.update_editor(cx, |vim, editor, cx| { let snapshot = editor.snapshot(window, cx); let buffer_row = action.range.head().buffer_row(vim, editor, window, cx)?; let current = editor.selections.newest::(cx); @@ -573,7 +572,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, action: &YankCommand, window, cx| { - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { let snapshot = editor.snapshot(window, cx); if let Ok(range) = action.range.buffer_range(vim, editor, window, cx) { let end = if range.end < snapshot.buffer_snapshot.max_row() { @@ -600,7 +599,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, action: &WithRange, window, cx| { - let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let result = vim.update_editor(cx, |vim, editor, cx| { action.range.buffer_range(vim, editor, window, cx) }); @@ -619,7 +618,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { }; let previous_selections = vim - .update_editor(window, cx, |_, editor, window, cx| { + .update_editor(cx, |_, editor, cx| { let selections = action.restore_selection.then(|| { editor .selections @@ -635,7 +634,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { .flatten(); window.dispatch_action(action.action.boxed_clone(), cx); cx.defer_in(window, move |vim, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { if let Some(previous_selections) = previous_selections { s.select_ranges(previous_selections); @@ -1176,8 +1175,10 @@ fn generate_commands(_: &App) -> Vec { VimCommand::str(("ls", ""), "tab_switcher::ToggleAll"), VimCommand::new(("new", ""), workspace::NewFileSplitHorizontal), VimCommand::new(("vne", "w"), workspace::NewFileSplitVertical), - VimCommand::new(("tabe", "dit"), workspace::NewFile), - VimCommand::new(("tabnew", ""), workspace::NewFile), + VimCommand::new(("tabe", "dit"), workspace::NewFile) + .args(|_action, args| Some(VimEdit { filename: args }.boxed_clone())), + VimCommand::new(("tabnew", ""), workspace::NewFile) + .args(|_action, args| Some(VimEdit { filename: args }.boxed_clone())), VimCommand::new(("tabn", "ext"), workspace::ActivateNextItem).count(), VimCommand::new(("tabp", "revious"), workspace::ActivatePreviousItem).count(), VimCommand::new(("tabN", "ext"), workspace::ActivatePreviousItem).count(), @@ -1536,7 +1537,7 @@ impl OnMatchingLines { } pub fn run(&self, vim: &mut Vim, window: &mut Window, cx: &mut Context) { - let result = vim.update_editor(window, cx, |vim, editor, window, cx| { + let result = vim.update_editor(cx, |vim, editor, cx| { self.range.buffer_range(vim, editor, window, cx) }); @@ -1600,7 +1601,7 @@ impl OnMatchingLines { }); }; - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { let snapshot = editor.snapshot(window, cx); let mut row = range.start.0; @@ -1680,7 +1681,7 @@ pub struct ShellExec { impl Vim { pub fn cancel_running_command(&mut self, window: &mut Window, cx: &mut Context) { if self.running_command.take().is_some() { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, _window, _cx| { editor.clear_row_highlights::(); }) @@ -1691,7 +1692,7 @@ impl Vim { fn prepare_shell_command( &mut self, command: &str, - window: &mut Window, + _: &mut Window, cx: &mut Context, ) -> String { let mut ret = String::new(); @@ -1711,7 +1712,7 @@ impl Vim { } match c { '%' => { - self.update_editor(window, cx, |_, editor, _window, cx| { + self.update_editor(cx, |_, editor, cx| { if let Some((_, buffer, _)) = editor.active_excerpt(cx) { if let Some(file) = buffer.read(cx).file() { if let Some(local) = file.as_local() { @@ -1747,7 +1748,7 @@ impl Vim { let Some(workspace) = self.workspace(window) else { return; }; - let command = self.update_editor(window, cx, |_, editor, window, cx| { + let command = self.update_editor(cx, |_, editor, cx| { let snapshot = editor.snapshot(window, cx); let start = editor.selections.newest_display(cx); let text_layout_details = editor.text_layout_details(window); @@ -1794,7 +1795,7 @@ impl Vim { let Some(workspace) = self.workspace(window) else { return; }; - let command = self.update_editor(window, cx, |_, editor, window, cx| { + let command = self.update_editor(cx, |_, editor, cx| { let snapshot = editor.snapshot(window, cx); let start = editor.selections.newest_display(cx); let range = object @@ -1896,7 +1897,7 @@ impl ShellExec { let mut input_snapshot = None; let mut input_range = None; let mut needs_newline_prefix = false; - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); let range = if let Some(range) = self.range.clone() { let Some(range) = range.buffer_range(vim, editor, window, cx).log_err() else { @@ -1990,7 +1991,7 @@ impl ShellExec { } vim.update_in(cx, |vim, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.edit([(range.clone(), text)], cx); let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -2477,4 +2478,110 @@ mod test { "}); // Once ctrl-v to input character literals is added there should be a test for redo } + + #[gpui::test] + async fn test_command_tabnew(cx: &mut TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + // Create a new file to ensure that, when the filename is used with + // `:tabnew`, it opens the existing file in a new tab. + let fs = cx.workspace(|workspace, _, cx| workspace.project().read(cx).fs().clone()); + fs.as_fake() + .insert_file(path!("/root/dir/file_2.rs"), "file_2".as_bytes().to_vec()) + .await; + + cx.simulate_keystrokes(": tabnew"); + cx.simulate_keystrokes("enter"); + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 2)); + + // Assert that the new tab is empty and not associated with any file, as + // no file path was provided to the `:tabnew` command. + cx.workspace(|workspace, _window, cx| { + let active_editor = workspace.active_item_as::(cx).unwrap(); + let buffer = active_editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .unwrap(); + + assert!(&buffer.read(cx).file().is_none()); + }); + + // Leverage the filename as an argument to the `:tabnew` command, + // ensuring that the file, instead of an empty buffer, is opened in a + // new tab. + cx.simulate_keystrokes(": tabnew space dir/file_2.rs"); + cx.simulate_keystrokes("enter"); + + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 3)); + cx.workspace(|workspace, _, cx| { + assert_active_item(workspace, path!("/root/dir/file_2.rs"), "file_2", cx); + }); + + // If the `filename` argument provided to the `:tabnew` command is for a + // file that doesn't yet exist, it should still associate the buffer + // with that file path, so that when the buffer contents are saved, the + // file is created. + cx.simulate_keystrokes(": tabnew space dir/file_3.rs"); + cx.simulate_keystrokes("enter"); + + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 4)); + cx.workspace(|workspace, _, cx| { + assert_active_item(workspace, path!("/root/dir/file_3.rs"), "", cx); + }); + } + + #[gpui::test] + async fn test_command_tabedit(cx: &mut TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + // Create a new file to ensure that, when the filename is used with + // `:tabedit`, it opens the existing file in a new tab. + let fs = cx.workspace(|workspace, _, cx| workspace.project().read(cx).fs().clone()); + fs.as_fake() + .insert_file(path!("/root/dir/file_2.rs"), "file_2".as_bytes().to_vec()) + .await; + + cx.simulate_keystrokes(": tabedit"); + cx.simulate_keystrokes("enter"); + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 2)); + + // Assert that the new tab is empty and not associated with any file, as + // no file path was provided to the `:tabedit` command. + cx.workspace(|workspace, _window, cx| { + let active_editor = workspace.active_item_as::(cx).unwrap(); + let buffer = active_editor + .read(cx) + .buffer() + .read(cx) + .as_singleton() + .unwrap(); + + assert!(&buffer.read(cx).file().is_none()); + }); + + // Leverage the filename as an argument to the `:tabedit` command, + // ensuring that the file, instead of an empty buffer, is opened in a + // new tab. + cx.simulate_keystrokes(": tabedit space dir/file_2.rs"); + cx.simulate_keystrokes("enter"); + + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 3)); + cx.workspace(|workspace, _, cx| { + assert_active_item(workspace, path!("/root/dir/file_2.rs"), "file_2", cx); + }); + + // If the `filename` argument provided to the `:tabedit` command is for a + // file that doesn't yet exist, it should still associate the buffer + // with that file path, so that when the buffer contents are saved, the + // file is created. + cx.simulate_keystrokes(": tabedit space dir/file_3.rs"); + cx.simulate_keystrokes("enter"); + + cx.workspace(|workspace, _, cx| assert_eq!(workspace.items(cx).count(), 4)); + cx.workspace(|workspace, _, cx| { + assert_active_item(workspace, path!("/root/dir/file_3.rs"), "", cx); + }); + } } diff --git a/crates/vim/src/digraph.rs b/crates/vim/src/digraph.rs index 881454392a..c555b781b1 100644 --- a/crates/vim/src/digraph.rs +++ b/crates/vim/src/digraph.rs @@ -56,9 +56,7 @@ impl Vim { self.pop_operator(window, cx); if self.editor_input_enabled() { - self.update_editor(window, cx, |_, editor, window, cx| { - editor.insert(&text, window, cx) - }); + self.update_editor(cx, |_, editor, cx| editor.insert(&text, window, cx)); } else { self.input_ignored(text, window, cx); } @@ -214,9 +212,7 @@ impl Vim { text.push_str(suffix); if self.editor_input_enabled() { - self.update_editor(window, cx, |_, editor, window, cx| { - editor.insert(&text, window, cx) - }); + self.update_editor(cx, |_, editor, cx| editor.insert(&text, window, cx)); } else { self.input_ignored(text.into(), window, cx); } diff --git a/crates/vim/src/helix.rs b/crates/vim/src/helix.rs index ca93c9c1de..0c8c06d8ab 100644 --- a/crates/vim/src/helix.rs +++ b/crates/vim/src/helix.rs @@ -1,9 +1,11 @@ +use editor::display_map::DisplaySnapshot; use editor::{DisplayPoint, Editor, SelectionEffects, ToOffset, ToPoint, movement}; use gpui::{Action, actions}; use gpui::{Context, Window}; use language::{CharClassifier, CharKind}; use text::{Bias, SelectionGoal}; +use crate::motion; use crate::{ Vim, motion::{Motion, right}, @@ -15,6 +17,8 @@ actions!( [ /// Switches to normal mode after the cursor (Helix-style). HelixNormalAfter, + /// Yanks the current selection or character if no selection. + HelixYank, /// Inserts at the beginning of the selection. HelixInsert, /// Appends at the end of the selection. @@ -26,6 +30,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, Vim::helix_normal_after); Vim::action(editor, cx, Vim::helix_insert); Vim::action(editor, cx, Vim::helix_append); + Vim::action(editor, cx, Vim::helix_yank); } impl Vim { @@ -55,6 +60,35 @@ impl Vim { self.helix_move_cursor(motion, times, window, cx); } + /// Updates all selections based on where the cursors are. + fn helix_new_selections( + &mut self, + window: &mut Window, + cx: &mut Context, + mut change: impl FnMut( + // the start of the cursor + DisplayPoint, + &DisplaySnapshot, + ) -> Option<(DisplayPoint, DisplayPoint)>, + ) { + self.update_editor(cx, |_, editor, cx| { + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|map, selection| { + let cursor_start = if selection.reversed || selection.is_empty() { + selection.head() + } else { + movement::left(map, selection.head()) + }; + let Some((head, tail)) = change(cursor_start, map) else { + return; + }; + + selection.set_head_tail(head, tail, SelectionGoal::None); + }); + }); + }); + } + fn helix_find_range_forward( &mut self, times: Option, @@ -62,49 +96,30 @@ impl Vim { cx: &mut Context, mut is_boundary: impl FnMut(char, char, &CharClassifier) -> bool, ) { - self.update_editor(window, cx, |_, editor, window, cx| { - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|map, selection| { - let times = times.unwrap_or(1); - let new_goal = SelectionGoal::None; - let mut head = selection.head(); - let mut tail = selection.tail(); + let times = times.unwrap_or(1); + self.helix_new_selections(window, cx, |cursor, map| { + let mut head = movement::right(map, cursor); + let mut tail = cursor; + let classifier = map.buffer_snapshot.char_classifier_at(head.to_point(map)); + if head == map.max_point() { + return None; + } + for _ in 0..times { + let (maybe_next_tail, next_head) = + movement::find_boundary_trail(map, head, |left, right| { + is_boundary(left, right, &classifier) + }); - if head == map.max_point() { - return; - } + if next_head == head && maybe_next_tail.unwrap_or(next_head) == tail { + break; + } - // collapse to block cursor - if tail < head { - tail = movement::left(map, head); - } else { - tail = head; - head = movement::right(map, head); - } - - // create a classifier - let classifier = map.buffer_snapshot.char_classifier_at(head.to_point(map)); - - for _ in 0..times { - let (maybe_next_tail, next_head) = - movement::find_boundary_trail(map, head, |left, right| { - is_boundary(left, right, &classifier) - }); - - if next_head == head && maybe_next_tail.unwrap_or(next_head) == tail { - break; - } - - head = next_head; - if let Some(next_tail) = maybe_next_tail { - tail = next_tail; - } - } - - selection.set_tail(tail, new_goal); - selection.set_head(head, new_goal); - }); - }); + head = next_head; + if let Some(next_tail) = maybe_next_tail { + tail = next_tail; + } + } + Some((head, tail)) }); } @@ -115,56 +130,33 @@ impl Vim { cx: &mut Context, mut is_boundary: impl FnMut(char, char, &CharClassifier) -> bool, ) { - self.update_editor(window, cx, |_, editor, window, cx| { - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|map, selection| { - let times = times.unwrap_or(1); - let new_goal = SelectionGoal::None; - let mut head = selection.head(); - let mut tail = selection.tail(); + let times = times.unwrap_or(1); + self.helix_new_selections(window, cx, |cursor, map| { + let mut head = cursor; + // The original cursor was one character wide, + // but the search starts from the left side of it, + // so to include that space the selection must end one character to the right. + let mut tail = movement::right(map, cursor); + let classifier = map.buffer_snapshot.char_classifier_at(head.to_point(map)); + if head == DisplayPoint::zero() { + return None; + } + for _ in 0..times { + let (maybe_next_tail, next_head) = + movement::find_preceding_boundary_trail(map, head, |left, right| { + is_boundary(left, right, &classifier) + }); - if head == DisplayPoint::zero() { - return; - } + if next_head == head && maybe_next_tail.unwrap_or(next_head) == tail { + break; + } - // collapse to block cursor - if tail < head { - tail = movement::left(map, head); - } else { - tail = head; - head = movement::right(map, head); - } - - selection.set_head(head, new_goal); - selection.set_tail(tail, new_goal); - // flip the selection - selection.swap_head_tail(); - head = selection.head(); - tail = selection.tail(); - - // create a classifier - let classifier = map.buffer_snapshot.char_classifier_at(head.to_point(map)); - - for _ in 0..times { - let (maybe_next_tail, next_head) = - movement::find_preceding_boundary_trail(map, head, |left, right| { - is_boundary(left, right, &classifier) - }); - - if next_head == head && maybe_next_tail.unwrap_or(next_head) == tail { - break; - } - - head = next_head; - if let Some(next_tail) = maybe_next_tail { - tail = next_tail; - } - } - - selection.set_tail(tail, new_goal); - selection.set_head(head, new_goal); - }); - }) + head = next_head; + if let Some(next_tail) = maybe_next_tail { + tail = next_tail; + } + } + Some((head, tail)) }); } @@ -175,7 +167,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { @@ -252,67 +244,103 @@ impl Vim { found }) } - Motion::FindForward { .. } => { - self.update_editor(window, cx, |_, editor, window, cx| { - let text_layout_details = editor.text_layout_details(window); - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|map, selection| { - let goal = selection.goal; - let cursor = if selection.is_empty() || selection.reversed { - selection.head() - } else { - movement::left(map, selection.head()) - }; - - let (point, goal) = motion - .move_point( - map, - cursor, - selection.goal, - times, - &text_layout_details, - ) - .unwrap_or((cursor, goal)); - selection.set_tail(selection.head(), goal); - selection.set_head(movement::right(map, point), goal); - }) - }); + Motion::FindForward { + before, + char, + mode, + smartcase, + } => { + self.helix_new_selections(window, cx, |cursor, map| { + let start = cursor; + let mut last_boundary = start; + for _ in 0..times.unwrap_or(1) { + last_boundary = movement::find_boundary( + map, + movement::right(map, last_boundary), + mode, + |left, right| { + let current_char = if before { right } else { left }; + motion::is_character_match(char, current_char, smartcase) + }, + ); + } + Some((last_boundary, start)) }); } - Motion::FindBackward { .. } => { - self.update_editor(window, cx, |_, editor, window, cx| { - let text_layout_details = editor.text_layout_details(window); - editor.change_selections(Default::default(), window, cx, |s| { - s.move_with(|map, selection| { - let goal = selection.goal; - let cursor = if selection.is_empty() || selection.reversed { - selection.head() - } else { - movement::left(map, selection.head()) - }; - - let (point, goal) = motion - .move_point( - map, - cursor, - selection.goal, - times, - &text_layout_details, - ) - .unwrap_or((cursor, goal)); - selection.set_tail(selection.head(), goal); - selection.set_head(point, goal); - }) - }); + Motion::FindBackward { + after, + char, + mode, + smartcase, + } => { + self.helix_new_selections(window, cx, |cursor, map| { + let start = cursor; + let mut last_boundary = start; + for _ in 0..times.unwrap_or(1) { + last_boundary = movement::find_preceding_boundary_display_point( + map, + last_boundary, + mode, + |left, right| { + let current_char = if after { left } else { right }; + motion::is_character_match(char, current_char, smartcase) + }, + ); + } + // The original cursor was one character wide, + // but the search started from the left side of it, + // so to include that space the selection must end one character to the right. + Some((last_boundary, movement::right(map, start))) }); } _ => self.helix_move_and_collapse(motion, times, window, cx), } } + pub fn helix_yank(&mut self, _: &HelixYank, window: &mut Window, cx: &mut Context) { + self.update_editor(cx, |vim, editor, cx| { + let has_selection = editor + .selections + .all_adjusted(cx) + .iter() + .any(|selection| !selection.is_empty()); + + if !has_selection { + // If no selection, expand to current character (like 'v' does) + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|map, selection| { + let head = selection.head(); + let new_head = movement::saturating_right(map, head); + selection.set_tail(head, SelectionGoal::None); + selection.set_head(new_head, SelectionGoal::None); + }); + }); + vim.yank_selections_content( + editor, + crate::motion::MotionKind::Exclusive, + window, + cx, + ); + editor.change_selections(Default::default(), window, cx, |s| { + s.move_with(|_map, selection| { + selection.collapse_to(selection.start, SelectionGoal::None); + }); + }); + } else { + // Yank the selection(s) + vim.yank_selections_content( + editor, + crate::motion::MotionKind::Exclusive, + window, + cx, + ); + } + }); + } + fn helix_insert(&mut self, _: &HelixInsert, window: &mut Window, cx: &mut Context) { self.start_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|_map, selection| { // In helix normal mode, move cursor to start of selection and collapse @@ -328,7 +356,7 @@ impl Vim { fn helix_append(&mut self, _: &HelixAppend, window: &mut Window, cx: &mut Context) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { let point = if selection.is_empty() { @@ -343,7 +371,7 @@ impl Vim { } pub fn helix_replace(&mut self, text: &str, window: &mut Window, cx: &mut Context) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let (map, selections) = editor.selections.all_display(cx); @@ -586,13 +614,33 @@ mod test { Mode::HelixNormal, ); - cx.simulate_keystrokes("2 T r"); + cx.simulate_keystrokes("F e F e"); cx.assert_state( indoc! {" - The quick br«ˇown - fox jumps over - the laz»y dog."}, + The quick brown + fox jumps ov«ˇer + the» lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("e 2 F e"); + + cx.assert_state( + indoc! {" + Th«ˇe quick brown + fox jumps over» + the lazy dog."}, + Mode::HelixNormal, + ); + + cx.simulate_keystrokes("t r t r"); + + cx.assert_state( + indoc! {" + The quick «brown + fox jumps oveˇ»r + the lazy dog."}, Mode::HelixNormal, ); } @@ -703,4 +751,29 @@ mod test { cx.assert_state("«xxˇ»", Mode::HelixNormal); } + + #[gpui::test] + async fn test_helix_yank(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + cx.enable_helix(); + + // Test yanking current character with no selection + cx.set_state("hello ˇworld", Mode::HelixNormal); + cx.simulate_keystrokes("y"); + + // Test cursor remains at the same position after yanking single character + cx.assert_state("hello ˇworld", Mode::HelixNormal); + cx.shared_clipboard().assert_eq("w"); + + // Move cursor and yank another character + cx.simulate_keystrokes("l"); + cx.simulate_keystrokes("y"); + cx.shared_clipboard().assert_eq("o"); + + // Test yanking with existing selection + cx.set_state("hello «worlˇ»d", Mode::HelixNormal); + cx.simulate_keystrokes("y"); + cx.shared_clipboard().assert_eq("worl"); + cx.assert_state("hello «worlˇ»d", Mode::HelixNormal); + } } diff --git a/crates/vim/src/indent.rs b/crates/vim/src/indent.rs index 75b1857a5b..7ef204de0f 100644 --- a/crates/vim/src/indent.rs +++ b/crates/vim/src/indent.rs @@ -31,7 +31,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let original_positions = vim.save_selection_starts(editor, cx); for _ in 0..count { @@ -50,7 +50,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let original_positions = vim.save_selection_starts(editor, cx); for _ in 0..count { @@ -69,7 +69,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let original_positions = vim.save_selection_starts(editor, cx); for _ in 0..count { @@ -95,7 +95,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { let mut selection_starts: HashMap<_, _> = Default::default(); @@ -137,7 +137,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let mut original_positions: HashMap<_, _> = Default::default(); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 0a370e16ba..8ef1cd7811 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -3,7 +3,9 @@ use editor::{Bias, Editor}; use gpui::{Action, Context, Window, actions}; use language::SelectionGoal; use settings::Settings; +use text::Point; use vim_mode_setting::HelixModeSetting; +use workspace::searchable::Direction; actions!( vim, @@ -11,13 +13,23 @@ actions!( /// Switches to normal mode with cursor positioned before the current character. NormalBefore, /// Temporarily switches to normal mode for one command. - TemporaryNormal + TemporaryNormal, + /// Inserts the next character from the line above into the current line. + InsertFromAbove, + /// Inserts the next character from the line below into the current line. + InsertFromBelow ] ); pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, Vim::normal_before); Vim::action(editor, cx, Vim::temporary_normal); + Vim::action(editor, cx, |vim, _: &InsertFromAbove, window, cx| { + vim.insert_around(Direction::Prev, window, cx) + }); + Vim::action(editor, cx, |vim, _: &InsertFromBelow, window, cx| { + vim.insert_around(Direction::Next, window, cx) + }) } impl Vim { @@ -38,7 +50,7 @@ impl Vim { if count <= 1 || Vim::globals(cx).dot_replaying { self.create_mark("^".into(), window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.dismiss_menus_and_popups(false, window, cx); if !HelixModeSetting::get_global(cx).0 { @@ -71,6 +83,29 @@ impl Vim { self.switch_mode(Mode::Normal, true, window, cx); self.temp_mode = true; } + + fn insert_around(&mut self, direction: Direction, _: &mut Window, cx: &mut Context) { + self.update_editor(cx, |_, editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let mut edits = Vec::new(); + for selection in editor.selections.all::(cx) { + let point = selection.head(); + let new_row = match direction { + Direction::Next => point.row + 1, + Direction::Prev if point.row > 0 => point.row - 1, + _ => continue, + }; + let source = snapshot.clip_point(Point::new(new_row, point.column), Bias::Left); + if let Some(c) = snapshot.chars_at(source).next() + && c != '\n' + { + edits.push((point..point, c.to_string())) + } + } + + editor.edit(edits, cx); + }); + } } #[cfg(test)] @@ -156,4 +191,13 @@ mod test { .await; cx.shared_state().await.assert_eq("hehello\nˇllo\n"); } + + #[gpui::test] + async fn test_insert_ctrl_y(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state("hello\nˇ\nworld").await; + cx.simulate_shared_keystrokes("i ctrl-y ctrl-e").await; + cx.shared_state().await.assert_eq("hello\nhoˇ\nworld"); + } } diff --git a/crates/vim/src/mode_indicator.rs b/crates/vim/src/mode_indicator.rs index d54b270074..714b74f239 100644 --- a/crates/vim/src/mode_indicator.rs +++ b/crates/vim/src/mode_indicator.rs @@ -20,7 +20,7 @@ impl ModeIndicator { }) .detach(); - let handle = cx.entity().clone(); + let handle = cx.entity(); let window_handle = window.window_handle(); cx.observe_new::(move |_, window, cx| { let Some(window) = window else { @@ -29,7 +29,7 @@ impl ModeIndicator { if window.window_handle() != window_handle { return; } - let vim = cx.entity().clone(); + let vim = cx.entity(); handle.update(cx, |_, cx| { cx.subscribe(&vim, |mode_indicator, vim, event, cx| match event { VimEvent::Focused => { diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index 0e487f4410..a6a07e7b2f 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -679,7 +679,7 @@ impl Vim { match self.mode { Mode::Visual | Mode::VisualLine | Mode::VisualBlock => { if !prior_selections.is_empty() { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.select_ranges(prior_selections.iter().cloned()) }) @@ -2639,7 +2639,8 @@ fn find_backward( } } -fn is_character_match(target: char, other: char, smartcase: bool) -> bool { +/// Returns true if one char is equal to the other or its uppercase variant (if smartcase is true). +pub fn is_character_match(target: char, other: char, smartcase: bool) -> bool { if smartcase { if target.is_uppercase() { target == other diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index 13128e7b40..b74d85b7c5 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -132,7 +132,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &HelixDelete, window, cx| { vim.record_current_action(cx); - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.move_with(|map, selection| { if selection.is_empty() { @@ -146,7 +146,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, _: &HelixCollapseSelection, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.move_with(|map, selection| { let mut point = selection.head(); @@ -198,7 +198,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Undo, window, cx| { let times = Vim::take_count(cx); Vim::take_forced_motion(cx); - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { for _ in 0..times.unwrap_or(1) { editor.undo(&editor::actions::Undo, window, cx); } @@ -207,7 +207,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Redo, window, cx| { let times = Vim::take_count(cx); Vim::take_forced_motion(cx); - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { for _ in 0..times.unwrap_or(1) { editor.redo(&editor::actions::Redo, window, cx); } @@ -215,7 +215,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, _: &UndoLastLine, window, cx| { Vim::take_forced_motion(cx); - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { let snapshot = editor.buffer().read(cx).snapshot(cx); let Some(last_change) = editor.change_list.last_before_grouping() else { return; @@ -526,7 +526,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.change_selections( SelectionEffects::default().nav_history(motion.push_to_jump_list()), @@ -546,7 +546,7 @@ impl Vim { fn insert_after(&mut self, _: &InsertAfter, window: &mut Window, cx: &mut Context) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_cursors_with(|map, cursor, _| (right(map, cursor, 1), SelectionGoal::None)); }); @@ -557,7 +557,7 @@ impl Vim { self.start_recording(cx); if self.mode.is_visual() { let current_mode = self.mode; - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { if current_mode == Mode::VisualLine { @@ -581,7 +581,7 @@ impl Vim { ) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_cursors_with(|map, cursor, _| { ( @@ -601,7 +601,7 @@ impl Vim { ) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_cursors_with(|map, cursor, _| { (next_line_end(map, cursor, 1), SelectionGoal::None) @@ -618,7 +618,7 @@ impl Vim { ) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let Some(Mark::Local(marks)) = vim.get_mark("^", editor, window, cx) else { return; }; @@ -637,7 +637,7 @@ impl Vim { ) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let selections = editor.selections.all::(cx); let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -678,7 +678,7 @@ impl Vim { ) { self.start_recording(cx); self.switch_mode(Mode::Insert, false, window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { let selections = editor.selections.all::(cx); @@ -725,7 +725,7 @@ impl Vim { self.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, _, cx| { let selections = editor.selections.all::(cx); @@ -754,7 +754,7 @@ impl Vim { self.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let selections = editor.selections.all::(cx); let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -804,7 +804,7 @@ impl Vim { times -= 1; } - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { for _ in 0..times { editor.join_lines_impl(insert_whitespace, window, cx) @@ -828,10 +828,10 @@ impl Vim { ) } - fn show_location(&mut self, _: &ShowLocation, window: &mut Window, cx: &mut Context) { + fn show_location(&mut self, _: &ShowLocation, _: &mut Window, cx: &mut Context) { let count = Vim::take_count(cx); Vim::take_forced_motion(cx); - self.update_editor(window, cx, |vim, editor, _window, cx| { + self.update_editor(cx, |vim, editor, cx| { let selection = editor.selections.newest_anchor(); let Some((buffer, point, _)) = editor .buffer() @@ -875,7 +875,7 @@ impl Vim { fn toggle_comments(&mut self, _: &ToggleComments, window: &mut Window, cx: &mut Context) { self.record_current_action(cx); self.store_visual_marks(window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let original_positions = vim.save_selection_starts(editor, cx); editor.toggle_comments(&Default::default(), window, cx); @@ -897,7 +897,7 @@ impl Vim { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let (map, display_selections) = editor.selections.all_display(cx); diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index c1bc7a70ae..fcd36dd7ee 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -34,7 +34,7 @@ impl Vim { } else { None }; - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { // We are swapping to insert mode anyway. Just set the line end clipping behavior now @@ -111,7 +111,7 @@ impl Vim { cx: &mut Context, ) { let mut objects_found = false; - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { // We are swapping to insert mode anyway. Just set the line end clipping behavior now editor.set_clip_at_line_ends(false, cx); editor.transact(window, cx, |editor, window, cx| { diff --git a/crates/vim/src/normal/convert.rs b/crates/vim/src/normal/convert.rs index cf9498bec9..4b9c3fc8f7 100644 --- a/crates/vim/src/normal/convert.rs +++ b/crates/vim/src/normal/convert.rs @@ -31,7 +31,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.set_clip_at_line_ends(false, cx); let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { @@ -87,7 +87,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let mut original_positions: HashMap<_, _> = Default::default(); @@ -195,7 +195,7 @@ impl Vim { let count = Vim::take_count(cx).unwrap_or(1) as u32; Vim::take_forced_motion(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let mut ranges = Vec::new(); let mut cursor_positions = Vec::new(); let snapshot = editor.buffer().read(cx).snapshot(cx); diff --git a/crates/vim/src/normal/delete.rs b/crates/vim/src/normal/delete.rs index 2cf40292cf..1b7557371a 100644 --- a/crates/vim/src/normal/delete.rs +++ b/crates/vim/src/normal/delete.rs @@ -22,7 +22,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -96,7 +96,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); // Emulates behavior in vim where if we expanded backwards to include a newline diff --git a/crates/vim/src/normal/increment.rs b/crates/vim/src/normal/increment.rs index 51f6e4a0f9..007514e472 100644 --- a/crates/vim/src/normal/increment.rs +++ b/crates/vim/src/normal/increment.rs @@ -53,7 +53,7 @@ impl Vim { cx: &mut Context, ) { self.store_visual_marks(window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let mut edits = Vec::new(); let mut new_anchors = Vec::new(); diff --git a/crates/vim/src/normal/mark.rs b/crates/vim/src/normal/mark.rs index 57a6108841..1d6264d593 100644 --- a/crates/vim/src/normal/mark.rs +++ b/crates/vim/src/normal/mark.rs @@ -19,7 +19,7 @@ use crate::{ impl Vim { pub fn create_mark(&mut self, text: Arc, window: &mut Window, cx: &mut Context) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let anchors = editor .selections .disjoint_anchors() @@ -49,7 +49,7 @@ impl Vim { let mut ends = vec![]; let mut reversed = vec![]; - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let (map, selections) = editor.selections.all_display(cx); for selection in selections { let end = movement::saturating_left(&map, selection.end); @@ -190,7 +190,7 @@ impl Vim { self.pop_operator(window, cx); } let mark = self - .update_editor(window, cx, |vim, editor, window, cx| { + .update_editor(cx, |vim, editor, cx| { vim.get_mark(&text, editor, window, cx) }) .flatten(); @@ -209,7 +209,7 @@ impl Vim { let Some(mut anchors) = anchors else { return }; - self.update_editor(window, cx, |_, editor, _, cx| { + self.update_editor(cx, |_, editor, cx| { editor.create_nav_history_entry(cx); }); let is_active_operator = self.active_operator().is_some(); @@ -231,7 +231,7 @@ impl Vim { || self.mode == Mode::VisualLine || self.mode == Mode::VisualBlock; - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let map = editor.snapshot(window, cx); let mut ranges: Vec> = Vec::new(); for mut anchor in anchors { diff --git a/crates/vim/src/normal/paste.rs b/crates/vim/src/normal/paste.rs index 07712fbedd..0fd17f310e 100644 --- a/crates/vim/src/normal/paste.rs +++ b/crates/vim/src/normal/paste.rs @@ -32,7 +32,7 @@ impl Vim { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -236,7 +236,7 @@ impl Vim { ) { self.stop_recording(cx); let selected_register = self.selected_register.take(); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { @@ -273,7 +273,7 @@ impl Vim { ) { self.stop_recording(cx); let selected_register = self.selected_register.take(); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); diff --git a/crates/vim/src/normal/scroll.rs b/crates/vim/src/normal/scroll.rs index e2ae74b52b..af13bc0fd0 100644 --- a/crates/vim/src/normal/scroll.rs +++ b/crates/vim/src/normal/scroll.rs @@ -97,7 +97,7 @@ impl Vim { let amount = by(Vim::take_count(cx).map(|c| c as f32)); Vim::take_forced_motion(cx); self.exit_temporary_normal(window, cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { scroll_editor(editor, move_cursor, &amount, window, cx) }); } diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index 24f2cf751f..4054c552ae 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -251,7 +251,7 @@ impl Vim { // If the active editor has changed during a search, don't panic. if prior_selections.iter().any(|s| { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { !s.start .is_valid(&editor.snapshot(window, cx).buffer_snapshot) }) @@ -332,7 +332,7 @@ impl Vim { Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); let cursor_word = self.editor_cursor_word(window, cx); - let vim = cx.entity().clone(); + let vim = cx.entity(); let searched = pane.update(cx, |pane, cx| { self.search.direction = direction; @@ -457,7 +457,7 @@ impl Vim { else { return; }; - if let Some(result) = self.update_editor(window, cx, |vim, editor, window, cx| { + if let Some(result) = self.update_editor(cx, |vim, editor, cx| { let range = action.range.buffer_range(vim, editor, window, cx)?; let snapshot = &editor.snapshot(window, cx).buffer_snapshot; let end_point = Point::new(range.end.0, snapshot.line_len(range.end)); diff --git a/crates/vim/src/normal/substitute.rs b/crates/vim/src/normal/substitute.rs index a9752f2887..889d487170 100644 --- a/crates/vim/src/normal/substitute.rs +++ b/crates/vim/src/normal/substitute.rs @@ -45,7 +45,7 @@ impl Vim { cx: &mut Context, ) { self.store_visual_marks(window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.set_clip_at_line_ends(false, cx); editor.transact(window, cx, |editor, window, cx| { let text_layout_details = editor.text_layout_details(window); diff --git a/crates/vim/src/normal/toggle_comments.rs b/crates/vim/src/normal/toggle_comments.rs index 636ea9eec2..17c3b2d363 100644 --- a/crates/vim/src/normal/toggle_comments.rs +++ b/crates/vim/src/normal/toggle_comments.rs @@ -14,7 +14,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { let mut selection_starts: HashMap<_, _> = Default::default(); @@ -51,7 +51,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let mut original_positions: HashMap<_, _> = Default::default(); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { diff --git a/crates/vim/src/normal/yank.rs b/crates/vim/src/normal/yank.rs index 847eba3143..fe8180ffff 100644 --- a/crates/vim/src/normal/yank.rs +++ b/crates/vim/src/normal/yank.rs @@ -25,7 +25,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -70,7 +70,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let mut start_positions: HashMap<_, _> = Default::default(); diff --git a/crates/vim/src/replace.rs b/crates/vim/src/replace.rs index aa857ef73e..eaa9fd5062 100644 --- a/crates/vim/src/replace.rs +++ b/crates/vim/src/replace.rs @@ -49,7 +49,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let map = editor.snapshot(window, cx); @@ -94,7 +94,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let map = editor.snapshot(window, cx); @@ -148,7 +148,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.set_clip_at_line_ends(false, cx); let mut selection = editor.selections.newest_display(cx); let snapshot = editor.snapshot(window, cx); @@ -167,7 +167,7 @@ impl Vim { pub fn exchange_visual(&mut self, window: &mut Window, cx: &mut Context) { self.stop_recording(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let selection = editor.selections.newest_anchor(); let new_range = selection.start..selection.end; let snapshot = editor.snapshot(window, cx); @@ -178,7 +178,7 @@ impl Vim { pub fn clear_exchange(&mut self, window: &mut Window, cx: &mut Context) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, _, cx| { + self.update_editor(cx, |_, editor, cx| { editor.clear_background_highlights::(cx); }); self.clear_operator(window, cx); @@ -193,7 +193,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.set_clip_at_line_ends(false, cx); let text_layout_details = editor.text_layout_details(window); let mut selection = editor.selections.newest_display(cx); diff --git a/crates/vim/src/rewrap.rs b/crates/vim/src/rewrap.rs index 4cd9449bfa..85e1967af0 100644 --- a/crates/vim/src/rewrap.rs +++ b/crates/vim/src/rewrap.rs @@ -18,7 +18,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::take_count(cx); Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); - vim.update_editor(window, cx, |vim, editor, window, cx| { + vim.update_editor(cx, |vim, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let mut positions = vim.save_selection_starts(editor, cx); editor.rewrap_impl( @@ -55,7 +55,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { let mut selection_starts: HashMap<_, _> = Default::default(); @@ -100,7 +100,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let mut original_positions: HashMap<_, _> = Default::default(); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { diff --git a/crates/vim/src/surrounds.rs b/crates/vim/src/surrounds.rs index 1f77ebda4a..63cd21e88c 100644 --- a/crates/vim/src/surrounds.rs +++ b/crates/vim/src/surrounds.rs @@ -29,7 +29,7 @@ impl Vim { let count = Vim::take_count(cx); let forced_motion = Vim::take_forced_motion(cx); let mode = self.mode; - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let text_layout_details = editor.text_layout_details(window); editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -140,7 +140,7 @@ impl Vim { }; let surround = pair.end != *text; - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -228,7 +228,7 @@ impl Vim { ) { if let Some(will_replace_pair) = object_to_bracket_pair(target) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); @@ -344,7 +344,7 @@ impl Vim { ) -> bool { let mut valid = false; if let Some(pair) = object_to_bracket_pair(object) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { editor.set_clip_at_line_ends(false, cx); let (display_map, selections) = editor.selections.all_adjusted_display(cx); diff --git a/crates/vim/src/test/vim_test_context.rs b/crates/vim/src/test/vim_test_context.rs index b8988b1d1f..5b6cb55e8c 100644 --- a/crates/vim/src/test/vim_test_context.rs +++ b/crates/vim/src/test/vim_test_context.rs @@ -15,6 +15,7 @@ impl VimTestContext { if cx.has_global::() { return; } + env_logger::try_init().ok(); cx.update(|cx| { let settings = SettingsStore::test(cx); cx.set_global(settings); @@ -142,6 +143,16 @@ impl VimTestContext { }) } + pub fn enable_helix(&mut self) { + self.cx.update(|_, cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |s| { + *s = Some(true) + }); + }); + }) + } + pub fn mode(&mut self) -> Mode { self.update_editor(|editor, _, cx| editor.addon::().unwrap().entity.read(cx).mode) } @@ -209,6 +220,26 @@ impl VimTestContext { assert_eq!(self.mode(), Mode::Normal, "{}", self.assertion_context()); assert_eq!(self.active_operator(), None, "{}", self.assertion_context()); } + + pub fn shared_clipboard(&mut self) -> VimClipboard { + VimClipboard { + editor: self + .read_from_clipboard() + .map(|item| item.text().unwrap().to_string()) + .unwrap_or_default(), + } + } +} + +pub struct VimClipboard { + editor: String, +} + +impl VimClipboard { + #[track_caller] + pub fn assert_eq(&self, expected: &str) { + assert_eq!(self.editor, expected); + } } impl Deref for VimTestContext { diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 72edbe77ed..44d9b8f456 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -284,9 +284,7 @@ pub fn init(cx: &mut App) { let count = Vim::take_count(cx).unwrap_or(1) as f32; Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); - let Ok(font_id) = window.text_system().font_id(&theme.buffer_font) else { - return; - }; + let font_id = window.text_system().resolve_font(&theme.buffer_font); let Ok(width) = window .text_system() .advance(font_id, theme.buffer_font_size(cx), 'm') @@ -300,9 +298,7 @@ pub fn init(cx: &mut App) { let count = Vim::take_count(cx).unwrap_or(1) as f32; Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); - let Ok(font_id) = window.text_system().font_id(&theme.buffer_font) else { - return; - }; + let font_id = window.text_system().resolve_font(&theme.buffer_font); let Ok(width) = window .text_system() .advance(font_id, theme.buffer_font_size(cx), 'm') @@ -406,7 +402,7 @@ impl Vim { const NAMESPACE: &'static str = "vim"; pub fn new(window: &mut Window, cx: &mut Context) -> Entity { - let editor = cx.entity().clone(); + let editor = cx.entity(); let mut initial_mode = VimSettings::get_global(cx).default_mode; if initial_mode == Mode::Normal && HelixModeSetting::get_global(cx).0 { @@ -748,7 +744,7 @@ impl Vim { editor, cx, |vim, action: &editor::actions::AcceptEditPrediction, window, cx| { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.accept_edit_prediction(action, window, cx); }); // In non-insertion modes, predictions will be hidden and instead a jump will be @@ -847,7 +843,7 @@ impl Vim { if let Some(action) = keystroke_event.action.as_ref() { // Keystroke is handled by the vim system, so continue forward if action.name().starts_with("vim::") { - self.update_editor(window, cx, |_, editor, _, cx| { + self.update_editor(cx, |_, editor, cx| { editor.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx) }); return; @@ -909,7 +905,7 @@ impl Vim { anchor, is_deactivate, } => { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let mark = if *is_deactivate { "\"".to_string() } else { @@ -972,7 +968,7 @@ impl Vim { if mode == Mode::Normal || mode != last_mode { self.current_tx.take(); self.current_anchor.take(); - self.update_editor(window, cx, |_, editor, _, _| { + self.update_editor(cx, |_, editor, _| { editor.clear_selection_drag_state(); }); } @@ -988,7 +984,7 @@ impl Vim { && self.mode != self.last_mode && (self.mode == Mode::Insert || self.last_mode == Mode::Insert) { - self.update_editor(window, cx, |vim, editor, _, cx| { + self.update_editor(cx, |vim, editor, cx| { let is_relative = vim.mode != Mode::Insert; editor.set_relative_line_number(Some(is_relative), cx) }); @@ -1003,7 +999,7 @@ impl Vim { } // Adjust selections - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { if last_mode != Mode::VisualBlock && last_mode.is_visual() && mode == Mode::VisualBlock { vim.visual_block_motion(true, editor, window, cx, |_, point, goal| { @@ -1214,7 +1210,7 @@ impl Vim { if preserve_selection { self.switch_mode(Mode::Visual, true, window, cx); } else { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.set_clip_at_line_ends(false, cx); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.move_with(|_, selection| { @@ -1232,18 +1228,18 @@ impl Vim { if let Some(old_vim) = Vim::globals(cx).focused_vim() { if old_vim.entity_id() != cx.entity().entity_id() { old_vim.update(cx, |vim, cx| { - vim.update_editor(window, cx, |_, editor, _, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.set_relative_line_number(None, cx) }); }); - self.update_editor(window, cx, |vim, editor, _, cx| { + self.update_editor(cx, |vim, editor, cx| { let is_relative = vim.mode != Mode::Insert; editor.set_relative_line_number(Some(is_relative), cx) }); } } else { - self.update_editor(window, cx, |vim, editor, _, cx| { + self.update_editor(cx, |vim, editor, cx| { let is_relative = vim.mode != Mode::Insert; editor.set_relative_line_number(Some(is_relative), cx) }); @@ -1256,35 +1252,30 @@ impl Vim { self.stop_recording_immediately(NormalBefore.boxed_clone(), cx); self.store_visual_marks(window, cx); self.clear_operator(window, cx); - self.update_editor(window, cx, |vim, editor, _, cx| { + self.update_editor(cx, |vim, editor, cx| { if vim.cursor_shape(cx) == CursorShape::Block { editor.set_cursor_shape(CursorShape::Hollow, cx); } }); } - fn cursor_shape_changed(&mut self, window: &mut Window, cx: &mut Context) { - self.update_editor(window, cx, |vim, editor, _, cx| { + fn cursor_shape_changed(&mut self, _: &mut Window, cx: &mut Context) { + self.update_editor(cx, |vim, editor, cx| { editor.set_cursor_shape(vim.cursor_shape(cx), cx); }); } fn update_editor( &mut self, - window: &mut Window, cx: &mut Context, - update: impl FnOnce(&mut Self, &mut Editor, &mut Window, &mut Context) -> S, + update: impl FnOnce(&mut Self, &mut Editor, &mut Context) -> S, ) -> Option { let editor = self.editor.upgrade()?; - Some(editor.update(cx, |editor, cx| update(self, editor, window, cx))) + Some(editor.update(cx, |editor, cx| update(self, editor, cx))) } - fn editor_selections( - &mut self, - window: &mut Window, - cx: &mut Context, - ) -> Vec> { - self.update_editor(window, cx, |_, editor, _, _| { + fn editor_selections(&mut self, _: &mut Window, cx: &mut Context) -> Vec> { + self.update_editor(cx, |_, editor, _| { editor .selections .disjoint_anchors() @@ -1300,7 +1291,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) -> Option { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let selection = editor.selections.newest::(cx); let snapshot = &editor.snapshot(window, cx).buffer_snapshot; @@ -1489,7 +1480,7 @@ impl Vim { ) { match self.mode { Mode::VisualLine | Mode::VisualBlock | Mode::Visual => { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let original_mode = vim.undo_modes.get(transaction_id); editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { match original_mode { @@ -1520,7 +1511,7 @@ impl Vim { self.switch_mode(Mode::Normal, true, window, cx) } Mode::Normal => { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { s.move_with(|map, selection| { selection @@ -1547,7 +1538,7 @@ impl Vim { self.current_anchor = Some(newest); } else if self.current_anchor.as_ref().unwrap() != &newest { if let Some(tx_id) = self.current_tx.take() { - self.update_editor(window, cx, |_, editor, _, cx| { + self.update_editor(cx, |_, editor, cx| { editor.group_until_transaction(tx_id, cx) }); } @@ -1694,7 +1685,7 @@ impl Vim { } Some(Operator::Register) => match self.mode { Mode::Insert => { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { if let Some(register) = Vim::update_globals(cx, |globals, cx| { globals.read_register(text.chars().next(), Some(editor), cx) }) { @@ -1720,7 +1711,7 @@ impl Vim { } if self.mode == Mode::Normal { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.accept_edit_prediction( &editor::actions::AcceptEditPrediction {}, window, @@ -1733,7 +1724,7 @@ impl Vim { } fn sync_vim_settings(&mut self, window: &mut Window, cx: &mut Context) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { editor.set_cursor_shape(vim.cursor_shape(cx), cx); editor.set_clip_at_line_ends(vim.clip_at_line_ends(), cx); editor.set_collapse_matches(true); diff --git a/crates/vim/src/visual.rs b/crates/vim/src/visual.rs index ca8734ba8b..7bfd8dc8be 100644 --- a/crates/vim/src/visual.rs +++ b/crates/vim/src/visual.rs @@ -104,7 +104,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); for _ in 0..count { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.select_larger_syntax_node(&Default::default(), window, cx); }); } @@ -117,7 +117,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { let count = Vim::take_count(cx).unwrap_or(1); Vim::take_forced_motion(cx); for _ in 0..count { - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.select_smaller_syntax_node(&Default::default(), window, cx); }); } @@ -129,7 +129,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { return; }; let marks = vim - .update_editor(window, cx, |vim, editor, window, cx| { + .update_editor(cx, |vim, editor, cx| { vim.get_mark("<", editor, window, cx) .zip(vim.get_mark(">", editor, window, cx)) }) @@ -148,7 +148,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { vim.create_visual_marks(vim.mode, window, cx); } - vim.update_editor(window, cx, |_, editor, window, cx| { + vim.update_editor(cx, |_, editor, cx| { editor.set_clip_at_line_ends(false, cx); editor.change_selections(Default::default(), window, cx, |s| { let map = s.display_map(); @@ -189,7 +189,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let text_layout_details = editor.text_layout_details(window); if vim.mode == Mode::VisualBlock && !matches!( @@ -397,7 +397,7 @@ impl Vim { self.switch_mode(target_mode, true, window, cx); } - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|map, selection| { let mut mut_selection = selection.clone(); @@ -475,7 +475,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.split_selection_into_lines(&Default::default(), window, cx); editor.change_selections(Default::default(), window, cx, |s| { s.move_cursors_with(|map, cursor, _| { @@ -493,7 +493,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.split_selection_into_lines(&Default::default(), window, cx); editor.change_selections(Default::default(), window, cx, |s| { s.move_cursors_with(|map, cursor, _| { @@ -517,7 +517,7 @@ impl Vim { } pub fn other_end(&mut self, _: &OtherEnd, window: &mut Window, cx: &mut Context) { - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|_, selection| { selection.reversed = !selection.reversed; @@ -533,7 +533,7 @@ impl Vim { cx: &mut Context, ) { let mode = self.mode; - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.change_selections(Default::default(), window, cx, |s| { s.move_with(|_, selection| { selection.reversed = !selection.reversed; @@ -547,7 +547,7 @@ impl Vim { pub fn visual_delete(&mut self, line_mode: bool, window: &mut Window, cx: &mut Context) { self.store_visual_marks(window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let mut original_columns: HashMap<_, _> = Default::default(); let line_mode = line_mode || editor.selections.line_mode; editor.selections.line_mode = false; @@ -631,7 +631,7 @@ impl Vim { pub fn visual_yank(&mut self, line_mode: bool, window: &mut Window, cx: &mut Context) { self.store_visual_marks(window, cx); - self.update_editor(window, cx, |vim, editor, window, cx| { + self.update_editor(cx, |vim, editor, cx| { let line_mode = line_mode || editor.selections.line_mode; // For visual line mode, adjust selections to avoid yanking the next line when on \n @@ -679,7 +679,7 @@ impl Vim { cx: &mut Context, ) { self.stop_recording(cx); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.transact(window, cx, |editor, window, cx| { let (display_map, selections) = editor.selections.all_adjusted_display(cx); @@ -722,7 +722,7 @@ impl Vim { Vim::take_forced_motion(cx); let count = Vim::take_count(cx).unwrap_or_else(|| if self.mode.is_visual() { 1 } else { 2 }); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { editor.set_clip_at_line_ends(false, cx); for _ in 0..count { if editor @@ -745,7 +745,7 @@ impl Vim { Vim::take_forced_motion(cx); let count = Vim::take_count(cx).unwrap_or_else(|| if self.mode.is_visual() { 1 } else { 2 }); - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { for _ in 0..count { if editor .select_previous(&Default::default(), window, cx) @@ -773,7 +773,7 @@ impl Vim { let mut start_selection = 0usize; let mut end_selection = 0usize; - self.update_editor(window, cx, |_, editor, _, _| { + self.update_editor(cx, |_, editor, _| { editor.set_collapse_matches(false); }); if vim_is_normal { @@ -791,7 +791,7 @@ impl Vim { } }); } - self.update_editor(window, cx, |_, editor, _, cx| { + self.update_editor(cx, |_, editor, cx| { let latest = editor.selections.newest::(cx); start_selection = latest.start; end_selection = latest.end; @@ -812,7 +812,7 @@ impl Vim { self.stop_replaying(cx); return; } - self.update_editor(window, cx, |_, editor, window, cx| { + self.update_editor(cx, |_, editor, cx| { let latest = editor.selections.newest::(cx); if vim_is_normal { start_selection = latest.start; diff --git a/crates/vim/test_data/test_insert_ctrl_y.json b/crates/vim/test_data/test_insert_ctrl_y.json new file mode 100644 index 0000000000..09b707a198 --- /dev/null +++ b/crates/vim/test_data/test_insert_ctrl_y.json @@ -0,0 +1,5 @@ +{"Put":{"state":"hello\nˇ\nworld"}} +{"Key":"i"} +{"Key":"ctrl-y"} +{"Key":"ctrl-e"} +{"Get":{"state":"hello\nhoˇ\nworld","mode":"Insert"}} diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml deleted file mode 100644 index acb3fe0f84..0000000000 --- a/crates/welcome/Cargo.toml +++ /dev/null @@ -1,40 +0,0 @@ -[package] -name = "welcome" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/welcome.rs" - -[features] -test-support = [] - -[dependencies] -anyhow.workspace = true -client.workspace = true -component.workspace = true -db.workspace = true -documented.workspace = true -fuzzy.workspace = true -gpui.workspace = true -install_cli.workspace = true -language.workspace = true -picker.workspace = true -project.workspace = true -serde.workspace = true -settings.workspace = true -telemetry.workspace = true -ui.workspace = true -util.workspace = true -vim_mode_setting.workspace = true -workspace-hack.workspace = true -workspace.workspace = true -zed_actions.workspace = true - -[dev-dependencies] -editor = { workspace = true, features = ["test-support"] } diff --git a/crates/welcome/LICENSE-GPL b/crates/welcome/LICENSE-GPL deleted file mode 120000 index 89e542f750..0000000000 --- a/crates/welcome/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs deleted file mode 100644 index b0a1c316f4..0000000000 --- a/crates/welcome/src/welcome.rs +++ /dev/null @@ -1,446 +0,0 @@ -use client::{TelemetrySettings, telemetry::Telemetry}; -use db::kvp::KEY_VALUE_STORE; -use gpui::{ - Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, - ParentElement, Render, Styled, Subscription, Task, WeakEntity, Window, actions, svg, -}; -use language::language_settings::{EditPredictionProvider, all_language_settings}; -use project::DisableAiSettings; -use settings::{Settings, SettingsStore}; -use std::sync::Arc; -use ui::{CheckboxWithLabel, ElevationIndex, Tooltip, prelude::*}; -use util::ResultExt; -use vim_mode_setting::VimModeSetting; -use workspace::{ - AppState, Welcome, Workspace, WorkspaceId, - dock::DockPosition, - item::{Item, ItemEvent}, - open_new, -}; - -pub use multibuffer_hint::*; - -mod base_keymap_picker; -mod multibuffer_hint; - -actions!( - welcome, - [ - /// Resets the welcome screen hints to their initial state. - ResetHints - ] -); - -pub const FIRST_OPEN: &str = "first_open"; -pub const DOCS_URL: &str = "https://zed.dev/docs/"; - -pub fn init(cx: &mut App) { - cx.observe_new(|workspace: &mut Workspace, _, _cx| { - workspace.register_action(|workspace, _: &Welcome, window, cx| { - let welcome_page = WelcomePage::new(workspace, cx); - workspace.add_item_to_active_pane(Box::new(welcome_page), None, true, window, cx) - }); - workspace - .register_action(|_workspace, _: &ResetHints, _, cx| MultibufferHint::set_count(0, cx)); - }) - .detach(); - - base_keymap_picker::init(cx); -} - -pub fn show_welcome_view(app_state: Arc, cx: &mut App) -> Task> { - open_new( - Default::default(), - app_state, - cx, - |workspace, window, cx| { - workspace.toggle_dock(DockPosition::Left, window, cx); - let welcome_page = WelcomePage::new(workspace, cx); - workspace.add_item_to_center(Box::new(welcome_page.clone()), window, cx); - - window.focus(&welcome_page.focus_handle(cx)); - - cx.notify(); - - db::write_and_log(cx, || { - KEY_VALUE_STORE.write_kvp(FIRST_OPEN.to_string(), "false".to_string()) - }); - }, - ) -} - -pub struct WelcomePage { - workspace: WeakEntity, - focus_handle: FocusHandle, - telemetry: Arc, - _settings_subscription: Subscription, -} - -impl Render for WelcomePage { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let edit_prediction_provider_is_zed = - all_language_settings(None, cx).edit_predictions.provider - == EditPredictionProvider::Zed; - - let edit_prediction_label = if edit_prediction_provider_is_zed { - "Edit Prediction Enabled" - } else { - "Try Edit Prediction" - }; - - h_flex() - .size_full() - .bg(cx.theme().colors().editor_background) - .key_context("Welcome") - .track_focus(&self.focus_handle(cx)) - .child( - v_flex() - .gap_8() - .mx_auto() - .child( - v_flex() - .w_full() - .child( - svg() - .path("icons/logo_96.svg") - .text_color(cx.theme().colors().icon_disabled) - .w(px(40.)) - .h(px(40.)) - .mx_auto() - .mb_4(), - ) - .child( - h_flex() - .w_full() - .justify_center() - .child(Headline::new("Welcome to Zed")), - ) - .child( - h_flex().w_full().justify_center().child( - Label::new("The editor for what's next") - .color(Color::Muted) - .italic(), - ), - ), - ) - .child( - h_flex() - .items_start() - .gap_8() - .child( - v_flex() - .gap_2() - .pr_8() - .border_r_1() - .border_color(cx.theme().colors().border_variant) - .child( - self.section_label( cx).child( - Label::new("Get Started") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .child( - Button::new("choose-theme", "Choose a Theme") - .icon(IconName::SwatchBook) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome Theme Changed"); - this.workspace - .update(cx, |_workspace, cx| { - window.dispatch_action(zed_actions::theme_selector::Toggle::default().boxed_clone(), cx); - }) - .ok(); - })), - ) - .child( - Button::new("choose-keymap", "Choose a Keymap") - .icon(IconName::Keyboard) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome Keymap Changed"); - this.workspace - .update(cx, |workspace, cx| { - base_keymap_picker::toggle( - workspace, - &Default::default(), - window, cx, - ) - }) - .ok(); - })), - ) - .when(!DisableAiSettings::get_global(cx).disable_ai, |parent| { - parent.child( - Button::new( - "edit_prediction_onboarding", - edit_prediction_label, - ) - .disabled(edit_prediction_provider_is_zed) - .icon(IconName::ZedPredict) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click( - cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Screen Try Edit Prediction clicked"); - window.dispatch_action(zed_actions::OpenZedPredictOnboarding.boxed_clone(), cx); - }), - ), - ) - }) - .child( - Button::new("edit settings", "Edit Settings") - .icon(IconName::Settings) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Settings Edited"); - window.dispatch_action(Box::new( - zed_actions::OpenSettings, - ), cx); - })), - ) - - ) - .child( - v_flex() - .gap_2() - .child( - self.section_label(cx).child( - Label::new("Resources") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .when(cfg!(target_os = "macos"), |el| { - el.child( - Button::new("install-cli", "Install the CLI") - .icon(IconName::Terminal) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - telemetry::event!("Welcome CLI Installed"); - this.workspace.update(cx, |_, cx|{ - install_cli::install_cli(window, cx); - }).log_err(); - })), - ) - }) - .child( - Button::new("view-docs", "View Documentation") - .icon(IconName::FileCode) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, _, cx| { - telemetry::event!("Welcome Documentation Viewed"); - cx.open_url(DOCS_URL); - })), - ) - .child( - Button::new("explore-extensions", "Explore Extensions") - .icon(IconName::Blocks) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .on_click(cx.listener(|_, _, window, cx| { - telemetry::event!("Welcome Extensions Page Opened"); - window.dispatch_action(Box::new( - zed_actions::Extensions::default(), - ), cx); - })), - ) - ), - ) - .child( - v_container() - .px_2() - .gap_2() - .child( - h_flex() - .justify_between() - .child( - CheckboxWithLabel::new( - "enable-vim", - Label::new("Enable Vim Mode"), - if VimModeSetting::get_global(cx).0 { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Vim Mode Toggled"); - this.update_settings::( - selection, - cx, - |setting, value| *setting = Some(value), - ); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ) - .child( - IconButton::new("vim-mode", IconName::Info) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .tooltip( - Tooltip::text( - "You can also toggle Vim Mode via the command palette or Editor Controls menu.") - ), - ), - ) - .child( - CheckboxWithLabel::new( - "enable-crash", - Label::new("Send Crash Reports"), - if TelemetrySettings::get_global(cx).diagnostics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Diagnostic Telemetry Toggled"); - this.update_settings::(selection, cx, { - move |settings, value| { - settings.diagnostics = Some(value); - telemetry::event!( - "Settings Changed", - setting = "diagnostic telemetry", - value - ); - } - }); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ) - .child( - CheckboxWithLabel::new( - "enable-telemetry", - Label::new("Send Telemetry"), - if TelemetrySettings::get_global(cx).metrics { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - cx.listener(move |this, selection, _window, cx| { - telemetry::event!("Welcome Metric Telemetry Toggled"); - this.update_settings::(selection, cx, { - move |settings, value| { - settings.metrics = Some(value); - telemetry::event!( - "Settings Changed", - setting = "metric telemetry", - value - ); - } - }); - }), - ) - .fill() - .elevation(ElevationIndex::ElevatedSurface), - ), - ), - ) - } -} - -impl WelcomePage { - pub fn new(workspace: &Workspace, cx: &mut Context) -> Entity { - let this = cx.new(|cx| { - cx.on_release(|_: &mut Self, _| { - telemetry::event!("Welcome Page Closed"); - }) - .detach(); - - WelcomePage { - focus_handle: cx.focus_handle(), - workspace: workspace.weak_handle(), - telemetry: workspace.client().telemetry().clone(), - _settings_subscription: cx - .observe_global::(move |_, cx| cx.notify()), - } - }); - - this - } - - fn section_label(&self, cx: &mut App) -> Div { - div() - .pl_1() - .font_buffer(cx) - .text_color(Color::Muted.color(cx)) - } - - fn update_settings( - &mut self, - selection: &ToggleState, - cx: &mut Context, - callback: impl 'static + Send + Fn(&mut T::FileContent, bool), - ) { - if let Some(workspace) = self.workspace.upgrade() { - let fs = workspace.read(cx).app_state().fs.clone(); - let selection = *selection; - settings::update_settings_file::(fs, cx, move |settings, _| { - let value = match selection { - ToggleState::Unselected => false, - ToggleState::Selected => true, - _ => return, - }; - - callback(settings, value) - }); - } - } -} - -impl EventEmitter for WelcomePage {} - -impl Focusable for WelcomePage { - fn focus_handle(&self, _: &App) -> gpui::FocusHandle { - self.focus_handle.clone() - } -} - -impl Item for WelcomePage { - type Event = ItemEvent; - - fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { - "Welcome".into() - } - - fn telemetry_event_text(&self) -> Option<&'static str> { - Some("Welcome Page Opened") - } - - fn show_toolbar(&self) -> bool { - false - } - - fn clone_on_split( - &self, - _workspace_id: Option, - _: &mut Window, - cx: &mut Context, - ) -> Option> { - Some(cx.new(|cx| WelcomePage { - focus_handle: cx.focus_handle(), - workspace: self.workspace.clone(), - telemetry: self.telemetry.clone(), - _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), - })) - } - - fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { - f(*event) - } -} diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index ca63d3e553..ae72df3971 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -253,7 +253,7 @@ impl Dock { cx: &mut Context, ) -> Entity { let focus_handle = cx.focus_handle(); - let workspace = cx.entity().clone(); + let workspace = cx.entity(); let dock = cx.new(|cx| { let focus_subscription = cx.on_focus(&focus_handle, window, |dock: &mut Dock, window, cx| { diff --git a/crates/workspace/src/item.rs b/crates/workspace/src/item.rs index c8ebe4550b..0c5543650e 100644 --- a/crates/workspace/src/item.rs +++ b/crates/workspace/src/item.rs @@ -270,6 +270,12 @@ pub trait Item: Focusable + EventEmitter + Render + Sized { /// Returns the textual contents of the tab. fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString; + /// Returns the suggested filename for saving this item. + /// By default, returns the tab content text. + fn suggested_filename(&self, cx: &App) -> SharedString { + self.tab_content_text(0, cx) + } + fn tab_icon(&self, _window: &Window, _cx: &App) -> Option { None } @@ -293,6 +299,7 @@ pub trait Item: Focusable + EventEmitter + Render + Sized { fn deactivated(&mut self, _window: &mut Window, _: &mut Context) {} fn discarded(&self, _project: Entity, _window: &mut Window, _cx: &mut Context) {} + fn on_removed(&self, _cx: &App) {} fn workspace_deactivated(&mut self, _window: &mut Window, _: &mut Context) {} fn navigate(&mut self, _: Box, _window: &mut Window, _: &mut Context) -> bool { false @@ -496,6 +503,7 @@ pub trait ItemHandle: 'static + Send { ) -> gpui::Subscription; fn tab_content(&self, params: TabContentParams, window: &Window, cx: &App) -> AnyElement; fn tab_content_text(&self, detail: usize, cx: &App) -> SharedString; + fn suggested_filename(&self, cx: &App) -> SharedString; fn tab_icon(&self, window: &Window, cx: &App) -> Option; fn tab_tooltip_text(&self, cx: &App) -> Option; fn tab_tooltip_content(&self, cx: &App) -> Option; @@ -532,6 +540,7 @@ pub trait ItemHandle: 'static + Send { ); fn deactivated(&self, window: &mut Window, cx: &mut App); fn discarded(&self, project: Entity, window: &mut Window, cx: &mut App); + fn on_removed(&self, cx: &App); fn workspace_deactivated(&self, window: &mut Window, cx: &mut App); fn navigate(&self, data: Box, window: &mut Window, cx: &mut App) -> bool; fn item_id(&self) -> EntityId; @@ -629,6 +638,10 @@ impl ItemHandle for Entity { self.read(cx).tab_content_text(detail, cx) } + fn suggested_filename(&self, cx: &App) -> SharedString { + self.read(cx).suggested_filename(cx) + } + fn tab_icon(&self, window: &Window, cx: &App) -> Option { self.read(cx).tab_icon(window, cx) } @@ -968,6 +981,10 @@ impl ItemHandle for Entity { self.update(cx, |this, cx| this.deactivated(window, cx)); } + fn on_removed(&self, cx: &App) { + self.read(cx).on_removed(cx); + } + fn workspace_deactivated(&self, window: &mut Window, cx: &mut App) { self.update(cx, |this, cx| this.workspace_deactivated(window, cx)); } diff --git a/crates/workspace/src/notifications.rs b/crates/workspace/src/notifications.rs index 96966435e1..1356322a5c 100644 --- a/crates/workspace/src/notifications.rs +++ b/crates/workspace/src/notifications.rs @@ -6,6 +6,7 @@ use gpui::{ Task, svg, }; use parking_lot::Mutex; + use std::ops::Deref; use std::sync::{Arc, LazyLock}; use std::{any::TypeId, time::Duration}; @@ -189,6 +190,7 @@ impl Workspace { cx.notify(); } + /// Hide all notifications matching the given ID pub fn suppress_notification(&mut self, id: &NotificationId, cx: &mut Context) { self.dismiss_notification(id, cx); self.suppressed_notifications.insert(id.clone()); @@ -344,7 +346,7 @@ impl Render for LanguageServerPrompt { ) .child(Label::new(request.message.to_string()).size(LabelSize::Small)) .children(request.actions.iter().enumerate().map(|(ix, action)| { - let this_handle = cx.entity().clone(); + let this_handle = cx.entity(); Button::new(ix, action.title.clone()) .size(ButtonSize::Large) .on_click(move |_, window, cx| { @@ -462,16 +464,144 @@ impl EventEmitter for ErrorMessagePrompt {} impl Notification for ErrorMessagePrompt {} +#[derive(IntoElement, RegisterComponent)] +pub struct NotificationFrame { + title: Option, + show_suppress_button: bool, + show_close_button: bool, + close: Option>, + contents: Option, + suffix: Option, +} + +impl NotificationFrame { + pub fn new() -> Self { + Self { + title: None, + contents: None, + suffix: None, + show_suppress_button: true, + show_close_button: true, + close: None, + } + } + + pub fn with_title(mut self, title: Option>) -> Self { + self.title = title.map(Into::into); + self + } + + pub fn with_content(self, content: impl IntoElement) -> Self { + Self { + contents: Some(content.into_any_element()), + ..self + } + } + + /// Determines whether the given notification ID should be suppressible + /// Suppressed motifications will not be shown anymore + pub fn show_suppress_button(mut self, show: bool) -> Self { + self.show_suppress_button = show; + self + } + + pub fn show_close_button(mut self, show: bool) -> Self { + self.show_close_button = show; + self + } + + pub fn on_close(self, on_close: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self { + Self { + close: Some(Box::new(on_close)), + ..self + } + } + + pub fn with_suffix(mut self, suffix: impl IntoElement) -> Self { + self.suffix = Some(suffix.into_any_element()); + self + } +} + +impl RenderOnce for NotificationFrame { + fn render(mut self, window: &mut Window, cx: &mut App) -> impl IntoElement { + let entity = window.current_view(); + let show_suppress_button = self.show_suppress_button; + let suppress = show_suppress_button && window.modifiers().shift; + let (close_id, close_icon) = if suppress { + ("suppress", IconName::Minimize) + } else { + ("close", IconName::Close) + }; + + v_flex() + .occlude() + .p_3() + .gap_2() + .elevation_3(cx) + .child( + h_flex() + .gap_4() + .justify_between() + .items_start() + .child( + v_flex() + .gap_0p5() + .when_some(self.title.clone(), |div, title| { + div.child(Label::new(title)) + }) + .child(div().max_w_96().children(self.contents)), + ) + .when(self.show_close_button, |this| { + this.on_modifiers_changed(move |_, _, cx| cx.notify(entity)) + .child( + IconButton::new(close_id, close_icon) + .tooltip(move |window, cx| { + if suppress { + Tooltip::for_action( + "Suppress.\nClose with click.", + &SuppressNotification, + window, + cx, + ) + } else if show_suppress_button { + Tooltip::for_action( + "Close.\nSuppress with shift-click.", + &menu::Cancel, + window, + cx, + ) + } else { + Tooltip::for_action("Close", &menu::Cancel, window, cx) + } + }) + .on_click({ + let close = self.close.take(); + move |_, window, cx| { + if let Some(close) = &close { + close(&suppress, window, cx) + } + } + }), + ) + }), + ) + .children(self.suffix) + } +} + +impl Component for NotificationFrame {} + pub mod simple_message_notification { use std::sync::Arc; use gpui::{ - AnyElement, ClickEvent, DismissEvent, EventEmitter, FocusHandle, Focusable, ParentElement, - Render, SharedString, Styled, div, + AnyElement, DismissEvent, EventEmitter, FocusHandle, Focusable, ParentElement, Render, + SharedString, Styled, }; - use ui::{Tooltip, prelude::*}; + use ui::prelude::*; - use crate::SuppressNotification; + use crate::notifications::NotificationFrame; use super::{Notification, SuppressEvent}; @@ -631,6 +761,8 @@ pub mod simple_message_notification { self } + /// Determines whether the given notification ID should be supressable + /// Suppressed motifications will not be shown anymor pub fn show_suppress_button(mut self, show: bool) -> Self { self.show_suppress_button = show; self @@ -647,71 +779,19 @@ pub mod simple_message_notification { impl Render for MessageNotification { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let show_suppress_button = self.show_suppress_button; - let suppress = show_suppress_button && window.modifiers().shift; - let (close_id, close_icon) = if suppress { - ("suppress", IconName::Minimize) - } else { - ("close", IconName::Close) - }; - - v_flex() - .occlude() - .p_3() - .gap_2() - .elevation_3(cx) - .child( - h_flex() - .gap_4() - .justify_between() - .items_start() - .child( - v_flex() - .gap_0p5() - .when_some(self.title.clone(), |element, title| { - element.child(Label::new(title)) - }) - .child(div().max_w_96().child((self.build_content)(window, cx))), - ) - .when(self.show_close_button, |this| { - this.on_modifiers_changed(cx.listener(|_, _, _, cx| cx.notify())) - .child( - IconButton::new(close_id, close_icon) - .tooltip(move |window, cx| { - if suppress { - Tooltip::for_action( - "Suppress.\nClose with click.", - &SuppressNotification, - window, - cx, - ) - } else if show_suppress_button { - Tooltip::for_action( - "Close.\nSuppress with shift-click.", - &menu::Cancel, - window, - cx, - ) - } else { - Tooltip::for_action( - "Close", - &menu::Cancel, - window, - cx, - ) - } - }) - .on_click(cx.listener(move |_, _: &ClickEvent, _, cx| { - if suppress { - cx.emit(SuppressEvent); - } else { - cx.emit(DismissEvent); - } - })), - ) - }), - ) - .child( + NotificationFrame::new() + .with_title(self.title.clone()) + .with_content((self.build_content)(window, cx)) + .show_close_button(self.show_close_button) + .show_suppress_button(self.show_suppress_button) + .on_close(cx.listener(|_, suppress, _, cx| { + if *suppress { + cx.emit(SuppressEvent); + } else { + cx.emit(DismissEvent); + } + })) + .with_suffix( h_flex() .gap_1() .children(self.primary_message.iter().map(|message| { diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 0c35752165..860a57c21f 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -1829,6 +1829,7 @@ impl Pane { let mode = self.nav_history.mode(); self.nav_history.set_mode(NavigationMode::ClosingItem); item.deactivated(window, cx); + item.on_removed(cx); self.nav_history.set_mode(mode); if self.is_active_preview_item(item.item_id()) { @@ -2061,6 +2062,8 @@ impl Pane { })? .await?; } else if can_save_as && is_singleton { + let suggested_name = + cx.update(|_window, cx| item.suggested_filename(cx).to_string())?; let new_path = pane.update_in(cx, |pane, window, cx| { pane.activate_item(item_ix, true, true, window, cx); pane.workspace.update(cx, |workspace, cx| { @@ -2072,7 +2075,7 @@ impl Pane { } else { DirectoryLister::Project(workspace.project().clone()) }; - workspace.prompt_for_new_path(lister, window, cx) + workspace.prompt_for_new_path(lister, Some(suggested_name), window, cx) }) })??; let Some(new_path) = new_path.await.ok().flatten().into_iter().flatten().next() @@ -2195,7 +2198,7 @@ impl Pane { fn update_status_bar(&mut self, window: &mut Window, cx: &mut Context) { let workspace = self.workspace.clone(); - let pane = cx.entity().clone(); + let pane = cx.entity(); window.defer(cx, move |window, cx| { let Ok(status_bar) = @@ -2276,7 +2279,7 @@ impl Pane { cx: &mut Context, ) { maybe!({ - let pane = cx.entity().clone(); + let pane = cx.entity(); let destination_index = match operation { PinOperation::Pin => self.pinned_tab_count.min(ix), @@ -2470,15 +2473,26 @@ impl Pane { .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail, is_active, ix, }, |tab, _, _, cx| cx.new(|_| tab.clone()), ) - .drag_over::(|tab, _, _, cx| { - tab.bg(cx.theme().colors().drop_target_background) + .drag_over::(move |tab, dragged_tab: &DraggedTab, _, cx| { + let mut styled_tab = tab + .bg(cx.theme().colors().drop_target_background) + .border_color(cx.theme().colors().drop_target_border) + .border_0(); + + if ix < dragged_tab.ix { + styled_tab = styled_tab.border_l_2(); + } else if ix > dragged_tab.ix { + styled_tab = styled_tab.border_r_2(); + } + + styled_tab }) .drag_over::(|tab, _, _, cx| { tab.bg(cx.theme().colors().drop_target_background) @@ -2818,7 +2832,7 @@ impl Pane { let navigate_backward = IconButton::new("navigate_backward", IconName::ArrowLeft) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| { entity.update(cx, |pane, cx| pane.navigate_backward(window, cx)) } @@ -2834,7 +2848,7 @@ impl Pane { let navigate_forward = IconButton::new("navigate_forward", IconName::ArrowRight) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx)) }) .disabled(!self.can_navigate_forward()) @@ -3040,7 +3054,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let item_id = dragged_tab.item.item_id(); if let Some(preview_item_id) = self.preview_item_id { @@ -3149,7 +3163,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let project_entry_id = *project_entry_id; self.workspace @@ -3225,7 +3239,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let mut split_direction = self.drag_split_direction; let paths = paths.paths().to_vec(); let is_remote = self diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 6fa5c969e7..b2d1340a7b 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -542,6 +542,20 @@ define_connection! { ALTER TABLE breakpoints ADD COLUMN condition TEXT; ALTER TABLE breakpoints ADD COLUMN hit_condition TEXT; ), + sql!(CREATE TABLE toolchains2 ( + workspace_id INTEGER, + worktree_id INTEGER, + language_name TEXT NOT NULL, + name TEXT NOT NULL, + path TEXT NOT NULL, + raw_json TEXT NOT NULL, + relative_worktree_path TEXT NOT NULL, + PRIMARY KEY (workspace_id, worktree_id, language_name, relative_worktree_path)) STRICT; + INSERT INTO toolchains2 + SELECT * FROM toolchains; + DROP TABLE toolchains; + ALTER TABLE toolchains2 RENAME TO toolchains; + ) ]; } @@ -1428,12 +1442,12 @@ impl WorkspaceDb { self.write(move |conn| { let mut insert = conn .exec_bound(sql!( - INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path) VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO UPDATE SET name = ?5, - path = ?6 - + path = ?6, + raw_json = ?7 )) .context("Preparing insertion")?; @@ -1444,6 +1458,7 @@ impl WorkspaceDb { toolchain.language_name.as_ref(), toolchain.name.as_ref(), toolchain.path.as_ref(), + toolchain.as_json.to_string(), ))?; Ok(()) diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index aab8a36f45..1eaa125ba5 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -15,6 +15,8 @@ mod toast_layer; mod toolbar; mod workspace_settings; +pub use crate::notifications::NotificationFrame; +pub use dock::Panel; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; use anyhow::{Context as _, Result, anyhow}; @@ -24,7 +26,6 @@ use client::{ proto::{self, ErrorCode, PanelId, PeerId}, }; use collections::{HashMap, HashSet, hash_map}; -pub use dock::Panel; use dock::{Dock, DockPosition, PanelButtons, PanelHandle, RESIZE_HANDLE_SIZE}; use futures::{ Future, FutureExt, StreamExt, @@ -224,6 +225,8 @@ actions!( ResetActiveDockSize, /// Resets all open docks to their default sizes. ResetOpenDocksSize, + /// Reloads the application + Reload, /// Saves the current file with a new name. SaveAs, /// Saves without formatting. @@ -246,8 +249,6 @@ actions!( ToggleZoom, /// Stops following a collaborator. Unfollow, - /// Shows the welcome screen. - Welcome, /// Restores the banner. RestoreBanner, /// Toggles expansion of the selected item. @@ -340,14 +341,6 @@ pub struct CloseInactiveTabsAndPanes { #[action(namespace = workspace)] pub struct SendKeystrokes(pub String); -/// Reloads the active item or workspace. -#[derive(Clone, Deserialize, PartialEq, Default, JsonSchema, Action)] -#[action(namespace = workspace)] -#[serde(deny_unknown_fields)] -pub struct Reload { - pub binary_path: Option, -} - actions!( project_symbols, [ @@ -555,8 +548,8 @@ pub fn init(app_state: Arc, cx: &mut App) { toast_layer::init(cx); history_manager::init(cx); - cx.on_action(Workspace::close_global); - cx.on_action(reload); + cx.on_action(|_: &CloseWindow, cx| Workspace::close_global(cx)); + cx.on_action(|_: &Reload, cx| reload(cx)); cx.on_action({ let app_state = Arc::downgrade(&app_state); @@ -2074,6 +2067,7 @@ impl Workspace { pub fn prompt_for_new_path( &mut self, lister: DirectoryLister, + suggested_name: Option, window: &mut Window, cx: &mut Context, ) -> oneshot::Receiver>> { @@ -2101,7 +2095,7 @@ impl Workspace { }) .or_else(std::env::home_dir) .unwrap_or_else(|| PathBuf::from("")); - cx.prompt_for_new_path(&relative_to) + cx.prompt_for_new_path(&relative_to, suggested_name.as_deref()) })?; let abs_path = match abs_path.await? { Ok(path) => path, @@ -2184,7 +2178,7 @@ impl Workspace { } } - pub fn close_global(_: &CloseWindow, cx: &mut App) { + pub fn close_global(cx: &mut App) { cx.defer(|cx| { cx.windows().iter().find(|window| { window @@ -3886,9 +3880,7 @@ impl Workspace { local, focus_changed, } => { - cx.on_next_frame(window, |_, window, _| { - window.invalidate_character_coordinates(); - }); + window.invalidate_character_coordinates(); pane.update(cx, |pane, _| { pane.track_alternate_file_items(); @@ -3929,9 +3921,7 @@ impl Workspace { } } pane::Event::Focus => { - cx.on_next_frame(window, |_, window, _| { - window.invalidate_character_coordinates(); - }); + window.invalidate_character_coordinates(); self.handle_pane_focused(pane.clone(), window, cx); } pane::Event::ZoomIn => { @@ -6348,7 +6338,7 @@ impl Render for Workspace { .border_b_1() .border_color(colors.border) .child({ - let this = cx.entity().clone(); + let this = cx.entity(); canvas( move |bounds, window, cx| { this.update(cx, |this, cx| { @@ -6672,25 +6662,15 @@ impl Render for Workspace { } }) .children(self.zoomed.as_ref().and_then(|view| { - let zoomed_view = view.upgrade()?; - let div = div() + Some(div() .occlude() .absolute() .overflow_hidden() .border_color(colors.border) .bg(colors.background) - .child(zoomed_view) + .child(view.upgrade()?) .inset_0() - .shadow_lg(); - - Some(match self.zoomed_position { - Some(DockPosition::Left) => div.right_2().border_r_1(), - Some(DockPosition::Right) => div.left_2().border_l_1(), - Some(DockPosition::Bottom) => div.top_2().border_t_1(), - None => { - div.top_2().bottom_2().left_2().right_2().border_1() - } - }) + .shadow_lg()) })) .children(self.render_notifications(window, cx)), ) @@ -7642,7 +7622,7 @@ pub fn join_in_room_project( }) } -pub fn reload(reload: &Reload, cx: &mut App) { +pub fn reload(cx: &mut App) { let should_confirm = WorkspaceSettings::get_global(cx).confirm_quit; let mut workspace_windows = cx .windows() @@ -7669,7 +7649,6 @@ pub fn reload(reload: &Reload, cx: &mut App) { .ok(); } - let binary_path = reload.binary_path.clone(); cx.spawn(async move |cx| { if let Some(prompt) = prompt { let answer = prompt.await?; @@ -7688,8 +7667,7 @@ pub fn reload(reload: &Reload, cx: &mut App) { } } } - - cx.update(|cx| cx.restart(binary_path)) + cx.update(|cx| cx.restart()) }) .detach_and_log_err(cx); } diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index ac116b2f8f..23cd5b9320 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -105,6 +105,10 @@ impl Model { } } + pub fn supports_prompt_cache_key(&self) -> bool { + false + } + pub fn supports_tool(&self) -> bool { match self { Self::Grok2Vision diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 5997e43864..d69efaf6c0 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.200.0" +version = "0.201.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] @@ -82,6 +82,7 @@ inspector_ui.workspace = true install_cli.workspace = true jj_ui.workspace = true journal.workspace = true +livekit_client.workspace = true language.workspace = true language_extension.workspace = true language_model.workspace = true @@ -157,7 +158,6 @@ vim_mode_setting.workspace = true watch.workspace = true web_search.workspace = true web_search_providers.workspace = true -welcome.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true diff --git a/crates/zed/resources/windows/app-icon-nightly.ico b/crates/zed/resources/windows/app-icon-nightly.ico index 165e4ce1f7..875c0d7b35 100644 Binary files a/crates/zed/resources/windows/app-icon-nightly.ico and b/crates/zed/resources/windows/app-icon-nightly.ico differ diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e4a14b5d32..2a82f81b5b 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -8,6 +8,7 @@ use cli::FORCE_CLI_MODE_ENV_VAR_NAME; use client::{Client, ProxySettings, UserStore, parse_zed_link}; use collab_ui::channel_view::ChannelView; use collections::HashMap; +use crashes::InitCrashHandler; use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE}; use editor::Editor; use extension::ExtensionHostProxy; @@ -20,6 +21,7 @@ use gpui::{App, AppContext as _, Application, AsyncApp, Focusable as _, UpdateGl use gpui_tokio::Tokio; use http_client::{Url, read_proxy_from_env}; use language::LanguageRegistry; +use onboarding::{FIRST_OPEN, show_onboarding_view}; use prompt_store::PromptBuilder; use reqwest_client::ReqwestClient; @@ -44,7 +46,6 @@ use theme::{ }; use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; -use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ AppState, SerializedWorkspaceLocation, Toast, Workspace, WorkspaceSettings, WorkspaceStore, notifications::NotificationId, @@ -201,16 +202,6 @@ pub fn main() { return; } - // Check if there is a pending installer - // If there is, run the installer and exit - // And we don't want to run the installer if we are not the first instance - #[cfg(target_os = "windows")] - let is_first_instance = crate::zed::windows_only_instance::is_first_instance(); - #[cfg(target_os = "windows")] - if is_first_instance && auto_update::check_pending_installation() { - return; - } - if args.dump_all_actions { dump_all_gpui_actions(); return; @@ -261,7 +252,15 @@ pub fn main() { return; } - log::info!("========== starting zed =========="); + log::info!( + "========== starting zed version {}, sha {} ==========", + app_version, + app_commit_sha + .as_ref() + .map(|sha| sha.short()) + .as_deref() + .unwrap_or("unknown"), + ); let app = Application::new().with_assets(Assets); @@ -271,7 +270,15 @@ pub fn main() { let session = app.background_executor().block(Session::new()); app.background_executor() - .spawn(crashes::init(session_id.clone())) + .spawn(crashes::init(InitCrashHandler { + session_id: session_id.clone(), + zed_version: app_version.to_string(), + release_channel: release_channel::RELEASE_CHANNEL_NAME.clone(), + commit_sha: app_commit_sha + .as_ref() + .map(|sha| sha.full()) + .unwrap_or_else(|| "no sha".to_owned()), + })) .detach(); reliability::init_panic_hook( app_version, @@ -283,30 +290,27 @@ pub fn main() { let (open_listener, mut open_rx) = OpenListener::new(); - let failed_single_instance_check = - if *db::ZED_STATELESS || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev { - false - } else { - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - { - crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() - } + let failed_single_instance_check = if *db::ZED_STATELESS + || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev + { + false + } else { + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + { + crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() + } - #[cfg(target_os = "windows")] - { - !crate::zed::windows_only_instance::handle_single_instance( - open_listener.clone(), - &args, - is_first_instance, - ) - } + #[cfg(target_os = "windows")] + { + !crate::zed::windows_only_instance::handle_single_instance(open_listener.clone(), &args) + } - #[cfg(target_os = "macos")] - { - use zed::mac_only_instance::*; - ensure_only_instance() != IsOnlyInstance::Yes - } - }; + #[cfg(target_os = "macos")] + { + use zed::mac_only_instance::*; + ensure_only_instance() != IsOnlyInstance::Yes + } + }; if failed_single_instance_check { println!("zed is already running"); return; @@ -628,7 +632,6 @@ pub fn main() { feedback::init(cx); markdown_preview::init(cx); svg_preview::init(cx); - welcome::init(cx); onboarding::init(cx); settings_ui::init(cx); extensions_ui::init(cx); @@ -1049,7 +1052,7 @@ async fn restore_or_create_workspace(app_state: Arc, cx: &mut AsyncApp } } } else if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) { - cx.update(|cx| show_welcome_view(app_state, cx))?.await?; + cx.update(|cx| show_onboarding_view(app_state, cx))?.await?; } else { cx.update(|cx| { workspace::open_new( diff --git a/crates/zed/src/reliability.rs b/crates/zed/src/reliability.rs index 53539699cc..c27f4cb0a8 100644 --- a/crates/zed/src/reliability.rs +++ b/crates/zed/src/reliability.rs @@ -12,6 +12,7 @@ use gpui::{App, AppContext as _, SemanticVersion}; use http_client::{self, HttpClient, HttpClientWithUrl, HttpRequestExt, Method}; use paths::{crashes_dir, crashes_retired_dir}; use project::Project; +use proto::{CrashReport, GetCrashFilesResponse}; use release_channel::{AppCommitSha, RELEASE_CHANNEL, ReleaseChannel}; use reqwest::multipart::{Form, Part}; use settings::Settings; @@ -51,10 +52,6 @@ pub fn init_panic_hook( thread::yield_now(); } } - crashes::handle_panic(); - - let thread = thread::current(); - let thread_name = thread.name().unwrap_or(""); let payload = info .payload() @@ -63,6 +60,11 @@ pub fn init_panic_hook( .or_else(|| info.payload().downcast_ref::().cloned()) .unwrap_or_else(|| "Box".to_string()); + crashes::handle_panic(payload.clone(), info.location()); + + let thread = thread::current(); + let thread_name = thread.name().unwrap_or(""); + if *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev { let location = info.location().unwrap(); let backtrace = Backtrace::new(); @@ -214,45 +216,53 @@ pub fn init( let installation_id = installation_id.clone(); let system_id = system_id.clone(); - if let Some(ssh_client) = project.ssh_client() { - ssh_client.update(cx, |client, cx| { - if TelemetrySettings::get_global(cx).diagnostics { - let request = client.proto_client().request(proto::GetCrashFiles {}); - cx.background_spawn(async move { - let crash_files = request.await?; - for crash in crash_files.crashes { - let mut panic: Option = crash - .panic_contents - .and_then(|s| serde_json::from_str(&s).log_err()); + let Some(ssh_client) = project.ssh_client() else { + return; + }; + ssh_client.update(cx, |client, cx| { + if !TelemetrySettings::get_global(cx).diagnostics { + return; + } + let request = client.proto_client().request(proto::GetCrashFiles {}); + cx.background_spawn(async move { + let GetCrashFilesResponse { + legacy_panics, + crashes, + } = request.await?; - if let Some(panic) = panic.as_mut() { - panic.session_id = session_id.clone(); - panic.system_id = system_id.clone(); - panic.installation_id = installation_id.clone(); - } - - if let Some(minidump) = crash.minidump_contents { - upload_minidump( - http_client.clone(), - minidump.clone(), - panic.as_ref(), - ) - .await - .log_err(); - } - - if let Some(panic) = panic { - upload_panic(&http_client, &panic_report_url, panic, &mut None) - .await?; - } - } - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + for panic in legacy_panics { + if let Some(mut panic) = serde_json::from_str::(&panic).log_err() { + panic.session_id = session_id.clone(); + panic.system_id = system_id.clone(); + panic.installation_id = installation_id.clone(); + upload_panic(&http_client, &panic_report_url, panic, &mut None).await?; + } } + + let Some(endpoint) = MINIDUMP_ENDPOINT.as_ref() else { + return Ok(()); + }; + for CrashReport { + metadata, + minidump_contents, + } in crashes + { + if let Some(metadata) = serde_json::from_str(&metadata).log_err() { + upload_minidump( + http_client.clone(), + endpoint, + minidump_contents, + &metadata, + ) + .await + .log_err(); + } + } + + anyhow::Ok(()) }) - } + .detach_and_log_err(cx); + }) }) .detach(); } @@ -466,16 +476,18 @@ fn upload_panics_and_crashes( installation_id: Option, cx: &App, ) { - let telemetry_settings = *client::TelemetrySettings::get_global(cx); + if !client::TelemetrySettings::get_global(cx).diagnostics { + return; + } cx.background_spawn(async move { - let most_recent_panic = - upload_previous_panics(http.clone(), &panic_report_url, telemetry_settings) - .await - .log_err() - .flatten(); - upload_previous_crashes(http, most_recent_panic, installation_id, telemetry_settings) + upload_previous_minidumps(http.clone()).await.warn_on_err(); + let most_recent_panic = upload_previous_panics(http.clone(), &panic_report_url) .await .log_err() + .flatten(); + upload_previous_crashes(http, most_recent_panic, installation_id) + .await + .log_err(); }) .detach() } @@ -484,7 +496,6 @@ fn upload_panics_and_crashes( async fn upload_previous_panics( http: Arc, panic_report_url: &Url, - telemetry_settings: client::TelemetrySettings, ) -> anyhow::Result> { let mut children = smol::fs::read_dir(paths::logs_dir()).await?; @@ -507,58 +518,41 @@ async fn upload_previous_panics( continue; } - if telemetry_settings.diagnostics { - let panic_file_content = smol::fs::read_to_string(&child_path) - .await - .context("error reading panic file")?; + let panic_file_content = smol::fs::read_to_string(&child_path) + .await + .context("error reading panic file")?; - let panic: Option = serde_json::from_str(&panic_file_content) - .log_err() - .or_else(|| { - panic_file_content - .lines() - .next() - .and_then(|line| serde_json::from_str(line).ok()) - }) - .unwrap_or_else(|| { - log::error!("failed to deserialize panic file {:?}", panic_file_content); - None - }); + let panic: Option = serde_json::from_str(&panic_file_content) + .log_err() + .or_else(|| { + panic_file_content + .lines() + .next() + .and_then(|line| serde_json::from_str(line).ok()) + }) + .unwrap_or_else(|| { + log::error!("failed to deserialize panic file {:?}", panic_file_content); + None + }); - if let Some(panic) = panic { - let minidump_path = paths::logs_dir() - .join(&panic.session_id) - .with_extension("dmp"); - if minidump_path.exists() { - let minidump = smol::fs::read(&minidump_path) - .await - .context("Failed to read minidump")?; - if upload_minidump(http.clone(), minidump, Some(&panic)) - .await - .log_err() - .is_some() - { - fs::remove_file(minidump_path).ok(); - } - } - - if !upload_panic(&http, &panic_report_url, panic, &mut most_recent_panic).await? { - continue; - } - } + if let Some(panic) = panic + && upload_panic(&http, &panic_report_url, panic, &mut most_recent_panic).await? + { + // We've done what we can, delete the file + fs::remove_file(child_path) + .context("error removing panic") + .log_err(); } - - // We've done what we can, delete the file - fs::remove_file(child_path) - .context("error removing panic") - .log_err(); } - if MINIDUMP_ENDPOINT.is_none() { - return Ok(most_recent_panic); - } + Ok(most_recent_panic) +} + +pub async fn upload_previous_minidumps(http: Arc) -> anyhow::Result<()> { + let Some(minidump_endpoint) = MINIDUMP_ENDPOINT.as_ref() else { + return Err(anyhow::anyhow!("Minidump endpoint not set")); + }; - // loop back over the directory again to upload any minidumps that are missing panics let mut children = smol::fs::read_dir(paths::logs_dir()).await?; while let Some(child) = children.next().await { let child = child?; @@ -566,33 +560,35 @@ async fn upload_previous_panics( if child_path.extension() != Some(OsStr::new("dmp")) { continue; } - if upload_minidump( - http.clone(), - smol::fs::read(&child_path) - .await - .context("Failed to read minidump")?, - None, - ) - .await - .log_err() - .is_some() - { - fs::remove_file(child_path).ok(); + let mut json_path = child_path.clone(); + json_path.set_extension("json"); + if let Ok(metadata) = serde_json::from_slice(&smol::fs::read(&json_path).await?) { + if upload_minidump( + http.clone(), + &minidump_endpoint, + smol::fs::read(&child_path) + .await + .context("Failed to read minidump")?, + &metadata, + ) + .await + .log_err() + .is_some() + { + fs::remove_file(child_path).ok(); + fs::remove_file(json_path).ok(); + } } } - - Ok(most_recent_panic) + Ok(()) } async fn upload_minidump( http: Arc, + endpoint: &str, minidump: Vec, - panic: Option<&Panic>, + metadata: &crashes::CrashInfo, ) -> Result<()> { - let minidump_endpoint = MINIDUMP_ENDPOINT - .to_owned() - .ok_or_else(|| anyhow::anyhow!("Minidump endpoint not set"))?; - let mut form = Form::new() .part( "upload_file_minidump", @@ -600,18 +596,22 @@ async fn upload_minidump( .file_name("minidump.dmp") .mime_str("application/octet-stream")?, ) + .text( + "sentry[tags][channel]", + metadata.init.release_channel.clone(), + ) + .text("sentry[tags][version]", metadata.init.zed_version.clone()) + .text("sentry[release]", metadata.init.commit_sha.clone()) .text("platform", "rust"); - if let Some(panic) = panic { - form = form - .text( - "sentry[release]", - format!("{}-{}", panic.release_channel, panic.app_version), - ) - .text("sentry[logentry][formatted]", panic.payload.clone()); + if let Some(panic_info) = metadata.panic.as_ref() { + form = form.text("sentry[logentry][formatted]", panic_info.message.clone()); + form = form.text("span", panic_info.span.clone()); + // TODO: add gpu-context, feature-flag-context, and more of device-context like gpu + // name, screen resolution, available ram, device model, etc } let mut response_text = String::new(); - let mut response = http.send_multipart_form(&minidump_endpoint, form).await?; + let mut response = http.send_multipart_form(endpoint, form).await?; response .body_mut() .read_to_string(&mut response_text) @@ -661,11 +661,7 @@ async fn upload_previous_crashes( http: Arc, most_recent_panic: Option<(i64, String)>, installation_id: Option, - telemetry_settings: client::TelemetrySettings, ) -> Result<()> { - if !telemetry_settings.diagnostics { - return Ok(()); - } let last_uploaded = KEY_VALUE_STORE .read_kvp(LAST_CRASH_UPLOADED)? .unwrap_or("zed-2024-01-17-221900.ips".to_string()); // don't upload old crash reports from before we had this. diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 8c89a7d85a..cfafbb70f0 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -31,9 +31,12 @@ use gpui::{ px, retain_all, }; use image_viewer::ImageInfo; +use language::Capability; use language_tools::lsp_tool::{self, LspTool}; use migrate::{MigrationBanner, MigrationEvent, MigrationNotification, MigrationType}; use migrator::{migrate_keymap, migrate_settings}; +use onboarding::DOCS_URL; +use onboarding::multibuffer_hint::MultibufferHint; pub use open_listener::*; use outline_panel::OutlinePanel; use paths::{ @@ -54,6 +57,7 @@ use settings::{ initial_local_debug_tasks_content, initial_project_settings_content, initial_tasks_content, update_settings_file, }; +use std::time::{Duration, Instant}; use std::{ borrow::Cow, path::{Path, PathBuf}, @@ -67,14 +71,17 @@ use util::markdown::MarkdownString; use util::{ResultExt, asset_str}; use uuid::Uuid; use vim_mode_setting::VimModeSetting; -use welcome::{DOCS_URL, MultibufferHint}; -use workspace::notifications::{NotificationId, dismiss_app_notification, show_app_notification}; +use workspace::notifications::{ + NotificationId, SuppressEvent, dismiss_app_notification, show_app_notification, +}; use workspace::{ AppState, NewFile, NewWindow, OpenLog, Toast, Workspace, WorkspaceSettings, create_and_open_local_file, notifications::simple_message_notification::MessageNotification, open_new, }; -use workspace::{CloseIntent, CloseWindow, RestoreBanner, with_active_or_new_workspace}; +use workspace::{ + CloseIntent, CloseWindow, NotificationFrame, RestoreBanner, with_active_or_new_workspace, +}; use workspace::{Pane, notifications::DetachAndPromptErr}; use zed_actions::{ OpenAccountSettings, OpenBrowser, OpenDocs, OpenServerSettings, OpenSettings, OpenZedUrl, Quit, @@ -116,6 +123,14 @@ actions!( ] ); +actions!( + dev, + [ + /// Record 10s of audio from your current microphone + CaptureAudio + ] +); + pub fn init(cx: &mut App) { #[cfg(target_os = "macos")] cx.on_action(|_: &Hide, cx| cx.hide()); @@ -305,7 +320,7 @@ pub fn initialize_workspace( return; }; - let workspace_handle = cx.entity().clone(); + let workspace_handle = cx.entity(); let center_pane = workspace.active_pane().clone(); initialize_pane(workspace, ¢er_pane, window, cx); @@ -701,9 +716,7 @@ fn register_actions( .insert(theme::clamp_font_size(ui_font_size).0); }); } else { - theme::adjust_ui_font_size(cx, |size| { - *size += px(1.0); - }); + theme::adjust_ui_font_size(cx, |size| size + px(1.0)); } } }) @@ -718,9 +731,7 @@ fn register_actions( .insert(theme::clamp_font_size(ui_font_size).0); }); } else { - theme::adjust_ui_font_size(cx, |size| { - *size -= px(1.0); - }); + theme::adjust_ui_font_size(cx, |size| size - px(1.0)); } } }) @@ -748,9 +759,7 @@ fn register_actions( .insert(theme::clamp_font_size(buffer_font_size).0); }); } else { - theme::adjust_buffer_font_size(cx, |size| { - *size += px(1.0); - }); + theme::adjust_buffer_font_size(cx, |size| size + px(1.0)); } } }) @@ -766,9 +775,7 @@ fn register_actions( .insert(theme::clamp_font_size(buffer_font_size).0); }); } else { - theme::adjust_buffer_font_size(cx, |size| { - *size -= px(1.0); - }); + theme::adjust_buffer_font_size(cx, |size| size - px(1.0)); } } }) @@ -896,7 +903,11 @@ fn register_actions( .detach(); } } + }) + .register_action(|workspace, _: &CaptureAudio, window, cx| { + capture_audio(workspace, window, cx); }); + if workspace.project().read(cx).is_via_ssh() { workspace.register_action({ move |workspace, _: &OpenServerSettings, window, cx| { @@ -1746,7 +1757,11 @@ fn open_bundled_file( workspace.with_local_workspace(window, cx, |workspace, window, cx| { let project = workspace.project(); let buffer = project.update(cx, move |project, cx| { - project.create_local_buffer(text.as_ref(), language, cx) + let buffer = project.create_local_buffer(text.as_ref(), language, cx); + buffer.update(cx, |buffer, cx| { + buffer.set_capability(Capability::ReadOnly, cx); + }); + buffer }); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx).with_title(title.into())); @@ -1805,6 +1820,107 @@ fn open_settings_file( .detach_and_log_err(cx); } +fn capture_audio(workspace: &mut Workspace, _: &mut Window, cx: &mut Context) { + #[derive(Default)] + enum State { + Recording(livekit_client::CaptureInput), + Failed(String), + Finished(PathBuf), + // Used during state switch. Should never occur naturally. + #[default] + Invalid, + } + + struct CaptureAudioNotification { + focus_handle: gpui::FocusHandle, + start_time: Instant, + state: State, + } + + impl gpui::EventEmitter for CaptureAudioNotification {} + impl gpui::EventEmitter for CaptureAudioNotification {} + impl gpui::Focusable for CaptureAudioNotification { + fn focus_handle(&self, _cx: &App) -> gpui::FocusHandle { + self.focus_handle.clone() + } + } + impl workspace::notifications::Notification for CaptureAudioNotification {} + + const AUDIO_RECORDING_TIME_SECS: u64 = 10; + + impl Render for CaptureAudioNotification { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let elapsed = self.start_time.elapsed().as_secs(); + let message = match &self.state { + State::Recording(capture) => format!( + "Recording {} seconds of audio from input: '{}'", + AUDIO_RECORDING_TIME_SECS - elapsed, + capture.name, + ), + State::Failed(e) => format!("Error capturing audio: {e}"), + State::Finished(path) => format!("Audio recorded to {}", path.display()), + State::Invalid => "Error invalid state".to_string(), + }; + + NotificationFrame::new() + .with_title(Some("Recording Audio")) + .show_suppress_button(false) + .on_close(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })) + .with_content(message) + } + } + + impl CaptureAudioNotification { + fn finish(&mut self) { + let state = std::mem::take(&mut self.state); + self.state = if let State::Recording(capture) = state { + match capture.finish() { + Ok(path) => State::Finished(path), + Err(e) => State::Failed(e.to_string()), + } + } else { + state + }; + } + + fn new(cx: &mut Context) -> Self { + cx.spawn(async move |this, cx| { + for _ in 0..10 { + cx.background_executor().timer(Duration::from_secs(1)).await; + this.update(cx, |_, cx| { + cx.notify(); + })?; + } + + this.update(cx, |this, cx| { + this.finish(); + cx.notify(); + })?; + + anyhow::Ok(()) + }) + .detach(); + + let state = match livekit_client::CaptureInput::start() { + Ok(capture_input) => State::Recording(capture_input), + Err(err) => State::Failed(format!("Error starting audio capture: {}", err)), + }; + + Self { + focus_handle: cx.focus_handle(), + start_time: Instant::now(), + state, + } + } + } + + workspace.show_notification(NotificationId::unique::(), cx, |cx| { + cx.new(CaptureAudioNotification::new) + }); +} + #[cfg(test)] mod tests { use super::*; @@ -3975,7 +4091,6 @@ mod tests { client::init(&app_state.client, cx); language::init(cx); workspace::init(app_state.clone(), cx); - welcome::init(cx); onboarding::init(cx); Project::init_settings(cx); app_state @@ -4380,7 +4495,6 @@ mod tests { "toolchain", "variable_list", "vim", - "welcome", "workspace", "zed", "zed_predict_onboarding", @@ -4402,11 +4516,11 @@ mod tests { cx.text_system() .add_fonts(vec![ Assets - .load("fonts/plex-mono/ZedPlexMono-Regular.ttf") + .load("fonts/lilex/Lilex-Regular.ttf") .unwrap() .unwrap(), Assets - .load("fonts/plex-sans/ZedPlexSans-Regular.ttf") + .load("fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf") .unwrap() .unwrap(), ]) @@ -4426,6 +4540,43 @@ mod tests { assert!(has_default_theme); } + #[gpui::test] + async fn test_bundled_files_editor(cx: &mut TestAppContext) { + let app_state = init_test(cx); + cx.update(init); + + let project = Project::test(app_state.fs.clone(), [], cx).await; + let _window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx)); + + cx.update(|cx| { + cx.dispatch_action(&OpenDefaultSettings); + }); + cx.run_until_parked(); + + assert_eq!(cx.read(|cx| cx.windows().len()), 1); + + let workspace = cx.windows()[0].downcast::().unwrap(); + let active_editor = workspace + .update(cx, |workspace, _, cx| { + workspace.active_item_as::(cx) + }) + .unwrap(); + assert!( + active_editor.is_some(), + "Settings action should have opened an editor with the default file contents" + ); + + let active_editor = active_editor.unwrap(); + assert!( + active_editor.read_with(cx, |editor, cx| editor.read_only(cx)), + "Default settings should be readonly" + ); + assert!( + active_editor.read_with(cx, |editor, cx| editor.buffer().read(cx).read_only()), + "The underlying buffer should also be readonly for the shipped default settings" + ); + } + #[gpui::test] async fn test_bundled_languages(cx: &mut TestAppContext) { env_logger::builder().is_test(true).try_init().ok(); diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 15d5659f03..6c7ab0b374 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -35,10 +35,8 @@ pub fn app_menus() -> Vec { ], }), MenuItem::separator(), - MenuItem::submenu(Menu { - name: "Services".into(), - items: vec![], - }), + #[cfg(target_os = "macos")] + MenuItem::os_submenu("Services", gpui::SystemMenuType::Services), MenuItem::separator(), MenuItem::action("Extensions", zed_actions::Extensions::default()), MenuItem::action("Install CLI", install_cli::Install), @@ -252,7 +250,7 @@ pub fn app_menus() -> Vec { ), MenuItem::action("View Telemetry", zed_actions::OpenTelemetryLog), MenuItem::action("View Dependency Licenses", zed_actions::OpenLicenses), - MenuItem::action("Show Welcome", workspace::Welcome), + MenuItem::action("Show Welcome", onboarding::ShowWelcome), MenuItem::action("Give Feedback...", zed_actions::feedback::GiveFeedback), MenuItem::separator(), MenuItem::action( diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index ac889a7ad9..4609ecce9b 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -761,7 +761,7 @@ impl Render for ComponentPreview { ) .track_scroll(self.nav_scroll_handle.clone()) .p_2p5() - .w(px(229.)) + .w(px(231.)) // Matches perfectly with the size of the "Component Preview" tab, if that's the first one in the pane .h_full() .flex_1(), ) diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index da4b6e78c6..5b0826413b 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -229,8 +229,7 @@ fn assign_edit_prediction_provider( if let Some(file) = buffer.read(cx).file() { let id = file.worktree_id(cx); if let Some(inner_worktree) = editor - .project - .as_ref() + .project() .and_then(|project| project.read(cx).worktree_for_id(id, cx)) { worktree = Some(inner_worktree); diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index 2fd9b0a68c..82d3795e94 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -15,6 +15,8 @@ use futures::{FutureExt, SinkExt, StreamExt}; use git_ui::file_diff_view::FileDiffView; use gpui::{App, AsyncApp, Global, WindowHandle}; use language::Point; +use onboarding::FIRST_OPEN; +use onboarding::show_onboarding_view; use recent_projects::{SshSettings, open_ssh_project}; use remote::SshConnectionOptions; use settings::Settings; @@ -24,7 +26,6 @@ use std::thread; use std::time::Duration; use util::ResultExt; use util::paths::PathWithPosition; -use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::item::ItemHandle; use workspace::{AppState, OpenOptions, SerializedWorkspaceLocation, Workspace}; @@ -378,7 +379,7 @@ async fn open_workspaces( if grouped_locations.is_empty() { // If we have no paths to open, show the welcome screen if this is the first launch if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) { - cx.update(|cx| show_welcome_view(app_state, cx).detach()) + cx.update(|cx| show_onboarding_view(app_state, cx).detach()) .log_err(); } // If not the first launch, show an empty window with empty editor diff --git a/crates/zed/src/zed/quick_action_bar.rs b/crates/zed/src/zed/quick_action_bar.rs index e76bef59a3..2b7c38f997 100644 --- a/crates/zed/src/zed/quick_action_bar.rs +++ b/crates/zed/src/zed/quick_action_bar.rs @@ -140,7 +140,7 @@ impl Render for QuickActionBar { let search_button = editor.is_singleton(cx).then(|| { QuickActionBarButton::new( "toggle buffer search", - IconName::MagnifyingGlass, + search::SEARCH_ICON, !self.buffer_search_bar.read(cx).is_dismissed(), Box::new(buffer_search::Deploy::find()), focus_handle.clone(), diff --git a/crates/zed/src/zed/quick_action_bar/repl_menu.rs b/crates/zed/src/zed/quick_action_bar/repl_menu.rs index 5d1a6c8887..ca180dccdd 100644 --- a/crates/zed/src/zed/quick_action_bar/repl_menu.rs +++ b/crates/zed/src/zed/quick_action_bar/repl_menu.rs @@ -216,7 +216,7 @@ impl QuickActionBar { .size(IconSize::XSmall) .color(Color::Muted), ) - .width(rems(1.).into()) + .width(rems(1.)) .disabled(menu_state.popover_disabled), Tooltip::text("REPL Menu"), ); diff --git a/crates/zed/src/zed/windows_only_instance.rs b/crates/zed/src/zed/windows_only_instance.rs index 277e8ee724..bd62dea75a 100644 --- a/crates/zed/src/zed/windows_only_instance.rs +++ b/crates/zed/src/zed/windows_only_instance.rs @@ -25,7 +25,8 @@ use windows::{ use crate::{Args, OpenListener, RawOpenRequest}; -pub fn is_first_instance() -> bool { +#[inline] +fn is_first_instance() -> bool { unsafe { CreateMutexW( None, @@ -37,7 +38,8 @@ pub fn is_first_instance() -> bool { unsafe { GetLastError() != ERROR_ALREADY_EXISTS } } -pub fn handle_single_instance(opener: OpenListener, args: &Args, is_first_instance: bool) -> bool { +pub fn handle_single_instance(opener: OpenListener, args: &Args) -> bool { + let is_first_instance = is_first_instance(); if is_first_instance { // We are the first instance, listen for messages sent from other instances std::thread::spawn(move || { diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 64891b6973..9455369e9a 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -285,10 +285,6 @@ pub mod agent { ResetOnboarding, /// Starts a chat conversation with the agent. Chat, - /// Displays the previous message in the history. - PreviousHistoryMessage, - /// Displays the next message in the history. - NextHistoryMessage, /// Toggles the language model selector dropdown. #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])] ToggleModelSelector diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 9f1d02b790..ee76308ff3 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -26,6 +26,7 @@ collections.workspace = true command_palette_hooks.workspace = true copilot.workspace = true db.workspace = true +edit_prediction.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true @@ -33,13 +34,13 @@ futures.workspace = true gpui.workspace = true http_client.workspace = true indoc.workspace = true -edit_prediction.workspace = true language.workspace = true language_model.workspace = true log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true +rand.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/input_excerpt.rs index 5949e713e9..8ca6d39407 100644 --- a/crates/zeta/src/input_excerpt.rs +++ b/crates/zeta/src/input_excerpt.rs @@ -9,7 +9,6 @@ use std::{fmt::Write, ops::Range}; pub struct InputExcerpt { pub editable_range: Range, pub prompt: String, - pub speculated_output: String, } pub fn excerpt_for_cursor_position( @@ -46,7 +45,6 @@ pub fn excerpt_for_cursor_position( let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); let mut prompt = String::new(); - let mut speculated_output = String::new(); writeln!(&mut prompt, "```{path}").unwrap(); if context_range.start == Point::zero() { @@ -58,12 +56,6 @@ pub fn excerpt_for_cursor_position( } push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); - push_editable_range( - position, - snapshot, - editable_range.clone(), - &mut speculated_output, - ); for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { prompt.push_str(chunk.text); @@ -73,7 +65,6 @@ pub fn excerpt_for_cursor_position( InputExcerpt { editable_range, prompt, - speculated_output, } } diff --git a/crates/zeta/src/license_detection.rs b/crates/zeta/src/license_detection.rs index c55f8d5d08..fa1eabf524 100644 --- a/crates/zeta/src/license_detection.rs +++ b/crates/zeta/src/license_detection.rs @@ -14,7 +14,7 @@ use util::ResultExt as _; use worktree::ChildEntriesOptions; /// Matches the most common license locations, with US and UK English spelling. -const LICENSE_FILE_NAME_REGEX: LazyLock = LazyLock::new(|| { +static LICENSE_FILE_NAME_REGEX: LazyLock = LazyLock::new(|| { regex::bytes::RegexBuilder::new( "^ \ (?: license | licence) \ @@ -29,7 +29,7 @@ const LICENSE_FILE_NAME_REGEX: LazyLock = LazyLock::new(|| }); fn is_license_eligible_for_data_collection(license: &str) -> bool { - const LICENSE_REGEXES: LazyLock> = LazyLock::new(|| { + static LICENSE_REGEXES: LazyLock> = LazyLock::new(|| { [ include_str!("license_detection/apache.regex"), include_str!("license_detection/isc.regex"), @@ -47,7 +47,7 @@ fn is_license_eligible_for_data_collection(license: &str) -> bool { /// Canonicalizes the whitespace of license text and license regexes. fn canonicalize_license_text(license: &str) -> String { - const PARAGRAPH_SEPARATOR_REGEX: LazyLock = + static PARAGRAPH_SEPARATOR_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\s*\n\s*\n\s*").unwrap()); PARAGRAPH_SEPARATOR_REGEX diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 1ddbd25cb8..1a6a8c2934 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -37,7 +37,6 @@ use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; use std::{ - borrow::Cow, cmp, fmt::Write, future::Future, @@ -66,6 +65,7 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch const MAX_CONTEXT_TOKENS: usize = 150; const MAX_REWRITE_TOKENS: usize = 350; const MAX_EVENT_TOKENS: usize = 500; +const MAX_DIAGNOSTIC_GROUPS: usize = 10; /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -429,6 +429,7 @@ impl Zeta { body, editable_range, } = gather_task.await?; + let done_gathering_context_at = Instant::now(); log::debug!( "Events:\n{}\nExcerpt:\n{:?}", @@ -481,6 +482,7 @@ impl Zeta { } }; + let received_response_at = Instant::now(); log::debug!("completion response: {}", &response.output_excerpt); if let Some(usage) = usage { @@ -492,7 +494,7 @@ impl Zeta { .ok(); } - Self::process_completion_response( + let edit_prediction = Self::process_completion_response( response, buffer, &snapshot, @@ -505,7 +507,25 @@ impl Zeta { buffer_snapshotted_at, &cx, ) - .await + .await; + + let finished_at = Instant::now(); + + // record latency for ~1% of requests + if rand::random::() <= 2 { + telemetry::event!( + "Edit Prediction Request", + context_latency = done_gathering_context_at + .duration_since(buffer_snapshotted_at) + .as_millis(), + request_latency = received_response_at + .duration_since(done_gathering_context_at) + .as_millis(), + process_latency = finished_at.duration_since(received_response_at).as_millis() + ); + } + + edit_prediction }) } @@ -1175,7 +1195,9 @@ pub fn gather_context( cx.background_spawn({ let snapshot = snapshot.clone(); async move { - let diagnostic_groups = if diagnostic_groups.is_empty() { + let diagnostic_groups = if diagnostic_groups.is_empty() + || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS + { None } else { Some(diagnostic_groups) @@ -1189,21 +1211,16 @@ pub fn gather_context( MAX_CONTEXT_TOKENS, ); let input_events = make_events_prompt(); - let input_outline = if can_collect_data { - prompt_for_outline(&snapshot) - } else { - String::new() - }; let editable_range = input_excerpt.editable_range.to_offset(&snapshot); let body = PredictEditsBody { input_events, input_excerpt: input_excerpt.prompt, - speculated_output: Some(input_excerpt.speculated_output), - outline: Some(input_outline), can_collect_data, diagnostic_groups, git_info, + outline: None, + speculated_output: None, }; Ok(GatherContextOutput { @@ -1214,32 +1231,6 @@ pub fn gather_context( }) } -fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { - let mut input_outline = String::new(); - - writeln!( - input_outline, - "```{}", - snapshot - .file() - .map_or(Cow::Borrowed("untitled"), |file| file - .path() - .to_string_lossy()) - ) - .unwrap(); - - if let Some(outline) = snapshot.outline(None) { - for item in &outline.items { - let spacing = " ".repeat(item.depth); - writeln!(input_outline, "{}{}", spacing, item.text).unwrap(); - } - } - - writeln!(input_outline, "```").unwrap(); - - input_outline -} - fn prompt_for_events(events: &VecDeque, mut remaining_tokens: usize) -> String { let mut result = String::new(); for event in events.iter().rev() { diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index fc936d6bd0..c7af36f431 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -21,6 +21,7 @@ - [Icon Themes](./icon-themes.md) - [Visual Customization](./visual-customization.md) - [Vim Mode](./vim.md) +- [Helix Mode](./helix.md) diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index 8fdb7ea325..58c9230760 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -6,29 +6,29 @@ You can do that by either subscribing to [one of Zed's plans](./plans-and-usage. ## Use Your Own Keys {#use-your-own-keys} -If you already have an API key for an existing LLM provider—say Anthropic or OpenAI, for example—you can insert them in Zed and use the Agent Panel **_for free_**. +If you already have an API key for an existing LLM provider—say Anthropic or OpenAI, for example—you can insert them into Zed and use the full power of the Agent Panel **_for free_**. -You can add your API key to a given provider either via the Agent Panel's settings UI or directly via the `settings.json` through the `language_models` key. +To add an existing API key to a given provider, go to the Agent Panel settings (`agent: open settings`), look for the desired provider, paste the key into the input, and hit enter. + +> Note: API keys are _not_ stored as plain text in your `settings.json`, but rather in your OS's secure credential storage. ## Supported Providers Here's all the supported LLM providers for which you can use your own API keys: -| Provider | -| ----------------------------------------------- | -| [Amazon Bedrock](#amazon-bedrock) | -| [Anthropic](#anthropic) | -| [DeepSeek](#deepseek) | -| [GitHub Copilot Chat](#github-copilot-chat) | -| [Google AI](#google-ai) | -| [LM Studio](#lmstudio) | -| [Mistral](#mistral) | -| [Ollama](#ollama) | -| [OpenAI](#openai) | -| [OpenAI API Compatible](#openai-api-compatible) | -| [OpenRouter](#openrouter) | -| [Vercel](#vercel-v0) | -| [xAI](#xai) | +- [Amazon Bedrock](#amazon-bedrock) +- [Anthropic](#anthropic) +- [DeepSeek](#deepseek) +- [GitHub Copilot Chat](#github-copilot-chat) +- [Google AI](#google-ai) +- [LM Studio](#lmstudio) +- [Mistral](#mistral) +- [Ollama](#ollama) +- [OpenAI](#openai) +- [OpenAI API Compatible](#openai-api-compatible) +- [OpenRouter](#openrouter) +- [Vercel](#vercel-v0) +- [xAI](#xai) ### Amazon Bedrock {#amazon-bedrock} @@ -391,27 +391,27 @@ Zed will also use the `OPENAI_API_KEY` environment variable if it's defined. #### Custom Models {#openai-custom-models} -The Zed agent comes pre-configured to use the latest version for common models (GPT-3.5 Turbo, GPT-4, GPT-4 Turbo, GPT-4o, GPT-4o mini). -To use alternate models, perhaps a preview release or a dated model release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: +The Zed agent comes pre-configured to use the latest version for common models (GPT-5, GPT-5 mini, o4-mini, GPT-4.1, and others). +To use alternate models, perhaps a preview release, or if you wish to control the request parameters, you can do so by adding the following to your Zed `settings.json`: ```json { "language_models": { "openai": { "available_models": [ + { + "name": "gpt-5", + "display_name": "gpt-5 high", + "reasoning_effort": "high", + "max_tokens": 272000, + "max_completion_tokens": 20000 + }, { "name": "gpt-4o-2024-08-06", "display_name": "GPT 4o Summer 2024", "max_tokens": 128000 - }, - { - "name": "o1-mini", - "display_name": "o1-mini", - "max_tokens": 128000, - "max_completion_tokens": 20000 } - ], - "version": "1" + ] } } } diff --git a/docs/src/ai/models.md b/docs/src/ai/models.md index b40f17b77f..8d46d0b8d1 100644 --- a/docs/src/ai/models.md +++ b/docs/src/ai/models.md @@ -12,8 +12,10 @@ We’re working hard to expand the models supported by Zed’s subscription offe | Claude Sonnet 4 | Anthropic | ✅ | 200k | N/A | $0.05 | | Claude Opus 4 | Anthropic | ❌ | 120k | $0.20 | N/A | | Claude Opus 4 | Anthropic | ✅ | 200k | N/A | $0.25 | +| Claude Opus 4.1 | Anthropic | ❌ | 120k | $0.20 | N/A | +| Claude Opus 4.1 | Anthropic | ✅ | 200k | N/A | $0.25 | -> Note: Because of the 5x token cost for [Opus relative to Sonnet](https://www.anthropic.com/pricing#api), each Opus prompt consumes 5 prompts against your billing meter +> Note: Because of the 5x token cost for [Opus relative to Sonnet](https://www.anthropic.com/pricing#api), each Opus 4 and 4.1 prompt consumes 5 prompts against your billing meter ## Usage {#usage} diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 1996e1c4ee..9d56130256 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -294,11 +294,11 @@ Define extensions which should be installed (`true`) or never installed (`false` - Description: The name of a font to use for rendering text in the editor. - Setting: `buffer_font_family` -- Default: `Zed Plex Mono` +- Default: `.ZedMono`. This currently aliases to [Lilex](https://lilex.myrt.co). **Options** -The name of any font family installed on the user's system +The name of any font family installed on the user's system, or `".ZedMono"`. ## Buffer Font Features @@ -1284,6 +1284,7 @@ Each option controls displaying of a particular toolbar element. If all elements ```json "status_bar": { "active_language_button": true, + "cursor_position_button": true }, ``` @@ -3195,10 +3196,16 @@ Run the `theme selector: toggle` action in the command palette to see a current ## Vim -- Description: Whether or not to enable vim mode (work in progress). +- Description: Whether or not to enable vim mode. See the [Vim documentation](./vim.md) for more details on configuration. - Setting: `vim_mode` - Default: `false` +## Helix Mode + +- Description: Whether or not to enable Helix mode. Enabling `helix_mode` also enables `vim_mode`. See the [Helix documentation](./helix.md) for more details. +- Setting: `helix_mode` +- Default: `false` + ## Project Panel - Description: Customize project panel @@ -3505,11 +3512,11 @@ Float values between `0.0` and `0.9`, where: - Description: The name of the font to use for text in the UI. - Setting: `ui_font_family` -- Default: `Zed Plex Sans` +- Default: `.ZedSans`. This currently aliases to [IBM Plex](https://www.ibm.com/plex/). **Options** -The name of any font family installed on the system. +The name of any font family installed on the system, `".ZedSans"` to use the Zed-provided default, or `".SystemUIFont"` to use the system's default UI font (on macOS and Windows). ## UI Font Features @@ -3597,7 +3604,7 @@ For example, to use `Nerd Font` as a fallback, add the following to your setting "soft_wrap": "none", "buffer_font_size": 18, - "buffer_font_family": "Zed Plex Mono", + "buffer_font_family": ".ZedMono", "autosave": "on_focus_change", "format_on_save": "off", diff --git a/docs/src/development/windows.md b/docs/src/development/windows.md index ac38e4d7d6..551d5f9f21 100644 --- a/docs/src/development/windows.md +++ b/docs/src/development/windows.md @@ -114,19 +114,19 @@ cargo test --workspace ## Installing from msys2 -[MSYS2](https://msys2.org/) distribution provides Zed as a package [mingw-w64-zed](https://packages.msys2.org/base/mingw-w64-zed). The package is available for UCRT64, MINGW64 and CLANG64 repositories. To download it, run +[MSYS2](https://msys2.org/) distribution provides Zed as a package [mingw-w64-zed](https://packages.msys2.org/base/mingw-w64-zed). The package is available for UCRT64, CLANG64 and CLANGARM64 repositories. To download it, run ```sh pacman -Syu pacman -S $MINGW_PACKAGE_PREFIX-zed ``` -then you can run `zeditor` CLI. Editor executable is installed under `$MINGW_PREFIX/lib/zed` directory - You can see the [build script](https://github.com/msys2/MINGW-packages/blob/master/mingw-w64-zed/PKGBUILD) for more details on build process. > Please, report any issue in [msys2/MINGW-packages/issues](https://github.com/msys2/MINGW-packages/issues?q=is%3Aissue+is%3Aopen+zed) first. +See also MSYS2 [documentation page](https://www.msys2.org/docs/ides-editors). + Note that `collab` is not supported for MSYS2. ## Troubleshooting diff --git a/docs/src/fonts.md b/docs/src/fonts.md deleted file mode 100644 index 93c687b134..0000000000 --- a/docs/src/fonts.md +++ /dev/null @@ -1,56 +0,0 @@ -# Fonts - - - -Zed ships two fonts: Zed Plex Mono and Zed Plex Sans. These are based on IBM Plex Mono and IBM Plex Sans, respectively. - - - -## Settings - - - -- Buffer fonts - - `buffer-font-family` - - `buffer-font-features` - - `buffer-font-size` - - `buffer-line-height` -- UI fonts - - `ui_font_family` - - `ui_font_fallbacks` - - `ui_font_features` - - `ui_font_weight` - - `ui_font_size` -- Terminal fonts - - `terminal.font-size` - - `terminal.font-family` - - `terminal.font-features` - -## Old Zed Fonts - -Previously, Zed shipped with `Zed Mono` and `Zed Sans`, customized versions of the [Iosevka](https://typeof.net/Iosevka/) typeface. You can find more about them in the [zed-fonts](https://github.com/zed-industries/zed-fonts/) repository. - -Here's how you can use the old Zed fonts instead of `Zed Plex Mono` and `Zed Plex Sans`: - -1. Download [zed-app-fonts-1.2.0.zip](https://github.com/zed-industries/zed-fonts/releases/download/1.2.0/zed-app-fonts-1.2.0.zip) from the [zed-fonts releases](https://github.com/zed-industries/zed-fonts/releases) page. -2. Open macOS `Font Book.app` -3. Unzip the file and drag the `ttf` files into the Font Book app. -4. Update your settings `ui_font_family` and `buffer_font_family` to use `Zed Mono` or `Zed Sans` in your `settings.json` file. - -```json -{ - "ui_font_family": "Zed Sans Extended", - "buffer_font_family": "Zed Mono Extend", - "terminal": { - "font-family": "Zed Mono Extended" - } -} -``` - -5. Note there will be red squiggles under the font name. (this is a bug, but harmless.) diff --git a/docs/src/helix.md b/docs/src/helix.md new file mode 100644 index 0000000000..ddf997d3f0 --- /dev/null +++ b/docs/src/helix.md @@ -0,0 +1,11 @@ +# Helix Mode + +_Work in progress! Not all Helix keybindings are implemented yet._ + +Zed's Helix mode is an emulation layer that brings Helix-style keybindings and modal editing to Zed. It builds upon Zed's [Vim mode](./vim.md), so much of the core functionality is shared. Enabling `helix_mode` will also enable `vim_mode`. + +For a guide on Vim-related features that are also available in Helix mode, please refer to our [Vim mode documentation](./vim.md). + +To check the current status of Helix mode, or to request a missing Helix feature, checkout out the ["Are we Helix yet?" discussion](https://github.com/zed-industries/zed/discussions/33580). + +For a detailed list of Helix's default keybindings, please visit the [official Helix documentation](https://docs.helix-editor.com/keymap.html). diff --git a/docs/src/key-bindings.md b/docs/src/key-bindings.md index feed912787..9fc94840b7 100644 --- a/docs/src/key-bindings.md +++ b/docs/src/key-bindings.md @@ -14,7 +14,7 @@ If you're used to a specific editor's defaults you can set a `base_keymap` in yo - TextMate - None (disables _all_ key bindings) -You can also enable `vim_mode`, which adds vim bindings too. +You can also enable `vim_mode` or `helix_mode`, which add modal bindings. For more information, see the documentation for [Vim mode](./vim.md) and [Helix mode](./helix.md). ## User keymaps @@ -119,7 +119,7 @@ It's worth noting that attributes are only available on the node they are define Note: Before Zed v0.197.x, the ! operator only looked at one node at a time, and `>` meant "parent" not "ancestor". This meant that `!Editor` would match the context `Workspace > Pane > Editor`, because (confusingly) the Pane matches `!Editor`, and that `os=macos > Editor` did not match the context `Workspace > Pane > Editor` because of the intermediate `Pane` node. -If you're using Vim mode, we have information on how [vim modes influence the context](./vim.md#contexts) +If you're using Vim mode, we have information on how [vim modes influence the context](./vim.md#contexts). Helix mode is built on top of Vim mode and uses the same contexts. ### Actions diff --git a/docs/src/languages/emmet.md b/docs/src/languages/emmet.md index 1a76291ad4..73e34c209f 100644 --- a/docs/src/languages/emmet.md +++ b/docs/src/languages/emmet.md @@ -1,5 +1,7 @@ # Emmet +Emmet support is available through the [Emmet extension](https://github.com/zed-extensions/emmet). + [Emmet](https://emmet.io/) is a web-developer’s toolkit that can greatly improve your HTML & CSS workflow. - Language Server: [olrtg/emmet-language-server](https://github.com/olrtg/emmet-language-server) diff --git a/docs/src/languages/rust.md b/docs/src/languages/rust.md index 1ee25a37b5..7695280275 100644 --- a/docs/src/languages/rust.md +++ b/docs/src/languages/rust.md @@ -326,7 +326,7 @@ When you use `cargo build` or `cargo test` as the build command, Zed can infer t [ { "label": "Build & Debug native binary", - "adapter": "CodeLLDB" + "adapter": "CodeLLDB", "build": { "command": "cargo", "args": ["build"] diff --git a/docs/src/telemetry.md b/docs/src/telemetry.md index 107aef5a96..46c39a88ae 100644 --- a/docs/src/telemetry.md +++ b/docs/src/telemetry.md @@ -4,7 +4,8 @@ Zed collects anonymous telemetry data to help the team understand how people are ## Configuring Telemetry Settings -You have full control over what data is sent out by Zed. To enable or disable some or all telemetry types, open your `settings.json` file via {#action zed::OpenSettings}({#kb zed::OpenSettings}) from the command palette. +You have full control over what data is sent out by Zed. +To enable or disable some or all telemetry types, open your `settings.json` file via {#action zed::OpenSettings}({#kb zed::OpenSettings}) from the command palette. Insert and tweak the following: @@ -15,8 +16,6 @@ Insert and tweak the following: }, ``` -The telemetry settings can also be configured via the welcome screen, which can be invoked via the {#action workspace::Welcome} action in the command palette. - ## Dataflow Telemetry is sent from the application to our servers. Data is proxied through our servers to enable us to easily switch analytics services. We currently use: diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 46de078d89..6e598f4436 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -39,13 +39,15 @@ If you would like to use distinct themes for light mode/dark mode that can be se ## Fonts ```json - // UI Font. Use ".SystemUIFont" to use the default system font (SF Pro on macOS) - "ui_font_family": "Zed Plex Sans", + // UI Font. Use ".SystemUIFont" to use the default system font (SF Pro on macOS), + // or ".ZedSans" for the bundled default (currently IBM Plex) + "ui_font_family": ".SystemUIFont", "ui_font_weight": 400, // Font weight in standard CSS units from 100 to 900. "ui_font_size": 16, // Buffer Font - Used by editor buffers - "buffer_font_family": "Zed Plex Mono", // Font name for editor buffers + // use ".ZedMono" for the bundled default monospace (currently Lilex) + "buffer_font_family": "Berkeley Mono", // Font name for editor buffers "buffer_font_size": 15, // Font size for editor buffers "buffer_font_weight": 400, // Font weight in CSS units [100-900] // Line height "comfortable" (1.618), "standard" (1.3) or custom: `{ "custom": 2 }` @@ -53,7 +55,7 @@ If you would like to use distinct themes for light mode/dark mode that can be se // Terminal Font Settings "terminal": { - "font_family": "Zed Plex Mono", + "font_family": "", "font_size": 15, // Terminal line height: comfortable (1.618), standard(1.3) or `{ "custom": 2 }` "line_height": "comfortable", @@ -314,6 +316,10 @@ TBD: Centered layout related settings // Clicking the button brings up the language selector. // Defaults to true. "active_language_button": true, + // Show/hide a button that displays the cursor's position. + // Clicking the button brings up an input for jumping to a line and column. + // Defaults to true. + "cursor_position_button": true, }, ``` @@ -473,7 +479,7 @@ See [Zed AI Documentation](./ai/overview.md) for additional non-visual AI settin "show": null // Show/hide: (auto, system, always, never) }, // Terminal Font Settings - "font_family": "Zed Plex Mono", + "font_family": "Fira Code", "font_size": 15, "font_weight": 400, // Terminal line height: comfortable (1.618), standard(1.3) or `{ "custom": 2 }` diff --git a/extensions/emmet/.gitignore b/extensions/emmet/.gitignore deleted file mode 100644 index 62c0add260..0000000000 --- a/extensions/emmet/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*.wasm -grammars -target diff --git a/extensions/emmet/Cargo.toml b/extensions/emmet/Cargo.toml deleted file mode 100644 index 9d72a6c5c4..0000000000 --- a/extensions/emmet/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "zed_emmet" -version = "0.0.4" -edition.workspace = true -publish.workspace = true -license = "Apache-2.0" - -[lints] -workspace = true - -[lib] -path = "src/emmet.rs" -crate-type = ["cdylib"] - -[dependencies] -zed_extension_api = "0.1.0" diff --git a/extensions/emmet/LICENSE-APACHE b/extensions/emmet/LICENSE-APACHE deleted file mode 120000 index 1cd601d0a3..0000000000 --- a/extensions/emmet/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/extensions/emmet/extension.toml b/extensions/emmet/extension.toml deleted file mode 100644 index 9fa14d091f..0000000000 --- a/extensions/emmet/extension.toml +++ /dev/null @@ -1,24 +0,0 @@ -id = "emmet" -name = "Emmet" -description = "Emmet support" -version = "0.0.4" -schema_version = 1 -authors = ["Piotr Osiewicz "] -repository = "https://github.com/zed-industries/zed" - -[language_servers.emmet-language-server] -name = "Emmet Language Server" -language = "HTML" -languages = ["HTML", "PHP", "ERB", "HTML/ERB", "JavaScript", "TSX", "CSS", "HEEX", "Elixir", "Vue.js"] - -[language_servers.emmet-language-server.language_ids] -"HTML" = "html" -"PHP" = "php" -"ERB" = "eruby" -"HTML/ERB" = "eruby" -"JavaScript" = "javascriptreact" -"TSX" = "typescriptreact" -"CSS" = "css" -"HEEX" = "heex" -"Elixir" = "heex" -"Vue.js" = "vue" diff --git a/extensions/emmet/src/emmet.rs b/extensions/emmet/src/emmet.rs deleted file mode 100644 index 83fe809c34..0000000000 --- a/extensions/emmet/src/emmet.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::{env, fs}; -use zed_extension_api::{self as zed, Result}; - -struct EmmetExtension { - did_find_server: bool, -} - -const SERVER_PATH: &str = "node_modules/.bin/emmet-language-server"; -const PACKAGE_NAME: &str = "@olrtg/emmet-language-server"; - -impl EmmetExtension { - fn server_exists(&self) -> bool { - fs::metadata(SERVER_PATH).map_or(false, |stat| stat.is_file()) - } - - fn server_script_path(&mut self, language_server_id: &zed::LanguageServerId) -> Result { - let server_exists = self.server_exists(); - if self.did_find_server && server_exists { - return Ok(SERVER_PATH.to_string()); - } - - zed::set_language_server_installation_status( - language_server_id, - &zed::LanguageServerInstallationStatus::CheckingForUpdate, - ); - let version = zed::npm_package_latest_version(PACKAGE_NAME)?; - - if !server_exists - || zed::npm_package_installed_version(PACKAGE_NAME)?.as_ref() != Some(&version) - { - zed::set_language_server_installation_status( - language_server_id, - &zed::LanguageServerInstallationStatus::Downloading, - ); - let result = zed::npm_install_package(PACKAGE_NAME, &version); - match result { - Ok(()) => { - if !self.server_exists() { - Err(format!( - "installed package '{PACKAGE_NAME}' did not contain expected path '{SERVER_PATH}'", - ))?; - } - } - Err(error) => { - if !self.server_exists() { - Err(error)?; - } - } - } - } - - self.did_find_server = true; - Ok(SERVER_PATH.to_string()) - } -} - -impl zed::Extension for EmmetExtension { - fn new() -> Self { - Self { - did_find_server: false, - } - } - - fn language_server_command( - &mut self, - language_server_id: &zed::LanguageServerId, - _worktree: &zed::Worktree, - ) -> Result { - let server_path = self.server_script_path(language_server_id)?; - Ok(zed::Command { - command: zed::node_binary_path()?, - args: vec![ - env::current_dir() - .unwrap() - .join(&server_path) - .to_string_lossy() - .to_string(), - "--stdio".to_string(), - ], - env: Default::default(), - }) - } -} - -zed::register_extension!(EmmetExtension); diff --git a/flake.lock b/flake.lock index fa0d51d90d..80022f7b55 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "crane": { "locked": { - "lastModified": 1750266157, - "narHash": "sha256-tL42YoNg9y30u7zAqtoGDNdTyXTi8EALDeCB13FtbQA=", + "lastModified": 1754269165, + "narHash": "sha256-0tcS8FHd4QjbCVoxN9jI+PjHgA4vc/IjkUSp+N3zy0U=", "owner": "ipetkov", "repo": "crane", - "rev": "e37c943371b73ed87faf33f7583860f81f1d5a48", + "rev": "444e81206df3f7d92780680e45858e31d2f07a08", "type": "github" }, "original": { @@ -33,10 +33,10 @@ "nixpkgs": { "locked": { "lastModified": 315532800, - "narHash": "sha256-j+zO+IHQ7VwEam0pjPExdbLT2rVioyVS3iq4bLO3GEc=", - "rev": "61c0f513911459945e2cb8bf333dc849f1b976ff", + "narHash": "sha256-5VYevX3GccubYeccRGAXvCPA1ktrGmIX1IFC0icX07g=", + "rev": "a683adc19ff5228af548c6539dbc3440509bfed3", "type": "tarball", - "url": "https://releases.nixos.org/nixpkgs/nixpkgs-25.11pre821324.61c0f5139114/nixexprs.tar.xz" + "url": "https://releases.nixos.org/nixpkgs/nixpkgs-25.11pre840248.a683adc19ff5/nixexprs.tar.xz" }, "original": { "type": "tarball", @@ -58,11 +58,11 @@ ] }, "locked": { - "lastModified": 1750964660, - "narHash": "sha256-YQ6EyFetjH1uy5JhdhRdPe6cuNXlYpMAQePFfZj4W7M=", + "lastModified": 1754575663, + "narHash": "sha256-afOx8AG0KYtw7mlt6s6ahBBy7eEHZwws3iCRoiuRQS4=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "04f0fcfb1a50c63529805a798b4b5c21610ff390", + "rev": "6db0fb0e9cec2e9729dc52bf4898e6c135bb8a0f", "type": "github" }, "original": { diff --git a/nix/build.nix b/nix/build.nix index 70b4f76932..03403cc1c9 100644 --- a/nix/build.nix +++ b/nix/build.nix @@ -171,8 +171,8 @@ let ZSTD_SYS_USE_PKG_CONFIG = true; FONTCONFIG_FILE = makeFontsConf { fontDirectories = [ - ../assets/fonts/plex-mono - ../assets/fonts/plex-sans + ../assets/fonts/lilex + ../assets/fonts/ibm-plex-sans ]; }; ZED_UPDATE_EXPLANATION = "Zed has been installed using Nix. Auto-updates have thus been disabled."; diff --git a/nix/shell.nix b/nix/shell.nix index b78eb5c001..b6f1efd366 100644 --- a/nix/shell.nix +++ b/nix/shell.nix @@ -46,8 +46,8 @@ # outside the nix store instead of to `$src` FONTCONFIG_FILE = makeFontsConf { fontDirectories = [ - "./assets/fonts/plex-mono" - "./assets/fonts/plex-sans" + "./assets/fonts/lilex" + "./assets/fonts/ibm-plex-sans" ]; }; PROTOC = "${protobuf}/bin/protoc"; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f80eab8fbc..2c909e0e1e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,13 +1,8 @@ [toolchain] -channel = "1.88" +channel = "1.89" profile = "minimal" components = [ "rustfmt", "clippy" ] targets = [ - "x86_64-apple-darwin", - "aarch64-apple-darwin", - "x86_64-unknown-freebsd", - "x86_64-unknown-linux-gnu", - "x86_64-pc-windows-msvc", "wasm32-wasip2", # extensions "x86_64-unknown-linux-musl", # remote server ] diff --git a/script/bundle-mac b/script/bundle-mac index b2be573235..f2a5bf313d 100755 --- a/script/bundle-mac +++ b/script/bundle-mac @@ -207,7 +207,7 @@ function prepare_binaries() { rm -f target/${architecture}/${target_dir}/Zed.dwarf.gz echo "Gzipping dSYMs for $architecture" - gzip -f target/${architecture}/${target_dir}/Zed.dwarf + gzip -kf target/${architecture}/${target_dir}/Zed.dwarf echo "Uploading dSYMs${architecture} for $architecture to by-uuid/${uuid}.dwarf.gz" upload_to_blob_store_public \ @@ -367,19 +367,25 @@ else gzip -f --stdout --best target/aarch64-apple-darwin/release/remote_server > target/zed-remote-server-macos-aarch64.gz fi -# Upload debug info to sentry.io -if ! command -v sentry-cli >/dev/null 2>&1; then - echo "sentry-cli not found. skipping sentry upload." - echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" -else +function upload_debug_info() { + architecture=$1 if [[ -n "${SENTRY_AUTH_TOKEN:-}" ]]; then echo "Uploading zed debug symbols to sentry..." # note: this uploads the unstripped binary which is needed because it contains # .eh_frame data for stack unwinindg. see https://github.com/getsentry/symbolic/issues/783 sentry-cli debug-files upload --include-sources --wait -p zed -o zed-dev \ - "target/x86_64-apple-darwin/${target_dir}/" \ - "target/aarch64-apple-darwin/${target_dir}/" + "target/${architecture}/${target_dir}/zed" \ + "target/${architecture}/${target_dir}/remote_server" \ + "target/${architecture}/${target_dir}/zed.dwarf" else echo "missing SENTRY_AUTH_TOKEN. skipping sentry upload." fi +} + +if command -v sentry-cli >/dev/null 2>&1; then + upload_debug_info "aarch64-apple-darwin" + upload_debug_info "x86_64-apple-darwin" +else + echo "sentry-cli not found. skipping sentry upload." + echo "install with: 'curl -sL https://sentry.io/get-cli | bash'" fi diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 338985ed95..054e757056 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -6,9 +6,9 @@ [package] name = "workspace-hack" version = "0.1.0" -edition = "2021" description = "workspace-hack package, managed by hakari" -publish = false +edition.workspace = true +publish.workspace = true # The parts of the file between the BEGIN HAKARI SECTION and END HAKARI SECTION comments # are managed by hakari. diff --git a/typos.toml b/typos.toml index 336a829a44..e5f02b6415 100644 --- a/typos.toml +++ b/typos.toml @@ -16,9 +16,6 @@ extend-exclude = [ "crates/google_ai/src/supported_countries.rs", "crates/open_ai/src/supported_countries.rs", - # Some crate names are flagged as typos. - "crates/indexed_docs/src/providers/rustdoc/popular_crates.txt", - # Some mock data is flagged as typos. "crates/assistant_tools/src/web_search_tool.rs",