- Resolved conflicts with open split direction

This commit is contained in:
Cesar 2025-08-08 11:56:09 -04:00
commit d6b89a4c30
267 changed files with 13106 additions and 4891 deletions

View file

@ -5,25 +5,25 @@ self-hosted-runner:
# GitHub-hosted Runners
- github-8vcpu-ubuntu-2404
- github-16vcpu-ubuntu-2404
- github-32vcpu-ubuntu-2404
- github-8vcpu-ubuntu-2204
- github-16vcpu-ubuntu-2204
- github-32vcpu-ubuntu-2204
- github-16vcpu-ubuntu-2204-arm
- windows-2025-16
- windows-2025-32
- windows-2025-64
# Buildjet Ubuntu 20.04 - AMD x86_64
- buildjet-2vcpu-ubuntu-2004
- buildjet-4vcpu-ubuntu-2004
- buildjet-8vcpu-ubuntu-2004
- buildjet-16vcpu-ubuntu-2004
- buildjet-32vcpu-ubuntu-2004
# Buildjet Ubuntu 22.04 - AMD x86_64
- buildjet-2vcpu-ubuntu-2204
- buildjet-4vcpu-ubuntu-2204
- buildjet-8vcpu-ubuntu-2204
- buildjet-16vcpu-ubuntu-2204
- buildjet-32vcpu-ubuntu-2204
# Buildjet Ubuntu 22.04 - Graviton aarch64
- buildjet-8vcpu-ubuntu-2204-arm
- buildjet-16vcpu-ubuntu-2204-arm
- buildjet-32vcpu-ubuntu-2204-arm
# Namespace Ubuntu 20.04 (Release builds)
- namespace-profile-16x32-ubuntu-2004
- namespace-profile-32x64-ubuntu-2004
- namespace-profile-16x32-ubuntu-2004-arm
- namespace-profile-32x64-ubuntu-2004-arm
# Namespace Ubuntu 22.04 (Everything else)
- namespace-profile-2x4-ubuntu-2204
- namespace-profile-4x8-ubuntu-2204
- namespace-profile-8x16-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
- namespace-profile-32x64-ubuntu-2204
# Self Hosted Runners
- self-mini-macos
- self-32vcpu-windows-2022

View file

@ -13,7 +13,7 @@ runs:
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
# cache-provider: "buildjet"
- name: Install Linux dependencies
shell: bash -euxo pipefail {0}

View file

@ -16,7 +16,7 @@ jobs:
bump_patch_version:
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4

View file

@ -137,7 +137,7 @@ jobs:
github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true'
runs-on:
- buildjet-8vcpu-ubuntu-2204
- namespace-profile-8x16-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -168,7 +168,7 @@ jobs:
needs: [job_spec]
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-8vcpu-ubuntu-2204
- namespace-profile-4x8-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -221,7 +221,7 @@ jobs:
github.repository_owner == 'zed-industries' &&
(needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true')
runs-on:
- buildjet-8vcpu-ubuntu-2204
- namespace-profile-8x16-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -328,7 +328,7 @@ jobs:
github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true'
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -342,7 +342,7 @@ jobs:
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
# cache-provider: "buildjet"
- name: Install Linux dependencies
run: ./script/linux
@ -380,7 +380,7 @@ jobs:
github.repository_owner == 'zed-industries' &&
needs.job_spec.outputs.run_tests == 'true'
runs-on:
- buildjet-8vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -394,7 +394,7 @@ jobs:
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
# cache-provider: "buildjet"
- name: Install Clang & Mold
run: ./script/remote-server && ./script/install-mold 2.34.0
@ -597,7 +597,7 @@ jobs:
timeout-minutes: 60
name: Linux x86_x64 release bundle
runs-on:
- buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
if: |
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
@ -650,7 +650,7 @@ jobs:
timeout-minutes: 60
name: Linux arm64 release bundle
runs-on:
- buildjet-32vcpu-ubuntu-2204-arm
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
if: |
startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')

View file

@ -9,7 +9,7 @@ jobs:
deploy-docs:
name: Deploy Docs
if: github.repository_owner == 'zed-industries'
runs-on: buildjet-16vcpu-ubuntu-2204
runs-on: namespace-profile-16x32-ubuntu-2204
steps:
- name: Checkout repo

View file

@ -61,7 +61,7 @@ jobs:
- style
- tests
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Install doctl
uses: digitalocean/action-doctl@v2
@ -94,7 +94,7 @@ jobs:
needs:
- publish
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Checkout repo
@ -137,12 +137,14 @@ jobs:
export ZED_SERVICE_NAME=collab
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT
export DATABASE_MAX_CONNECTIONS=850
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
export ZED_SERVICE_NAME=api
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT
export DATABASE_MAX_CONNECTIONS=60
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

View file

@ -32,7 +32,7 @@ jobs:
github.repository_owner == 'zed-industries' &&
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -46,7 +46,7 @@ jobs:
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
# cache-provider: "buildjet"
- name: Install Linux dependencies
run: ./script/linux

View file

@ -20,7 +20,7 @@ jobs:
matrix:
system:
- os: x86 Linux
runner: buildjet-16vcpu-ubuntu-2204
runner: namespace-profile-16x32-ubuntu-2204
install_nix: true
- os: arm Mac
runner: [macOS, ARM64, test]

View file

@ -20,7 +20,7 @@ jobs:
name: Run randomized tests
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Install Node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4

View file

@ -128,7 +128,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for x86
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-16vcpu-ubuntu-2004
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
needs: tests
steps:
- name: Checkout repo
@ -168,7 +168,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for ARM
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-32vcpu-ubuntu-2204-arm
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
needs: tests
steps:
- name: Checkout repo

View file

@ -23,7 +23,7 @@ jobs:
timeout-minutes: 60
name: Run unit evals
runs-on:
- buildjet-16vcpu-ubuntu-2204
- namespace-profile-16x32-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH"
@ -37,7 +37,7 @@ jobs:
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
# cache-provider: "buildjet"
- name: Install Linux dependencies
run: ./script/linux

125
Cargo.lock generated
View file

@ -17,6 +17,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"markdown",
"parking_lot",
"project",
@ -137,9 +138,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.18"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8e4c1dccb35e69d32566f0d11948d902f9942fc3f038821816c1150cf5925f4"
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
dependencies = [
"anyhow",
"futures 0.3.31",
@ -150,6 +151,54 @@ dependencies = [
"serde_json",
]
[[package]]
name = "agent2"
version = "0.1.0"
dependencies = [
"acp_thread",
"agent-client-protocol",
"agent_servers",
"agent_settings",
"anyhow",
"assistant_tool",
"assistant_tools",
"client",
"clock",
"cloud_llm_client",
"collections",
"ctor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"log",
"lsp",
"paths",
"pretty_assertions",
"project",
"prompt_store",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"ui",
"util",
"uuid",
"watch",
"workspace-hack",
"worktree",
]
[[package]]
name = "agent_servers"
version = "0.1.0"
@ -214,6 +263,7 @@ dependencies = [
"acp_thread",
"agent",
"agent-client-protocol",
"agent2",
"agent_servers",
"agent_settings",
"ai_onboarding",
@ -1371,7 +1421,7 @@ dependencies = [
"anyhow",
"arrayvec",
"log",
"nom",
"nom 7.1.3",
"num-rational",
"v_frame",
]
@ -2745,7 +2795,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
"nom 7.1.3",
]
[[package]]
@ -3031,17 +3081,22 @@ dependencies = [
"anyhow",
"cloud_api_types",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"http_client",
"parking_lot",
"serde_json",
"workspace-hack",
"yawc",
]
[[package]]
name = "cloud_api_types"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"ciborium",
"cloud_llm_client",
"pretty_assertions",
"serde",
@ -7457,9 +7512,9 @@ dependencies = [
[[package]]
name = "grid"
version = "0.17.0"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa"
checksum = "12101ecc8225ea6d675bc70263074eab6169079621c2186fe0c66590b2df9681"
[[package]]
name = "group"
@ -9078,6 +9133,7 @@ dependencies = [
"anyhow",
"base64 0.22.1",
"client",
"cloud_api_types",
"cloud_llm_client",
"collections",
"futures 0.3.31",
@ -9208,6 +9264,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"async-fs",
"async-tar",
"async-trait",
"chrono",
@ -9239,9 +9296,11 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
"sha2",
"smol",
"snippet_provider",
"task",
"tempfile",
"text",
"theme",
"toml 0.8.20",
@ -10537,6 +10596,15 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nom"
version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405"
dependencies = [
"memchr",
]
[[package]]
name = "noop_proc_macro"
version = "0.3.0"
@ -12575,6 +12643,7 @@ dependencies = [
"editor",
"file_icons",
"git",
"git_ui",
"gpui",
"indexmap",
"language",
@ -12588,6 +12657,7 @@ dependencies = [
"serde_json",
"settings",
"smallvec",
"telemetry",
"theme",
"ui",
"util",
@ -15358,7 +15428,7 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790"
dependencies = [
"nom",
"nom 7.1.3",
"unicode_categories",
]
@ -16158,9 +16228,9 @@ dependencies = [
[[package]]
name = "taffy"
version = "0.8.3"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c"
checksum = "a13e5d13f79d558b5d353a98072ca8ca0e99da429467804de959aa8c83c9a004"
dependencies = [
"arrayvec",
"grid",
@ -16561,9 +16631,8 @@ dependencies = [
[[package]]
name = "tiktoken-rs"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625"
version = "0.8.0"
source = "git+https://github.com/zed-industries/tiktoken-rs?rev=30c32a4522751699adeda0d5840c71c3b75ae73d#30c32a4522751699adeda0d5840c71c3b75ae73d"
dependencies = [
"anyhow",
"base64 0.22.1",
@ -19934,7 +20003,7 @@ dependencies = [
"naga",
"nix 0.28.0",
"nix 0.29.0",
"nom",
"nom 7.1.3",
"num-bigint",
"num-bigint-dig",
"num-integer",
@ -20269,6 +20338,34 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yawc"
version = "0.2.4"
source = "git+https://github.com/deviant-forks/yawc?rev=1899688f3e69ace4545aceb97b2a13881cf26142#1899688f3e69ace4545aceb97b2a13881cf26142"
dependencies = [
"base64 0.22.1",
"bytes 1.10.1",
"flate2",
"futures 0.3.31",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"js-sys",
"nom 8.0.0",
"pin-project",
"rand 0.8.5",
"sha1",
"thiserror 1.0.69",
"tokio",
"tokio-rustls 0.26.2",
"tokio-util",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
]
[[package]]
name = "yazi"
version = "0.2.1"
@ -20375,7 +20472,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.199.0"
version = "0.200.0"
dependencies = [
"activity_indicator",
"agent",

View file

@ -4,6 +4,7 @@ members = [
"crates/acp_thread",
"crates/activity_indicator",
"crates/agent",
"crates/agent2",
"crates/agent_servers",
"crates/agent_settings",
"crates/agent_ui",
@ -229,6 +230,7 @@ edition = "2024"
acp_thread = { path = "crates/acp_thread" }
agent = { path = "crates/agent" }
agent2 = { path = "crates/agent2" }
activity_indicator = { path = "crates/activity_indicator" }
agent_ui = { path = "crates/agent_ui" }
agent_settings = { path = "crates/agent_settings" }
@ -423,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.18"
agent-client-protocol = "0.0.23"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@ -459,6 +461,7 @@ bytes = "1.0"
cargo_metadata = "0.19"
cargo_toml = "0.21"
chrono = { version = "0.4", features = ["serde"] }
ciborium = "0.2"
circular-buffer = "1.0"
clap = { version = "4.4", features = ["derive"] }
cocoa = "0.26"
@ -598,7 +601,7 @@ sysinfo = "0.31.0"
take-until = "0.2.0"
tempfile = "3.20.0"
thiserror = "2.0.12"
tiktoken-rs = "0.7.0"
tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "30c32a4522751699adeda0d5840c71c3b75ae73d" }
time = { version = "0.3", features = [
"macros",
"parsing",
@ -658,6 +661,9 @@ which = "6.0.0"
windows-core = "0.61"
wit-component = "0.221"
workspace-hack = "0.1.0"
# We can switch back to the published version once https://github.com/infinitefield/yawc/pull/16 is merged and a new
# version is released.
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
zstd = "0.11"
[workspace.dependencies.async-stripe]

View file

@ -1,3 +1,4 @@
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
cloud: cd ../cloud; cargo make dev
livekit: livekit-server --dev
blob_store: ./script/run-local-minio

View file

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="17" height="16" fill="none"><path fill="#000" d="M12.5 9.639V6.354H9.683L8.18 5.007V2.5H4.5v3.285h2.817l1.51 1.347v1.736l-1.51 1.347H4.5V13.5h3.681v-2.507l1.51-1.347H12.5v-.007ZM5.727 3.595h1.227V4.69H5.727V3.595Zm1.227 8.803H5.727v-1.095h1.227v1.095Z"/></svg>

After

Width:  |  Height:  |  Size: 308 B

View file

Before

Width:  |  Height:  |  Size: 776 B

After

Width:  |  Height:  |  Size: 776 B

Before After
Before After

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.3 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.5 KiB

View file

@ -332,7 +332,9 @@
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff"
"shift-ctrl-r": "agent::OpenAgentDiff",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"
}
},
{
@ -846,6 +848,7 @@
"ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }],
"alt-ctrl-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"alt-d": "project_panel::CompareMarkedFiles",
"shift-find": "project_panel::NewSearchInDirectory",
"ctrl-alt-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
@ -1100,6 +1103,13 @@
"ctrl-enter": "menu::Confirm"
}
},
{
"context": "OnboardingAiConfigurationModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
@ -1176,7 +1186,8 @@
"ctrl-1": "onboarding::ActivateBasicsPage",
"ctrl-2": "onboarding::ActivateEditingPage",
"ctrl-3": "onboarding::ActivateAISetupPage",
"ctrl-escape": "onboarding::Finish"
"ctrl-escape": "onboarding::Finish",
"alt-tab": "onboarding::SignIn"
}
}
]

View file

@ -384,7 +384,9 @@
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff"
"shift-ctrl-r": "agent::OpenAgentDiff",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"
}
},
{
@ -905,6 +907,7 @@
"cmd-delete": ["project_panel::Delete", { "skip_prompt": false }],
"alt-cmd-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"alt-d": "project_panel::CompareMarkedFiles",
"cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }],
"cmd-alt-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
@ -1202,6 +1205,13 @@
"cmd-enter": "menu::Confirm"
}
},
{
"context": "OnboardingAiConfigurationModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
@ -1278,7 +1288,8 @@
"cmd-1": "onboarding::ActivateBasicsPage",
"cmd-2": "onboarding::ActivateEditingPage",
"cmd-3": "onboarding::ActivateAISetupPage",
"cmd-escape": "onboarding::Finish"
"cmd-escape": "onboarding::Finish",
"alt-tab": "onboarding::SignIn"
}
}
]

View file

@ -815,6 +815,7 @@
"ctrl-x": "project_panel::OpenSplitUp",
"x": "project_panel::RevealInFileManager",
"s": "project_panel::OpenWithSystem",
"z d": "project_panel::CompareMarkedFiles",
"] c": "project_panel::SelectNextGitEntry",
"[ c": "project_panel::SelectPrevGitEntry",
"] d": "project_panel::SelectNextDiagnostic",

View file

@ -596,6 +596,8 @@
// when a corresponding project entry becomes active.
// Gitignored entries are never auto revealed.
"auto_reveal_entries": true,
// Whether the project panel should open on startup.
"starts_open": true,
// Whether to fold directories automatically and show compact folders
// (e.g. "a/b/c" ) when a directory has only one subdirectory inside.
"auto_fold_dirs": true,
@ -1171,6 +1173,9 @@
// Sets a delay after which the inline blame information is shown.
// Delay is restarted with every cursor movement.
"delay_ms": 0,
// The amount of padding between the end of the source line and the start
// of the inline blame in units of em widths.
"padding": 7,
// Whether or not to display the git commit summary on the same line.
"show_commit_summary": false,
// The minimum column number to show the inline blame information at
@ -1233,6 +1238,11 @@
// 2. hour24
"hour_format": "hour12"
},
// Status bar-related settings.
"status_bar": {
// Whether to show the active language button in the status bar.
"active_language_button": true
},
// Settings specific to the terminal
"terminal": {
// What shell to use when opening a terminal. May take 3 values:

View file

@ -25,6 +25,7 @@ futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
markdown.workspace = true
project.workspace = true
serde.workspace = true

View file

@ -1,23 +1,23 @@
mod connection;
mod diff;
pub use connection::*;
pub use diff::*;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_tool::ActionLog;
use buffer_diff::BufferDiff;
use editor::{Bias, MultiBuffer, PathKey};
use editor::Bias;
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use itertools::Itertools;
use language::{
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
text_diff,
};
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
use markdown::Markdown;
use project::{AgentLocation, Project};
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Formatter;
use std::process::ExitStatus;
use std::rc::Rc;
use std::{
fmt::Display,
@ -139,7 +139,7 @@ impl AgentThreadEntry {
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs())
} else {
@ -165,6 +165,7 @@ pub struct ToolCall {
pub status: ToolCallStatus,
pub locations: Vec<acp::ToolCallLocation>,
pub raw_input: Option<serde_json::Value>,
pub raw_output: Option<serde_json::Value>,
}
impl ToolCall {
@ -193,10 +194,11 @@ impl ToolCall {
locations: tool_call.locations,
status,
raw_input: tool_call.raw_input,
raw_output: tool_call.raw_output,
}
}
fn update(
fn update_fields(
&mut self,
fields: acp::ToolCallUpdateFields,
language_registry: Arc<LanguageRegistry>,
@ -209,6 +211,7 @@ impl ToolCall {
content,
locations,
raw_input,
raw_output,
} = fields;
if let Some(kind) = kind {
@ -220,7 +223,9 @@ impl ToolCall {
}
if let Some(title) = title {
self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
self.label.update(cx, |label, cx| {
label.replace(title, cx);
});
}
if let Some(content) = content {
@ -237,9 +242,13 @@ impl ToolCall {
if let Some(raw_input) = raw_input {
self.raw_input = Some(raw_input);
}
if let Some(raw_output) = raw_output {
self.raw_output = Some(raw_output);
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
self.content.iter().filter_map(|content| match content {
ToolCallContent::ContentBlock { .. } => None,
ToolCallContent::Diff { diff } => Some(diff),
@ -379,7 +388,7 @@ impl ContentBlock {
#[derive(Debug)]
pub enum ToolCallContent {
ContentBlock { content: ContentBlock },
Diff { diff: Diff },
Diff { diff: Entity<Diff> },
}
impl ToolCallContent {
@ -393,7 +402,7 @@ impl ToolCallContent {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
diff: Diff::from_acp(diff, language_registry, cx),
diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
},
}
}
@ -401,109 +410,44 @@ impl ToolCallContent {
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
Self::Diff { diff } => diff.to_markdown(cx),
Self::Diff { diff } => diff.read(cx).to_markdown(cx),
}
}
}
#[derive(Debug)]
pub struct Diff {
pub multibuffer: Entity<MultiBuffer>,
pub path: PathBuf,
pub new_buffer: Entity<Buffer>,
pub old_buffer: Entity<Buffer>,
_task: Task<Result<()>>,
#[derive(Debug, PartialEq)]
pub enum ToolCallUpdate {
UpdateFields(acp::ToolCallUpdate),
UpdateDiff(ToolCallUpdateDiff),
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let old_buffer_snapshot = old_buffer.read(cx).snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let diff_task = buffer_diff.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry.clone()),
new_buffer_snapshot,
cx,
)
});
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
let new_buffer = new_buffer.clone();
async move |cx| {
diff_task.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff.clone(), cx);
})
.log_err();
if let Some(language) = language_registry
.language_for_file_path(&path)
.await
.log_err()
{
new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
}
anyhow::Ok(())
}
});
Self {
multibuffer,
path,
new_buffer,
old_buffer,
_task: task,
impl ToolCallUpdate {
fn id(&self) -> &acp::ToolCallId {
match self {
Self::UpdateFields(update) => &update.id,
Self::UpdateDiff(diff) => &diff.id,
}
}
}
fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
impl From<acp::ToolCallUpdate> for ToolCallUpdate {
fn from(update: acp::ToolCallUpdate) -> Self {
Self::UpdateFields(update)
}
}
impl From<ToolCallUpdateDiff> for ToolCallUpdate {
fn from(diff: ToolCallUpdateDiff) -> Self {
Self::UpdateDiff(diff)
}
}
#[derive(Debug, PartialEq)]
pub struct ToolCallUpdateDiff {
pub id: acp::ToolCallId,
pub diff: Entity<Diff>,
}
#[derive(Debug, Default)]
pub struct Plan {
pub entries: Vec<PlanEntry>,
@ -556,7 +500,7 @@ pub struct PlanEntry {
impl PlanEntry {
pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
Self {
content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
priority: entry.priority,
status: entry.status,
}
@ -581,6 +525,7 @@ pub enum AcpThreadEvent {
ToolAuthorizationRequired,
Stopped,
Error,
ServerExited(ExitStatus),
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
@ -654,6 +599,10 @@ impl AcpThread {
&self.entries
}
pub fn session_id(&self) -> &acp::SessionId {
&self.session_id
}
pub fn status(&self) -> ThreadStatus {
if self.send_task.is_some() {
if self.waiting_for_tool_confirmation() {
@ -794,15 +743,26 @@ impl AcpThread {
pub fn update_tool_call(
&mut self,
update: acp::ToolCallUpdate,
update: impl Into<ToolCallUpdate>,
cx: &mut Context<Self>,
) -> Result<()> {
let update = update.into();
let languages = self.project.read(cx).languages().clone();
let (ix, current_call) = self
.tool_call_mut(&update.id)
.tool_call_mut(update.id())
.context("Tool call not found")?;
current_call.update(update.fields, languages, cx);
match update {
ToolCallUpdate::UpdateFields(update) => {
current_call.update_fields(update.fields, languages, cx);
}
ToolCallUpdate::UpdateDiff(update) => {
current_call.content.clear();
current_call
.content
.push(ToolCallContent::Diff { diff: update.diff });
}
}
cx.emit(AcpThreadEvent::EntryUpdated(ix));
@ -890,7 +850,7 @@ impl AcpThread {
});
}
pub fn request_tool_call_permission(
pub fn request_tool_call_authorization(
&mut self,
tool_call: acp::ToolCall,
options: Vec<acp::PermissionOption>,
@ -965,13 +925,26 @@ impl AcpThread {
}
pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
self.plan = Plan {
entries: request
.entries
.into_iter()
.map(|entry| PlanEntry::from_acp(entry, cx))
.collect(),
};
let new_entries_len = request.entries.len();
let mut new_entries = request.entries.into_iter();
// Reuse existing markdown to prevent flickering
for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
let PlanEntry {
content,
priority,
status,
} = old;
content.update(cx, |old, cx| {
old.replace(new.content, cx);
});
*priority = new.priority;
*status = new.status;
}
for new in new_entries {
self.plan.entries.push(PlanEntry::from_acp(new, cx))
}
self.plan.entries.truncate(new_entries_len);
cx.notify();
}
@ -1032,8 +1005,9 @@ impl AcpThread {
)
})?
.await;
tx.send(result).log_err();
this.update(cx, |this, _cx| this.send_task.take())?;
anyhow::Ok(())
}
.await
@ -1046,7 +1020,23 @@ impl AcpThread {
.log_err();
Err(e)?
}
_ => {
result => {
let cancelled = matches!(
result,
Ok(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled
}))
);
// We only take the task if the current prompt wasn't cancelled.
//
// This prompt may have been cancelled because another one was sent
// while it was still generating. In these cases, dropping `send_task`
// would cause the next generation to be cancelled.
if !cancelled {
this.update(cx, |this, _cx| this.send_task.take()).ok();
}
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
.log_err();
Ok(())
@ -1229,6 +1219,10 @@ impl AcpThread {
pub fn to_markdown(&self, cx: &App) -> String {
self.entries.iter().map(|e| e.to_markdown(cx)).collect()
}
pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::ServerExited(status));
}
}
#[cfg(test)]
@ -1371,6 +1365,9 @@ mod tests {
cx,
)
.unwrap();
})?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
}
.boxed_local()
@ -1443,7 +1440,9 @@ mod tests {
.unwrap()
.await
.unwrap();
Ok(())
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
}
.boxed_local()
},
@ -1510,13 +1509,16 @@ mod tests {
content: vec![],
locations: vec![],
raw_input: None,
raw_output: None,
}),
cx,
)
})
.unwrap()
.unwrap();
Ok(())
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
}
.boxed_local()
}
@ -1620,13 +1622,16 @@ mod tests {
}],
locations: vec![],
raw_input: None,
raw_output: None,
}),
cx,
)
})
.unwrap()
.unwrap();
Ok(())
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
}
.boxed_local()
}
@ -1680,7 +1685,7 @@ mod tests {
acp::PromptRequest,
WeakEntity<AcpThread>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<()>>
) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
+ 'static,
>,
>,
@ -1707,7 +1712,7 @@ mod tests {
acp::PromptRequest,
WeakEntity<AcpThread>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<()>>
) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
+ 'static,
) -> Self {
self.on_user_message.replace(Rc::new(handler));
@ -1749,7 +1754,11 @@ mod tests {
}
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
if let Some(handler) = &self.on_user_message {
@ -1757,7 +1766,9 @@ mod tests {
let thread = thread.clone();
cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
} else {
Task::ready(Ok(()))
Task::ready(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
}))
}
}

View file

@ -1,13 +1,61 @@
use std::{error::Error, fmt, path::Path, rc::Rc};
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use agent_client_protocol::{self as acp};
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
use language_model::LanguageModel;
use project::Project;
use ui::App;
use crate::AcpThread;
/// Trait for agents that support listing, selecting, and querying language models.
///
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
pub trait ModelSelector: 'static {
/// Lists all available language models for this agent.
///
/// # Parameters
/// - `cx`: The GPUI app context for async operations and global access.
///
/// # Returns
/// A task resolving to the list of models or an error (e.g., if no models are configured).
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
/// Selects a model for a specific session (thread).
///
/// This sets the default model for future interactions in the session.
/// If the session doesn't exist or the model is invalid, it returns an error.
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to apply the model to.
/// - `model`: The model to select (should be one from [list_models]).
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to `Ok(())` on success or an error.
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>>;
/// Retrieves the currently selected model for a specific session (thread).
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to query.
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>>;
}
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
@ -20,9 +68,18 @@ pub trait AgentConnection {
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
-> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
None // Default impl for agents that don't support it
}
}
#[derive(Debug)]

View file

@ -0,0 +1,388 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task};
use itertools::Itertools;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer,
};
use std::{
cmp::Reverse,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use util::ResultExt;
pub enum Diff {
Pending(PendingDiff),
Finalized(FinalizedDiff),
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
async move |_, cx| {
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
anyhow::Ok(())
}
});
Self::Finalized(FinalizedDiff {
multibuffer,
path,
_update_diff: task,
})
}
pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
let buffer_snapshot = buffer.read(cx).snapshot();
let base_text = buffer_snapshot.text();
let language_registry = buffer.read(cx).language_registry();
let text_snapshot = buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&text_snapshot, cx);
let _ = diff.set_base_text(
buffer_snapshot.clone(),
language_registry,
text_snapshot,
cx,
);
diff
});
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
multibuffer.add_diff(buffer_diff.clone(), cx);
multibuffer
});
Self::Pending(PendingDiff {
multibuffer,
base_text: Arc::new(base_text),
_subscription: cx.observe(&buffer, |this, _, cx| {
if let Diff::Pending(diff) = this {
diff.update(cx);
}
}),
buffer,
diff: buffer_diff,
revealed_ranges: Vec::new(),
update_diff: Task::ready(Ok(())),
})
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
diff.reveal_range(range, cx);
}
}
pub fn finalize(&mut self, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
*self = Self::Finalized(diff.finalize(cx));
}
}
pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
match self {
Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer,
Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer,
}
}
pub fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer()
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
let path = match self {
Diff::Pending(PendingDiff { buffer, .. }) => {
buffer.read(cx).file().map(|file| file.path().as_ref())
}
Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()),
};
format!(
"Diff: {}\n```\n{}\n```\n",
path.unwrap_or(Path::new("untitled")).display(),
buffer_text
)
}
}
pub struct PendingDiff {
multibuffer: Entity<MultiBuffer>,
base_text: Arc<String>,
buffer: Entity<Buffer>,
diff: Entity<BufferDiff>,
revealed_ranges: Vec<Range<Anchor>>,
_subscription: Subscription,
update_diff: Task<Result<()>>,
}
impl PendingDiff {
pub fn update(&mut self, cx: &mut Context<Diff>) {
let buffer = self.buffer.clone();
let buffer_diff = self.diff.clone();
let base_text = self.base_text.clone();
self.update_diff = cx.spawn(async move |diff, cx| {
let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
let diff_snapshot = BufferDiff::update_diff(
buffer_diff.clone(),
text_snapshot.clone(),
Some(base_text),
false,
false,
None,
None,
cx,
)
.await?;
buffer_diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
})?;
diff.update(cx, |diff, cx| {
if let Diff::Pending(diff) = diff {
diff.update_visible_ranges(cx);
}
})
});
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Diff>) {
self.revealed_ranges.push(range);
self.update_visible_ranges(cx);
}
fn finalize(&self, cx: &mut Context<Diff>) -> FinalizedDiff {
let ranges = self.excerpt_ranges(cx);
let base_text = self.base_text.clone();
let language_registry = self.buffer.read(cx).language_registry().clone();
let path = self
.buffer
.read(cx)
.file()
.map(|file| file.path().as_ref())
.unwrap_or(Path::new("untitled"))
.into();
// Replace the buffer in the multibuffer with the snapshot
let buffer = cx.new(|cx| {
let language = self.buffer.read(cx).language().cloned();
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
self.buffer.read(cx).line_ending(),
self.buffer.read(cx).as_rope().clone(),
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
});
let buffer_diff = cx.spawn({
let buffer = buffer.clone();
let language_registry = language_registry.clone();
async move |_this, cx| {
build_buffer_diff(base_text, &buffer, language_registry, cx).await
}
});
let update_diff = cx.spawn(async move |this, cx| {
let buffer_diff = buffer_diff.await?;
this.update(cx, |this, cx| {
this.multibuffer().update(cx, |multibuffer, cx| {
let path_key = PathKey::for_buffer(&buffer, cx);
multibuffer.clear(cx);
multibuffer.set_excerpts_for_path(
path_key,
buffer,
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff.clone(), cx);
});
cx.notify();
})
});
FinalizedDiff {
path,
multibuffer: self.multibuffer.clone(),
_update_diff: update_diff,
}
}
fn update_visible_ranges(&mut self, cx: &mut Context<Diff>) {
let ranges = self.excerpt_ranges(cx);
self.multibuffer.update(cx, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&self.buffer, cx),
self.buffer.clone(),
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
});
cx.notify();
}
fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
let buffer = self.buffer.read(cx);
let diff = self.diff.read(cx);
let mut ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>();
ranges.extend(
self.revealed_ranges
.iter()
.map(|range| range.to_point(&buffer)),
);
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
// Merge adjacent ranges
let mut ranges = ranges.into_iter().peekable();
let mut merged_ranges = Vec::new();
while let Some(mut range) = ranges.next() {
while let Some(next_range) = ranges.peek() {
if range.end >= next_range.start {
range.end = range.end.max(next_range.end);
ranges.next();
} else {
break;
}
}
merged_ranges.push(range);
}
merged_ranges
}
}
pub struct FinalizedDiff {
path: PathBuf,
multibuffer: Entity<MultiBuffer>,
_update_diff: Task<Result<()>>,
}
async fn build_buffer_diff(
old_text: Arc<String>,
buffer: &Entity<Buffer>,
language_registry: Option<Arc<LanguageRegistry>>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let old_text_rope = cx
.background_spawn({
let old_text = old_text.clone();
async move { Rope::from(old_text.as_str()) }
})
.await;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text_rope,
buffer.language().cloned(),
language_registry,
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text),
base_buffer,
cx,
)
})?
.await;
let secondary_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer, cx);
diff.set_snapshot(diff_snapshot.clone(), &buffer, cx);
diff
})?;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer, cx);
diff.set_secondary_diff(secondary_diff);
diff
})
}

View file

@ -212,7 +212,16 @@ impl HistoryStore {
fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<Vec<HistoryEntryId>>> {
cx.background_spawn(async move {
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
let contents = smol::fs::read_to_string(path).await?;
let contents = match smol::fs::read_to_string(path).await {
Ok(it) => it,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Ok(Vec::new());
}
Err(e) => {
return Err(e)
.context("deserializing persisted agent panel navigation history");
}
};
let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&contents)
.context("deserializing persisted agent panel navigation history")?
.into_iter()

View file

@ -8,7 +8,7 @@ use crate::{
},
tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
@ -2112,12 +2112,10 @@ impl Thread {
return;
}
let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
let request = self.to_summarize_request(
&model.model,
CompletionIntent::ThreadSummarization,
added_user_message.into(),
SUMMARIZE_THREAD_PROMPT.into(),
cx,
);
@ -4047,8 +4045,8 @@ fn main() {{
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief");
fake_model.stream_last_completion_response(" Introduction");
fake_model.send_last_completion_stream_text_chunk("Brief");
fake_model.send_last_completion_stream_text_chunk(" Introduction");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@ -4141,7 +4139,7 @@ fn main() {{
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary");
fake_model.send_last_completion_stream_text_chunk("A successful summary");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@ -4774,7 +4772,7 @@ fn main() {{
!pending.is_empty(),
"Should have a pending completion after retry"
);
fake_model.stream_completion_response(&pending[0], "Success!");
fake_model.send_completion_stream_text_chunk(&pending[0], "Success!");
fake_model.end_completion_stream(&pending[0]);
cx.run_until_parked();
@ -4942,7 +4940,7 @@ fn main() {{
// Check for pending completions and complete them
if let Some(pending) = inner_fake.pending_completions().first() {
inner_fake.stream_completion_response(pending, "Success!");
inner_fake.send_completion_stream_text_chunk(pending, "Success!");
inner_fake.end_completion_stream(pending);
}
cx.run_until_parked();
@ -5427,7 +5425,7 @@ fn main() {{
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response");
fake_model.send_last_completion_stream_text_chunk("Assistant response");
fake_model.end_last_completion_stream();
cx.run_until_parked();
}

64
crates/agent2/Cargo.toml Normal file
View file

@ -0,0 +1,64 @@
[package]
name = "agent2"
version = "0.1.0"
edition = "2021"
license = "GPL-3.0-or-later"
publish = false
[lib]
path = "src/agent2.rs"
[lints]
workspace = true
[dependencies]
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
indoc.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
paths.workspace = true
project.workspace = true
prompt_store.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
lsp = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
worktree = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true

1
crates/agent2/LICENSE-GPL Symbolic link
View file

@ -0,0 +1 @@
../../LICENSE-GPL

696
crates/agent2/src/agent.rs Normal file
View file

@ -0,0 +1,696 @@
use crate::{templates::Templates, AgentResponseEvent, Thread};
use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
use futures::{future, StreamExt};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
"AGENT.md",
"AGENTS.md",
"GEMINI.md",
];
pub struct RulesLoadingError {
pub message: SharedString,
}
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
/// The internal thread that processes messages
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
_subscription: Subscription,
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
/// Shared templates for all threads
templates: Arc<Templates>,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
pub async fn new(
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
let project_context = cx
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
.await;
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
templates,
project,
prompt_store,
_subscriptions: subscriptions,
}
})
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
this.update(cx, |this, _| this.project_context.replace(project_context))?;
}
Ok(())
}
fn build_project_context(
project: &Entity<Project>,
prompt_store: Option<&Entity<PromptStore>>,
cx: &mut App,
) -> Task<ProjectContext> {
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let worktree_tasks = worktrees
.into_iter()
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
})
.collect::<Vec<_>>();
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
})
} else {
Task::ready(vec![])
};
cx.spawn(async move |_cx| {
let (worktrees, default_user_rules) =
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = worktrees
.into_iter()
.map(|(worktree, _rules_error)| {
// TODO: show error message
// if let Some(rules_error) = rules_error {
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
// }
worktree
})
.collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
},
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(_err) => {
// TODO: show error message
// this.update(cx, |_, cx| {
// cx.emit(RulesLoadingError {
// message: format!("{err:?}").into(),
// });
// })
// .ok();
None
}
})
.collect::<Vec<_>>();
ProjectContext::new(worktrees, default_user_rules)
})
}
fn load_worktree_info_for_system_prompt(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let tree = worktree.read(cx);
let root_name = tree.root_name().into();
let abs_path = tree.abs_path();
let mut context = WorktreeContext {
root_name,
abs_path,
rules_file: None,
};
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
let Some(rules_task) = rules_task else {
return Task::ready((context, None));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
context.rules_file = rules_file;
(context, rules_file_error)
})
}
fn load_worktree_rules_file(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Option<Task<Result<RulesFileContext>>> {
let worktree = worktree.read(cx);
let worktree_id = worktree.id();
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| entry.path.clone())
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|path_in_worktree| {
let project_path = ProjectPath {
worktree_id,
path: path_in_worktree.clone(),
};
let buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
let rope_task = cx.spawn(async move |cx| {
buffer_task.await?.read_with(cx, |buffer, cx| {
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
})?
});
// Build a string from the rope on a background thread.
cx.background_spawn(async move {
let (project_entry_id, rope) = rope_task.await?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
text: rope.to_string().trim().to_string(),
project_entry_id: project_entry_id.to_usize(),
})
})
})
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
}
}
fn handle_prompts_updated_event(
&mut self,
_prompt_store: Entity<PromptStore>,
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
self.project_context_needs_refresh.send(()).ok();
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl ModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
log::debug!("NativeAgentConnection::list_models called");
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
log::info!("Found {} available models", models.len());
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
log::info!(
"Setting model for session {}: {:?}",
session_id,
model.name()
);
let agent = self.0.clone();
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(session) = agent.sessions.get(&session_id) {
session.thread.update(cx, |thread, _cx| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow!("Session not found"))
}
})?
})
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
let session_id = session_id.clone();
cx.spawn(async move |cx| {
let thread = agent
.read_with(cx, |agent, _| {
agent
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
})?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
}
}
impl acp_thread::AgentConnection for NativeAgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let agent = self.0.clone();
log::info!("Creating new thread for project at: {:?}", cwd);
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
let default_model = registry
.default_model()
.map(|configured| {
log::info!(
"Using configured default model: {:?} from provider: {:?}",
configured.model.name(),
configured.provider.name()
);
configured.model
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|cx| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(cx.entity()));
thread
});
Ok(thread)
},
)??;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
session_id,
Session {
thread,
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
})
},
);
})?;
Ok(acp_thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[] // No auth for in-process
}
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
}
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
cx.spawn(async move |cx| {
// Get session
let (thread, acp_thread) = agent
.update(cx, |agent, _| {
agent
.sessions
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
})?
.ok_or_else(|| {
log::error!("Session not found: {}", session_id);
anyhow::anyhow!("Session not found")
})?;
log::debug!("Found session for: {}", session_id);
// Convert prompt to message
let message = convert_prompt_to_message(params.prompt);
log::info!("Converted prompt to message: {} chars", message.len());
log::debug!("Message content: {}", message);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
false,
cx,
)
})?;
}
AgentResponseEvent::Thinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
text,
annotations: None,
}),
true,
cx,
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})?;
}
AgentResponseEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
}
}
Err(e) => {
log::error!("Error in model response stream: {:?}", e);
// TODO: Consider sending an error message to the UI
break;
}
}
}
log::info!("Response stream completed");
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel());
}
});
}
}
/// Convert ACP content blocks to a message string
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
log::debug!("Converting {} content blocks to message", blocks.len());
let mut message = String::new();
for block in blocks {
match block {
acp::ContentBlock::Text(text) => {
log::trace!("Processing text block: {} chars", text.text.len());
message.push_str(&text.text);
}
acp::ContentBlock::ResourceLink(link) => {
log::trace!("Processing resource link: {}", link.uri);
message.push_str(&format!(" @{} ", link.uri));
}
acp::ContentBlock::Image(_) => {
log::trace!("Processing image block");
message.push_str(" [image] ");
}
acp::ContentBlock::Audio(_) => {
log::trace!("Processing audio block");
message.push_str(" [audio] ");
}
acp::ContentBlock::Resource(resource) => {
log::trace!("Processing resource block: {:?}", resource.resource);
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
}
}
}
message
}
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
use settings::SettingsStore;
#[gpui::test]
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"a": {}
}),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
let worktree = project
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
.await
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, _| {
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: None
}]
)
});
// Creating `/a/.rules` updates the project context.
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: Some(RulesFileContext {
path_in_worktree: Path::new(".rules").into(),
text: "".into(),
project_entry_id: rules_entry.id.to_usize()
})
}]
)
});
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
}
}

View file

@ -0,0 +1,14 @@
mod agent;
mod native_agent_server;
mod templates;
mod thread;
mod tools;
#[cfg(test)]
mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
pub use tools::*;

View file

@ -0,0 +1,60 @@
use std::path::Path;
use std::rc::Rc;
use agent_servers::AgentServer;
use anyhow::Result;
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
#[derive(Clone)]
pub struct NativeAgentServer;
impl AgentServer for NativeAgentServer {
fn name(&self) -> &'static str {
"Native Agent"
}
fn empty_state_headline(&self) -> &'static str {
"Native Agent"
}
fn empty_state_message(&self) -> &'static str {
"How can I help you today?"
}
fn logo(&self) -> ui::IconName {
// Using the ZedAssistant icon as it's the native built-in agent
ui::IconName::ZedAssistant
}
fn connect(
&self,
_root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
log::info!(
"NativeAgentServer::connect called for path: {:?}",
_root_dir
);
let project = project.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
let templates = Templates::new();
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);
log::info!("NativeAgentServer connection established successfully");
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
})
}
}

View file

@ -0,0 +1,87 @@
use anyhow::Result;
use gpui::SharedString;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true);
handlebars.register_helper("contains", Box::new(contains));
handlebars.register_embed_templates::<Assets>().unwrap();
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}
#[derive(Serialize)]
pub struct SystemPromptTemplate<'a> {
#[serde(flatten)]
pub project: &'a prompt_store::ProjectContext,
pub available_tools: Vec<SharedString>,
}
impl Template for SystemPromptTemplate<'_> {
const TEMPLATE_NAME: &'static str = "system_prompt.hbs";
}
/// Handlebars helper for checking if an item is in a list
fn contains(
h: &handlebars::Helper,
_: &handlebars::Handlebars,
_: &handlebars::Context,
_: &mut handlebars::RenderContext,
out: &mut dyn handlebars::Output,
) -> handlebars::HelperResult {
let list = h
.param(0)
.and_then(|v| v.value().as_array())
.ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid list parameter")
})?;
let query = h.param(1).map(|v| v.value()).ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid query parameter")
})?;
if list.contains(&query) {
out.write("true")?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_system_prompt_template() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Fixing Diagnostics"));
}
}

View file

@ -0,0 +1,178 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the user in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
{{#if (gt (len available_tools) 0)}}
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
7. Avoid HTML entity escaping - use plain characters instead.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{abs_path}}`
{{/each}}
- Bias towards not asking the user for help if you can find the answer yourself.
- When providing paths to tools, the path should always start with the name of a project root directory listed above.
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
{{# if (contains available_tools 'grep') }}
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
{{/if}}
## Code Block Formatting
Whenever you mention a code block, you MUST use ONLY use the following format:
```path/to/Something.blah#L123-456
(code goes here)
```
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
{{#if (gt (len available_tools) 0)}}
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
{{/if}}
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}
{{#if (or has_rules has_user_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.path_in_worktree}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}}
``````
{{/each}}
{{/if}}
{{/if}}

View file

@ -0,0 +1,846 @@
use super::*;
use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use assistant_tool::ActionLog;
use client::{Client, UserStore};
use fs::FakeFs;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
use indoc::indoc;
use language_model::{
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use smol::stream::StreamExt;
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
use test_tools::*;
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_echo(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
let events = thread
.update(cx, |thread, cx| {
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
let events = thread
.update(cx, |thread, cx| {
thread.send(
model.clone(),
indoc! {"
Testing:
Generate a thinking step where you just think the word 'Think',
and have your final answer be 'Hello'
"},
cx,
)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().to_markdown(),
indoc! {"
## assistant
<think>Think</think>
Hello
"}
)
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
#[gpui::test]
async fn test_system_prompt(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
project_context,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
pending_completions.len(),
1,
"unexpected pending completions: {:?}",
pending_completions
);
let pending_completion = pending_completions.pop().unwrap();
assert_eq!(pending_completion.messages[0].role, Role::System);
let system_message = &pending_completion.messages[0];
let system_prompt = system_message.content[0].to_str().unwrap();
assert!(
system_prompt.contains("test-shell"),
"unexpected system message: {:?}",
system_message
);
assert!(
system_prompt.contains("## Fixing Diagnostics"),
"unexpected system message: {:?}",
system_message
);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
// Test a tool call that's likely to complete *before* streaming stops.
let events = thread
.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(
model.clone(),
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
cx,
)
})
.collect()
.await;
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
// Test a tool calls that's likely to complete *after* streaming stops.
let events = thread
.update(cx, |thread, cx| {
thread.remove_tool(&AgentTool::name(&EchoTool));
thread.add_tool(DelayTool);
thread.send(
model.clone(),
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
cx,
)
})
.collect()
.await;
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
thread.update(cx, |thread, _cx| {
assert!(thread
.messages()
.last()
.unwrap()
.content
.iter()
.any(|content| {
if let MessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
}));
});
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
// Test a tool call that's likely to complete *before* streaming stops.
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(WordListTool);
thread.send(model.clone(), "Test the word_list tool.", cx)
});
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let last_content = thread.messages().last().unwrap().content.last().unwrap();
if let MessageContent::ToolUse(last_tool_use) = last_content {
assert_eq!(last_tool_use.name.as_ref(), "word_list");
if tool_call.status == acp::ToolCallStatus::Pending {
if !last_tool_use.is_input_complete
&& last_tool_use.input.get("g").is_none()
{
saw_partial_tool_use = true;
}
} else {
last_tool_use
.input
.get("a")
.expect("'a' has streamed because input is now complete");
last_tool_use
.input
.get("g")
.expect("'g' has streamed because input is now complete");
}
} else {
panic!("last content should be a tool use");
}
});
}
}
assert!(
saw_partial_tool_use,
"should see at least one partially streamed tool use in the history"
);
}
#[gpui::test]
async fn test_tool_authorization(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send(model.clone(), "abc", cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_2".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
// Approve the first
tool_call_auth_1
.response
.send(tool_call_auth_1.options[1].id.clone())
.unwrap();
cx.run_until_parked();
// Reject the second
tool_call_auth_2
.response
.send(tool_call_auth_1.options[2].id.clone())
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: None
})
]
);
}
#[gpui::test]
async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: "nonexistent_tool".into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call = expect_tool_call(&mut events).await;
assert_eq!(tool_call.title, "nonexistent_tool");
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
event => {
panic!("Unexpected event {event:?}");
}
}
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update
}
event => {
panic!("Unexpected event {event:?}");
}
}
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> ToolCallAuthorization {
loop {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
.map(|o| o.kind)
.collect::<Vec<_>>();
assert_eq!(
permission_kinds,
vec![
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::AllowOnce,
acp::PermissionOptionKind::RejectOnce,
]
);
return tool_call_authorization;
}
}
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
// Test concurrent tool calls with different delay times
let events = thread
.update(cx, |thread, cx| {
thread.add_tool(DelayTool);
thread.send(
model.clone(),
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
cx,
)
})
.collect()
.await;
let stop_reasons = stop_events(events);
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
thread.update(cx, |thread, _cx| {
let last_message = thread.messages().last().unwrap();
let text = last_message
.content
.iter()
.filter_map(|content| {
if let MessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
}
})
.collect::<String>();
assert!(text.contains("Ding"));
});
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(InfiniteTool);
thread.add_tool(EchoTool);
thread.send(
model.clone(),
"Call the echo tool and then call the infinite tool, then explain their output",
cx,
)
});
// Wait until both tools are called.
let mut expected_tools = vec!["Echo", "Infinite Tool"];
let mut echo_id = None;
let mut echo_completed = false;
while let Some(event) = events.next().await {
match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" {
echo_id = Some(tool_call.id);
}
}
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate {
id,
fields:
acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..
},
},
)) if Some(&id) == echo_id.as_ref() => {
echo_completed = true;
}
_ => {}
}
if expected_tools.is_empty() && echo_completed {
break;
}
}
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel());
events.collect::<Vec<_>>().await;
// Ensure we can still send a new message after cancellation.
let events = thread
.update(cx, |thread, cx| {
thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
})
.collect::<Vec<_>>()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
#[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## user
Hello
"}
);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## user
Hello
## assistant
Hey!
"}
);
});
// If the model refuses to continue, the thread should remove all the messages after the last user message.
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
let events = events.collect::<Vec<_>>().await;
assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
});
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
let templates = Templates::new();
// Initialize language model system with test provider
cx.update(|cx| {
gpui_tokio::init(cx);
client::init_settings(cx);
let http_client = FakeHttpClient::with_404_response();
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
let selector_opt = connection.model_selector();
assert!(
selector_opt.is_some(),
"agent2 should always support ModelSelector"
);
let selector = selector_opt.unwrap();
// Test list_models
let listed_models = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.await
.expect("list_models should succeed");
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a thread using new_thread
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.await
.expect("new_thread should succeed");
// Get the session_id from the AcpThread
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
// Test selected_model returns the default
let model = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await
.expect("selected_model should succeed");
let model = model.as_fake();
assert_eq!(model.id().0, "fake", "should return default model");
let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
cx.run_until_parked();
model.send_last_completion_stream_text_chunk("def");
cx.run_until_parked();
acp_thread.read_with(cx, |thread, cx| {
assert_eq!(
thread.to_markdown(cx),
indoc! {"
## User
abc
## Assistant
def
"}
)
});
// Test cancel
cx.update(|cx| connection.cancel(&session_id, cx));
request.await.expect("prompt should fail gracefully");
// Ensure that dropping the ACP thread causes the native thread to be
// dropped as well.
cx.update(|_| drop(acp_thread));
let result = cx
.update(|cx| {
connection.prompt(
acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec!["ghi".into()],
},
cx,
)
})
.await;
assert_eq!(
result.as_ref().unwrap_err().to_string(),
"Session not found",
"unexpected result: {:?}",
result
);
}
#[gpui::test]
async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
cx.run_until_parked();
// Simulate streaming partial input.
let input = json!({});
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "1".into(),
name: ThinkingTool.name().into(),
raw_input: input.to_string(),
input,
is_input_complete: false,
},
));
// Input streaming completed
let input = json!({ "content": "Thinking hard!" });
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "1".into(),
name: "thinking".into(),
raw_input: input.to_string(),
input,
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let tool_call = expect_tool_call(&mut events).await;
assert_eq!(
tool_call,
acp::ToolCall {
id: acp::ToolCallId("1".into()),
title: "Thinking".into(),
kind: acp::ToolKind::Think,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(json!({})),
raw_output: None,
}
);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
title: Some("Thinking".into()),
kind: Some(acp::ToolKind::Think),
raw_input: Some(json!({ "content": "Thinking hard!" })),
..Default::default()
},
}
);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
}
);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
content: Some(vec!["Thinking hard!".into()]),
..Default::default()
},
}
);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(
update,
acp::ToolCallUpdate {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..Default::default()
},
}
);
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
}
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
}
enum TestModel {
Sonnet4,
Sonnet4Thinking,
Fake,
}
impl TestModel {
fn id(&self) -> LanguageModelId {
match self {
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
TestModel::Fake => unreachable!(),
}
}
}
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
cx.update(|cx| {
settings::init(cx);
Project::init_settings(cx);
});
let templates = Templates::new();
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
let model_id = model.id();
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.id() == model_id)
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
let authenticated = provider.authenticate(cx);
cx.spawn(async move |_cx| {
authenticated.await.unwrap();
model
})
}
})
.await;
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_| {
Thread::new(
project,
project_context.clone(),
action_log,
templates,
model.clone(),
)
});
ThreadTest {
model,
thread,
project_context,
}
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

View file

@ -0,0 +1,201 @@
use super::*;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use std::future;
/// A tool that echoes its input
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
/// The text to echo.
text: String,
}
pub struct EchoTool;
impl AgentTool for EchoTool {
type Input = EchoToolInput;
type Output = String;
fn name(&self) -> SharedString {
"echo".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
"Echo".into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
Task::ready(Ok(input.text))
}
}
/// A tool that waits for a specified delay
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct DelayToolInput {
/// The delay in milliseconds.
ms: u64,
}
pub struct DelayTool;
impl AgentTool for DelayTool {
type Input = DelayToolInput;
type Output = String;
fn name(&self) -> SharedString {
"delay".into()
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
if let Ok(input) = input {
format!("Delay {}ms", input.ms).into()
} else {
"Delay".into()
}
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>
where
Self: Sized,
{
cx.foreground_executor().spawn(async move {
smol::Timer::after(Duration::from_millis(input.ms)).await;
Ok("Ding".to_string())
})
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct ToolRequiringPermissionInput {}
pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
type Output = String;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
"This tool requires permission".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let auth_check = event_stream.authorize("Authorize?".into());
cx.foreground_executor().spawn(async move {
auth_check.await?;
Ok("Allowed".to_string())
})
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct InfiniteToolInput {}
pub struct InfiniteTool;
impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput;
type Output = String;
fn name(&self) -> SharedString {
"infinite".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
"Infinite Tool".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
cx.foreground_executor().spawn(async move {
future::pending::<()>().await;
unreachable!()
})
}
}
/// A tool that takes an object with map from letters to random words starting with that letter.
/// All fiealds are required! Pass a word for every letter!
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct WordListInput {
/// Provide a random word that starts with A.
a: Option<String>,
/// Provide a random word that starts with B.
b: Option<String>,
/// Provide a random word that starts with C.
c: Option<String>,
/// Provide a random word that starts with D.
d: Option<String>,
/// Provide a random word that starts with E.
e: Option<String>,
/// Provide a random word that starts with F.
f: Option<String>,
/// Provide a random word that starts with G.
g: Option<String>,
}
pub struct WordListTool;
impl AgentTool for WordListTool {
type Input = WordListInput;
type Output = String;
fn name(&self) -> SharedString {
"word_list".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
"List of random words".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
Task::ready(Ok("ok".to_string()))
}
}

1026
crates/agent2/src/thread.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,9 @@
mod edit_file_tool;
mod find_path_tool;
mod read_file_tool;
mod thinking_tool;
pub use edit_file_tool::*;
pub use find_path_tool::*;
pub use read_file_tool::*;
pub use thinking_tool::*;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,248 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
/// Fast file path pattern matching tool that works with any codebase size
///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
/// - Returns matching file paths sorted alphabetically
/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths.
/// - Use this tool when you need to find files by name patterns
/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindPathToolInput {
/// The glob to match against every path in the project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1/a/something.txt
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can get back the first two paths by providing a glob of "*thing*.txt"
/// </example>
pub glob: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: usize,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FindPathToolOutput {
offset: usize,
current_matches_page: Vec<PathBuf>,
all_matches_len: usize,
}
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
fn from(output: FindPathToolOutput) -> Self {
if output.current_matches_page.is_empty() {
"No matches found".into()
} else {
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
if output.all_matches_len > RESULTS_PER_PAGE {
write!(
&mut llm_output,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
output.offset + 1,
output.offset + output.current_matches_page.len()
)
.unwrap();
}
for mat in output.current_matches_page {
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
}
llm_output.into()
}
}
}
const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool {
project: Entity<Project>,
}
impl FindPathTool {
pub fn new(project: Entity<Project>) -> Self {
Self { project }
}
}
impl AgentTool for FindPathTool {
type Input = FindPathToolInput;
type Output = FindPathToolOutput;
fn name(&self) -> SharedString {
"find_path".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Search
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
let mut title = "Find paths".to_string();
if let Ok(input) = input {
title.push_str(&format!(" matching “`{}`”", input.glob));
}
title.into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
let matches = search_paths_task.await?;
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(if paginated_matches.len() == 0 {
"No matches".into()
} else if paginated_matches.len() == 1 {
"1 match".into()
} else {
format!("{} matches", paginated_matches.len())
}),
content: Some(
paginated_matches
.iter()
.map(|path| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: format!("file://{}", path.display()),
name: path.to_string_lossy().into(),
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}),
})
.collect(),
),
raw_output: Some(serde_json::json!({
"paths": &matches,
})),
..Default::default()
});
Ok(FindPathToolOutput {
offset: input.offset,
current_matches_page: paginated_matches.to_vec(),
all_matches_len: matches.len(),
})
})
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<_> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_find_path_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View file

@ -0,0 +1,951 @@
use agent_client_protocol::{self as acp};
use anyhow::{anyhow, Context, Result};
use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task};
use indoc::formatdoc;
use language::{Anchor, Point};
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use ui::{App, SharedString};
use crate::{AgentTool, ToolCallEventStream};
/// Reads the content of the given file in the project.
///
/// - Never attempt to read a path that hasn't been previously mentioned.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ReadFileToolInput {
/// The relative path of the file to read.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.
///
/// <example>
/// If the project has the following root directories:
///
/// - /a/b/directory1
/// - /c/d/directory2
///
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
/// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
/// </example>
pub path: String,
/// Optional line number to start reading on (1-based index)
#[serde(default)]
pub start_line: Option<u32>,
/// Optional line number to end reading on (1-based index, inclusive)
#[serde(default)]
pub end_line: Option<u32>,
}
pub struct ReadFileTool {
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl ReadFileTool {
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
Self {
project,
action_log,
}
}
}
impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput;
type Output = LanguageModelToolResultContent;
fn name(&self) -> SharedString {
"read_file".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
if let Ok(input) = input {
let path = &input.path;
match (input.start_line, input.end_line) {
(Some(start), Some(end)) => {
format!(
"[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))",
path, start, end, path, start, end
)
}
(Some(start), None) => {
format!(
"[Read file `{}` (from line {})](@selection:{}:({}-{}))",
path, start, path, start, start
)
}
_ => format!("[Read file `{}`](@file:{})", path, path),
}
.into()
} else {
"Read file".into()
}
}
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
// Error out if this path is either excluded or private in global settings
let global_settings = WorktreeSettings::get_global(cx);
if global_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
&input.path
)));
}
if global_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the global `private_files` setting: {}",
&input.path
)));
}
// Error out if this path is either excluded or private in worktree settings
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
if worktree_settings.is_path_excluded(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
&input.path
)));
}
if worktree_settings.is_path_private(&project_path.path) {
return Task::ready(Err(anyhow!(
"Cannot read file because its path matches the worktree `private_files` setting: {}",
&input.path
)));
}
let file_path = input.path.clone();
if image_store::is_image_file(&self.project, &project_path, cx) {
return cx.spawn(async move |cx| {
let image_entity: Entity<ImageItem> = cx
.update(|cx| {
self.project.update(cx, |project, cx| {
project.open_image(project_path.clone(), cx)
})
})?
.await?;
let image =
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
let language_model_image = cx
.update(|cx| LanguageModelImage::from_image(image, cx))?
.await
.context("processing image")?;
Ok(language_model_image.into())
});
}
let project = self.project.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if buffer.read_with(cx, |buffer, _| {
buffer
.file()
.as_ref()
.map_or(true, |file| !file.disk_state().exists())
})? {
anyhow::bail!("{file_path} not found");
}
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
// Check if specific line ranges are provided
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let start_row = start - 1;
if start_row <= buffer.max_point().row {
let column = buffer.line_indent_for_row(start_row).raw_len();
anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
}
let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
itertools::intersperse(lines.take(count as usize), "\n").collect::<String>()
} else {
itertools::intersperse(lines, "\n").collect::<String>()
}
})?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result.into())
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= outline::AUTO_OUTLINE_SIZE {
// File is small enough, so return its contents.
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
Ok(result.into())
} else {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
let outline =
outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once.
Here is an outline of its symbols:
{outline}
Using the line numbers in this outline, you can call this tool again
while specifying the start_line and end_line fields to see the
implementations of symbols in the outline.
Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content."
}
.into())
}
}
})
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/root"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/nonexistent_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.run(input, event_stream, cx)
})
.await;
assert_eq!(
result.unwrap_err().to_string(),
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"small_file.txt": "This is a small file content"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/small_file.txt".into(),
start_line: None,
end_line: None,
};
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "This is a small file content".into());
}
#[gpui::test]
async fn test_read_large_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
let content = result.to_str().unwrap();
assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(),
vec![
"struct Test0 [L1-4]",
" a [L2]",
" b [L3]",
"struct Test1 [L5-8]",
" a [L6]",
" b [L7]",
]
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
let content = result.to_str().unwrap();
let expected_content = (0..1000)
.flat_map(|i| {
vec![
format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
format!(" a [L{}]", i * 4 + 2),
format!(" b [L{}]", i * 4 + 3),
]
})
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(
content
.lines()
.skip(4)
.take(expected_content.len())
.collect::<Vec<_>>(),
expected_content
);
}
#[gpui::test]
async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(2),
end_line: Some(4),
};
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
}
#[gpui::test]
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/root"),
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(0),
end_line: Some(2),
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(1),
end_line: Some(0),
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1".into());
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/multiline.txt".to_string(),
start_line: Some(3),
end_line: Some(2),
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 3".into());
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(
r#"
(line_comment) @annotation
(struct_item
"struct" @context
name: (_) @name) @item
(enum_item
"enum" @context
name: (_) @name) @item
(enum_variant
name: (_) @name) @item
(field_declaration
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name
body: (_ "{" (_)* "}")) @item
(function_item
"fn" @context
name: (_) @name) @item
(mod_item
"mod" @context
name: (_) @name) @item
"#,
)
.unwrap()
}
#[gpui::test]
async fn test_read_file_security(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/"),
json!({
"project_root": {
"allowed_file.txt": "This file is in the project",
".mysecrets": "SECRET_KEY=abc123",
".secretdir": {
"config": "special configuration"
},
".mymetadata": "custom metadata",
"subdir": {
"normal_file.txt": "Normal file content",
"special.privatekey": "private key content",
"data.mysensitive": "sensitive data"
}
},
"outside_project": {
"sensitive_file.txt": "This file is outside the project"
}
}),
)
.await;
cx.update(|cx| {
use gpui::UpdateGlobal;
use project::WorktreeSettings;
use settings::SettingsStore;
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions = Some(vec![
"**/.secretdir".to_string(),
"**/.mymetadata".to_string(),
]);
settings.private_files = Some(vec![
"**/.mysecrets".to_string(),
"**/*.privatekey".to_string(),
"**/*.mysensitive".to_string(),
]);
});
});
});
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
// Reading a file outside the project worktree should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "/outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read an absolute path outside a worktree"
);
// Reading a file within the project should succeed
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/allowed_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_ok(),
"read_file_tool should be able to read files inside worktrees"
);
// Reading files that match file_scan_exclusions should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.secretdir/config".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mymetadata".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)"
);
// Reading private files should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/.mysecrets".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysecrets (private_files)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/special.privatekey".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .privatekey files (private_files)"
);
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/data.mysensitive".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read .mysensitive files (private_files)"
);
// Reading a normal file should still work, even with private_files configured
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/subdir/normal_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(result.unwrap(), "Normal file content".into());
// Path traversal attempts with .. should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
start_line: None,
end_line: None,
};
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
result.is_err(),
"read_file_tool should error when attempting to read a relative path that resolves to outside a worktree"
);
}
#[gpui::test]
async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
// Create first worktree with its own private_files setting
fs.insert_tree(
path!("/worktree1"),
json!({
"src": {
"main.rs": "fn main() { println!(\"Hello from worktree1\"); }",
"secret.rs": "const API_KEY: &str = \"secret_key_1\";",
"config.toml": "[database]\nurl = \"postgres://localhost/db1\""
},
"tests": {
"test.rs": "mod tests { fn test_it() {} }",
"fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/fixture.*"],
"private_files": ["**/secret.rs", "**/config.toml"]
}"#
}
}),
)
.await;
// Create second worktree with different private_files setting
fs.insert_tree(
path!("/worktree2"),
json!({
"lib": {
"public.js": "export function greet() { return 'Hello from worktree2'; }",
"private.js": "const SECRET_TOKEN = \"private_token_2\";",
"data.json": "{\"api_key\": \"json_secret_key\"}"
},
"docs": {
"README.md": "# Public Documentation",
"internal.md": "# Internal Secrets and Configuration"
},
".zed": {
"settings.json": r#"{
"file_scan_exclusions": ["**/internal.*"],
"private_files": ["**/private.js", "**/data.json"]
}"#
}
}),
)
.await;
// Set global settings
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
settings.file_scan_exclusions =
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
settings.private_files = Some(vec!["**/.env".to_string()]);
});
});
});
let project = Project::test(
fs.clone(),
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
cx,
)
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
// Test reading allowed files in worktree1
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/main.rs".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
assert_eq!(
result,
"fn main() { println!(\"Hello from worktree1\"); }".into()
);
// Test reading private file in worktree1 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/secret.rs".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree1 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/tests/fixture.sql".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test reading allowed files in worktree2
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/lib/public.js".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
assert_eq!(
result,
"export function greet() { return 'Hello from worktree2'; }".into()
);
// Test reading private file in worktree2 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/lib/private.js".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Error should mention worktree private_files setting"
);
// Test reading excluded file in worktree2 should fail
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree2/docs/internal.md".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `file_scan_exclusions` setting"),
"Error should mention worktree file_scan_exclusions setting"
);
// Test that files allowed in one worktree but not in another are handled correctly
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "worktree1/src/config.toml".to_string(),
start_line: None,
end_line: None,
};
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("worktree `private_files` setting"),
"Config.toml should be blocked by worktree1's private_files setting"
);
}
}

View file

@ -0,0 +1,49 @@
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::{AgentTool, ToolCallEventStream};
/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions.
/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ThinkingToolInput {
/// Content to think about. This should be a description of what to think about or
/// a problem to solve.
content: String,
}
pub struct ThinkingTool;
impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput;
type Output = String;
fn name(&self) -> SharedString {
"thinking".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Think
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
"Thinking".into()
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![input.content.into()]),
..Default::default()
});
Task::ready(Ok("Finished thinking.".to_string()))
}
}

View file

@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate {
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_permission(tool_call, acp_options, cx)
thread.request_tool_call_authorization(tool_call, acp_options, cx)
})
})?
.context("Failed to update thread")?
@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams)
.map(into_new_tool_call_location)
.collect(),
raw_input: None,
raw_output: None,
}
}
@ -380,6 +381,7 @@ impl AcpConnection {
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
log::trace!("Spawned (pid: {})", child.id());
let foreground_executor = cx.foreground_executor().clone();
@ -463,7 +465,11 @@ impl AgentConnection for AcpConnection {
})
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let chunks = params
.prompt
.into_iter()
@ -483,7 +489,9 @@ impl AgentConnection for AcpConnection {
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
anyhow::Ok(())
anyhow::Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}

View file

@ -19,7 +19,6 @@ pub struct AcpConnection {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
_io_task: Task<Result<()>>,
_child: smol::process::Child,
}
pub struct AcpSession {
@ -47,6 +46,7 @@ impl AcpConnection {
let stdout = child.stdout.take().expect("Failed to take stdout");
let stdin = child.stdin.take().expect("Failed to take stdin");
log::trace!("Spawned (pid: {})", child.id());
let sessions = Rc::new(RefCell::new(HashMap::default()));
@ -63,6 +63,23 @@ impl AcpConnection {
let io_task = cx.background_spawn(io_task);
cx.spawn({
let sessions = sessions.clone();
async move |cx| {
let status = child.status().await?;
for session in sessions.borrow().values() {
session
.thread
.update(cx, |thread, cx| thread.emit_server_exited(status, cx))
.ok();
}
anyhow::Ok(())
}
})
.detach();
let response = connection
.initialize(acp::InitializeRequest {
protocol_version: acp::VERSION,
@ -84,7 +101,6 @@ impl AcpConnection {
connection: connection.into(),
server_name,
sessions,
_child: child,
_io_task: io_task,
})
}
@ -153,10 +169,16 @@ impl AgentConnection for AcpConnection {
})
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let conn = self.connection.clone();
cx.foreground_executor()
.spawn(async move { Ok(conn.prompt(params).await?) })
cx.foreground_executor().spawn(async move {
let response = conn.prompt(params).await?;
Ok(response)
})
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
@ -188,7 +210,7 @@ impl acp::Client for ClientDelegate {
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx.await;

View file

@ -89,6 +89,7 @@ impl AgentServerCommand {
pub(crate) async fn resolve(
path_bin_name: &'static str,
extra_args: &[&'static str],
fallback_path: Option<&Path>,
settings: Option<AgentServerSettings>,
project: &Entity<Project>,
cx: &mut AsyncApp,
@ -105,13 +106,24 @@ impl AgentServerCommand {
env: agent_settings.command.env,
});
} else {
find_bin_in_path(path_bin_name, project, cx)
.await
.map(|path| Self {
match find_bin_in_path(path_bin_name, project, cx).await {
Some(path) => Some(Self {
path,
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
env: None,
})
}),
None => fallback_path.and_then(|path| {
if path.exists() {
Some(Self {
path: path.to_path_buf(),
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
env: None,
})
} else {
None
}
}),
}
}
}
}

View file

@ -24,7 +24,7 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize};
use util::ResultExt;
use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
@ -101,8 +101,15 @@ impl AgentConnection for ClaudeAgentConnection {
settings.get::<AllAgentServersSettings>(None).claude.clone()
})?;
let Some(command) =
AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
let Some(command) = AgentServerCommand::resolve(
"claude",
&[],
Some(&util::paths::home_dir().join(".claude/local/claude")),
settings,
&project,
cx,
)
.await
else {
anyhow::bail!("Failed to find claude binary");
};
@ -114,53 +121,63 @@ impl AgentConnection for ClaudeAgentConnection {
log::trace!("Starting session with id: {}", session_id);
cx.background_spawn({
let session_id = session_id.clone();
async move {
let mut outgoing_rx = Some(outgoing_rx);
let mut child = spawn_claude(
&command,
ClaudeSessionMode::Start,
session_id.clone(),
&mcp_config_path,
&cwd,
)?;
let mut child = spawn_claude(
&command,
ClaudeSessionMode::Start,
session_id.clone(),
&mcp_config_path,
&cwd,
)
.await?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let pid = child.id();
log::trace!("Spawned (pid: {})", pid);
let pid = child.id();
log::trace!("Spawned (pid: {})", pid);
ClaudeAgentSession::handle_io(
outgoing_rx.take().unwrap(),
incoming_message_tx.clone(),
child.stdin.take().unwrap(),
child.stdout.take().unwrap(),
)
.await?;
cx.background_spawn(async move {
let mut outgoing_rx = Some(outgoing_rx);
log::trace!("Stopped (pid: {})", pid);
ClaudeAgentSession::handle_io(
outgoing_rx.take().unwrap(),
incoming_message_tx.clone(),
stdin,
stdout,
)
.await?;
drop(mcp_config_path);
anyhow::Ok(())
}
log::trace!("Stopped (pid: {})", pid);
drop(mcp_config_path);
anyhow::Ok(())
})
.detach();
let end_turn_tx = Rc::new(RefCell::new(None));
let turn_state = Rc::new(RefCell::new(TurnState::None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
let thread_rx = thread_rx.clone();
let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
end_turn_tx.clone(),
turn_state.clone(),
cx,
)
.await
}
if let Some(status) = child.status().await.log_err() {
if let Some(thread) = thread_rx.recv().await.ok() {
thread
.update(cx, |thread, cx| {
thread.emit_server_exited(status, cx);
})
.ok();
}
}
}
});
@ -172,7 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
turn_state,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@ -191,7 +208,11 @@ impl AgentConnection for ClaudeAgentConnection {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&params.session_id) else {
return Task::ready(Err(anyhow!(
@ -200,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
)));
};
let (tx, rx) = oneshot::channel();
session.end_turn_tx.borrow_mut().replace(tx);
let (end_tx, end_rx) = oneshot::channel();
session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new();
for chunk in params.prompt {
@ -235,10 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
cx.foreground_executor().spawn(async move {
rx.await??;
Ok(())
})
cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@ -248,9 +266,26 @@ impl AgentConnection for ClaudeAgentConnection {
return;
};
let request_id = new_request_id();
let turn_state = session.turn_state.take();
let TurnState::InProgress { end_tx } = turn_state else {
// Already cancelled or idle, put it back
session.turn_state.replace(turn_state);
return;
};
session.turn_state.replace(TurnState::CancelRequested {
end_tx,
request_id: request_id.clone(),
});
session
.outgoing_tx
.unbounded_send(SdkMessage::new_interrupt_message())
.unbounded_send(SdkMessage::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
})
.log_err();
}
}
@ -262,7 +297,7 @@ enum ClaudeSessionMode {
Resume,
}
async fn spawn_claude(
fn spawn_claude(
command: &AgentServerCommand,
mode: ClaudeSessionMode,
session_id: acp::SessionId,
@ -313,26 +348,139 @@ async fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
turn_state: Rc<RefCell<TurnState>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
#[derive(Debug, Default)]
enum TurnState {
#[default]
None,
InProgress {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
CancelRequested {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String,
},
CancelConfirmed {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
}
impl TurnState {
fn is_cancelled(&self) -> bool {
matches!(self, TurnState::CancelConfirmed { .. })
}
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
match self {
TurnState::None => None,
TurnState::InProgress { end_tx, .. } => Some(end_tx),
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
}
}
fn confirm_cancellation(self, id: &str) -> Self {
match self {
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
TurnState::CancelConfirmed { end_tx }
}
_ => self,
}
}
}
impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
turn_state: Rc<RefCell<TurnState>>,
cx: &mut AsyncApp,
) {
match message {
// we should only be sending these out, they don't need to be in the thread
SdkMessage::ControlRequest { .. } => {}
SdkMessage::Assistant {
SdkMessage::User {
message,
session_id: _,
} => {
let Some(thread) = thread_rx
.recv()
.await
.log_err()
.and_then(|entity| entity.upgrade())
else {
log::error!("Received an SDK message but thread is gone");
return;
};
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
})
.log_err();
}
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: if turn_state.borrow().is_cancelled() {
// Do not set to completed if turn was cancelled
None
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
}
ContentChunk::Thinking { .. }
| ContentChunk::RedactedThinking
| ContentChunk::ToolUse { .. } => {
debug_panic!(
"Should not get {:?} with role: assistant. should we handle this?",
chunk
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
format!("Unsupported content: {:?}", chunk).into(),
false,
cx,
)
})
.log_err();
}
}
}
}
| SdkMessage::User {
SdkMessage::Assistant {
message,
session_id: _,
} => {
@ -355,6 +503,24 @@ impl ClaudeAgentSession {
})
.log_err();
}
ContentChunk::Thinking { thinking } => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(thinking.into(), true, cx)
})
.log_err();
}
ContentChunk::RedactedThinking => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
"[REDACTED]".into(),
true,
cx,
)
})
.log_err();
}
ContentChunk::ToolUse { id, name, input } => {
let claude_tool = ClaudeTool::infer(&name, input);
@ -380,33 +546,12 @@ impl ClaudeAgentSession {
})
.log_err();
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
debug_panic!(
"Should not get tool results with role: assistant. should we handle this?"
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::Thinking
| ContentChunk::RedactedThinking
| ContentChunk::WebSearchToolResult => {
ContentChunk::Image | ContentChunk::Document => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@ -426,20 +571,41 @@ impl ClaudeAgentSession {
result,
..
} => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error {
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
end_turn_tx.send(Ok(())).ok();
}
let turn_state = turn_state.take();
let was_cancelled = turn_state.is_cancelled();
let Some(end_turn_tx) = turn_state.end_tx() else {
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
return;
};
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
{
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
}
}
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) {
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
turn_state.replace(new_state);
} else {
log::error!("Control response error: {:?}", response);
}
}
SdkMessage::System { .. } => {}
}
}
@ -548,11 +714,13 @@ enum ContentChunk {
content: Content,
tool_use_id: String,
},
Thinking {
thinking: String,
},
RedactedThinking,
// TODO
Image,
Document,
Thinking,
RedactedThinking,
WebSearchToolResult,
#[serde(untagged)]
UntaggedText(String),
@ -562,12 +730,12 @@ impl Display for ContentChunk {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentChunk::Text { text } => write!(f, "{}", text),
ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::Thinking
| ContentChunk::RedactedThinking
| ContentChunk::ToolUse { .. }
| ContentChunk::WebSearchToolResult => {
write!(f, "\n{:?}\n", &self)
@ -660,7 +828,7 @@ struct ControlResponse {
subtype: ResultErrorType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
enum ResultErrorType {
Success,
@ -678,22 +846,15 @@ impl Display for ResultErrorType {
}
}
impl SdkMessage {
fn new_interrupt_message() -> Self {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
let request_id = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect();
Self::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
}
}
fn new_request_id() -> String {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -714,6 +875,8 @@ enum PermissionMode {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::e2e_tests;
use gpui::TestAppContext;
use serde_json::json;
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
@ -726,6 +889,68 @@ pub(crate) mod tests {
}
}
#[gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)]
async fn test_todo_plan(cx: &mut TestAppContext) {
let fs = e2e_tests::init_test(cx).await;
let project = Project::test(fs, [], cx).await;
let thread =
e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
cx,
)
})
.await
.unwrap();
let mut entries_len = 0;
thread.read_with(cx, |thread, _| {
entries_len = thread.plan().entries.len();
assert!(thread.plan().entries.len() > 0, "Empty plan");
});
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Mark the first entry status as in progress without acting on it.",
cx,
)
})
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
thread.plan().entries[0].status,
acp::PlanEntryStatus::InProgress
));
assert_eq!(thread.plan().entries.len(), entries_len);
});
thread
.update(cx, |thread, cx| {
thread.send_raw(
"Now mark the first entry as completed without acting on it.",
cx,
)
})
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
thread.plan().entries[0].status,
acp::PlanEntryStatus::Completed
));
assert_eq!(thread.plan().entries.len(), entries_len);
});
}
#[test]
fn test_deserialize_content_untagged_text() {
let json = json!("Hello, world!");

View file

@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool {
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id),
vec![
acp::PermissionOption {

View file

@ -143,25 +143,6 @@ impl ClaudeTool {
Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
Self::WebSearch(Some(params)) => vec![params.to_string().into()],
Self::TodoWrite(Some(params)) => vec![
params
.todos
.iter()
.map(|todo| {
format!(
"- {} {}: {}",
match todo.status {
TodoStatus::Completed => "",
TodoStatus::InProgress => "🚧",
TodoStatus::Pending => "",
},
todo.priority,
todo.content
)
})
.join("\n")
.into(),
],
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
@ -193,6 +174,10 @@ impl ClaudeTool {
})
.unwrap_or_default()
}
Self::TodoWrite(Some(_)) => {
// These are mapped to plan updates later
vec![]
}
Self::Task(None)
| Self::NotebookRead(None)
| Self::NotebookEdit(None)
@ -312,6 +297,7 @@ impl ClaudeTool {
content: self.content(),
locations: self.locations(),
raw_input: None,
raw_output: None,
}
}
}
@ -488,10 +474,11 @@ impl std::fmt::Display for GrepToolParams {
}
}
#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TodoPriority {
High,
#[default]
Medium,
Low,
}
@ -526,14 +513,13 @@ impl Into<acp::PlanEntryStatus> for TodoStatus {
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
pub struct Todo {
/// Unique identifier
pub id: String,
/// Task description
pub content: String,
/// Priority level of the todo
pub priority: TodoPriority,
/// Current status of the todo
pub status: TodoStatus,
/// Priority level of the todo
#[serde(default)]
pub priority: TodoPriority,
}
impl Into<acp::PlanEntry> for Todo {

View file

@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
let _ = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
cx,
@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
id.clone()
});
let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
full_turn.await.unwrap();
thread.read_with(cx, |thread, _| {
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Canceled,
..
@ -311,6 +310,27 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
});
}
pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
let fs = init_test(cx).await;
let project = Project::test(fs, [], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
thread
.update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
.await
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
});
let weak_thread = thread.downgrade();
drop(thread);
cx.executor().run_until_parked();
assert!(!weak_thread.is_upgradable());
}
#[macro_export]
macro_rules! common_e2e_tests {
($server:expr, allow_option_id = $allow_option_id:expr) => {
@ -351,6 +371,12 @@ macro_rules! common_e2e_tests {
async fn cancel(cx: &mut ::gpui::TestAppContext) {
$crate::e2e_tests::test_cancel($server, cx).await;
}
#[::gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)]
async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
$crate::e2e_tests::test_thread_drop($server, cx).await;
}
}
};
}

View file

@ -2,7 +2,7 @@ use std::path::Path;
use std::rc::Rc;
use crate::{AgentServer, AgentServerCommand};
use acp_thread::AgentConnection;
use acp_thread::{AgentConnection, LoadError};
use anyhow::Result;
use gpui::{Entity, Task};
use project::Project;
@ -48,12 +48,42 @@ impl AgentServer for Gemini {
})?;
let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
else {
anyhow::bail!("Failed to find gemini binary");
};
crate::acp::connect(server_name, command, &root_dir, cx).await
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
if result.is_err() {
let version_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output();
let help_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--help")
.kill_on_drop(true)
.output();
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
let current_version = String::from_utf8(version_output?.stdout)?;
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
if !supported {
return Err(LoadError::Unsupported {
error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
current_version
).into(),
upgrade_message: "Upgrade Gemini to Latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
}.into())
}
}
result
})
}
}

View file

@ -13,6 +13,9 @@ use std::borrow::Cow;
pub use crate::agent_profile::*;
pub const SUMMARIZE_THREAD_PROMPT: &str =
include_str!("../../agent/src/prompts/summarize_thread_prompt.txt");
pub fn init(cx: &mut App) {
AgentSettings::register(cx);
}

View file

@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"]
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent.workspace = true
agent2.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
ai_onboarding.workspace = true

View file

@ -45,6 +45,11 @@ impl<T> MessageHistory<T> {
None
})
}
#[cfg(test)]
pub fn items(&self) -> &[T] {
&self.items
}
}
#[cfg(test)]
mod tests {

View file

@ -5,6 +5,7 @@ use audio::{Audio, Sound};
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::path::Path;
use std::process::ExitStatus;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
@ -20,26 +21,28 @@ use editor::{
use file_icons::FileIcons;
use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
list, percentage, point, prelude::*, pulsating_between,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
};
use language::language_settings::SoftWrap;
use language::{Buffer, Language};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use parking_lot::Mutex;
use project::Project;
use settings::Settings as _;
use settings::{Settings as _, SettingsStore};
use text::{Anchor, BufferSnapshot};
use theme::ThemeSettings;
use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*};
use ui::{
Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*,
};
use util::ResultExt;
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
};
@ -68,6 +71,7 @@ pub struct AcpThreadView {
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
last_error: Option<Entity<Markdown>>,
list_state: ListState,
scrollbar_state: ScrollbarState,
auth_task: Option<Task<()>>,
expanded_tool_calls: HashSet<acp::ToolCallId>,
expanded_thinking_blocks: HashSet<(usize, usize)>,
@ -76,6 +80,7 @@ pub struct AcpThreadView {
editor_expanded: bool,
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
_cancel_task: Option<Task<()>>,
_subscriptions: [Subscription; 1],
}
enum ThreadState {
@ -90,6 +95,9 @@ enum ThreadState {
Unauthenticated {
connection: Rc<dyn AgentConnection>,
},
ServerExited {
status: ExitStatus,
},
}
impl AcpThreadView {
@ -169,22 +177,9 @@ impl AcpThreadView {
let mention_set = mention_set.clone();
let list_state = ListState::new(
0,
gpui::ListAlignment::Bottom,
px(2048.0),
cx.processor({
move |this: &mut Self, index: usize, window, cx| {
let Some((entry, len)) = this.thread().and_then(|thread| {
let entries = &thread.read(cx).entries();
Some((entries.get(index)?, entries.len()))
}) else {
return Empty.into_any();
};
this.render_entry(index, len, entry, window, cx)
}
}),
);
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
let subscription = cx.observe_global_in::<SettingsStore>(window, Self::settings_changed);
Self {
agent: agent.clone(),
@ -198,7 +193,8 @@ impl AcpThreadView {
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
diff_editors: Default::default(),
list_state: list_state,
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
last_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
@ -207,6 +203,7 @@ impl AcpThreadView {
plan_expanded: false,
editor_expanded: false,
message_history,
_subscriptions: [subscription],
_cancel_task: None,
}
}
@ -228,7 +225,7 @@ impl AcpThreadView {
let connect_task = agent.connect(&root_dir, &project, cx);
let load_task = cx.spawn_in(window, async move |this, cx| {
let connection = match connect_task.await {
Ok(thread) => thread,
Ok(connection) => connection,
Err(err) => {
this.update(cx, |this, cx| {
this.handle_load_error(err, cx);
@ -239,6 +236,20 @@ impl AcpThreadView {
}
};
// this.update_in(cx, |_this, _window, cx| {
// let status = connection.exit_status(cx);
// cx.spawn(async move |this, cx| {
// let status = status.await.ok();
// this.update(cx, |this, cx| {
// this.thread_state = ThreadState::ServerExited { status };
// cx.notify();
// })
// .ok();
// })
// .detach();
// })
// .ok();
let result = match connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
@ -307,7 +318,8 @@ impl AcpThreadView {
ThreadState::Ready { thread, .. } => Some(thread),
ThreadState::Unauthenticated { .. }
| ThreadState::Loading { .. }
| ThreadState::LoadError(..) => None,
| ThreadState::LoadError(..)
| ThreadState::ServerExited { .. } => None,
}
}
@ -317,6 +329,7 @@ impl AcpThreadView {
ThreadState::Loading { .. } => "Loading…".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
}
}
@ -368,6 +381,11 @@ impl AcpThreadView {
editor.display_map.update(cx, |map, cx| {
let snapshot = map.snapshot(cx);
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
// Skip creases that have been edited out of the message buffer.
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
continue;
}
if let Some(project_path) =
self.mention_set.lock().path_for_crease_id(crease_id)
{
@ -646,6 +664,9 @@ impl AcpThreadView {
cx,
);
}
AcpThreadEvent::ServerExited(status) => {
self.thread_state = ThreadState::ServerExited { status: *status };
}
}
cx.notify();
}
@ -692,15 +713,7 @@ impl AcpThreadView {
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
});
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
@ -719,7 +732,11 @@ impl AcpThreadView {
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
Some(
entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone()),
)
}
fn authenticate(
@ -785,7 +802,7 @@ impl AcpThreadView {
window: &mut Window,
cx: &Context<Self>,
) -> AnyElement {
match &entry {
let primary = match &entry {
AgentThreadEntry::UserMessage(message) => div()
.py_4()
.px_2()
@ -846,10 +863,25 @@ impl AcpThreadView {
.into_any()
}
AgentThreadEntry::ToolCall(tool_call) => div()
.w_full()
.py_1p5()
.px_5()
.child(self.render_tool_call(index, tool_call, window, cx))
.into_any(),
};
let Some(thread) = self.thread() else {
return primary;
};
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
if index == total_entries - 1 && !is_generating {
v_flex()
.w_full()
.child(primary)
.child(self.render_thread_controls(cx))
.into_any_element()
} else {
primary
}
}
@ -877,6 +909,7 @@ impl AcpThreadView {
cx: &Context<Self>,
) -> AnyElement {
let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix));
let card_header_id = SharedString::from("inner-card-header");
let key = (entry_ix, chunk_ix);
let is_open = self.expanded_thinking_blocks.contains(&key);
@ -884,41 +917,53 @@ impl AcpThreadView {
.child(
h_flex()
.id(header_id)
.group("disclosure-header")
.group(&card_header_id)
.relative()
.w_full()
.justify_between()
.gap_1p5()
.opacity(0.8)
.hover(|style| style.opacity(1.))
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ToolBulb)
.size(IconSize::Small)
.color(Color::Muted),
)
.size_4()
.justify_center()
.child(
div()
.text_size(self.tool_name_font_size())
.child("Thinking"),
.group_hover(&card_header_id, |s| s.invisible().w_0())
.child(
Icon::new(IconName::ToolThink)
.size(IconSize::Small)
.color(Color::Muted),
),
)
.child(
h_flex()
.absolute()
.inset_0()
.invisible()
.justify_center()
.group_hover(&card_header_id, |s| s.visible())
.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronRight)
.on_click(cx.listener({
move |this, _event, _window, cx| {
if is_open {
this.expanded_thinking_blocks.remove(&key);
} else {
this.expanded_thinking_blocks.insert(key);
}
cx.notify();
}
})),
),
),
)
.child(
div().visible_on_hover("disclosure-header").child(
Disclosure::new("thinking-disclosure", is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
move |this, _event, _window, cx| {
if is_open {
this.expanded_thinking_blocks.remove(&key);
} else {
this.expanded_thinking_blocks.insert(key);
}
cx.notify();
}
})),
),
div()
.text_size(self.tool_name_font_size())
.child("Thinking"),
)
.on_click(cx.listener({
move |this, _event, _window, cx| {
@ -949,6 +994,67 @@ impl AcpThreadView {
.into_any_element()
}
fn render_tool_call_icon(
&self,
group_name: SharedString,
entry_ix: usize,
is_collapsible: bool,
is_open: bool,
tool_call: &ToolCall,
cx: &Context<Self>,
) -> Div {
let tool_icon = Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolRead,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolThink,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted);
if is_collapsible {
h_flex()
.size_4()
.justify_center()
.child(
div()
.group_hover(&group_name, |s| s.invisible().w_0())
.child(tool_icon),
)
.child(
h_flex()
.absolute()
.inset_0()
.invisible()
.justify_center()
.group_hover(&group_name, |s| s.visible())
.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronRight)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
),
)
} else {
div().child(tool_icon)
}
}
fn render_tool_call(
&self,
entry_ix: usize,
@ -956,7 +1062,8 @@ impl AcpThreadView {
window: &Window,
cx: &Context<Self>,
) -> Div {
let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix));
let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix));
let card_header_id = SharedString::from("inner-tool-call-header");
let status_icon = match &tool_call.status {
ToolCallStatus::Allowed {
@ -1005,6 +1112,21 @@ impl AcpThreadView {
let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
let gradient_color = cx.theme().colors().panel_background;
let gradient_overlay = {
div()
.absolute()
.top_0()
.right_0()
.w_12()
.h_full()
.bg(linear_gradient(
90.,
linear_color_stop(gradient_color, 1.),
linear_color_stop(gradient_color.opacity(0.2), 0.),
))
};
v_flex()
.when(needs_confirmation, |this| {
this.rounded_lg()
@ -1021,43 +1143,38 @@ impl AcpThreadView {
.justify_between()
.map(|this| {
if needs_confirmation {
this.px_2()
this.pl_2()
.pr_1()
.py_1()
.rounded_t_md()
.bg(self.tool_card_header_bg(cx))
.border_b_1()
.border_color(self.tool_card_border_color(cx))
.bg(self.tool_card_header_bg(cx))
} else {
this.opacity(0.8).hover(|style| style.opacity(1.))
}
})
.child(
h_flex()
.id("tool-call-header")
.overflow_x_scroll()
.group(&card_header_id)
.relative()
.w_full()
.map(|this| {
if needs_confirmation {
this.text_xs()
if tool_call.locations.len() == 1 {
this.gap_0()
} else {
this.text_size(self.tool_name_font_size())
this.gap_1p5()
}
})
.gap_1p5()
.child(
Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolRead,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolBulb,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted),
)
.text_size(self.tool_name_font_size())
.child(self.render_tool_call_icon(
card_header_id,
entry_ix,
is_collapsible,
is_open,
tool_call,
cx,
))
.child(if tool_call.locations.len() == 1 {
let name = tool_call.locations[0]
.path
@ -1068,13 +1185,11 @@ impl AcpThreadView {
h_flex()
.id(("open-tool-call-location", entry_ix))
.child(name)
.w_full()
.max_w_full()
.pr_1()
.gap_0p5()
.cursor_pointer()
.px_1p5()
.rounded_sm()
.overflow_x_scroll()
.opacity(0.8)
.hover(|label| {
label.opacity(1.).bg(cx
@ -1083,53 +1198,49 @@ impl AcpThreadView {
.element_hover
.opacity(0.5))
})
.child(name)
.tooltip(Tooltip::text("Jump to File"))
.on_click(cx.listener(move |this, _, window, cx| {
this.open_tool_call_location(entry_ix, 0, window, cx);
}))
.into_any_element()
} else {
self.render_markdown(
tool_call.label.clone(),
default_markdown_style(needs_confirmation, window, cx),
)
.into_any()
h_flex()
.id("non-card-label-container")
.w_full()
.relative()
.overflow_hidden()
.child(
h_flex()
.id("non-card-label")
.pr_8()
.w_full()
.overflow_x_scroll()
.child(self.render_markdown(
tool_call.label.clone(),
default_markdown_style(
needs_confirmation,
window,
cx,
),
)),
)
.child(gradient_overlay)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
}))
.into_any()
}),
)
.child(
h_flex()
.gap_0p5()
.when(is_collapsible, |this| {
this.child(
Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
)
})
.children(status_icon),
)
.on_click(cx.listener({
let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open {
this.expanded_tool_calls.remove(&id);
} else {
this.expanded_tool_calls.insert(id.clone());
}
cx.notify();
}
})),
.children(status_icon),
)
.when(is_open, |this| {
this.child(
@ -1207,10 +1318,9 @@ impl AcpThreadView {
Empty.into_any_element()
}
}
ToolCallContent::Diff {
diff: Diff { multibuffer, .. },
..
} => self.render_diff_editor(multibuffer),
ToolCallContent::Diff { diff, .. } => {
self.render_diff_editor(&diff.read(cx).multibuffer())
}
}
}
@ -1223,8 +1333,7 @@ impl AcpThreadView {
cx: &Context<Self>,
) -> Div {
h_flex()
.py_1p5()
.px_1p5()
.p_1p5()
.gap_1()
.justify_end()
.when(!empty_content, |this| {
@ -1250,6 +1359,7 @@ impl AcpThreadView {
})
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.label_size(LabelSize::Small)
.on_click(cx.listener({
let tool_call_id = tool_call_id.clone();
let option_id = option.id.clone();
@ -1369,7 +1479,29 @@ impl AcpThreadView {
.into_any()
}
fn render_error_state(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
fn render_server_exited(&self, status: ExitStatus, _cx: &Context<Self>) -> AnyElement {
v_flex()
.items_center()
.justify_center()
.child(self.render_error_agent_logo())
.child(
v_flex()
.mt_4()
.mb_2()
.gap_0p5()
.text_center()
.items_center()
.child(Headline::new("Server exited unexpectedly").size(HeadlineSize::Medium))
.child(
Label::new(format!("Exit status: {}", status.code().unwrap_or(-127)))
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.into_any_element()
}
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
let mut container = v_flex()
.items_center()
.justify_center()
@ -1477,7 +1609,7 @@ impl AcpThreadView {
})
})
.when(!changed_buffers.is_empty(), |this| {
this.child(Divider::horizontal())
this.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_edits_summary(
action_log,
&changed_buffers,
@ -1507,6 +1639,7 @@ impl AcpThreadView {
{
h_flex()
.w_full()
.cursor_default()
.gap_1()
.text_xs()
.text_color(cx.theme().colors().text_muted)
@ -1536,7 +1669,7 @@ impl AcpThreadView {
let status_label = if stats.pending == 0 {
"All Done".to_string()
} else if stats.completed == 0 {
format!("{}", plan.entries.len())
format!("{} Tasks", plan.entries.len())
} else {
format!("{}/{}", stats.completed, plan.entries.len())
};
@ -1650,7 +1783,6 @@ impl AcpThreadView {
.child(
h_flex()
.id("edits-container")
.cursor_pointer()
.w_full()
.gap_1()
.child(Disclosure::new("edits-disclosure", expanded))
@ -2404,7 +2536,7 @@ impl AcpThreadView {
}
}
fn render_thread_controls(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
fn render_thread_controls(&self, cx: &Context<Self>) -> impl IntoElement {
let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileText)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
@ -2425,9 +2557,9 @@ impl AcpThreadView {
}));
h_flex()
.mt_1()
.w_full()
.mr_1()
.py_2()
.pb_2()
.px(RESPONSE_PADDING_X)
.opacity(0.4)
.hover(|style| style.opacity(1.))
@ -2436,6 +2568,48 @@ impl AcpThreadView {
.child(open_as_markdown)
.child(scroll_to_top)
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.id("acp-thread-scrollbar")
.occlude()
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
for diff_editor in self.diff_editors.values() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
}
}
}
impl Focusable for AcpThreadView {
@ -2481,24 +2655,36 @@ impl Render for AcpThreadView {
.flex_1()
.items_center()
.justify_center()
.child(self.render_error_state(e, cx)),
.child(self.render_load_error(e, cx)),
ThreadState::ServerExited { status } => v_flex()
.p_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_server_exited(*status, cx)),
ThreadState::Ready { thread, .. } => {
let thread_clone = thread.clone();
v_flex().flex_1().map(|this| {
if self.list_state.item_count() > 0 {
let is_generating =
matches!(thread_clone.read(cx).status(), ThreadStatus::Generating);
this.child(
list(self.list_state.clone())
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
.flex_grow()
.into_any(),
list(
self.list_state.clone(),
cx.processor(|this, index: usize, window, cx| {
let Some((entry, len)) = this.thread().and_then(|thread| {
let entries = &thread.read(cx).entries();
Some((entries.get(index)?, entries.len()))
}) else {
return Empty.into_any();
};
this.render_entry(index, len, entry, window, cx)
}),
)
.with_sizing_behavior(gpui::ListSizingBehavior::Auto)
.flex_grow()
.into_any(),
)
.when(!is_generating, |this| {
this.child(self.render_thread_controls(cx))
})
.child(self.render_vertical_scrollbar(cx))
.children(match thread_clone.read(cx).status() {
ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => {
None
@ -2701,6 +2887,18 @@ fn plan_label_markdown_style(
}
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
}
}
#[cfg(test)]
mod tests {
use agent_client_protocol::SessionId;
@ -2708,11 +2906,25 @@ mod tests {
use fs::FakeFs;
use futures::future::try_join_all;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
use lsp::{CompletionContext, CompletionTriggerKind};
use project::CompletionIntent;
use rand::Rng;
use serde_json::json;
use settings::SettingsStore;
use util::path;
use super::*;
#[gpui::test]
async fn test_drop(cx: &mut TestAppContext) {
init_test(cx);
let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await;
let weak_view = thread_view.downgrade();
drop(thread_view);
assert!(!weak_view.is_upgradable());
}
#[gpui::test]
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
init_test(cx);
@ -2779,6 +2991,7 @@ mod tests {
content: vec!["hi".into()],
locations: vec![],
raw_input: None,
raw_output: None,
};
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
.with_permission_requests(HashMap::from_iter([(
@ -2811,6 +3024,109 @@ mod tests {
);
}
#[gpui::test]
async fn test_crease_removal(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({"file": ""})).await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let agent = StubAgentServer::default();
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(agent),
workspace.downgrade(),
project,
Rc::new(RefCell::new(MessageHistory::default())),
1,
None,
window,
cx,
)
})
});
cx.run_until_parked();
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
let excerpt_id = message_editor.update(cx, |editor, cx| {
editor
.buffer()
.read(cx)
.excerpt_ids()
.into_iter()
.next()
.unwrap()
});
let completions = message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello @", window, cx);
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
let completion_provider = editor.completion_provider().unwrap();
completion_provider.completions(
excerpt_id,
&buffer,
Anchor::MAX,
CompletionContext {
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
trigger_character: Some("@".into()),
},
window,
cx,
)
});
let [_, completion]: [_; 2] = completions
.await
.unwrap()
.into_iter()
.flat_map(|response| response.completions)
.collect::<Vec<_>>()
.try_into()
.unwrap();
message_editor.update_in(cx, |editor, window, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let start = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
.unwrap();
let end = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
.unwrap();
editor.edit([(start..end, completion.new_text)], cx);
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
});
cx.run_until_parked();
// Backspace over the inserted crease (and the following space).
message_editor.update_in(cx, |editor, window, cx| {
editor.backspace(&Default::default(), window, cx);
editor.backspace(&Default::default(), window, cx);
});
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.chat(&Chat, window, cx);
});
cx.run_until_parked();
let content = thread_view.update_in(cx, |thread_view, _window, _cx| {
thread_view
.message_history
.borrow()
.items()
.iter()
.flatten()
.cloned()
.collect::<Vec<_>>()
});
// We don't send a resource link for the deleted crease.
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
}
async fn setup_thread_view(
agent: impl AgentServer + 'static,
cx: &mut TestAppContext,
@ -2943,7 +3259,11 @@ mod tests {
unimplemented!()
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
fn prompt(
&self,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
@ -2960,7 +3280,7 @@ mod tests {
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,
@ -2977,7 +3297,9 @@ mod tests {
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(())
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
@ -3021,7 +3343,11 @@ mod tests {
unimplemented!()
}
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
fn prompt(
&self,
_params: acp::PromptRequest,
_cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
Task::ready(Err(anyhow::anyhow!("Error prompting")))
}

View file

@ -69,8 +69,6 @@ pub struct ActiveThread {
messages: Vec<MessageId>,
list_state: ListState,
scrollbar_state: ScrollbarState,
show_scrollbar: bool,
hide_scrollbar_task: Option<Task<()>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
editing_message: Option<(MessageId, EditingMessageState)>,
@ -780,13 +778,7 @@ impl ActiveThread {
cx.observe_global::<SettingsStore>(|_, cx| cx.notify()),
];
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
let this = cx.entity().downgrade();
move |ix, window: &mut Window, cx: &mut App| {
this.update(cx, |this, cx| this.render_message(ix, window, cx))
.unwrap()
}
});
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.));
let workspace_subscription = if let Some(workspace) = workspace.upgrade() {
Some(cx.observe_release(&workspace, |this, _, cx| {
@ -811,9 +803,7 @@ impl ActiveThread {
expanded_thinking_segments: HashMap::default(),
expanded_code_blocks: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state),
show_scrollbar: false,
hide_scrollbar_task: None,
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
editing_message: None,
last_error: None,
copied_code_block_ids: HashSet::default(),
@ -1846,7 +1836,12 @@ impl ActiveThread {
)))
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
fn render_message(
&mut self,
ix: usize,
window: &mut Window,
cx: &mut Context<Self>,
) -> AnyElement {
let message_id = self.messages[ix];
let workspace = self.workspace.clone();
let thread = self.thread.read(cx);
@ -2629,7 +2624,7 @@ impl ActiveThread {
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::ToolBulb)
Icon::new(IconName::ToolThink)
.size(IconSize::Small)
.color(Color::Muted),
)
@ -3503,60 +3498,37 @@ impl ActiveThread {
}
}
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
if !self.show_scrollbar && !self.scrollbar_state.is_dragging() {
return None;
}
Some(
div()
.occlude()
.id("active-thread-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
div()
.occlude()
.id("active-thread-scrollbar")
.on_mouse_move(cx.listener(|_, _, _, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _, cx| {
cx.stop_propagation();
})
.on_mouse_up(
MouseButton::Left,
cx.listener(|_, _, _, cx| {
cx.stop_propagation();
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
)
}
fn hide_scrollbar_later(&mut self, cx: &mut Context<Self>) {
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| {
cx.background_executor()
.timer(SCROLLBAR_SHOW_INTERVAL)
.await;
thread
.update(cx, |thread, cx| {
if !thread.scrollbar_state.is_dragging() {
thread.show_scrollbar = false;
cx.notify();
}
})
.log_err();
}))
}),
)
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
cx.notify();
}))
.h_full()
.absolute()
.right_1()
.top_1()
.bottom_0()
.w(px(12.))
.cursor_default()
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
}
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
@ -3597,26 +3569,8 @@ impl Render for ActiveThread {
.size_full()
.relative()
.bg(cx.theme().colors().panel_background)
.on_mouse_move(cx.listener(|this, _, _, cx| {
this.show_scrollbar = true;
this.hide_scrollbar_later(cx);
cx.notify();
}))
.on_scroll_wheel(cx.listener(|this, _, _, cx| {
this.show_scrollbar = true;
this.hide_scrollbar_later(cx);
cx.notify();
}))
.on_mouse_up(
MouseButton::Left,
cx.listener(|this, _, _, cx| {
this.hide_scrollbar_later(cx);
}),
)
.child(list(self.list_state.clone()).flex_grow())
.when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| {
this.child(scrollbar)
})
.child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow())
.child(self.render_vertical_scrollbar(cx))
}
}

View file

@ -1523,7 +1523,8 @@ impl AgentDiff {
}
AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error => {}
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}
}
}

View file

@ -970,13 +970,7 @@ impl AgentPanel {
)
});
this.set_active_view(
ActiveView::ExternalAgentThread {
thread_view: thread_view.clone(),
},
window,
cx,
);
this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx);
})
})
.detach_and_log_err(cx);
@ -1987,6 +1981,22 @@ impl AgentPanel {
);
}),
)
.item(
ContextMenuEntry::new("New Native Agent Thread")
.icon(IconName::ZedAssistant)
.icon_color(Color::Muted)
.handler(move |window, cx| {
window.dispatch_action(
NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::NativeAgent,
),
}
.boxed_clone(),
cx,
);
}),
)
});
menu
}))
@ -2333,6 +2343,16 @@ impl AgentPanel {
)
}
fn render_backdrop(&self, cx: &mut Context<Self>) -> impl IntoElement {
div()
.size_full()
.absolute()
.inset_0()
.bg(cx.theme().colors().panel_background)
.opacity(0.8)
.block_mouse_except_scroll()
}
fn render_trial_end_upsell(
&self,
_window: &mut Window,
@ -2342,15 +2362,24 @@ impl AgentPanel {
return None;
}
Some(EndTrialUpsell::new(Arc::new({
let this = cx.entity();
move |_, cx| {
this.update(cx, |_this, cx| {
TrialEndUpsell::set_dismissed(true, cx);
cx.notify();
});
}
})))
Some(
v_flex()
.absolute()
.inset_0()
.size_full()
.bg(cx.theme().colors().panel_background)
.opacity(0.85)
.block_mouse_except_scroll()
.child(EndTrialUpsell::new(Arc::new({
let this = cx.entity();
move |_, cx| {
this.update(cx, |_this, cx| {
TrialEndUpsell::set_dismissed(true, cx);
cx.notify();
});
}
}))),
)
}
fn render_empty_state_section_header(
@ -2649,6 +2678,31 @@ impl AgentPanel {
)
},
),
)
.child(
NewThreadButton::new(
"new-native-agent-thread-btn",
"New Native Agent Thread",
IconName::ZedAssistant,
)
// .keybinding(KeyBinding::for_action_in(
// &OpenHistory,
// &self.focus_handle(cx),
// window,
// cx,
// ))
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::NativeAgent,
),
}),
cx,
)
},
),
),
)
}),
@ -3175,9 +3229,10 @@ impl Render for AgentPanel {
// - Scrolling in all views works as expected
// - Files can be dropped into the panel
let content = v_flex()
.key_context(self.key_context())
.justify_between()
.relative()
.size_full()
.justify_between()
.key_context(self.key_context())
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(|this, action: &NewThread, window, cx| {
this.new_thread(action, window, cx);
@ -3220,14 +3275,12 @@ impl Render for AgentPanel {
.on_action(cx.listener(Self::toggle_burn_mode))
.child(self.render_toolbar(window, cx))
.children(self.render_onboarding(window, cx))
.children(self.render_trial_end_upsell(window, cx))
.map(|parent| match &self.active_view {
ActiveView::Thread {
thread,
message_editor,
..
} => parent
.relative()
.child(
if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) {
self.render_thread_empty_state(window, cx)
@ -3264,21 +3317,10 @@ impl Render for AgentPanel {
})
.child(h_flex().relative().child(message_editor.clone()).when(
!LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx),
|this| {
this.child(
div()
.size_full()
.absolute()
.inset_0()
.bg(cx.theme().colors().panel_background)
.opacity(0.8)
.block_mouse_except_scroll(),
)
},
|this| this.child(self.render_backdrop(cx)),
))
.child(self.render_drag_target(cx)),
ActiveView::ExternalAgentThread { thread_view, .. } => parent
.relative()
.child(thread_view.clone())
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
@ -3317,7 +3359,8 @@ impl Render for AgentPanel {
))
}
ActiveView::Configuration => parent.children(self.configuration.clone()),
});
})
.children(self.render_trial_end_upsell(window, cx));
match self.active_view.which_font_size_used() {
WhichFontSize::AgentFont => {

View file

@ -151,6 +151,7 @@ enum ExternalAgent {
#[default]
Gemini,
ClaudeCode,
NativeAgent,
}
impl ExternalAgent {
@ -158,6 +159,7 @@ impl ExternalAgent {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
}
}
}

View file

@ -504,7 +504,7 @@ impl Render for ContextStrip {
)
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
if event.click_count() > 1 {
this.open_context(&context, window, cx);
} else {
this.focused_index = Some(i);

View file

@ -1,9 +1,9 @@
use std::sync::Arc;
use ai_onboarding::{AgentPanelOnboardingCard, BulletItem};
use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions};
use client::zed_urls;
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Divider, List, Tooltip, prelude::*};
use ui::{Divider, Tooltip, prelude::*};
#[derive(IntoElement, RegisterComponent)]
pub struct EndTrialUpsell {
@ -18,6 +18,8 @@ impl EndTrialUpsell {
impl RenderOnce for EndTrialUpsell {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let plan_definitions = PlanDefinitions;
let pro_section = v_flex()
.gap_1()
.child(
@ -31,13 +33,7 @@ impl RenderOnce for EndTrialUpsell {
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("500 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
)),
)
.child(plan_definitions.pro_plan(false))
.child(
Button::new("cta-button", "Upgrade to Zed Pro")
.full_width()
@ -68,11 +64,7 @@ impl RenderOnce for EndTrialUpsell {
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("50 prompts with the Claude models"))
.child(BulletItem::new("2,000 accepted edit predictions")),
);
.child(plan_definitions.free_plan());
AgentPanelOnboardingCard::new()
.child(Headline::new("Your Zed Pro Trial has expired"))
@ -102,18 +94,20 @@ impl RenderOnce for EndTrialUpsell {
impl Component for EndTrialUpsell {
fn scope() -> ComponentScope {
ComponentScope::Agent
ComponentScope::Onboarding
}
fn name() -> &'static str {
"End of Trial Upsell Banner"
}
fn sort_name() -> &'static str {
"AgentEndTrialUpsell"
"End of Trial Upsell Banner"
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
Some(
v_flex()
.p_4()
.gap_4()
.child(EndTrialUpsell {
dismiss_upsell: Arc::new(|_, _| {}),
})

View file

@ -1,8 +1,6 @@
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::{Divider, List, prelude::*};
use crate::BulletItem;
use ui::{Divider, List, ListBulletItem, prelude::*};
pub struct ApiKeysWithProviders {
configured_providers: Vec<(IconName, SharedString)>,
@ -128,7 +126,7 @@ impl RenderOnce for ApiKeysWithoutProviders {
)
.child(Divider::horizontal()),
)
.child(List::new().child(BulletItem::new(
.child(List::new().child(ListBulletItem::new(
"Add your own keys to use AI without signing in.",
)))
.child(

View file

@ -3,6 +3,7 @@ mod agent_panel_onboarding_card;
mod agent_panel_onboarding_content;
mod ai_upsell_card;
mod edit_prediction_onboarding_content;
mod plan_definitions;
mod young_account_banner;
pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProviders};
@ -11,51 +12,14 @@ pub use agent_panel_onboarding_content::AgentPanelOnboarding;
pub use ai_upsell_card::AiUpsellCard;
use cloud_llm_client::Plan;
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
pub use plan_definitions::PlanDefinitions;
pub use young_account_banner::YoungAccountBanner;
use std::sync::Arc;
use client::{Client, UserStore, zed_urls};
use gpui::{AnyElement, Entity, IntoElement, ParentElement, SharedString};
use ui::{Divider, List, ListItem, RegisterComponent, TintColor, Tooltip, prelude::*};
#[derive(IntoElement)]
pub struct BulletItem {
label: SharedString,
}
impl BulletItem {
pub fn new(label: impl Into<SharedString>) -> Self {
Self {
label: label.into(),
}
}
}
impl RenderOnce for BulletItem {
fn render(self, window: &mut Window, _cx: &mut App) -> impl IntoElement {
let line_height = 0.85 * window.line_height();
ListItem::new("list-item")
.selectable(false)
.child(
h_flex()
.w_full()
.min_w_0()
.gap_1()
.items_start()
.child(
h_flex().h(line_height).justify_center().child(
Icon::new(IconName::Dash)
.size(IconSize::XSmall)
.color(Color::Hidden),
),
)
.child(div().w_full().min_w_0().child(Label::new(self.label))),
)
.into_any_element()
}
}
use gpui::{AnyElement, Entity, IntoElement, ParentElement};
use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*};
#[derive(PartialEq)]
pub enum SignInStatus {
@ -130,107 +94,6 @@ impl ZedAiOnboarding {
self
}
fn free_plan_definition(&self, cx: &mut App) -> impl IntoElement {
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Free")
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx),
)
.child(
Label::new("(Current Plan)")
.size(LabelSize::Small)
.color(Color::Custom(cx.theme().colors().text_muted.opacity(0.6)))
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("50 prompts per month with Claude models"))
.child(BulletItem::new(
"2,000 accepted edit predictions with Zeta, our open-source model",
)),
)
}
fn pro_trial_definition(&self) -> impl IntoElement {
List::new()
.child(BulletItem::new("150 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited accepted edit predictions with Zeta, our open-source model",
))
}
fn pro_plan_definition(&self, cx: &mut App) -> impl IntoElement {
v_flex().mt_2().gap_1().map(|this| {
if self.account_too_young {
this.child(
h_flex()
.gap_2()
.child(
Label::new("Pro")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("500 prompts per month with Claude models"))
.child(BulletItem::new(
"Unlimited accepted edit predictions with Zeta, our open-source model",
))
.child(BulletItem::new("$20 USD per month")),
)
.child(
Button::new("pro", "Get Started")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!("Upgrade To Pro Clicked", state = "young-account");
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
}),
)
} else {
this.child(
h_flex()
.gap_2()
.child(
Label::new("Pro Trial")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(self.pro_trial_definition())
.child(BulletItem::new(
"Try it out for 14 days for free, no credit card required",
)),
)
.child(
Button::new("pro", "Start Free Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
cx.open_url(&zed_urls::start_trial_url(cx))
}),
)
}
})
}
fn render_accept_terms_of_service(&self) -> AnyElement {
v_flex()
.gap_1()
@ -269,6 +132,7 @@ impl ZedAiOnboarding {
fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement {
let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn);
let plan_definitions = PlanDefinitions;
v_flex()
.gap_1()
@ -278,7 +142,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(self.pro_trial_definition())
.child(plan_definitions.pro_plan(false))
.child(
Button::new("sign_in", "Try Zed Pro for Free")
.disabled(signing_in)
@ -297,43 +161,132 @@ impl ZedAiOnboarding {
fn render_free_plan_state(&self, cx: &mut App) -> AnyElement {
let young_account_banner = YoungAccountBanner;
let plan_definitions = PlanDefinitions;
v_flex()
.relative()
.gap_1()
.child(Headline::new("Welcome to Zed AI"))
.map(|this| {
if self.account_too_young {
this.child(young_account_banner)
} else {
this.child(self.free_plan_definition(cx)).when_some(
self.dismiss_onboarding.as_ref(),
|this, dismiss_callback| {
let callback = dismiss_callback.clone();
if self.account_too_young {
v_flex()
.relative()
.max_w_full()
.gap_1()
.child(Headline::new("Welcome to Zed AI"))
.child(young_account_banner)
.child(
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Pro")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(true))
.child(
Button::new("pro", "Get Started")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!(
"Upgrade To Pro Clicked",
state = "young-account"
);
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
}),
),
)
.into_any_element()
} else {
v_flex()
.relative()
.gap_1()
.child(Headline::new("Welcome to Zed AI"))
.child(
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Free")
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx),
)
.child(
Label::new("(Current Plan)")
.size(LabelSize::Small)
.color(Color::Custom(
cx.theme().colors().text_muted.opacity(0.6),
))
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(plan_definitions.free_plan()),
)
.when_some(
self.dismiss_onboarding.as_ref(),
|this, dismiss_callback| {
let callback = dismiss_callback.clone();
this.child(
h_flex().absolute().top_0().right_0().child(
IconButton::new("dismiss_onboarding", IconName::Close)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Dismiss"))
.on_click(move |_, window, cx| {
telemetry::event!(
"Banner Dismissed",
source = "AI Onboarding",
);
callback(window, cx)
}),
),
)
},
)
}
})
.child(self.pro_plan_definition(cx))
.into_any_element()
this.child(
h_flex().absolute().top_0().right_0().child(
IconButton::new("dismiss_onboarding", IconName::Close)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Dismiss"))
.on_click(move |_, window, cx| {
telemetry::event!(
"Banner Dismissed",
source = "AI Onboarding",
);
callback(window, cx)
}),
),
)
},
)
.child(
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Pro Trial")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_trial(true))
.child(
Button::new("pro", "Start Free Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!(
"Start Trial Clicked",
state = "post-sign-in"
);
cx.open_url(&zed_urls::start_trial_url(cx))
}),
),
)
.into_any_element()
}
}
fn render_trial_state(&self, _cx: &mut App) -> AnyElement {
let plan_definitions = PlanDefinitions;
v_flex()
.relative()
.gap_1()
@ -343,13 +296,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(
List::new()
.child(BulletItem::new("150 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
)),
)
.child(plan_definitions.pro_trial(false))
.when_some(
self.dismiss_onboarding.as_ref(),
|this, dismiss_callback| {
@ -374,6 +321,8 @@ impl ZedAiOnboarding {
}
fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement {
let plan_definitions = PlanDefinitions;
v_flex()
.gap_1()
.child(Headline::new("Welcome to Zed Pro"))
@ -382,13 +331,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(
List::new()
.child(BulletItem::new("500 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
)),
)
.child(plan_definitions.pro_plan(false))
.child(
Button::new("pro", "Continue with Zed Pro")
.full_width()
@ -425,7 +368,15 @@ impl RenderOnce for ZedAiOnboarding {
impl Component for ZedAiOnboarding {
fn scope() -> ComponentScope {
ComponentScope::Agent
ComponentScope::Onboarding
}
fn name() -> &'static str {
"Agent Panel Banners"
}
fn sort_name() -> &'static str {
"Agent Panel Banners"
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
@ -450,8 +401,9 @@ impl Component for ZedAiOnboarding {
Some(
v_flex()
.p_4()
.gap_4()
.items_center()
.max_w_4_5()
.children(vec![
single_example(
"Not Signed-in",
@ -462,8 +414,8 @@ impl Component for ZedAiOnboarding {
onboarding(SignInStatus::SignedIn, false, None, false),
),
single_example(
"Account too young",
onboarding(SignInStatus::SignedIn, false, None, true),
"Young Account",
onboarding(SignInStatus::SignedIn, true, None, true),
),
single_example(
"Free Plan",

View file

@ -1,23 +1,33 @@
use std::sync::Arc;
use std::{sync::Arc, time::Duration};
use client::{Client, zed_urls};
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::Plan;
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Divider, List, Vector, VectorName, prelude::*};
use gpui::{
Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation,
Window, percentage,
};
use ui::{Divider, Vector, VectorName, prelude::*};
use crate::{BulletItem, SignInStatus};
use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions};
#[derive(IntoElement, RegisterComponent)]
pub struct AiUpsellCard {
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
pub account_too_young: bool,
pub user_plan: Option<Plan>,
pub tab_index: Option<isize>,
}
impl AiUpsellCard {
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
pub fn new(
client: Arc<Client>,
user_store: &Entity<UserStore>,
user_plan: Option<Plan>,
cx: &mut App,
) -> Self {
let status = *client.status().borrow();
let store = user_store.read(cx);
Self {
user_plan,
@ -29,6 +39,7 @@ impl AiUpsellCard {
})
.detach_and_log_err(cx);
}),
account_too_young: store.account_too_young(),
tab_index: None,
}
}
@ -36,6 +47,9 @@ impl AiUpsellCard {
impl RenderOnce for AiUpsellCard {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let plan_definitions = PlanDefinitions;
let young_account_banner = YoungAccountBanner;
let pro_section = v_flex()
.flex_grow()
.w_full()
@ -51,13 +65,7 @@ impl RenderOnce for AiUpsellCard {
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("500 prompts with Claude models"))
.child(BulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
)),
);
.child(plan_definitions.pro_plan(false));
let free_section = v_flex()
.flex_grow()
@ -74,11 +82,7 @@ impl RenderOnce for AiUpsellCard {
)
.child(Divider::horizontal()),
)
.child(
List::new()
.child(BulletItem::new("50 prompts with Claude models"))
.child(BulletItem::new("2,000 accepted edit predictions")),
);
.child(plan_definitions.free_plan());
let grid_bg = h_flex().absolute().inset_0().w_full().h(px(240.)).child(
Vector::new(VectorName::Grid, rems_from_px(500.), rems_from_px(240.))
@ -101,44 +105,11 @@ impl RenderOnce for AiUpsellCard {
),
));
const DESCRIPTION: &str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
let description = PlanDefinitions::AI_DESCRIPTION;
let footer_buttons = match self.sign_in_status {
SignInStatus::SignedIn => v_flex()
.items_center()
.gap_1()
.child(
Button::new("sign_in", "Start 14-day Free Pro Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!("Start Trial Clicked", state = "post-sign-in");
cx.open_url(&zed_urls::start_trial_url(cx))
})
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index)),
)
.child(
Label::new("No credit card required")
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any_element(),
_ => Button::new("sign_in", "Sign In")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index))
.on_click({
let callback = self.sign_in.clone();
move |_, window, cx| {
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
callback(window, cx)
}
})
.into_any_element(),
};
v_flex()
let card = v_flex()
.relative()
.flex_grow()
.p_4()
.pt_3()
.border_1()
@ -146,31 +117,169 @@ impl RenderOnce for AiUpsellCard {
.rounded_lg()
.overflow_hidden()
.child(grid_bg)
.child(gradient_bg)
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.child(gradient_bg);
let plans_section = h_flex()
.w_full()
.mt_1p5()
.mb_2p5()
.items_start()
.gap_6()
.child(free_section)
.child(pro_section);
let footer_container = v_flex().items_center().gap_1();
let certified_user_stamp = div()
.absolute()
.top_2()
.right_2()
.size(rems_from_px(72.))
.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(DESCRIPTION).color(Color::Muted)),
)
Vector::new(
VectorName::ProUserStamp,
rems_from_px(72.),
rems_from_px(72.),
)
.color(Color::Custom(cx.theme().colors().text_accent.alpha(0.3)))
.with_animation(
"loading_stamp",
Animation::new(Duration::from_secs(10)).repeat(),
|this, delta| this.transform(Transformation::rotate(percentage(delta))),
),
);
let pro_trial_stamp = div()
.absolute()
.top_2()
.right_2()
.size(rems_from_px(72.))
.child(
h_flex()
.w_full()
.mt_1p5()
.mb_2p5()
.items_start()
.gap_6()
.child(free_section)
.child(pro_section),
)
.child(footer_buttons)
Vector::new(
VectorName::ProTrialStamp,
rems_from_px(72.),
rems_from_px(72.),
)
.color(Color::Custom(cx.theme().colors().text.alpha(0.2))),
);
match self.sign_in_status {
SignInStatus::SignedIn => match self.user_plan {
None | Some(Plan::ZedFree) => card
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.map(|this| {
if self.account_too_young {
this.child(young_account_banner).child(
v_flex()
.mt_2()
.gap_1()
.child(
h_flex()
.gap_2()
.child(
Label::new("Pro")
.size(LabelSize::Small)
.color(Color::Accent)
.buffer_font(cx),
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(true))
.child(
Button::new("pro", "Get Started")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(move |_, _window, cx| {
telemetry::event!(
"Upgrade To Pro Clicked",
state = "young-account"
);
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
}),
),
)
} else {
this.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(description).color(Color::Muted)),
)
.child(plans_section)
.child(
footer_container
.child(
Button::new("start_trial", "Start 14-day Free Pro Trial")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.when_some(self.tab_index, |this, tab_index| {
this.tab_index(tab_index)
})
.on_click(move |_, _window, cx| {
telemetry::event!(
"Start Trial Clicked",
state = "post-sign-in"
);
cx.open_url(&zed_urls::start_trial_url(cx))
}),
)
.child(
Label::new("No credit card required")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}
}),
Some(Plan::ZedProTrial) => card
.child(pro_trial_stamp)
.child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large))
.child(
Label::new("Here's what you get for the next 14 days:")
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_trial(false)),
Some(Plan::ZedPro) => card
.child(certified_user_stamp)
.child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large))
.child(
Label::new("Here's what you get:")
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_plan(false)),
},
// Signed Out State
_ => card
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(description).color(Color::Muted)),
)
.child(plans_section)
.child(
Button::new("sign_in", "Sign In")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.when_some(self.tab_index, |this, tab_index| this.tab_index(tab_index))
.on_click({
let callback = self.sign_in.clone();
move |_, window, cx| {
telemetry::event!("Start Trial Clicked", state = "pre-sign-in");
callback(window, cx)
}
}),
),
}
}
}
impl Component for AiUpsellCard {
fn scope() -> ComponentScope {
ComponentScope::Agent
ComponentScope::Onboarding
}
fn name() -> &'static str {
@ -188,30 +297,69 @@ impl Component for AiUpsellCard {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
Some(
v_flex()
.p_4()
.gap_4()
.children(vec![example_group(vec![
single_example(
"Signed Out State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
user_plan: None,
tab_index: Some(0),
}
.into_any_element(),
),
single_example(
"Signed In State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
user_plan: None,
tab_index: Some(1),
}
.into_any_element(),
),
])])
.items_center()
.max_w_4_5()
.child(single_example(
"Signed Out State",
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: None,
tab_index: Some(0),
}
.into_any_element(),
))
.child(example_group_with_title(
"Signed In States",
vec![
single_example(
"Free Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedFree),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Free Plan but Young Account",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: true,
user_plan: Some(Plan::ZedFree),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Trial",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedProTrial),
tab_index: Some(1),
}
.into_any_element(),
),
single_example(
"Pro Plan",
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
account_too_young: false,
user_plan: Some(Plan::ZedPro),
tab_index: Some(1),
}
.into_any_element(),
),
],
))
.into_any_element(),
)
}

View file

@ -0,0 +1,39 @@
use gpui::{IntoElement, ParentElement};
use ui::{List, ListBulletItem, prelude::*};
/// Centralized definitions for Zed AI plans
pub struct PlanDefinitions;
impl PlanDefinitions {
pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
pub fn free_plan(&self) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("50 prompts with Claude models"))
.child(ListBulletItem::new("2,000 accepted edit predictions"))
}
pub fn pro_trial(&self, period: bool) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("150 prompts with Claude models"))
.child(ListBulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
))
.when(period, |this| {
this.child(ListBulletItem::new(
"Try it out for 14 days for free, no credit card required",
))
})
}
pub fn pro_plan(&self, price: bool) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("500 prompts with Claude models"))
.child(ListBulletItem::new(
"Unlimited edit predictions with Zeta, our open-source model",
))
.when(price, |this| {
this.child(ListBulletItem::new("$20 USD per month"))
})
}
}

View file

@ -15,6 +15,7 @@ impl RenderOnce for YoungAccountBanner {
.child(YOUNG_ACCOUNT_DISCLAIMER);
div()
.max_w_full()
.my_1()
.child(Banner::new().severity(ui::Severity::Warning).child(label))
}

View file

@ -2,16 +2,16 @@
mod assistant_context_tests;
mod context_store;
use agent_settings::AgentSettings;
use agent_settings::{AgentSettings, SUMMARIZE_THREAD_PROMPT};
use anyhow::{Context as _, Result, bail};
use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
SlashCommandResult, SlashCommandWorkingSet,
};
use assistant_slash_commands::FileCommandMetadata;
use client::{self, Client, proto, telemetry::Telemetry};
use client::{self, Client, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
use clock::ReplicaId;
use cloud_llm_client::CompletionIntent;
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
@ -2080,7 +2080,18 @@ impl AssistantContext {
});
match event {
LanguageModelCompletionEvent::StatusUpdate { .. } => {}
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
match status_update {
CompletionRequestStatus::UsageUpdated { amount, limit } => {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
_ => {}
}
}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@ -2677,10 +2688,7 @@ impl AssistantContext {
let mut request = self.to_completion_request(Some(&model.model), cx);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
"Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
.into(),
],
content: vec![SUMMARIZE_THREAD_PROMPT.into()],
cache: false,
});
@ -2956,6 +2964,21 @@ impl AssistantContext {
summary.text = custom_summary;
cx.emit(ContextEvent::SummaryChanged);
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) {
let Some(project) = &self.project else {
return;
};
project.read(cx).user_store().update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
});
}
}
#[derive(Debug, Default)]

View file

@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief");
fake_model.stream_last_completion_response(" Introduction");
fake_model.send_last_completion_stream_text_chunk("Brief");
fake_model.send_last_completion_stream_text_chunk(" Introduction");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary");
fake_model.send_last_completion_stream_text_chunk("A successful summary");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model(
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response");
fake_model.send_last_completion_stream_text_chunk("Assistant response");
fake_model.end_last_completion_stream();
cx.run_until_parked();
}

View file

@ -2,7 +2,7 @@ mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
pub mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@ -14,7 +14,7 @@ mod open_tool;
mod project_notifications_tool;
mod read_file_tool;
mod schema;
mod templates;
pub mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::thinking_tool::ThinkingTool;
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use find_path_tool::*;
pub use grep_tool::{GrepTool, GrepToolInput};
pub use open_tool::OpenTool;
pub use project_notifications_tool::ProjectNotificationsTool;

View file

@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize};
use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
#[derive(Serialize)]
struct CreateFilePromptTemplate {
@ -682,11 +681,6 @@ impl EditAgent {
if last_message.content.is_empty() {
conversation.messages.pop();
}
} else {
debug_panic!(
"Last message must be an Assistant tool calling! Got {:?}",
last_message.content
);
}
}
@ -962,7 +956,7 @@ mod tests {
);
cx.run_until_parked();
model.stream_last_completion_response("<old_text>a");
model.send_last_completion_stream_text_chunk("<old_text>a");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
@ -974,7 +968,7 @@ mod tests {
None
);
model.stream_last_completion_response("bc</old_text>");
model.send_last_completion_stream_text_chunk("bc</old_text>");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
@ -996,7 +990,7 @@ mod tests {
})
);
model.stream_last_completion_response("<new_text>abX");
model.send_last_completion_stream_text_chunk("<new_text>abX");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_eq!(
@ -1011,7 +1005,7 @@ mod tests {
})
);
model.stream_last_completion_response("cY");
model.send_last_completion_stream_text_chunk("cY");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_eq!(
@ -1026,8 +1020,8 @@ mod tests {
})
);
model.stream_last_completion_response("</new_text>");
model.stream_last_completion_response("<old_text>hall");
model.send_last_completion_stream_text_chunk("</new_text>");
model.send_last_completion_stream_text_chunk("<old_text>hall");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
@ -1042,8 +1036,8 @@ mod tests {
})
);
model.stream_last_completion_response("ucinated old</old_text>");
model.stream_last_completion_response("<new_text>");
model.send_last_completion_stream_text_chunk("ucinated old</old_text>");
model.send_last_completion_stream_text_chunk("<new_text>");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
@ -1061,8 +1055,8 @@ mod tests {
})
);
model.stream_last_completion_response("hallucinated new</new_");
model.stream_last_completion_response("text>");
model.send_last_completion_stream_text_chunk("hallucinated new</new_");
model.send_last_completion_stream_text_chunk("text>");
cx.run_until_parked();
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
@ -1077,7 +1071,7 @@ mod tests {
})
);
model.stream_last_completion_response("<old_text>\nghi\nj");
model.send_last_completion_stream_text_chunk("<old_text>\nghi\nj");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
@ -1099,8 +1093,8 @@ mod tests {
})
);
model.stream_last_completion_response("kl</old_text>");
model.stream_last_completion_response("<new_text>");
model.send_last_completion_stream_text_chunk("kl</old_text>");
model.send_last_completion_stream_text_chunk("<new_text>");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
@ -1122,7 +1116,7 @@ mod tests {
})
);
model.stream_last_completion_response("GHI</new_text>");
model.send_last_completion_stream_text_chunk("GHI</new_text>");
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
@ -1367,7 +1361,9 @@ mod tests {
cx.background_spawn(async move {
for chunk in chunks {
executor.simulate_random_delay().await;
model.as_fake().stream_last_completion_response(chunk);
model
.as_fake()
.send_last_completion_stream_text_chunk(chunk);
}
model.as_fake().end_last_completion_stream();
})

View file

@ -1577,7 +1577,7 @@ mod tests {
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
@ -1641,7 +1641,7 @@ mod tests {
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
@ -1720,7 +1720,9 @@ mod tests {
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.send_last_completion_stream_text_chunk(
CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
);
model.end_last_completion_stream();
edit_task.await
@ -1777,7 +1779,9 @@ mod tests {
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.send_last_completion_stream_text_chunk(
CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
);
model.end_last_completion_stream();
edit_task.await

View file

@ -225,7 +225,6 @@ impl Tool for TerminalTool {
env,
..Default::default()
}),
window,
cx,
)
})?

View file

@ -37,7 +37,7 @@ impl Tool for ThinkingTool {
}
fn icon(&self) -> IconName {
IconName::ToolBulb
IconName::ToolThink
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {

View file

@ -134,10 +134,15 @@ impl Settings for AutoUpdateSetting {
type FileContent = Option<AutoUpdateSettingContent>;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let auto_update = [sources.server, sources.release_channel, sources.user]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
let auto_update = [
sources.server,
sources.release_channel,
sources.operating_system,
sources.user,
]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
Ok(Self(auto_update.0))
}

View file

@ -400,7 +400,6 @@ mod linux {
os::unix::net::{SocketAddr, UnixDatagram},
path::{Path, PathBuf},
process::{self, ExitStatus},
sync::LazyLock,
thread,
time::Duration,
};
@ -411,9 +410,6 @@ mod linux {
use crate::{Detect, InstalledApp};
static RELEASE_CHANNEL: LazyLock<String> =
LazyLock::new(|| include_str!("../../zed/RELEASE_CHANNEL").trim().to_string());
struct App(PathBuf);
impl Detect {
@ -444,10 +440,10 @@ mod linux {
fn zed_version_string(&self) -> String {
format!(
"Zed {}{}{} {}",
if *RELEASE_CHANNEL == "stable" {
if *release_channel::RELEASE_CHANNEL_NAME == "stable" {
"".to_string()
} else {
format!("{} ", *RELEASE_CHANNEL)
format!("{} ", *release_channel::RELEASE_CHANNEL_NAME)
},
option_env!("RELEASE_VERSION").unwrap_or_default(),
match option_env!("ZED_COMMIT_SHA") {
@ -459,7 +455,10 @@ mod linux {
}
fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
let sock_path = paths::data_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL));
let sock_path = paths::data_dir().join(format!(
"zed-{}.sock",
*release_channel::RELEASE_CHANNEL_NAME
));
let sock = UnixDatagram::unbound()?;
if sock.connect(&sock_path).is_err() {
self.boot_background(ipc_url)?;

View file

@ -14,7 +14,9 @@ use async_tungstenite::tungstenite::{
};
use clock::SystemClock;
use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
channel::oneshot, future::BoxFuture,
@ -191,6 +193,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
});
}
pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
struct GlobalClient(Arc<Client>);
impl Global for GlobalClient {}
@ -204,6 +208,7 @@ pub struct Client {
credentials_provider: ClientCredentialsProvider,
state: RwLock<ClientState>,
handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
#[allow(clippy::type_complexity)]
#[cfg(any(test, feature = "test-support"))]
@ -553,6 +558,7 @@ impl Client {
credentials_provider: ClientCredentialsProvider::new(cx),
state: Default::default(),
handler_set: Default::default(),
message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
#[cfg(any(test, feature = "test-support"))]
authenticate: Default::default(),
@ -933,23 +939,77 @@ impl Client {
}
}
/// Performs a sign-in and also connects to Collab.
/// Establishes a WebSocket connection with Cloud for receiving updates from the server.
async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
let connect_task = cx.update({
let cloud_client = self.cloud_client.clone();
move |cx| cloud_client.connect(cx)
})??;
let connection = connect_task.await?;
let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
task.detach();
cx.spawn({
let this = self.clone();
async move |cx| {
while let Some(message) = messages.next().await {
if let Some(message) = message.log_err() {
this.handle_message_to_client(message, cx);
}
}
}
})
.detach();
Ok(())
}
/// Performs a sign-in and also (optionally) connects to Collab.
///
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
/// to `sign_in` when we're ready to remove auto-connection to Collab.
/// Only Zed staff automatically connect to Collab.
pub async fn sign_in_with_optional_connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<()> {
let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
let mut is_staff_tx = Some(is_staff_tx);
cx.update(|cx| {
cx.on_flags_ready(move |state, _cx| {
if let Some(is_staff_tx) = is_staff_tx.take() {
is_staff_tx.send(state.is_staff).log_err();
}
})
.detach();
})
.log_err();
let credentials = self.sign_in(try_provider, cx).await?;
let connect_result = match self.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => result.context("client auth and connect"),
};
connect_result.log_err();
self.connect_to_cloud(cx).await.log_err();
cx.update(move |cx| {
cx.spawn({
let client = self.clone();
async move |cx| {
let is_staff = is_staff_rx.await?;
if is_staff {
match client.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => {
result.context("client auth and connect")
}
}
} else {
Ok(())
}
}
})
.detach_and_log_err(cx);
})
.log_err();
Ok(())
}
@ -1622,6 +1682,24 @@ impl Client {
}
}
pub fn add_message_to_client_handler(
self: &Arc<Client>,
handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
) {
self.message_to_client_handlers
.lock()
.push(Box::new(handler));
}
fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
cx.update(|cx| {
for handler in self.message_to_client_handlers.lock().iter() {
handler(&message, cx);
}
})
.ok();
}
pub fn telemetry(&self) -> &Arc<Telemetry> {
&self.telemetry
}

View file

@ -1,6 +1,7 @@
use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use cloud_api_client::websocket_protocol::MessageToClient;
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
@ -181,6 +182,12 @@ impl UserStore {
client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
];
client.add_message_to_client_handler({
let this = cx.weak_entity();
move |message, cx| Self::handle_message_to_client(this.clone(), message, cx)
});
Self {
users: Default::default(),
by_github_login: Default::default(),
@ -813,6 +820,32 @@ impl UserStore {
cx.emit(Event::PrivateUserInfoUpdated);
}
fn handle_message_to_client(this: WeakEntity<Self>, message: &MessageToClient, cx: &App) {
cx.spawn(async move |cx| {
match message {
MessageToClient::UserUpdated => {
let cloud_client = cx
.update(|cx| {
this.read_with(cx, |this, _cx| {
this.client.upgrade().map(|client| client.cloud_client())
})
})??
.ok_or(anyhow::anyhow!("Failed to get Cloud client"))?;
let response = cloud_client.get_authenticated_user().await?;
cx.update(|cx| {
this.update(cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
})??;
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}

View file

@ -15,7 +15,10 @@ path = "src/cloud_api_client.rs"
anyhow.workspace = true
cloud_api_types.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
parking_lot.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true
yawc.workspace = true

View file

@ -1,11 +1,19 @@
mod websocket;
use std::sync::Arc;
use anyhow::{Context, Result, anyhow};
use cloud_api_types::websocket_protocol::{PROTOCOL_VERSION, PROTOCOL_VERSION_HEADER_NAME};
pub use cloud_api_types::*;
use futures::AsyncReadExt as _;
use gpui::{App, Task};
use gpui_tokio::Tokio;
use http_client::http::request;
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode};
use parking_lot::RwLock;
use yawc::WebSocket;
use crate::websocket::Connection;
struct Credentials {
user_id: u32,
@ -78,6 +86,41 @@ impl CloudApiClient {
Ok(serde_json::from_str(&body)?)
}
pub fn connect(&self, cx: &App) -> Result<Task<Result<Connection>>> {
let mut connect_url = self
.http_client
.build_zed_cloud_url("/client/users/connect", &[])?;
connect_url
.set_scheme(match connect_url.scheme() {
"https" => "wss",
"http" => "ws",
scheme => Err(anyhow!("invalid URL scheme: {scheme}"))?,
})
.map_err(|_| anyhow!("failed to set URL scheme"))?;
let credentials = self.credentials.read();
let credentials = credentials.as_ref().context("no credentials provided")?;
let authorization_header = format!("{} {}", credentials.user_id, credentials.access_token);
Ok(cx.spawn(async move |cx| {
let handle = cx
.update(|cx| Tokio::handle(cx))
.ok()
.context("failed to get Tokio handle")?;
let _guard = handle.enter();
let ws = WebSocket::connect(connect_url)
.with_request(
request::Builder::new()
.header("Authorization", authorization_header)
.header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION.to_string()),
)
.await?;
Ok(Connection::new(ws))
}))
}
pub async fn accept_terms_of_service(&self) -> Result<AcceptTermsOfServiceResponse> {
let request = self.build_request(
Request::builder().method(Method::POST).uri(

View file

@ -0,0 +1,73 @@
use std::pin::Pin;
use std::time::Duration;
use anyhow::Result;
use cloud_api_types::websocket_protocol::MessageToClient;
use futures::channel::mpsc::unbounded;
use futures::stream::{SplitSink, SplitStream};
use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _, TryStreamExt as _, pin_mut};
use gpui::{App, BackgroundExecutor, Task};
use yawc::WebSocket;
use yawc::frame::{FrameView, OpCode};
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
pub type MessageStream = Pin<Box<dyn Stream<Item = Result<MessageToClient>>>>;
pub struct Connection {
tx: SplitSink<WebSocket, FrameView>,
rx: SplitStream<WebSocket>,
}
impl Connection {
pub fn new(ws: WebSocket) -> Self {
let (tx, rx) = ws.split();
Self { tx, rx }
}
pub fn spawn(self, cx: &App) -> (MessageStream, Task<()>) {
let (mut tx, rx) = (self.tx, self.rx);
let (message_tx, message_rx) = unbounded();
let handle_io = |executor: BackgroundExecutor| async move {
// Send messages on this frequency so the connection isn't closed.
let keepalive_timer = executor.timer(KEEPALIVE_INTERVAL).fuse();
futures::pin_mut!(keepalive_timer);
let rx = rx.fuse();
pin_mut!(rx);
loop {
futures::select_biased! {
_ = keepalive_timer => {
let _ = tx.send(FrameView::ping(Vec::new())).await;
keepalive_timer.set(executor.timer(KEEPALIVE_INTERVAL).fuse());
}
frame = rx.next() => {
let Some(frame) = frame else {
break;
};
match frame.opcode {
OpCode::Binary => {
let message_result = MessageToClient::deserialize(&frame.payload);
message_tx.unbounded_send(message_result).ok();
}
OpCode::Close => {
break;
}
_ => {}
}
}
}
}
};
let task = cx.spawn(async move |cx| handle_io(cx.background_executor().clone()).await);
(message_rx.into_stream().boxed(), task)
}
}

View file

@ -12,7 +12,9 @@ workspace = true
path = "src/cloud_api_types.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
ciborium.workspace = true
cloud_llm_client.workspace = true
serde.workspace = true
workspace-hack.workspace = true

View file

@ -1,4 +1,5 @@
mod timestamp;
pub mod websocket_protocol;
use serde::{Deserialize, Serialize};

View file

@ -0,0 +1,28 @@
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
/// The version of the Cloud WebSocket protocol.
pub const PROTOCOL_VERSION: u32 = 0;
/// The name of the header used to indicate the protocol version in use.
pub const PROTOCOL_VERSION_HEADER_NAME: &str = "x-zed-protocol-version";
/// A message from Cloud to the Zed client.
#[derive(Debug, Serialize, Deserialize)]
pub enum MessageToClient {
/// The user was updated and should be refreshed.
UserUpdated,
}
impl MessageToClient {
pub fn serialize(&self) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
ciborium::into_writer(self, &mut buffer).context("failed to serialize message")?;
Ok(buffer)
}
pub fn deserialize(data: &[u8]) -> Result<Self> {
ciborium::from_reader(data).context("failed to deserialize message")
}
}

View file

@ -2,5 +2,6 @@ ZED_ENVIRONMENT=production
RUST_LOG=info
INVITE_LINK_PREFIX=https://zed.dev/invites/
AUTO_JOIN_CHANNEL_ID=283
DATABASE_MAX_CONNECTIONS=250
# Set DATABASE_MAX_CONNECTIONS max connections in the `deploy_collab.yml`:
# https://github.com/zed-industries/zed/blob/main/.github/workflows/deploy_collab.yml
LLM_DATABASE_MAX_CONNECTIONS=25

View file

@ -1,485 +1,10 @@
use anyhow::{Context as _, bail};
use chrono::{DateTime, Utc};
use cloud_llm_client::LanguageModelProvider;
use collections::{HashMap, HashSet};
use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
use util::{ResultExt, maybe};
use std::sync::Arc;
use stripe::SubscriptionStatus;
use crate::AppState;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId,
};
use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{
db::{
CreateBillingCustomerParams, CreateBillingSubscriptionParams,
CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingSubscriptionParams, billing_customer,
},
stripe_billing::StripeBilling,
};
/// The amount of time we wait in between each poll of Stripe events.
///
/// This value should strike a balance between:
/// 1. Being short enough that we update quickly when something in Stripe changes
/// 2. Being long enough that we don't eat into our rate limits.
///
/// As a point of reference, the Sequin folks say they have this at **500ms**:
///
/// > We poll the Stripe /events endpoint every 500ms per account
/// >
/// > — https://blog.sequinstream.com/events-not-webhooks/
const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
/// The maximum number of events to return per page.
///
/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
///
/// > Limit can range between 1 and 100, and the default is 10.
const EVENTS_LIMIT_PER_PAGE: u64 = 100;
/// The number of pages consisting entirely of already-processed events that we
/// will see before we stop retrieving events.
///
/// This is used to prevent over-fetching the Stripe events API for events we've
/// already seen and processed.
const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
.await
.log_err();
executor.sleep(POLL_EVENTS_INTERVAL).await;
}
}
});
}
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_events_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
if sync_events_using_cloud {
return Ok(());
}
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
// so we need to unquote it.
event_type.to_string().trim_matches('"').to_string()
}
let event_types = [
EventType::CustomerCreated,
EventType::CustomerUpdated,
EventType::CustomerSubscriptionCreated,
EventType::CustomerSubscriptionUpdated,
EventType::CustomerSubscriptionPaused,
EventType::CustomerSubscriptionResumed,
EventType::CustomerSubscriptionDeleted,
]
.into_iter()
.map(event_type_to_string)
.collect::<Vec<_>>();
let mut pages_of_already_processed_events = 0;
let mut unprocessed_events = Vec::new();
log::info!(
"Stripe events: starting retrieval for {}",
event_types.join(", ")
);
let mut params = ListEvents::new();
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
.await?
.paginate(params);
loop {
let processed_event_ids = {
let event_ids = event_pages
.page
.data
.iter()
.map(|event| event.id.as_str())
.collect::<Vec<_>>();
app.db
.get_processed_stripe_events_by_event_ids(&event_ids)
.await?
.into_iter()
.map(|event| event.stripe_event_id)
.collect::<Vec<_>>()
};
let mut processed_events_in_page = 0;
let events_in_page = event_pages.page.data.len();
for event in &event_pages.page.data {
if processed_event_ids.contains(&event.id.to_string()) {
processed_events_in_page += 1;
log::debug!("Stripe events: already processed '{}', skipping", event.id);
} else {
unprocessed_events.push(event.clone());
}
}
if processed_events_in_page == events_in_page {
pages_of_already_processed_events += 1;
}
if event_pages.page.has_more {
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
{
log::info!(
"Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
);
break;
} else {
log::info!("Stripe events: retrieving next page");
event_pages = event_pages.next(&real_stripe_client).await?;
}
} else {
break;
}
}
log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
// Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
for event in unprocessed_events {
let event_id = event.id.clone();
let processed_event_params = CreateProcessedStripeEventParams {
stripe_event_id: event.id.to_string(),
stripe_event_type: event_type_to_string(event.type_),
stripe_event_created_timestamp: event.created,
};
// If the event has happened too far in the past, we don't want to
// process it and risk overwriting other more-recent updates.
//
// 1 day was chosen arbitrarily. This could be made longer or shorter.
let one_day = Duration::from_secs(24 * 60 * 60);
let a_day_ago = Utc::now() - one_day;
if a_day_ago.timestamp() > event.created {
log::info!(
"Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
event_id
);
app.db
.create_processed_stripe_event(&processed_event_params)
.await?;
continue;
}
let process_result = match event.type_ {
EventType::CustomerCreated | EventType::CustomerUpdated => {
handle_customer_event(app, real_stripe_client, event).await
}
EventType::CustomerSubscriptionCreated
| EventType::CustomerSubscriptionUpdated
| EventType::CustomerSubscriptionPaused
| EventType::CustomerSubscriptionResumed
| EventType::CustomerSubscriptionDeleted => {
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
}
_ => Ok(()),
};
if let Some(()) = process_result
.with_context(|| format!("failed to process event {event_id} successfully"))
.log_err()
{
app.db
.create_processed_stripe_event(&processed_event_params)
.await?;
}
}
Ok(())
}
async fn handle_customer_event(
app: &Arc<AppState>,
_stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Customer(customer) = event.data.object else {
bail!("unexpected event payload for {}", event.id);
};
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let Some(email) = customer.email else {
log::info!("Stripe customer has no email: skipping");
return Ok(());
};
let Some(user) = app.db.get_user_by_email(&email).await? else {
log::info!("no user found for email: skipping");
return Ok(());
};
if let Some(existing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(&customer.id)
.await?
{
app.db
.update_billing_customer(
existing_customer.id,
&UpdateBillingCustomerParams {
// For now we just leave the information as-is, as it is not
// likely to change.
..Default::default()
},
)
.await?;
} else {
app.db
.create_billing_customer(&CreateBillingCustomerParams {
user_id: user.id,
stripe_customer_id: customer.id.to_string(),
})
.await?;
}
Ok(())
}
async fn sync_subscription(
app: &Arc<AppState>,
stripe_client: &Arc<dyn StripeClient>,
subscription: StripeSubscription,
) -> anyhow::Result<billing_customer::Model> {
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
stripe_billing
.determine_subscription_kind(&subscription)
.await
} else {
None
};
let billing_customer =
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
.await?
.context("billing customer not found")?;
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
if subscription.status == SubscriptionStatus::Trialing {
let current_period_start =
DateTime::from_timestamp(subscription.current_period_start, 0)
.context("No trial subscription period start")?;
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
..Default::default()
},
)
.await?;
}
}
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
&& subscription
.cancellation_details
.as_ref()
.and_then(|details| details.reason)
.map_or(false, |reason| {
reason == StripeCancellationDetailsReason::PaymentFailed
});
if was_canceled_due_to_payment_failure {
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
has_overdue_invoices: ActiveValue::set(true),
..Default::default()
},
)
.await?;
}
if let Some(existing_subscription) = app
.db
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
.await?
{
app.db
.update_billing_subscription(
existing_subscription.id,
&UpdateBillingSubscriptionParams {
billing_customer_id: ActiveValue::set(billing_customer.id),
kind: ActiveValue::set(subscription_kind),
stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
stripe_subscription_status: ActiveValue::set(subscription.status.into()),
stripe_cancel_at: ActiveValue::set(
subscription
.cancel_at
.and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
.map(|time| time.naive_utc()),
),
stripe_cancellation_reason: ActiveValue::set(
subscription
.cancellation_details
.and_then(|details| details.reason)
.map(|reason| reason.into()),
),
stripe_current_period_start: ActiveValue::set(Some(
subscription.current_period_start,
)),
stripe_current_period_end: ActiveValue::set(Some(
subscription.current_period_end,
)),
},
)
.await?;
} else {
if let Some(existing_subscription) = app
.db
.get_active_billing_subscription(billing_customer.user_id)
.await?
{
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
{
let stripe_subscription_id = StripeSubscriptionId(
existing_subscription.stripe_subscription_id.clone().into(),
);
stripe_client
.cancel_subscription(&stripe_subscription_id)
.await?;
} else {
// If the user already has an active billing subscription, ignore the
// event and return an `Ok` to signal that it was processed
// successfully.
//
// There is the possibility that this could cause us to not create a
// subscription in the following scenario:
//
// 1. User has an active subscription A
// 2. User cancels subscription A
// 3. User creates a new subscription B
// 4. We process the new subscription B before the cancellation of subscription A
// 5. User ends up with no subscriptions
//
// In theory this situation shouldn't arise as we try to process the events in the order they occur.
log::info!(
"user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
user_id = billing_customer.user_id,
subscription_id = subscription.id
);
return Ok(billing_customer);
}
}
app.db
.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
kind: subscription_kind,
stripe_subscription_id: subscription.id.to_string(),
stripe_subscription_status: subscription.status.into(),
stripe_cancellation_reason: subscription
.cancellation_details
.and_then(|details| details.reason)
.map(|reason| reason.into()),
stripe_current_period_start: Some(subscription.current_period_start),
stripe_current_period_end: Some(subscription.current_period_end),
})
.await?;
}
if let Some(stripe_billing) = app.stripe_billing.as_ref() {
if subscription.status == SubscriptionStatus::Canceled
|| subscription.status == SubscriptionStatus::Paused
{
let already_has_active_billing_subscription = app
.db
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
.await?;
}
}
}
Ok(billing_customer)
}
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Subscription(subscription) = event.data.object else {
bail!("unexpected event payload for {}", event.id);
};
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
// When the user's subscription changes, push down any changes to their plan.
rpc_server
.update_plan_for_user_legacy(billing_customer.user_id)
.await
.trace_err();
// When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access.
rpc_server
.refresh_llm_tokens_for_user(billing_customer.user_id)
.await;
Ok(())
}
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::{CreateBillingCustomerParams, billing_customer};
use crate::stripe_client::{StripeClient, StripeCustomerId};
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
@ -496,16 +21,6 @@ impl From<SubscriptionStatus> for StripeSubscriptionStatus {
}
}
impl From<CancellationDetailsReason> for StripeCancellationReason {
fn from(value: CancellationDetailsReason) -> Self {
match value {
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
@ -542,194 +57,3 @@ pub async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::warn!("failed to retrieve Stripe billing object");
return;
};
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
return;
};
let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
.await
.context("failed to sync LLM request usage to Stripe")
.trace_err();
executor
.sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
.await;
}
}
});
}
async fn sync_model_request_usage_with_stripe(
app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_model_request_usage_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
if sync_model_request_usage_using_cloud {
return Ok(());
}
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();
let staff_users = app.db.get_staff_users().await?;
let staff_user_ids = staff_users
.iter()
.map(|user| user.id)
.collect::<HashSet<UserId>>();
let usage_meters = llm_db
.get_current_subscription_usage_meters(Utc::now())
.await?;
let mut usage_meters_by_user_id =
HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
for (usage_meter, usage) in usage_meters {
let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
meters.push(usage_meter);
}
log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
let get_zed_pro_subscriptions_started_at = Utc::now();
let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
log::info!(
"Stripe usage sync: Retrieved {} Zed Pro subscriptions in {}",
billing_subscriptions.len(),
Utc::now() - get_zed_pro_subscriptions_started_at
);
let claude_sonnet_4 = stripe_billing
.find_price_by_lookup_key("claude-sonnet-4-requests")
.await?;
let claude_sonnet_4_max = stripe_billing
.find_price_by_lookup_key("claude-sonnet-4-requests-max")
.await?;
let claude_opus_4 = stripe_billing
.find_price_by_lookup_key("claude-opus-4-requests")
.await?;
let claude_opus_4_max = stripe_billing
.find_price_by_lookup_key("claude-opus-4-requests-max")
.await?;
let claude_3_5_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
.await?;
let claude_3_7_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
.await?;
let claude_3_7_sonnet_max = stripe_billing
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
let model_mode_combinations = [
("claude-opus-4", CompletionMode::Max),
("claude-opus-4", CompletionMode::Normal),
("claude-sonnet-4", CompletionMode::Max),
("claude-sonnet-4", CompletionMode::Normal),
("claude-3-7-sonnet", CompletionMode::Max),
("claude-3-7-sonnet", CompletionMode::Normal),
("claude-3-5-sonnet", CompletionMode::Normal),
];
let billing_subscription_count = billing_subscriptions.len();
log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
maybe!(async {
if staff_user_ids.contains(&user_id) {
return anyhow::Ok(());
}
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription_id =
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
let usage_meters = usage_meters_by_user_id.get(&user_id);
for (model, mode) in &model_mode_combinations {
let Ok(model) =
llm_db.model(LanguageModelProvider::Anthropic, model)
else {
log::warn!("Failed to load model for user {user_id}: {model}");
continue;
};
let (price, meter_event_name) = match model.name.as_str() {
"claude-opus-4" => match mode {
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
},
"claude-sonnet-4" => match mode {
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
CompletionMode::Max => {
(&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
}
},
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match mode {
CompletionMode::Normal => {
(&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
}
CompletionMode::Max => {
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
}
},
model_name => {
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
}
};
let model_requests = usage_meters
.and_then(|usage_meters| {
usage_meters
.iter()
.find(|meter| meter.model_id == model.id && meter.mode == *mode)
})
.map(|usage_meter| usage_meter.requests)
.unwrap_or(0);
if model_requests > 0 {
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price)
.await?;
}
stripe_billing
.bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
.await
.with_context(|| {
format!(
"Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
)
})?;
}
Ok(())
})
.await
.log_err();
}
log::info!(
"Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
Utc::now() - started_at
);
Ok(())
}

View file

@ -85,19 +85,6 @@ impl Database {
.await
}
/// Returns the billing subscription with the specified ID.
pub async fn get_billing_subscription_by_id(
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
})
.await
}
/// Returns the billing subscription with the specified Stripe subscription ID.
pub async fn get_billing_subscription_by_stripe_subscription_id(
&self,
@ -143,119 +130,6 @@ impl Database {
.await
}
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
/// If you're wanting to check if a use has an active billing subscription,
/// use `get_active_billing_subscriptions` instead.
pub async fn get_billing_subscriptions(
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;
Ok(subscriptions)
})
.await
}
pub async fn get_active_billing_subscriptions(
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.is_null())
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
}
})
.await
}
pub async fn get_active_zed_pro_billing_subscriptions(
&self,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
})
.await
}
pub async fn get_active_zed_pro_billing_subscriptions_for_users(
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
}
})
.await
}
/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)

View file

@ -699,7 +699,10 @@ impl Database {
language_server::Column::ProjectId,
language_server::Column::Id,
])
.update_column(language_server::Column::Name)
.update_columns([
language_server::Column::Name,
language_server::Column::Capabilities,
])
.to_owned(),
)
.exec(&*tx)

View file

@ -1,4 +1,3 @@
mod billing_subscription_tests;
mod buffer_tests;
mod channel_tests;
mod contributor_tests;

View file

@ -1,96 +0,0 @@
use std::sync::Arc;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::tests::new_test_user;
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
use super::Database;
test_both_dbs!(
test_get_active_billing_subscriptions,
test_get_active_billing_subscriptions_postgres,
test_get_active_billing_subscriptions_sqlite
);
async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
// A user with no subscription has no active billing subscriptions.
{
let user_id = new_test_user(db, "no-subscription-user@example.com").await;
let subscription_count = db
.count_active_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 0);
}
// A user with an active subscription has one active billing subscription.
{
let user_id = new_test_user(db, "active-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_active_user".into(),
})
.await
.unwrap();
assert_eq!(customer.stripe_customer_id, "cus_active_user".to_string());
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
kind: None,
stripe_subscription_id: "sub_active_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Active,
stripe_cancellation_reason: None,
stripe_current_period_start: None,
stripe_current_period_end: None,
})
.await
.unwrap();
let subscriptions = db.get_billing_subscriptions(user_id).await.unwrap();
assert_eq!(subscriptions.len(), 1);
let subscription = &subscriptions[0];
assert_eq!(
subscription.stripe_subscription_id,
"sub_active_user".to_string()
);
assert_eq!(
subscription.stripe_subscription_status,
StripeSubscriptionStatus::Active
);
}
// A user with a past-due subscription has no active billing subscriptions.
{
let user_id = new_test_user(db, "past-due-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_past_due_user".into(),
})
.await
.unwrap();
assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
kind: None,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
stripe_cancellation_reason: None,
stripe_current_period_start: None,
stripe_current_period_end: None,
})
.await
.unwrap();
let subscription_count = db
.count_active_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 0);
}
}

View file

@ -1,6 +1,5 @@
use super::*;
pub mod providers;
pub mod subscription_usage_meters;
pub mod subscription_usages;
pub mod usages;

View file

@ -1,72 +0,0 @@
use crate::db::UserId;
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
use super::*;
impl LlmDatabase {
/// Returns all current subscription usage meters as of the given timestamp.
pub async fn get_current_subscription_usage_meters(
&self,
now: DateTimeUtc,
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
let now = convert_chrono_to_time(now)?;
self.transaction(|tx| async move {
let result = subscription_usage_meter::Entity::find()
.inner_join(subscription_usage::Entity)
.filter(
subscription_usage::Column::PeriodStartAt
.lte(now)
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
)
.select_also(subscription_usage::Entity)
.all(&*tx)
.await?;
let result = result
.into_iter()
.filter_map(|(meter, usage)| {
let usage = usage?;
Some((meter, usage))
})
.collect();
Ok(result)
})
.await
}
/// Returns all current subscription usage meters for the given user as of the given timestamp.
pub async fn get_current_subscription_usage_meters_for_user(
&self,
user_id: UserId,
now: DateTimeUtc,
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
let now = convert_chrono_to_time(now)?;
self.transaction(|tx| async move {
let result = subscription_usage_meter::Entity::find()
.inner_join(subscription_usage::Entity)
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(
subscription_usage::Column::PeriodStartAt
.lte(now)
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
)
.select_also(subscription_usage::Entity)
.all(&*tx)
.await?;
let result = result
.into_iter()
.filter_map(|(meter, usage)| {
let usage = usage?;
Some((meter, usage))
})
.collect();
Ok(result)
})
.await
}
}

View file

@ -1,28 +1,7 @@
use time::PrimitiveDateTime;
use crate::db::UserId;
use super::*;
pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
use chrono::{Datelike as _, Timelike as _};
let date = time::Date::from_calendar_date(
datetime.year(),
time::Month::try_from(datetime.month() as u8).unwrap(),
datetime.day() as u8,
)?;
let time = time::Time::from_hms_nano(
datetime.hour() as u8,
datetime.minute() as u8,
datetime.second() as u8,
datetime.nanosecond(),
)?;
Ok(PrimitiveDateTime::new(date, time))
}
impl LlmDatabase {
pub async fn get_subscription_usage_for_period(
&self,

View file

@ -7,8 +7,8 @@ use axum::{
routing::get,
};
use collab::ServiceMode;
use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
@ -16,7 +16,6 @@ use collab::{
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
use std::{
env::args,
@ -31,7 +30,7 @@ use tower_http::trace::TraceLayer;
use tracing_subscriber::{
Layer, filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt,
};
use util::{ResultExt as _, maybe};
use util::ResultExt as _;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@ -120,8 +119,6 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));
@ -133,29 +130,6 @@ async fn main() -> Result<()> {
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());
let llm_db = maybe!(async {
let database_url = state
.config
.llm_database_url
.as_ref()
.context("missing LLM_DATABASE_URL")?;
let max_connections = state
.config
.llm_database_max_connections
.context("missing LLM_DATABASE_MAX_CONNECTIONS")?;
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
LlmDatabase::new(db_options, state.executor.clone()).await
})
.await
.trace_err();
if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?;
sync_llm_request_usage_with_stripe_periodically(state.clone());
}
app = app
.merge(collab::api::events::router())
.merge(collab::api::extensions::router())

View file

@ -41,9 +41,11 @@ use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use futures::TryFutureExt as _;
use reqwest_client::ReqwestClient;
use rpc::proto::{MultiLspQuery, split_repository_update};
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use tracing::Span;
use futures::{
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512;
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
const TOTAL_DURATION_MS: &str = "total_duration_ms";
const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
const QUEUE_DURATION_MS: &str = "queue_duration_ms";
const HOST_WAITING_MS: &str = "host_waiting_ms";
type MessageHandler =
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
pub struct ConnectionGuard;
@ -163,6 +170,42 @@ impl Principal {
}
}
#[derive(Clone)]
struct MessageContext {
session: Session,
span: tracing::Span,
}
impl Deref for MessageContext {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl MessageContext {
pub fn forward_request<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = anyhow::Result<T::Response>> {
let request_start_time = Instant::now();
let span = self.span.clone();
tracing::info!("start forwarding request");
self.peer
.forward_request(self.connection_id, receiver_id, request)
.inspect(move |_| {
span.record(
HOST_WAITING_MS,
request_start_time.elapsed().as_micros() as f64 / 1000.0,
);
})
.inspect_err(|_| tracing::error!("error forwarding request"))
.inspect_ok(|_| tracing::info!("finished forwarding request"))
}
}
#[derive(Clone)]
struct Session {
principal: Principal,
@ -340,9 +383,6 @@ impl Server {
.add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
.add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
.add_request_handler(
forward_read_only_project_request::<proto::LanguageServerIdForName>,
)
.add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
.add_request_handler(
forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
@ -649,40 +689,37 @@ impl Server {
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
let prev_handler = self.handlers.insert(
TypeId::of::<M>(),
Box::new(move |envelope, session| {
Box::new(move |envelope, session, span| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
let received_at = envelope.received_at;
tracing::info!("message received");
let start_time = Instant::now();
let future = (handler)(*envelope, session);
let future = (handler)(
*envelope,
MessageContext {
session,
span: span.clone(),
},
);
async move {
let result = future.await;
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
let queue_duration_ms = total_duration_ms - processing_duration_ms;
span.record(TOTAL_DURATION_MS, total_duration_ms);
span.record(PROCESSING_DURATION_MS, processing_duration_ms);
span.record(QUEUE_DURATION_MS, queue_duration_ms);
match result {
Err(error) => {
tracing::error!(
?error,
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"error handling message"
)
tracing::error!(?error, "error handling message")
}
Ok(()) => tracing::info!(
total_duration_ms,
processing_duration_ms,
queue_duration_ms,
"finished handling message"
),
Ok(()) => tracing::info!("finished handling message"),
}
}
.boxed()
@ -696,7 +733,7 @@ impl Server {
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, Session) -> Fut,
F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
Fut: 'static + Send + Future<Output = Result<()>>,
M: EnvelopedMessage,
{
@ -706,7 +743,7 @@ impl Server {
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
@ -749,6 +786,7 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
release_channel: Option<String>,
user_agent: Option<String>,
geoip_country_code: Option<String>,
system_id: Option<String>,
@ -763,12 +801,16 @@ impl Server {
login=field::Empty,
impersonator=field::Empty,
user_agent=field::Empty,
geoip_country_code=field::Empty
geoip_country_code=field::Empty,
release_channel=field::Empty,
);
principal.update_span(&span);
if let Some(user_agent) = user_agent {
span.record("user_agent", user_agent);
}
if let Some(release_channel) = release_channel {
span.record("release_channel", release_channel);
}
if let Some(country_code) = geoip_country_code.as_ref() {
span.record("geoip_country_code", country_code);
@ -887,12 +929,17 @@ impl Server {
login=field::Empty,
impersonator=field::Empty,
multi_lsp_query_request=field::Empty,
release_channel=field::Empty,
{ TOTAL_DURATION_MS }=field::Empty,
{ PROCESSING_DURATION_MS }=field::Empty,
{ QUEUE_DURATION_MS }=field::Empty,
{ HOST_WAITING_MS }=field::Empty
);
principal.update_span(&span);
let span_enter = span.enter();
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
let is_background = message.is_background();
let handle_message = (handler)(message, session.clone());
let handle_message = (handler)(message, session.clone(), span.clone());
drop(span_enter);
let handle_message = async move {
@ -1184,6 +1231,35 @@ impl Header for AppVersionHeader {
}
}
#[derive(Debug)]
pub struct ReleaseChannelHeader(String);
impl Header for ReleaseChannelHeader {
fn name() -> &'static HeaderName {
static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
}
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
Ok(Self(
values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?
.to_owned(),
))
}
fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
values.extend([self.0.parse().unwrap()]);
}
}
pub fn routes(server: Arc<Server>) -> Router<(), Body> {
Router::new()
.route("/rpc", get(handle_websocket_request))
@ -1199,6 +1275,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>,
release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
@ -1223,6 +1300,8 @@ pub async fn handle_websocket_request(
.into_response();
};
let release_channel = release_channel_header.map(|header| header.0.0);
if !version.can_collaborate() {
return (
StatusCode::UPGRADE_REQUIRED,
@ -1258,6 +1337,7 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
release_channel,
user_agent.map(|header| header.to_string()),
country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
@ -1351,7 +1431,11 @@ async fn connection_lost(
}
/// Acknowledges a ping from a client, used to keep the connection alive.
async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
async fn ping(
_: proto::Ping,
response: Response<proto::Ping>,
_session: MessageContext,
) -> Result<()> {
response.send(proto::Ack {})?;
Ok(())
}
@ -1360,7 +1444,7 @@ async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session
async fn create_room(
_request: proto::CreateRoom,
response: Response<proto::CreateRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let livekit_room = nanoid::nanoid!(30);
@ -1400,7 +1484,7 @@ async fn create_room(
async fn join_room(
request: proto::JoinRoom,
response: Response<proto::JoinRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.id);
@ -1467,7 +1551,7 @@ async fn join_room(
async fn rejoin_room(
request: proto::RejoinRoom,
response: Response<proto::RejoinRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room;
let channel;
@ -1595,15 +1679,15 @@ fn notify_rejoined_projects(
}
// Stream this worktree's diagnostics.
for summary in worktree.diagnostic_summaries {
session.peer.send(
session.connection_id,
proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
},
)?;
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
if let Some(summary) = worktree_diagnostics.next() {
let message = proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
more_summaries: worktree_diagnostics.collect(),
};
session.peer.send(session.connection_id, message)?;
}
for settings_file in worktree.settings_files {
@ -1644,7 +1728,7 @@ fn notify_rejoined_projects(
async fn leave_room(
_: proto::LeaveRoom,
response: Response<proto::LeaveRoom>,
session: Session,
session: MessageContext,
) -> Result<()> {
leave_room_for_session(&session, session.connection_id).await?;
response.send(proto::Ack {})?;
@ -1655,7 +1739,7 @@ async fn leave_room(
async fn set_room_participant_role(
request: proto::SetRoomParticipantRole,
response: Response<proto::SetRoomParticipantRole>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_id = UserId::from_proto(request.user_id);
let role = ChannelRole::from(request.role());
@ -1703,7 +1787,7 @@ async fn set_room_participant_role(
async fn call(
request: proto::Call,
response: Response<proto::Call>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let calling_user_id = session.user_id();
@ -1772,7 +1856,7 @@ async fn call(
async fn cancel_call(
request: proto::CancelCall,
response: Response<proto::CancelCall>,
session: Session,
session: MessageContext,
) -> Result<()> {
let called_user_id = UserId::from_proto(request.called_user_id);
let room_id = RoomId::from_proto(request.room_id);
@ -1807,7 +1891,7 @@ async fn cancel_call(
}
/// Decline an incoming call.
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(message.room_id);
{
let room = session
@ -1842,7 +1926,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
async fn update_participant_location(
request: proto::UpdateParticipantLocation,
response: Response<proto::UpdateParticipantLocation>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let location = request.location.context("invalid location")?;
@ -1861,7 +1945,7 @@ async fn update_participant_location(
async fn share_project(
request: proto::ShareProject,
response: Response<proto::ShareProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let (project_id, room) = &*session
.db()
@ -1882,7 +1966,7 @@ async fn share_project(
}
/// Unshare a project from the room.
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id);
unshare_project_internal(project_id, session.connection_id, &session).await
}
@ -1929,7 +2013,7 @@ async fn unshare_project_internal(
async fn join_project(
request: proto::JoinProject,
response: Response<proto::JoinProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
@ -2025,15 +2109,15 @@ async fn join_project(
}
// Stream this worktree's diagnostics.
for summary in worktree.diagnostic_summaries {
session.peer.send(
session.connection_id,
proto::UpdateDiagnosticSummary {
project_id: project_id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
},
)?;
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
if let Some(summary) = worktree_diagnostics.next() {
let message = proto::UpdateDiagnosticSummary {
project_id: project.id.to_proto(),
worktree_id: worktree.id,
summary: Some(summary),
more_summaries: worktree_diagnostics.collect(),
};
session.peer.send(session.connection_id, message)?;
}
for settings_file in worktree.settings_files {
@ -2076,7 +2160,7 @@ async fn join_project(
}
/// Leave someone elses shared project.
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
let sender_id = session.connection_id;
let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await;
@ -2099,7 +2183,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
async fn update_project(
request: proto::UpdateProject,
response: Response<proto::UpdateProject>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let (room, guest_connection_ids) = &*session
@ -2128,7 +2212,7 @@ async fn update_project(
async fn update_worktree(
request: proto::UpdateWorktree,
response: Response<proto::UpdateWorktree>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2152,7 +2236,7 @@ async fn update_worktree(
async fn update_repository(
request: proto::UpdateRepository,
response: Response<proto::UpdateRepository>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2176,7 +2260,7 @@ async fn update_repository(
async fn remove_repository(
request: proto::RemoveRepository,
response: Response<proto::RemoveRepository>,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2200,7 +2284,7 @@ async fn remove_repository(
/// Updates other participants with changes to the diagnostics
async fn update_diagnostic_summary(
message: proto::UpdateDiagnosticSummary,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2224,7 +2308,7 @@ async fn update_diagnostic_summary(
/// Updates other participants with changes to the worktree settings
async fn update_worktree_settings(
message: proto::UpdateWorktreeSettings,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2248,7 +2332,7 @@ async fn update_worktree_settings(
/// Notify other participants that a language server has started.
async fn start_language_server(
request: proto::StartLanguageServer,
session: Session,
session: MessageContext,
) -> Result<()> {
let guest_connection_ids = session
.db()
@ -2271,7 +2355,7 @@ async fn start_language_server(
/// Notify other participants that a language server has changed.
async fn update_language_server(
request: proto::UpdateLanguageServer,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let db = session.db().await;
@ -2304,7 +2388,7 @@ async fn update_language_server(
async fn forward_read_only_project_request<T>(
request: T,
response: Response<T>,
session: Session,
session: MessageContext,
) -> Result<()>
where
T: EntityMessage + RequestMessage,
@ -2315,10 +2399,7 @@ where
.await
.host_for_read_only_project_request(project_id, session.connection_id)
.await?;
let payload = session
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
let payload = session.forward_request(host_connection_id, request).await?;
response.send(payload)?;
Ok(())
}
@ -2328,7 +2409,7 @@ where
async fn forward_mutating_project_request<T>(
request: T,
response: Response<T>,
session: Session,
session: MessageContext,
) -> Result<()>
where
T: EntityMessage + RequestMessage,
@ -2340,10 +2421,7 @@ where
.await
.host_for_mutating_project_request(project_id, session.connection_id)
.await?;
let payload = session
.peer
.forward_request(session.connection_id, host_connection_id, request)
.await?;
let payload = session.forward_request(host_connection_id, request).await?;
response.send(payload)?;
Ok(())
}
@ -2351,7 +2429,7 @@ where
async fn multi_lsp_query(
request: MultiLspQuery,
response: Response<MultiLspQuery>,
session: Session,
session: MessageContext,
) -> Result<()> {
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
tracing::info!("multi_lsp_query message received");
@ -2361,7 +2439,7 @@ async fn multi_lsp_query(
/// Notify other participants that a new buffer has been created
async fn create_buffer_for_peer(
request: proto::CreateBufferForPeer,
session: Session,
session: MessageContext,
) -> Result<()> {
session
.db()
@ -2383,7 +2461,7 @@ async fn create_buffer_for_peer(
async fn update_buffer(
request: proto::UpdateBuffer,
response: Response<proto::UpdateBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let mut capability = Capability::ReadOnly;
@ -2418,17 +2496,14 @@ async fn update_buffer(
};
if host != session.connection_id {
session
.peer
.forward_request(session.connection_id, host, request.clone())
.await?;
session.forward_request(host, request.clone()).await?;
}
response.send(proto::Ack {})?;
Ok(())
}
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id);
let operation = message.operation.as_ref().context("invalid operation")?;
@ -2473,7 +2548,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu
/// Notify other participants that a project has been updated.
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
request: T,
session: Session,
session: MessageContext,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.remote_entity_id());
let project_connection_ids = session
@ -2498,7 +2573,7 @@ async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProj
async fn follow(
request: proto::Follow,
response: Response<proto::Follow>,
session: Session,
session: MessageContext,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
@ -2511,10 +2586,7 @@ async fn follow(
.check_room_participants(room_id, leader_id, session.connection_id)
.await?;
let response_payload = session
.peer
.forward_request(session.connection_id, leader_id, request)
.await?;
let response_payload = session.forward_request(leader_id, request).await?;
response.send(response_payload)?;
if let Some(project_id) = project_id {
@ -2530,7 +2602,7 @@ async fn follow(
}
/// Stop following another user in a call.
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
let leader_id = request.leader_id.context("invalid leader id")?.into();
@ -2559,7 +2631,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
}
/// Notify everyone following you of your current location.
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let database = session.db.lock().await;
@ -2594,7 +2666,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
async fn get_users(
request: proto::GetUsers,
response: Response<proto::GetUsers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_ids = request
.user_ids
@ -2622,7 +2694,7 @@ async fn get_users(
async fn fuzzy_search_users(
request: proto::FuzzySearchUsers,
response: Response<proto::FuzzySearchUsers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let query = request.query;
let users = match query.len() {
@ -2654,7 +2726,7 @@ async fn fuzzy_search_users(
async fn request_contact(
request: proto::RequestContact,
response: Response<proto::RequestContact>,
session: Session,
session: MessageContext,
) -> Result<()> {
let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.responder_id);
@ -2701,7 +2773,7 @@ async fn request_contact(
async fn respond_to_contact_request(
request: proto::RespondToContactRequest,
response: Response<proto::RespondToContactRequest>,
session: Session,
session: MessageContext,
) -> Result<()> {
let responder_id = session.user_id();
let requester_id = UserId::from_proto(request.requester_id);
@ -2759,7 +2831,7 @@ async fn respond_to_contact_request(
async fn remove_contact(
request: proto::RemoveContact,
response: Response<proto::RemoveContact>,
session: Session,
session: MessageContext,
) -> Result<()> {
let requester_id = session.user_id();
let responder_id = UserId::from_proto(request.user_id);
@ -3018,7 +3090,10 @@ async fn update_user_plan(session: &Session) -> Result<()> {
Ok(())
}
async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
async fn subscribe_to_channels(
_: proto::SubscribeToChannels,
session: MessageContext,
) -> Result<()> {
subscribe_user_to_channels(session.user_id(), &session).await?;
Ok(())
}
@ -3044,7 +3119,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul
async fn create_channel(
request: proto::CreateChannel,
response: Response<proto::CreateChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@ -3099,7 +3174,7 @@ async fn create_channel(
async fn delete_channel(
request: proto::DeleteChannel,
response: Response<proto::DeleteChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@ -3127,7 +3202,7 @@ async fn delete_channel(
async fn invite_channel_member(
request: proto::InviteChannelMember,
response: Response<proto::InviteChannelMember>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3164,7 +3239,7 @@ async fn invite_channel_member(
async fn remove_channel_member(
request: proto::RemoveChannelMember,
response: Response<proto::RemoveChannelMember>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3208,7 +3283,7 @@ async fn remove_channel_member(
async fn set_channel_visibility(
request: proto::SetChannelVisibility,
response: Response<proto::SetChannelVisibility>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3253,7 +3328,7 @@ async fn set_channel_visibility(
async fn set_channel_member_role(
request: proto::SetChannelMemberRole,
response: Response<proto::SetChannelMemberRole>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3301,7 +3376,7 @@ async fn set_channel_member_role(
async fn rename_channel(
request: proto::RenameChannel,
response: Response<proto::RenameChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3333,7 +3408,7 @@ async fn rename_channel(
async fn move_channel(
request: proto::MoveChannel,
response: Response<proto::MoveChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let to = ChannelId::from_proto(request.to);
@ -3375,7 +3450,7 @@ async fn move_channel(
async fn reorder_channel(
request: proto::ReorderChannel,
response: Response<proto::ReorderChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let direction = request.direction();
@ -3421,7 +3496,7 @@ async fn reorder_channel(
async fn get_channel_members(
request: proto::GetChannelMembers,
response: Response<proto::GetChannelMembers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3441,7 +3516,7 @@ async fn get_channel_members(
async fn respond_to_channel_invite(
request: proto::RespondToChannelInvite,
response: Response<proto::RespondToChannelInvite>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3482,7 +3557,7 @@ async fn respond_to_channel_invite(
async fn join_channel(
request: proto::JoinChannel,
response: Response<proto::JoinChannel>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
join_channel_internal(channel_id, Box::new(response), session).await
@ -3505,7 +3580,7 @@ impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
async fn join_channel_internal(
channel_id: ChannelId,
response: Box<impl JoinChannelInternalResponse>,
session: Session,
session: MessageContext,
) -> Result<()> {
let joined_room = {
let mut db = session.db().await;
@ -3600,7 +3675,7 @@ async fn join_channel_internal(
async fn join_channel_buffer(
request: proto::JoinChannelBuffer,
response: Response<proto::JoinChannelBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3631,7 +3706,7 @@ async fn join_channel_buffer(
/// Edit the channel notes
async fn update_channel_buffer(
request: proto::UpdateChannelBuffer,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3683,7 +3758,7 @@ async fn update_channel_buffer(
async fn rejoin_channel_buffers(
request: proto::RejoinChannelBuffers,
response: Response<proto::RejoinChannelBuffers>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let buffers = db
@ -3718,7 +3793,7 @@ async fn rejoin_channel_buffers(
async fn leave_channel_buffer(
request: proto::LeaveChannelBuffer,
response: Response<proto::LeaveChannelBuffer>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
@ -3780,7 +3855,7 @@ fn send_notifications(
async fn send_channel_message(
request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
// Validate the message body.
let body = request.body.trim().to_string();
@ -3873,7 +3948,7 @@ async fn send_channel_message(
async fn remove_channel_message(
request: proto::RemoveChannelMessage,
response: Response<proto::RemoveChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@ -3908,7 +3983,7 @@ async fn remove_channel_message(
async fn update_channel_message(
request: proto::UpdateChannelMessage,
response: Response<proto::UpdateChannelMessage>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@ -3992,7 +4067,7 @@ async fn update_channel_message(
/// Mark a channel message as read
async fn acknowledge_channel_message(
request: proto::AckChannelMessage,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
@ -4012,7 +4087,7 @@ async fn acknowledge_channel_message(
/// Mark a buffer version as synced
async fn acknowledge_buffer_version(
request: proto::AckBufferOperation,
session: Session,
session: MessageContext,
) -> Result<()> {
let buffer_id = BufferId::from_proto(request.buffer_id);
session
@ -4032,7 +4107,7 @@ async fn acknowledge_buffer_version(
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,
response: Response<proto::GetSupermavenApiKey>,
session: Session,
session: MessageContext,
) -> Result<()> {
let user_id: String = session.user_id().to_string();
if !session.is_staff() {
@ -4061,7 +4136,7 @@ async fn get_supermaven_api_key(
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
@ -4079,7 +4154,10 @@ async fn join_channel_chat(
}
/// Stop receiving chat updates for a channel
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
async fn leave_channel_chat(
request: proto::LeaveChannelChat,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
session
.db()
@ -4093,7 +4171,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session)
async fn get_channel_messages(
request: proto::GetChannelMessages,
response: Response<proto::GetChannelMessages>,
session: Session,
session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let messages = session
@ -4117,7 +4195,7 @@ async fn get_channel_messages(
async fn get_channel_messages_by_id(
request: proto::GetChannelMessagesById,
response: Response<proto::GetChannelMessagesById>,
session: Session,
session: MessageContext,
) -> Result<()> {
let message_ids = request
.message_ids
@ -4140,7 +4218,7 @@ async fn get_channel_messages_by_id(
async fn get_notifications(
request: proto::GetNotifications,
response: Response<proto::GetNotifications>,
session: Session,
session: MessageContext,
) -> Result<()> {
let notifications = session
.db()
@ -4162,7 +4240,7 @@ async fn get_notifications(
async fn mark_notification_as_read(
request: proto::MarkNotificationRead,
response: Response<proto::MarkNotificationRead>,
session: Session,
session: MessageContext,
) -> Result<()> {
let database = &session.db().await;
let notifications = database
@ -4184,7 +4262,7 @@ async fn mark_notification_as_read(
async fn get_private_user_info(
_request: proto::GetPrivateUserInfo,
response: Response<proto::GetPrivateUserInfo>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@ -4208,7 +4286,7 @@ async fn get_private_user_info(
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
@ -4232,7 +4310,7 @@ async fn accept_terms_of_service(
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: Session,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;

View file

@ -1,21 +1,15 @@
use std::sync::Arc;
use anyhow::anyhow;
use chrono::Utc;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::stripe_client::{
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
StripeSubscription,
};
pub struct StripeBilling {
@ -94,30 +88,6 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn determine_subscription_kind(
&self,
subscription: &StripeSubscription,
) -> Option<SubscriptionKind> {
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
subscription.items.iter().find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_pro_price_id {
Some(if subscription.status == SubscriptionStatus::Trialing {
SubscriptionKind::ZedProTrial
} else {
SubscriptionKind::ZedPro
})
} else if price.id == zed_free_price_id {
Some(SubscriptionKind::ZedFree)
} else {
None
}
})
}
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
/// not already exist.
///
@ -150,65 +120,6 @@ impl StripeBilling {
Ok(customer_id)
}
pub async fn subscribe_to_price(
&self,
subscription_id: &StripeSubscriptionId,
price: &StripePrice,
) -> Result<()> {
let subscription = self.client.get_subscription(subscription_id).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
}
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
self.client
.update_subscription(
subscription_id,
UpdateSubscriptionParams {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
)
.await?;
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &StripeCustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
self.client
.create_meter_event(StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
value: requests as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
})
.await?;
Ok(())
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
@ -243,14 +154,3 @@ impl StripeBilling {
Ok(subscription)
}
}
fn subscription_contains_price(
subscription: &StripeSubscription,
price_id: &StripePriceId,
) -> bool {
subscription.items.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)
})
}

View file

@ -24,10 +24,7 @@ use language::{
};
use project::{
ProjectPath, SERVER_PROGRESS_THROTTLE_TIMEOUT,
lsp_store::{
lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
rust_analyzer_ext::RUST_ANALYZER_NAME,
},
lsp_store::lsp_ext_command::{ExpandedMacro, LspExtExpandMacro},
project_settings::{InlineBlameSettings, ProjectSettings},
};
use recent_projects::disconnected_overlay::DisconnectedOverlay;
@ -3104,9 +3101,7 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
// Turn inline-blame-off by default so no state is transferred without us explicitly doing so
let inline_blame_off_settings = Some(InlineBlameSettings {
enabled: false,
delay_ms: None,
min_column: None,
show_commit_summary: false,
..Default::default()
});
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
@ -3786,11 +3781,18 @@ async fn test_client_can_query_lsp_ext(cx_a: &mut TestAppContext, cx_b: &mut Tes
cx_b.update(editor::init);
client_a.language_registry().add(rust_lang());
client_b.language_registry().add(rust_lang());
let mut fake_language_servers = client_a.language_registry().register_fake_lsp(
"Rust",
FakeLspAdapter {
name: RUST_ANALYZER_NAME,
name: "rust-analyzer",
..FakeLspAdapter::default()
},
);
client_b.language_registry().add(rust_lang());
client_b.language_registry().register_fake_lsp_adapter(
"Rust",
FakeLspAdapter {
name: "rust-analyzer",
..FakeLspAdapter::default()
},
);

View file

@ -1,14 +1,9 @@
use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionItems,
};
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
@ -21,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test meters
let meter1 = StripeMeter {
id: StripeMeterId("meter_1".into()),
event_name: "event_1".to_string(),
};
let meter2 = StripeMeter {
id: StripeMeterId("meter_2".into()),
event_name: "event_2".to_string(),
};
stripe_client
.meters
.lock()
.insert(meter1.id.clone(), meter1);
stripe_client
.meters
.lock()
.insert(meter2.id.clone(), meter2);
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
@ -144,217 +121,3 @@ async fn test_find_or_create_customer_by_email() {
assert_eq!(customer.email.as_deref(), Some(email));
}
}
#[gpui::test]
async fn test_subscribe_to_price() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let price = StripePrice {
id: StripePriceId("price_test".into()),
unit_amount: Some(2000),
lookup_key: Some("test-price".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
let update_subscription_calls = stripe_client
.update_subscription_calls
.lock()
.iter()
.map(|(id, params)| (id.clone(), params.clone()))
.collect::<Vec<_>>();
assert_eq!(update_subscription_calls.len(), 1);
assert_eq!(update_subscription_calls[0].0, subscription.id);
assert_eq!(
update_subscription_calls[0].1.items,
Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone())
}])
);
// Subscribing to a price that is already on the subscription is a no-op.
{
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
}
}
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
stripe_billing
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
.await
.unwrap();
let create_meter_event_calls = stripe_client
.create_meter_event_calls
.lock()
.iter()
.cloned()
.collect::<Vec<_>>();
assert_eq!(create_meter_event_calls.len(), 1);
assert!(
create_meter_event_calls[0]
.identifier
.starts_with("model_requests/")
);
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
assert_eq!(
create_meter_event_calls[0].event_name.as_ref(),
"some_model/requests"
);
assert_eq!(create_meter_event_calls[0].value, 73);
}

View file

@ -297,6 +297,7 @@ impl TestServer {
client_name,
Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)),
Some("test".to_string()),
None,
None,
None,

View file

@ -103,28 +103,16 @@ impl ChatPanel {
});
cx.new(|cx| {
let entity = cx.entity().downgrade();
let message_list = ListState::new(
0,
gpui::ListAlignment::Bottom,
px(1000.),
move |ix, window, cx| {
if let Some(entity) = entity.upgrade() {
entity.update(cx, |this: &mut Self, cx| {
this.render_message(ix, window, cx).into_any_element()
})
} else {
div().into_any()
}
},
);
let message_list = ListState::new(0, gpui::ListAlignment::Bottom, px(1000.));
message_list.set_scroll_handler(cx.listener(|this, event: &ListScrollEvent, _, cx| {
if event.visible_range.start < MESSAGE_LOADING_THRESHOLD {
this.load_more_messages(cx);
}
this.is_scrolled_to_bottom = !event.is_scrolled;
}));
message_list.set_scroll_handler(cx.listener(
|this: &mut Self, event: &ListScrollEvent, _, cx| {
if event.visible_range.start < MESSAGE_LOADING_THRESHOLD {
this.load_more_messages(cx);
}
this.is_scrolled_to_bottom = !event.is_scrolled;
},
));
let local_offset = chrono::Local::now().offset().local_minus_utc();
let mut this = Self {
@ -399,7 +387,7 @@ impl ChatPanel {
ix: usize,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
) -> AnyElement {
let active_chat = &self.active_chat.as_ref().unwrap().0;
let (message, is_continuation_from_previous, is_admin) =
active_chat.update(cx, |active_chat, cx| {
@ -582,6 +570,7 @@ impl ChatPanel {
self.render_popover_buttons(message_id, can_delete_message, can_edit_message, cx)
.mt_neg_2p5(),
)
.into_any_element()
}
fn has_open_menu(&self, message_id: Option<u64>) -> bool {
@ -979,7 +968,13 @@ impl Render for ChatPanel {
)
.child(div().flex_grow().px_2().map(|this| {
if self.active_chat.is_some() {
this.child(list(self.message_list.clone()).size_full())
this.child(
list(
self.message_list.clone(),
cx.processor(Self::render_message),
)
.size_full(),
)
} else {
this.child(
div()

View file

@ -324,20 +324,6 @@ impl CollabPanel {
)
.detach();
let entity = cx.entity().downgrade();
let list_state = ListState::new(
0,
gpui::ListAlignment::Top,
px(1000.),
move |ix, window, cx| {
if let Some(entity) = entity.upgrade() {
entity.update(cx, |this, cx| this.render_list_entry(ix, window, cx))
} else {
div().into_any()
}
},
);
let mut this = Self {
width: None,
focus_handle: cx.focus_handle(),
@ -345,7 +331,7 @@ impl CollabPanel {
fs: workspace.app_state().fs.clone(),
pending_serialization: Task::ready(None),
context_menu: None,
list_state,
list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)),
channel_name_editor,
filter_editor,
entries: Vec::default(),
@ -2431,7 +2417,13 @@ impl CollabPanel {
});
v_flex()
.size_full()
.child(list(self.list_state.clone()).size_full())
.child(
list(
self.list_state.clone(),
cx.processor(Self::render_list_entry),
)
.size_full(),
)
.child(
v_flex()
.child(div().mx_2().border_primary(cx).border_t_1())
@ -2605,7 +2597,7 @@ impl CollabPanel {
let contact = contact.clone();
move |this, event: &ClickEvent, window, cx| {
this.deploy_contact_context_menu(
event.down.position,
event.position(),
contact.clone(),
window,
cx,
@ -3061,7 +3053,7 @@ impl Render for CollabPanel {
.on_action(cx.listener(CollabPanel::move_channel_down))
.track_focus(&self.focus_handle)
.size_full()
.child(if self.user_store.read(cx).current_user().is_none() {
.child(if !self.client.status().borrow().is_connected() {
self.render_signed_out(cx)
} else {
self.render_signed_in(window, cx)

Some files were not shown because too many files have changed in this diff Show more