diff --git a/.github/actions/build_docs/action.yml b/.github/actions/build_docs/action.yml index 9a2d7e1ec7..a7effad247 100644 --- a/.github/actions/build_docs/action.yml +++ b/.github/actions/build_docs/action.yml @@ -19,7 +19,7 @@ runs: shell: bash -euxo pipefail {0} run: ./script/linux - - name: Check for broken links + - name: Check for broken links (in MD) uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 with: args: --no-progress --exclude '^http' './docs/src/**/*' @@ -30,3 +30,9 @@ runs: run: | mkdir -p target/deploy mdbook build ./docs --dest-dir=../target/deploy/docs/ + + - name: Check for broken links (in HTML) + uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1 + with: + args: --no-progress --exclude '^http' 'target/deploy/docs/' + fail: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 009fcc8337..7dfc33e0d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -771,7 +771,8 @@ jobs: timeout-minutes: 120 name: Create a Windows installer runs-on: [self-hosted, Windows, X64] - if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') + # if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling')) needs: [windows_tests] env: AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }} diff --git a/Cargo.lock b/Cargo.lock index f31ecdef99..767d9a4c9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,7 +114,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "proto", "rand 0.8.5", "ref-cast", "rope", @@ -357,10 +356,10 @@ name = "ai_onboarding" version = "0.1.0" dependencies = [ "client", + "cloud_llm_client", "component", "gpui", "language_model", - "proto", "serde", "smallvec", "telemetry", @@ -1077,17 +1076,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "async-recursion" version = "1.1.1" @@ -2973,11 +2961,11 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", "clock", + "cloud_api_client", "cloud_llm_client", "cocoa 0.26.0", "collections", @@ -3033,6 +3021,31 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_api_client" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_api_types", + "futures 0.3.31", + "http_client", + "parking_lot", + "serde_json", + "workspace-hack", +] + +[[package]] +name = "cloud_api_types" +version = "0.1.0" +dependencies = [ + "chrono", + "cloud_llm_client", + "pretty_assertions", + "serde", + "serde_json", + "workspace-hack", +] + [[package]] name = "cloud_llm_client" version = "0.1.0" @@ -4271,41 +4284,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.101", -] - -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core", - "quote", - "syn 2.0.101", -] - [[package]] name = "dashmap" version = "5.5.3" @@ -4521,37 +4499,6 @@ dependencies = [ "serde", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.101", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.101", -] - [[package]] name = "derive_more" version = "0.99.19" @@ -4962,6 +4909,7 @@ dependencies = [ "theme", "time", "tree-sitter-bash", + "tree-sitter-c", "tree-sitter-html", "tree-sitter-python", "tree-sitter-rust", @@ -5930,7 +5878,7 @@ dependencies = [ "ignore", "libc", "log", - "notify", + "notify 8.0.0", "objc", "parking_lot", "paths", @@ -7369,9 +7317,8 @@ dependencies = [ "wayland-backend", "wayland-client", "wayland-cursor", - "wayland-protocols 0.31.2", + "wayland-protocols", "wayland-protocols-plasma", - "wayland-protocols-wlr", "windows 0.61.1", "windows-core 0.61.0", "windows-numerics", @@ -7488,18 +7435,16 @@ dependencies = [ [[package]] name = "handlebars" -version = "6.3.2" +version = "5.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759e2d5aea3287cb1190c8ec394f42866cb5bf74fcbf213f354e3c856ea26098" +checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b" dependencies = [ - "derive_builder", "log", - "num-order", "pest", "pest_derive", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 1.0.69", ] [[package]] @@ -7680,12 +7625,6 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -[[package]] -name = "hex-literal" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcaaec4551594c969335c98c903c1397853d4198408ea609190f420500f6be71" - [[package]] name = "hexf-parse" version = "0.2.1" @@ -7863,6 +7802,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "log", + "parking_lot", "serde", "serde_json", "url", @@ -8170,12 +8110,6 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "1.0.3" @@ -8394,6 +8328,17 @@ dependencies = [ "zeta", ] +[[package]] +name = "inotify" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" +dependencies = [ + "bitflags 1.3.2", + "inotify-sys", + "libc", +] + [[package]] name = "inotify" version = "0.11.0" @@ -8547,7 +8492,7 @@ dependencies = [ "fnv", "lazy_static", "libc", - "mio", + "mio 1.0.3", "rand 0.8.5", "serde", "tempfile", @@ -9129,7 +9074,6 @@ dependencies = [ "open_router", "partial-json-fixer", "project", - "proto", "release_channel", "schemars", "serde", @@ -9401,7 +9345,7 @@ dependencies = [ [[package]] name = "libwebrtc" version = "0.3.10" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cxx", "jni", @@ -9481,7 +9425,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "livekit" version = "0.7.8" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "chrono", "futures-util", @@ -9504,7 +9448,7 @@ dependencies = [ [[package]] name = "livekit-api" version = "0.4.2" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "http 0.2.12", @@ -9528,7 +9472,7 @@ dependencies = [ [[package]] name = "livekit-protocol" version = "0.3.9" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "futures-util", "livekit-runtime", @@ -9545,7 +9489,7 @@ dependencies = [ [[package]] name = "livekit-runtime" version = "0.4.0" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "tokio", "tokio-stream", @@ -9867,7 +9811,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 1.1.1", + "async-recursion", "collections", "editor", "fs", @@ -9987,9 +9931,9 @@ dependencies = [ [[package]] name = "mdbook" -version = "0.4.48" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6fbb4ac2d9fd7aa987c3510309ea3c80004a968d063c42f0d34fea070817c1" +checksum = "b45a38e19bd200220ef07c892b0157ad3d2365e5b5a267ca01ad12182491eea5" dependencies = [ "ammonia", "anyhow", @@ -9999,12 +9943,11 @@ dependencies = [ "elasticlunr-rs", "env_logger 0.11.8", "futures-util", - "handlebars 6.3.2", - "hex", + "handlebars 5.1.2", "ignore", "log", "memchr", - "notify", + "notify 6.1.1", "notify-debouncer-mini", "once_cell", "opener", @@ -10013,7 +9956,6 @@ dependencies = [ "regex", "serde", "serde_json", - "sha2", "shlex", "tempfile", "tokio", @@ -10156,6 +10098,18 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e53debba6bda7a793e5f99b8dacf19e626084f525f7829104ba9898f367d85ff" +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.3" @@ -10502,6 +10456,25 @@ dependencies = [ "zed_actions", ] +[[package]] +name = "notify" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" +dependencies = [ + "bitflags 2.9.0", + "crossbeam-channel", + "filetime", + "fsevent-sys 4.1.0", + "inotify 0.9.6", + "kqueue", + "libc", + "log", + "mio 0.8.11", + "walkdir", + "windows-sys 0.48.0", +] + [[package]] name = "notify" version = "8.0.0" @@ -10510,11 +10483,11 @@ dependencies = [ "bitflags 2.9.0", "filetime", "fsevent-sys 4.1.0", - "inotify", + "inotify 0.11.0", "kqueue", "libc", "log", - "mio", + "mio 1.0.3", "notify-types", "walkdir", "windows-sys 0.59.0", @@ -10522,14 +10495,13 @@ dependencies = [ [[package]] name = "notify-debouncer-mini" -version = "0.6.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a689eb4262184d9a1727f9087cd03883ea716682ab03ed24efec57d7716dccb8" +checksum = "5d40b221972a1fc5ef4d858a2f671fb34c75983eb385463dff3780eeff6a9d43" dependencies = [ + "crossbeam-channel", "log", - "notify", - "notify-types", - "tempfile", + "notify 6.1.1", ] [[package]] @@ -10669,21 +10641,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-modular" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f" - -[[package]] -name = "num-order" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6" -dependencies = [ - "num-modular", -] - [[package]] name = "num-rational" version = "0.4.2" @@ -10954,21 +10911,33 @@ dependencies = [ name = "onboarding" version = "0.1.0" dependencies = [ + "ai_onboarding", "anyhow", + "client", "command_palette_hooks", + "component", "db", + "documented", "editor", "feature_flags", "fs", "gpui", + "itertools 0.14.0", "language", + "language_model", + "menu", "project", + "schemars", + "serde", "settings", "theme", "ui", + "util", + "vim_mode_setting", "workspace", "workspace-hack", "zed_actions", + "zlog", ] [[package]] @@ -14732,6 +14701,27 @@ dependencies = [ "zlog", ] +[[package]] +name = "settings_profile_selector" +version = "0.1.0" +dependencies = [ + "client", + "editor", + "fuzzy", + "gpui", + "language", + "menu", + "picker", + "project", + "serde_json", + "settings", + "theme", + "ui", + "workspace", + "workspace-hack", + "zed_actions", +] + [[package]] name = "settings_ui" version = "0.1.0" @@ -14754,7 +14744,6 @@ dependencies = [ "notifications", "paths", "project", - "schemars", "search", "serde", "serde_json", @@ -16191,7 +16180,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion 1.1.1", + "async-recursion", "breadcrumbs", "client", "collections", @@ -16540,6 +16529,7 @@ dependencies = [ "call", "chrono", "client", + "cloud_llm_client", "collections", "db", "gpui", @@ -16575,7 +16565,7 @@ dependencies = [ "backtrace", "bytes 1.10.1", "libc", - "mio", + "mio 1.0.3", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -18388,9 +18378,9 @@ dependencies = [ [[package]] name = "wayland-backend" -version = "0.3.10" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe770181423e5fc79d3e2a7f4410b7799d5aab1de4372853de3c6aa13ca24121" +checksum = "b7208998eaa3870dad37ec8836979581506e0c5c64c20c9e79e9d2a10d6f47bf" dependencies = [ "cc", "downcast-rs", @@ -18402,9 +18392,9 @@ dependencies = [ [[package]] name = "wayland-client" -version = "0.31.10" +version = "0.31.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978fa7c67b0847dbd6a9f350ca2569174974cd4082737054dbb7fbb79d7d9a61" +checksum = "c2120de3d33638aaef5b9f4472bff75f07c56379cf76ea320bd3a3d65ecaf73f" dependencies = [ "bitflags 2.9.0", "rustix 0.38.44", @@ -18435,18 +18425,6 @@ dependencies = [ "wayland-scanner", ] -[[package]] -name = "wayland-protocols" -version = "0.32.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "779075454e1e9a521794fed15886323ea0feda3f8b0fc1390f5398141310422a" -dependencies = [ - "bitflags 2.9.0", - "wayland-backend", - "wayland-client", - "wayland-scanner", -] - [[package]] name = "wayland-protocols-plasma" version = "0.2.0" @@ -18456,20 +18434,7 @@ dependencies = [ "bitflags 2.9.0", "wayland-backend", "wayland-client", - "wayland-protocols 0.31.2", - "wayland-scanner", -] - -[[package]] -name = "wayland-protocols-wlr" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cb6cdc73399c0e06504c437fe3cf886f25568dd5454473d565085b36d6a8bbf" -dependencies = [ - "bitflags 2.9.0", - "wayland-backend", - "wayland-client", - "wayland-protocols 0.32.8", + "wayland-protocols", "wayland-scanner", ] @@ -18578,7 +18543,7 @@ dependencies = [ [[package]] name = "webrtc-sys" version = "0.3.7" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "cc", "cxx", @@ -18591,15 +18556,13 @@ dependencies = [ [[package]] name = "webrtc-sys-build" version = "0.3.6" -source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd" +source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d" dependencies = [ "fs2", - "hex-literal", "regex", "reqwest 0.11.27", "scratch", "semver", - "sha2", "zip", ] @@ -18628,7 +18591,6 @@ dependencies = [ "serde", "settings", "telemetry", - "theme", "ui", "util", "vim_mode_setting", @@ -19643,7 +19605,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion 1.1.1", + "async-recursion", "bincode", "call", "client", @@ -19777,7 +19739,7 @@ dependencies = [ "md-5", "memchr", "miniz_oxide", - "mio", + "mio 1.0.3", "naga", "nix 0.29.0", "nom", @@ -20168,7 +20130,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion 1.1.1", + "async-recursion", "async-task", "async-trait", "blocking", @@ -20221,7 +20183,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.198.0" +version = "0.199.0" dependencies = [ "activity_indicator", "agent", @@ -20324,6 +20286,7 @@ dependencies = [ "serde_json", "session", "settings", + "settings_profile_selector", "settings_ui", "shellexpand 2.1.2", "smol", @@ -20600,6 +20563,7 @@ dependencies = [ "call", "client", "clock", + "cloud_api_types", "cloud_llm_client", "collections", "command_palette_hooks", @@ -20620,7 +20584,6 @@ dependencies = [ "menu", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", @@ -20645,6 +20608,42 @@ dependencies = [ "zlog", ] +[[package]] +name = "zeta_cli" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "client", + "debug_adapter_extension", + "extension", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "language", + "language_extension", + "language_model", + "language_models", + "languages", + "node_runtime", + "paths", + "project", + "prompt_store", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "settings", + "shellexpand 2.1.2", + "smol", + "terminal_view", + "util", + "watch", + "workspace-hack", + "zeta", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index df1e1d7467..902c7e8d19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [workspace] resolver = "2" members = [ - "crates/activity_indicator", "crates/acp_thread", - "crates/agent_ui", + "crates/activity_indicator", "crates/agent", - "crates/agent_settings", - "crates/ai_onboarding", "crates/agent_servers", + "crates/agent_settings", + "crates/agent_ui", + "crates/ai_onboarding", "crates/anthropic", "crates/askpass", "crates/assets", @@ -29,6 +29,8 @@ members = [ "crates/cli", "crates/client", "crates/clock", + "crates/cloud_api_client", + "crates/cloud_api_types", "crates/cloud_llm_client", "crates/collab", "crates/collab_ui", @@ -49,8 +51,8 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", - "crates/explorer_command_injector", "crates/eval", + "crates/explorer_command_injector", "crates/extension", "crates/extension_api", "crates/extension_cli", @@ -99,7 +101,6 @@ members = [ "crates/markdown_preview", "crates/media", "crates/menu", - "crates/svg_preview", "crates/migrator", "crates/mistral", "crates/multi_buffer", @@ -140,6 +141,7 @@ members = [ "crates/semantic_version", "crates/session", "crates/settings", + "crates/settings_profile_selector", "crates/settings_ui", "crates/snippet", "crates/snippet_provider", @@ -152,6 +154,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/svg_preview", "crates/tab_switcher", "crates/task", "crates/tasks_ui", @@ -186,6 +189,7 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zeta", + "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -251,6 +255,8 @@ channel = { path = "crates/channel" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } +cloud_api_client = { path = "crates/cloud_api_client" } +cloud_api_types = { path = "crates/cloud_api_types" } cloud_llm_client = { path = "crates/cloud_llm_client" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } @@ -338,6 +344,7 @@ picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } +settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } project_panel = { path = "crates/project_panel" } project_symbols = { path = "crates/project_symbols" } @@ -672,14 +679,16 @@ features = [ "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", - "Win32_Graphics_Direct2D", - "Win32_Graphics_Direct2D_Common", + "Win32_Graphics_Direct3D", + "Win32_Graphics_Direct3D11", + "Win32_Graphics_Direct3D_Fxc", + "Win32_Graphics_DirectComposition", "Win32_Graphics_DirectWrite", "Win32_Graphics_Dwm", + "Win32_Graphics_Dxgi", "Win32_Graphics_Dxgi_Common", "Win32_Graphics_Gdi", "Win32_Graphics_Imaging", - "Win32_Graphics_Imaging_D2D", "Win32_Networking_WinSock", "Win32_Security", "Win32_Security_Credentials", diff --git a/assets/icons/editor_atom.svg b/assets/icons/editor_atom.svg new file mode 100644 index 0000000000..cc5fa83843 --- /dev/null +++ b/assets/icons/editor_atom.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_cursor.svg b/assets/icons/editor_cursor.svg new file mode 100644 index 0000000000..338697be8a --- /dev/null +++ b/assets/icons/editor_cursor.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/assets/icons/editor_emacs.svg b/assets/icons/editor_emacs.svg new file mode 100644 index 0000000000..951d7b2be1 --- /dev/null +++ b/assets/icons/editor_emacs.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/assets/icons/editor_jet_brains.svg b/assets/icons/editor_jet_brains.svg new file mode 100644 index 0000000000..7d9cf0c65c --- /dev/null +++ b/assets/icons/editor_jet_brains.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/editor_sublime.svg b/assets/icons/editor_sublime.svg new file mode 100644 index 0000000000..95a04f6b54 --- /dev/null +++ b/assets/icons/editor_sublime.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/icons/editor_vs_code.svg b/assets/icons/editor_vs_code.svg new file mode 100644 index 0000000000..2a71ad52af --- /dev/null +++ b/assets/icons/editor_vs_code.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/shield_check.svg b/assets/icons/shield_check.svg new file mode 100644 index 0000000000..6e58c31468 --- /dev/null +++ b/assets/icons/shield_check.svg @@ -0,0 +1,4 @@ + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index e36e093e22..ef5354e82d 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -232,7 +232,7 @@ "ctrl-n": "agent::NewThread", "ctrl-alt-n": "agent::NewTextThread", "ctrl-shift-h": "agent::OpenHistory", - "ctrl-alt-c": "agent::OpenConfiguration", + "ctrl-alt-c": "agent::OpenSettings", "ctrl-alt-p": "agent::OpenRulesLibrary", "ctrl-i": "agent::ToggleProfileSelector", "ctrl-alt-/": "agent::ToggleModelSelector", @@ -598,6 +598,7 @@ "ctrl-shift-t": "pane::ReopenClosedItem", "ctrl-k ctrl-s": "zed::OpenKeymapEditor", "ctrl-k ctrl-t": "theme_selector::Toggle", + "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -1167,5 +1168,14 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "ctrl-1": "onboarding::ActivateBasicsPage", + "ctrl-2": "onboarding::ActivateEditingPage", + "ctrl-3": "onboarding::ActivateAISetupPage" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 0114e2da1d..3287e50acb 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -272,7 +272,7 @@ "cmd-n": "agent::NewThread", "cmd-alt-n": "agent::NewTextThread", "cmd-shift-h": "agent::OpenHistory", - "cmd-alt-c": "agent::OpenConfiguration", + "cmd-alt-c": "agent::OpenSettings", "cmd-alt-p": "agent::OpenRulesLibrary", "cmd-i": "agent::ToggleProfileSelector", "cmd-alt-/": "agent::ToggleModelSelector", @@ -665,6 +665,7 @@ "cmd-shift-t": "pane::ReopenClosedItem", "cmd-k cmd-s": "zed::OpenKeymapEditor", "cmd-k cmd-t": "theme_selector::Toggle", + "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", "ctrl-tab": "tab_switcher::Toggle", @@ -1269,5 +1270,14 @@ "up": "menu::SelectPrevious", "down": "menu::SelectNext" } + }, + { + "context": "Onboarding", + "use_key_equivalents": true, + "bindings": { + "cmd-1": "onboarding::ActivateBasicsPage", + "cmd-2": "onboarding::ActivateEditingPage", + "cmd-3": "onboarding::ActivateAISetupPage" + } } ] diff --git a/assets/keymaps/linux/cursor.json b/assets/keymaps/linux/cursor.json index 347b7885fc..1c381b0cf0 100644 --- a/assets/keymaps/linux/cursor.json +++ b/assets/keymaps/linux/cursor.json @@ -8,7 +8,7 @@ "ctrl-shift-i": "agent::ToggleFocus", "ctrl-l": "agent::ToggleFocus", "ctrl-shift-l": "agent::ToggleFocus", - "ctrl-shift-j": "agent::OpenConfiguration" + "ctrl-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/linux/jetbrains.json b/assets/keymaps/linux/jetbrains.json index f81f363ae0..9bc1f24bfb 100644 --- a/assets/keymaps/linux/jetbrains.json +++ b/assets/keymaps/linux/jetbrains.json @@ -95,7 +95,7 @@ "ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "alt-shift-f10": "task::Spawn", "ctrl-e": "file_finder::Toggle", - "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "ctrl-shift-n": "file_finder::Toggle", "ctrl-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", diff --git a/assets/keymaps/macos/cursor.json b/assets/keymaps/macos/cursor.json index b1d39bef9e..fdf9c437cf 100644 --- a/assets/keymaps/macos/cursor.json +++ b/assets/keymaps/macos/cursor.json @@ -8,7 +8,7 @@ "cmd-shift-i": "agent::ToggleFocus", "cmd-l": "agent::ToggleFocus", "cmd-shift-l": "agent::ToggleFocus", - "cmd-shift-j": "agent::OpenConfiguration" + "cmd-shift-j": "agent::OpenSettings" } }, { diff --git a/assets/keymaps/macos/jetbrains.json b/assets/keymaps/macos/jetbrains.json index 5795d2ac7e..b1cd51a338 100644 --- a/assets/keymaps/macos/jetbrains.json +++ b/assets/keymaps/macos/jetbrains.json @@ -97,7 +97,7 @@ "cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }], "ctrl-alt-r": "task::Spawn", "cmd-e": "file_finder::Toggle", - "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor + // "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor "cmd-shift-o": "file_finder::Toggle", "cmd-shift-a": "command_palette::Toggle", "shift shift": "command_palette::Toggle", diff --git a/assets/settings/default.json b/assets/settings/default.json index 3a7a48efc2..4734b5d118 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1877,5 +1877,25 @@ "save_breakpoints": true, "dock": "bottom", "button": true - } + }, + // Configures any number of settings profiles that are temporarily applied on + // top of your existing user settings when selected from + // `settings profile selector: toggle`. + // Examples: + // "profiles": { + // "Presenting": { + // "agent_font_size": 20.0, + // "buffer_font_size": 20.0, + // "theme": "One Light", + // "ui_font_size": 20.0 + // }, + // "Python (ty)": { + // "languages": { + // "Python": { + // "language_servers": ["ty"] + // } + // } + // } + // } + "profiles": [] } diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 3bf6134862..d10fecdb28 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -580,6 +580,9 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + ToolAuthorizationRequired, + Stopped, + Error, } impl EventEmitter for AcpThread {} @@ -677,6 +680,18 @@ impl AcpThread { false } + pub fn used_tools_since_last_user_message(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(..) => return false, + AgentThreadEntry::AssistantMessage(..) => continue, + AgentThreadEntry::ToolCall(..) => return true, + } + } + + false + } + pub fn handle_session_update( &mut self, update: acp::SessionUpdate, @@ -880,6 +895,7 @@ impl AcpThread { }; self.upsert_tool_call_inner(tool_call, status, cx); + cx.emit(AcpThreadEvent::ToolAuthorizationRequired); rx } @@ -1015,12 +1031,18 @@ impl AcpThread { .log_err(); })); - async move { - match rx.await { - Ok(Err(e)) => Err(e)?, - _ => Ok(()), + cx.spawn(async move |this, cx| match rx.await { + Ok(Err(e)) => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error)) + .log_err(); + Err(e)? } - } + _ => { + this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) + .log_err(); + Ok(()) + } + }) .boxed() } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index c89a7f3303..7bc0e82cad 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -47,7 +47,6 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true -proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 0e5da2d43b..8558dd528d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -13,7 +13,7 @@ use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -37,7 +37,6 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -3255,8 +3254,10 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { user_store.update_model_request_usage( ModelRequestUsage(RequestUsage { amount: amount as i32, @@ -3264,8 +3265,7 @@ impl Thread { }), cx, ) - }) - }); + }); } pub fn deny_tool_use( diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 17575e42db..167f7e6136 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,5 +1,7 @@ use acp_thread::{AgentConnection, Plan}; use agent_servers::AgentServer; +use agent_settings::{AgentSettings, NotifyWhenAgentWaiting}; +use audio::{Audio, Sound}; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; @@ -18,10 +20,10 @@ use editor::{ use file_icons::FileIcons; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, - FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, - Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, - Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, - pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString, + StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, + UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient, + list, percentage, point, prelude::*, pulsating_between, }; use language::language_settings::SoftWrap; use language::{Buffer, Language}; @@ -45,7 +47,10 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES}; -use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll}; +use crate::ui::{AgentNotification, AgentNotificationEvent}; +use crate::{ + AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll, +}; const RESPONSE_PADDING_X: Pixels = px(19.); @@ -59,6 +64,8 @@ pub struct AcpThreadView { message_set_from_history: bool, _message_editor_subscription: Subscription, mention_set: Arc>, + notifications: Vec>, + notification_subscriptions: HashMap, Vec>, last_error: Option>, list_state: ListState, auth_task: Option>, @@ -174,6 +181,8 @@ impl AcpThreadView { message_set_from_history: false, _message_editor_subscription: message_editor_subscription, mention_set, + notifications: Vec::new(), + notification_subscriptions: HashMap::default(), diff_editors: Default::default(), list_state: list_state, last_error: None, @@ -382,7 +391,9 @@ impl AcpThreadView { return; } - let Some(thread) = self.thread() else { return }; + let Some(thread) = self.thread() else { + return; + }; let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); cx.spawn(async move |this, cx| { @@ -565,6 +576,30 @@ impl AcpThreadView { self.sync_thread_entry_view(index, window, cx); self.list_state.splice(index..index + 1, 1); } + AcpThreadEvent::ToolAuthorizationRequired => { + self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); + } + AcpThreadEvent::Stopped => { + let used_tools = thread.read(cx).used_tools_since_last_user_message(); + self.notify_with_sound( + if used_tools { + "Finished running tools" + } else { + "New message" + }, + IconName::ZedAssistant, + window, + cx, + ); + } + AcpThreadEvent::Error => { + self.notify_with_sound( + "Agent stopped due to an error", + IconName::Warning, + window, + cx, + ); + } } cx.notify(); } @@ -2166,6 +2201,154 @@ impl AcpThreadView { self.list_state.scroll_to(ListOffset::default()); cx.notify(); } + + fn notify_with_sound( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + self.play_notification_sound(window, cx); + self.show_notification(caption, icon, window, cx); + } + + fn play_notification_sound(&self, window: &Window, cx: &mut App) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done && !window.is_window_active() { + Audio::play_sound(Sound::AgentDone, cx); + } + } + + fn show_notification( + &mut self, + caption: impl Into, + icon: IconName, + window: &mut Window, + cx: &mut Context, + ) { + if window.is_window_active() || !self.notifications.is_empty() { + return; + } + + let title = self.title(cx); + + match AgentSettings::get_global(cx).notify_when_agent_waiting { + NotifyWhenAgentWaiting::PrimaryScreen => { + if let Some(primary) = cx.primary_display() { + self.pop_up(icon, caption.into(), title, window, primary, cx); + } + } + NotifyWhenAgentWaiting::AllScreens => { + let caption = caption.into(); + for screen in cx.displays() { + self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); + } + } + NotifyWhenAgentWaiting::Never => { + // Don't show anything + } + } + } + + fn pop_up( + &mut self, + icon: IconName, + caption: SharedString, + title: SharedString, + window: &mut Window, + screen: Rc, + cx: &mut Context, + ) { + let options = AgentNotification::window_options(screen, cx); + + let project_name = self.workspace.upgrade().and_then(|workspace| { + workspace + .read(cx) + .project() + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).root_name().to_string()) + }); + + if let Some(screen_window) = cx + .open_window(options, |_, cx| { + cx.new(|_| { + AgentNotification::new(title.clone(), caption.clone(), icon, project_name) + }) + }) + .log_err() + { + if let Some(pop_up) = screen_window.entity(cx).log_err() { + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push(cx.subscribe_in(&pop_up, window, { + |this, _, event, window, cx| match event { + AgentNotificationEvent::Accepted => { + let handle = window.window_handle(); + cx.activate(true); + + let workspace_handle = this.workspace.clone(); + + // If there are multiple Zed windows, activate the correct one. + cx.defer(move |cx| { + handle + .update(cx, |_view, window, _cx| { + window.activate_window(); + + if let Some(workspace) = workspace_handle.upgrade() { + workspace.update(_cx, |workspace, cx| { + workspace.focus_panel::(window, cx); + }); + } + }) + .log_err(); + }); + + this.dismiss_notifications(cx); + } + AgentNotificationEvent::Dismissed => { + this.dismiss_notifications(cx); + } + } + })); + + self.notifications.push(screen_window); + + // If the user manually refocuses the original window, dismiss the popup. + self.notification_subscriptions + .entry(screen_window) + .or_insert_with(Vec::new) + .push({ + let pop_up_weak = pop_up.downgrade(); + + cx.observe_window_activation(window, move |_, window, cx| { + if window.is_window_active() { + if let Some(pop_up) = pop_up_weak.upgrade() { + pop_up.update(cx, |_, cx| { + cx.emit(AgentNotificationEvent::Dismissed); + }); + } + } + }) + }); + } + } + } + + fn dismiss_notifications(&mut self, cx: &mut Context) { + for window in self.notifications.drain(..) { + window + .update(cx, |_, window, _| { + window.remove_window(); + }) + .ok(); + + self.notification_subscriptions.remove(&window); + } + } } impl Focusable for AcpThreadView { @@ -2451,3 +2634,331 @@ fn plan_label_markdown_style( ..default_md_style } } + +#[cfg(test)] +mod tests { + use agent_client_protocol::SessionId; + use editor::EditorSettings; + use fs::FakeFs; + use futures::future::try_join_all; + use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use rand::Rng; + use settings::SettingsStore; + + use super::*; + + #[gpui::test] + async fn test_notification_for_stop_event(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_error(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + #[gpui::test] + async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { + init_test(cx); + + let tool_call_id = acp::ToolCallId("1".into()); + let tool_call = acp::ToolCall { + id: tool_call_id.clone(), + label: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Pending, + content: vec!["hi".into()], + locations: vec![], + raw_input: None, + }; + let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) + .with_permission_requests(HashMap::from_iter([( + tool_call_id, + vec![acp::PermissionOption { + id: acp::PermissionOptionId("1".into()), + label: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }], + )])); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Hello", window, cx); + }); + + cx.deactivate_window(); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.chat(&Chat, window, cx); + }); + + cx.run_until_parked(); + + assert!( + cx.windows() + .iter() + .any(|window| window.downcast::().is_some()) + ); + } + + async fn setup_thread_view( + agent: impl AgentServer + 'static, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(agent), + workspace.downgrade(), + project, + Rc::new(RefCell::new(MessageHistory::default())), + 1, + None, + window, + cx, + ) + }) + }); + cx.run_until_parked(); + (thread_view, cx) + } + + struct StubAgentServer { + connection: C, + } + + impl StubAgentServer { + fn new(connection: C) -> Self { + Self { connection } + } + } + + impl StubAgentServer { + fn default() -> Self { + Self::new(StubAgentConnection::default()) + } + } + + impl AgentServer for StubAgentServer + where + C: 'static + AgentConnection + Send + Clone, + { + fn logo(&self) -> ui::IconName { + unimplemented!() + } + + fn name(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_headline(&self) -> &'static str { + unimplemented!() + } + + fn empty_state_message(&self) -> &'static str { + unimplemented!() + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Ok(Rc::new(self.connection.clone()))) + } + } + + #[derive(Clone, Default)] + struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + updates: Vec, + } + + impl StubAgentConnection { + fn new(updates: Vec) -> Self { + Self { + updates, + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + } + + impl AgentConnection for StubAgentConnection { + fn name(&self) -> &'static str { + "StubAgentConnection" + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + unimplemented!() + } + + fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in &self.updates { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_permission( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(()) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + #[derive(Clone)] + struct SaboteurAgentConnection; + + impl AgentConnection for SaboteurAgentConnection { + fn name(&self) -> &'static str { + "SaboteurAgentConnection" + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + Task::ready(Ok(cx + .new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx)) + .unwrap())) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + unimplemented!() + } + + fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow::anyhow!("Error prompting"))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index fae04188eb..dad930be9e 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; +use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -25,7 +26,6 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; -use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, @@ -180,7 +180,7 @@ impl AgentConfiguration { let current_plan = if is_zed_provider { self.workspace .upgrade() - .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) + .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) } else { None }; @@ -406,7 +406,9 @@ impl AgentConfiguration { SwitchField::new( "always-allow-tool-actions-switch", "Allow running commands without asking for confirmation", - "The agent can perform potentially destructive actions without asking for your confirmation.", + Some( + "The agent can perform potentially destructive actions without asking for your confirmation.".into(), + ), always_allow_tool_actions, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -424,7 +426,7 @@ impl AgentConfiguration { SwitchField::new( "single-file-review", "Enable single-file agent reviews", - "Agent edits are also displayed in single-file editors for review.", + Some("Agent edits are also displayed in single-file editors for review.".into()), single_file_review, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -442,7 +444,9 @@ impl AgentConfiguration { SwitchField::new( "sound-notification", "Play sound when finished generating", - "Hear a notification sound when the agent is done generating changes or needs your input.", + Some( + "Hear a notification sound when the agent is done generating changes or needs your input.".into(), + ), play_sound_when_agent_done, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -460,7 +464,9 @@ impl AgentConfiguration { SwitchField::new( "modifier-send", "Use modifier to submit a message", - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.", + Some( + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + ), use_modifier_to_send, move |state, _window, cx| { let allow = state == &ToggleState::Selected; @@ -502,7 +508,7 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::Free => ("Free", Color::Default, free_chip_bg), + Plan::ZedFree => ("Free", Color::Default, free_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index ec0a11f86b..c4dc359093 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1521,6 +1521,9 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } + AcpThreadEvent::Stopped + | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::Error => {} } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 875320372d..a09c669769 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -44,7 +44,7 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; use client::{DisableAiSettings, UserStore, zed_urls}; -use cloud_llm_client::{CompletionIntent, UsageLimit}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; @@ -60,7 +60,6 @@ use language_model::{ }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; -use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -78,7 +77,7 @@ use workspace::{ }; use zed_actions::{ DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize, - agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector}, + agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector}, assistant::{OpenRulesLibrary, ToggleFocus}, }; @@ -105,7 +104,7 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.open_history(window, cx)); } }) - .register_action(|workspace, _: &OpenConfiguration, window, cx| { + .register_action(|workspace, _: &OpenSettings, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); @@ -579,7 +578,6 @@ impl AgentPanel { MessageEditor::new( fs.clone(), workspace.clone(), - user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), @@ -848,7 +846,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1122,7 +1119,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), @@ -2074,7 +2070,7 @@ impl AgentPanel { menu = menu .action("Rules…", Box::new(OpenRulesLibrary::default())) - .action("Settings", Box::new(OpenConfiguration)) + .action("Settings", Box::new(OpenSettings)) .action(zoom_in_label, Box::new(ToggleZoom)); menu })) @@ -2279,10 +2275,10 @@ impl AgentPanel { | ActiveView::Configuration => return false, } - let plan = self.user_store.read(cx).current_plan(); + let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::Free)) && has_previous_trial + matches!(plan, Some(Plan::ZedFree)) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2468,14 +2464,14 @@ impl AgentPanel { .icon_color(Color::Muted) .full_width() .key_binding(KeyBinding::for_action_in( - &OpenConfiguration, + &OpenSettings, &focus_handle, window, cx, )) .on_click(|_event, window, cx| { window.dispatch_action( - OpenConfiguration.boxed_clone(), + OpenSettings.boxed_clone(), cx, ) }), @@ -2680,16 +2676,11 @@ impl AgentPanel { .style(ButtonStyle::Tinted(ui::TintColor::Warning)) .label_size(LabelSize::Small) .key_binding( - KeyBinding::for_action_in( - &OpenConfiguration, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), + KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))), ) .on_click(|_event, window, cx| { - window.dispatch_action(OpenConfiguration.boxed_clone(), cx) + window.dispatch_action(OpenSettings.boxed_clone(), cx) }), ), ConfigurationError::ProviderPendingTermsAcceptance(provider) => { @@ -2883,7 +2874,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) @@ -3193,7 +3184,7 @@ impl Render for AgentPanel { .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); })) - .on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| { + .on_action(cx.listener(|this, _: &OpenSettings, window, cx| { this.open_configuration(window, cx); })) .on_action(cx.listener(Self::open_active_thread_as_markdown)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 6ae78585de..c5574c2371 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -263,8 +263,8 @@ fn update_command_palette_filter(cx: &mut App) { filter.hide_namespace("agent"); filter.hide_namespace("assistant"); filter.hide_namespace("copilot"); + filter.hide_namespace("supermaven"); filter.hide_namespace("zed_predict_onboarding"); - filter.hide_namespace("edit_prediction"); use editor::actions::{ diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 44ec050ae2..ffa654d12b 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _}; use ui::prelude::*; use util::{RangeExt, ResultExt, maybe}; use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId}; -use zed_actions::agent::OpenConfiguration; +use zed_actions::agent::OpenSettings; pub fn init( fs: Arc, @@ -345,7 +345,7 @@ impl InlineAssistant { if let Some(answer) = answer { if answer == 0 { cx.update(|window, cx| { - window.dispatch_action(Box::new(OpenConfiguration), cx) + window.dispatch_action(Box::new(OpenSettings), cx) }) .ok(); } diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 655e87d7cd..7121624c87 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), + zed_actions::agent::OpenSettings.boxed_clone(), cx, ); }), diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 082d1dfb51..2185885347 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -17,7 +17,6 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::UserStore; use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; @@ -43,7 +42,6 @@ use language_model::{ use multi_buffer; use project::Project; use prompt_store::PromptStore; -use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -79,7 +77,6 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - user_store: Entity, context_store: Entity, prompt_store: Option>, history_store: Option>, @@ -159,7 +156,6 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, @@ -231,7 +227,6 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, @@ -1287,24 +1282,12 @@ impl MessageEditor { return None; } - let user_store = self.user_store.read(cx); - - let ubb_enable = user_store - .usage_based_billing_enabled() - .map_or(false, |enabled| enabled); - - if ubb_enable { + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { return None; } - let plan = user_store - .current_plan() - .map(|plan| match plan { - Plan::Free => cloud_llm_client::Plan::ZedFree, - Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }) - .unwrap_or(cloud_llm_client::Plan::ZedFree); + let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); let usage = user_store.model_request_usage()?; @@ -1769,7 +1752,6 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let user_store = workspace.read(cx).app_state().user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1782,7 +1764,6 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - user_store, context_store, None, thread_store.downgrade(), diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 9031e14e29..95a45b1a6f 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -16,10 +16,10 @@ default = [] [dependencies] client.workspace = true +cloud_llm_client.workspace = true component.workspace = true gpui.workspace = true language_model.workspace = true -proto.workspace = true serde.workspace = true smallvec.workspace = true telemetry.workspace = true diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index 5f56e4d26e..e86568fe7a 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -136,10 +136,7 @@ impl RenderOnce for ApiKeysWithoutProviders { .full_width() .style(ButtonStyle::Outlined) .on_click(move |_, window, cx| { - window.dispatch_action( - zed_actions::agent::OpenConfiguration.boxed_clone(), - cx, - ); + window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); }), ) } diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index e8a62f7ff2..f1629eeff8 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, UserStore}; +use cloud_llm_client::Plan; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::prelude::*; @@ -56,15 +57,8 @@ impl AgentPanelOnboarding { impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedProTrial) - ); - - let is_pro_user = matches!( - self.user_store.read(cx).current_plan(), - Some(proto::Plan::ZedPro) - ); + let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); + let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); AgentPanelOnboardingCard::new() .child( diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index 3aec9c62cd..c252b65f20 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -9,6 +9,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use ai_upsell_card::AiUpsellCard; +use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use young_account_banner::YoungAccountBanner; @@ -79,7 +80,7 @@ impl From for SignInStatus { pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, pub has_accepted_terms_of_service: bool, - pub plan: Option, + pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, @@ -99,8 +100,8 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), - plan: store.current_plan(), + has_accepted_terms_of_service: store.has_accepted_terms_of_service(), + plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, accept_terms_of_service: Arc::new({ @@ -113,11 +114,9 @@ impl ZedAiOnboarding { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), dismiss_onboarding: None, } @@ -411,9 +410,9 @@ impl RenderOnce for ZedAiOnboarding { if matches!(self.sign_in_status, SignInStatus::SignedIn) { if self.has_accepted_terms_of_service { match self.plan { - None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), - Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), - Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_accept_terms_of_service() @@ -433,7 +432,7 @@ impl Component for ZedAiOnboarding { fn onboarding( sign_in_status: SignInStatus, has_accepted_terms_of_service: bool, - plan: Option, + plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { @@ -468,25 +467,15 @@ impl Component for ZedAiOnboarding { ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), ), single_example( "Pro Trial", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedProTrial), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), ), single_example( "Pro Plan", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedPro), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs index 041e0d87ec..2408b6aa37 100644 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use client::{Client, zed_urls}; +use cloud_llm_client::Plan; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use ui::{Divider, List, Vector, VectorName, prelude::*}; @@ -10,22 +11,22 @@ use crate::{BulletItem, SignInStatus}; pub struct AiUpsellCard { pub sign_in_status: SignInStatus, pub sign_in: Arc, + pub user_plan: Option, } impl AiUpsellCard { - pub fn new(client: Arc) -> Self { + pub fn new(client: Arc, user_plan: Option) -> Self { let status = *client.status().borrow(); Self { + user_plan, sign_in_status: status.into(), sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), } } @@ -34,6 +35,7 @@ impl AiUpsellCard { impl RenderOnce for AiUpsellCard { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { let pro_section = v_flex() + .flex_grow() .w_full() .gap_1() .child( @@ -56,6 +58,7 @@ impl RenderOnce for AiUpsellCard { ); let free_section = v_flex() + .flex_grow() .w_full() .gap_1() .child( @@ -71,7 +74,7 @@ impl RenderOnce for AiUpsellCard { ) .child( List::new() - .child(BulletItem::new("50 prompts with the Claude models")) + .child(BulletItem::new("50 prompts with Claude models")) .child(BulletItem::new("2,000 accepted edit predictions")), ); @@ -132,22 +135,28 @@ impl RenderOnce for AiUpsellCard { v_flex() .relative() - .p_6() - .pt_4() + .p_4() + .pt_3() .border_1() .border_color(cx.theme().colors().border) .rounded_lg() .overflow_hidden() .child(grid_bg) .child(gradient_bg) - .child(Headline::new("Try Zed AI")) - .child(Label::new(DESCRIPTION).color(Color::Muted).mb_2()) + .child(Label::new("Try Zed AI").size(LabelSize::Large)) + .child( + div() + .max_w_3_4() + .mb_2() + .child(Label::new(DESCRIPTION).color(Color::Muted)), + ) .child( h_flex() + .w_full() .mt_1p5() .mb_2p5() .items_start() - .gap_12() + .gap_6() .child(free_section) .child(pro_section), ) @@ -183,6 +192,7 @@ impl Component for AiUpsellCard { AiUpsellCard { sign_in_status: SignInStatus::SignedOut, sign_in: Arc::new(|_, _| {}), + user_plan: None, } .into_any_element(), ), @@ -191,6 +201,7 @@ impl Component for AiUpsellCard { AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), + user_plan: None, } .into_any_element(), ), diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index b7ba811421..4ad156b9fb 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -126,7 +126,7 @@ impl ChannelMembership { proto::channel_member::Kind::Member => 0, proto::channel_member::Kind::Invitee => 1, }, - username_order: self.user.github_login.as_str(), + username_order: &self.user.github_login, } } } diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index f8f5de3c39..c92226eeeb 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - name: None, - }], - }, - ); - // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "a".into()), + ("user-5".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "y".into()), + ("user-5".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index dd97bd9ca4..365625b445 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,11 +17,11 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true -async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true +cloud_api_client.workspace = true cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index e0f4a70b15..b9b20aa4f2 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -6,22 +6,21 @@ pub mod telemetry; pub mod user; pub mod zed_urls; -use anyhow::{Context as _, Result, anyhow, bail}; -use async_recursion::async_recursion; +use anyhow::{Context as _, Result, anyhow}; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; -use chrono::{DateTime, Utc}; use clock::SystemClock; +use cloud_api_client::CloudApiClient; use credentials_provider::CredentialsProvider; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{AsyncBody, HttpClient, HttpClientWithUrl, http}; +use http_client::{HttpClient, HttpClientWithUrl, http}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -31,7 +30,6 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, @@ -45,6 +43,7 @@ use std::{ }, time::{Duration, Instant}, }; +use std::{cmp, pin::Pin}; use telemetry::Telemetry; use thiserror::Error; use tokio::net::TcpStream; @@ -78,7 +77,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock = LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); -pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10); +pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30); pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); actions!( @@ -162,20 +161,8 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn( - async move |cx| match client.authenticate_and_connect(true, &cx).await { - ConnectionResult::Timeout => { - log::error!("Initial authentication timed out"); - } - ConnectionResult::ConnectionReset => { - log::error!("Initial authentication connection reset"); - } - ConnectionResult::Result(r) => { - r.log_err(); - } - }, - ) - .detach(); + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + .detach_and_log_err(cx); } } }); @@ -213,6 +200,7 @@ pub struct Client { id: AtomicU64, peer: Arc, http: Arc, + cloud_client: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, @@ -283,6 +271,8 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, + Authenticated, + AuthenticationError, Connecting, ConnectionError, Connected { @@ -586,6 +576,7 @@ impl Client { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), + cloud_client: Arc::new(CloudApiClient::new(http.clone())), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), @@ -618,6 +609,10 @@ impl Client { self.http.clone() } + pub fn cloud_client(&self) -> Arc { + self.cloud_client.clone() + } + pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self @@ -704,7 +699,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.authenticate_and_connect(true, &cx).await { + match client.connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -727,11 +722,10 @@ impl Client { }, &cx, ); - cx.background_executor().timer(delay).await; - delay = delay - .mul_f32(rng.gen_range(0.5..=2.5)) - .max(INITIAL_RECONNECTION_DELAY) - .min(MAX_RECONNECTION_DELAY); + let jitter = + Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64)); + cx.background_executor().timer(delay + jitter).await; + delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY); } else { break; } @@ -875,17 +869,122 @@ impl Client { .is_some() } - #[async_recursion(?Send)] - pub async fn authenticate_and_connect( + pub async fn sign_in( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result { + if self.status().borrow().is_signed_out() { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx); + } + + let mut credentials = None; + + let old_credentials = self.state.read().credentials.clone(); + if let Some(old_credentials) = old_credentials { + self.cloud_client.set_credentials( + old_credentials.user_id as u32, + old_credentials.access_token.clone(), + ); + + // Fetch the authenticated user with the old credentials, to ensure they are still valid. + if self.cloud_client.get_authenticated_user().await.is_ok() { + credentials = Some(old_credentials); + } + } + + if credentials.is_none() && try_provider { + if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { + self.cloud_client.set_credentials( + stored_credentials.user_id as u32, + stored_credentials.access_token.clone(), + ); + + // Fetch the authenticated user with the stored credentials, and + // clear them from the credentials provider if that fails. + if self.cloud_client.get_authenticated_user().await.is_ok() { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); + } + } + } + + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => { + if IMPERSONATE_LOGIN.is_none() { + self.credentials_provider + .write_credentials(creds.user_id, creds.access_token.clone(), cx) + .await + .log_err(); + } + + credentials = Some(creds); + }, + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + return Err(err); + } + } + } + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } + } + + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); + self.cloud_client + .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + self.state.write().credentials = Some(credentials.clone()); + self.set_status(Status::Authenticated, cx); + + Ok(credentials) + } + + /// Performs a sign-in and also connects to Collab. + /// + /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls + /// to `sign_in` when we're ready to remove auto-connection to Collab. + pub async fn sign_in_with_optional_connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result<()> { + let credentials = self.sign_in(try_provider, cx).await?; + + let connect_result = match self.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => result.context("client auth and connect"), + }; + connect_result.log_err(); + + Ok(()) + } + + pub async fn connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> ConnectionResult<()> { let was_disconnected = match *self.status().borrow() { - Status::SignedOut => true, + Status::SignedOut | Status::Authenticated => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } + | Status::AuthenticationError | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { @@ -898,39 +997,10 @@ impl Client { ); } }; - if was_disconnected { - self.set_status(Status::Authenticating, cx); - } else { - self.set_status(Status::Reauthenticating, cx) - } - - let mut read_from_provider = false; - let mut credentials = self.state.read().credentials.clone(); - if credentials.is_none() && try_provider { - credentials = self.credentials_provider.read_credentials(cx).await; - read_from_provider = credentials.is_some(); - } - - if credentials.is_none() { - let mut status_rx = self.status(); - let _ = status_rx.next().await; - futures::select_biased! { - authenticate = self.authenticate(cx).fuse() => { - match authenticate { - Ok(creds) => credentials = Some(creds), - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return ConnectionResult::Result(Err(err)); - } - } - } - _ = status_rx.next().fuse() => { - return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); - } - } - } - let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); + let credentials = match self.sign_in(try_provider, cx).await { + Ok(credentials) => credentials, + Err(err) => return ConnectionResult::Result(Err(err)), + }; if was_disconnected { self.set_status(Status::Connecting, cx); @@ -938,17 +1008,20 @@ impl Client { self.set_status(Status::Reconnecting, cx); } + self.connect_with_credentials(credentials, cx).await + } + + async fn connect_with_credentials( + self: &Arc, + credentials: Credentials, + cx: &AsyncApp, + ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { - self.state.write().credentials = Some(credentials.clone()); - if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - } - futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -966,15 +1039,8 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.state.write().credentials.take(); - if read_from_provider { - self.credentials_provider.delete_credentials(cx).await.log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(false, cx).await - } else { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) - } + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1369,96 +1435,31 @@ impl Client { self: &Arc, http: Arc, login: String, - mut api_token: String, + api_token: String, ) -> Result { - #[derive(Deserialize)] - struct AuthenticatedUserResponse { - user: User, + #[derive(Serialize)] + struct ImpersonateUserBody { + github_login: String, } #[derive(Deserialize)] - struct User { - id: u64, + struct ImpersonateUserResponse { + user_id: u64, + access_token: String, } - let github_user = { - #[derive(Deserialize)] - struct GithubUser { - id: i32, - login: String, - created_at: DateTime, - } - - let request = { - let mut request_builder = - Request::get(&format!("https://api.github.com/users/{login}")); - if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { - request_builder = - request_builder.header("Authorization", format!("Bearer {}", github_token)); - } - - request_builder.body(AsyncBody::empty())? - }; - - let mut response = http - .send(request) - .await - .context("error fetching GitHub user")?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading GitHub user")?; - - if !response.status().is_success() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - serde_json::from_slice::(body.as_slice()).map_err(|err| { - log::error!("Error deserializing: {:?}", err); - log::error!( - "GitHub API response text: {:?}", - String::from_utf8_lossy(body.as_slice()) - ); - anyhow!("error deserializing GitHub user") - })? - }; - - let query_params = [ - ("github_login", &github_user.login), - ("github_user_id", &github_user.id.to_string()), - ( - "github_user_created_at", - &github_user.created_at.to_rfc3339(), - ), - ]; - - // Use the collab server's admin API to retrieve the ID - // of the impersonated user. - let mut url = self.rpc_url(http.clone(), None).await?; - url.set_path("/user"); - url.set_query(Some( - &query_params - .iter() - .map(|(key, value)| { - format!( - "{}={}", - key, - url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() - ) - }) - .collect::>() - .join("&"), - )); - let request: http_client::Request = Request::get(url.as_str()) - .header("Authorization", format!("token {api_token}")) - .body("".into())?; + let url = self + .http + .build_zed_cloud_url("/internal/users/impersonate", &[])?; + let request = Request::post(url.as_str()) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {api_token}")) + .body( + serde_json::to_string(&ImpersonateUserBody { + github_login: login, + })? + .into(), + )?; let mut response = http.send(request).await?; let mut body = String::new(); @@ -1469,18 +1470,17 @@ impl Client { response.status().as_u16(), body, ); - let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + let response: ImpersonateUserResponse = serde_json::from_str(&body)?; - // Use the admin API token to authenticate as the impersonated user. - api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { - user_id: response.user.id, - access_token: api_token, + user_id: response.user_id, + access_token: response.access_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; + self.cloud_client.clear_credentials(); self.disconnect(cx); if self.has_credentials(cx).await { @@ -1790,7 +1790,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1867,7 +1867,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.authenticate_and_connect(false, &cx).await } + move |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1875,7 +1875,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6ce79fa9c5..439fb100d2 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,8 +1,11 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; +use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; +use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -39,6 +42,44 @@ impl FakeServer { executor: cx.executor(), }; + client.http_client().as_fake().replace_handler({ + let state = server.state.clone(); + move |old_handler, req| { + let state = state.clone(); + let old_handler = old_handler.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + client_user_id as i32, + format!("user-{client_user_id}"), + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => old_handler(req).await, + } + } + } + }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -105,7 +146,7 @@ impl FakeServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -223,3 +264,54 @@ impl Drop for FakeServer { self.disconnect(); } } + +pub fn parse_authorization_header(req: &Request) -> Option { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION)? + .to_str() + .ok()? + .split_whitespace(); + let user_id = auth_header.next()?.parse().ok()?; + let access_token = auth_header.next()?; + Some(Credentials { + user_id, + access_token: access_token.to_string(), + }) +} + +pub fn make_get_authenticated_user_response( + user_id: i32, + github_login: String, +) -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: user_id, + metrics_id: format!("metrics-id-{user_id}"), + avatar_url: "".to_string(), + github_login, + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } +} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index a7dab2a8d3..3c125a0882 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,7 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{ EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, @@ -20,7 +21,7 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::{TryFutureExt as _, maybe}; +use util::{ResultExt, TryFutureExt as _}; pub type UserId = u64; @@ -55,7 +56,7 @@ pub struct ParticipantIndex(pub u32); #[derive(Default, Debug)] pub struct User { pub id: UserId, - pub github_login: String, + pub github_login: SharedString, pub avatar_uri: SharedUri, pub name: Option, } @@ -107,19 +108,14 @@ pub enum ContactRequestStatus { pub struct UserStore { users: HashMap>, - by_github_login: HashMap, + by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, - current_plan: Option, - subscription_period: Option<(DateTime, DateTime)>, - trial_started_at: Option>, model_request_usage: Option, edit_prediction_usage: Option, - is_usage_based_billing_enabled: Option, - account_too_young: Option, - has_overdue_invoices: Option, + plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>>, + accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -145,6 +141,7 @@ pub enum Event { ShowContacts, ParticipantIndicesChanged, PrivateUserInfoUpdated, + PlanUpdated, } #[derive(Clone, Copy)] @@ -188,14 +185,9 @@ impl UserStore { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - current_plan: None, - subscription_period: None, - trial_started_at: None, + plan_info: None, model_request_usage: None, edit_prediction_usage: None, - is_usage_based_billing_enabled: None, - account_too_young: None, - has_overdue_invoices: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -225,53 +217,30 @@ impl UserStore { return Ok(()); }; match status { - Status::Connected { .. } => { + Status::Authenticated | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let fetch_user = if let Ok(fetch_user) = - this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) - { - fetch_user - } else { - break; - }; - let fetch_private_user_info = - client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = - futures::join!(fetch_user, fetch_private_user_info); - + let response = client.cloud_client().get_authenticated_user().await; + let mut current_user = None; cx.update(|cx| { - if let Some(info) = info { - let staff = - info.staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, info.flags); - client.telemetry.set_authenticated_user_info( - Some(info.metrics_id.clone()), - staff, - ); - + if let Some(response) = response.log_err() { + let user = Arc::new(User { + id: user_id, + github_login: response.user.github_login.clone().into(), + avatar_uri: response.user.avatar_url.clone().into(), + name: response.user.name.clone(), + }); + current_user = Some(user.clone()); this.update(cx, |this, cx| { - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() - { - None - } else { - info.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - info.accepted_tos_at - }; - - this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); + this.by_github_login + .insert(user.github_login.clone(), user_id); + this.users.insert(user_id, user); + this.update_authenticated_user(response, cx) }) } else { anyhow::Ok(()) } })??; - - current_user_tx.send(user).await.ok(); + current_user_tx.send(current_user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -352,59 +321,22 @@ impl UserStore { async fn handle_update_plan( this: Entity, - message: TypedEnvelope, + _message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { + let client = this + .read_with(&cx, |this, _| this.client.upgrade())? + .context("client was dropped")?; + + let response = client + .cloud_client() + .get_authenticated_user() + .await + .context("failed to fetch authenticated user")?; + this.update(&mut cx, |this, cx| { - this.current_plan = Some(message.payload.plan()); - this.subscription_period = maybe!({ - let period = message.payload.subscription_period?; - let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?; - let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?; - - Some((started_at, ended_at)) - }); - this.trial_started_at = message - .payload - .trial_started_at - .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); - this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; - this.account_too_young = message.payload.account_too_young; - this.has_overdue_invoices = message.payload.has_overdue_invoices; - - if let Some(usage) = message.payload.usage { - // limits are always present even though they are wrapped in Option - this.model_request_usage = usage - .model_requests_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); - } - - cx.notify(); - })?; - Ok(()) - } - - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); + this.update_authenticated_user(response, cx); + }) } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -763,59 +695,131 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn current_plan(&self) -> Option { + pub fn plan(&self) -> Option { #[cfg(debug_assertions)] if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { return match plan.as_str() { - "free" => Some(proto::Plan::Free), - "trial" => Some(proto::Plan::ZedProTrial), - "pro" => Some(proto::Plan::ZedPro), + "free" => Some(cloud_llm_client::Plan::ZedFree), + "trial" => Some(cloud_llm_client::Plan::ZedProTrial), + "pro" => Some(cloud_llm_client::Plan::ZedPro), _ => { panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); } }; } - self.current_plan + self.plan_info.as_ref().map(|info| info.plan) } pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.subscription_period + self.plan_info + .as_ref() + .and_then(|plan| plan.subscription_period) + .map(|subscription_period| { + ( + subscription_period.started_at.0, + subscription_period.ended_at.0, + ) + }) } pub fn trial_started_at(&self) -> Option> { - self.trial_started_at + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) } - pub fn usage_based_billing_enabled(&self) -> Option { - self.is_usage_based_billing_enabled + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() } pub fn model_request_usage(&self) -> Option { self.model_request_usage } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + + fn update_authenticated_user( + &mut self, + response: GetAuthenticatedUserResponse, + cx: &mut Context, + ) { + let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, response.feature_flags); + if let Some(client) = self.client.upgrade() { + client + .telemetry + .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); + } + + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { + None + } else { + response.user.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + response.user.accepted_tos_at + }; + + self.accepted_tos_at = Some(accepted_tos_at); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); + self.plan_info = Some(response.plan); + cx.emit(Event::PrivateUserInfoUpdated); + } + pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.account_too_young.unwrap_or(false) - } - - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.has_overdue_invoices.unwrap_or(false) - } - - pub fn current_user_has_accepted_terms(&self) -> Option { + pub fn has_accepted_terms_of_service(&self) -> bool { self.accepted_tos_at - .map(|accepted_tos_at| accepted_tos_at.is_some()) + .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -827,23 +831,18 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .request(proto::AcceptTermsOfService {}) + .cloud_client() + .accept_terms_of_service() .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); + this.accepted_tos_at = Some(response.user.accepted_tos_at); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } - fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { - self.accepted_tos_at = Some( - accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), - ); - } - fn load_users( &self, request: impl RequestMessage, @@ -902,7 +901,7 @@ impl UserStore { let mut missing_user_ids = Vec::new(); for id in user_ids { if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) { - ret.insert(id, github_login.into()); + ret.insert(id, github_login); } else { missing_user_ids.push(id) } @@ -923,7 +922,7 @@ impl User { fn new(message: proto::User) -> Arc { Arc::new(User { id: message.id, - github_login: message.github_login, + github_login: message.github_login.into(), avatar_uri: message.avatar_url.into(), name: message.name, }) diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml new file mode 100644 index 0000000000..d56aa94c6e --- /dev/null +++ b/crates/cloud_api_client/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "cloud_api_client" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_client.rs" + +[dependencies] +anyhow.workspace = true +cloud_api_types.workspace = true +futures.workspace = true +http_client.workspace = true +parking_lot.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true diff --git a/crates/cloud_api_client/LICENSE-APACHE b/crates/cloud_api_client/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_client/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs new file mode 100644 index 0000000000..6689475dae --- /dev/null +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -0,0 +1,155 @@ +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +pub use cloud_api_types::*; +use futures::AsyncReadExt as _; +use http_client::http::request; +use http_client::{AsyncBody, HttpClientWithUrl, Method, Request}; +use parking_lot::RwLock; + +struct Credentials { + user_id: u32, + access_token: String, +} + +pub struct CloudApiClient { + credentials: RwLock>, + http_client: Arc, +} + +impl CloudApiClient { + pub fn new(http_client: Arc) -> Self { + Self { + credentials: RwLock::new(None), + http_client, + } + } + + pub fn has_credentials(&self) -> bool { + self.credentials.read().is_some() + } + + pub fn set_credentials(&self, user_id: u32, access_token: String) { + *self.credentials.write() = Some(Credentials { + user_id, + access_token, + }); + } + + pub fn clear_credentials(&self) { + *self.credentials.write() = None; + } + + fn authorization_header(&self) -> Result { + let guard = self.credentials.read(); + let credentials = guard + .as_ref() + .ok_or_else(|| anyhow!("No credentials provided"))?; + + Ok(format!( + "{} {}", + credentials.user_id, credentials.access_token + )) + } + + fn build_request( + &self, + req: request::Builder, + body: impl Into, + ) -> Result> { + Ok(req + .header("Content-Type", "application/json") + .header("Authorization", self.authorization_header()?) + .body(body.into())?) + } + + pub async fn get_authenticated_user(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn accept_terms_of_service(&self) -> Result { + let request = self.build_request( + Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/terms_of_service/accept", &[])? + .as_ref(), + ), + AsyncBody::default(), + )?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } + + pub async fn create_llm_token( + &self, + system_id: Option, + ) -> Result { + let mut request_builder = Request::builder().method(Method::POST).uri( + self.http_client + .build_zed_cloud_url("/client/llm_tokens", &[])? + .as_ref(), + ); + + if let Some(system_id) = system_id { + request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id); + } + + let request = self.build_request(request_builder, AsyncBody::default())?; + + let mut response = self.http_client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::bail!( + "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", + response.status() + ) + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(serde_json::from_str(&body)?) + } +} diff --git a/crates/cloud_api_types/Cargo.toml b/crates/cloud_api_types/Cargo.toml new file mode 100644 index 0000000000..868797df3b --- /dev/null +++ b/crates/cloud_api_types/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "cloud_api_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "Apache-2.0" + +[lints] +workspace = true + +[lib] +path = "src/cloud_api_types.rs" + +[dependencies] +chrono.workspace = true +cloud_llm_client.workspace = true +serde.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true +serde_json.workspace = true diff --git a/crates/cloud_api_types/LICENSE-APACHE b/crates/cloud_api_types/LICENSE-APACHE new file mode 120000 index 0000000000..1cd601d0a3 --- /dev/null +++ b/crates/cloud_api_types/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cloud_api_types/src/cloud_api_types.rs b/crates/cloud_api_types/src/cloud_api_types.rs new file mode 100644 index 0000000000..b38b38cde1 --- /dev/null +++ b/crates/cloud_api_types/src/cloud_api_types.rs @@ -0,0 +1,55 @@ +mod timestamp; + +use serde::{Deserialize, Serialize}; + +pub use crate::timestamp::Timestamp; + +pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id"; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct GetAuthenticatedUserResponse { + pub user: AuthenticatedUser, + pub feature_flags: Vec, + pub plan: PlanInfo, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AuthenticatedUser { + pub id: i32, + pub metrics_id: String, + pub avatar_url: String, + pub github_login: String, + pub name: Option, + pub is_staff: bool, + pub accepted_tos_at: Option, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct PlanInfo { + pub plan: cloud_llm_client::Plan, + pub subscription_period: Option, + pub usage: cloud_llm_client::CurrentUsage, + pub trial_started_at: Option, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +pub struct SubscriptionPeriod { + pub started_at: Timestamp, + pub ended_at: Timestamp, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct AcceptTermsOfServiceResponse { + pub user: AuthenticatedUser, +} + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct LlmToken(pub String); + +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +pub struct CreateLlmTokenResponse { + pub token: LlmToken, +} diff --git a/crates/cloud_api_types/src/timestamp.rs b/crates/cloud_api_types/src/timestamp.rs new file mode 100644 index 0000000000..1f055d58ef --- /dev/null +++ b/crates/cloud_api_types/src/timestamp.rs @@ -0,0 +1,166 @@ +use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A timestamp with a serialized representation in RFC 3339 format. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct Timestamp(pub DateTime); + +impl Timestamp { + pub fn new(datetime: DateTime) -> Self { + Self(datetime) + } +} + +impl From> for Timestamp { + fn from(value: DateTime) -> Self { + Self(value) + } +} + +impl From for Timestamp { + fn from(value: NaiveDateTime) -> Self { + Self(value.and_utc()) + } +} + +impl Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true); + serializer.serialize_str(&rfc3339_string) + } +} + +impl<'de> Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = String::deserialize(deserializer)?; + let datetime = DateTime::parse_from_rfc3339(&value) + .map_err(serde::de::Error::custom)? + .to_utc(); + Ok(Self(datetime)) + } +} + +#[cfg(test)] +mod tests { + use chrono::NaiveDate; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_timestamp_serialization() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + let timestamp = Timestamp::new(datetime); + + let json = serde_json::to_string(×tamp).unwrap(); + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization() { + let json = "\"2023-12-25T14:30:45.123Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_roundtrip() { + let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::new(original); + let json = serde_json::to_string(×tamp).unwrap(); + let deserialized: Timestamp = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.0, original); + } + + #[test] + fn test_timestamp_from_datetime_utc() { + let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z") + .unwrap() + .to_utc(); + + let timestamp = Timestamp::from(datetime); + assert_eq!(timestamp.0, datetime); + } + + #[test] + fn test_timestamp_from_naive_datetime() { + let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(14, 30, 45, 123) + .unwrap(); + + let timestamp = Timestamp::from(naive_dt); + let expected = naive_dt.and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_serialization_with_microseconds() { + // Test that microseconds are truncated to milliseconds + let datetime = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_micro_opt(14, 30, 45, 123456) + .unwrap() + .and_utc(); + + let timestamp = Timestamp::new(datetime); + let json = serde_json::to_string(×tamp).unwrap(); + + // Should be truncated to milliseconds + assert_eq!(json, "\"2023-12-25T14:30:45.123Z\""); + } + + #[test] + fn test_timestamp_deserialization_without_milliseconds() { + let json = "\"2023-12-25T14:30:45Z\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_opt(14, 30, 45) + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_timezone() { + let json = "\"2023-12-25T14:30:45.123+05:30\""; + let timestamp: Timestamp = serde_json::from_str(json).unwrap(); + + // Should be converted to UTC + let expected = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30 + .unwrap() + .and_utc(); + + assert_eq!(timestamp.0, expected); + } + + #[test] + fn test_timestamp_deserialization_with_invalid_format() { + let json = "\"invalid-date\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } +} diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 2488088a49..171c923154 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -308,13 +308,13 @@ pub struct GetSubscriptionResponse { pub usage: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct CurrentUsage { pub model_requests: UsageData, pub edit_predictions: UsageData, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct UsageData { pub used: u32, pub limit: UsageLimit, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 5c35394e1d..e648617fe1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -42,7 +42,7 @@ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; -use rpc::proto::split_repository_update; +use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ @@ -374,7 +374,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) - .add_request_handler(forward_mutating_project_request::) + .add_request_handler(multi_lsp_query) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) @@ -838,7 +838,7 @@ impl Server { // This arrangement ensures we will attempt to process earlier messages first, but fall // back to processing messages arrived later in the spirit of making progress. let mut foreground_message_handlers = FuturesUnordered::new(); - let concurrent_handlers = Arc::new(Semaphore::new(512)); + let concurrent_handlers = Arc::new(Semaphore::new(256)); loop { let next_message = async { let permit = concurrent_handlers.clone().acquire_owned().await.unwrap(); @@ -865,6 +865,7 @@ impl Server { user_id=field::Empty, login=field::Empty, impersonator=field::Empty, + multi_lsp_query_request=field::Empty, ); principal.update_span(&span); let span_enter = span.enter(); @@ -2329,6 +2330,15 @@ where Ok(()) } +async fn multi_lsp_query( + request: MultiLspQuery, + response: Response, + session: Session, +) -> Result<()> { + tracing::Span::current().record("multi_lsp_query_request", request.request_str()); + forward_mutating_project_request(request, response, session).await +} + /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 19e410de5b..8d5d076780 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -38,12 +38,12 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic let mut remote = room .remote_participants() .values() - .map(|participant| participant.user.github_login.clone()) + .map(|participant| participant.user.github_login.clone().to_string()) .collect::>(); let mut pending = room .pending_participants() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect::>(); remote.sort(); pending.sort(); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 9795c27574..aea359d75b 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room( // Allow user A to reconnect to the server. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); // Call user B again from client A. active_call_a @@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .authenticate_and_connect(false, &cx_b1.to_async()) + .connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections( // User A reconnects automatically, then calls user B again. server.allow_connections(); - executor.advance_clock(RECEIVE_TIMEOUT); + executor.advance_clock(RECONNECT_TIMEOUT); active_call_a .update(cx_a, |call, cx| { call.invite(client_b1.user_id().unwrap(), None, cx) @@ -1667,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .authenticate_and_connect(false, &cx_a.to_async()) + .connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1796,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .authenticate_and_connect(false, &cx_b.to_async()) + .connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -1881,7 +1881,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_a.user_id().unwrap(), - github_login: "user_a".to_string(), + github_login: "user_a".into(), avatar_uri: "avatar_a".into(), name: None, }), @@ -1900,7 +1900,7 @@ async fn test_active_call_events( vec![room::Event::RemoteProjectShared { owner: Arc::new(User { id: client_b.user_id().unwrap(), - github_login: "user_b".to_string(), + github_login: "user_b".into(), avatar_uri: "avatar_b".into(), name: None, }), @@ -5738,7 +5738,7 @@ async fn test_contacts( server.allow_connections(); client_c - .authenticate_and_connect(false, &cx_c.to_async()) + .connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6079,7 +6079,7 @@ async fn test_contacts( .iter() .map(|contact| { ( - contact.user.github_login.clone(), + contact.user.github_login.clone().to_string(), if contact.online { "online" } else { "offline" }, if contact.busy { "busy" } else { "free" }, ) @@ -6269,7 +6269,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 4e64b5526b..9bf906694e 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; +use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -17,6 +18,9 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + // Wait for authentication/connection to Collab to be established. + executor.run_until_parked(); + let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5192db16a7..5fcc622fc1 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,6 +8,7 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; +use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -20,7 +21,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::FakeHttpClient; +use http_client::{FakeHttpClient, Method}; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -161,6 +162,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + const ACCESS_TOKEN: &str = "the-token"; + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -175,7 +178,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -197,6 +200,47 @@ impl TestServer { .expect("creating user failed") .user_id }; + + let http = FakeHttpClient::create({ + let name = name.to_string(); + move |req| { + let name = name.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: user_id.to_proto(), + access_token: ACCESS_TOKEN.into(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + user_id.0, name, + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -208,11 +252,10 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { - let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token, + access_token: ACCESS_TOKEN.into(), }) }) }) @@ -221,7 +264,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: "the-token".into() + access_token: ACCESS_TOKEN.into(), } ); @@ -319,7 +362,7 @@ impl TestServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -692,17 +735,17 @@ impl TestClient { current: store .contacts() .iter() - .map(|contact| contact.user.github_login.clone()) + .map(|contact| contact.user.github_login.clone().to_string()) .collect(), outgoing_requests: store .outgoing_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), incoming_requests: store .incoming_contact_requests() .iter() - .map(|user| user.github_login.clone()) + .map(|user| user.github_login.clone().to_string()) .collect(), }) } diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 4d5973481e..54077303a1 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -940,7 +940,7 @@ impl CollabPanel { room.read(cx).local_participant().role == proto::ChannelRole::Admin }); - ListItem::new(SharedString::from(user.github_login.clone())) + ListItem::new(user.github_login.clone()) .start_slot(Avatar::new(user.avatar_uri.clone())) .child(Label::new(user.github_login.clone())) .toggle_state(is_selected) @@ -2331,7 +2331,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .authenticate_and_connect(true, &cx) + .connect(true, &cx) .await .into_response() .notify_async_err(cx); @@ -2583,7 +2583,7 @@ impl CollabPanel { ) -> impl IntoElement { let online = contact.online; let busy = contact.busy || calling; - let github_login = SharedString::from(contact.user.github_login.clone()); + let github_login = contact.user.github_login.clone(); let item = ListItem::new(github_login.clone()) .indent_level(1) .indent_step_size(px(20.)) @@ -2662,7 +2662,7 @@ impl CollabPanel { is_selected: bool, cx: &mut Context, ) -> impl IntoElement { - let github_login = SharedString::from(user.github_login.clone()); + let github_login = user.github_login.clone(); let user_id = user.id; let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user); let color = if is_response_pending { diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index fba8f66c2d..c3e834b645 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -634,13 +634,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if self.client.user_id().is_none() { + if !self.client.status().borrow().is_connected() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("sign_in_prompt_button", "Sign in") + Button::new("connect_prompt_button", "Connect") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,10 +652,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client - .authenticate_and_connect(true, &cx) - .await - { + match client.connect(true, &cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -673,7 +670,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Sign in to view notifications.") + Label::new("Connect to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), diff --git a/crates/dap/src/client.rs b/crates/dap/src/client.rs index 86a15b2d8a..7b791450ec 100644 --- a/crates/dap/src/client.rs +++ b/crates/dap/src/client.rs @@ -295,7 +295,7 @@ mod tests { request: dap_types::StartDebuggingRequestArgumentsRequest::Launch, }, }, - Box::new(|_| panic!("Did not expect to hit this code path")), + Box::new(|_| {}), &mut cx.to_async(), ) .await diff --git a/crates/dap/src/transport.rs b/crates/dap/src/transport.rs index 6dadf1cf35..f9fbbfc842 100644 --- a/crates/dap/src/transport.rs +++ b/crates/dap/src/transport.rs @@ -883,6 +883,7 @@ impl FakeTransport { break Err(anyhow!("exit in response to request")); } }; + let success = response.success; let message = serde_json::to_string(&Message::Response(response)).unwrap(); @@ -893,6 +894,25 @@ impl FakeTransport { ) .await .unwrap(); + + if request.command == dap_types::requests::Initialize::COMMAND + && success + { + let message = serde_json::to_string(&Message::Event(Box::new( + dap_types::messages::Events::Initialized(Some( + Default::default(), + )), + ))) + .unwrap(); + writer + .write_all( + TransportDelegate::build_rpc_message(message) + .as_bytes(), + ) + .await + .unwrap(); + } + writer.flush().await.unwrap(); } } diff --git a/crates/docs_preprocessor/Cargo.toml b/crates/docs_preprocessor/Cargo.toml index a9eff17fa1..e46ceb18db 100644 --- a/crates/docs_preprocessor/Cargo.toml +++ b/crates/docs_preprocessor/Cargo.toml @@ -9,7 +9,9 @@ license = "GPL-3.0-or-later" anyhow.workspace = true command_palette.workspace = true gpui.workspace = true -mdbook = "0.4.40" +# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories. +# Ask @maxdeviant about this before bumping. +mdbook = "= 0.4.40" regex.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 0692c7fbe6..ab2d1c8ecb 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -22,6 +22,7 @@ test-support = [ "theme/test-support", "util/test-support", "workspace/test-support", + "tree-sitter-c", "tree-sitter-rust", "tree-sitter-typescript", "tree-sitter-html", @@ -76,6 +77,7 @@ telemetry.workspace = true text.workspace = true time.workspace = true theme.workspace = true +tree-sitter-c = { workspace = true, optional = true } tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } @@ -106,6 +108,7 @@ settings = { workspace = true, features = ["test-support"] } tempfile.workspace = true text = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } +tree-sitter-c.workspace = true tree-sitter-html.workspace = true tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index a2f2310144..3516eff45c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -56,7 +56,7 @@ use aho_corasick::AhoCorasick; use anyhow::{Context as _, Result, anyhow}; use blink_manager::BlinkManager; use buffer_diff::DiffHunkStatus; -use client::{Collaborator, ParticipantIndex}; +use client::{Collaborator, DisableAiSettings, ParticipantIndex}; use clock::{AGENT_REPLICA_ID, ReplicaId}; use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; @@ -1305,6 +1305,7 @@ impl Default for SelectionHistoryMode { /// /// Similarly, you might want to disable scrolling if you don't want the viewport to /// move. +#[derive(Clone)] pub struct SelectionEffects { nav_history: Option, completions: bool, @@ -2944,10 +2945,12 @@ impl Editor { } } + let selection_anchors = self.selections.disjoint_anchors(); + if self.focus_handle.is_focused(window) && self.leader_id.is_none() { self.buffer.update(cx, |buffer, cx| { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &selection_anchors, self.selections.line_mode, self.cursor_shape, cx, @@ -2964,9 +2967,8 @@ impl Editor { self.select_next_state = None; self.select_prev_state = None; self.select_syntax_node_history.try_clear(); - self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), buffer); - self.snippet_stack - .invalidate(&self.selections.disjoint_anchors(), buffer); + self.invalidate_autoclose_regions(&selection_anchors, buffer); + self.snippet_stack.invalidate(&selection_anchors, buffer); self.take_rename(false, window, cx); let newest_selection = self.selections.newest_anchor(); @@ -4047,7 +4049,8 @@ impl Editor { // then don't insert that closing bracket again; just move the selection // past the closing bracket. let should_skip = selection.end == region.range.end.to_point(&snapshot) - && text.as_ref() == region.pair.end.as_str(); + && text.as_ref() == region.pair.end.as_str() + && snapshot.contains_str_at(region.range.end, text.as_ref()); if should_skip { let anchor = snapshot.anchor_after(selection.end); new_selections @@ -4973,13 +4976,17 @@ impl Editor { }) } - /// Remove any autoclose regions that no longer contain their selection. + /// Remove any autoclose regions that no longer contain their selection or have invalid anchors in ranges. fn invalidate_autoclose_regions( &mut self, mut selections: &[Selection], buffer: &MultiBufferSnapshot, ) { self.autoclose_regions.retain(|state| { + if !state.range.start.is_valid(buffer) || !state.range.end.is_valid(buffer) { + return false; + } + let mut i = 0; while let Some(selection) = selections.get(i) { if selection.end.cmp(&state.range.start, buffer).is_lt() { @@ -5891,18 +5898,20 @@ impl Editor { text: new_text[common_prefix_len..].into(), }); - self.transact(window, cx, |this, window, cx| { + self.transact(window, cx, |editor, window, cx| { if let Some(mut snippet) = snippet { snippet.text = new_text.to_string(); - this.insert_snippet(&ranges, snippet, window, cx).log_err(); + editor + .insert_snippet(&ranges, snippet, window, cx) + .log_err(); } else { - this.buffer.update(cx, |buffer, cx| { + editor.buffer.update(cx, |multi_buffer, cx| { let auto_indent = match completion.insert_text_mode { Some(InsertTextMode::AS_IS) => None, - _ => this.autoindent_mode.clone(), + _ => editor.autoindent_mode.clone(), }; let edits = ranges.into_iter().map(|range| (range, new_text.as_str())); - buffer.edit(edits, auto_indent, cx); + multi_buffer.edit(edits, auto_indent, cx); }); } for (buffer, edits) in linked_edits { @@ -5921,8 +5930,9 @@ impl Editor { }) } - this.refresh_inline_completion(true, false, window, cx); + editor.refresh_inline_completion(true, false, window, cx); }); + self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot); let show_new_completions_on_confirm = completion .confirm @@ -7048,7 +7058,7 @@ impl Editor { } pub fn update_edit_prediction_settings(&mut self, cx: &mut Context) { - if self.edit_prediction_provider.is_none() { + if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai { self.edit_prediction_settings = EditPredictionSettings::Disabled; } else { let selection = self.selections.newest_anchor(); @@ -9562,27 +9572,46 @@ impl Editor { // Check whether the just-entered snippet ends with an auto-closable bracket. if self.autoclose_regions.is_empty() { let snapshot = self.buffer.read(cx).snapshot(cx); - for selection in &mut self.selections.all::(cx) { + let mut all_selections = self.selections.all::(cx); + for selection in &mut all_selections { let selection_head = selection.head(); let Some(scope) = snapshot.language_scope_at(selection_head) else { continue; }; let mut bracket_pair = None; - let next_chars = snapshot.chars_at(selection_head).collect::(); - let prev_chars = snapshot - .reversed_chars_at(selection_head) - .collect::(); - for (pair, enabled) in scope.brackets() { - if enabled - && pair.close - && prev_chars.starts_with(pair.start.as_str()) - && next_chars.starts_with(pair.end.as_str()) - { - bracket_pair = Some(pair.clone()); - break; + let max_lookup_length = scope + .brackets() + .map(|(pair, _)| { + pair.start + .as_str() + .chars() + .count() + .max(pair.end.as_str().chars().count()) + }) + .max(); + if let Some(max_lookup_length) = max_lookup_length { + let next_text = snapshot + .chars_at(selection_head) + .take(max_lookup_length) + .collect::(); + let prev_text = snapshot + .reversed_chars_at(selection_head) + .take(max_lookup_length) + .collect::(); + + for (pair, enabled) in scope.brackets() { + if enabled + && pair.close + && prev_text.starts_with(pair.start.as_str()) + && next_text.starts_with(pair.end.as_str()) + { + bracket_pair = Some(pair.clone()); + break; + } } } + if let Some(pair) = bracket_pair { let snapshot_settings = snapshot.language_settings_at(selection_head, cx); let autoclose_enabled = diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index a13708c580..1a4f444275 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -8612,6 +8612,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) { cx.language_registry().add(html_language.clone()); cx.language_registry().add(javascript_language.clone()); + cx.executor().run_until_parked(); cx.update_buffer(|buffer, cx| { buffer.set_language(Some(html_language), cx); @@ -13400,6 +13401,178 @@ async fn test_as_is_completions(cx: &mut TestAppContext) { cx.assert_editor_state("fn a() {}\n unsafeˇ"); } +#[gpui::test] +async fn test_panic_during_c_completions(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let language = + Arc::try_unwrap(languages::language("c", tree_sitter_c::LANGUAGE.into())).unwrap(); + let mut cx = EditorLspTestContext::new( + language, + lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + cx.set_state( + "#ifndef BAR_H +#define BAR_H + +#include + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +ˇ", + ); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("#", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("i", window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.handle_input("n", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#inˇ", + ); + + cx.lsp + .set_request_handler::(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: false, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::SNIPPET), + label_details: Some(lsp::CompletionItemLabelDetails { + detail: Some("header".to_string()), + description: None, + }), + label: " include".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 1, + }, + end: lsp::Position { + line: 8, + character: 1, + }, + }, + new_text: "include \"$0\"".to_string(), + })), + sort_text: Some("40b67681include".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + filter_text: Some("include".to_string()), + insert_text: Some("include \"$0\"".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + "#ifndef BAR_H +#define BAR_H + +#include + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include \"ˇ\"", + ); + + cx.lsp + .set_request_handler::(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList { + is_incomplete: true, + item_defaults: None, + items: vec![lsp::CompletionItem { + kind: Some(lsp::CompletionItemKind::FILE), + label: "AGL/".to_string(), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range { + start: lsp::Position { + line: 8, + character: 10, + }, + end: lsp::Position { + line: 8, + character: 11, + }, + }, + new_text: "AGL/".to_string(), + })), + sort_text: Some("40b67681AGL/".to_string()), + insert_text_format: Some(lsp::InsertTextFormat::PLAIN_TEXT), + filter_text: Some("AGL/".to_string()), + insert_text: Some("AGL/".to_string()), + ..lsp::CompletionItem::default() + }], + }))) + }); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + cx.executor().run_until_parked(); + cx.update_editor(|editor, window, cx| { + editor.confirm_completion(&ConfirmCompletion::default(), window, cx) + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/ˇ"##, + ); + + cx.update_editor(|editor, window, cx| { + editor.handle_input("\"", window, cx); + }); + cx.executor().run_until_parked(); + cx.assert_editor_state( + r##"#ifndef BAR_H +#define BAR_H + +#include + +int fn_branch(bool do_branch1, bool do_branch2); + +#endif // BAR_H +#include "AGL/"ˇ"##, + ); +} + #[gpui::test] async fn test_no_duplicated_completion_requests(cx: &mut TestAppContext) { init_test(cx, |_| {}); diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index a02b4a7f0b..d638ac171f 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -18,7 +18,7 @@ use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; use gpui::http_client::read_proxy_from_env; -use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; @@ -337,7 +337,8 @@ pub struct AgentAppState { } pub fn init(cx: &mut App) -> Arc { - release_channel::init(SemanticVersion::default(), cx); + let app_version = AppVersion::global(cx); + release_channel::init(app_version, cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); @@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc { // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", - AppVersion::global(cx), + app_version, std::env::consts::OS, std::env::consts::ARCH ); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index e196a5b139..ee74ac4d54 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2416,7 +2416,7 @@ impl GitPanel { .committer_name .clone() .or_else(|| participant.user.name.clone()) - .unwrap_or_else(|| participant.user.github_login.clone()); + .unwrap_or_else(|| participant.user.github_login.clone().to_string()); new_co_authors.push((name.clone(), email.clone())) } } @@ -2436,7 +2436,7 @@ impl GitPanel { .name .clone() .or_else(|| user.name.clone()) - .unwrap_or_else(|| user.github_login.clone()); + .unwrap_or_else(|| user.github_login.clone().to_string()); Some((name, email)) } diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 4023ddf2dc..2bf49fa7d8 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -47,7 +47,6 @@ wayland = [ "wayland-cursor", "wayland-protocols", "wayland-protocols-plasma", - "wayland-protocols-wlr", "filedescriptor", "xkbcommon", "open", @@ -194,9 +193,6 @@ wayland-protocols = { version = "0.31.2", features = [ wayland-protocols-plasma = { version = "0.2.0", features = [ "client", ], optional = true } -wayland-protocols-wlr = { version = "0.3.8", features = [ - "client" -], optional = true} # X11 as-raw-xcb-connection = { version = "1", optional = true } @@ -220,10 +216,6 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf x11-clipboard = { version = "0.9.3", optional = true } [target.'cfg(target_os = "windows")'.dependencies] -blade-util.workspace = true -bytemuck = "1" -blade-graphics.workspace = true -blade-macros.workspace = true flume = "0.11" rand.workspace = true windows.workspace = true @@ -244,7 +236,6 @@ util = { workspace = true, features = ["test-support"] } [target.'cfg(target_os = "windows")'.build-dependencies] embed-resource = "3.0" -naga.workspace = true [target.'cfg(target_os = "macos")'.build-dependencies] bindgen = "0.71" diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index 7ab44a73f5..93a1c15c41 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -9,7 +9,10 @@ fn main() { let target = env::var("CARGO_CFG_TARGET_OS"); println!("cargo::rustc-check-cfg=cfg(gles)"); - #[cfg(any(not(target_os = "macos"), feature = "macos-blade"))] + #[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") + ))] check_wgsl_shaders(); match target.as_deref() { @@ -17,21 +20,18 @@ fn main() { #[cfg(target_os = "macos")] macos::build(); } - #[cfg(all(target_os = "windows", feature = "windows-manifest"))] Ok("windows") => { - let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); - let rc_file = std::path::Path::new("resources/windows/gpui.rc"); - println!("cargo:rerun-if-changed={}", manifest.display()); - println!("cargo:rerun-if-changed={}", rc_file.display()); - embed_resource::compile(rc_file, embed_resource::NONE) - .manifest_required() - .unwrap(); + #[cfg(target_os = "windows")] + windows::build(); } _ => (), }; } -#[allow(dead_code)] +#[cfg(any( + not(any(target_os = "macos", target_os = "windows")), + all(target_os = "macos", feature = "macos-blade") +))] fn check_wgsl_shaders() { use std::path::PathBuf; use std::process; @@ -243,3 +243,215 @@ mod macos { } } } + +#[cfg(target_os = "windows")] +mod windows { + use std::{ + fs, + io::Write, + path::{Path, PathBuf}, + process::{self, Command}, + }; + + pub(super) fn build() { + // Compile HLSL shaders + #[cfg(not(debug_assertions))] + compile_shaders(); + + // Embed the Windows manifest and resource file + #[cfg(feature = "windows-manifest")] + embed_resource(); + } + + #[cfg(feature = "windows-manifest")] + fn embed_resource() { + let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml"); + let rc_file = std::path::Path::new("resources/windows/gpui.rc"); + println!("cargo:rerun-if-changed={}", manifest.display()); + println!("cargo:rerun-if-changed={}", rc_file.display()); + embed_resource::compile(rc_file, embed_resource::NONE) + .manifest_required() + .unwrap(); + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn compile_shaders() { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/shaders.hlsl"); + let out_dir = std::env::var("OUT_DIR").unwrap(); + + println!("cargo:rerun-if-changed={}", shader_path.display()); + + // Check if fxc.exe is available + let fxc_path = find_fxc_compiler(); + + // Define all modules + let modules = [ + "quad", + "shadow", + "path_rasterization", + "path_sprite", + "underline", + "monochrome_sprite", + "polychrome_sprite", + ]; + + let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir); + if Path::new(&rust_binding_path).exists() { + fs::remove_file(&rust_binding_path) + .expect("Failed to remove existing Rust binding file"); + } + for module in modules { + compile_shader_for_module( + module, + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + + { + let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("src/platform/windows/color_text_raster.hlsl"); + compile_shader_for_module( + "emoji_rasterization", + &out_dir, + &fxc_path, + shader_path.to_str().unwrap(), + &rust_binding_path, + ); + } + } + + /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler. + fn find_fxc_compiler() -> String { + // Check environment variable + if let Ok(path) = std::env::var("GPUI_FXC_PATH") { + if Path::new(&path).exists() { + return path; + } + } + + // Try to find in PATH + // NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe` + if let Ok(output) = std::process::Command::new("where.exe") + .arg("fxc.exe") + .output() + { + if output.status.success() { + let path = String::from_utf8_lossy(&output.stdout); + return path.trim().to_string(); + } + } + + // Check the default path + if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe") + .exists() + { + return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe" + .to_string(); + } + + panic!("Failed to find fxc.exe"); + } + + fn compile_shader_for_module( + module: &str, + out_dir: &str, + fxc_path: &str, + shader_path: &str, + rust_binding_path: &str, + ) { + // Compile vertex shader + let output_file = format!("{}/{}_vs.h", out_dir, module); + let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_vertex"), + &output_file, + &const_name, + shader_path, + "vs_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + + // Compile fragment shader + let output_file = format!("{}/{}_ps.h", out_dir, module); + let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase()); + compile_shader_impl( + fxc_path, + &format!("{module}_fragment"), + &output_file, + &const_name, + shader_path, + "ps_4_1", + ); + generate_rust_binding(&const_name, &output_file, &rust_binding_path); + } + + fn compile_shader_impl( + fxc_path: &str, + entry_point: &str, + output_path: &str, + var_name: &str, + shader_path: &str, + target: &str, + ) { + let output = Command::new(fxc_path) + .args([ + "/T", + target, + "/E", + entry_point, + "/Fh", + output_path, + "/Vn", + var_name, + "/O3", + shader_path, + ]) + .output(); + + match output { + Ok(result) => { + if result.status.success() { + return; + } + eprintln!( + "Shader compilation failed for {}:\n{}", + entry_point, + String::from_utf8_lossy(&result.stderr) + ); + process::exit(1); + } + Err(e) => { + eprintln!("Failed to run fxc for {}: {}", entry_point, e); + process::exit(1); + } + } + } + + fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) { + let header_content = fs::read_to_string(head_file).expect("Failed to read header file"); + let const_definition = { + let global_var_start = header_content.find("const BYTE").unwrap(); + let global_var = &header_content[global_var_start..]; + let equal = global_var.find('=').unwrap(); + global_var[equal + 1..].trim() + }; + let rust_binding = format!( + "const {}: &[u8] = &{}\n", + const_name, + const_definition.replace('{', "[").replace('}', "]") + ); + let mut options = fs::OpenOptions::new() + .create(true) + .append(true) + .open(output_path) + .expect("Failed to open Rust binding file"); + options + .write_all(rust_binding.as_bytes()) + .expect("Failed to write Rust binding file"); + } +} diff --git a/crates/gpui/examples/text.rs b/crates/gpui/examples/text.rs index 19214aebde..1166bb2795 100644 --- a/crates/gpui/examples/text.rs +++ b/crates/gpui/examples/text.rs @@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid { "χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р", "У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*", "_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ", - "❮", "<=", "!=", "==", "--", "++", "=>", "->", + "❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎", ]; let columns = 11; diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index a16c8f46be..639c84c101 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -35,6 +35,7 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) { /// An RGBA color #[derive(PartialEq, Clone, Copy, Default)] +#[repr(C)] pub struct Rgba { /// The red component of the color, in the range 0.0 to 1.0 pub r: f32, diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index febf294e48..b495d70dfd 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -13,8 +13,7 @@ mod mac; any(target_os = "linux", target_os = "freebsd"), any(feature = "x11", feature = "wayland") ), - target_os = "windows", - feature = "macos-blade" + all(target_os = "macos", feature = "macos-blade") ))] mod blade; @@ -448,6 +447,8 @@ impl Tiling { #[derive(Debug, Copy, Clone, Eq, PartialEq, Default)] pub(crate) struct RequestFrameOptions { pub(crate) require_presentation: bool, + /// Force refresh of all rendering states when true + pub(crate) force_render: bool, } pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { @@ -1216,10 +1217,6 @@ pub enum WindowKind { /// A window that appears above all other windows, usually used for alerts or popups /// use sparingly! PopUp, - /// An overlay such as a notification window, a launcher, ... - /// - /// Only supported on wayland - Overlay, } /// The appearance of the window, as defined by the operating system. diff --git a/crates/gpui/src/platform/linux/wayland/client.rs b/crates/gpui/src/platform/linux/wayland/client.rs index 33b22e7ce5..72e4477ecf 100644 --- a/crates/gpui/src/platform/linux/wayland/client.rs +++ b/crates/gpui/src/platform/linux/wayland/client.rs @@ -61,7 +61,6 @@ use wayland_protocols::xdg::decoration::zv1::client::{ }; use wayland_protocols::xdg::shell::client::{xdg_surface, xdg_toplevel, xdg_wm_base}; use wayland_protocols_plasma::blur::client::{org_kde_kwin_blur, org_kde_kwin_blur_manager}; -use wayland_protocols_wlr::layer_shell::v1::client::{zwlr_layer_shell_v1, zwlr_layer_surface_v1}; use xkbcommon::xkb::ffi::XKB_KEYMAP_FORMAT_TEXT_V1; use xkbcommon::xkb::{self, KEYMAP_COMPILE_NO_FLAGS, Keycode}; @@ -115,7 +114,6 @@ pub struct Globals { pub fractional_scale_manager: Option, pub decoration_manager: Option, - pub layer_shell: Option, pub blur_manager: Option, pub text_input_manager: Option, pub executor: ForegroundExecutor, @@ -153,7 +151,6 @@ impl Globals { viewporter: globals.bind(&qh, 1..=1, ()).ok(), fractional_scale_manager: globals.bind(&qh, 1..=1, ()).ok(), decoration_manager: globals.bind(&qh, 1..=1, ()).ok(), - layer_shell: globals.bind(&qh, 1..=1, ()).ok(), blur_manager: globals.bind(&qh, 1..=1, ()).ok(), text_input_manager: globals.bind(&qh, 1..=1, ()).ok(), executor, @@ -932,7 +929,6 @@ delegate_noop!(WaylandClientStatePtr: ignore wl_buffer::WlBuffer); delegate_noop!(WaylandClientStatePtr: ignore wl_region::WlRegion); delegate_noop!(WaylandClientStatePtr: ignore wp_fractional_scale_manager_v1::WpFractionalScaleManagerV1); delegate_noop!(WaylandClientStatePtr: ignore zxdg_decoration_manager_v1::ZxdgDecorationManagerV1); -delegate_noop!(WaylandClientStatePtr: ignore zwlr_layer_shell_v1::ZwlrLayerShellV1); delegate_noop!(WaylandClientStatePtr: ignore org_kde_kwin_blur_manager::OrgKdeKwinBlurManager); delegate_noop!(WaylandClientStatePtr: ignore zwp_text_input_manager_v3::ZwpTextInputManagerV3); delegate_noop!(WaylandClientStatePtr: ignore org_kde_kwin_blur::OrgKdeKwinBlur); @@ -1078,31 +1074,6 @@ impl Dispatch for WaylandClientStatePtr { } } -impl Dispatch for WaylandClientStatePtr { - fn event( - this: &mut Self, - _: &zwlr_layer_surface_v1::ZwlrLayerSurfaceV1, - event: ::Event, - surface_id: &ObjectId, - _: &Connection, - _: &QueueHandle, - ) { - let client = this.get_client(); - let mut state = client.borrow_mut(); - let Some(window) = get_window(&mut state, surface_id) else { - return; - }; - drop(state); - - let should_close = window.handle_layersurface_event(event); - - if should_close { - // The close logic will be handled in drop_window() - window.close(); - } - } -} - impl Dispatch for WaylandClientStatePtr { fn event( _: &mut Self, diff --git a/crates/gpui/src/platform/linux/wayland/window.rs b/crates/gpui/src/platform/linux/wayland/window.rs index 33c908d1b2..2b2207e22c 100644 --- a/crates/gpui/src/platform/linux/wayland/window.rs +++ b/crates/gpui/src/platform/linux/wayland/window.rs @@ -1,6 +1,3 @@ -use blade_graphics as gpu; -use collections::HashMap; -use futures::channel::oneshot::Receiver; use std::{ cell::{Ref, RefCell, RefMut}, ffi::c_void, @@ -9,14 +6,9 @@ use std::{ sync::Arc, }; -use crate::{ - Capslock, - platform::{ - PlatformAtlas, PlatformInputHandler, PlatformWindow, - blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}, - linux::wayland::{display::WaylandDisplay, serial::SerialKind}, - }, -}; +use blade_graphics as gpu; +use collections::HashMap; +use futures::channel::oneshot::Receiver; use raw_window_handle as rwh; use wayland_backend::client::ObjectId; @@ -28,8 +20,6 @@ use wayland_protocols::xdg::decoration::zv1::client::zxdg_toplevel_decoration_v1 use wayland_protocols::xdg::shell::client::xdg_surface; use wayland_protocols::xdg::shell::client::xdg_toplevel::{self}; use wayland_protocols_plasma::blur::client::org_kde_kwin_blur; -use wayland_protocols_wlr::layer_shell::v1::client::zwlr_layer_shell_v1::Layer; -use wayland_protocols_wlr::layer_shell::v1::client::zwlr_layer_surface_v1; use crate::scene::Scene; use crate::{ @@ -37,7 +27,15 @@ use crate::{ PlatformDisplay, PlatformInput, Point, PromptButton, PromptLevel, RequestFrameOptions, ResizeEdge, ScaledPixels, Size, Tiling, WaylandClientStatePtr, WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControlArea, WindowControls, WindowDecorations, - WindowKind, WindowParams, px, size, + WindowParams, px, size, +}; +use crate::{ + Capslock, + platform::{ + PlatformAtlas, PlatformInputHandler, PlatformWindow, + blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}, + linux::wayland::{display::WaylandDisplay, serial::SerialKind}, + }, }; #[derive(Default)] @@ -83,12 +81,14 @@ struct InProgressConfigure { } pub struct WaylandWindowState { - surface_state: WaylandSurfaceState, + xdg_surface: xdg_surface::XdgSurface, acknowledged_first_configure: bool, pub surface: wl_surface::WlSurface, + decoration: Option, app_id: Option, appearance: WindowAppearance, blur: Option, + toplevel: xdg_toplevel::XdgToplevel, viewport: Option, outputs: HashMap, display: Option<(ObjectId, Output)>, @@ -114,78 +114,6 @@ pub struct WaylandWindowState { client_inset: Option, } -pub enum WaylandSurfaceState { - Xdg(WaylandXdgSurfaceState), - LayerShell(WaylandLayerSurfaceState), -} - -pub struct WaylandXdgSurfaceState { - xdg_surface: xdg_surface::XdgSurface, - toplevel: xdg_toplevel::XdgToplevel, - decoration: Option, -} - -pub struct WaylandLayerSurfaceState { - layer_surface: zwlr_layer_surface_v1::ZwlrLayerSurfaceV1, -} - -impl WaylandSurfaceState { - fn ack_configure(&self, serial: u32) { - match self { - WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { xdg_surface, .. }) => { - xdg_surface.ack_configure(serial); - } - WaylandSurfaceState::LayerShell(WaylandLayerSurfaceState { layer_surface, .. }) => { - layer_surface.ack_configure(serial); - } - } - } - - fn decoration(&self) -> Option<&zxdg_toplevel_decoration_v1::ZxdgToplevelDecorationV1> { - if let WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { decoration, .. }) = self { - decoration.as_ref() - } else { - None - } - } - - fn toplevel(&self) -> Option<&xdg_toplevel::XdgToplevel> { - if let WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { toplevel, .. }) = self { - Some(toplevel) - } else { - None - } - } - - fn set_geometry(&self, x: i32, y: i32, width: i32, height: i32) { - match self { - WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { xdg_surface, .. }) => { - xdg_surface.set_window_geometry(x, y, width, height); - } - WaylandSurfaceState::LayerShell(WaylandLayerSurfaceState { layer_surface, .. }) => { - // cannot set window position of a layer surface - layer_surface.set_size(width as u32, height as u32); - } - } - } - - fn destroy(&mut self) { - match self { - WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { - xdg_surface, - toplevel, - decoration: _decoration, - }) => { - toplevel.destroy(); - xdg_surface.destroy(); - } - WaylandSurfaceState::LayerShell(WaylandLayerSurfaceState { layer_surface }) => { - layer_surface.destroy(); - } - } - } -} - #[derive(Clone)] pub struct WaylandWindowStatePtr { state: Rc>, @@ -196,7 +124,9 @@ impl WaylandWindowState { pub(crate) fn new( handle: AnyWindowHandle, surface: wl_surface::WlSurface, - surface_state: WaylandSurfaceState, + xdg_surface: xdg_surface::XdgSurface, + toplevel: xdg_toplevel::XdgToplevel, + decoration: Option, appearance: WindowAppearance, viewport: Option, client: WaylandClientStatePtr, @@ -226,11 +156,13 @@ impl WaylandWindowState { }; Ok(Self { - surface_state, + xdg_surface, acknowledged_first_configure: false, surface, + decoration, app_id: None, blur: None, + toplevel, viewport, globals, outputs: HashMap::default(), @@ -303,16 +235,17 @@ impl Drop for WaylandWindow { let client = state.client.clone(); state.renderer.destroy(); - if let Some(decoration) = &state.surface_state.decoration() { + if let Some(decoration) = &state.decoration { decoration.destroy(); } if let Some(blur) = &state.blur { blur.release(); } - state.surface_state.destroy(); + state.toplevel.destroy(); if let Some(viewport) = &state.viewport { viewport.destroy(); } + state.xdg_surface.destroy(); state.surface.destroy(); let state_ptr = self.0.clone(); @@ -346,65 +279,27 @@ impl WaylandWindow { appearance: WindowAppearance, ) -> anyhow::Result<(Self, ObjectId)> { let surface = globals.compositor.create_surface(&globals.qh, ()); + let xdg_surface = globals + .wm_base + .get_xdg_surface(&surface, &globals.qh, surface.id()); + let toplevel = xdg_surface.get_toplevel(&globals.qh, surface.id()); - let surface_state = match (params.kind, globals.layer_shell.as_ref()) { - // Matching on layer_shell here means that if kind is Overlay, but the compositor doesn't support layer_shell, - // we end up defaulting to xdg_surface anyway - (WindowKind::Overlay, Some(layer_shell)) => { - let layer_surface = layer_shell.get_layer_surface( - &surface, - None, - Layer::Overlay, - "".to_string(), - &globals.qh, - surface.id(), - ); - - let width = params.bounds.size.width.0; - let height = params.bounds.size.height.0; - layer_surface.set_size(width as u32, height as u32); - layer_surface.set_keyboard_interactivity( - zwlr_layer_surface_v1::KeyboardInteractivity::OnDemand, - ); - - WaylandSurfaceState::LayerShell(WaylandLayerSurfaceState { layer_surface }) - } - _ => { - let xdg_surface = - globals - .wm_base - .get_xdg_surface(&surface, &globals.qh, surface.id()); - - let toplevel = xdg_surface.get_toplevel(&globals.qh, surface.id()); - - if let Some(size) = params.window_min_size { - toplevel.set_min_size(size.width.0 as i32, size.height.0 as i32); - } - - // Attempt to set up window decorations based on the requested configuration - let decoration = globals - .decoration_manager - .as_ref() - .map(|decoration_manager| { - decoration_manager.get_toplevel_decoration( - &toplevel, - &globals.qh, - surface.id(), - ) - }); - - WaylandSurfaceState::Xdg(WaylandXdgSurfaceState { - xdg_surface, - toplevel, - decoration, - }) - } - }; + if let Some(size) = params.window_min_size { + toplevel.set_min_size(size.width.0 as i32, size.height.0 as i32); + } if let Some(fractional_scale_manager) = globals.fractional_scale_manager.as_ref() { fractional_scale_manager.get_fractional_scale(&surface, &globals.qh, surface.id()); } + // Attempt to set up window decorations based on the requested configuration + let decoration = globals + .decoration_manager + .as_ref() + .map(|decoration_manager| { + decoration_manager.get_toplevel_decoration(&toplevel, &globals.qh, surface.id()) + }); + let viewport = globals .viewporter .as_ref() @@ -414,7 +309,9 @@ impl WaylandWindow { state: Rc::new(RefCell::new(WaylandWindowState::new( handle, surface.clone(), - surface_state, + xdg_surface, + toplevel, + decoration, appearance, viewport, client, @@ -506,7 +403,7 @@ impl WaylandWindowStatePtr { } } let mut state = self.state.borrow_mut(); - state.surface_state.ack_configure(serial); + state.xdg_surface.ack_configure(serial); let window_geometry = inset_by_tiling( state.bounds.map_origin(|_| px(0.0)), @@ -516,7 +413,7 @@ impl WaylandWindowStatePtr { .map(|v| v.0 as i32) .map_size(|v| if v <= 0 { 1 } else { v }); - state.surface_state.set_geometry( + state.xdg_surface.set_window_geometry( window_geometry.origin.x, window_geometry.origin.y, window_geometry.size.width, @@ -681,42 +578,6 @@ impl WaylandWindowStatePtr { } } - pub fn handle_layersurface_event(&self, event: zwlr_layer_surface_v1::Event) -> bool { - match event { - zwlr_layer_surface_v1::Event::Configure { - width, - height, - serial, - } => { - let mut size = if width == 0 || height == 0 { - None - } else { - Some(size(px(width as f32), px(height as f32))) - }; - - let mut state = self.state.borrow_mut(); - state.in_progress_configure = Some(InProgressConfigure { - size, - fullscreen: false, - maximized: false, - resizing: false, - tiling: Tiling::default(), - }); - drop(state); - - // just do the same thing we'd do as an xdg_surface - self.handle_xdg_surface_event(xdg_surface::Event::Configure { serial }); - - false - } - zwlr_layer_surface_v1::Event::Closed => { - // unlike xdg, we don't have a choice here: the surface is closing. - true - } - _ => false, - } - } - #[allow(clippy::mutable_key_type)] pub fn handle_surface_event( &self, @@ -979,7 +840,7 @@ impl PlatformWindow for WaylandWindow { let state_ptr = self.0.clone(); let dp_size = size.to_device_pixels(self.scale_factor()); - state.surface_state.set_geometry( + state.xdg_surface.set_window_geometry( state.bounds.origin.x.0 as i32, state.bounds.origin.y.0 as i32, dp_size.width.0, @@ -1073,16 +934,12 @@ impl PlatformWindow for WaylandWindow { } fn set_title(&mut self, title: &str) { - if let Some(toplevel) = self.borrow().surface_state.toplevel() { - toplevel.set_title(title.to_string()); - } + self.borrow().toplevel.set_title(title.to_string()); } fn set_app_id(&mut self, app_id: &str) { let mut state = self.borrow_mut(); - if let Some(toplevel) = self.borrow().surface_state.toplevel() { - toplevel.set_app_id(app_id.to_owned()); - } + state.toplevel.set_app_id(app_id.to_owned()); state.app_id = Some(app_id.to_owned()); } @@ -1093,30 +950,24 @@ impl PlatformWindow for WaylandWindow { } fn minimize(&self) { - if let Some(toplevel) = self.borrow().surface_state.toplevel() { - toplevel.set_minimized(); - } + self.borrow().toplevel.set_minimized(); } fn zoom(&self) { let state = self.borrow(); - if let Some(toplevel) = state.surface_state.toplevel() { - if !state.maximized { - toplevel.set_maximized(); - } else { - toplevel.unset_maximized(); - } + if !state.maximized { + state.toplevel.set_maximized(); + } else { + state.toplevel.unset_maximized(); } } fn toggle_fullscreen(&self) { - let mut state = self.borrow(); - if let Some(toplevel) = state.surface_state.toplevel() { - if !state.fullscreen { - toplevel.set_fullscreen(None); - } else { - toplevel.unset_fullscreen(); - } + let mut state = self.borrow_mut(); + if !state.fullscreen { + state.toplevel.set_fullscreen(None); + } else { + state.toplevel.unset_fullscreen(); } } @@ -1181,33 +1032,27 @@ impl PlatformWindow for WaylandWindow { fn show_window_menu(&self, position: Point) { let state = self.borrow(); let serial = state.client.get_serial(SerialKind::MousePress); - if let Some(toplevel) = state.surface_state.toplevel() { - toplevel.show_window_menu( - &state.globals.seat, - serial, - position.x.0 as i32, - position.y.0 as i32, - ); - } + state.toplevel.show_window_menu( + &state.globals.seat, + serial, + position.x.0 as i32, + position.y.0 as i32, + ); } fn start_window_move(&self) { let state = self.borrow(); let serial = state.client.get_serial(SerialKind::MousePress); - if let Some(toplevel) = state.surface_state.toplevel() { - toplevel._move(&state.globals.seat, serial); - } + state.toplevel._move(&state.globals.seat, serial); } fn start_window_resize(&self, edge: crate::ResizeEdge) { let state = self.borrow(); - if let Some(toplevel) = state.surface_state.toplevel() { - toplevel.resize( - &state.globals.seat, - state.client.get_serial(SerialKind::MousePress), - edge.to_xdg(), - ) - } + state.toplevel.resize( + &state.globals.seat, + state.client.get_serial(SerialKind::MousePress), + edge.to_xdg(), + ) } fn window_decorations(&self) -> Decorations { @@ -1223,7 +1068,7 @@ impl PlatformWindow for WaylandWindow { fn request_decorations(&self, decorations: WindowDecorations) { let mut state = self.borrow_mut(); state.decorations = decorations; - if let Some(decoration) = state.surface_state.decoration() { + if let Some(decoration) = state.decoration.as_ref() { decoration.set_mode(decorations.to_xdg()); update_window(state); } diff --git a/crates/gpui/src/platform/linux/x11/client.rs b/crates/gpui/src/platform/linux/x11/client.rs index d1cb7d00cc..573e4addf7 100644 --- a/crates/gpui/src/platform/linux/x11/client.rs +++ b/crates/gpui/src/platform/linux/x11/client.rs @@ -1004,12 +1004,13 @@ impl X11Client { let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Down); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Down); + if let Some(mut compose_state) = state.compose_state.take() { compose_state.feed(keysym); match compose_state.status() { @@ -1067,12 +1068,13 @@ impl X11Client { let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code); let keysym = state.xkb.key_get_one_sym(code); - // should be called after key_get_one_sym - state.xkb.update_key(code, xkbc::KeyDirection::Up); - if keysym.is_modifier_key() { return Some(()); } + + // should be called after key_get_one_sym + state.xkb.update_key(code, xkbc::KeyDirection::Up); + keystroke }; drop(state); @@ -1793,6 +1795,7 @@ impl X11ClientState { drop(state); window.refresh(RequestFrameOptions { require_presentation: expose_event_received, + force_render: false, }); } xcb_connection diff --git a/crates/gpui/src/platform/mac/window.rs b/crates/gpui/src/platform/mac/window.rs index f01d33147b..aedf131909 100644 --- a/crates/gpui/src/platform/mac/window.rs +++ b/crates/gpui/src/platform/mac/window.rs @@ -559,7 +559,7 @@ impl MacWindow { } let native_window: id = match kind { - WindowKind::Normal | WindowKind::Overlay => msg_send![WINDOW_CLASS, alloc], + WindowKind::Normal => msg_send![WINDOW_CLASS, alloc], WindowKind::PopUp => { style_mask |= NSWindowStyleMaskNonactivatingPanel; msg_send![PANEL_CLASS, alloc] @@ -711,7 +711,7 @@ impl MacWindow { native_window.makeFirstResponder_(native_view); match kind { - WindowKind::Normal | WindowKind::Overlay => { + WindowKind::Normal => { native_window.setLevel_(NSNormalWindowLevel); native_window.setAcceptsMouseMovedEvents_(YES); } diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 4bdf42080d..5268d3ccba 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,6 +1,8 @@ mod clipboard; mod destination_list; mod direct_write; +mod directx_atlas; +mod directx_renderer; mod dispatcher; mod display; mod events; @@ -14,6 +16,8 @@ mod wrapper; pub(crate) use clipboard::*; pub(crate) use destination_list::*; pub(crate) use direct_write::*; +pub(crate) use directx_atlas::*; +pub(crate) use directx_renderer::*; pub(crate) use dispatcher::*; pub(crate) use display::*; pub(crate) use events::*; diff --git a/crates/gpui/src/platform/windows/color_text_raster.hlsl b/crates/gpui/src/platform/windows/color_text_raster.hlsl new file mode 100644 index 0000000000..ccc5fa26f0 --- /dev/null +++ b/crates/gpui/src/platform/windows/color_text_raster.hlsl @@ -0,0 +1,39 @@ +struct RasterVertexOutput { + float4 position : SV_Position; + float2 texcoord : TEXCOORD0; +}; + +RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID) +{ + RasterVertexOutput output; + output.texcoord = float2((vertexID << 1) & 2, vertexID & 2); + output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f); + output.position.y = -output.position.y; + + return output; +} + +struct PixelInput { + float4 position: SV_Position; + float2 texcoord : TEXCOORD0; +}; + +struct Bounds { + int2 origin; + int2 size; +}; + +Texture2D t_layer : register(t0); +SamplerState s_layer : register(s0); + +cbuffer GlyphLayerTextureParams : register(b0) { + Bounds bounds; + float4 run_color; +}; + +float4 emoji_rasterization_fragment(PixelInput input): SV_Target { + float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb; + float alpha = (sampled.r + sampled.g + sampled.b) / 3; + + return float4(run_color.rgb, alpha); +} diff --git a/crates/gpui/src/platform/windows/direct_write.rs b/crates/gpui/src/platform/windows/direct_write.rs index ada306c15c..587cb7b4a6 100644 --- a/crates/gpui/src/platform/windows/direct_write.rs +++ b/crates/gpui/src/platform/windows/direct_write.rs @@ -10,10 +10,11 @@ use windows::{ Foundation::*, Globalization::GetUserDefaultLocaleName, Graphics::{ - Direct2D::{Common::*, *}, + Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + Direct3D11::*, DirectWrite::*, Dxgi::Common::*, - Gdi::LOGFONTW, + Gdi::{IsRectEmpty, LOGFONTW}, Imaging::*, }, System::SystemServices::LOCALE_NAME_MAX_LENGTH, @@ -40,16 +41,21 @@ struct DirectWriteComponent { locale: String, factory: IDWriteFactory5, bitmap_factory: AgileReference, - d2d1_factory: ID2D1Factory, in_memory_loader: IDWriteInMemoryFontFileLoader, builder: IDWriteFontSetBuilder1, text_renderer: Arc, - render_context: GlyphRenderContext, + + render_params: IDWriteRenderingParams3, + gpu_state: GPUState, } -struct GlyphRenderContext { - params: IDWriteRenderingParams3, - dc_target: ID2D1DeviceContext4, +struct GPUState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + sampler: [Option; 1], + blend_state: ID3D11BlendState, + vertex_shader: ID3D11VertexShader, + pixel_shader: ID3D11PixelShader, } struct DirectWriteState { @@ -70,12 +76,11 @@ struct FontIdentifier { } impl DirectWriteComponent { - pub fn new(bitmap_factory: &IWICImagingFactory) -> Result { + pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result { + // todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing unsafe { let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?; let bitmap_factory = AgileReference::new(bitmap_factory)?; - let d2d1_factory: ID2D1Factory = - D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?; // The `IDWriteInMemoryFontFileLoader` here is supported starting from // Windows 10 Creators Update, which consequently requires the entire // `DirectWriteTextSystem` to run on `win10 1703`+. @@ -86,60 +91,132 @@ impl DirectWriteComponent { GetUserDefaultLocaleName(&mut locale_vec); let locale = String::from_utf16_lossy(&locale_vec); let text_renderer = Arc::new(TextRendererWrapper::new(&locale)); - let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?; + + let render_params = { + let default_params: IDWriteRenderingParams3 = + factory.CreateRenderingParams()?.cast()?; + let gamma = default_params.GetGamma(); + let enhanced_contrast = default_params.GetEnhancedContrast(); + let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); + let cleartype_level = default_params.GetClearTypeLevel(); + let grid_fit_mode = default_params.GetGridFitMode(); + + factory.CreateCustomRenderingParams( + gamma, + enhanced_contrast, + gray_contrast, + cleartype_level, + DWRITE_PIXEL_GEOMETRY_RGB, + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + grid_fit_mode, + )? + }; + + let gpu_state = GPUState::new(gpu_context)?; Ok(DirectWriteComponent { locale, factory, bitmap_factory, - d2d1_factory, in_memory_loader, builder, text_renderer, - render_context, + render_params, + gpu_state, }) } } } -impl GlyphRenderContext { - pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result { - unsafe { - let default_params: IDWriteRenderingParams3 = - factory.CreateRenderingParams()?.cast()?; - let gamma = default_params.GetGamma(); - let enhanced_contrast = default_params.GetEnhancedContrast(); - let gray_contrast = default_params.GetGrayscaleEnhancedContrast(); - let cleartype_level = default_params.GetClearTypeLevel(); - let grid_fit_mode = default_params.GetGridFitMode(); +impl GPUState { + fn new(gpu_context: &DirectXDevices) -> Result { + let device = gpu_context.device.clone(); + let device_context = gpu_context.device_context.clone(); - let params = factory.CreateCustomRenderingParams( - gamma, - enhanced_contrast, - gray_contrast, - cleartype_level, - DWRITE_PIXEL_GEOMETRY_RGB, - DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, - grid_fit_mode, - )?; - let dc_target = { - let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ))?; - let target = target.cast::()?; - target.SetTextRenderingParams(¶ms); - target + let blend_state = { + let mut blend_state = None; + let desc = D3D11_BLEND_DESC { + AlphaToCoverageEnable: false.into(), + IndependentBlendEnable: false.into(), + RenderTarget: [ + D3D11_RENDER_TARGET_BLEND_DESC { + BlendEnable: true.into(), + SrcBlend: D3D11_BLEND_SRC_ALPHA, + DestBlend: D3D11_BLEND_INV_SRC_ALPHA, + BlendOp: D3D11_BLEND_OP_ADD, + SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA, + DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA, + BlendOpAlpha: D3D11_BLEND_OP_ADD, + RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8, + }, + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ], }; + unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?; + blend_state.unwrap() + }; - Ok(Self { params, dc_target }) - } + let sampler = { + let mut sampler = None; + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_POINT, + AddressU: D3D11_TEXTURE_ADDRESS_BORDER, + AddressV: D3D11_TEXTURE_ADDRESS_BORDER, + AddressW: D3D11_TEXTURE_ADDRESS_BORDER, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0, 0.0, 0.0, 0.0], + MinLOD: 0.0, + MaxLOD: 0.0, + }; + unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?; + [sampler] + }; + + let vertex_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Vertex, + )?; + let mut shader = None; + unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + let pixel_shader = { + let source = shader_resources::RawShaderBytes::new( + shader_resources::ShaderModule::EmojiRasterization, + shader_resources::ShaderTarget::Fragment, + )?; + let mut shader = None; + unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?; + shader.unwrap() + }; + + Ok(Self { + device, + device_context, + sampler, + blend_state, + vertex_shader, + pixel_shader, + }) } } impl DirectWriteTextSystem { - pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result { - let components = DirectWriteComponent::new(bitmap_factory)?; + pub(crate) fn new( + gpu_context: &DirectXDevices, + bitmap_factory: &IWICImagingFactory, + ) -> Result { + let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?; let system_font_collection = unsafe { let mut result = std::mem::zeroed(); components @@ -648,15 +725,13 @@ impl DirectWriteState { } } - fn raster_bounds(&self, params: &RenderGlyphParams) -> Result> { - let render_target = &self.components.render_context.dc_target; - unsafe { - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor); - } + fn create_glyph_run_analysis( + &self, + params: &RenderGlyphParams, + ) -> Result { let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; - let advance = [0.0f32]; + let advance = [0.0]; let offset = [DWRITE_GLYPH_OFFSET::default()]; let glyph_run = DWRITE_GLYPH_RUN { fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, @@ -668,44 +743,87 @@ impl DirectWriteState { isSideways: BOOL(0), bidiLevel: 0, }; - let bounds = unsafe { - render_target.GetGlyphRunWorldBounds( - Vector2 { X: 0.0, Y: 0.0 }, - &glyph_run, - DWRITE_MEASURING_MODE_NATURAL, - )? + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, }; - // todo(windows) - // This is a walkaround, deleted when figured out. - let y_offset; - let extra_height; - if params.is_emoji { - y_offset = 0; - extra_height = 0; - } else { - // make some room for scaler. - y_offset = -1; - extra_height = 2; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let mut rendering_mode = DWRITE_RENDERING_MODE1::default(); + let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default(); + unsafe { + font.font_face.GetRecommendedRenderingMode( + params.font_size.0, + // The dpi here seems that it has the same effect with `Some(&transform)` + 1.0, + 1.0, + Some(&transform), + false, + DWRITE_OUTLINE_THRESHOLD_ANTIALIASED, + DWRITE_MEASURING_MODE_NATURAL, + &self.components.render_params, + &mut rendering_mode, + &mut grid_fit_mode, + )?; } - if bounds.right < bounds.left { + let glyph_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &glyph_run, + Some(&transform), + rendering_mode, + DWRITE_MEASURING_MODE_NATURAL, + grid_fit_mode, + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + Ok(glyph_analysis) + } + + fn raster_bounds(&self, params: &RenderGlyphParams) -> Result> { + let glyph_analysis = self.create_glyph_run_analysis(params)?; + + let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? }; + // Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case + // GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet. + if !unsafe { IsRectEmpty(&bounds) }.as_bool() { Ok(Bounds { - origin: point(0.into(), 0.into()), - size: size(0.into(), 0.into()), + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), }) } else { - Ok(Bounds { - origin: point( - ((bounds.left * params.scale_factor).ceil() as i32).into(), - ((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(), - ), - size: size( - (((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(), - (((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32 - + extra_height) - .into(), - ), - }) + // If it's empty, retry with grayscale AA. + let bounds = + unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? }; + + if bounds.right < bounds.left { + Ok(Bounds { + origin: point(0.into(), 0.into()), + size: size(0.into(), 0.into()), + }) + } else { + Ok(Bounds { + origin: point(bounds.left.into(), bounds.top.into()), + size: size( + (bounds.right - bounds.left).into(), + (bounds.bottom - bounds.top).into(), + ), + }) + } } } @@ -731,7 +849,95 @@ impl DirectWriteState { anyhow::bail!("glyph bounds are empty"); } - let font_info = &self.fonts[params.font_id.0]; + let bitmap_data = if params.is_emoji { + if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) { + color + } else { + let monochrome = self.rasterize_monochrome(params, glyph_bounds)?; + monochrome + .into_iter() + .flat_map(|pixel| [0, 0, 0, pixel]) + .collect::>() + } + } else { + self.rasterize_monochrome(params, glyph_bounds)? + }; + + Ok((glyph_bounds.size, bitmap_data)) + } + + fn rasterize_monochrome( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds, + ) -> Result> { + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3]; + + let glyph_analysis = self.create_glyph_run_analysis(params)?; + unsafe { + glyph_analysis.CreateAlphaTexture( + // We're using cleartype not grayscale for monochrome is because it provides better quality + DWRITE_TEXTURE_CLEARTYPE_3x1, + &RECT { + left: glyph_bounds.origin.x.0, + top: glyph_bounds.origin.y.0, + right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0, + bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0, + }, + &mut bitmap_data, + )?; + } + + let bitmap_factory = self.components.bitmap_factory.resolve()?; + let bitmap = unsafe { + bitmap_factory.CreateBitmapFromMemory( + glyph_bounds.size.width.0 as u32, + glyph_bounds.size.height.0 as u32, + &GUID_WICPixelFormat24bppRGB, + glyph_bounds.size.width.0 as u32 * 3, + &bitmap_data, + ) + }?; + + let grayscale_bitmap = + unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?; + + let mut bitmap_data = + vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize]; + unsafe { + grayscale_bitmap.CopyPixels( + std::ptr::null() as _, + glyph_bounds.size.width.0 as u32, + &mut bitmap_data, + ) + }?; + + Ok(bitmap_data) + } + + fn rasterize_color( + &self, + params: &RenderGlyphParams, + glyph_bounds: Bounds, + ) -> Result> { + let bitmap_size = glyph_bounds.size; + let subpixel_shift = params + .subpixel_variant + .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); + let baseline_origin_x = subpixel_shift.x / params.scale_factor; + let baseline_origin_y = subpixel_shift.y / params.scale_factor; + + let transform = DWRITE_MATRIX { + m11: params.scale_factor, + m12: 0.0, + m21: 0.0, + m22: params.scale_factor, + dx: 0.0, + dy: 0.0, + }; + + let font = &self.fonts[params.font_id.0]; let glyph_id = [params.glyph_id.0 as u16]; let advance = [glyph_bounds.size.width.0 as f32]; let offset = [DWRITE_GLYPH_OFFSET { @@ -739,7 +945,7 @@ impl DirectWriteState { ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor, }]; let glyph_run = DWRITE_GLYPH_RUN { - fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) }, + fontFace: unsafe { std::mem::transmute_copy(&font.font_face) }, fontEmSize: params.font_size.0, glyphCount: 1, glyphIndices: glyph_id.as_ptr(), @@ -749,160 +955,254 @@ impl DirectWriteState { bidiLevel: 0, }; - // Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing. - let mut bitmap_size = glyph_bounds.size; - if params.subpixel_variant.x > 0 { - bitmap_size.width += DevicePixels(1); - } - if params.subpixel_variant.y > 0 { - bitmap_size.height += DevicePixels(1); - } - let bitmap_size = bitmap_size; + // todo: support formats other than COLR + let color_enumerator = unsafe { + self.components.factory.TranslateColorGlyphRun( + Vector2::new(baseline_origin_x, baseline_origin_y), + &glyph_run, + None, + DWRITE_GLYPH_IMAGE_FORMATS_COLR, + DWRITE_MEASURING_MODE_NATURAL, + Some(&transform), + 0, + ) + }?; - let total_bytes; - let bitmap_format; - let render_target_property; - let bitmap_width; - let bitmap_height; - let bitmap_stride; - let bitmap_dpi; - if params.is_emoji { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4; - bitmap_format = &GUID_WICPixelFormat32bppPBGRA; - render_target_property = get_render_target_property( - DXGI_FORMAT_B8G8R8A8_UNORM, - D2D1_ALPHA_MODE_PREMULTIPLIED, - ); - bitmap_width = bitmap_size.width.0 as u32; - bitmap_height = bitmap_size.height.0 as u32; - bitmap_stride = bitmap_size.width.0 as u32 * 4; - bitmap_dpi = 96.0; - } else { - total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize; - bitmap_format = &GUID_WICPixelFormat8bppAlpha; - render_target_property = - get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT); - bitmap_width = bitmap_size.width.0 as u32 * 2; - bitmap_height = bitmap_size.height.0 as u32 * 2; - bitmap_stride = bitmap_size.width.0 as u32; - bitmap_dpi = 192.0; + let mut glyph_layers = Vec::new(); + loop { + let color_run = unsafe { color_enumerator.GetCurrentRun() }?; + let color_run = unsafe { &*color_run }; + let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE; + if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR { + let color_analysis = unsafe { + self.components.factory.CreateGlyphRunAnalysis( + &color_run.Base.glyphRun as *const _, + Some(&transform), + DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC, + DWRITE_MEASURING_MODE_NATURAL, + DWRITE_GRID_FIT_MODE_DEFAULT, + DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE, + baseline_origin_x, + baseline_origin_y, + ) + }?; + + let color_bounds = + unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?; + + let color_size = size( + color_bounds.right - color_bounds.left, + color_bounds.bottom - color_bounds.top, + ); + if color_size.width > 0 && color_size.height > 0 { + let mut alpha_data = + vec![0u8; (color_size.width * color_size.height * 3) as usize]; + unsafe { + color_analysis.CreateAlphaTexture( + DWRITE_TEXTURE_CLEARTYPE_3x1, + &color_bounds, + &mut alpha_data, + ) + }?; + + let run_color = { + let run_color = color_run.Base.runColor; + Rgba { + r: run_color.r, + g: run_color.g, + b: run_color.b, + a: run_color.a, + } + }; + let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size); + let alpha_data = alpha_data + .chunks_exact(3) + .flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255]) + .collect::>(); + glyph_layers.push(GlyphLayerTexture::new( + &self.components.gpu_state, + run_color, + bounds, + &alpha_data, + )?); + } + } + + let has_next = unsafe { color_enumerator.MoveNext() } + .map(|e| e.as_bool()) + .unwrap_or(false); + if !has_next { + break; + } } - let bitmap_factory = self.components.bitmap_factory.resolve()?; - unsafe { - let bitmap = bitmap_factory.CreateBitmap( - bitmap_width, - bitmap_height, - bitmap_format, - WICBitmapCacheOnLoad, - )?; - let render_target = self - .components - .d2d1_factory - .CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?; - let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?; - let subpixel_shift = params - .subpixel_variant - .map(|v| v as f32 / SUBPIXEL_VARIANTS as f32); - let baseline_origin = Vector2 { - X: subpixel_shift.x / params.scale_factor, - Y: subpixel_shift.y / params.scale_factor, + let gpu_state = &self.components.gpu_state; + let params_buffer = { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + StructureByteStride: 0, }; - // This `cast()` action here should never fail since we are running on Win10+, and - // ID2D1DeviceContext4 requires Win8+ - let render_target = render_target.cast::().unwrap(); - render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS); - render_target.SetDpi( - bitmap_dpi * params.scale_factor, - bitmap_dpi * params.scale_factor, - ); - render_target.SetTextRenderingParams(&self.components.render_context.params); - render_target.BeginDraw(); + let mut buffer = None; + unsafe { + gpu_state + .device + .CreateBuffer(&desc, None, Some(&mut buffer)) + }?; + [buffer] + }; - if params.is_emoji { - // WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested - let enumerator = self.components.factory.TranslateColorGlyphRun( - baseline_origin, - &glyph_run as _, - None, - DWRITE_GLYPH_IMAGE_FORMATS_COLR - | DWRITE_GLYPH_IMAGE_FORMATS_SVG - | DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8, - DWRITE_MEASURING_MODE_NATURAL, - None, + let render_target_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let render_target_view = { + let desc = D3D11_RENDER_TARGET_VIEW_DESC { + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D, + Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 { + Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 }, + }, + }; + let mut rtv = None; + unsafe { + gpu_state.device.CreateRenderTargetView( + &render_target_texture, + Some(&desc), + Some(&mut rtv), + ) + }?; + [rtv] + }; + + let staging_texture = { + let mut texture = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: bitmap_size.width.0 as u32, + Height: bitmap_size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_B8G8R8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_STAGING, + BindFlags: 0, + CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32, + MiscFlags: 0, + }; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture)) + }?; + texture.unwrap() + }; + + let device_context = &gpu_state.device_context; + unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) }; + unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) }; + unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) }; + unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) }; + unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) }; + unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) }; + unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) }; + + for layer in glyph_layers { + let params = GlyphLayerTextureParams { + run_color: layer.run_color, + bounds: layer.bounds, + }; + unsafe { + let mut dest = std::mem::zeroed(); + gpu_state.device_context.Map( + params_buffer[0].as_ref().unwrap(), 0, + D3D11_MAP_WRITE_DISCARD, + 0, + Some(&mut dest), )?; - while enumerator.MoveNext().is_ok() { - let Ok(color_glyph) = enumerator.GetCurrentRun() else { - break; - }; - let color_glyph = &*color_glyph; - let brush_color = translate_color(&color_glyph.Base.runColor); - brush.SetColor(&brush_color); - match color_glyph.glyphImageFormat { - DWRITE_GLYPH_IMAGE_FORMATS_PNG - | DWRITE_GLYPH_IMAGE_FORMATS_JPEG - | DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target - .DrawColorBitmapGlyphRun( - color_glyph.glyphImageFormat, - baseline_origin, - &color_glyph.Base.glyphRun, - color_glyph.measuringMode, - D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT, - ), - DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - &brush, - None, - color_glyph.Base.paletteIndex as u32, - color_glyph.measuringMode, - ), - _ => render_target.DrawGlyphRun( - baseline_origin, - &color_glyph.Base.glyphRun, - Some(color_glyph.Base.glyphRunDescription as *const _), - &brush, - color_glyph.measuringMode, - ), - } - } - } else { - render_target.DrawGlyphRun( - baseline_origin, - &glyph_run, - None, - &brush, - DWRITE_MEASURING_MODE_NATURAL, - ); - } - render_target.EndDraw(None, None)?; + std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1); + gpu_state + .device_context + .Unmap(params_buffer[0].as_ref().unwrap(), 0); + }; - let mut raw_data = vec![0u8; total_bytes]; - if params.is_emoji { - bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - // Convert from BGRA with premultiplied alpha to BGRA with straight alpha. - for pixel in raw_data.chunks_exact_mut(4) { - let a = pixel[3] as f32 / 255.; - pixel[0] = (pixel[0] as f32 / a) as u8; - pixel[1] = (pixel[1] as f32 / a) as u8; - pixel[2] = (pixel[2] as f32 / a) as u8; - } - } else { - let scaler = bitmap_factory.CreateBitmapScaler()?; - scaler.Initialize( - &bitmap, - bitmap_size.width.0 as u32, - bitmap_size.height.0 as u32, - WICBitmapInterpolationModeHighQualityCubic, - )?; - scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?; - } - Ok((bitmap_size, raw_data)) + let texture = [Some(layer.texture_view)]; + unsafe { device_context.PSSetShaderResources(0, Some(&texture)) }; + + let viewport = [D3D11_VIEWPORT { + TopLeftX: layer.bounds.origin.x as f32, + TopLeftY: layer.bounds.origin.y as f32, + Width: layer.bounds.size.width as f32, + Height: layer.bounds.size.height as f32, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + + unsafe { device_context.Draw(4, 0) }; } + + unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) }; + + let mapped_data = { + let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default(); + unsafe { + device_context.Map( + &staging_texture, + 0, + D3D11_MAP_READ, + 0, + Some(&mut mapped_data), + ) + }?; + mapped_data + }; + let mut rasterized = + vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize]; + + for y in 0..bitmap_size.height.0 as usize { + let width = bitmap_size.width.0 as usize; + unsafe { + std::ptr::copy_nonoverlapping::( + (mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y), + rasterized + .as_mut_ptr() + .byte_add(width * y * std::mem::size_of::()), + width * std::mem::size_of::(), + ) + }; + } + + Ok(rasterized) } fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result> { @@ -976,6 +1276,84 @@ impl Drop for DirectWriteState { } } +struct GlyphLayerTexture { + run_color: Rgba, + bounds: Bounds, + texture_view: ID3D11ShaderResourceView, + // holding on to the texture to not RAII drop it + _texture: ID3D11Texture2D, +} + +impl GlyphLayerTexture { + pub fn new( + gpu_state: &GPUState, + run_color: Rgba, + bounds: Bounds, + alpha_data: &[u8], + ) -> Result { + let texture_size = bounds.size; + + let desc = D3D11_TEXTURE2D_DESC { + Width: texture_size.width as u32, + Height: texture_size.height as u32, + MipLevels: 1, + ArraySize: 1, + Format: DXGI_FORMAT_R8G8B8A8_UNORM, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + + let texture = { + let mut texture: Option = None; + unsafe { + gpu_state + .device + .CreateTexture2D(&desc, None, Some(&mut texture))? + }; + texture.unwrap() + }; + let texture_view = { + let mut view: Option = None; + unsafe { + gpu_state + .device + .CreateShaderResourceView(&texture, None, Some(&mut view))? + }; + view.unwrap() + }; + + unsafe { + gpu_state.device_context.UpdateSubresource( + &texture, + 0, + None, + alpha_data.as_ptr() as _, + (texture_size.width * 4) as u32, + 0, + ) + }; + + Ok(GlyphLayerTexture { + run_color, + bounds, + texture_view, + _texture: texture, + }) + } +} + +#[repr(C)] +struct GlyphLayerTextureParams { + bounds: Bounds, + run_color: Rgba, +} + struct TextRendererWrapper(pub IDWriteTextRenderer); impl TextRendererWrapper { @@ -1470,16 +1848,6 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result { Ok(String::from_utf16_lossy(&name_vec[..name_length])) } -#[inline] -fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F { - D2D1_COLOR_F { - r: color.r, - g: color.g, - b: color.b, - a: color.a, - } -} - fn get_system_ui_font_name() -> SharedString { unsafe { let mut info: LOGFONTW = std::mem::zeroed(); @@ -1504,24 +1872,6 @@ fn get_system_ui_font_name() -> SharedString { } } -#[inline] -fn get_render_target_property( - pixel_format: DXGI_FORMAT, - alpha_mode: D2D1_ALPHA_MODE, -) -> D2D1_RENDER_TARGET_PROPERTIES { - D2D1_RENDER_TARGET_PROPERTIES { - r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT, - pixelFormat: D2D1_PIXEL_FORMAT { - format: pixel_format, - alphaMode: alpha_mode, - }, - dpiX: 96.0, - dpiY: 96.0, - usage: D2D1_RENDER_TARGET_USAGE_NONE, - minLevel: D2D1_FEATURE_LEVEL_DEFAULT, - } -} - // One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats // but that doesn't seem to work for some glyphs, say ❤ fn is_color_glyph( @@ -1561,12 +1911,6 @@ fn is_color_glyph( } const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US"); -const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F { - r: 1.0, - g: 1.0, - b: 1.0, - a: 1.0, -}; #[cfg(test)] mod tests { diff --git a/crates/gpui/src/platform/windows/directx_atlas.rs b/crates/gpui/src/platform/windows/directx_atlas.rs new file mode 100644 index 0000000000..6bced4c11d --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_atlas.rs @@ -0,0 +1,309 @@ +use collections::FxHashMap; +use etagere::BucketedAtlasAllocator; +use parking_lot::Mutex; +use windows::Win32::Graphics::{ + Direct3D11::{ + D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC, + D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView, + ID3D11Texture2D, + }, + Dxgi::Common::*, +}; + +use crate::{ + AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas, + Point, Size, platform::AtlasTextureList, +}; + +pub(crate) struct DirectXAtlas(Mutex); + +struct DirectXAtlasState { + device: ID3D11Device, + device_context: ID3D11DeviceContext, + monochrome_textures: AtlasTextureList, + polychrome_textures: AtlasTextureList, + tiles_by_key: FxHashMap, +} + +struct DirectXAtlasTexture { + id: AtlasTextureId, + bytes_per_pixel: u32, + allocator: BucketedAtlasAllocator, + texture: ID3D11Texture2D, + view: [Option; 1], + live_atlas_keys: u32, +} + +impl DirectXAtlas { + pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self { + DirectXAtlas(Mutex::new(DirectXAtlasState { + device: device.clone(), + device_context: device_context.clone(), + monochrome_textures: Default::default(), + polychrome_textures: Default::default(), + tiles_by_key: Default::default(), + })) + } + + pub(crate) fn get_texture_view( + &self, + id: AtlasTextureId, + ) -> [Option; 1] { + let lock = self.0.lock(); + let tex = lock.texture(id); + tex.view.clone() + } + + pub(crate) fn handle_device_lost( + &self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + ) { + let mut lock = self.0.lock(); + lock.device = device.clone(); + lock.device_context = device_context.clone(); + lock.monochrome_textures = AtlasTextureList::default(); + lock.polychrome_textures = AtlasTextureList::default(); + lock.tiles_by_key.clear(); + } +} + +impl PlatformAtlas for DirectXAtlas { + fn get_or_insert_with<'a>( + &self, + key: &AtlasKey, + build: &mut dyn FnMut() -> anyhow::Result< + Option<(Size, std::borrow::Cow<'a, [u8]>)>, + >, + ) -> anyhow::Result> { + let mut lock = self.0.lock(); + if let Some(tile) = lock.tiles_by_key.get(key) { + Ok(Some(tile.clone())) + } else { + let Some((size, bytes)) = build()? else { + return Ok(None); + }; + let tile = lock + .allocate(size, key.texture_kind()) + .ok_or_else(|| anyhow::anyhow!("failed to allocate"))?; + let texture = lock.texture(tile.texture_id); + texture.upload(&lock.device_context, tile.bounds, &bytes); + lock.tiles_by_key.insert(key.clone(), tile.clone()); + Ok(Some(tile)) + } + } + + fn remove(&self, key: &AtlasKey) { + let mut lock = self.0.lock(); + + let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else { + return; + }; + + let textures = match id.kind { + AtlasTextureKind::Monochrome => &mut lock.monochrome_textures, + AtlasTextureKind::Polychrome => &mut lock.polychrome_textures, + }; + + let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else { + return; + }; + + if let Some(mut texture) = texture_slot.take() { + texture.decrement_ref_count(); + if texture.is_unreferenced() { + textures.free_list.push(texture.id.index as usize); + lock.tiles_by_key.remove(key); + } else { + *texture_slot = Some(texture); + } + } + } +} + +impl DirectXAtlasState { + fn allocate( + &mut self, + size: Size, + texture_kind: AtlasTextureKind, + ) -> Option { + { + let textures = match texture_kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + + if let Some(tile) = textures + .iter_mut() + .rev() + .find_map(|texture| texture.allocate(size)) + { + return Some(tile); + } + } + + let texture = self.push_texture(size, texture_kind)?; + texture.allocate(size) + } + + fn push_texture( + &mut self, + min_size: Size, + kind: AtlasTextureKind, + ) -> Option<&mut DirectXAtlasTexture> { + const DEFAULT_ATLAS_SIZE: Size = Size { + width: DevicePixels(1024), + height: DevicePixels(1024), + }; + // Max texture size for DirectX. See: + // https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits + const MAX_ATLAS_SIZE: Size = Size { + width: DevicePixels(16384), + height: DevicePixels(16384), + }; + let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE); + let pixel_format; + let bind_flag; + let bytes_per_pixel; + match kind { + AtlasTextureKind::Monochrome => { + pixel_format = DXGI_FORMAT_R8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 1; + } + AtlasTextureKind::Polychrome => { + pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM; + bind_flag = D3D11_BIND_SHADER_RESOURCE; + bytes_per_pixel = 4; + } + } + let texture_desc = D3D11_TEXTURE2D_DESC { + Width: size.width.0 as u32, + Height: size.height.0 as u32, + MipLevels: 1, + ArraySize: 1, + Format: pixel_format, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: bind_flag.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: 0, + }; + let mut texture: Option = None; + unsafe { + // This only returns None if the device is lost, which we will recreate later. + // So it's ok to return None here. + self.device + .CreateTexture2D(&texture_desc, None, Some(&mut texture)) + .ok()?; + } + let texture = texture.unwrap(); + + let texture_list = match kind { + AtlasTextureKind::Monochrome => &mut self.monochrome_textures, + AtlasTextureKind::Polychrome => &mut self.polychrome_textures, + }; + let index = texture_list.free_list.pop(); + let view = unsafe { + let mut view = None; + self.device + .CreateShaderResourceView(&texture, None, Some(&mut view)) + .ok()?; + [view] + }; + let atlas_texture = DirectXAtlasTexture { + id: AtlasTextureId { + index: index.unwrap_or(texture_list.textures.len()) as u32, + kind, + }, + bytes_per_pixel, + allocator: etagere::BucketedAtlasAllocator::new(size.into()), + texture, + view, + live_atlas_keys: 0, + }; + if let Some(ix) = index { + texture_list.textures[ix] = Some(atlas_texture); + texture_list.textures.get_mut(ix).unwrap().as_mut() + } else { + texture_list.textures.push(Some(atlas_texture)); + texture_list.textures.last_mut().unwrap().as_mut() + } + } + + fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture { + let textures = match id.kind { + crate::AtlasTextureKind::Monochrome => &self.monochrome_textures, + crate::AtlasTextureKind::Polychrome => &self.polychrome_textures, + }; + textures[id.index as usize].as_ref().unwrap() + } +} + +impl DirectXAtlasTexture { + fn allocate(&mut self, size: Size) -> Option { + let allocation = self.allocator.allocate(size.into())?; + let tile = AtlasTile { + texture_id: self.id, + tile_id: allocation.id.into(), + bounds: Bounds { + origin: allocation.rectangle.min.into(), + size, + }, + padding: 0, + }; + self.live_atlas_keys += 1; + Some(tile) + } + + fn upload( + &self, + device_context: &ID3D11DeviceContext, + bounds: Bounds, + bytes: &[u8], + ) { + unsafe { + device_context.UpdateSubresource( + &self.texture, + 0, + Some(&D3D11_BOX { + left: bounds.left().0 as u32, + top: bounds.top().0 as u32, + front: 0, + right: bounds.right().0 as u32, + bottom: bounds.bottom().0 as u32, + back: 1, + }), + bytes.as_ptr() as _, + bounds.size.width.to_bytes(self.bytes_per_pixel as u8), + 0, + ); + } + } + + fn decrement_ref_count(&mut self) { + self.live_atlas_keys -= 1; + } + + fn is_unreferenced(&mut self) -> bool { + self.live_atlas_keys == 0 + } +} + +impl From> for etagere::Size { + fn from(size: Size) -> Self { + etagere::Size::new(size.width.into(), size.height.into()) + } +} + +impl From for Point { + fn from(value: etagere::Point) -> Self { + Point { + x: DevicePixels::from(value.x), + y: DevicePixels::from(value.y), + } + } +} diff --git a/crates/gpui/src/platform/windows/directx_renderer.rs b/crates/gpui/src/platform/windows/directx_renderer.rs new file mode 100644 index 0000000000..72cc12a5b4 --- /dev/null +++ b/crates/gpui/src/platform/windows/directx_renderer.rs @@ -0,0 +1,1807 @@ +use std::{mem::ManuallyDrop, sync::Arc}; + +use ::util::ResultExt; +use anyhow::{Context, Result}; +use windows::{ + Win32::{ + Foundation::{HMODULE, HWND}, + Graphics::{ + Direct3D::*, + Direct3D11::*, + DirectComposition::*, + Dxgi::{Common::*, *}, + }, + }, + core::Interface, +}; + +use crate::{ + platform::windows::directx_renderer::shader_resources::{ + RawShaderBytes, ShaderModule, ShaderTarget, + }, + *, +}; + +pub(crate) const DISABLE_DIRECT_COMPOSITION: &str = "GPUI_DISABLE_DIRECT_COMPOSITION"; +const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM; +// This configuration is used for MSAA rendering on paths only, and it's guaranteed to be supported by DirectX 11. +const PATH_MULTISAMPLE_COUNT: u32 = 4; + +pub(crate) struct DirectXRenderer { + hwnd: HWND, + atlas: Arc, + devices: ManuallyDrop, + resources: ManuallyDrop, + globals: DirectXGlobalElements, + pipelines: DirectXRenderPipelines, + direct_composition: Option, +} + +/// Direct3D objects +#[derive(Clone)] +pub(crate) struct DirectXDevices { + adapter: IDXGIAdapter1, + dxgi_factory: IDXGIFactory6, + pub(crate) device: ID3D11Device, + pub(crate) device_context: ID3D11DeviceContext, + dxgi_device: Option, +} + +struct DirectXResources { + // Direct3D rendering objects + swap_chain: IDXGISwapChain1, + render_target: ManuallyDrop, + render_target_view: [Option; 1], + + // Path intermediate textures (with MSAA) + path_intermediate_texture: ID3D11Texture2D, + path_intermediate_srv: [Option; 1], + path_intermediate_msaa_texture: ID3D11Texture2D, + path_intermediate_msaa_view: [Option; 1], + + // Cached window size and viewport + width: u32, + height: u32, + viewport: [D3D11_VIEWPORT; 1], +} + +struct DirectXRenderPipelines { + shadow_pipeline: PipelineState, + quad_pipeline: PipelineState, + path_rasterization_pipeline: PipelineState, + path_sprite_pipeline: PipelineState, + underline_pipeline: PipelineState, + mono_sprites: PipelineState, + poly_sprites: PipelineState, +} + +struct DirectXGlobalElements { + global_params_buffer: [Option; 1], + sampler: [Option; 1], +} + +struct DirectComposition { + comp_device: IDCompositionDevice, + comp_target: IDCompositionTarget, + comp_visual: IDCompositionVisual, +} + +impl DirectXDevices { + pub(crate) fn new(disable_direct_composition: bool) -> Result> { + let debug_layer_available = check_debug_layer_available(); + let dxgi_factory = + get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?; + let adapter = + get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?; + let (device, device_context) = { + let mut device: Option = None; + let mut context: Option = None; + let mut feature_level = D3D_FEATURE_LEVEL::default(); + get_device( + &adapter, + Some(&mut device), + Some(&mut context), + Some(&mut feature_level), + debug_layer_available, + ) + .context("Creating Direct3D device")?; + match feature_level { + D3D_FEATURE_LEVEL_11_1 => { + log::info!("Created device with Direct3D 11.1 feature level.") + } + D3D_FEATURE_LEVEL_11_0 => { + log::info!("Created device with Direct3D 11.0 feature level.") + } + D3D_FEATURE_LEVEL_10_1 => { + log::info!("Created device with Direct3D 10.1 feature level.") + } + _ => unreachable!(), + } + (device.unwrap(), context.unwrap()) + }; + let dxgi_device = if disable_direct_composition { + None + } else { + Some(device.cast().context("Creating DXGI device")?) + }; + + Ok(ManuallyDrop::new(Self { + adapter, + dxgi_factory, + dxgi_device, + device, + device_context, + })) + } +} + +impl DirectXRenderer { + pub(crate) fn new(hwnd: HWND, disable_direct_composition: bool) -> Result { + if disable_direct_composition { + log::info!("Direct Composition is disabled."); + } + + let devices = + DirectXDevices::new(disable_direct_composition).context("Creating DirectX devices")?; + let atlas = Arc::new(DirectXAtlas::new(&devices.device, &devices.device_context)); + + let resources = DirectXResources::new(&devices, 1, 1, hwnd, disable_direct_composition) + .context("Creating DirectX resources")?; + let globals = DirectXGlobalElements::new(&devices.device) + .context("Creating DirectX global elements")?; + let pipelines = DirectXRenderPipelines::new(&devices.device) + .context("Creating DirectX render pipelines")?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), hwnd) + .context("Creating DirectComposition")?; + composition + .set_swap_chain(&resources.swap_chain) + .context("Setting swap chain for DirectComposition")?; + Some(composition) + }; + + Ok(DirectXRenderer { + hwnd, + atlas, + devices, + resources, + globals, + pipelines, + direct_composition, + }) + } + + pub(crate) fn sprite_atlas(&self) -> Arc { + self.atlas.clone() + } + + fn pre_draw(&self) -> Result<()> { + update_buffer( + &self.devices.device_context, + self.globals.global_params_buffer[0].as_ref().unwrap(), + &[GlobalParams { + viewport_size: [ + self.resources.viewport[0].Width, + self.resources.viewport[0].Height, + ], + _pad: 0, + }], + )?; + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.render_target_view[0].as_ref().unwrap(), + &[0.0; 4], + ); + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + self.devices + .device_context + .RSSetViewports(Some(&self.resources.viewport)); + } + Ok(()) + } + + fn present(&mut self) -> Result<()> { + unsafe { + let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0)); + // Presenting the swap chain can fail if the DirectX device was removed or reset. + if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when drawing. Reason: {:?}", + reason + ); + self.handle_device_lost()?; + } else { + result.ok()?; + } + } + Ok(()) + } + + fn handle_device_lost(&mut self) -> Result<()> { + // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // If we don't wait, the final drawing result will be blank. + std::thread::sleep(std::time::Duration::from_millis(300)); + let disable_direct_composition = self.direct_composition.is_none(); + + unsafe { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + ManuallyDrop::drop(&mut self.resources); + self.devices.device_context.OMSetRenderTargets(None, None); + self.devices.device_context.ClearState(); + self.devices.device_context.Flush(); + + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device) + .context("Failed to report live objects after device lost") + .log_err(); + + drop(self.direct_composition.take()); + ManuallyDrop::drop(&mut self.devices); + } + + let devices = DirectXDevices::new(disable_direct_composition) + .context("Recreating DirectX devices")?; + let resources = DirectXResources::new( + &devices, + self.resources.width, + self.resources.height, + self.hwnd, + disable_direct_composition, + )?; + let globals = DirectXGlobalElements::new(&devices.device)?; + let pipelines = DirectXRenderPipelines::new(&devices.device)?; + + let direct_composition = if disable_direct_composition { + None + } else { + let composition = + DirectComposition::new(devices.dxgi_device.as_ref().unwrap(), self.hwnd)?; + composition.set_swap_chain(&resources.swap_chain)?; + Some(composition) + }; + + self.atlas + .handle_device_lost(&devices.device, &devices.device_context); + self.devices = devices; + self.resources = resources; + self.globals = globals; + self.pipelines = pipelines; + self.direct_composition = direct_composition; + + unsafe { + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + pub(crate) fn draw(&mut self, scene: &Scene) -> Result<()> { + self.pre_draw()?; + for batch in scene.batches() { + match batch { + PrimitiveBatch::Shadows(shadows) => self.draw_shadows(shadows), + PrimitiveBatch::Quads(quads) => self.draw_quads(quads), + PrimitiveBatch::Paths(paths) => { + self.draw_paths_to_intermediate(paths)?; + self.draw_paths_from_intermediate(paths) + } + PrimitiveBatch::Underlines(underlines) => self.draw_underlines(underlines), + PrimitiveBatch::MonochromeSprites { + texture_id, + sprites, + } => self.draw_monochrome_sprites(texture_id, sprites), + PrimitiveBatch::PolychromeSprites { + texture_id, + sprites, + } => self.draw_polychrome_sprites(texture_id, sprites), + PrimitiveBatch::Surfaces(surfaces) => self.draw_surfaces(surfaces), + }.context(format!("scene too large: {} paths, {} shadows, {} quads, {} underlines, {} mono, {} poly, {} surfaces", + scene.paths.len(), + scene.shadows.len(), + scene.quads.len(), + scene.underlines.len(), + scene.monochrome_sprites.len(), + scene.polychrome_sprites.len(), + scene.surfaces.len(),))?; + } + self.present() + } + + pub(crate) fn resize(&mut self, new_size: Size) -> Result<()> { + let width = new_size.width.0.max(1) as u32; + let height = new_size.height.0.max(1) as u32; + if self.resources.width == width && self.resources.height == height { + return Ok(()); + } + unsafe { + // Clear the render target before resizing + self.devices.device_context.OMSetRenderTargets(None, None); + ManuallyDrop::drop(&mut self.resources.render_target); + drop(self.resources.render_target_view[0].take().unwrap()); + + let result = self.resources.swap_chain.ResizeBuffers( + BUFFER_COUNT as u32, + width, + height, + RENDER_TARGET_FORMAT, + DXGI_SWAP_CHAIN_FLAG(0), + ); + // Resizing the swap chain requires a call to the underlying DXGI adapter, which can return the device removed error. + // The app might have moved to a monitor that's attached to a different graphics device. + // When a graphics device is removed or reset, the desktop resolution often changes, resulting in a window size change. + match result { + Ok(_) => {} + Err(e) => { + if e.code() == DXGI_ERROR_DEVICE_REMOVED || e.code() == DXGI_ERROR_DEVICE_RESET + { + let reason = self.devices.device.GetDeviceRemovedReason(); + log::error!( + "DirectX device removed or reset when resizing. Reason: {:?}", + reason + ); + self.resources.width = width; + self.resources.height = height; + self.handle_device_lost()?; + return Ok(()); + } else { + log::error!("Failed to resize swap chain: {:?}", e); + return Err(e.into()); + } + } + } + + self.resources + .recreate_resources(&self.devices, width, height)?; + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + Ok(()) + } + + fn draw_shadows(&mut self, shadows: &[Shadow]) -> Result<()> { + if shadows.is_empty() { + return Ok(()); + } + self.pipelines.shadow_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + shadows, + )?; + self.pipelines.shadow_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + shadows.len() as u32, + ) + } + + fn draw_quads(&mut self, quads: &[Quad]) -> Result<()> { + if quads.is_empty() { + return Ok(()); + } + self.pipelines.quad_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + quads, + )?; + self.pipelines.quad_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + quads.len() as u32, + ) + } + + fn draw_paths_to_intermediate(&mut self, paths: &[Path]) -> Result<()> { + if paths.is_empty() { + return Ok(()); + } + + // Clear intermediate MSAA texture + unsafe { + self.devices.device_context.ClearRenderTargetView( + self.resources.path_intermediate_msaa_view[0] + .as_ref() + .unwrap(), + &[0.0; 4], + ); + // Set intermediate MSAA texture as render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.path_intermediate_msaa_view), None); + } + + // Collect all vertices and sprites for a single draw call + let mut vertices = Vec::new(); + + for path in paths { + vertices.extend(path.vertices.iter().map(|v| PathRasterizationSprite { + xy_position: v.xy_position, + st_position: v.st_position, + color: path.color, + bounds: path.bounds.intersect(&path.content_mask.bounds), + })); + } + + self.pipelines.path_rasterization_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &vertices, + )?; + self.pipelines.path_rasterization_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST, + vertices.len() as u32, + 1, + )?; + + // Resolve MSAA to non-MSAA intermediate texture + unsafe { + self.devices.device_context.ResolveSubresource( + &self.resources.path_intermediate_texture, + 0, + &self.resources.path_intermediate_msaa_texture, + 0, + RENDER_TARGET_FORMAT, + ); + // Restore main render target + self.devices + .device_context + .OMSetRenderTargets(Some(&self.resources.render_target_view), None); + } + + Ok(()) + } + + fn draw_paths_from_intermediate(&mut self, paths: &[Path]) -> Result<()> { + let Some(first_path) = paths.first() else { + return Ok(()); + }; + + // When copying paths from the intermediate texture to the drawable, + // each pixel must only be copied once, in case of transparent paths. + // + // If all paths have the same draw order, then their bounds are all + // disjoint, so we can copy each path's bounds individually. If this + // batch combines different draw orders, we perform a single copy + // for a minimal spanning rect. + let sprites = if paths.last().unwrap().order == first_path.order { + paths + .iter() + .map(|path| PathSprite { + bounds: path.bounds, + }) + .collect::>() + } else { + let mut bounds = first_path.bounds; + for path in paths.iter().skip(1) { + bounds = bounds.union(&path.bounds); + } + vec![PathSprite { bounds }] + }; + + self.pipelines.path_sprite_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + &sprites, + )?; + + // Draw the sprites with the path texture + self.pipelines.path_sprite_pipeline.draw_with_texture( + &self.devices.device_context, + &self.resources.path_intermediate_srv, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_underlines(&mut self, underlines: &[Underline]) -> Result<()> { + if underlines.is_empty() { + return Ok(()); + } + self.pipelines.underline_pipeline.update_buffer( + &self.devices.device, + &self.devices.device_context, + underlines, + )?; + self.pipelines.underline_pipeline.draw( + &self.devices.device_context, + &self.resources.viewport, + &self.globals.global_params_buffer, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + 4, + underlines.len() as u32, + ) + } + + fn draw_monochrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[MonochromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.mono_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.mono_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_polychrome_sprites( + &mut self, + texture_id: AtlasTextureId, + sprites: &[PolychromeSprite], + ) -> Result<()> { + if sprites.is_empty() { + return Ok(()); + } + self.pipelines.poly_sprites.update_buffer( + &self.devices.device, + &self.devices.device_context, + sprites, + )?; + let texture_view = self.atlas.get_texture_view(texture_id); + self.pipelines.poly_sprites.draw_with_texture( + &self.devices.device_context, + &texture_view, + &self.resources.viewport, + &self.globals.global_params_buffer, + &self.globals.sampler, + sprites.len() as u32, + ) + } + + fn draw_surfaces(&mut self, surfaces: &[PaintSurface]) -> Result<()> { + if surfaces.is_empty() { + return Ok(()); + } + Ok(()) + } + + pub(crate) fn gpu_specs(&self) -> Result { + let desc = unsafe { self.devices.adapter.GetDesc1() }?; + let is_software_emulated = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE.0 as u32) != 0; + let device_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + let driver_name = match desc.VendorId { + 0x10DE => "NVIDIA Corporation".to_string(), + 0x1002 => "AMD Corporation".to_string(), + 0x8086 => "Intel Corporation".to_string(), + id => format!("Unknown Vendor (ID: {:#X})", id), + }; + let driver_version = match desc.VendorId { + 0x10DE => nvidia::get_driver_version(), + 0x1002 => amd::get_driver_version(), + // For Intel and other vendors, we use the DXGI API to get the driver version. + _ => dxgi::get_driver_version(&self.devices.adapter), + } + .context("Failed to get gpu driver info") + .log_err() + .unwrap_or("Unknown Driver".to_string()); + Ok(GpuSpecs { + is_software_emulated, + device_name, + driver_name, + driver_info: driver_version, + }) + } +} + +impl DirectXResources { + pub fn new( + devices: &DirectXDevices, + width: u32, + height: u32, + hwnd: HWND, + disable_direct_composition: bool, + ) -> Result> { + let swap_chain = if disable_direct_composition { + create_swap_chain(&devices.dxgi_factory, &devices.device, hwnd, width, height)? + } else { + create_swap_chain_for_composition( + &devices.dxgi_factory, + &devices.device, + width, + height, + )? + }; + + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &swap_chain, width, height)?; + set_rasterizer_state(&devices.device, &devices.device_context)?; + + Ok(ManuallyDrop::new(Self { + swap_chain, + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + path_intermediate_srv, + viewport, + width, + height, + })) + } + + #[inline] + fn recreate_resources( + &mut self, + devices: &DirectXDevices, + width: u32, + height: u32, + ) -> Result<()> { + let ( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + ) = create_resources(devices, &self.swap_chain, width, height)?; + self.render_target = render_target; + self.render_target_view = render_target_view; + self.path_intermediate_texture = path_intermediate_texture; + self.path_intermediate_msaa_texture = path_intermediate_msaa_texture; + self.path_intermediate_msaa_view = path_intermediate_msaa_view; + self.path_intermediate_srv = path_intermediate_srv; + self.viewport = viewport; + self.width = width; + self.height = height; + Ok(()) + } +} + +impl DirectXRenderPipelines { + pub fn new(device: &ID3D11Device) -> Result { + let shadow_pipeline = PipelineState::new( + device, + "shadow_pipeline", + ShaderModule::Shadow, + 4, + create_blend_state(device)?, + )?; + let quad_pipeline = PipelineState::new( + device, + "quad_pipeline", + ShaderModule::Quad, + 64, + create_blend_state(device)?, + )?; + let path_rasterization_pipeline = PipelineState::new( + device, + "path_rasterization_pipeline", + ShaderModule::PathRasterization, + 32, + create_blend_state_for_path_rasterization(device)?, + )?; + let path_sprite_pipeline = PipelineState::new( + device, + "path_sprite_pipeline", + ShaderModule::PathSprite, + 4, + create_blend_state_for_path_sprite(device)?, + )?; + let underline_pipeline = PipelineState::new( + device, + "underline_pipeline", + ShaderModule::Underline, + 4, + create_blend_state(device)?, + )?; + let mono_sprites = PipelineState::new( + device, + "monochrome_sprite_pipeline", + ShaderModule::MonochromeSprite, + 512, + create_blend_state(device)?, + )?; + let poly_sprites = PipelineState::new( + device, + "polychrome_sprite_pipeline", + ShaderModule::PolychromeSprite, + 16, + create_blend_state(device)?, + )?; + + Ok(Self { + shadow_pipeline, + quad_pipeline, + path_rasterization_pipeline, + path_sprite_pipeline, + underline_pipeline, + mono_sprites, + poly_sprites, + }) + } +} + +impl DirectComposition { + pub fn new(dxgi_device: &IDXGIDevice, hwnd: HWND) -> Result { + let comp_device = get_comp_device(&dxgi_device)?; + let comp_target = unsafe { comp_device.CreateTargetForHwnd(hwnd, true) }?; + let comp_visual = unsafe { comp_device.CreateVisual() }?; + + Ok(Self { + comp_device, + comp_target, + comp_visual, + }) + } + + pub fn set_swap_chain(&self, swap_chain: &IDXGISwapChain1) -> Result<()> { + unsafe { + self.comp_visual.SetContent(swap_chain)?; + self.comp_target.SetRoot(&self.comp_visual)?; + self.comp_device.Commit()?; + } + Ok(()) + } +} + +impl DirectXGlobalElements { + pub fn new(device: &ID3D11Device) -> Result { + let global_params_buffer = unsafe { + let desc = D3D11_BUFFER_DESC { + ByteWidth: std::mem::size_of::() as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + ..Default::default() + }; + let mut buffer = None; + device.CreateBuffer(&desc, None, Some(&mut buffer))?; + [buffer] + }; + + let sampler = unsafe { + let desc = D3D11_SAMPLER_DESC { + Filter: D3D11_FILTER_MIN_MAG_MIP_LINEAR, + AddressU: D3D11_TEXTURE_ADDRESS_WRAP, + AddressV: D3D11_TEXTURE_ADDRESS_WRAP, + AddressW: D3D11_TEXTURE_ADDRESS_WRAP, + MipLODBias: 0.0, + MaxAnisotropy: 1, + ComparisonFunc: D3D11_COMPARISON_ALWAYS, + BorderColor: [0.0; 4], + MinLOD: 0.0, + MaxLOD: D3D11_FLOAT32_MAX, + }; + let mut output = None; + device.CreateSamplerState(&desc, Some(&mut output))?; + [output] + }; + + Ok(Self { + global_params_buffer, + sampler, + }) + } +} + +#[derive(Debug, Default)] +#[repr(C)] +struct GlobalParams { + viewport_size: [f32; 2], + _pad: u64, +} + +struct PipelineState { + label: &'static str, + vertex: ID3D11VertexShader, + fragment: ID3D11PixelShader, + buffer: ID3D11Buffer, + buffer_size: usize, + view: [Option; 1], + blend_state: ID3D11BlendState, + _marker: std::marker::PhantomData, +} + +impl PipelineState { + fn new( + device: &ID3D11Device, + label: &'static str, + shader_module: ShaderModule, + buffer_size: usize, + blend_state: ID3D11BlendState, + ) -> Result { + let vertex = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?; + create_vertex_shader(device, raw_shader.as_bytes())? + }; + let fragment = { + let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?; + create_fragment_shader(device, raw_shader.as_bytes())? + }; + let buffer = create_buffer(device, std::mem::size_of::(), buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + + Ok(PipelineState { + label, + vertex, + fragment, + buffer, + buffer_size, + view, + blend_state, + _marker: std::marker::PhantomData, + }) + } + + fn update_buffer( + &mut self, + device: &ID3D11Device, + device_context: &ID3D11DeviceContext, + data: &[T], + ) -> Result<()> { + if self.buffer_size < data.len() { + let new_buffer_size = data.len().next_power_of_two(); + log::info!( + "Updating {} buffer size from {} to {}", + self.label, + self.buffer_size, + new_buffer_size + ); + let buffer = create_buffer(device, std::mem::size_of::(), new_buffer_size)?; + let view = create_buffer_view(device, &buffer)?; + self.buffer = buffer; + self.view = view; + self.buffer_size = new_buffer_size; + } + update_buffer(device_context, &self.buffer, data) + } + + fn draw( + &self, + device_context: &ID3D11DeviceContext, + viewport: &[D3D11_VIEWPORT], + global_params: &[Option], + topology: D3D_PRIMITIVE_TOPOLOGY, + vertex_count: u32, + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + topology, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.DrawInstanced(vertex_count, instance_count, 0, 0); + } + Ok(()) + } + + fn draw_with_texture( + &self, + device_context: &ID3D11DeviceContext, + texture: &[Option], + viewport: &[D3D11_VIEWPORT], + global_params: &[Option], + sampler: &[Option], + instance_count: u32, + ) -> Result<()> { + set_pipeline_state( + device_context, + &self.view, + D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP, + viewport, + &self.vertex, + &self.fragment, + global_params, + &self.blend_state, + ); + unsafe { + device_context.PSSetSamplers(0, Some(sampler)); + device_context.VSSetShaderResources(0, Some(texture)); + device_context.PSSetShaderResources(0, Some(texture)); + + device_context.DrawInstanced(4, instance_count, 0, 0); + } + Ok(()) + } +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathRasterizationSprite { + xy_position: Point, + st_position: Point, + color: Background, + bounds: Bounds, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct PathSprite { + bounds: Bounds, +} + +impl Drop for DirectXRenderer { + fn drop(&mut self) { + #[cfg(debug_assertions)] + report_live_objects(&self.devices.device).ok(); + unsafe { + ManuallyDrop::drop(&mut self.devices); + ManuallyDrop::drop(&mut self.resources); + } + } +} + +impl Drop for DirectXResources { + fn drop(&mut self) { + unsafe { + ManuallyDrop::drop(&mut self.render_target); + } + } +} + +#[inline] +fn check_debug_layer_available() -> bool { + #[cfg(debug_assertions)] + { + unsafe { DXGIGetDebugInterface1::(0) } + .log_err() + .is_some() + } + #[cfg(not(debug_assertions))] + { + false + } +} + +#[inline] +fn get_dxgi_factory(debug_layer_available: bool) -> Result { + let factory_flag = if debug_layer_available { + DXGI_CREATE_FACTORY_DEBUG + } else { + #[cfg(debug_assertions)] + log::warn!( + "Failed to get DXGI debug interface. DirectX debugging features will be disabled." + ); + DXGI_CREATE_FACTORY_FLAGS::default() + }; + unsafe { Ok(CreateDXGIFactory2(factory_flag)?) } +} + +fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result { + for adapter_index in 0.. { + let adapter: IDXGIAdapter1 = unsafe { + dxgi_factory + .EnumAdapterByGpuPreference(adapter_index, DXGI_GPU_PREFERENCE_MINIMUM_POWER) + }?; + if let Ok(desc) = unsafe { adapter.GetDesc1() } { + let gpu_name = String::from_utf16_lossy(&desc.Description) + .trim_matches(char::from(0)) + .to_string(); + log::info!("Using GPU: {}", gpu_name); + } + // Check to see whether the adapter supports Direct3D 11, but don't + // create the actual device yet. + if get_device(&adapter, None, None, None, debug_layer_available) + .log_err() + .is_some() + { + return Ok(adapter); + } + } + + unreachable!() +} + +fn get_device( + adapter: &IDXGIAdapter1, + device: Option<*mut Option>, + context: Option<*mut Option>, + feature_level: Option<*mut D3D_FEATURE_LEVEL>, + debug_layer_available: bool, +) -> Result<()> { + let device_flags = if debug_layer_available { + D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG + } else { + D3D11_CREATE_DEVICE_BGRA_SUPPORT + }; + unsafe { + D3D11CreateDevice( + adapter, + D3D_DRIVER_TYPE_UNKNOWN, + HMODULE::default(), + device_flags, + // 4x MSAA is required for Direct3D Feature Level 10.1 or better + Some(&[ + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + ]), + D3D11_SDK_VERSION, + device, + feature_level, + context, + )?; + } + Ok(()) +} + +#[inline] +fn get_comp_device(dxgi_device: &IDXGIDevice) -> Result { + Ok(unsafe { DCompositionCreateDevice(dxgi_device)? }) +} + +fn create_swap_chain_for_composition( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result { + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + // Composition SwapChains only support the DXGI_SCALING_STRETCH Scaling. + Scaling: DXGI_SCALING_STRETCH, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_PREMULTIPLIED, + Flags: 0, + }; + Ok(unsafe { dxgi_factory.CreateSwapChainForComposition(device, &desc, None)? }) +} + +fn create_swap_chain( + dxgi_factory: &IDXGIFactory6, + device: &ID3D11Device, + hwnd: HWND, + width: u32, + height: u32, +) -> Result { + use windows::Win32::Graphics::Dxgi::DXGI_MWA_NO_ALT_ENTER; + + let desc = DXGI_SWAP_CHAIN_DESC1 { + Width: width, + Height: height, + Format: RENDER_TARGET_FORMAT, + Stereo: false.into(), + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + BufferUsage: DXGI_USAGE_RENDER_TARGET_OUTPUT, + BufferCount: BUFFER_COUNT as u32, + Scaling: DXGI_SCALING_NONE, + SwapEffect: DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL, + AlphaMode: DXGI_ALPHA_MODE_IGNORE, + Flags: 0, + }; + let swap_chain = + unsafe { dxgi_factory.CreateSwapChainForHwnd(device, hwnd, &desc, None, None) }?; + unsafe { dxgi_factory.MakeWindowAssociation(hwnd, DXGI_MWA_NO_ALT_ENTER) }?; + Ok(swap_chain) +} + +#[inline] +fn create_resources( + devices: &DirectXDevices, + swap_chain: &IDXGISwapChain1, + width: u32, + height: u32, +) -> Result<( + ManuallyDrop, + [Option; 1], + ID3D11Texture2D, + [Option; 1], + ID3D11Texture2D, + [Option; 1], + [D3D11_VIEWPORT; 1], +)> { + let (render_target, render_target_view) = + create_render_target_and_its_view(&swap_chain, &devices.device)?; + let (path_intermediate_texture, path_intermediate_srv) = + create_path_intermediate_texture(&devices.device, width, height)?; + let (path_intermediate_msaa_texture, path_intermediate_msaa_view) = + create_path_intermediate_msaa_texture_and_view(&devices.device, width, height)?; + let viewport = set_viewport(&devices.device_context, width as f32, height as f32); + Ok(( + render_target, + render_target_view, + path_intermediate_texture, + path_intermediate_srv, + path_intermediate_msaa_texture, + path_intermediate_msaa_view, + viewport, + )) +} + +#[inline] +fn create_render_target_and_its_view( + swap_chain: &IDXGISwapChain1, + device: &ID3D11Device, +) -> Result<( + ManuallyDrop, + [Option; 1], +)> { + let render_target: ID3D11Texture2D = unsafe { swap_chain.GetBuffer(0) }?; + let mut render_target_view = None; + unsafe { device.CreateRenderTargetView(&render_target, None, Some(&mut render_target_view))? }; + Ok(( + ManuallyDrop::new(render_target), + [Some(render_target_view.unwrap())], + )) +} + +#[inline] +fn create_path_intermediate_texture( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option; 1])> { + let texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: (D3D11_BIND_RENDER_TARGET.0 | D3D11_BIND_SHADER_RESOURCE.0) as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + + let mut shader_resource_view = None; + unsafe { device.CreateShaderResourceView(&texture, None, Some(&mut shader_resource_view))? }; + + Ok((texture, [Some(shader_resource_view.unwrap())])) +} + +#[inline] +fn create_path_intermediate_msaa_texture_and_view( + device: &ID3D11Device, + width: u32, + height: u32, +) -> Result<(ID3D11Texture2D, [Option; 1])> { + let msaa_texture = unsafe { + let mut output = None; + let desc = D3D11_TEXTURE2D_DESC { + Width: width, + Height: height, + MipLevels: 1, + ArraySize: 1, + Format: RENDER_TARGET_FORMAT, + SampleDesc: DXGI_SAMPLE_DESC { + Count: PATH_MULTISAMPLE_COUNT, + Quality: D3D11_STANDARD_MULTISAMPLE_PATTERN.0 as u32, + }, + Usage: D3D11_USAGE_DEFAULT, + BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32, + CPUAccessFlags: 0, + MiscFlags: 0, + }; + device.CreateTexture2D(&desc, None, Some(&mut output))?; + output.unwrap() + }; + let mut msaa_view = None; + unsafe { device.CreateRenderTargetView(&msaa_texture, None, Some(&mut msaa_view))? }; + Ok((msaa_texture, [Some(msaa_view.unwrap())])) +} + +#[inline] +fn set_viewport( + device_context: &ID3D11DeviceContext, + width: f32, + height: f32, +) -> [D3D11_VIEWPORT; 1] { + let viewport = [D3D11_VIEWPORT { + TopLeftX: 0.0, + TopLeftY: 0.0, + Width: width, + Height: height, + MinDepth: 0.0, + MaxDepth: 1.0, + }]; + unsafe { device_context.RSSetViewports(Some(&viewport)) }; + viewport +} + +#[inline] +fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> { + let desc = D3D11_RASTERIZER_DESC { + FillMode: D3D11_FILL_SOLID, + CullMode: D3D11_CULL_NONE, + FrontCounterClockwise: false.into(), + DepthBias: 0, + DepthBiasClamp: 0.0, + SlopeScaledDepthBias: 0.0, + DepthClipEnable: true.into(), + ScissorEnable: false.into(), + MultisampleEnable: true.into(), + AntialiasedLineEnable: false.into(), + }; + let rasterizer_state = unsafe { + let mut state = None; + device.CreateRasterizerState(&desc, Some(&mut state))?; + state.unwrap() + }; + unsafe { device_context.RSSetState(&rasterizer_state) }; + Ok(()) +} + +// https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc +#[inline] +fn create_blend_state(device: &ID3D11Device) -> Result { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_SRC_ALPHA; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_rasterization(device: &ID3D11Device) -> Result { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_blend_state_for_path_sprite(device: &ID3D11Device) -> Result { + // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display + // device performs the blend in linear space, which is ideal. + let mut desc = D3D11_BLEND_DESC::default(); + desc.RenderTarget[0].BlendEnable = true.into(); + desc.RenderTarget[0].BlendOp = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].BlendOpAlpha = D3D11_BLEND_OP_ADD; + desc.RenderTarget[0].SrcBlend = D3D11_BLEND_ONE; + desc.RenderTarget[0].SrcBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].DestBlend = D3D11_BLEND_INV_SRC_ALPHA; + desc.RenderTarget[0].DestBlendAlpha = D3D11_BLEND_ONE; + desc.RenderTarget[0].RenderTargetWriteMask = D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8; + unsafe { + let mut state = None; + device.CreateBlendState(&desc, Some(&mut state))?; + Ok(state.unwrap()) + } +} + +#[inline] +fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result { + unsafe { + let mut shader = None; + device.CreateVertexShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result { + unsafe { + let mut shader = None; + device.CreatePixelShader(bytes, None, Some(&mut shader))?; + Ok(shader.unwrap()) + } +} + +#[inline] +fn create_buffer( + device: &ID3D11Device, + element_size: usize, + buffer_size: usize, +) -> Result { + let desc = D3D11_BUFFER_DESC { + ByteWidth: (element_size * buffer_size) as u32, + Usage: D3D11_USAGE_DYNAMIC, + BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32, + CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32, + MiscFlags: D3D11_RESOURCE_MISC_BUFFER_STRUCTURED.0 as u32, + StructureByteStride: element_size as u32, + }; + let mut buffer = None; + unsafe { device.CreateBuffer(&desc, None, Some(&mut buffer)) }?; + Ok(buffer.unwrap()) +} + +#[inline] +fn create_buffer_view( + device: &ID3D11Device, + buffer: &ID3D11Buffer, +) -> Result<[Option; 1]> { + let mut view = None; + unsafe { device.CreateShaderResourceView(buffer, None, Some(&mut view)) }?; + Ok([view]) +} + +#[inline] +fn update_buffer( + device_context: &ID3D11DeviceContext, + buffer: &ID3D11Buffer, + data: &[T], +) -> Result<()> { + unsafe { + let mut dest = std::mem::zeroed(); + device_context.Map(buffer, 0, D3D11_MAP_WRITE_DISCARD, 0, Some(&mut dest))?; + std::ptr::copy_nonoverlapping(data.as_ptr(), dest.pData as _, data.len()); + device_context.Unmap(buffer, 0); + } + Ok(()) +} + +#[inline] +fn set_pipeline_state( + device_context: &ID3D11DeviceContext, + buffer_view: &[Option], + topology: D3D_PRIMITIVE_TOPOLOGY, + viewport: &[D3D11_VIEWPORT], + vertex_shader: &ID3D11VertexShader, + fragment_shader: &ID3D11PixelShader, + global_params: &[Option], + blend_state: &ID3D11BlendState, +) { + unsafe { + device_context.VSSetShaderResources(1, Some(buffer_view)); + device_context.PSSetShaderResources(1, Some(buffer_view)); + device_context.IASetPrimitiveTopology(topology); + device_context.RSSetViewports(Some(viewport)); + device_context.VSSetShader(vertex_shader, None); + device_context.PSSetShader(fragment_shader, None); + device_context.VSSetConstantBuffers(0, Some(global_params)); + device_context.PSSetConstantBuffers(0, Some(global_params)); + device_context.OMSetBlendState(blend_state, None, 0xFFFFFFFF); + } +} + +#[cfg(debug_assertions)] +fn report_live_objects(device: &ID3D11Device) -> Result<()> { + let debug_device: ID3D11Debug = device.cast()?; + unsafe { + debug_device.ReportLiveDeviceObjects(D3D11_RLDO_DETAIL)?; + } + Ok(()) +} + +const BUFFER_COUNT: usize = 3; + +pub(crate) mod shader_resources { + use anyhow::Result; + + #[cfg(debug_assertions)] + use windows::{ + Win32::Graphics::Direct3D::{ + Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile}, + ID3DBlob, + }, + core::{HSTRING, PCSTR}, + }; + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderModule { + Quad, + Shadow, + Underline, + PathRasterization, + PathSprite, + MonochromeSprite, + PolychromeSprite, + EmojiRasterization, + } + + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(crate) enum ShaderTarget { + Vertex, + Fragment, + } + + pub(crate) struct RawShaderBytes<'t> { + inner: &'t [u8], + + #[cfg(debug_assertions)] + _blob: ID3DBlob, + } + + impl<'t> RawShaderBytes<'t> { + pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result { + #[cfg(not(debug_assertions))] + { + Ok(Self::from_bytes(module, target)) + } + #[cfg(debug_assertions)] + { + let blob = build_shader_blob(module, target)?; + let inner = unsafe { + std::slice::from_raw_parts( + blob.GetBufferPointer() as *const u8, + blob.GetBufferSize(), + ) + }; + Ok(Self { inner, _blob: blob }) + } + } + + pub(crate) fn as_bytes(&'t self) -> &'t [u8] { + self.inner + } + + #[cfg(not(debug_assertions))] + fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self { + let bytes = match module { + ShaderModule::Quad => match target { + ShaderTarget::Vertex => QUAD_VERTEX_BYTES, + ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES, + }, + ShaderModule::Shadow => match target { + ShaderTarget::Vertex => SHADOW_VERTEX_BYTES, + ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES, + }, + ShaderModule::Underline => match target { + ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES, + ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES, + }, + ShaderModule::PathRasterization => match target { + ShaderTarget::Vertex => PATH_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_RASTERIZATION_FRAGMENT_BYTES, + }, + ShaderModule::PathSprite => match target { + ShaderTarget::Vertex => PATH_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => PATH_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::MonochromeSprite => match target { + ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::PolychromeSprite => match target { + ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES, + ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES, + }, + ShaderModule::EmojiRasterization => match target { + ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES, + ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES, + }, + }; + Self { inner: bytes } + } + } + + #[cfg(debug_assertions)] + pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result { + unsafe { + let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) { + "color_text_raster.hlsl" + } else { + "shaders.hlsl" + }; + + let entry = format!( + "{}_{}\0", + entry.as_str(), + match target { + ShaderTarget::Vertex => "vertex", + ShaderTarget::Fragment => "fragment", + } + ); + let target = match target { + ShaderTarget::Vertex => "vs_4_1\0", + ShaderTarget::Fragment => "ps_4_1\0", + }; + + let mut compile_blob = None; + let mut error_blob = None; + let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(&format!("src/platform/windows/{}", shader_name)) + .canonicalize()?; + + let entry_point = PCSTR::from_raw(entry.as_ptr()); + let target_cstr = PCSTR::from_raw(target.as_ptr()); + + let ret = D3DCompileFromFile( + &HSTRING::from(shader_path.to_str().unwrap()), + None, + None, + entry_point, + target_cstr, + D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION, + 0, + &mut compile_blob, + Some(&mut error_blob), + ); + if ret.is_err() { + let Some(error_blob) = error_blob else { + return Err(anyhow::anyhow!("{ret:?}")); + }; + + let error_string = + std::ffi::CStr::from_ptr(error_blob.GetBufferPointer() as *const i8) + .to_string_lossy(); + log::error!("Shader compile error: {}", error_string); + return Err(anyhow::anyhow!("Compile error: {}", error_string)); + } + Ok(compile_blob.unwrap()) + } + } + + #[cfg(not(debug_assertions))] + include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs")); + + #[cfg(debug_assertions)] + impl ShaderModule { + pub fn as_str(&self) -> &str { + match self { + ShaderModule::Quad => "quad", + ShaderModule::Shadow => "shadow", + ShaderModule::Underline => "underline", + ShaderModule::PathRasterization => "path_rasterization", + ShaderModule::PathSprite => "path_sprite", + ShaderModule::MonochromeSprite => "monochrome_sprite", + ShaderModule::PolychromeSprite => "polychrome_sprite", + ShaderModule::EmojiRasterization => "emoji_rasterization", + } + } + } +} + +mod nvidia { + use std::{ + ffi::CStr, + os::raw::{c_char, c_int, c_uint}, + }; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180 + const NVAPI_SHORT_STRING_MAX: usize = 64; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L235 + #[allow(non_camel_case_types)] + type NvAPI_ShortString = [c_char; NVAPI_SHORT_STRING_MAX]; + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L447 + #[allow(non_camel_case_types)] + type NvAPI_SYS_GetDriverAndBranchVersion_t = unsafe extern "C" fn( + driver_version: *mut c_uint, + build_branch_string: *mut NvAPI_ShortString, + ) -> c_int; + + pub(super) fn get_driver_version() -> Result { + unsafe { + // Try to load the NVIDIA driver DLL + #[cfg(target_pointer_width = "64")] + let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?; + #[cfg(target_pointer_width = "32")] + let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?; + + let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface")) + .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?; + let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr); + + // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_interface.h#L41 + let nvapi_get_driver_version_ptr = nvapi_query(0x2926aaad); + if nvapi_get_driver_version_ptr.is_null() { + anyhow::bail!("Failed to get NVIDIA driver version function pointer"); + } + let nvapi_get_driver_version: NvAPI_SYS_GetDriverAndBranchVersion_t = + std::mem::transmute(nvapi_get_driver_version_ptr); + + let mut driver_version: c_uint = 0; + let mut build_branch_string: NvAPI_ShortString = [0; NVAPI_SHORT_STRING_MAX]; + let result = nvapi_get_driver_version( + &mut driver_version as *mut c_uint, + &mut build_branch_string as *mut NvAPI_ShortString, + ); + + if result != 0 { + anyhow::bail!( + "Failed to get NVIDIA driver version, error code: {}", + result + ); + } + let major = driver_version / 100; + let minor = driver_version % 100; + let branch_string = CStr::from_ptr(build_branch_string.as_ptr()); + Ok(format!( + "{}.{} {}", + major, + minor, + branch_string.to_string_lossy() + )) + } + } +} + +mod amd { + use std::os::raw::{c_char, c_int, c_void}; + + use anyhow::{Context, Result}; + use windows::{ + Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA}, + core::s, + }; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145 + const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12); + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L204 + // This is an opaque type, using struct to represent it properly for FFI + #[repr(C)] + struct AGSContext { + _private: [u8; 0], + } + + #[repr(C)] + pub struct AGSGPUInfo { + pub driver_version: *const c_char, + pub radeon_software_version: *const c_char, + pub num_devices: c_int, + pub devices: *mut c_void, + } + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L429 + #[allow(non_camel_case_types)] + type agsInitialize_t = unsafe extern "C" fn( + version: c_int, + config: *const c_void, + context: *mut *mut AGSContext, + gpu_info: *mut AGSGPUInfo, + ) -> c_int; + + // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L436 + #[allow(non_camel_case_types)] + type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int; + + pub(super) fn get_driver_version() -> Result { + unsafe { + #[cfg(target_pointer_width = "64")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?; + #[cfg(target_pointer_width = "32")] + let amd_dll = + LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?; + + let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?; + let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize")) + .ok_or_else(|| anyhow::anyhow!("Failed to get agsDeInitialize address"))?; + + let ags_initialize: agsInitialize_t = std::mem::transmute(ags_initialize_addr); + let ags_deinitialize: agsDeInitialize_t = std::mem::transmute(ags_deinitialize_addr); + + let mut context: *mut AGSContext = std::ptr::null_mut(); + let mut gpu_info: AGSGPUInfo = AGSGPUInfo { + driver_version: std::ptr::null(), + radeon_software_version: std::ptr::null(), + num_devices: 0, + devices: std::ptr::null_mut(), + }; + + let result = ags_initialize( + AGS_CURRENT_VERSION, + std::ptr::null(), + &mut context, + &mut gpu_info, + ); + if result != 0 { + anyhow::bail!("Failed to initialize AMD AGS, error code: {}", result); + } + + // Vulkan acctually returns this as the driver version + let software_version = if !gpu_info.radeon_software_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.radeon_software_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Software Version".to_string() + }; + + let driver_version = if !gpu_info.driver_version.is_null() { + std::ffi::CStr::from_ptr(gpu_info.driver_version) + .to_string_lossy() + .into_owned() + } else { + "Unknown Radeon Driver Version".to_string() + }; + + ags_deinitialize(context); + Ok(format!("{} ({})", software_version, driver_version)) + } + } +} + +mod dxgi { + use windows::{ + Win32::Graphics::Dxgi::{IDXGIAdapter1, IDXGIDevice}, + core::Interface, + }; + + pub(super) fn get_driver_version(adapter: &IDXGIAdapter1) -> anyhow::Result { + let number = unsafe { adapter.CheckInterfaceSupport(&IDXGIDevice::IID as _) }?; + Ok(format!( + "{}.{}.{}.{}", + number >> 48, + (number >> 32) & 0xFFFF, + (number >> 16) & 0xFFFF, + number & 0xFFFF + )) + } +} diff --git a/crates/gpui/src/platform/windows/events.rs b/crates/gpui/src/platform/windows/events.rs index 839fd10375..61f410a8c6 100644 --- a/crates/gpui/src/platform/windows/events.rs +++ b/crates/gpui/src/platform/windows/events.rs @@ -23,6 +23,7 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1; pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2; pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3; pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4; +pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5; const SIZE_MOVE_LOOP_TIMER_ID: usize = 1; const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1; @@ -37,6 +38,7 @@ pub(crate) fn handle_msg( let handled = match msg { WM_ACTIVATE => handle_activate_msg(wparam, state_ptr), WM_CREATE => handle_create_msg(handle, state_ptr), + WM_DEVICECHANGE => handle_device_change_msg(handle, wparam, state_ptr), WM_MOVE => handle_move_msg(handle, lparam, state_ptr), WM_SIZE => handle_size_msg(wparam, lparam, state_ptr), WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr), @@ -48,7 +50,7 @@ pub(crate) fn handle_msg( WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr), WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr), WM_PAINT => handle_paint_msg(handle, state_ptr), - WM_CLOSE => handle_close_msg(handle, state_ptr), + WM_CLOSE => handle_close_msg(state_ptr), WM_DESTROY => handle_destroy_msg(handle, state_ptr), WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr), WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr), @@ -96,6 +98,7 @@ pub(crate) fn handle_msg( WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr), WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr), WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr), + WM_GPUI_FORCE_UPDATE_WINDOW => draw_window(handle, true, state_ptr), _ => None, }; if let Some(n) = handled { @@ -181,11 +184,9 @@ fn handle_size_msg( let new_size = size(DevicePixels(width), DevicePixels(height)); let scale_factor = lock.scale_factor; if lock.restore_from_minimized.is_some() { - lock.renderer - .update_drawable_size_even_if_unchanged(new_size); lock.callbacks.request_frame = lock.restore_from_minimized.take(); } else { - lock.renderer.update_drawable_size(new_size); + lock.renderer.resize(new_size).log_err(); } let new_size = new_size.to_pixels(scale_factor); lock.logical_size = new_size; @@ -238,40 +239,14 @@ fn handle_timer_msg( } fn handle_paint_msg(handle: HWND, state_ptr: Rc) -> Option { - let mut lock = state_ptr.state.borrow_mut(); - if let Some(mut request_frame) = lock.callbacks.request_frame.take() { - drop(lock); - request_frame(Default::default()); - state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); - } - unsafe { ValidateRect(Some(handle), None).ok().log_err() }; - Some(0) + draw_window(handle, false, state_ptr) } -fn handle_close_msg(handle: HWND, state_ptr: Rc) -> Option { - let mut lock = state_ptr.state.borrow_mut(); - let output = if let Some(mut callback) = lock.callbacks.should_close.take() { - drop(lock); - let should_close = callback(); - state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); - if should_close { None } else { Some(0) } - } else { - None - }; - - // Workaround as window close animation is not played with `WS_EX_LAYERED` enabled. - if output.is_none() { - unsafe { - let current_style = get_window_long(handle, GWL_EXSTYLE); - set_window_long( - handle, - GWL_EXSTYLE, - current_style & !WS_EX_LAYERED.0 as isize, - ); - } - } - - output +fn handle_close_msg(state_ptr: Rc) -> Option { + let mut callback = state_ptr.state.borrow_mut().callbacks.should_close.take()?; + let should_close = callback(); + state_ptr.state.borrow_mut().callbacks.should_close = Some(callback); + if should_close { None } else { Some(0) } } fn handle_destroy_msg(handle: HWND, state_ptr: Rc) -> Option { @@ -1223,6 +1198,53 @@ fn handle_input_language_changed( Some(0) } +fn handle_device_change_msg( + handle: HWND, + wparam: WPARAM, + state_ptr: Rc, +) -> Option { + if wparam.0 == DBT_DEVNODES_CHANGED as usize { + // The reason for sending this message is to actually trigger a redraw of the window. + unsafe { + PostMessageW( + Some(handle), + WM_GPUI_FORCE_UPDATE_WINDOW, + WPARAM(0), + LPARAM(0), + ) + .log_err(); + } + // If the GPU device is lost, this redraw will take care of recreating the device context. + // The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after + // the device context has been recreated. + draw_window(handle, true, state_ptr) + } else { + // Other device change messages are not handled. + None + } +} + +#[inline] +fn draw_window( + handle: HWND, + force_render: bool, + state_ptr: Rc, +) -> Option { + let mut request_frame = state_ptr + .state + .borrow_mut() + .callbacks + .request_frame + .take()?; + request_frame(RequestFrameOptions { + require_presentation: false, + force_render, + }); + state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame); + unsafe { ValidateRect(Some(handle), None).ok().log_err() }; + Some(0) +} + #[inline] fn parse_char_message(wparam: WPARAM, state_ptr: &Rc) -> Option { let code_point = wparam.loword(); diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 401ecdeffe..bc09cc199d 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -28,13 +28,12 @@ use windows::{ core::*, }; -use crate::{platform::blade::BladeContext, *}; +use crate::*; pub(crate) struct WindowsPlatform { state: RefCell, raw_window_handles: RwLock>, // The below members will never change throughout the entire lifecycle of the app. - gpu_context: BladeContext, icon: HICON, main_receiver: flume::Receiver, background_executor: BackgroundExecutor, @@ -45,6 +44,7 @@ pub(crate) struct WindowsPlatform { drop_target_helper: IDropTargetHelper, validation_number: usize, main_thread_id_win32: u32, + disable_direct_composition: bool, } pub(crate) struct WindowsPlatformState { @@ -94,14 +94,18 @@ impl WindowsPlatform { main_thread_id_win32, validation_number, )); + let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION) + .is_ok_and(|value| value == "true" || value == "1"); let background_executor = BackgroundExecutor::new(dispatcher.clone()); let foreground_executor = ForegroundExecutor::new(dispatcher); + let directx_devices = DirectXDevices::new(disable_direct_composition) + .context("Unable to init directx devices.")?; let bitmap_factory = ManuallyDrop::new(unsafe { CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER) .context("Error creating bitmap factory.")? }); let text_system = Arc::new( - DirectWriteTextSystem::new(&bitmap_factory) + DirectWriteTextSystem::new(&directx_devices, &bitmap_factory) .context("Error creating DirectWriteTextSystem")?, ); let drop_target_helper: IDropTargetHelper = unsafe { @@ -111,18 +115,17 @@ impl WindowsPlatform { let icon = load_icon().unwrap_or_default(); let state = RefCell::new(WindowsPlatformState::new()); let raw_window_handles = RwLock::new(SmallVec::new()); - let gpu_context = BladeContext::new().context("Unable to init GPU context")?; let windows_version = WindowsVersion::new().context("Error retrieve windows version")?; Ok(Self { state, raw_window_handles, - gpu_context, icon, main_receiver, background_executor, foreground_executor, text_system, + disable_direct_composition, windows_version, bitmap_factory, drop_target_helper, @@ -187,6 +190,7 @@ impl WindowsPlatform { validation_number: self.validation_number, main_receiver: self.main_receiver.clone(), main_thread_id_win32: self.main_thread_id_win32, + disable_direct_composition: self.disable_direct_composition, } } @@ -343,27 +347,11 @@ impl Platform for WindowsPlatform { fn run(&self, on_finish_launching: Box) { on_finish_launching(); - let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) }; - begin_vsync(*vsync_event); - 'a: loop { - let wait_result = unsafe { - MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT) - }; - - match wait_result { - // compositor clock ticked so we should draw a frame - WAIT_EVENT(0) => self.redraw_all(), - // Windows thread messages are posted - WAIT_EVENT(1) => { - if self.handle_events() { - break 'a; - } - } - _ => { - log::error!("Something went wrong while waiting {:?}", wait_result); - break; - } + loop { + if self.handle_events() { + break; } + self.redraw_all(); } if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit { @@ -455,12 +443,7 @@ impl Platform for WindowsPlatform { handle: AnyWindowHandle, options: WindowParams, ) -> Result> { - let window = WindowsWindow::new( - handle, - options, - self.generate_creation_info(), - &self.gpu_context, - )?; + let window = WindowsWindow::new(handle, options, self.generate_creation_info())?; let handle = window.get_raw_handle(); self.raw_window_handles.write().push(handle); @@ -739,6 +722,7 @@ pub(crate) struct WindowCreationInfo { pub(crate) validation_number: usize, pub(crate) main_receiver: flume::Receiver, pub(crate) main_thread_id_win32: u32, + pub(crate) disable_direct_composition: bool, } fn open_target(target: &str) { @@ -846,16 +830,6 @@ fn file_save_dialog(directory: PathBuf, window: Option) -> Result Result { let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? }; let handle = unsafe { diff --git a/crates/gpui/src/platform/windows/shaders.hlsl b/crates/gpui/src/platform/windows/shaders.hlsl new file mode 100644 index 0000000000..25830e4b6c --- /dev/null +++ b/crates/gpui/src/platform/windows/shaders.hlsl @@ -0,0 +1,1159 @@ +cbuffer GlobalParams: register(b0) { + float2 global_viewport_size; + uint2 _pad; +}; + +Texture2D t_sprite: register(t0); +SamplerState s_sprite: register(s0); + +struct Bounds { + float2 origin; + float2 size; +}; + +struct Corners { + float top_left; + float top_right; + float bottom_right; + float bottom_left; +}; + +struct Edges { + float top; + float right; + float bottom; + float left; +}; + +struct Hsla { + float h; + float s; + float l; + float a; +}; + +struct LinearColorStop { + Hsla color; + float percentage; +}; + +struct Background { + // 0u is Solid + // 1u is LinearGradient + // 2u is PatternSlash + uint tag; + // 0u is sRGB linear color + // 1u is Oklab color + uint color_space; + Hsla solid; + float gradient_angle_or_pattern_height; + LinearColorStop colors[2]; + uint pad; +}; + +struct GradientColor { + float4 solid; + float4 color0; + float4 color1; +}; + +struct AtlasTextureId { + uint index; + uint kind; +}; + +struct AtlasBounds { + int2 origin; + int2 size; +}; + +struct AtlasTile { + AtlasTextureId texture_id; + uint tile_id; + uint padding; + AtlasBounds bounds; +}; + +struct TransformationMatrix { + float2x2 rotation_scale; + float2 translation; +}; + +static const float M_PI_F = 3.141592653f; +static const float3 GRAYSCALE_FACTORS = float3(0.2126f, 0.7152f, 0.0722f); + +float4 to_device_position_impl(float2 position) { + float2 device_position = position / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0., 1.); +} + +float4 to_device_position(float2 unit_vertex, Bounds bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return to_device_position_impl(position); +} + +float4 distance_from_clip_rect_impl(float2 position, Bounds clip_bounds) { + float2 tl = position - clip_bounds.origin; + float2 br = clip_bounds.origin + clip_bounds.size - position; + return float4(tl.x, br.x, tl.y, br.y); +} + +float4 distance_from_clip_rect(float2 unit_vertex, Bounds bounds, Bounds clip_bounds) { + float2 position = unit_vertex * bounds.size + bounds.origin; + return distance_from_clip_rect_impl(position, clip_bounds); +} + +// Convert linear RGB to sRGB +float3 linear_to_srgb(float3 color) { + return pow(color, float3(2.2, 2.2, 2.2)); +} + +// Convert sRGB to linear RGB +float3 srgb_to_linear(float3 color) { + return pow(color, float3(1.0 / 2.2, 1.0 / 2.2, 1.0 / 2.2)); +} + +/// Hsla to linear RGBA conversion. +float4 hsla_to_rgba(Hsla hsla) { + float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range + float s = hsla.s; + float l = hsla.l; + float a = hsla.a; + + float c = (1.0 - abs(2.0 * l - 1.0)) * s; + float x = c * (1.0 - abs(fmod(h, 2.0) - 1.0)); + float m = l - c / 2.0; + + float r = 0.0; + float g = 0.0; + float b = 0.0; + + if (h >= 0.0 && h < 1.0) { + r = c; + g = x; + b = 0.0; + } else if (h >= 1.0 && h < 2.0) { + r = x; + g = c; + b = 0.0; + } else if (h >= 2.0 && h < 3.0) { + r = 0.0; + g = c; + b = x; + } else if (h >= 3.0 && h < 4.0) { + r = 0.0; + g = x; + b = c; + } else if (h >= 4.0 && h < 5.0) { + r = x; + g = 0.0; + b = c; + } else { + r = c; + g = 0.0; + b = x; + } + + float4 rgba; + rgba.x = (r + m); + rgba.y = (g + m); + rgba.z = (b + m); + rgba.w = a; + return rgba; +} + +// Converts a sRGB color to the Oklab color space. +// Reference: https://bottosson.github.io/posts/oklab/#converting-from-linear-srgb-to-oklab +float4 srgb_to_oklab(float4 color) { + // Convert non-linear sRGB to linear sRGB + color = float4(srgb_to_linear(color.rgb), color.a); + + float l = 0.4122214708 * color.r + 0.5363325363 * color.g + 0.0514459929 * color.b; + float m = 0.2119034982 * color.r + 0.6806995451 * color.g + 0.1073969566 * color.b; + float s = 0.0883024619 * color.r + 0.2817188376 * color.g + 0.6299787005 * color.b; + + float l_ = pow(l, 1.0/3.0); + float m_ = pow(m, 1.0/3.0); + float s_ = pow(s, 1.0/3.0); + + return float4( + 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, + 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, + 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, + color.a + ); +} + +// Converts an Oklab color to the sRGB color space. +float4 oklab_to_srgb(float4 color) { + float l_ = color.r + 0.3963377774 * color.g + 0.2158037573 * color.b; + float m_ = color.r - 0.1055613458 * color.g - 0.0638541728 * color.b; + float s_ = color.r - 0.0894841775 * color.g - 1.2914855480 * color.b; + + float l = l_ * l_ * l_; + float m = m_ * m_ * m_; + float s = s_ * s_ * s_; + + float3 linear_rgb = float3( + 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s, + -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s, + -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s + ); + + // Convert linear sRGB to non-linear sRGB + return float4(linear_to_srgb(linear_rgb), color.a); +} + +// This approximates the error function, needed for the gaussian integral +float2 erf(float2 x) { + float2 s = sign(x); + float2 a = abs(x); + x = 1. + (0.278393 + (0.230389 + 0.078108 * (a * a)) * a) * a; + x *= x; + return s - s / (x * x); +} + +float blur_along_x(float x, float y, float sigma, float corner, float2 half_size) { + float delta = min(half_size.y - corner - abs(y), 0.); + float curved = half_size.x - corner + sqrt(max(0., corner * corner - delta * delta)); + float2 integral = 0.5 + 0.5 * erf((x + float2(-curved, curved)) * (sqrt(0.5) / sigma)); + return integral.y - integral.x; +} + +// A standard gaussian function, used for weighting samples +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2. * sigma * sigma)) / (sqrt(2. * M_PI_F) * sigma); +} + +float4 over(float4 below, float4 above) { + float4 result; + float alpha = above.a + below.a * (1.0 - above.a); + result.rgb = (above.rgb * above.a + below.rgb * below.a * (1.0 - above.a)) / alpha; + result.a = alpha; + return result; +} + +float2 to_tile_position(float2 unit_vertex, AtlasTile tile) { + float2 atlas_size; + t_sprite.GetDimensions(atlas_size.x, atlas_size.y); + return (float2(tile.bounds.origin) + unit_vertex * float2(tile.bounds.size)) / atlas_size; +} + +// Selects corner radius based on quadrant. +float pick_corner_radius(float2 center_to_point, Corners corner_radii) { + if (center_to_point.x < 0.) { + if (center_to_point.y < 0.) { + return corner_radii.top_left; + } else { + return corner_radii.bottom_left; + } + } else { + if (center_to_point.y < 0.) { + return corner_radii.top_right; + } else { + return corner_radii.bottom_right; + } + } +} + +float4 to_device_position_transformed(float2 unit_vertex, Bounds bounds, + TransformationMatrix transformation) { + float2 position = unit_vertex * bounds.size + bounds.origin; + float2 transformed = mul(position, transformation.rotation_scale) + transformation.translation; + float2 device_position = transformed / global_viewport_size * float2(2.0, -2.0) + float2(-1.0, 1.0); + return float4(device_position, 0.0, 1.0); +} + +// Implementation of quad signed distance field +float quad_sdf_impl(float2 corner_center_to_point, float corner_radius) { + if (corner_radius == 0.0) { + // Fast path for unrounded corners + return max(corner_center_to_point.x, corner_center_to_point.y); + } else { + // Signed distance of the point from a quad that is inset by corner_radius + // It is negative inside this quad, and positive outside + float signed_distance_to_inset_quad = + // 0 inside the inset quad, and positive outside + length(max(float2(0.0, 0.0), corner_center_to_point)) + + // 0 outside the inset quad, and negative inside + min(0.0, max(corner_center_to_point.x, corner_center_to_point.y)); + + return signed_distance_to_inset_quad - corner_radius; + } +} + +float quad_sdf(float2 pt, Bounds bounds, Corners corner_radii) { + float2 half_size = bounds.size / 2.; + float2 center = bounds.origin + half_size; + float2 center_to_point = pt - center; + float corner_radius = pick_corner_radius(center_to_point, corner_radii); + float2 corner_to_point = abs(center_to_point) - half_size; + float2 corner_center_to_point = corner_to_point + corner_radius; + return quad_sdf_impl(corner_center_to_point, corner_radius); +} + +GradientColor prepare_gradient_color(uint tag, uint color_space, Hsla solid, LinearColorStop colors[2]) { + GradientColor output; + if (tag == 0 || tag == 2) { + output.solid = hsla_to_rgba(solid); + } else if (tag == 1) { + output.color0 = hsla_to_rgba(colors[0].color); + output.color1 = hsla_to_rgba(colors[1].color); + + // Prepare color space in vertex for avoid conversion + // in fragment shader for performance reasons + if (color_space == 1) { + // Oklab + output.color0 = srgb_to_oklab(output.color0); + output.color1 = srgb_to_oklab(output.color1); + } + } + + return output; +} + +float2x2 rotate2d(float angle) { + float s = sin(angle); + float c = cos(angle); + return float2x2(c, -s, s, c); +} + +float4 gradient_color(Background background, + float2 position, + Bounds bounds, + float4 solid_color, float4 color0, float4 color1) { + float4 color; + + switch (background.tag) { + case 0: + color = solid_color; + break; + case 1: { + // -90 degrees to match the CSS gradient angle. + float gradient_angle = background.gradient_angle_or_pattern_height; + float radians = (fmod(gradient_angle, 360.0) - 90.0) * (M_PI_F / 180.0); + float2 direction = float2(cos(radians), sin(radians)); + + // Expand the short side to be the same as the long side + if (bounds.size.x > bounds.size.y) { + direction.y *= bounds.size.y / bounds.size.x; + } else { + direction.x *= bounds.size.x / bounds.size.y; + } + + // Get the t value for the linear gradient with the color stop percentages. + float2 half_size = bounds.size * 0.5; + float2 center = bounds.origin + half_size; + float2 center_to_point = position - center; + float t = dot(center_to_point, direction) / length(direction); + // Check the direct to determine the use x or y + if (abs(direction.x) > abs(direction.y)) { + t = (t + half_size.x) / bounds.size.x; + } else { + t = (t + half_size.y) / bounds.size.y; + } + + // Adjust t based on the stop percentages + t = (t - background.colors[0].percentage) + / (background.colors[1].percentage + - background.colors[0].percentage); + t = clamp(t, 0.0, 1.0); + + switch (background.color_space) { + case 0: + color = lerp(color0, color1, t); + break; + case 1: { + float4 oklab_color = lerp(color0, color1, t); + color = oklab_to_srgb(oklab_color); + break; + } + } + break; + } + case 2: { + float gradient_angle_or_pattern_height = background.gradient_angle_or_pattern_height; + float pattern_width = (gradient_angle_or_pattern_height / 65535.0f) / 255.0f; + float pattern_interval = fmod(gradient_angle_or_pattern_height, 65535.0f) / 255.0f; + float pattern_height = pattern_width + pattern_interval; + float stripe_angle = M_PI_F / 4.0; + float pattern_period = pattern_height * sin(stripe_angle); + float2x2 rotation = rotate2d(stripe_angle); + float2 relative_position = position - bounds.origin; + float2 rotated_point = mul(rotation, relative_position); + float pattern = fmod(rotated_point.x, pattern_period); + float distance = min(pattern, pattern_period - pattern) - pattern_period * (pattern_width / pattern_height) / 2.0f; + color = solid_color; + color.a *= saturate(0.5 - distance); + break; + } + } + + return color; +} + +// Returns the dash velocity of a corner given the dash velocity of the two +// sides, by returning the slower velocity (larger dashes). +// +// Since 0 is used for dash velocity when the border width is 0 (instead of +// +inf), this returns the other dash velocity in that case. +// +// An alternative to this might be to appropriately interpolate the dash +// velocity around the corner, but that seems overcomplicated. +float corner_dash_velocity(float dv1, float dv2) { + if (dv1 == 0.0) { + return dv2; + } else if (dv2 == 0.0) { + return dv1; + } else { + return min(dv1, dv2); + } +} + +// Returns alpha used to render antialiased dashes. +// `t` is within the dash when `fmod(t, period) < length`. +float dash_alpha( + float t, float period, float length, float dash_velocity, + float antialias_threshold +) { + float half_period = period / 2.0; + float half_length = length / 2.0; + // Value in [-half_period, half_period] + // The dash is in [-half_length, half_length] + float centered = fmod(t + half_period - half_length, period) - half_period; + // Signed distance for the dash, negative values are inside the dash + float signed_distance = abs(centered) - half_length; + // Antialiased alpha based on the signed distance + return saturate(antialias_threshold - signed_distance / dash_velocity); +} + +// This approximates distance to the nearest point to a quarter ellipse in a way +// that is sufficient for anti-aliasing when the ellipse is not very eccentric. +// The components of `point` are expected to be positive. +// +// Negative on the outside and positive on the inside. +float quarter_ellipse_sdf(float2 pt, float2 radii) { + // Scale the space to treat the ellipse like a unit circle + float2 circle_vec = pt / radii; + float unit_circle_sdf = length(circle_vec) - 1.0; + // Approximate up-scaling of the length by using the average of the radii. + // + // TODO: A better solution would be to use the gradient of the implicit + // function for an ellipse to approximate a scaling factor. + return unit_circle_sdf * (radii.x + radii.y) * -0.5; +} + +/* +** +** Quads +** +*/ + +struct Quad { + uint order; + uint border_style; + Bounds bounds; + Bounds content_mask; + Background background; + Hsla border_color; + Corners corner_radii; + Edges border_widths; +}; + +struct QuadVertexOutput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; + float4 clip_distance: SV_ClipDistance; +}; + +struct QuadFragmentInput { + nointerpolation uint quad_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 border_color: COLOR0; + nointerpolation float4 background_solid: COLOR1; + nointerpolation float4 background_color0: COLOR2; + nointerpolation float4 background_color1: COLOR3; +}; + +StructuredBuffer quads: register(t1); + +QuadVertexOutput quad_vertex(uint vertex_id: SV_VertexID, uint quad_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Quad quad = quads[quad_id]; + float4 device_position = to_device_position(unit_vertex, quad.bounds); + + GradientColor gradient = prepare_gradient_color( + quad.background.tag, + quad.background.color_space, + quad.background.solid, + quad.background.colors + ); + float4 clip_distance = distance_from_clip_rect(unit_vertex, quad.bounds, quad.content_mask); + float4 border_color = hsla_to_rgba(quad.border_color); + + QuadVertexOutput output; + output.position = device_position; + output.border_color = border_color; + output.quad_id = quad_id; + output.background_solid = gradient.solid; + output.background_color0 = gradient.color0; + output.background_color1 = gradient.color1; + output.clip_distance = clip_distance; + return output; +} + +float4 quad_fragment(QuadFragmentInput input): SV_Target { + Quad quad = quads[input.quad_id]; + float4 background_color = gradient_color(quad.background, input.position.xy, quad.bounds, + input.background_solid, input.background_color0, input.background_color1); + + bool unrounded = quad.corner_radii.top_left == 0.0 && + quad.corner_radii.top_right == 0.0 && + quad.corner_radii.bottom_left == 0.0 && + quad.corner_radii.bottom_right == 0.0; + + // Fast path when the quad is not rounded and doesn't have any border + if (quad.border_widths.top == 0.0 && + quad.border_widths.left == 0.0 && + quad.border_widths.right == 0.0 && + quad.border_widths.bottom == 0.0 && + unrounded) { + return background_color; + } + + float2 size = quad.bounds.size; + float2 half_size = size / 2.; + float2 the_point = input.position.xy - quad.bounds.origin; + float2 center_to_point = the_point - half_size; + + // Signed distance field threshold for inclusion of pixels. 0.5 is the + // minimum distance between the center of the pixel and the edge. + const float antialias_threshold = 0.5; + + // Radius of the nearest corner + float corner_radius = pick_corner_radius(center_to_point, quad.corner_radii); + + float2 border = float2( + center_to_point.x < 0.0 ? quad.border_widths.left : quad.border_widths.right, + center_to_point.y < 0.0 ? quad.border_widths.top : quad.border_widths.bottom + ); + + // 0-width borders are reduced so that `inner_sdf >= antialias_threshold`. + // The purpose of this is to not draw antialiasing pixels in this case. + float2 reduced_border = float2( + border.x == 0.0 ? -antialias_threshold : border.x, + border.y == 0.0 ? -antialias_threshold : border.y + ); + + // Vector from the corner of the quad bounds to the point, after mirroring + // the point into the bottom right quadrant. Both components are <= 0. + float2 corner_to_point = abs(center_to_point) - half_size; + + // Vector from the point to the center of the rounded corner's circle, also + // mirrored into bottom right quadrant. + float2 corner_center_to_point = corner_to_point + corner_radius; + + // Whether the nearest point on the border is rounded + bool is_near_rounded_corner = + corner_center_to_point.x >= 0.0 && + corner_center_to_point.y >= 0.0; + + // Vector from straight border inner corner to point. + // + // 0-width borders are turned into width -1 so that inner_sdf is > 1.0 near + // the border. Without this, antialiasing pixels would be drawn. + float2 straight_border_inner_corner_to_point = corner_to_point + reduced_border; + + // Whether the point is beyond the inner edge of the straight border + bool is_beyond_inner_straight_border = + straight_border_inner_corner_to_point.x > 0.0 || + straight_border_inner_corner_to_point.y > 0.0; + + // Whether the point is far enough inside the quad, such that the pixels are + // not affected by the straight border. + bool is_within_inner_straight_border = + straight_border_inner_corner_to_point.x < -antialias_threshold && + straight_border_inner_corner_to_point.y < -antialias_threshold; + + // Fast path for points that must be part of the background + if (is_within_inner_straight_border && !is_near_rounded_corner) { + return background_color; + } + + // Signed distance of the point to the outside edge of the quad's border + float outer_sdf = quad_sdf_impl(corner_center_to_point, corner_radius); + + // Approximate signed distance of the point to the inside edge of the quad's + // border. It is negative outside this edge (within the border), and + // positive inside. + // + // This is not always an accurate signed distance: + // * The rounded portions with varying border width use an approximation of + // nearest-point-on-ellipse. + // * When it is quickly known to be outside the edge, -1.0 is used. + float inner_sdf = 0.0; + if (corner_center_to_point.x <= 0.0 || corner_center_to_point.y <= 0.0) { + // Fast paths for straight borders + inner_sdf = -max(straight_border_inner_corner_to_point.x, + straight_border_inner_corner_to_point.y); + } else if (is_beyond_inner_straight_border) { + // Fast path for points that must be outside the inner edge + inner_sdf = -1.0; + } else if (reduced_border.x == reduced_border.y) { + // Fast path for circular inner edge. + inner_sdf = -(outer_sdf + reduced_border.x); + } else { + float2 ellipse_radii = max(float2(0.0, 0.0), float2(corner_radius, corner_radius) - reduced_border); + inner_sdf = quarter_ellipse_sdf(corner_center_to_point, ellipse_radii); + } + + // Negative when inside the border + float border_sdf = max(inner_sdf, outer_sdf); + + float4 color = background_color; + if (border_sdf < antialias_threshold) { + float4 border_color = input.border_color; + // Dashed border logic when border_style == 1 + if (quad.border_style == 1) { + // Position along the perimeter in "dash space", where each dash + // period has length 1 + float t = 0.0; + + // Total number of dash periods, so that the dash spacing can be + // adjusted to evenly divide it + float max_t = 0.0; + + // Border width is proportional to dash size. This is the behavior + // used by browsers, but also avoids dashes from different segments + // overlapping when dash size is smaller than the border width. + // + // Dash pattern: (2 * border width) dash, (1 * border width) gap + const float dash_length_per_width = 2.0; + const float dash_gap_per_width = 1.0; + const float dash_period_per_width = dash_length_per_width + dash_gap_per_width; + + // Since the dash size is determined by border width, the density of + // dashes varies. Multiplying a pixel distance by this returns a + // position in dash space - it has units (dash period / pixels). So + // a dash velocity of (1 / 10) is 1 dash every 10 pixels. + float dash_velocity = 0.0; + + // Dividing this by the border width gives the dash velocity + const float dv_numerator = 1.0 / dash_period_per_width; + + if (unrounded) { + // When corners aren't rounded, the dashes are separately laid + // out on each straight line, rather than around the whole + // perimeter. This way each line starts and ends with a dash. + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + float border_width = is_horizontal ? border.x : border.y; + dash_velocity = dv_numerator / border_width; + t = is_horizontal ? the_point.x : the_point.y; + t *= dash_velocity; + max_t = is_horizontal ? size.x : size.y; + max_t *= dash_velocity; + } else { + // When corners are rounded, the dashes are laid out clockwise + // around the whole perimeter. + + float r_tr = quad.corner_radii.top_right; + float r_br = quad.corner_radii.bottom_right; + float r_bl = quad.corner_radii.bottom_left; + float r_tl = quad.corner_radii.top_left; + + float w_t = quad.border_widths.top; + float w_r = quad.border_widths.right; + float w_b = quad.border_widths.bottom; + float w_l = quad.border_widths.left; + + // Straight side dash velocities + float dv_t = w_t <= 0.0 ? 0.0 : dv_numerator / w_t; + float dv_r = w_r <= 0.0 ? 0.0 : dv_numerator / w_r; + float dv_b = w_b <= 0.0 ? 0.0 : dv_numerator / w_b; + float dv_l = w_l <= 0.0 ? 0.0 : dv_numerator / w_l; + + // Straight side lengths in dash space + float s_t = (size.x - r_tl - r_tr) * dv_t; + float s_r = (size.y - r_tr - r_br) * dv_r; + float s_b = (size.x - r_br - r_bl) * dv_b; + float s_l = (size.y - r_bl - r_tl) * dv_l; + + float corner_dash_velocity_tr = corner_dash_velocity(dv_t, dv_r); + float corner_dash_velocity_br = corner_dash_velocity(dv_b, dv_r); + float corner_dash_velocity_bl = corner_dash_velocity(dv_b, dv_l); + float corner_dash_velocity_tl = corner_dash_velocity(dv_t, dv_l); + + // Corner lengths in dash space + float c_tr = r_tr * (M_PI_F / 2.0) * corner_dash_velocity_tr; + float c_br = r_br * (M_PI_F / 2.0) * corner_dash_velocity_br; + float c_bl = r_bl * (M_PI_F / 2.0) * corner_dash_velocity_bl; + float c_tl = r_tl * (M_PI_F / 2.0) * corner_dash_velocity_tl; + + // Cumulative dash space upto each segment + float upto_tr = s_t; + float upto_r = upto_tr + c_tr; + float upto_br = upto_r + s_r; + float upto_b = upto_br + c_br; + float upto_bl = upto_b + s_b; + float upto_l = upto_bl + c_bl; + float upto_tl = upto_l + s_l; + max_t = upto_tl + c_tl; + + if (is_near_rounded_corner) { + float radians = atan2(corner_center_to_point.y, corner_center_to_point.x); + float corner_t = radians * corner_radius; + + if (center_to_point.x >= 0.0) { + if (center_to_point.y < 0.0) { + dash_velocity = corner_dash_velocity_tr; + // Subtracted because radians is pi/2 to 0 when + // going clockwise around the top right corner, + // since the y axis has been flipped + t = upto_r - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_br; + // Added because radians is 0 to pi/2 when going + // clockwise around the bottom-right corner + t = upto_br + corner_t * dash_velocity; + } + } else { + if (center_to_point.y >= 0.0) { + dash_velocity = corner_dash_velocity_bl; + // Subtracted because radians is pi/1 to 0 when + // going clockwise around the bottom-left corner, + // since the x axis has been flipped + t = upto_l - corner_t * dash_velocity; + } else { + dash_velocity = corner_dash_velocity_tl; + // Added because radians is 0 to pi/2 when going + // clockwise around the top-left corner, since both + // axis were flipped + t = upto_tl + corner_t * dash_velocity; + } + } + } else { + // Straight borders + bool is_horizontal = corner_center_to_point.x < corner_center_to_point.y; + if (is_horizontal) { + if (center_to_point.y < 0.0) { + dash_velocity = dv_t; + t = (the_point.x - r_tl) * dash_velocity; + } else { + dash_velocity = dv_b; + t = upto_bl - (the_point.x - r_bl) * dash_velocity; + } + } else { + if (center_to_point.x < 0.0) { + dash_velocity = dv_l; + t = upto_tl - (the_point.y - r_tl) * dash_velocity; + } else { + dash_velocity = dv_r; + t = upto_r + (the_point.y - r_tr) * dash_velocity; + } + } + } + } + float dash_length = dash_length_per_width / dash_period_per_width; + float desired_dash_gap = dash_gap_per_width / dash_period_per_width; + + // Straight borders should start and end with a dash, so max_t is + // reduced to cause this. + max_t -= unrounded ? dash_length : 0.0; + if (max_t >= 1.0) { + // Adjust dash gap to evenly divide max_t + float dash_count = floor(max_t); + float dash_period = max_t / dash_count; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } else if (unrounded) { + // When there isn't enough space for the full gap between the + // two start / end dashes of a straight border, reduce gap to + // make them fit. + float dash_gap = max_t - dash_length; + if (dash_gap > 0.0) { + float dash_period = dash_length + dash_gap; + border_color.a *= dash_alpha(t, dash_period, dash_length, dash_velocity, antialias_threshold); + } + } + } + + // Blend the border on top of the background and then linearly interpolate + // between the two as we slide inside the background. + float4 blended_border = over(background_color, border_color); + color = lerp(background_color, blended_border, + saturate(antialias_threshold - inner_sdf)); + } + + return color * float4(1.0, 1.0, 1.0, saturate(antialias_threshold - outer_sdf)); +} + +/* +** +** Shadows +** +*/ + +struct Shadow { + uint order; + float blur_radius; + Bounds bounds; + Corners corner_radii; + Bounds content_mask; + Hsla color; +}; + +struct ShadowVertexOutput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct ShadowFragmentInput { + nointerpolation uint shadow_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer shadows: register(t1); + +ShadowVertexOutput shadow_vertex(uint vertex_id: SV_VertexID, uint shadow_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Shadow shadow = shadows[shadow_id]; + + float margin = 3.0 * shadow.blur_radius; + Bounds bounds = shadow.bounds; + bounds.origin -= margin; + bounds.size += 2.0 * margin; + + float4 device_position = to_device_position(unit_vertex, bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, bounds, shadow.content_mask); + float4 color = hsla_to_rgba(shadow.color); + + ShadowVertexOutput output; + output.position = device_position; + output.color = color; + output.shadow_id = shadow_id; + output.clip_distance = clip_distance; + + return output; +} + +float4 shadow_fragment(ShadowFragmentInput input): SV_TARGET { + Shadow shadow = shadows[input.shadow_id]; + + float2 half_size = shadow.bounds.size / 2.; + float2 center = shadow.bounds.origin + half_size; + float2 point0 = input.position.xy - center; + float corner_radius = pick_corner_radius(point0, shadow.corner_radii); + + // The signal is only non-zero in a limited range, so don't waste samples + float low = point0.y - half_size.y; + float high = point0.y + half_size.y; + float start = clamp(-3. * shadow.blur_radius, low, high); + float end = clamp(3. * shadow.blur_radius, low, high); + + // Accumulate samples (we can get away with surprisingly few samples) + float step = (end - start) / 4.; + float y = start + step * 0.5; + float alpha = 0.; + for (int i = 0; i < 4; i++) { + alpha += blur_along_x(point0.x, point0.y - y, shadow.blur_radius, + corner_radius, half_size) * + gaussian(y, shadow.blur_radius) * step; + y += step; + } + + return input.color * float4(1., 1., 1., alpha); +} + +/* +** +** Path Rasterization +** +*/ + +struct PathRasterizationSprite { + float2 xy_position; + float2 st_position; + Background color; + Bounds bounds; +}; + +StructuredBuffer path_rasterization_sprites: register(t1); + +struct PathVertexOutput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; + float4 clip_distance: SV_ClipDistance; +}; + +struct PathFragmentInput { + float4 position: SV_Position; + float2 st_position: TEXCOORD0; + nointerpolation uint vertex_id: TEXCOORD1; +}; + +PathVertexOutput path_rasterization_vertex(uint vertex_id: SV_VertexID) { + PathRasterizationSprite sprite = path_rasterization_sprites[vertex_id]; + + PathVertexOutput output; + output.position = to_device_position_impl(sprite.xy_position); + output.st_position = sprite.st_position; + output.vertex_id = vertex_id; + output.clip_distance = distance_from_clip_rect_impl(sprite.xy_position, sprite.bounds); + + return output; +} + +float4 path_rasterization_fragment(PathFragmentInput input): SV_Target { + float2 dx = ddx(input.st_position); + float2 dy = ddy(input.st_position); + PathRasterizationSprite sprite = path_rasterization_sprites[input.vertex_id]; + + Background background = sprite.color; + Bounds bounds = sprite.bounds; + + float alpha; + if (length(float2(dx.x, dy.x))) { + alpha = 1.0; + } else { + float2 gradient = 2.0 * input.st_position.xx * float2(dx.x, dy.x) - float2(dx.y, dy.y); + float f = input.st_position.x * input.st_position.x - input.st_position.y; + float distance = f / length(gradient); + alpha = saturate(0.5 - distance); + } + + GradientColor gradient = prepare_gradient_color( + background.tag, background.color_space, background.solid, background.colors); + + float4 color = gradient_color(background, input.position.xy, bounds, + gradient.solid, gradient.color0, gradient.color1); + return float4(color.rgb * color.a * alpha, alpha * color.a); +} + +/* +** +** Path Sprites +** +*/ + +struct PathSprite { + Bounds bounds; +}; + +struct PathSpriteVertexOutput { + float4 position: SV_Position; + float2 texture_coords: TEXCOORD0; +}; + +StructuredBuffer path_sprites: register(t1); + +PathSpriteVertexOutput path_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PathSprite sprite = path_sprites[sprite_id]; + + // Don't apply content mask because it was already accounted for when rasterizing the path + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + + float2 screen_position = sprite.bounds.origin + unit_vertex * sprite.bounds.size; + float2 texture_coords = screen_position / global_viewport_size; + + PathSpriteVertexOutput output; + output.position = device_position; + output.texture_coords = texture_coords; + return output; +} + +float4 path_sprite_fragment(PathSpriteVertexOutput input): SV_Target { + return t_sprite.Sample(s_sprite, input.texture_coords); +} + +/* +** +** Underlines +** +*/ + +struct Underline { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + float thickness; + uint wavy; +}; + +struct UnderlineVertexOutput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct UnderlineFragmentInput { + nointerpolation uint underline_id: TEXCOORD0; + float4 position: SV_Position; + nointerpolation float4 color: COLOR; +}; + +StructuredBuffer underlines: register(t1); + +UnderlineVertexOutput underline_vertex(uint vertex_id: SV_VertexID, uint underline_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + Underline underline = underlines[underline_id]; + float4 device_position = to_device_position(unit_vertex, underline.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, underline.bounds, + underline.content_mask); + float4 color = hsla_to_rgba(underline.color); + + UnderlineVertexOutput output; + output.position = device_position; + output.color = color; + output.underline_id = underline_id; + output.clip_distance = clip_distance; + return output; +} + +float4 underline_fragment(UnderlineFragmentInput input): SV_Target { + Underline underline = underlines[input.underline_id]; + if (underline.wavy) { + float half_thickness = underline.thickness * 0.5; + float2 origin = underline.bounds.origin; + float2 st = ((input.position.xy - origin) / underline.bounds.size.y) - float2(0., 0.5); + float frequency = (M_PI_F * (3. * underline.thickness)) / 8.; + float amplitude = 1. / (2. * underline.thickness); + float sine = sin(st.x * frequency) * amplitude; + float dSine = cos(st.x * frequency) * amplitude * frequency; + float distance = (st.y - sine) / sqrt(1. + dSine * dSine); + float distance_in_pixels = distance * underline.bounds.size.y; + float distance_from_top_border = distance_in_pixels - half_thickness; + float distance_from_bottom_border = distance_in_pixels + half_thickness; + float alpha = saturate( + 0.5 - max(-distance_from_bottom_border, distance_from_top_border)); + return input.color * float4(1., 1., 1., alpha); + } else { + return input.color; + } +} + +/* +** +** Monochrome sprites +** +*/ + +struct MonochromeSprite { + uint order; + uint pad; + Bounds bounds; + Bounds content_mask; + Hsla color; + AtlasTile tile; + TransformationMatrix transformation; +}; + +struct MonochromeSpriteVertexOutput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +struct MonochromeSpriteFragmentInput { + float4 position: SV_Position; + float2 tile_position: POSITION; + nointerpolation float4 color: COLOR; + float4 clip_distance: SV_ClipDistance; +}; + +StructuredBuffer mono_sprites: register(t1); + +MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + MonochromeSprite sprite = mono_sprites[sprite_id]; + float4 device_position = + to_device_position_transformed(unit_vertex, sprite.bounds, sprite.transformation); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + float4 color = hsla_to_rgba(sprite.color); + + MonochromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.color = color; + output.clip_distance = clip_distance; + return output; +} + +float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target { + float sample = t_sprite.Sample(s_sprite, input.tile_position).r; + return float4(input.color.rgb, input.color.a * sample); +} + +/* +** +** Polychrome sprites +** +*/ + +struct PolychromeSprite { + uint order; + uint pad; + uint grayscale; + float opacity; + Bounds bounds; + Bounds content_mask; + Corners corner_radii; + AtlasTile tile; +}; + +struct PolychromeSpriteVertexOutput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; + float4 clip_distance: SV_ClipDistance; +}; + +struct PolychromeSpriteFragmentInput { + nointerpolation uint sprite_id: TEXCOORD0; + float4 position: SV_Position; + float2 tile_position: POSITION; +}; + +StructuredBuffer poly_sprites: register(t1); + +PolychromeSpriteVertexOutput polychrome_sprite_vertex(uint vertex_id: SV_VertexID, uint sprite_id: SV_InstanceID) { + float2 unit_vertex = float2(float(vertex_id & 1u), 0.5 * float(vertex_id & 2u)); + PolychromeSprite sprite = poly_sprites[sprite_id]; + float4 device_position = to_device_position(unit_vertex, sprite.bounds); + float4 clip_distance = distance_from_clip_rect(unit_vertex, sprite.bounds, + sprite.content_mask); + float2 tile_position = to_tile_position(unit_vertex, sprite.tile); + + PolychromeSpriteVertexOutput output; + output.position = device_position; + output.tile_position = tile_position; + output.sprite_id = sprite_id; + output.clip_distance = clip_distance; + return output; +} + +float4 polychrome_sprite_fragment(PolychromeSpriteFragmentInput input): SV_Target { + PolychromeSprite sprite = poly_sprites[input.sprite_id]; + float4 sample = t_sprite.Sample(s_sprite, input.tile_position); + float distance = quad_sdf(input.position.xy, sprite.bounds, sprite.corner_radii); + + float4 color = sample; + if ((sprite.grayscale & 0xFFu) != 0u) { + float3 grayscale = dot(color.rgb, GRAYSCALE_FACTORS); + color = float4(grayscale, sample.a); + } + color.a *= sprite.opacity * saturate(0.5 - distance); + return color; +} diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index 5703a82815..68b667569b 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -26,7 +26,6 @@ use windows::{ core::*, }; -use crate::platform::blade::{BladeContext, BladeRenderer}; use crate::*; pub(crate) struct WindowsWindow(pub Rc); @@ -49,7 +48,7 @@ pub struct WindowsWindowState { pub system_key_handled: bool, pub hovered: bool, - pub renderer: BladeRenderer, + pub renderer: DirectXRenderer, pub click_state: ClickState, pub system_settings: WindowsSystemSettings, @@ -80,13 +79,12 @@ pub(crate) struct WindowsWindowStatePtr { impl WindowsWindowState { fn new( hwnd: HWND, - transparent: bool, cs: &CREATESTRUCTW, current_cursor: Option, display: WindowsDisplay, - gpu_context: &BladeContext, min_size: Option>, appearance: WindowAppearance, + disable_direct_composition: bool, ) -> Result { let scale_factor = { let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32; @@ -103,7 +101,8 @@ impl WindowsWindowState { }; let border_offset = WindowBorderOffset::default(); let restore_from_minimized = None; - let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?; + let renderer = DirectXRenderer::new(hwnd, disable_direct_composition) + .context("Creating DirectX renderer")?; let callbacks = Callbacks::default(); let input_handler = None; let pending_surrogate = None; @@ -206,13 +205,12 @@ impl WindowsWindowStatePtr { fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result> { let state = RefCell::new(WindowsWindowState::new( hwnd, - context.transparent, cs, context.current_cursor, context.display, - context.gpu_context, context.min_size, context.appearance, + context.disable_direct_composition, )?); Ok(Rc::new_cyclic(|this| Self { @@ -329,12 +327,11 @@ pub(crate) struct Callbacks { pub(crate) appearance_changed: Option>, } -struct WindowCreateContext<'a> { +struct WindowCreateContext { inner: Option>>, handle: AnyWindowHandle, hide_title_bar: bool, display: WindowsDisplay, - transparent: bool, is_movable: bool, min_size: Option>, executor: ForegroundExecutor, @@ -343,9 +340,9 @@ struct WindowCreateContext<'a> { drop_target_helper: IDropTargetHelper, validation_number: usize, main_receiver: flume::Receiver, - gpu_context: &'a BladeContext, main_thread_id_win32: u32, appearance: WindowAppearance, + disable_direct_composition: bool, } impl WindowsWindow { @@ -353,7 +350,6 @@ impl WindowsWindow { handle: AnyWindowHandle, params: WindowParams, creation_info: WindowCreationInfo, - gpu_context: &BladeContext, ) -> Result { let WindowCreationInfo { icon, @@ -364,6 +360,7 @@ impl WindowsWindow { validation_number, main_receiver, main_thread_id_win32, + disable_direct_composition, } = creation_info; let classname = register_wnd_class(icon); let hide_title_bar = params @@ -379,14 +376,18 @@ impl WindowsWindow { .map(|title| title.as_ref()) .unwrap_or(""), ); - let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp { - (WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0)) + + let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp { + (WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0)) } else { ( - WS_EX_APPWINDOW | WS_EX_LAYERED, + WS_EX_APPWINDOW, WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX, ) }; + if !disable_direct_composition { + dwexstyle |= WS_EX_NOREDIRECTIONBITMAP; + } let hinstance = get_module_handle(); let display = if let Some(display_id) = params.display_id { @@ -401,7 +402,6 @@ impl WindowsWindow { handle, hide_title_bar, display, - transparent: true, is_movable: params.is_movable, min_size: params.window_min_size, executor, @@ -410,9 +410,9 @@ impl WindowsWindow { drop_target_helper, validation_number, main_receiver, - gpu_context, main_thread_id_win32, appearance, + disable_direct_composition, }; let lpparam = Some(&context as *const _ as *const _); let creation_result = unsafe { @@ -453,14 +453,6 @@ impl WindowsWindow { state: WindowOpenState::Windowed, }); } - // The render pipeline will perform compositing on the GPU when the - // swapchain is configured correctly (see downstream of - // update_transparency). - // The following configuration is a one-time setup to ensure that the - // window is going to be composited with per-pixel alpha, but the render - // pipeline is responsible for effectively calling UpdateLayeredWindow - // at the appropriate time. - unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? }; Ok(Self(state_ptr)) } @@ -485,7 +477,6 @@ impl rwh::HasDisplayHandle for WindowsWindow { impl Drop for WindowsWindow { fn drop(&mut self) { - self.0.state.borrow_mut().renderer.destroy(); // clone this `Rc` to prevent early release of the pointer let this = self.0.clone(); self.0 @@ -705,24 +696,21 @@ impl PlatformWindow for WindowsWindow { } fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) { - let mut window_state = self.0.state.borrow_mut(); - window_state - .renderer - .update_transparency(background_appearance != WindowBackgroundAppearance::Opaque); + let hwnd = self.0.hwnd; match background_appearance { WindowBackgroundAppearance::Opaque => { // ACCENT_DISABLED - set_window_composition_attribute(window_state.hwnd, None, 0); + set_window_composition_attribute(hwnd, None, 0); } WindowBackgroundAppearance::Transparent => { // Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background - set_window_composition_attribute(window_state.hwnd, None, 2); + set_window_composition_attribute(hwnd, None, 2); } WindowBackgroundAppearance::Blurred => { // Enable acrylic blur // ACCENT_ENABLE_ACRYLICBLURBEHIND - set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4); + set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4); } } } @@ -794,11 +782,11 @@ impl PlatformWindow for WindowsWindow { } fn draw(&self, scene: &Scene) { - self.0.state.borrow_mut().renderer.draw(scene) + self.0.state.borrow_mut().renderer.draw(scene).log_err(); } fn sprite_atlas(&self) -> Arc { - self.0.state.borrow().renderer.sprite_atlas().clone() + self.0.state.borrow().renderer.sprite_atlas() } fn get_raw_handle(&self) -> HWND { @@ -806,11 +794,11 @@ impl PlatformWindow for WindowsWindow { } fn gpu_specs(&self) -> Option { - Some(self.0.state.borrow().renderer.gpu_specs()) + self.0.state.borrow().renderer.gpu_specs().log_err() } fn update_ime_position(&self, _bounds: Bounds) { - // todo(windows) + // There is no such thing on Windows. } } @@ -1306,52 +1294,6 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option, state: u32 } } -mod windows_renderer { - use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig}; - use raw_window_handle as rwh; - use std::num::NonZeroIsize; - use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE}; - - use crate::{get_window_long, show_error}; - - pub(super) fn init( - context: &BladeContext, - hwnd: HWND, - transparent: bool, - ) -> anyhow::Result { - let raw = RawWindow { hwnd }; - let config = BladeSurfaceConfig { - size: Default::default(), - transparent, - }; - BladeRenderer::new(context, &raw, config) - .inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string())) - } - - struct RawWindow { - hwnd: HWND, - } - - impl rwh::HasWindowHandle for RawWindow { - fn window_handle(&self) -> Result, rwh::HandleError> { - Ok(unsafe { - let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize); - let mut handle = rwh::Win32WindowHandle::new(hwnd); - let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE); - handle.hinstance = NonZeroIsize::new(hinstance); - rwh::WindowHandle::borrow_raw(handle.into()) - }) - } - } - - impl rwh::HasDisplayHandle for RawWindow { - fn display_handle(&self) -> Result, rwh::HandleError> { - let handle = rwh::WindowsDisplayHandle::new(); - Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) }) - } - } -} - #[cfg(test)] mod tests { use super::ClickState; diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index 01fbfff1c5..6ebb1cac40 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -1020,7 +1020,7 @@ impl Window { || (active.get() && last_input_timestamp.get().elapsed() < Duration::from_secs(1)); - if invalidator.is_dirty() { + if invalidator.is_dirty() || request_frame_options.force_render { measure("frame duration", || { handle .update(&mut cx, |_, window, cx| { diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 2045708ff2..3f51cc5a23 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -23,6 +23,7 @@ futures.workspace = true http.workspace = true http-body.workspace = true log.workspace = true +parking_lot.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 434bd74fc8..d33bbefc06 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -9,12 +9,10 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri}; use futures::future::BoxFuture; use http::request::Builder; +use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Arc}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -86,6 +84,11 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::()) + } } /// An [`HttpClient`] that may have a proxy. @@ -132,6 +135,11 @@ impl HttpClient for HttpClientWithProxy { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } impl HttpClient for Arc { @@ -153,6 +161,11 @@ impl HttpClient for Arc { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } /// An [`HttpClient`] that has a base URL. @@ -199,20 +212,13 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url - .lock() - .map_or_else(|_| Default::default(), |url| url.clone()) + self.base_url.lock().clone() } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into) { let base_url = base_url.into(); - self.base_url - .lock() - .map(|mut url| { - *url = base_url; - }) - .ok(); + *self.base_url.lock() = base_url; } /// Builds a URL using the given path. @@ -236,6 +242,22 @@ impl HttpClientWithUrl { )?) } + /// Builds a Zed Cloud URL using the given path. + pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result { + let base_url = self.base_url(); + let base_api_url = match base_url.as_ref() { + "https://zed.dev" => "https://cloud.zed.dev", + "https://staging.zed.dev" => "https://cloud.zed.dev", + "http://localhost:3000" => "http://localhost:8787", + other => other, + }; + + Ok(Url::parse_with_params( + &format!("{}{}", base_api_url, path), + query, + )?) + } + /// Builds a Zed LLM URL using the given path. pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result { let base_url = self.base_url(); @@ -272,6 +294,11 @@ impl HttpClient for Arc { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } impl HttpClient for HttpClientWithUrl { @@ -293,6 +320,11 @@ impl HttpClient for HttpClientWithUrl { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } pub fn read_proxy_from_env() -> Option { @@ -344,10 +376,15 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::()) + } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Box< +type FakeHttpHandler = Arc< dyn Fn(Request) -> BoxFuture<'static, anyhow::Result>> + Send + Sync @@ -356,7 +393,7 @@ type FakeHttpHandler = Box< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: FakeHttpHandler, + handler: Mutex>, user_agent: HeaderValue, } @@ -371,7 +408,7 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Box::new(move |req| Box::pin(handler(req))), + handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), user_agent: HeaderValue::from_static(type_name::()), }), proxy: None, @@ -396,6 +433,18 @@ impl FakeHttpClient { .unwrap()) }) } + + pub fn replace_handler(&self, new_handler: F) + where + Fut: futures::Future>> + Send + 'static, + F: Fn(FakeHttpHandler, Request) -> Fut + Send + Sync + 'static, + { + let mut handler = self.handler.lock(); + let old_handler = handler.take().unwrap(); + *handler = Some(Arc::new(move |req| { + Box::pin(new_handler(old_handler.clone(), req)) + })); + } } #[cfg(feature = "test-support")] @@ -411,7 +460,7 @@ impl HttpClient for FakeHttpClient { &self, req: Request, ) -> BoxFuture<'static, anyhow::Result>> { - let future = (self.handler)(req); + let future = (self.handler.lock().as_ref().unwrap())(req); future } @@ -426,4 +475,8 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::() } + + fn as_fake(&self) -> &FakeHttpClient { + self + } } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 7552060be4..fe68cdd2d6 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -107,6 +107,12 @@ pub enum IconName { Disconnected, DocumentText, Download, + EditorAtom, + EditorCursor, + EditorEmacs, + EditorJetBrains, + EditorSublime, + EditorVsCode, Ellipsis, EllipsisVertical, Envelope, @@ -229,6 +235,7 @@ pub enum IconName { Server, Settings, SettingsAlt, + ShieldCheck, Shift, Slash, SlashSquare, diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index 81d9181cfc..2d7f211942 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -246,12 +246,15 @@ impl Render for InlineCompletionButton { }; if zeta::should_show_upsell_modal(&self.user_store, cx) { - let tooltip_meta = - match self.user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => "Choose a Plan", - Some(false) => "Accept the Terms of Service", - None => "Sign In", - }; + let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { + if self.user_store.read(cx).has_accepted_terms_of_service() { + "Choose a Plan" + } else { + "Accept the Terms of Service" + } + } else { + "Sign In" + }; return div().child( IconButton::new("zed-predict-pending-button", zeta_icon) @@ -387,9 +390,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, + user_store, popover_menu_handle, fs, - user_store, } } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60..8ae5893410 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_llm_client::Plan; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use proto::{Plan, TypedEnvelope}; +use proto::TypedEnvelope; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,7 +31,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } @@ -64,9 +65,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, ) -> Result { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 208b0d99c9..b5bfb870f6 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true -proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 3de135c5a2..2108547c4f 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -6,7 +6,7 @@ use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, - EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; @@ -27,7 +27,6 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; -use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -137,11 +136,10 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store, + user_store: user_store.clone(), status, accept_terms_of_service_task: None, models: Vec::new(), @@ -154,8 +152,9 @@ impl State { .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; loop { - let status = this.read_with(cx, |this, _cx| this.status)?; - if matches!(status, client::Status::Connected { .. }) { + let is_authenticated = user_store + .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?; + if is_authenticated { break; } @@ -194,26 +193,20 @@ impl State { } } - fn is_signed_out(&self) -> bool { - self.status.is_signed_out() + fn is_signed_out(&self, cx: &App) -> bool { + self.user_store.read(cx).current_user().is_none() } fn authenticate(&self, cx: &mut Context) -> Task> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.sign_in_with_optional_connect(true, &cx).await?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) + self.user_store.read(cx).has_accepted_terms_of_service() } fn accept_terms_of_service(&mut self, cx: &mut Context) { @@ -398,7 +391,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out() && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task> { @@ -613,11 +606,6 @@ impl CloudLanguageModel { .and_then(|plan| plan.to_str().ok()) .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { - let plan = match plan { - cloud_llm_client::Plan::ZedFree => Plan::Free, - cloud_llm_client::Plan::ZedPro => Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, - }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -1118,7 +1106,7 @@ fn response_lines( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option, + plan: Option, subscription_period: Option<(DateTime, DateTime)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, @@ -1132,15 +1120,15 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(proto::Plan::ZedPro); + let is_pro = self.plan == Some(Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(proto::Plan::ZedPro), Some(_)) => { + (Some(Plan::ZedPro), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(proto::Plan::ZedProTrial), Some(_)) => { + (Some(Plan::ZedProTrial), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(proto::Plan::Free), Some(_)) => { + (Some(Plan::ZedFree), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1265,8 +1253,8 @@ impl Render for ConfigurationView { let user_store = state.user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(), - plan: user_store.current_plan(), + is_connected: !state.is_signed_out(cx), + plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), @@ -1286,7 +1274,7 @@ impl Component for ZedAiConfiguration { fn preview(_window: &mut Window, _cx: &mut App) -> Option { fn configuration( is_connected: bool, - plan: Option, + plan: Option, eligible_for_trial: bool, account_too_young: bool, has_accepted_terms_of_service: bool, @@ -1330,15 +1318,15 @@ impl Component for ZedAiConfiguration { ), single_example( "Free Plan", - configuration(true, Some(proto::Plan::Free), true, false, true), + configuration(true, Some(Plan::ZedFree), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(proto::Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(proto::Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::ZedPro), true, false, true), ), ]) .into_any_element(), diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 9e95ed4673..a339f3b941 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1015,7 +1015,7 @@ impl Render for LspTool { .anchor(Corner::BottomLeft) .with_handle(self.popover_menu_handle.clone()) .trigger_with_tooltip( - IconButton::new("zed-lsp-tool-button", IconName::BoltFilledAlt) + IconButton::new("zed-lsp-tool-button", IconName::Bolt) .when_some(indicator, IconButton::indicator) .icon_size(IconSize::Small) .indicator_border_color(Some(cx.theme().colors().status_bar_background)), diff --git a/crates/livekit_client/Cargo.toml b/crates/livekit_client/Cargo.toml index c367e03bb7..821fd5d390 100644 --- a/crates/livekit_client/Cargo.toml +++ b/crates/livekit_client/Cargo.toml @@ -40,8 +40,8 @@ util.workspace = true workspace-hack.workspace = true [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies] -libwebrtc = { rev = "383e5377f8b7de1f8627ee16f0cf11c5293337bd", git = "https://github.com/zed-industries/livekit-rust-sdks" } -livekit = { rev = "383e5377f8b7de1f8627ee16f0cf11c5293337bd", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ +libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" } +livekit = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks", features = [ "__rustls-tls" ] } diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index ccb39ab8a2..b9701a83d2 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -421,14 +421,14 @@ impl LanguageServer { .map(|stderr| { let io_handlers = io_handlers.clone(); let stderr_captures = stderr_capture.clone(); - cx.spawn(async move |_| { + cx.background_spawn(async move { Self::handle_stderr(stderr, io_handlers, stderr_captures) .log_err() .await }) }) .unwrap_or_else(|| Task::ready(None)); - let input_task = cx.spawn(async move |_| { + let input_task = cx.background_spawn(async move { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); @@ -846,7 +846,7 @@ impl LanguageServer { configuration: Arc, cx: &App, ) -> Task>> { - cx.spawn(async move |_| { + cx.background_spawn(async move { let response = self .request::(params) .await diff --git a/crates/multi_buffer/src/anchor.rs b/crates/multi_buffer/src/anchor.rs index 9e28295c56..1305328d38 100644 --- a/crates/multi_buffer/src/anchor.rs +++ b/crates/multi_buffer/src/anchor.rs @@ -167,10 +167,10 @@ impl Anchor { if *self == Anchor::min() || *self == Anchor::max() { true } else if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) { - excerpt.contains(self) - && (self.text_anchor == excerpt.range.context.start - || self.text_anchor == excerpt.range.context.end - || self.text_anchor.is_valid(&excerpt.buffer)) + (self.text_anchor == excerpt.range.context.start + || self.text_anchor == excerpt.range.context.end + || self.text_anchor.is_valid(&excerpt.buffer)) + && excerpt.contains(self) } else { false } diff --git a/crates/onboarding/Cargo.toml b/crates/onboarding/Cargo.toml index da009b4e4e..8f684dd1b8 100644 --- a/crates/onboarding/Cargo.toml +++ b/crates/onboarding/Cargo.toml @@ -16,17 +16,29 @@ default = [] [dependencies] anyhow.workspace = true +ai_onboarding.workspace = true +client.workspace = true command_palette_hooks.workspace = true +component.workspace = true +documented.workspace = true db.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true +itertools.workspace = true language.workspace = true +language_model.workspace = true +menu.workspace = true project.workspace = true +schemars.workspace = true +serde.workspace = true settings.workspace = true theme.workspace = true ui.workspace = true +util.workspace = true +vim_mode_setting.workspace = true workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true +zlog.workspace = true diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs new file mode 100644 index 0000000000..2f031e7bb8 --- /dev/null +++ b/crates/onboarding/src/ai_setup_page.rs @@ -0,0 +1,359 @@ +use std::sync::Arc; + +use ai_onboarding::{AiUpsellCard, SignInStatus}; +use client::DisableAiSettings; +use fs::Fs; +use gpui::{ + Action, AnyView, App, DismissEvent, EventEmitter, FocusHandle, Focusable, Window, prelude::*, +}; +use itertools; + +use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; +use settings::{Settings, update_settings_file}; +use ui::{ + Badge, ButtonLike, Divider, Modal, ModalFooter, ModalHeader, Section, SwitchField, ToggleState, + prelude::*, +}; +use workspace::ModalView; + +use util::ResultExt; +use zed_actions::agent::OpenSettings; + +use crate::Onboarding; + +const FEATURED_PROVIDERS: [&'static str; 4] = ["anthropic", "google", "openai", "ollama"]; + +fn render_llm_provider_section( + onboarding: &Onboarding, + disabled: bool, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Or use other LLM providers").size(LabelSize::Large)) + .child( + Label::new("Bring your API keys to use the available providers with Zed's UI for free.") + .color(Color::Muted), + ), + ) + .child(render_llm_provider_card(onboarding, disabled, window, cx)) +} + +fn render_privacy_card(disabled: bool, cx: &mut App) -> impl IntoElement { + let privacy_badge = || Badge::new("Privacy").icon(IconName::ShieldCheck); + + v_flex() + .relative() + .pt_2() + .pb_2p5() + .pl_3() + .pr_2() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.5)) + .bg(cx.theme().colors().surface_background.opacity(0.3)) + .rounded_lg() + .overflow_hidden() + .map(|this| { + if disabled { + this.child( + h_flex() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Label::new("AI is disabled across Zed")) + .child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ), + ) + .child(privacy_badge()), + ) + .child( + Label::new("Re-enable it any time in Settings.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + this.child( + h_flex() + .gap_2() + .justify_between() + .child(Label::new("We don't train models using your data")) + .child( + h_flex().gap_1().child(privacy_badge()).child( + Button::new("learn_more", "Learn More") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(|_, _, cx| { + cx.open_url("https://zed.dev/docs/ai/privacy-and-security"); + }), + ), + ), + ) + .child( + Label::new( + "Feel confident in the security and privacy of your projects using Zed.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) +} + +fn render_llm_provider_card( + onboarding: &Onboarding, + disabled: bool, + _: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let registry = LanguageModelRegistry::read_global(cx); + + v_flex() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().surface_background.opacity(0.5)) + .rounded_lg() + .overflow_hidden() + .children(itertools::intersperse_with( + FEATURED_PROVIDERS + .into_iter() + .flat_map(|provider_name| { + registry.provider(&LanguageModelProviderId::new(provider_name)) + }) + .enumerate() + .map(|(index, provider)| { + let group_name = SharedString::new(format!("onboarding-hover-group-{}", index)); + let is_authenticated = provider.is_authenticated(cx); + + ButtonLike::new(("onboarding-ai-setup-buttons", index)) + .size(ButtonSize::Large) + .child( + h_flex() + .group(&group_name) + .px_0p5() + .w_full() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1() + .child( + Icon::new(provider.icon()) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new(provider.name().0)), + ) + .child( + h_flex() + .gap_1() + .when(!is_authenticated, |el| { + el.visible_on_hover(group_name.clone()) + .child( + Icon::new(IconName::Settings) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configure") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }) + .when(is_authenticated && !disabled, |el| { + el.child( + Icon::new(IconName::Check) + .color(Color::Success) + .size(IconSize::XSmall), + ) + .child( + Label::new("Configured") + .color(Color::Muted) + .size(LabelSize::Small), + ) + }), + ), + ) + .on_click({ + let workspace = onboarding.workspace.clone(); + move |_, window, cx| { + workspace + .update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + let modal = AiConfigurationModal::new( + provider.clone(), + window, + cx, + ); + window.focus(&modal.focus_handle(cx)); + modal + }); + }) + .log_err(); + } + }) + .into_any_element() + }), + || Divider::horizontal().into_any_element(), + )) + .child(Divider::horizontal()) + .child( + Button::new("agent_settings", "Add Many Others") + .size(ButtonSize::Large) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .on_click(|_event, window, cx| { + window.dispatch_action(OpenSettings.boxed_clone(), cx) + }), + ) +} + +pub(crate) fn render_ai_setup_page( + onboarding: &Onboarding, + window: &mut Window, + cx: &mut App, +) -> impl IntoElement { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + + let backdrop = div() + .id("backdrop") + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().editor_background) + .opacity(0.8) + .block_mouse_except_scroll(); + + v_flex() + .gap_2() + .child(SwitchField::new( + "enable_ai", + "Enable AI features", + None, + if is_ai_disabled { + ToggleState::Unselected + } else { + ToggleState::Selected + }, + |toggle_state, _, cx| { + let enabled = match toggle_state { + ToggleState::Indeterminate => { + return; + } + ToggleState::Unselected => false, + ToggleState::Selected => true, + }; + + let fs = ::global(cx); + update_settings_file::( + fs, + cx, + move |ai_settings: &mut Option, _| { + *ai_settings = Some(!enabled); + }, + ); + }, + )) + .child(render_privacy_card(is_ai_disabled, cx)) + .child( + v_flex() + .mt_2() + .gap_6() + .child(AiUpsellCard { + sign_in_status: SignInStatus::SignedIn, + sign_in: Arc::new(|_, _| {}), + user_plan: onboarding.user_store.read(cx).plan(), + }) + .child(render_llm_provider_section( + onboarding, + is_ai_disabled, + window, + cx, + )) + .when(is_ai_disabled, |this| this.child(backdrop)), + ) +} + +struct AiConfigurationModal { + focus_handle: FocusHandle, + selected_provider: Arc, + configuration_view: AnyView, +} + +impl AiConfigurationModal { + fn new( + selected_provider: Arc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let focus_handle = cx.focus_handle(); + let configuration_view = selected_provider.configuration_view(window, cx); + + Self { + focus_handle, + configuration_view, + selected_provider, + } + } +} + +impl ModalView for AiConfigurationModal {} + +impl EventEmitter for AiConfigurationModal {} + +impl Focusable for AiConfigurationModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AiConfigurationModal { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .w(rems(34.)) + .elevation_3(cx) + .track_focus(&self.focus_handle) + .child( + Modal::new("onboarding-ai-setup-modal", None) + .header( + ModalHeader::new() + .icon( + Icon::new(self.selected_provider.icon()) + .color(Color::Muted) + .size(IconSize::Small), + ) + .headline(self.selected_provider.name().0), + ) + .section(Section::new().child(self.configuration_view.clone())) + .footer( + ModalFooter::new().end_slot( + h_flex() + .gap_1() + .child( + Button::new("onboarding-closing-cancel", "Cancel") + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + .child(Button::new("save-btn", "Done").on_click(cx.listener( + |_, _, window, cx| { + window.dispatch_action(menu::Confirm.boxed_clone(), cx); + cx.emit(DismissEvent); + }, + ))), + ), + ), + ) + } +} diff --git a/crates/onboarding/src/basics_page.rs b/crates/onboarding/src/basics_page.rs new file mode 100644 index 0000000000..82688e6220 --- /dev/null +++ b/crates/onboarding/src/basics_page.rs @@ -0,0 +1,351 @@ +use client::TelemetrySettings; +use fs::Fs; +use gpui::{App, Entity, IntoElement, Window}; +use settings::{BaseKeymap, Settings, update_settings_file}; +use theme::{Appearance, ThemeMode, ThemeName, ThemeRegistry, ThemeSelection, ThemeSettings}; +use ui::{ + ParentElement as _, StatefulInteractiveElement, SwitchField, ToggleButtonGroup, + ToggleButtonSimple, ToggleButtonWithIcon, prelude::*, rems_from_px, +}; +use vim_mode_setting::VimModeSetting; + +use crate::theme_preview::ThemePreviewTile; + +/// separates theme "mode" ("dark" | "light" | "system") into two separate states +/// - appearance = "dark" | "light" +/// - "system" true/false +/// when system selected: +/// - toggling between light and dark does not change theme.mode, just which variant will be changed +/// when system not selected: +/// - toggling between light and dark does change theme.mode +/// selecting a theme preview will always change theme.["light" | "dark"] to the selected theme, +/// +/// this allows for selecting a dark and light theme option regardless of whether the mode is set to system or not +/// it does not support setting theme to a static value +fn render_theme_section(window: &mut Window, cx: &mut App) -> impl IntoElement { + let theme_selection = ThemeSettings::get_global(cx).theme_selection.clone(); + let system_appearance = theme::SystemAppearance::global(cx); + let appearance_state = window.use_state(cx, |_, _cx| { + theme_selection + .as_ref() + .and_then(|selection| selection.mode()) + .and_then(|mode| match mode { + ThemeMode::System => None, + ThemeMode::Light => Some(Appearance::Light), + ThemeMode::Dark => Some(Appearance::Dark), + }) + .unwrap_or(*system_appearance) + }); + let appearance = *appearance_state.read(cx); + let theme_selection = theme_selection.unwrap_or_else(|| ThemeSelection::Dynamic { + mode: match *system_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }, + light: ThemeName("One Light".into()), + dark: ThemeName("One Dark".into()), + }); + let theme_registry = ThemeRegistry::global(cx); + + let current_theme_name = theme_selection.theme(appearance); + let theme_mode = theme_selection.mode().unwrap_or_default(); + + // let theme_mode = theme_selection.mode(); + // TODO: Clean this up once the "System" button inside the + // toggle button group is done + + let selected_index = match appearance { + Appearance::Light => 0, + Appearance::Dark => 1, + }; + + let theme_seed = 0xBEEF as f32; + + const LIGHT_THEMES: [&'static str; 3] = ["One Light", "Ayu Light", "Gruvbox Light"]; + const DARK_THEMES: [&'static str; 3] = ["One Dark", "Ayu Dark", "Gruvbox Dark"]; + + let theme_names = match appearance { + Appearance::Light => LIGHT_THEMES, + Appearance::Dark => DARK_THEMES, + }; + let themes = theme_names + .map(|theme_name| theme_registry.get(theme_name)) + .map(Result::unwrap); + + let theme_previews = themes.map(|theme| { + let is_selected = theme.name == current_theme_name; + let name = theme.name.clone(); + let colors = cx.theme().colors(); + + v_flex() + .id(name.clone()) + .w_full() + .items_center() + .gap_1() + .child( + div() + .w_full() + .border_2() + .border_color(colors.border_transparent) + .rounded(ThemePreviewTile::CORNER_RADIUS) + .map(|this| { + if is_selected { + this.border_color(colors.border_selected) + } else { + this.opacity(0.8).hover(|s| s.border_color(colors.border)) + } + }) + .child(ThemePreviewTile::new(theme.clone(), theme_seed)), + ) + .child(Label::new(name).color(Color::Muted).size(LabelSize::Small)) + .on_click({ + let theme_name = theme.name.clone(); + move |_, _, cx| { + let fs = ::global(cx); + let theme_name = theme_name.clone(); + update_settings_file::(fs, cx, move |settings, _| { + settings.set_theme(theme_name, appearance); + }); + } + }) + }); + + return v_flex() + .gap_2() + .child( + h_flex().justify_between().child(Label::new("Theme")).child( + ToggleButtonGroup::single_row( + "theme-selector-onboarding-dark-light", + [ + ToggleButtonSimple::new("Light", { + let appearance_state = appearance_state.clone(); + move |_, _, cx| { + write_appearance_change(&appearance_state, Appearance::Light, cx); + } + }), + ToggleButtonSimple::new("Dark", { + let appearance_state = appearance_state.clone(); + move |_, _, cx| { + write_appearance_change(&appearance_state, Appearance::Dark, cx); + } + }), + // TODO: Properly put the System back as a button within this group + // Currently, given "System" is not an option in the Appearance enum, + // this button doesn't get selected + ToggleButtonSimple::new("System", { + let theme = theme_selection.clone(); + move |_, _, cx| { + toggle_system_theme_mode(theme.clone(), appearance, cx); + } + }) + .selected(theme_mode == ThemeMode::System), + ], + ) + .selected_index(selected_index) + .style(ui::ToggleButtonGroupStyle::Outlined) + .button_width(rems_from_px(64.)), + ), + ) + .child(h_flex().gap_4().justify_between().children(theme_previews)); + + fn write_appearance_change( + appearance_state: &Entity, + new_appearance: Appearance, + cx: &mut App, + ) { + let fs = ::global(cx); + appearance_state.write(cx, new_appearance); + + update_settings_file::(fs, cx, move |settings, _| { + if settings.theme.as_ref().and_then(ThemeSelection::mode) == Some(ThemeMode::System) { + return; + } + let new_mode = match new_appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }; + settings.set_mode(new_mode); + }); + } + + fn toggle_system_theme_mode( + theme_selection: ThemeSelection, + appearance: Appearance, + cx: &mut App, + ) { + let fs = ::global(cx); + + update_settings_file::(fs, cx, move |settings, _| { + settings.theme = Some(match theme_selection { + ThemeSelection::Static(theme_name) => ThemeSelection::Dynamic { + mode: ThemeMode::System, + light: theme_name.clone(), + dark: theme_name.clone(), + }, + ThemeSelection::Dynamic { + mode: ThemeMode::System, + light, + dark, + } => { + let mode = match appearance { + Appearance::Light => ThemeMode::Light, + Appearance::Dark => ThemeMode::Dark, + }; + ThemeSelection::Dynamic { mode, light, dark } + } + ThemeSelection::Dynamic { + mode: _, + light, + dark, + } => ThemeSelection::Dynamic { + mode: ThemeMode::System, + light, + dark, + }, + }); + }); + } +} + +fn write_keymap_base(keymap_base: BaseKeymap, cx: &App) { + let fs = ::global(cx); + + update_settings_file::(fs, cx, move |setting, _| { + *setting = Some(keymap_base); + }); +} + +fn render_telemetry_section(cx: &App) -> impl IntoElement { + let fs = ::global(cx); + + v_flex() + .gap_4() + .child(Label::new("Telemetry").size(LabelSize::Large)) + .child(SwitchField::new( + "onboarding-telemetry-metrics", + "Help Improve Zed", + Some("Sending anonymous usage data helps us build the right features and create the best experience.".into()), + if TelemetrySettings::get_global(cx).metrics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::( + fs.clone(), + cx, + move |setting, _| setting.metrics = Some(enabled), + ); + }}, + )) + .child(SwitchField::new( + "onboarding-telemetry-crash-reports", + "Help Fix Zed", + Some("Send crash reports so we can fix critical issues fast.".into()), + if TelemetrySettings::get_global(cx).diagnostics { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = fs.clone(); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::( + fs.clone(), + cx, + move |setting, _| setting.diagnostics = Some(enabled), + ); + } + } + )) +} + +pub(crate) fn render_basics_page(window: &mut Window, cx: &mut App) -> impl IntoElement { + let base_keymap = match BaseKeymap::get_global(cx) { + BaseKeymap::VSCode => Some(0), + BaseKeymap::JetBrains => Some(1), + BaseKeymap::SublimeText => Some(2), + BaseKeymap::Atom => Some(3), + BaseKeymap::Emacs => Some(4), + BaseKeymap::Cursor => Some(5), + BaseKeymap::TextMate | BaseKeymap::None => None, + }; + + v_flex() + .gap_6() + .child(render_theme_section(window, cx)) + .child( + v_flex().gap_2().child(Label::new("Base Keymap")).child( + ToggleButtonGroup::two_rows( + "multiple_row_test", + [ + ToggleButtonWithIcon::new("VS Code", IconName::EditorVsCode, |_, _, cx| { + write_keymap_base(BaseKeymap::VSCode, cx); + }), + ToggleButtonWithIcon::new("Jetbrains", IconName::EditorJetBrains, |_, _, cx| { + write_keymap_base(BaseKeymap::JetBrains, cx); + }), + ToggleButtonWithIcon::new("Sublime Text", IconName::EditorSublime, |_, _, cx| { + write_keymap_base(BaseKeymap::SublimeText, cx); + }), + ], + [ + ToggleButtonWithIcon::new("Atom", IconName::EditorAtom, |_, _, cx| { + write_keymap_base(BaseKeymap::Atom, cx); + }), + ToggleButtonWithIcon::new("Emacs", IconName::EditorEmacs, |_, _, cx| { + write_keymap_base(BaseKeymap::Emacs, cx); + }), + ToggleButtonWithIcon::new("Cursor (Beta)", IconName::EditorCursor, |_, _, cx| { + write_keymap_base(BaseKeymap::Cursor, cx); + }), + ], + ) + .when_some(base_keymap, |this, base_keymap| this.selected_index(base_keymap)) + .button_width(rems_from_px(216.)) + .size(ui::ToggleButtonGroupSize::Medium) + .style(ui::ToggleButtonGroupStyle::Outlined) + ), + ) + .child(SwitchField::new( + "onboarding-vim-mode", + "Vim Mode", + Some("Coming from Neovim? Zed's first-class implementation of Vim Mode has got your back.".into()), + if VimModeSetting::get_global(cx).0 { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + { + let fs = ::global(cx); + move |selection, _, cx| { + let enabled = match selection { + ToggleState::Selected => true, + ToggleState::Unselected => false, + ToggleState::Indeterminate => { return; }, + }; + + update_settings_file::( + fs.clone(), + cx, + move |setting, _| *setting = Some(enabled), + ); + } + }, + )) + .child(render_telemetry_section(cx)) +} diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index c07d8fef4d..33d0955d19 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -1,16 +1,17 @@ use editor::{EditorSettings, ShowMinimap}; use fs::Fs; -use gpui::{App, IntoElement, Pixels, Window}; +use gpui::{Action, App, IntoElement, Pixels, Window}; use language::language_settings::AllLanguageSettings; use project::project_settings::ProjectSettings; use settings::{Settings as _, update_settings_file}; use theme::{FontFamilyCache, FontFamilyName, ThemeSettings}; use ui::{ - ContextMenu, DropdownMenu, IconButton, Label, LabelCommon, LabelSize, NumericStepper, - ParentElement, SharedString, Styled, SwitchColor, SwitchField, ToggleButtonGroup, - ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, div, h_flex, px, v_flex, + ButtonLike, ContextMenu, DropdownMenu, NumericStepper, SwitchField, ToggleButtonGroup, + ToggleButtonGroupStyle, ToggleButtonSimple, ToggleState, prelude::*, }; +use crate::{ImportCursorSettings, ImportVsCodeSettings}; + fn read_show_mini_map(cx: &App) -> ShowMinimap { editor::EditorSettings::get_global(cx).minimap.show } @@ -18,6 +19,14 @@ fn read_show_mini_map(cx: &App) -> ShowMinimap { fn write_show_mini_map(show: ShowMinimap, cx: &mut App) { let fs = ::global(cx); + // This is used to speed up the UI + // the UI reads the current values to get what toggle state to show on buttons + // there's a slight delay if we just call update_settings_file so we manually set + // the value here then call update_settings file to get around the delay + let mut curr_settings = EditorSettings::get_global(cx).clone(); + curr_settings.minimap.show = show; + EditorSettings::override_global(curr_settings, cx); + update_settings_file::(fs, cx, move |editor_settings, _| { editor_settings.minimap.get_or_insert_default().show = Some(show); }); @@ -33,6 +42,10 @@ fn read_inlay_hints(cx: &App) -> bool { fn write_inlay_hints(enabled: bool, cx: &mut App) { let fs = ::global(cx); + let mut curr_settings = AllLanguageSettings::get_global(cx).clone(); + curr_settings.defaults.inlay_hints.enabled = enabled; + AllLanguageSettings::override_global(curr_settings, cx); + update_settings_file::(fs, cx, move |all_language_settings, cx| { all_language_settings .defaults @@ -54,6 +67,14 @@ fn read_git_blame(cx: &App) -> bool { fn set_git_blame(enabled: bool, cx: &mut App) { let fs = ::global(cx); + let mut curr_settings = ProjectSettings::get_global(cx).clone(); + curr_settings + .git + .inline_blame + .get_or_insert_default() + .enabled = enabled; + ProjectSettings::override_global(curr_settings, cx); + update_settings_file::(fs, cx, move |project_settings, _| { project_settings .git @@ -95,139 +116,212 @@ fn write_buffer_font_family(font_family: SharedString, cx: &mut App) { }); } -pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { +fn render_import_settings_section() -> impl IntoElement { + v_flex() + .gap_4() + .child( + v_flex() + .child(Label::new("Import Settings").size(LabelSize::Large)) + .child( + Label::new("Automatically pull your settings from other editors.") + .color(Color::Muted), + ), + ) + .child( + h_flex() + .w_full() + .gap_4() + .child( + h_flex().w_full().child( + ButtonLike::new("import_vs_code") + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Large) + .child( + h_flex() + .w_full() + .gap_1p5() + .px_1() + .child( + Icon::new(IconName::EditorVsCode) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new("VS Code")), + ) + .on_click(|_, window, cx| { + window.dispatch_action( + ImportVsCodeSettings::default().boxed_clone(), + cx, + ) + }), + ), + ) + .child( + h_flex().w_full().child( + ButtonLike::new("import_cursor") + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Large) + .child( + h_flex() + .w_full() + .gap_1p5() + .px_1() + .child( + Icon::new(IconName::EditorCursor) + .color(Color::Muted) + .size(IconSize::XSmall), + ) + .child(Label::new("Cursor")), + ) + .on_click(|_, window, cx| { + window.dispatch_action( + ImportCursorSettings::default().boxed_clone(), + cx, + ) + }), + ), + ), + ) +} + +fn render_font_customization_section(window: &mut Window, cx: &mut App) -> impl IntoElement { let theme_settings = ThemeSettings::get_global(cx); let ui_font_size = theme_settings.ui_font_size(cx); let font_family = theme_settings.buffer_font.family.clone(); let buffer_font_size = theme_settings.buffer_font_size(cx); - v_flex() + h_flex() + .w_full() .gap_4() - .child(Label::new("Import Settings").size(LabelSize::Large)) .child( - Label::new("Automatically pull your settings from other editors.") - .size(LabelSize::Small), - ) - .child( - h_flex() - .child(IconButton::new( - "import-vs-code-settings", - ui::IconName::Code, - )) - .child(IconButton::new( - "import-cursor-settings", - ui::IconName::CursorIBeam, - )), - ) - .child(Label::new("Popular Settings").size(LabelSize::Large)) - .child( - h_flex() - .gap_4() - .justify_between() + v_flex() + .w_full() + .gap_1() + .child(Label::new("UI Font")) .child( - v_flex() + h_flex() + .w_full() .justify_between() - .gap_1() - .child(Label::new("UI Font")) + .gap_2() .child( - h_flex() - .justify_between() - .gap_2() - .child(div().min_w(px(120.)).child(DropdownMenu::new( - "ui-font-family", - theme_settings.ui_font.family.clone(), - ContextMenu::build(window, cx, |mut menu, _, cx| { - let font_family_cache = FontFamilyCache::global(cx); + DropdownMenu::new( + "ui-font-family", + theme_settings.ui_font.family.clone(), + ContextMenu::build(window, cx, |mut menu, _, cx| { + let font_family_cache = FontFamilyCache::global(cx); - for font_name in font_family_cache.list_font_families(cx) { - menu = menu.custom_entry( - { - let font_name = font_name.clone(); - move |_window, _cx| { - Label::new(font_name.clone()) - .into_any_element() - } - }, - { - let font_name = font_name.clone(); - move |_window, cx| { - write_ui_font_family(font_name.clone(), cx); - } - }, - ) - } + for font_name in font_family_cache.list_font_families(cx) { + menu = menu.custom_entry( + { + let font_name = font_name.clone(); + move |_window, _cx| { + Label::new(font_name.clone()).into_any_element() + } + }, + { + let font_name = font_name.clone(); + move |_window, cx| { + write_ui_font_family(font_name.clone(), cx); + } + }, + ) + } - menu - }), - ))) - .child(NumericStepper::new( - "ui-font-size", - ui_font_size.to_string(), - move |_, _, cx| { - write_ui_font_size(ui_font_size - px(1.), cx); - }, - move |_, _, cx| { - write_ui_font_size(ui_font_size + px(1.), cx); - }, - )), - ), - ) - .child( - v_flex() - .justify_between() - .gap_1() - .child(Label::new("Editor Font")) + menu + }), + ) + .style(ui::DropdownStyle::Outlined) + .full_width(true), + ) .child( - h_flex() - .justify_between() - .gap_2() - .child(DropdownMenu::new( - "buffer-font-family", - font_family, - ContextMenu::build(window, cx, |mut menu, _, cx| { - let font_family_cache = FontFamilyCache::global(cx); - - for font_name in font_family_cache.list_font_families(cx) { - menu = menu.custom_entry( - { - let font_name = font_name.clone(); - move |_window, _cx| { - Label::new(font_name.clone()) - .into_any_element() - } - }, - { - let font_name = font_name.clone(); - move |_window, cx| { - write_buffer_font_family( - font_name.clone(), - cx, - ); - } - }, - ) - } - - menu - }), - )) - .child(NumericStepper::new( - "buffer-font-size", - buffer_font_size.to_string(), - move |_, _, cx| { - write_buffer_font_size(buffer_font_size - px(1.), cx); - }, - move |_, _, cx| { - write_buffer_font_size(buffer_font_size + px(1.), cx); - }, - )), + NumericStepper::new( + "ui-font-size", + ui_font_size.to_string(), + move |_, _, cx| { + write_ui_font_size(ui_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_ui_font_size(ui_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined), ), ), ) + .child( + v_flex() + .w_full() + .gap_1() + .child(Label::new("Editor Font")) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + DropdownMenu::new( + "buffer-font-family", + font_family, + ContextMenu::build(window, cx, |mut menu, _, cx| { + let font_family_cache = FontFamilyCache::global(cx); + + for font_name in font_family_cache.list_font_families(cx) { + menu = menu.custom_entry( + { + let font_name = font_name.clone(); + move |_window, _cx| { + Label::new(font_name.clone()).into_any_element() + } + }, + { + let font_name = font_name.clone(); + move |_window, cx| { + write_buffer_font_family(font_name.clone(), cx); + } + }, + ) + } + + menu + }), + ) + .style(ui::DropdownStyle::Outlined) + .full_width(true), + ) + .child( + NumericStepper::new( + "buffer-font-size", + buffer_font_size.to_string(), + move |_, _, cx| { + write_buffer_font_size(buffer_font_size - px(1.), cx); + }, + move |_, _, cx| { + write_buffer_font_size(buffer_font_size + px(1.), cx); + }, + ) + .style(ui::NumericStepperStyle::Outlined), + ), + ), + ) +} + +fn render_popular_settings_section(window: &mut Window, cx: &mut App) -> impl IntoElement { + v_flex() + .gap_5() + .child(Label::new("Popular Settings").size(LabelSize::Large).mt_8()) + .child(render_font_customization_section(window, cx)) .child( h_flex() + .items_start() .justify_between() - .child(Label::new("Mini Map")) + .child( + v_flex().child(Label::new("Mini Map")).child( + Label::new("See a high-level overview of your source code.") + .color(Color::Muted), + ), + ) .child( ToggleButtonGroup::single_row( "onboarding-show-mini-map", @@ -252,36 +346,37 @@ pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl Int .button_width(ui::rems_from_px(64.)), ), ) - .child( - SwitchField::new( - "onboarding-enable-inlay-hints", - "Inlay Hints", - "See parameter names for function and method calls inline.", - if read_inlay_hints(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - write_inlay_hints(toggle_state == &ToggleState::Selected, cx); - }, - ) - .color(SwitchColor::Accent), - ) - .child( - SwitchField::new( - "onboarding-git-blame-switch", - "Git Blame", - "See who committed each line on a given file.", - if read_git_blame(cx) { - ui::ToggleState::Selected - } else { - ui::ToggleState::Unselected - }, - |toggle_state, _, cx| { - set_git_blame(toggle_state == &ToggleState::Selected, cx); - }, - ) - .color(SwitchColor::Accent), - ) + .child(SwitchField::new( + "onboarding-enable-inlay-hints", + "Inlay Hints", + Some("See parameter names for function and method calls inline.".into()), + if read_inlay_hints(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + write_inlay_hints(toggle_state == &ToggleState::Selected, cx); + }, + )) + .child(SwitchField::new( + "onboarding-git-blame-switch", + "Git Blame", + Some("See who committed each line on a given file.".into()), + if read_git_blame(cx) { + ui::ToggleState::Selected + } else { + ui::ToggleState::Unselected + }, + |toggle_state, _, cx| { + set_git_blame(toggle_state == &ToggleState::Selected, cx); + }, + )) +} + +pub(crate) fn render_editing_page(window: &mut Window, cx: &mut App) -> impl IntoElement { + v_flex() + .gap_4() + .child(render_import_settings_section()) + .child(render_popular_settings_section(window, cx)) } diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index cc0c47ca71..f7e76f2f34 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,27 +1,34 @@ use crate::welcome::{ShowWelcome, WelcomePage}; +use client::{Client, UserStore}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use fs::Fs; use gpui::{ - AnyElement, App, AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, - IntoElement, Render, SharedString, Subscription, Task, WeakEntity, Window, actions, + Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, + FocusHandle, Focusable, IntoElement, KeyContext, Render, SharedString, Subscription, Task, + WeakEntity, Window, actions, }; -use settings::{Settings, SettingsStore, update_settings_file}; +use schemars::JsonSchema; +use serde::Deserialize; +use settings::{SettingsStore, VsCodeSettingsSource}; use std::sync::Arc; -use theme::{ThemeMode, ThemeSettings}; use ui::{ - Divider, FluentBuilder, Headline, KeyBinding, ParentElement as _, StatefulInteractiveElement, - ToggleButtonGroup, ToggleButtonSimple, Vector, VectorName, prelude::*, rems_from_px, + Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, + StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, }; use workspace::{ AppState, Workspace, WorkspaceId, dock::DockPosition, item::{Item, ItemEvent}, - open_new, with_active_or_new_workspace, + notifications::NotifyResultExt as _, + open_new, register_serializable_item, with_active_or_new_workspace, }; +mod ai_setup_page; +mod basics_page; mod editing_page; +mod theme_preview; mod welcome; pub struct OnBoardingFeatureFlag {} @@ -30,6 +37,24 @@ impl FeatureFlag for OnBoardingFeatureFlag { const NAME: &'static str = "onboarding"; } +/// Imports settings from Visual Studio Code. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportVsCodeSettings { + #[serde(default)] + pub skip_prompt: bool, +} + +/// Imports settings from Cursor editor. +#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = zed)] +#[serde(deny_unknown_fields)] +pub struct ImportCursorSettings { + #[serde(default)] + pub skip_prompt: bool, +} + pub const FIRST_OPEN: &str = "first_open"; actions!( @@ -40,6 +65,18 @@ actions!( ] ); +actions!( + onboarding, + [ + /// Activates the Basics page. + ActivateBasicsPage, + /// Activates the Editing page. + ActivateEditingPage, + /// Activates the AI Setup page. + ActivateAISetupPage, + ] +); + pub fn init(cx: &mut App) { cx.on_action(|_: &OpenOnboarding, cx| { with_active_or_new_workspace(cx, |workspace, window, cx| { @@ -54,7 +91,7 @@ pub fn init(cx: &mut App) { if let Some(existing) = existing { workspace.activate_item(&existing, true, true, window, cx); } else { - let settings_page = Onboarding::new(workspace.weak_handle(), cx); + let settings_page = Onboarding::new(workspace, cx); workspace.add_item_to_active_pane( Box::new(settings_page), None, @@ -81,7 +118,7 @@ pub fn init(cx: &mut App) { if let Some(existing) = existing { workspace.activate_item(&existing, true, true, window, cx); } else { - let settings_page = WelcomePage::new(cx); + let settings_page = WelcomePage::new(window, cx); workspace.add_item_to_active_pane( Box::new(settings_page), None, @@ -95,6 +132,43 @@ pub fn init(cx: &mut App) { }); }); + cx.observe_new(|workspace: &mut Workspace, _window, _cx| { + workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { + let fs = ::global(cx); + let action = *action; + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + VsCodeSettingsSource::VsCode, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + + workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { + let fs = ::global(cx); + let action = *action; + + window + .spawn(cx, async move |cx: &mut AsyncWindowContext| { + handle_import_vscode_settings( + VsCodeSettingsSource::Cursor, + action.skip_prompt, + fs, + cx, + ) + .await + }) + .detach(); + }); + }) + .detach(); + cx.observe_new::(|_, window, cx| { let Some(window) = window else { return; @@ -123,6 +197,7 @@ pub fn init(cx: &mut App) { .detach(); }) .detach(); + register_serializable_item::(cx); } pub fn show_onboarding_view(app_state: Arc, cx: &mut App) -> Task> { @@ -133,7 +208,7 @@ pub fn show_onboarding_view(app_state: Arc, cx: &mut App) -> Task, cx: &mut App) -> Task ThemeMode { - let settings = ThemeSettings::get_global(cx); - settings - .theme_selection - .as_ref() - .and_then(|selection| selection.mode()) - .unwrap_or_default() -} - -fn write_theme_selection(theme_mode: ThemeMode, cx: &App) { - let fs = ::global(cx); - - update_settings_file::(fs, cx, move |settings, _| { - settings.set_mode(theme_mode); - }); -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SelectedPage { Basics, @@ -175,163 +233,278 @@ struct Onboarding { workspace: WeakEntity, focus_handle: FocusHandle, selected_page: SelectedPage, + user_store: Entity, _settings_subscription: Subscription, } impl Onboarding { - fn new(workspace: WeakEntity, cx: &mut App) -> Entity { + fn new(workspace: &Workspace, cx: &mut App) -> Entity { cx.new(|cx| Self { - workspace, + workspace: workspace.weak_handle(), focus_handle: cx.focus_handle(), selected_page: SelectedPage::Basics, + user_store: workspace.user_store().clone(), _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), }) } - fn render_page_nav( + fn set_page(&mut self, page: SelectedPage, cx: &mut Context) { + self.selected_page = page; + cx.notify(); + cx.emit(ItemEvent::UpdateTab); + } + + fn render_nav_buttons( &mut self, - page: SelectedPage, - _: &mut Window, + window: &mut Window, cx: &mut Context, - ) -> impl IntoElement { - let text = match page { - SelectedPage::Basics => "Basics", - SelectedPage::Editing => "Editing", - SelectedPage::AiSetup => "AI Setup", - }; - let binding = match page { - SelectedPage::Basics => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-1").unwrap()], cx) - } - SelectedPage::Editing => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-2").unwrap()], cx) - } - SelectedPage::AiSetup => { - KeyBinding::new(vec![gpui::Keystroke::parse("cmd-3").unwrap()], cx) - } - }; - let selected = self.selected_page == page; - h_flex() - .id(text) - .rounded_sm() - .child(text) - .child(binding) - .h_8() - .gap_2() - .px_2() - .py_0p5() - .w_full() + ) -> [impl IntoElement; 3] { + let pages = [ + SelectedPage::Basics, + SelectedPage::Editing, + SelectedPage::AiSetup, + ]; + + let text = ["Basics", "Editing", "AI Setup"]; + + let actions: [&dyn Action; 3] = [ + &ActivateBasicsPage, + &ActivateEditingPage, + &ActivateAISetupPage, + ]; + + let mut binding = actions.map(|action| { + KeyBinding::for_action_in(action, &self.focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }); + + pages.map(|page| { + let i = page as usize; + let selected = self.selected_page == page; + h_flex() + .id(text[i]) + .relative() + .w_full() + .gap_2() + .px_2() + .py_0p5() + .justify_between() + .rounded_sm() + .when(selected, |this| { + this.child( + div() + .h_4() + .w_px() + .bg(cx.theme().colors().text_accent) + .absolute() + .left_0(), + ) + }) + .hover(|style| style.bg(cx.theme().colors().element_hover)) + .child(Label::new(text[i]).map(|this| { + if selected { + this.color(Color::Default) + } else { + this.color(Color::Muted) + } + })) + .child(binding[i].take().map_or( + gpui::Empty.into_any_element(), + IntoElement::into_any_element, + )) + .on_click(cx.listener(move |this, _, _, cx| { + this.set_page(page, cx); + })) + }) + } + + fn render_nav(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .h_full() + .w(rems_from_px(220.)) + .flex_shrink_0() + .gap_4() .justify_between() - .map(|this| { - if selected { - this.bg(Color::Selected.color(cx)) - .border_l_1() - .border_color(Color::Accent.color(cx)) + .child( + v_flex() + .gap_6() + .child( + h_flex() + .px_2() + .gap_4() + .child(Vector::square(VectorName::ZedLogo, rems(2.5))) + .child( + v_flex() + .child( + Headline::new("Welcome to Zed").size(HeadlineSize::Small), + ) + .child( + Label::new("The editor for what's next") + .color(Color::Muted) + .size(LabelSize::Small) + .italic(), + ), + ), + ) + .child( + v_flex() + .gap_4() + .child( + v_flex() + .py_4() + .border_y_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .gap_1() + .children(self.render_nav_buttons(window, cx)), + ) + .child( + ButtonLike::new("skip_all") + .child(Label::new("Skip All").ml_1()) + .on_click(|_, _, cx| { + with_active_or_new_workspace( + cx, + |workspace, window, cx| { + let Some((onboarding_id, onboarding_idx)) = + workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = + item.downcast::()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map( + |(idx, item)| { + let _ = + item.downcast::()?; + Some(idx) + }, + ); + + if let Some(idx) = idx { + pane.activate_item( + idx, true, true, window, cx, + ); + } else { + let item = + Box::new(WelcomePage::new(window, cx)); + pane.add_item( + item, + true, + true, + Some(onboarding_idx), + window, + cx, + ); + } + + pane.remove_item( + onboarding_id, + false, + false, + window, + cx, + ); + }); + }, + ); + }), + ), + ), + ) + .child( + if let Some(user) = self.user_store.read(cx).current_user() { + h_flex() + .pl_1p5() + .gap_2() + .child(Avatar::new(user.avatar_uri.clone())) + .child(Label::new(user.github_login.clone())) + .into_any_element() } else { - this.text_color(Color::Muted.color(cx)) - } - }) - .hover(|style| { - if selected { - style.bg(Color::Selected.color(cx).opacity(0.6)) - } else { - style.bg(Color::Selected.color(cx).opacity(0.3)) - } - }) - .on_click(cx.listener(move |this, _, _, cx| { - this.selected_page = page; - cx.notify(); - })) + Button::new("sign_in", "Sign In") + .style(ButtonStyle::Outlined) + .full_width() + .on_click(|_, window, cx| { + let client = Client::global(cx); + window + .spawn(cx, async move |cx| { + client + .sign_in_with_optional_connect(true, &cx) + .await + .notify_async_err(cx); + }) + .detach(); + }) + .into_any_element() + }, + ) } fn render_page(&mut self, window: &mut Window, cx: &mut Context) -> AnyElement { match self.selected_page { - SelectedPage::Basics => self.render_basics_page(window, cx).into_any_element(), + SelectedPage::Basics => { + crate::basics_page::render_basics_page(window, cx).into_any_element() + } SelectedPage::Editing => { crate::editing_page::render_editing_page(window, cx).into_any_element() } - SelectedPage::AiSetup => self.render_ai_setup_page(window, cx).into_any_element(), + SelectedPage::AiSetup => { + crate::ai_setup_page::render_ai_setup_page(&self, window, cx).into_any_element() + } } } - - fn render_basics_page(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let theme_mode = read_theme_selection(cx); - - v_flex().child( - h_flex().justify_between().child(Label::new("Theme")).child( - ToggleButtonGroup::single_row( - "theme-selector-onboarding", - [ - ToggleButtonSimple::new("Light", |_, _, cx| { - write_theme_selection(ThemeMode::Light, cx) - }), - ToggleButtonSimple::new("Dark", |_, _, cx| { - write_theme_selection(ThemeMode::Dark, cx) - }), - ToggleButtonSimple::new("System", |_, _, cx| { - write_theme_selection(ThemeMode::System, cx) - }), - ], - ) - .selected_index(match theme_mode { - ThemeMode::Light => 0, - ThemeMode::Dark => 1, - ThemeMode::System => 2, - }) - .style(ui::ToggleButtonGroupStyle::Outlined) - .button_width(rems_from_px(64.)), - ), - ) - } - - fn render_ai_setup_page(&mut self, _: &mut Window, _: &mut Context) -> impl IntoElement { - div().child("ai setup page") - } } impl Render for Onboarding { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { h_flex() .image_cache(gpui::retain_all("onboarding-page")) - .key_context("onboarding-page") - .px_24() - .py_12() - .items_start() + .key_context({ + let mut ctx = KeyContext::new_with_defaults(); + ctx.add("Onboarding"); + ctx + }) + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .on_action(cx.listener(|this, _: &ActivateBasicsPage, _, cx| { + this.set_page(SelectedPage::Basics, cx); + })) + .on_action(cx.listener(|this, _: &ActivateEditingPage, _, cx| { + this.set_page(SelectedPage::Editing, cx); + })) + .on_action(cx.listener(|this, _: &ActivateAISetupPage, _, cx| { + this.set_page(SelectedPage::AiSetup, cx); + })) .child( - v_flex() - .w_1_3() - .h_full() + h_flex() + .max_w(rems_from_px(1100.)) + .size_full() + .m_auto() + .py_20() + .px_12() + .items_start() + .gap_12() + .child(self.render_nav(window, cx)) .child( - h_flex() - .pt_0p5() - .child(Vector::square(VectorName::ZedLogo, rems(2.))) - .child( - v_flex() - .left_1() - .items_center() - .child(Headline::new("Welcome to Zed")) - .child( - Label::new("The editor for what's next") - .color(Color::Muted) - .italic(), - ), - ), - ) - .p_1() - .child(Divider::horizontal_dashed()) - .child( - v_flex().gap_1().children([ - self.render_page_nav(SelectedPage::Basics, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::Editing, window, cx) - .into_element(), - self.render_page_nav(SelectedPage::AiSetup, window, cx) - .into_element(), - ]), + v_flex() + .max_w_full() + .min_w_0() + .pl_12() + .border_l_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .size_full() + .child(self.render_page(window, cx)), ), ) - // .child(Divider::vertical_dashed()) - .child(div().w_2_3().h_full().child(self.render_page(window, cx))) } } @@ -364,10 +537,185 @@ impl Item for Onboarding { _: &mut Window, cx: &mut Context, ) -> Option> { - Some(Onboarding::new(self.workspace.clone(), cx)) + self.workspace + .update(cx, |workspace, cx| Onboarding::new(workspace, cx)) + .ok() } fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { f(*event) } } + +pub async fn handle_import_vscode_settings( + source: VsCodeSettingsSource, + skip_prompt: bool, + fs: Arc, + cx: &mut AsyncWindowContext, +) { + use util::truncate_and_remove_front; + + let vscode_settings = + match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { + Ok(vscode_settings) => vscode_settings, + Err(err) => { + zlog::error!("{err}"); + let _ = cx.prompt( + gpui::PromptLevel::Info, + &format!("Could not find or load a {source} settings file"), + None, + &["Ok"], + ); + return; + } + }; + + if !skip_prompt { + let prompt = cx.prompt( + gpui::PromptLevel::Warning, + &format!( + "Importing {} settings may overwrite your existing settings. \ + Will import settings from {}", + vscode_settings.source, + truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), + ), + None, + &["Ok", "Cancel"], + ); + let result = cx.spawn(async move |_| prompt.await.ok()).await; + if result != Some(0) { + return; + } + }; + + cx.update(|_, cx| { + let source = vscode_settings.source; + let path = vscode_settings.path.clone(); + cx.global::() + .import_vscode_settings(fs, vscode_settings); + zlog::info!("Imported {source} settings from {}", path.display()); + }) + .ok(); +} + +impl workspace::SerializableItem for Onboarding { + fn serialized_item_kind() -> &'static str { + "OnboardingPage" + } + + fn cleanup( + workspace_id: workspace::WorkspaceId, + alive_items: Vec, + _window: &mut Window, + cx: &mut App, + ) -> gpui::Task> { + workspace::delete_unloaded_items( + alive_items, + workspace_id, + "onboarding_pages", + &persistence::ONBOARDING_PAGES, + cx, + ) + } + + fn deserialize( + _project: Entity, + workspace: WeakEntity, + workspace_id: workspace::WorkspaceId, + item_id: workspace::ItemId, + window: &mut Window, + cx: &mut App, + ) -> gpui::Task>> { + window.spawn(cx, async move |cx| { + if let Some(page_number) = + persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)? + { + let page = match page_number { + 0 => Some(SelectedPage::Basics), + 1 => Some(SelectedPage::Editing), + 2 => Some(SelectedPage::AiSetup), + _ => None, + }; + workspace.update(cx, |workspace, cx| { + let onboarding_page = Onboarding::new(workspace, cx); + if let Some(page) = page { + zlog::info!("Onboarding page {page:?} loaded"); + onboarding_page.update(cx, |onboarding_page, cx| { + onboarding_page.set_page(page, cx); + }) + } + onboarding_page + }) + } else { + Err(anyhow::anyhow!("No onboarding page to deserialize")) + } + }) + } + + fn serialize( + &mut self, + workspace: &mut Workspace, + item_id: workspace::ItemId, + _closing: bool, + _window: &mut Window, + cx: &mut ui::Context, + ) -> Option>> { + let workspace_id = workspace.database_id()?; + let page_number = self.selected_page as u16; + Some(cx.background_spawn(async move { + persistence::ONBOARDING_PAGES + .save_onboarding_page(item_id, workspace_id, page_number) + .await + })) + } + + fn should_serialize(&self, event: &Self::Event) -> bool { + event == &ItemEvent::UpdateTab + } +} + +mod persistence { + use db::{define_connection, query, sqlez_macros::sql}; + use workspace::WorkspaceDb; + + define_connection! { + pub static ref ONBOARDING_PAGES: OnboardingPagesDb = + &[ + sql!( + CREATE TABLE onboarding_pages ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + page_number INTEGER, + + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + ), + ]; + } + + impl OnboardingPagesDb { + query! { + pub async fn save_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId, + page_number: u16 + ) -> Result<()> { + INSERT OR REPLACE INTO onboarding_pages(item_id, workspace_id, page_number) + VALUES (?, ?, ?) + } + } + + query! { + pub fn get_onboarding_page( + item_id: workspace::ItemId, + workspace_id: workspace::WorkspaceId + ) -> Result> { + SELECT page_number + FROM onboarding_pages + WHERE item_id = ? AND workspace_id = ? + } + } + } +} diff --git a/crates/welcome/src/welcome_ui/theme_preview.rs b/crates/onboarding/src/theme_preview.rs similarity index 72% rename from crates/welcome/src/welcome_ui/theme_preview.rs rename to crates/onboarding/src/theme_preview.rs index b3a80c74c3..d51511b7f4 100644 --- a/crates/welcome/src/welcome_ui/theme_preview.rs +++ b/crates/onboarding/src/theme_preview.rs @@ -11,22 +11,14 @@ use ui::{ #[derive(IntoElement, RegisterComponent, Documented)] pub struct ThemePreviewTile { theme: Arc, - selected: bool, seed: f32, } impl ThemePreviewTile { - pub fn new(theme: Arc, selected: bool, seed: f32) -> Self { - Self { - theme, - selected, - seed, - } - } + pub const CORNER_RADIUS: Pixels = px(8.0); - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self + pub fn new(theme: Arc, seed: f32) -> Self { + Self { theme, seed } } } @@ -34,7 +26,7 @@ impl RenderOnce for ThemePreviewTile { fn render(self, _window: &mut ui::Window, _cx: &mut ui::App) -> impl IntoElement { let color = self.theme.colors(); - let root_radius = px(8.0); + let root_radius = Self::CORNER_RADIUS; let root_border = px(2.0); let root_padding = px(2.0); let child_border = px(1.0); @@ -43,7 +35,7 @@ impl RenderOnce for ThemePreviewTile { let item_skeleton = |w: Length, h: Pixels, bg: Hsla| div().w(w).h(h).rounded_full().bg(bg); - let skeleton_height = px(4.); + let skeleton_height = px(2.); let sidebar_seeded_width = |seed: f32, index: usize| { let value = (seed * 1000.0 + index as f32 * 10.0).sin() * 0.5 + 0.5; @@ -70,12 +62,10 @@ impl RenderOnce for ThemePreviewTile { .border_color(color.border_transparent) .bg(color.panel_background) .child( - div() + v_flex() .p_2() - .flex() - .flex_col() .size_full() - .gap(px(4.)) + .gap_1() .children(sidebar_skeleton), ); @@ -151,32 +141,19 @@ impl RenderOnce for ThemePreviewTile { v_flex() .size_full() .p_1() - .gap(px(6.)) + .gap_1p5() .children(lines) .into_any_element() }; - let pane = div() - .h_full() - .flex_grow() - .flex() - .flex_col() - // .child( - // div() - // .w_full() - // .border_color(color.border) - // .border_b(px(1.)) - // .h(relative(0.1)) - // .bg(color.tab_bar_background), - // ) - .child( - div() - .size_full() - .overflow_hidden() - .bg(color.editor_background) - .p_2() - .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), - ); + let pane = v_flex().h_full().flex_grow().child( + div() + .size_full() + .overflow_hidden() + .bg(color.editor_background) + .p_2() + .child(pseudo_code_skeleton(self.theme.clone(), self.seed)), + ); let content = div().size_full().flex().child(sidebar).child(pane); @@ -184,11 +161,6 @@ impl RenderOnce for ThemePreviewTile { .size_full() .rounded(root_radius) .p(root_padding) - .border(root_border) - .border_color(color.border_transparent) - .when(self.selected, |this| { - this.border_color(color.border_selected) - }) .child( div() .size_full() @@ -230,24 +202,14 @@ impl Component for ThemePreviewTile { .p_4() .children({ if let Some(one_dark) = one_dark.ok() { - vec![example_group(vec![ - single_example( - "Default", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark.clone(), false, 0.42)) - .into_any_element(), - ), - single_example( - "Selected", - div() - .w(px(240.)) - .h(px(180.)) - .child(ThemePreviewTile::new(one_dark, true, 0.42)) - .into_any_element(), - ), - ])] + vec![example_group(vec![single_example( + "Default", + div() + .w(px(240.)) + .h(px(180.)) + .child(ThemePreviewTile::new(one_dark.clone(), 0.42)) + .into_any_element(), + )])] } else { vec![] } @@ -261,12 +223,11 @@ impl Component for ThemePreviewTile { themes_to_preview .iter() .enumerate() - .map(|(i, theme)| { - div().w(px(200.)).h(px(140.)).child(ThemePreviewTile::new( - theme.clone(), - false, - 0.42, - )) + .map(|(_, theme)| { + div() + .w(px(200.)) + .h(px(140.)) + .child(ThemePreviewTile::new(theme.clone(), 0.42)) }) .collect::>(), ) diff --git a/crates/onboarding/src/welcome.rs b/crates/onboarding/src/welcome.rs index 2ea120e021..3d2c034367 100644 --- a/crates/onboarding/src/welcome.rs +++ b/crates/onboarding/src/welcome.rs @@ -4,10 +4,13 @@ use gpui::{ }; use ui::{ButtonLike, Divider, DividerColor, KeyBinding, Vector, VectorName, prelude::*}; use workspace::{ - NewFile, Open, Workspace, WorkspaceId, + NewFile, Open, WorkspaceId, item::{Item, ItemEvent}, + with_active_or_new_workspace, }; -use zed_actions::{Extensions, OpenSettings, command_palette}; +use zed_actions::{Extensions, OpenSettings, agent, command_palette}; + +use crate::{Onboarding, OpenOnboarding}; actions!( zed, @@ -55,8 +58,7 @@ const CONTENT: (Section<4>, Section<3>) = ( SectionEntry { icon: IconName::ZedAssistant, title: "View AI Settings", - // TODO: use proper action - action: &NoAction, + action: &agent::OpenSettings, }, SectionEntry { icon: IconName::Blocks, @@ -217,7 +219,64 @@ impl Render for WelcomePage { div().child( Button::new("welcome-exit", "Return to Setup") .full_width() - .label_size(LabelSize::XSmall), + .label_size(LabelSize::XSmall) + .on_click(|_, window, cx| { + window.dispatch_action( + OpenOnboarding.boxed_clone(), + cx, + ); + + with_active_or_new_workspace(cx, |workspace, window, cx| { + let Some((welcome_id, welcome_idx)) = workspace + .active_pane() + .read(cx) + .items() + .enumerate() + .find_map(|(idx, item)| { + let _ = item.downcast::()?; + Some((item.item_id(), idx)) + }) + else { + return; + }; + + workspace.active_pane().update(cx, |pane, cx| { + // Get the index here to get around the borrow checker + let idx = pane.items().enumerate().find_map( + |(idx, item)| { + let _ = + item.downcast::()?; + Some(idx) + }, + ); + + if let Some(idx) = idx { + pane.activate_item( + idx, true, true, window, cx, + ); + } else { + let item = + Box::new(Onboarding::new(workspace, cx)); + pane.add_item( + item, + true, + true, + Some(welcome_idx), + window, + cx, + ); + } + + pane.remove_item( + welcome_id, + false, + false, + window, + cx, + ); + }); + }); + }), ), ), ), @@ -228,12 +287,14 @@ impl Render for WelcomePage { } impl WelcomePage { - pub fn new(cx: &mut Context) -> Entity { - let this = cx.new(|cx| WelcomePage { - focus_handle: cx.focus_handle(), - }); + pub fn new(window: &mut Window, cx: &mut App) -> Entity { + cx.new(|cx| { + let focus_handle = cx.focus_handle(); + cx.on_focus(&focus_handle, window, |_, _, cx| cx.notify()) + .detach(); - this + WelcomePage { focus_handle } + }) } } diff --git a/crates/project/src/debugger/test.rs b/crates/project/src/debugger/test.rs index 3b9425e369..53b88323e6 100644 --- a/crates/project/src/debugger/test.rs +++ b/crates/project/src/debugger/test.rs @@ -1,7 +1,7 @@ use std::{path::Path, sync::Arc}; use dap::client::DebugAdapterClient; -use gpui::{App, AppContext, Subscription}; +use gpui::{App, Subscription}; use super::session::{Session, SessionStateEvent}; @@ -19,14 +19,6 @@ pub fn intercept_debug_sessions) + 'static>( let client = session.adapter_client().unwrap(); register_default_handlers(session, &client, cx); configure(&client); - cx.background_spawn(async move { - client - .fake_event(dap::messages::Events::Initialized( - Some(Default::default()), - )) - .await - }) - .detach(); } }) .detach(); diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index a2f6de44c9..958921a0e6 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -2269,7 +2269,7 @@ impl LspCommand for GetCompletions { // the range based on the syntax tree. None => { if self.position != clipped_position { - log::info!("completion out of expected range"); + log::info!("completion out of expected range "); return false; } @@ -2483,7 +2483,9 @@ pub(crate) fn parse_completion_text_edit( let start = snapshot.clip_point_utf16(range.start, Bias::Left); let end = snapshot.clip_point_utf16(range.end, Bias::Left); if start != range.start.0 || end != range.end.0 { - log::info!("completion out of expected range"); + log::info!( + "completion out of expected range, start: {start:?}, end: {end:?}, range: {range:?}" + ); return None; } snapshot.anchor_before(start)..snapshot.anchor_after(end) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 5a8cc05d7d..af3df72c29 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -424,7 +424,7 @@ impl LocalLspStore { if settings.as_ref().is_some_and(|b| b.path.is_some()) { let settings = settings.unwrap(); - return cx.spawn(async move |_| { + return cx.background_spawn(async move { let mut env = delegate.shell_env().await; env.extend(settings.env.unwrap_or_default()); @@ -4911,7 +4911,7 @@ impl LspStore { language_server_id: server_id.0 as u64, hint: Some(InlayHints::project_to_proto_hint(hint.clone())), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let response = upstream_client .request(request) .await @@ -5125,7 +5125,7 @@ impl LspStore { trigger, version: serialize_version(&buffer.read(cx).version()), }; - cx.spawn(async move |_, _| { + cx.background_spawn(async move { client .request(request) .await? @@ -5284,7 +5284,7 @@ impl LspStore { GetDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(definitions_task .await .into_iter() @@ -5357,7 +5357,7 @@ impl LspStore { GetDeclarations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(declarations_task .await .into_iter() @@ -5430,7 +5430,7 @@ impl LspStore { GetTypeDefinitions { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(type_definitions_task .await .into_iter() @@ -5503,7 +5503,7 @@ impl LspStore { GetImplementations { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(implementations_task .await .into_iter() @@ -5576,7 +5576,7 @@ impl LspStore { GetReferences { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(references_task .await .into_iter() @@ -5660,7 +5660,7 @@ impl LspStore { }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(all_actions_task .await .into_iter() @@ -6854,7 +6854,7 @@ impl LspStore { } else { let document_colors_task = self.request_multiple_lsp_locally(buffer, None::, GetDocumentColor, cx); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { Ok(document_colors_task .await .into_iter() @@ -6933,7 +6933,7 @@ impl LspStore { GetSignatureHelp { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -7010,7 +7010,7 @@ impl LspStore { GetHover { position }, cx, ); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { all_actions_task .await .into_iter() @@ -8013,7 +8013,7 @@ impl LspStore { }) .collect::>(); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let mut responses = Vec::with_capacity(response_results.len()); while let Some((server_id, response_result)) = response_results.next().await { if let Some(response) = response_result.log_err() { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6b943216b3..623f48d3c9 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1362,10 +1362,7 @@ impl Project { fs: Arc, cx: AsyncApp, ) -> Result> { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.connect(true, &cx).await.into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::(remote_id)?), @@ -3372,7 +3369,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3390,7 +3387,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.declarations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3408,7 +3405,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.type_definitions(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3426,7 +3423,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.implementations(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3444,7 +3441,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.references(buffer, position, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result @@ -3996,7 +3993,7 @@ impl Project { let task = self.lsp_store.update(cx, |lsp_store, cx| { lsp_store.request_lsp(buffer_handle, server, request, cx) }); - cx.spawn(async move |_, _| { + cx.background_spawn(async move { let result = task.await; drop(guard); result diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 9f586a7839..83e5a77c86 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -784,6 +784,25 @@ pub fn split_repository_update( }]) } +impl MultiLspQuery { + pub fn request_str(&self) -> &str { + match self.request { + Some(multi_lsp_query::Request::GetHover(_)) => "GetHover", + Some(multi_lsp_query::Request::GetCodeActions(_)) => "GetCodeActions", + Some(multi_lsp_query::Request::GetSignatureHelp(_)) => "GetSignatureHelp", + Some(multi_lsp_query::Request::GetCodeLens(_)) => "GetCodeLens", + Some(multi_lsp_query::Request::GetDocumentDiagnostics(_)) => "GetDocumentDiagnostics", + Some(multi_lsp_query::Request::GetDocumentColor(_)) => "GetDocumentColor", + Some(multi_lsp_query::Request::GetDefinition(_)) => "GetDefinition", + Some(multi_lsp_query::Request::GetDeclaration(_)) => "GetDeclaration", + Some(multi_lsp_query::Request::GetTypeDefinition(_)) => "GetTypeDefinition", + Some(multi_lsp_query::Request::GetImplementation(_)) => "GetImplementation", + Some(multi_lsp_query::Request::GetReferences(_)) => "GetReferences", + None => "", + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index e31d3dcfd5..4306251e44 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1742,7 +1742,7 @@ impl SshRemoteConnection { } }); - cx.spawn(async move |_| { + cx.background_spawn(async move { let result = futures::select! { result = stdin_task.fuse() => { result.context("stdin") diff --git a/crates/settings/src/settings.rs b/crates/settings/src/settings.rs index 4e6bd94d92..afd4ea0890 100644 --- a/crates/settings/src/settings.rs +++ b/crates/settings/src/settings.rs @@ -7,7 +7,7 @@ mod settings_json; mod settings_store; mod vscode_import; -use gpui::App; +use gpui::{App, Global}; use rust_embed::RustEmbed; use std::{borrow::Cow, fmt, str}; use util::asset_str; @@ -27,6 +27,11 @@ pub use settings_store::{ }; pub use vscode_import::{VsCodeSettings, VsCodeSettingsSource}; +#[derive(Clone, Debug, PartialEq)] +pub struct ActiveSettingsProfileName(pub String); + +impl Global for ActiveSettingsProfileName {} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)] pub struct WorktreeId(usize); @@ -74,6 +79,7 @@ pub fn init(cx: &mut App) { .unwrap(); cx.set_global(settings); BaseKeymap::register(cx); + SettingsStore::observe_active_settings_profile_name(cx).detach(); } pub fn default_settings() -> Cow<'static, str> { diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 0d23385a68..7f6437dac8 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -26,8 +26,8 @@ use util::{ pub type EditorconfigProperties = ec4rs::Properties; use crate::{ - ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, WorktreeId, - parse_json_with_comments, update_value_in_json_text, + ActiveSettingsProfileName, ParameterizedJsonSchema, SettingsJsonSchemaParams, VsCodeSettings, + WorktreeId, parse_json_with_comments, update_value_in_json_text, }; /// A value that can be defined as a user setting. @@ -122,6 +122,8 @@ pub struct SettingsSources<'a, T> { pub user: Option<&'a T>, /// The user settings for the current release channel. pub release_channel: Option<&'a T>, + /// The settings associated with an enabled settings profile + pub profile: Option<&'a T>, /// The server's settings. pub server: Option<&'a T>, /// The project settings, ordered from least specific to most specific. @@ -141,6 +143,7 @@ impl<'a, T: Serialize> SettingsSources<'a, T> { .chain(self.extensions) .chain(self.user) .chain(self.release_channel) + .chain(self.profile) .chain(self.server) .chain(self.project.iter().copied()) } @@ -282,6 +285,14 @@ impl SettingsStore { } } + pub fn observe_active_settings_profile_name(cx: &mut App) -> gpui::Subscription { + cx.observe_global::(|cx| { + Self::update_global(cx, |store, cx| { + store.recompute_values(None, cx).log_err(); + }); + }) + } + pub fn update(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R where C: BorrowAppContext, @@ -321,6 +332,17 @@ impl SettingsStore { .log_err(); } + let mut profile_value = None; + if let Some(active_profile) = cx.try_global::() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_settings) = profiles.get(&active_profile.0) { + profile_value = setting_value + .deserialize_setting(profile_settings) + .log_err(); + } + } + } + let server_value = self .raw_server_settings .as_ref() @@ -340,6 +362,7 @@ impl SettingsStore { extensions: extension_value.as_ref(), user: user_value.as_ref(), release_channel: release_channel_value.as_ref(), + profile: profile_value.as_ref(), server: server_value.as_ref(), project: &[], }, @@ -402,6 +425,16 @@ impl SettingsStore { &self.raw_user_settings } + /// Get the configured settings profile names. + pub fn configured_settings_profiles(&self) -> impl Iterator { + self.raw_user_settings + .get("profiles") + .and_then(|v| v.as_object()) + .into_iter() + .flat_map(|obj| obj.keys()) + .map(|s| s.as_str()) + } + /// Access the raw JSON value of the global settings. pub fn raw_global_settings(&self) -> Option<&Value> { self.raw_global_settings.as_ref() @@ -532,7 +565,9 @@ impl SettingsStore { })) .ok(); } +} +impl SettingsStore { /// Updates the value of a setting in a JSON file, returning the new text /// for that JSON file. pub fn new_text_for_update( @@ -1001,18 +1036,18 @@ impl SettingsStore { const ZED_SETTINGS: &str = "ZedSettings"; let zed_settings_ref = add_new_subschema(&mut generator, ZED_SETTINGS, combined_schema); - // add `ZedReleaseStageSettings` which is the same as `ZedSettings` except that unknown - // fields are rejected. - let mut zed_release_stage_settings = zed_settings_ref.clone(); - zed_release_stage_settings.insert("unevaluatedProperties".to_string(), false.into()); - let zed_release_stage_settings_ref = add_new_subschema( + // add `ZedSettingsOverride` which is the same as `ZedSettings` except that unknown + // fields are rejected. This is used for release stage settings and profiles. + let mut zed_settings_override = zed_settings_ref.clone(); + zed_settings_override.insert("unevaluatedProperties".to_string(), false.into()); + let zed_settings_override_ref = add_new_subschema( &mut generator, - "ZedReleaseStageSettings", - zed_release_stage_settings.to_value(), + "ZedSettingsOverride", + zed_settings_override.to_value(), ); // Remove `"additionalProperties": false` added by `DefaultDenyUnknownFields` so that - // unknown fields can be handled by the root schema and `ZedReleaseStageSettings`. + // unknown fields can be handled by the root schema and `ZedSettingsOverride`. let mut definitions = generator.take_definitions(true); definitions .get_mut(ZED_SETTINGS) @@ -1032,15 +1067,20 @@ impl SettingsStore { "$schema": meta_schema, "title": "Zed Settings", "unevaluatedProperties": false, - // ZedSettings + settings overrides for each release stage + // ZedSettings + settings overrides for each release stage / profiles "allOf": [ zed_settings_ref, { "properties": { - "dev": zed_release_stage_settings_ref, - "nightly": zed_release_stage_settings_ref, - "stable": zed_release_stage_settings_ref, - "preview": zed_release_stage_settings_ref, + "dev": zed_settings_override_ref, + "nightly": zed_settings_override_ref, + "stable": zed_settings_override_ref, + "preview": zed_settings_override_ref, + "profiles": { + "type": "object", + "description": "Configures any number of settings profiles.", + "additionalProperties": zed_settings_override_ref + } } } ], @@ -1099,6 +1139,16 @@ impl SettingsStore { } } + let mut profile_settings = None; + if let Some(active_profile) = cx.try_global::() { + if let Some(profiles) = self.raw_user_settings.get("profiles") { + if let Some(profile_json) = profiles.get(&active_profile.0) { + profile_settings = + setting_value.deserialize_setting(profile_json).log_err(); + } + } + } + // If the global settings file changed, reload the global value for the field. if changed_local_path.is_none() { if let Some(value) = setting_value @@ -1109,6 +1159,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &[], }, @@ -1161,6 +1212,7 @@ impl SettingsStore { extensions: extension_settings.as_ref(), user: user_settings.as_ref(), release_channel: release_channel_settings.as_ref(), + profile: profile_settings.as_ref(), server: server_settings.as_ref(), project: &project_settings_stack.iter().collect::>(), }, @@ -1286,6 +1338,9 @@ impl AnySettingValue for SettingValue { release_channel: values .release_channel .map(|value| value.0.downcast_ref::().unwrap()), + profile: values + .profile + .map(|value| value.0.downcast_ref::().unwrap()), server: values .server .map(|value| value.0.downcast_ref::().unwrap()), diff --git a/crates/settings_profile_selector/Cargo.toml b/crates/settings_profile_selector/Cargo.toml new file mode 100644 index 0000000000..189272e54b --- /dev/null +++ b/crates/settings_profile_selector/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "settings_profile_selector" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/settings_profile_selector.rs" +doctest = false + +[dependencies] +fuzzy.workspace = true +gpui.workspace = true +picker.workspace = true +settings.workspace = true +ui.workspace = true +workspace-hack.workspace = true +workspace.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +editor = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +menu.workspace = true +project = { workspace = true, features = ["test-support"] } +serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/settings_profile_selector/LICENSE-GPL b/crates/settings_profile_selector/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/settings_profile_selector/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/settings_profile_selector/src/settings_profile_selector.rs b/crates/settings_profile_selector/src/settings_profile_selector.rs new file mode 100644 index 0000000000..8a34c12051 --- /dev/null +++ b/crates/settings_profile_selector/src/settings_profile_selector.rs @@ -0,0 +1,581 @@ +use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; +use gpui::{ + App, Context, DismissEvent, Entity, EventEmitter, Focusable, Render, Task, WeakEntity, Window, +}; +use picker::{Picker, PickerDelegate}; +use settings::{ActiveSettingsProfileName, SettingsStore}; +use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*}; +use workspace::{ModalView, Workspace}; + +pub fn init(cx: &mut App) { + cx.on_action(|_: &zed_actions::settings_profile_selector::Toggle, cx| { + workspace::with_active_or_new_workspace(cx, |workspace, window, cx| { + toggle_settings_profile_selector(workspace, window, cx); + }); + }); +} + +fn toggle_settings_profile_selector( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, +) { + workspace.toggle_modal(window, cx, |window, cx| { + let delegate = SettingsProfileSelectorDelegate::new(cx.entity().downgrade(), window, cx); + SettingsProfileSelector::new(delegate, window, cx) + }); +} + +pub struct SettingsProfileSelector { + picker: Entity>, +} + +impl ModalView for SettingsProfileSelector {} + +impl EventEmitter for SettingsProfileSelector {} + +impl Focusable for SettingsProfileSelector { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for SettingsProfileSelector { + fn render(&mut self, _: &mut Window, _: &mut Context) -> impl IntoElement { + v_flex().w(rems(22.)).child(self.picker.clone()) + } +} + +impl SettingsProfileSelector { + pub fn new( + delegate: SettingsProfileSelectorDelegate, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + Self { picker } + } +} + +pub struct SettingsProfileSelectorDelegate { + matches: Vec, + profile_names: Vec>, + original_profile_name: Option, + selected_profile_name: Option, + selected_index: usize, + selection_completed: bool, + selector: WeakEntity, +} + +impl SettingsProfileSelectorDelegate { + fn new( + selector: WeakEntity, + _: &mut Window, + cx: &mut Context, + ) -> Self { + let settings_store = cx.global::(); + let mut profile_names: Vec> = settings_store + .configured_settings_profiles() + .map(|s| Some(s.to_string())) + .collect(); + profile_names.insert(0, None); + + let matches = profile_names + .iter() + .enumerate() + .map(|(ix, profile_name)| StringMatch { + candidate_id: ix, + score: 0.0, + positions: Default::default(), + string: display_name(profile_name), + }) + .collect(); + + let profile_name = cx + .try_global::() + .map(|p| p.0.clone()); + + let mut this = Self { + matches, + profile_names, + original_profile_name: profile_name.clone(), + selected_profile_name: None, + selected_index: 0, + selection_completed: false, + selector, + }; + + if let Some(profile_name) = profile_name { + this.select_if_matching(&profile_name); + } + + this + } + + fn select_if_matching(&mut self, profile_name: &str) { + self.selected_index = self + .matches + .iter() + .position(|mat| mat.string == profile_name) + .unwrap_or(self.selected_index); + } + + fn set_selected_profile( + &self, + cx: &mut Context>, + ) -> Option { + let mat = self.matches.get(self.selected_index)?; + let profile_name = self.profile_names.get(mat.candidate_id)?; + return Self::update_active_profile_name_global(profile_name.clone(), cx); + } + + fn update_active_profile_name_global( + profile_name: Option, + cx: &mut Context>, + ) -> Option { + if let Some(profile_name) = profile_name { + cx.set_global(ActiveSettingsProfileName(profile_name.clone())); + return Some(profile_name.clone()); + } + + if cx.has_global::() { + cx.remove_global::(); + } + + None + } +} + +impl PickerDelegate for SettingsProfileSelectorDelegate { + type ListItem = ListItem; + + fn placeholder_text(&self, _: &mut Window, _: &mut App) -> std::sync::Arc { + "Select a settings profile...".into() + } + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _: &mut Window, + cx: &mut Context>, + ) { + self.selected_index = ix; + self.selected_profile_name = self.set_selected_profile(cx); + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + let background = cx.background_executor().clone(); + let candidates = self + .profile_names + .iter() + .enumerate() + .map(|(id, profile_name)| StringMatchCandidate::new(id, &display_name(profile_name))) + .collect::>(); + + cx.spawn_in(window, async move |this, cx| { + let matches = if query.is_empty() { + candidates + .into_iter() + .enumerate() + .map(|(index, candidate)| StringMatch { + candidate_id: index, + string: candidate.string, + positions: Vec::new(), + score: 0.0, + }) + .collect() + } else { + match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + background, + ) + .await + }; + + this.update_in(cx, |this, _, cx| { + this.delegate.matches = matches; + this.delegate.selected_index = this + .delegate + .selected_index + .min(this.delegate.matches.len().saturating_sub(1)); + this.delegate.selected_profile_name = this.delegate.set_selected_profile(cx); + }) + .ok(); + }) + } + + fn confirm( + &mut self, + _: bool, + _: &mut Window, + cx: &mut Context>, + ) { + self.selection_completed = true; + self.selector + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn dismissed( + &mut self, + _: &mut Window, + cx: &mut Context>, + ) { + if !self.selection_completed { + SettingsProfileSelectorDelegate::update_active_profile_name_global( + self.original_profile_name.clone(), + cx, + ); + } + self.selector.update(cx, |_, cx| cx.emit(DismissEvent)).ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context>, + ) -> Option { + let mat = &self.matches[ix]; + let profile_name = &self.profile_names[mat.candidate_id]; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .child(HighlightedLabel::new( + display_name(profile_name), + mat.positions.clone(), + )), + ) + } +} + +fn display_name(profile_name: &Option) -> String { + profile_name.clone().unwrap_or("Disabled".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use client; + use editor; + use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; + use language; + use menu::{Cancel, Confirm, SelectNext, SelectPrevious}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::Settings; + use theme::{self, ThemeSettings}; + use workspace::{self, AppState}; + use zed_actions::settings_profile_selector; + + async fn init_test( + profiles_json: serde_json::Value, + cx: &mut TestAppContext, + ) -> (Entity, &mut VisualTestContext) { + cx.update(|cx| { + let state = AppState::test(cx); + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + settings::init(cx); + theme::init(theme::LoadThemes::JustBase, cx); + ThemeSettings::register(cx); + client::init_settings(cx); + language::init(cx); + super::init(cx); + editor::init(cx); + workspace::init_settings(cx); + Project::init_settings(cx); + state + }); + + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + let settings_json = json!({ + "buffer_font_size": 10.0, + "profiles": profiles_json, + }); + + store + .set_user_settings(&settings_json.to_string(), cx) + .unwrap(); + }); + }); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, ["/test".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + cx.update(|_, cx| { + assert!(!cx.has_global::()); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + (workspace, cx) + } + + #[track_caller] + fn active_settings_profile_picker( + workspace: &Entity, + cx: &mut VisualTestContext, + ) -> Entity> { + workspace.update(cx, |workspace, cx| { + workspace + .active_modal::(cx) + .expect("settings profile selector is not open") + .read(cx) + .picker + .clone() + }) + } + + #[gpui::test] + async fn test_settings_profile_selector_state(cx: &mut TestAppContext) { + let classroom_and_streaming_profile_name = "Classroom / Streaming".to_string(); + let demo_videos_profile_name = "Demo Videos".to_string(); + + let profiles_json = json!({ + classroom_and_streaming_profile_name.clone(): { + "buffer_font_size": 20.0, + }, + demo_videos_profile_name.clone(): { + "buffer_font_size": 15.0 + } + }); + let (workspace, cx) = init_test(profiles_json.clone(), cx).await; + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.matches.len(), 3); + assert_eq!(picker.delegate.matches[0].string, display_name(&None)); + assert_eq!( + picker.delegate.matches[1].string, + classroom_and_streaming_profile_name + ); + assert_eq!(picker.delegate.matches[2].string, demo_videos_profile_name); + assert_eq!(picker.delegate.matches.get(3), None); + + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!(cx.try_global::(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::(), None); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectNext); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(Cancel); + + cx.update(|_, cx| { + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(settings_profile_selector::Toggle); + let picker = active_settings_profile_picker(&workspace, cx); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 2); + assert_eq!( + picker.delegate.selected_profile_name, + Some(demo_videos_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(demo_videos_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 15.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 1); + assert_eq!( + picker.delegate.selected_profile_name, + Some(classroom_and_streaming_profile_name.clone()) + ); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + Some(classroom_and_streaming_profile_name) + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 20.0); + }); + + cx.dispatch_action(SelectPrevious); + + picker.read_with(cx, |picker, cx| { + assert_eq!(picker.delegate.selected_index, 0); + assert_eq!(picker.delegate.selected_profile_name, None); + + assert_eq!( + cx.try_global::() + .map(|p| p.0.clone()), + None + ); + + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + + cx.dispatch_action(Confirm); + + cx.update(|_, cx| { + assert_eq!(cx.try_global::(), None); + assert_eq!(ThemeSettings::get_global(cx).buffer_font_size(cx).0, 10.0); + }); + } +} diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index e8434c1a32..a4c47081c6 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -30,7 +30,6 @@ menu.workspace = true notifications.workspace = true paths.workspace = true project.workspace = true -schemars.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 2f0abb4789..3022cc7142 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -1,20 +1,12 @@ mod appearance_settings_controls; use std::any::TypeId; -use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; use editor::EditorSettingsControls; use feature_flags::{FeatureFlag, FeatureFlagViewExt}; -use fs::Fs; -use gpui::{ - Action, App, AsyncWindowContext, Entity, EventEmitter, FocusHandle, Focusable, Task, actions, -}; -use schemars::JsonSchema; -use serde::Deserialize; -use settings::{SettingsStore, VsCodeSettingsSource}; +use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, actions}; use ui::prelude::*; -use util::truncate_and_remove_front; use workspace::item::{Item, ItemEvent}; use workspace::{Workspace, with_active_or_new_workspace}; @@ -29,23 +21,6 @@ impl FeatureFlag for SettingsUiFeatureFlag { const NAME: &'static str = "settings-ui"; } -/// Imports settings from Visual Studio Code. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportVsCodeSettings { - #[serde(default)] - pub skip_prompt: bool, -} - -/// Imports settings from Cursor editor. -#[derive(Copy, Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)] -#[action(namespace = zed)] -#[serde(deny_unknown_fields)] -pub struct ImportCursorSettings { - #[serde(default)] - pub skip_prompt: bool, -} actions!( zed, [ @@ -72,45 +47,11 @@ pub fn init(cx: &mut App) { }); }); - cx.observe_new(|workspace: &mut Workspace, window, cx| { + cx.observe_new(|_workspace: &mut Workspace, window, cx| { let Some(window) = window else { return; }; - workspace.register_action(|_workspace, action: &ImportVsCodeSettings, window, cx| { - let fs = ::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::VsCode, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - - workspace.register_action(|_workspace, action: &ImportCursorSettings, window, cx| { - let fs = ::global(cx); - let action = *action; - - window - .spawn(cx, async move |cx: &mut AsyncWindowContext| { - handle_import_vscode_settings( - VsCodeSettingsSource::Cursor, - action.skip_prompt, - fs, - cx, - ) - .await - }) - .detach(); - }); - let settings_ui_actions = [TypeId::of::()]; CommandPaletteFilter::update_global(cx, |filter, _cx| { @@ -138,57 +79,6 @@ pub fn init(cx: &mut App) { keybindings::init(cx); } -async fn handle_import_vscode_settings( - source: VsCodeSettingsSource, - skip_prompt: bool, - fs: Arc, - cx: &mut AsyncWindowContext, -) { - let vscode_settings = - match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await { - Ok(vscode_settings) => vscode_settings, - Err(err) => { - log::error!("{err}"); - let _ = cx.prompt( - gpui::PromptLevel::Info, - &format!("Could not find or load a {source} settings file"), - None, - &["Ok"], - ); - return; - } - }; - - let prompt = if skip_prompt { - Task::ready(Some(0)) - } else { - let prompt = cx.prompt( - gpui::PromptLevel::Warning, - &format!( - "Importing {} settings may overwrite your existing settings. \ - Will import settings from {}", - vscode_settings.source, - truncate_and_remove_front(&vscode_settings.path.to_string_lossy(), 128), - ), - None, - &["Ok", "Cancel"], - ); - cx.spawn(async move |_| prompt.await.ok()) - }; - if prompt.await != Some(0) { - return; - } - - cx.update(|_, cx| { - let source = vscode_settings.source; - let path = vscode_settings.path.clone(); - cx.global::() - .import_vscode_settings(fs, vscode_settings); - log::info!("Imported {source} settings from {}", path.display()); - }) - .ok(); -} - pub struct SettingsPage { focus_handle: FocusHandle, } diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index 5807d3aae0..bf17336f9d 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -99,7 +99,9 @@ impl Anchor { } else if self.buffer_id != Some(buffer.remote_id) { false } else { - let fragment_id = buffer.fragment_id_for_anchor(self); + let Some(fragment_id) = buffer.try_fragment_id_for_anchor(self) else { + return false; + }; let mut fragment_cursor = buffer.fragments.cursor::<(Option<&Locator>, usize)>(&None); fragment_cursor.seek(&Some(fragment_id), Bias::Left); fragment_cursor diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index c1da0649da..aded03d46a 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -2330,10 +2330,19 @@ impl BufferSnapshot { } fn fragment_id_for_anchor(&self, anchor: &Anchor) -> &Locator { + self.try_fragment_id_for_anchor(anchor).unwrap_or_else(|| { + panic!( + "invalid anchor {:?}. buffer id: {}, version: {:?}", + anchor, self.remote_id, self.version, + ) + }) + } + + fn try_fragment_id_for_anchor(&self, anchor: &Anchor) -> Option<&Locator> { if *anchor == Anchor::MIN { - Locator::min_ref() + Some(Locator::min_ref()) } else if *anchor == Anchor::MAX { - Locator::max_ref() + Some(Locator::max_ref()) } else { let anchor_key = InsertionFragmentKey { timestamp: anchor.timestamp, @@ -2354,20 +2363,12 @@ impl BufferSnapshot { insertion_cursor.prev(); } - let Some(insertion) = insertion_cursor.item().filter(|insertion| { - if cfg!(debug_assertions) { - insertion.timestamp == anchor.timestamp - } else { - true - } - }) else { - panic!( - "invalid anchor {:?}. buffer id: {}, version: {:?}", - anchor, self.remote_id, self.version - ); - }; - - &insertion.fragment_id + insertion_cursor + .item() + .filter(|insertion| { + !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp + }) + .map(|insertion| &insertion.fragment_id) } } diff --git a/crates/theme/src/settings.rs b/crates/theme/src/settings.rs index 1c4c90a475..20c837f287 100644 --- a/crates/theme/src/settings.rs +++ b/crates/theme/src/settings.rs @@ -438,7 +438,7 @@ fn default_font_fallbacks() -> Option { impl ThemeSettingsContent { /// Sets the theme for the given appearance to the theme with the specified name. - pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) { + pub fn set_theme(&mut self, theme_name: impl Into>, appearance: Appearance) { if let Some(selection) = self.theme.as_mut() { let theme_to_update = match selection { ThemeSelection::Static(theme) => theme, @@ -867,6 +867,7 @@ impl settings::Settings for ThemeSettings { .user .into_iter() .chain(sources.release_channel) + .chain(sources.profile) .chain(sources.server) { if let Some(value) = value.ui_density { diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index 8e95c6f79f..cf178e2850 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -32,6 +32,7 @@ auto_update.workspace = true call.workspace = true chrono.workspace = true client.workspace = true +cloud_llm_client.workspace = true db.workspace = true gpui = { workspace = true, features = ["screen-capture"] } notifications.workspace = true diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 17c4c85b6d..426d87ad13 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -21,6 +21,7 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; use client::{Client, UserStore, zed_urls}; +use cloud_llm_client::Plan; use gpui::{ Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, IntoElement, MouseButton, ParentElement, Render, StatefulInteractiveElement, Styled, @@ -28,7 +29,6 @@ use gpui::{ }; use onboarding_banner::OnboardingBanner; use project::Project; -use rpc::proto; use settings::Settings as _; use settings_ui::keybindings; use std::sync::Arc; @@ -179,24 +179,23 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } + let status = self.client.status(); + let status = &*status.borrow(); + let user = self.user_store.read(cx).current_user(); + children.push( h_flex() .gap_1() .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) - .map(|el| { - let status = self.client.status(); - let status = &*status.borrow(); - if matches!(status, client::Status::Connected { .. }) { - el.child(self.render_user_menu_button(cx)) - } else { - el.children(self.render_connection_status(status, cx)) - .when(TitleBarSettings::get_global(cx).show_sign_in, |el| { - el.child(self.render_sign_in_button(cx)) - }) - .child(self.render_user_menu_button(cx)) - } + .children(self.render_connection_status(status, cx)) + .when( + user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, + |el| el.child(self.render_sign_in_button(cx)), + ) + .when(user.is_some(), |parent| { + parent.child(self.render_user_menu_button(cx)) }) .into_any_element(), ); @@ -618,9 +617,8 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); @@ -630,8 +628,8 @@ impl TitleBar { pub fn render_user_menu_button(&mut self, cx: &mut Context) -> impl Element { let user_store = self.user_store.read(cx); if let Some(user) = user_store.current_user() { - let has_subscription_period = self.user_store.read(cx).subscription_period().is_some(); - let plan = self.user_store.read(cx).current_plan().filter(|_| { + let has_subscription_period = user_store.subscription_period().is_some(); + let plan = user_store.plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); @@ -658,13 +656,9 @@ impl TitleBar { let user_login = user.github_login.clone(); let (plan_name, label_color, bg_color) = match plan { - None | Some(proto::Plan::Free) => { - ("Free", Color::Default, free_chip_bg) - } - Some(proto::Plan::ZedProTrial) => { - ("Pro Trial", Color::Accent, pro_chip_bg) - } - Some(proto::Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), + None | Some(Plan::ZedFree) => ("Free", Color::Default, free_chip_bg), + Some(Plan::ZedProTrial) => ("Pro Trial", Color::Accent, pro_chip_bg), + Some(Plan::ZedPro) => ("Pro", Color::Accent, pro_chip_bg), }; menu.custom_entry( diff --git a/crates/ui/src/components.rs b/crates/ui/src/components.rs index 9c2961c55f..486673e733 100644 --- a/crates/ui/src/components.rs +++ b/crates/ui/src/components.rs @@ -1,4 +1,5 @@ mod avatar; +mod badge; mod banner; mod button; mod callout; @@ -41,6 +42,7 @@ mod tooltip; mod stories; pub use avatar::*; +pub use badge::*; pub use banner::*; pub use button::*; pub use callout::*; diff --git a/crates/ui/src/components/badge.rs b/crates/ui/src/components/badge.rs new file mode 100644 index 0000000000..2eee084bbb --- /dev/null +++ b/crates/ui/src/components/badge.rs @@ -0,0 +1,66 @@ +use crate::Divider; +use crate::DividerColor; +use crate::component_prelude::*; +use crate::prelude::*; +use gpui::{AnyElement, IntoElement, SharedString, Window}; + +#[derive(IntoElement, RegisterComponent)] +pub struct Badge { + label: SharedString, + icon: IconName, +} + +impl Badge { + pub fn new(label: impl Into) -> Self { + Self { + label: label.into(), + icon: IconName::Check, + } + } + + pub fn icon(mut self, icon: IconName) -> Self { + self.icon = icon; + self + } +} + +impl RenderOnce for Badge { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + h_flex() + .h_full() + .gap_1() + .pl_1() + .pr_2() + .border_1() + .border_color(cx.theme().colors().border.opacity(0.6)) + .bg(cx.theme().colors().element_background) + .rounded_sm() + .overflow_hidden() + .child( + Icon::new(self.icon) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child(Label::new(self.label.clone()).size(LabelSize::Small).ml_1()) + } +} + +impl Component for Badge { + fn scope() -> ComponentScope { + ComponentScope::DataDisplay + } + + fn description() -> Option<&'static str> { + Some( + "A compact, labeled component with optional icon for displaying status, categories, or metadata.", + ) + } + + fn preview(_window: &mut Window, _cx: &mut App) -> Option { + Some( + single_example("Basic Badge", Badge::new("Default").into_any_element()) + .into_any_element(), + ) + } +} diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 135ecdfe62..03f7964f35 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -358,6 +358,7 @@ impl ButtonStyle { #[derive(Default, PartialEq, Clone, Copy)] pub enum ButtonSize { Large, + Medium, #[default] Default, Compact, @@ -368,6 +369,7 @@ impl ButtonSize { pub fn rems(self) -> Rems { match self { ButtonSize::Large => rems_from_px(32.), + ButtonSize::Medium => rems_from_px(28.), ButtonSize::Default => rems_from_px(22.), ButtonSize::Compact => rems_from_px(18.), ButtonSize::None => rems_from_px(16.), @@ -573,7 +575,7 @@ impl RenderOnce for ButtonLike { }) .gap(DynamicSpacing::Base04.rems(cx)) .map(|this| match self.size { - ButtonSize::Large => this.px(DynamicSpacing::Base06.rems(cx)), + ButtonSize::Large | ButtonSize::Medium => this.px(DynamicSpacing::Base06.rems(cx)), ButtonSize::Default | ButtonSize::Compact => { this.px(DynamicSpacing::Base04.rems(cx)) } diff --git a/crates/ui/src/components/button/toggle_button.rs b/crates/ui/src/components/button/toggle_button.rs index 30683e60f3..a1e4d65a24 100644 --- a/crates/ui/src/components/button/toggle_button.rs +++ b/crates/ui/src/components/button/toggle_button.rs @@ -295,6 +295,7 @@ pub struct ButtonConfiguration { label: SharedString, icon: Option, on_click: Box, + selected: bool, } mod private { @@ -308,6 +309,7 @@ pub trait ButtonBuilder: 'static + private::ToggleButtonStyle { pub struct ToggleButtonSimple { label: SharedString, on_click: Box, + selected: bool, } impl ToggleButtonSimple { @@ -318,8 +320,14 @@ impl ToggleButtonSimple { Self { label: label.into(), on_click: Box::new(on_click), + selected: false, } } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } } impl private::ToggleButtonStyle for ToggleButtonSimple {} @@ -330,6 +338,7 @@ impl ButtonBuilder for ToggleButtonSimple { label: self.label, icon: None, on_click: self.on_click, + selected: self.selected, } } } @@ -338,6 +347,7 @@ pub struct ToggleButtonWithIcon { label: SharedString, icon: IconName, on_click: Box, + selected: bool, } impl ToggleButtonWithIcon { @@ -350,8 +360,14 @@ impl ToggleButtonWithIcon { label: label.into(), icon, on_click: Box::new(on_click), + selected: false, } } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } } impl private::ToggleButtonStyle for ToggleButtonWithIcon {} @@ -362,6 +378,7 @@ impl ButtonBuilder for ToggleButtonWithIcon { label: self.label, icon: Some(self.icon), on_click: self.on_click, + selected: self.selected, } } } @@ -373,6 +390,12 @@ pub enum ToggleButtonGroupStyle { Outlined, } +#[derive(Clone, Copy, PartialEq)] +pub enum ToggleButtonGroupSize { + Default, + Medium, +} + #[derive(IntoElement)] pub struct ToggleButtonGroup where @@ -381,6 +404,7 @@ where group_name: &'static str, rows: [[T; COLS]; ROWS], style: ToggleButtonGroupStyle, + size: ToggleButtonGroupSize, button_width: Rems, selected_index: usize, } @@ -391,6 +415,7 @@ impl ToggleButtonGroup { group_name, rows: [buttons], style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, button_width: rems_from_px(100.), selected_index: 0, } @@ -403,6 +428,7 @@ impl ToggleButtonGroup { group_name, rows: [first_row, second_row], style: ToggleButtonGroupStyle::Transparent, + size: ToggleButtonGroupSize::Default, button_width: rems_from_px(100.), selected_index: 0, } @@ -415,6 +441,11 @@ impl ToggleButtonGroup Self { + self.size = size; + self + } + pub fn button_width(mut self, button_width: Rems) -> Self { self.button_width = button_width; self @@ -430,47 +461,56 @@ impl RenderOnce for ToggleButtonGroup { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let entries = self.rows.into_iter().enumerate().map(|(row_index, row)| { - row.into_iter().enumerate().map(move |(index, button)| { - let ButtonConfiguration { - label, - icon, - on_click, - } = button.into_configuration(); + let entries = + self.rows.into_iter().enumerate().map(|(row_index, row)| { + row.into_iter().enumerate().map(move |(col_index, button)| { + let ButtonConfiguration { + label, + icon, + on_click, + selected, + } = button.into_configuration(); - ButtonLike::new((self.group_name, row_index * COLS + index)) - .when(index == self.selected_index, |this| { - this.toggle_state(true) - .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - }) - .rounding(None) - .when(self.style == ToggleButtonGroupStyle::Filled, |button| { - button.style(ButtonStyle::Filled) - }) - .child( - h_flex() - .min_w(self.button_width) - .gap_1p5() - .justify_center() - .when_some(icon, |this, icon| { - this.child(Icon::new(icon).size(IconSize::XSmall).map(|this| { - if index == self.selected_index { - this.color(Color::Accent) - } else { - this.color(Color::Muted) - } - })) - }) - .child( - Label::new(label).when(index == self.selected_index, |this| { - this.color(Color::Accent) - }), - ), - ) - .on_click(on_click) - .into_any_element() - }) - }); + let entry_index = row_index * COLS + col_index; + + ButtonLike::new((self.group_name, entry_index)) + .when(entry_index == self.selected_index || selected, |this| { + this.toggle_state(true) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + }) + .rounding(None) + .when(self.style == ToggleButtonGroupStyle::Filled, |button| { + button.style(ButtonStyle::Filled) + }) + .when(self.size == ToggleButtonGroupSize::Medium, |button| { + button.size(ButtonSize::Medium) + }) + .child( + h_flex() + .min_w(self.button_width) + .gap_1p5() + .px_3() + .py_1() + .justify_center() + .when_some(icon, |this, icon| { + this.py_2() + .child(Icon::new(icon).size(IconSize::XSmall).map(|this| { + if entry_index == self.selected_index || selected { + this.color(Color::Accent) + } else { + this.color(Color::Muted) + } + })) + }) + .child(Label::new(label).size(LabelSize::Small).when( + entry_index == self.selected_index || selected, + |this| this.color(Color::Accent), + )), + ) + .on_click(on_click) + .into_any_element() + }) + }); let border_color = cx.theme().colors().border.opacity(0.6); let is_outlined_or_filled = self.style == ToggleButtonGroupStyle::Outlined diff --git a/crates/ui/src/components/dropdown_menu.rs b/crates/ui/src/components/dropdown_menu.rs index 189fac930f..cdb98086ca 100644 --- a/crates/ui/src/components/dropdown_menu.rs +++ b/crates/ui/src/components/dropdown_menu.rs @@ -8,6 +8,7 @@ use super::PopoverMenuHandle; pub enum DropdownStyle { #[default] Solid, + Outlined, Ghost, } @@ -147,6 +148,23 @@ impl Component for DropdownMenu { ), ], ), + example_group_with_title( + "Styles", + vec![ + single_example( + "Outlined", + DropdownMenu::new("outlined", "Outlined Dropdown", menu.clone()) + .style(DropdownStyle::Outlined) + .into_any_element(), + ), + single_example( + "Ghost", + DropdownMenu::new("ghost", "Ghost Dropdown", menu.clone()) + .style(DropdownStyle::Ghost) + .into_any_element(), + ), + ], + ), example_group_with_title( "States", vec![single_example( @@ -170,10 +188,13 @@ pub struct DropdownTriggerStyle { impl DropdownTriggerStyle { pub fn for_style(style: DropdownStyle, cx: &App) -> Self { let colors = cx.theme().colors(); + let bg = match style { DropdownStyle::Solid => colors.editor_background, + DropdownStyle::Outlined => colors.surface_background, DropdownStyle::Ghost => colors.ghost_element_background, }; + Self { bg } } } @@ -244,17 +265,24 @@ impl RenderOnce for DropdownMenuTrigger { let disabled = self.disabled; let style = DropdownTriggerStyle::for_style(self.style, cx); + let is_outlined = matches!(self.style, DropdownStyle::Outlined); h_flex() .id("dropdown-menu-trigger") - .justify_between() - .rounded_sm() - .bg(style.bg) + .min_w_20() .pl_2() .pr_1p5() .py_0p5() .gap_2() - .min_w_20() + .justify_between() + .rounded_sm() + .bg(style.bg) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .when(is_outlined, |this| { + this.border_1() + .border_color(cx.theme().colors().border) + .overflow_hidden() + }) .map(|el| { if self.full_width { el.w_full() diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs index 2145b34ef2..a70f5e1ea5 100644 --- a/crates/ui/src/components/modal.rs +++ b/crates/ui/src/components/modal.rs @@ -1,5 +1,5 @@ use crate::{ - Clickable, Color, DynamicSpacing, Headline, HeadlineSize, IconButton, IconButtonShape, + Clickable, Color, DynamicSpacing, Headline, HeadlineSize, Icon, IconButton, IconButtonShape, IconName, Label, LabelCommon, LabelSize, h_flex, v_flex, }; use gpui::{prelude::FluentBuilder, *}; @@ -92,6 +92,7 @@ impl RenderOnce for Modal { #[derive(IntoElement)] pub struct ModalHeader { + icon: Option, headline: Option, description: Option, children: SmallVec<[AnyElement; 2]>, @@ -108,6 +109,7 @@ impl Default for ModalHeader { impl ModalHeader { pub fn new() -> Self { Self { + icon: None, headline: None, description: None, children: SmallVec::new(), @@ -116,6 +118,11 @@ impl ModalHeader { } } + pub fn icon(mut self, icon: Icon) -> Self { + self.icon = Some(icon); + self + } + /// Set the headline of the modal. /// /// This will insert the headline as the first item @@ -179,12 +186,17 @@ impl RenderOnce for ModalHeader { ) }) .child( - v_flex().flex_1().children(children).when_some( - self.description, - |this, description| { + v_flex() + .flex_1() + .child( + h_flex() + .gap_1() + .when_some(self.icon, |this, icon| this.child(icon)) + .children(children), + ) + .when_some(self.description, |this, description| { this.child(Label::new(description).color(Color::Muted).mb_2()) - }, - ), + }), ) .when(self.show_dismiss_button, |this| { this.child( diff --git a/crates/ui/src/components/numeric_stepper.rs b/crates/ui/src/components/numeric_stepper.rs index 05d368f427..5a84633d1b 100644 --- a/crates/ui/src/components/numeric_stepper.rs +++ b/crates/ui/src/components/numeric_stepper.rs @@ -2,10 +2,18 @@ use gpui::ClickEvent; use crate::{IconButtonShape, prelude::*}; +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NumericStepperStyle { + Outlined, + #[default] + Ghost, +} + #[derive(IntoElement, RegisterComponent)] pub struct NumericStepper { id: ElementId, value: SharedString, + style: NumericStepperStyle, on_decrement: Box, on_increment: Box, /// Whether to reserve space for the reset button. @@ -23,6 +31,7 @@ impl NumericStepper { Self { id: id.into(), value: value.into(), + style: NumericStepperStyle::default(), on_decrement: Box::new(on_decrement), on_increment: Box::new(on_increment), reserve_space_for_reset: false, @@ -30,6 +39,11 @@ impl NumericStepper { } } + pub fn style(mut self, style: NumericStepperStyle) -> Self { + self.style = style; + self + } + pub fn reserve_space_for_reset(mut self, reserve_space_for_reset: bool) -> Self { self.reserve_space_for_reset = reserve_space_for_reset; self @@ -49,6 +63,8 @@ impl RenderOnce for NumericStepper { let shape = IconButtonShape::Square; let icon_size = IconSize::Small; + let is_outlined = matches!(self.style, NumericStepperStyle::Outlined); + h_flex() .id(self.id) .gap_1() @@ -74,22 +90,65 @@ impl RenderOnce for NumericStepper { .child( h_flex() .gap_1() - .px_1() - .rounded_xs() - .bg(cx.theme().colors().editor_background) - .child( - IconButton::new("decrement", IconName::Dash) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_decrement), - ) - .child(Label::new(self.value)) - .child( - IconButton::new("increment", IconName::Plus) - .shape(shape) - .icon_size(icon_size) - .on_click(self.on_increment), - ), + .rounded_sm() + .map(|this| { + if is_outlined { + this.overflow_hidden() + .bg(cx.theme().colors().surface_background) + .border_1() + .border_color(cx.theme().colors().border) + } else { + this.px_1().bg(cx.theme().colors().editor_background) + } + }) + .map(|decrement| { + if is_outlined { + decrement.child( + h_flex() + .id("decrement_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_r_1() + .border_color(cx.theme().colors().border) + .child(Icon::new(IconName::Dash).size(IconSize::Small)) + .on_click(self.on_decrement), + ) + } else { + decrement.child( + IconButton::new("decrement", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .on_click(self.on_decrement), + ) + } + }) + .when(is_outlined, |this| this) + .child(Label::new(self.value).mx_3()) + .map(|increment| { + if is_outlined { + increment.child( + h_flex() + .id("increment_button") + .p_1p5() + .size_full() + .justify_center() + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .border_l_1() + .border_color(cx.theme().colors().border) + .child(Icon::new(IconName::Plus).size(IconSize::Small)) + .on_click(self.on_increment), + ) + } else { + increment.child( + IconButton::new("increment", IconName::Dash) + .shape(shape) + .icon_size(icon_size) + .on_click(self.on_increment), + ) + } + }), ) } } @@ -100,7 +159,7 @@ impl Component for NumericStepper { } fn name() -> &'static str { - "NumericStepper" + "Numeric Stepper" } fn sort_name() -> &'static str { @@ -108,18 +167,39 @@ impl Component for NumericStepper { } fn description() -> Option<&'static str> { - Some("A button used to increment or decrement a numeric value. ") + Some("A button used to increment or decrement a numeric value.") } fn preview(_window: &mut Window, _cx: &mut App) -> Option { Some( - div() - .child(NumericStepper::new( - "numeric-stepper-component-preview", - "10", - move |_, _, _| {}, - move |_, _, _| {}, - )) + v_flex() + .gap_6() + .children(vec![example_group_with_title( + "Styles", + vec![ + single_example( + "Default", + NumericStepper::new( + "numeric-stepper-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .into_any_element(), + ), + single_example( + "Outlined", + NumericStepper::new( + "numeric-stepper-with-border-component-preview", + "10", + move |_, _, _| {}, + move |_, _, _| {}, + ) + .style(NumericStepperStyle::Outlined) + .into_any_element(), + ), + ], + )]) .into_any_element(), ) } diff --git a/crates/ui/src/components/toggle.rs b/crates/ui/src/components/toggle.rs index cf2a56b1c9..0d8f5c4107 100644 --- a/crates/ui/src/components/toggle.rs +++ b/crates/ui/src/components/toggle.rs @@ -566,7 +566,7 @@ impl RenderOnce for Switch { pub struct SwitchField { id: ElementId, label: SharedString, - description: SharedString, + description: Option, toggle_state: ToggleState, on_click: Arc, disabled: bool, @@ -577,14 +577,14 @@ impl SwitchField { pub fn new( id: impl Into, label: impl Into, - description: impl Into, + description: Option, toggle_state: impl Into, on_click: impl Fn(&ToggleState, &mut Window, &mut App) + 'static, ) -> Self { Self { id: id.into(), label: label.into(), - description: description.into(), + description: description, toggle_state: toggle_state.into(), on_click: Arc::new(on_click), disabled: false, @@ -592,6 +592,11 @@ impl SwitchField { } } + pub fn description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + pub fn disabled(mut self, disabled: bool) -> Self { self.disabled = disabled; self @@ -609,17 +614,22 @@ impl RenderOnce for SwitchField { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { h_flex() .id(SharedString::from(format!("{}-container", self.id))) + .when(!self.disabled, |this| { + this.hover(|this| this.cursor_pointer()) + }) .w_full() .gap_4() .justify_between() .flex_wrap() - .child( - v_flex() + .child(match &self.description { + Some(description) => v_flex() .gap_0p5() .max_w_5_6() - .child(Label::new(self.label)) - .child(Label::new(self.description).color(Color::Muted)), - ) + .child(Label::new(self.label.clone())) + .child(Label::new(description.clone()).color(Color::Muted)) + .into_any_element(), + None => Label::new(self.label.clone()).into_any_element(), + }) .child( Switch::new( SharedString::from(format!("{}-switch", self.id)), @@ -668,7 +678,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_unselected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Unselected, |_, _, _| {}, ) @@ -679,7 +689,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_selected", "Enable notifications", - "Receive notifications when new messages arrive.", + Some("Receive notifications when new messages arrive.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -695,7 +705,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_default", "Default color", - "This uses the default switch color.", + Some("This uses the default switch color.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -706,7 +716,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_accent", "Accent color", - "This uses the accent color scheme.", + Some("This uses the accent color scheme.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -722,7 +732,7 @@ impl Component for SwitchField { SwitchField::new( "switch_field_disabled", "Disabled field", - "This field is disabled and cannot be toggled.", + Some("This field is disabled and cannot be toggled.".into()), ToggleState::Selected, |_, _, _| {}, ) @@ -730,6 +740,20 @@ impl Component for SwitchField { .into_any_element(), )], ), + example_group_with_title( + "No Description", + vec![single_example( + "No Description", + SwitchField::new( + "switch_field_disabled", + "Disabled field", + None, + ToggleState::Selected, + |_, _, _| {}, + ) + .into_any_element(), + )], + ), ]) .into_any_element(), ) diff --git a/crates/ui_prompt/src/ui_prompt.rs b/crates/ui_prompt/src/ui_prompt.rs index 2b6a030f26..fe6dc5b3f4 100644 --- a/crates/ui_prompt/src/ui_prompt.rs +++ b/crates/ui_prompt/src/ui_prompt.rs @@ -43,7 +43,7 @@ fn zed_prompt_renderer( let renderer = cx.new({ |cx| ZedPromptRenderer { _level: level, - message: message.to_string(), + message: cx.new(|cx| Markdown::new(SharedString::new(message), None, None, cx)), actions: actions.iter().map(|a| a.label().to_string()).collect(), focus: cx.focus_handle(), active_action_id: 0, @@ -58,7 +58,7 @@ fn zed_prompt_renderer( pub struct ZedPromptRenderer { _level: PromptLevel, - message: String, + message: Entity, actions: Vec, focus: FocusHandle, active_action_id: usize, @@ -114,7 +114,7 @@ impl ZedPromptRenderer { impl Render for ZedPromptRenderer { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); - let font_family = settings.ui_font.family.clone(); + let font_size = settings.ui_font_size(cx).into(); let prompt = v_flex() .key_context("Prompt") .cursor_default() @@ -130,24 +130,38 @@ impl Render for ZedPromptRenderer { .overflow_hidden() .p_4() .gap_4() - .font_family(font_family) + .font_family(settings.ui_font.family.clone()) .child( div() .w_full() - .font_weight(FontWeight::BOLD) - .child(self.message.clone()) - .text_color(ui::Color::Default.color(cx)), + .child(MarkdownElement::new(self.message.clone(), { + let mut base_text_style = window.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_size: Some(font_size), + font_weight: Some(FontWeight::BOLD), + color: Some(ui::Color::Default.color(cx)), + ..Default::default() + }); + MarkdownStyle { + base_text_style, + selection_background_color: cx + .theme() + .colors() + .element_selection_background, + ..Default::default() + } + })), ) .children(self.detail.clone().map(|detail| { div() .w_full() .text_xs() .child(MarkdownElement::new(detail, { - let settings = ThemeSettings::get_global(cx); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(settings.ui_font.family.clone()), - font_size: Some(settings.ui_font_size(cx).into()), + font_size: Some(font_size), color: Some(ui::Color::Muted.color(cx)), ..Default::default() }); @@ -176,24 +190,28 @@ impl Render for ZedPromptRenderer { }), )); - div().size_full().occlude().child( - div() - .size_full() - .absolute() - .top_0() - .left_0() - .flex() - .flex_col() - .justify_around() - .child( - div() - .w_full() - .flex() - .flex_row() - .justify_around() - .child(prompt), - ), - ) + div() + .size_full() + .occlude() + .bg(gpui::black().opacity(0.2)) + .child( + div() + .size_full() + .absolute() + .top_0() + .left_0() + .flex() + .flex_col() + .justify_around() + .child( + div() + .w_full() + .flex() + .flex_row() + .justify_around() + .child(prompt), + ), + ) } } diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index a50b238cc5..c22cf0ef00 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -987,7 +987,7 @@ impl Motion { SelectionGoal::None, ), NextWordEnd { ignore_punctuation } => ( - next_word_end(map, point, *ignore_punctuation, times, true), + next_word_end(map, point, *ignore_punctuation, times, true, true), SelectionGoal::None, ), PreviousWordStart { ignore_punctuation } => ( @@ -1723,14 +1723,19 @@ pub(crate) fn next_word_end( ignore_punctuation: bool, times: usize, allow_cross_newline: bool, + always_advance: bool, ) -> DisplayPoint { let classifier = map .buffer_snapshot .char_classifier_at(point.to_point(map)) .ignore_punctuation(ignore_punctuation); for _ in 0..times { - let new_point = next_char(map, point, allow_cross_newline); let mut need_next_char = false; + let new_point = if always_advance { + next_char(map, point, allow_cross_newline) + } else { + point + }; let new_point = movement::find_boundary_exclusive( map, new_point, diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 9485f17477..135cdd687f 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -51,6 +51,7 @@ impl Vim { ignore_punctuation, &text_layout_details, motion == Motion::NextSubwordStart { ignore_punctuation }, + !matches!(motion, Motion::NextWordStart { .. }), ) } _ => { @@ -148,6 +149,7 @@ fn expand_changed_word_selection( ignore_punctuation: bool, text_layout_details: &TextLayoutDetails, use_subword: bool, + always_advance: bool, ) -> Option { let is_in_word = || { let classifier = map @@ -173,8 +175,14 @@ fn expand_changed_word_selection( selection.end = motion::next_subword_end(map, selection.end, ignore_punctuation, 1, false); } else { - selection.end = - motion::next_word_end(map, selection.end, ignore_punctuation, 1, false); + selection.end = motion::next_word_end( + map, + selection.end, + ignore_punctuation, + 1, + false, + always_advance, + ); } selection.end = motion::next_char(map, selection.end, false); } @@ -271,6 +279,10 @@ mod test { cx.simulate("c shift-w", "Test teˇst-test test") .await .assert_matches(); + + // on last character of word, `cw` doesn't eat subsequent punctuation + // see https://github.com/zed-industries/zed/issues/35269 + cx.simulate("c w", "tesˇt-test").await.assert_matches(); } #[gpui::test] diff --git a/crates/vim/test_data/test_change_w.json b/crates/vim/test_data/test_change_w.json index 27be543532..149dac8420 100644 --- a/crates/vim/test_data/test_change_w.json +++ b/crates/vim/test_data/test_change_w.json @@ -30,3 +30,7 @@ {"Key":"c"} {"Key":"shift-w"} {"Get":{"state":"Test teˇ test","mode":"Insert"}} +{"Put":{"state":"tesˇt-test"}} +{"Key":"c"} +{"Key":"w"} +{"Get":{"state":"tesˇ-test","mode":"Insert"}} diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index 769dd8d6aa..acb3fe0f84 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -29,7 +29,6 @@ project.workspace = true serde.workspace = true settings.workspace = true telemetry.workspace = true -theme.workspace = true ui.workspace = true util.workspace = true vim_mode_setting.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index 49bf2031ab..352118eee8 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -21,7 +21,6 @@ pub use multibuffer_hint::*; mod base_keymap_picker; mod multibuffer_hint; -mod welcome_ui; actions!( welcome, diff --git a/crates/welcome/src/welcome_ui.rs b/crates/welcome/src/welcome_ui.rs deleted file mode 100644 index 622b6f448d..0000000000 --- a/crates/welcome/src/welcome_ui.rs +++ /dev/null @@ -1 +0,0 @@ -mod theme_preview; diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 3f8b098203..6fa5c969e7 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -939,6 +939,26 @@ impl WorkspaceDb { } } + query! { + pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result> { + UPDATE ssh_projects + SET paths = ?2 + WHERE id = ?1 + RETURNING id, host, port, paths, user + } + } + + pub(crate) async fn update_ssh_project_paths( + &self, + ssh_project_id: SshProjectId, + new_paths: Vec, + ) -> Result { + let paths = serde_json::to_string(&new_paths)?; + self.update_ssh_project_paths_query(ssh_project_id.0, paths) + .await? + .context("failed to update ssh project paths") + } + query! { pub async fn next_id() -> Result { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id @@ -2624,4 +2644,56 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } + + #[gpui::test] + async fn test_update_ssh_project_paths() { + zlog::init_test(); + + let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; + + let (host, port, initial_paths, user) = ( + "example.com".to_string(), + Some(22_u16), + vec!["/home/user".to_string(), "/etc/nginx".to_string()], + Some("user".to_string()), + ); + + let project = db + .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) + .await + .unwrap(); + + assert_eq!(project.host, host); + assert_eq!(project.paths, initial_paths); + assert_eq!(project.user, user); + + let new_paths = vec![ + "/home/user".to_string(), + "/etc/nginx".to_string(), + "/var/log".to_string(), + "/opt/app".to_string(), + ]; + + let updated_project = db + .update_ssh_project_paths(project.id, new_paths.clone()) + .await + .unwrap(); + + assert_eq!(updated_project.id, project.id); + assert_eq!(updated_project.paths, new_paths); + + let retrieved_project = db + .get_ssh_project( + host.clone(), + port, + serde_json::to_string(&new_paths).unwrap(), + user.clone(), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(retrieved_project.id, project.id); + assert_eq!(retrieved_project.paths, new_paths); + } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index e58014e7b8..6f7db668dd 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1090,7 +1090,8 @@ pub struct Workspace { _subscriptions: Vec, _apply_leader_updates: Task>, _observe_current_user: Task>, - _schedule_serialize: Option>, + _schedule_serialize_workspace: Option>, + _schedule_serialize_ssh_paths: Option>, pane_history_timestamp: Arc, bounds: Bounds, pub centered_layout: bool, @@ -1149,6 +1150,8 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); + this.update_ssh_paths(cx); + this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. this.update_history(cx); @@ -1416,7 +1419,8 @@ impl Workspace { app_state, _observe_current_user, _apply_leader_updates, - _schedule_serialize: None, + _schedule_serialize_workspace: None, + _schedule_serialize_ssh_paths: None, leader_updates_tx, _subscriptions: subscriptions, pane_history_timestamp, @@ -5073,6 +5077,46 @@ impl Workspace { } } + fn update_ssh_paths(&mut self, cx: &App) { + let project = self.project().read(cx); + if !project.is_local() { + let paths: Vec = project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) + .collect(); + if let Some(ssh_project) = &mut self.serialized_ssh_project { + ssh_project.paths = paths; + } + } + } + + fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context) { + if self._schedule_serialize_ssh_paths.is_none() { + self._schedule_serialize_ssh_paths = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + let task = if let Some(ssh_project) = &this.serialized_ssh_project { + let ssh_project_id = ssh_project.id; + let ssh_project_paths = ssh_project.paths.clone(); + window.spawn(cx, async move |_| { + persistence::DB + .update_ssh_project_paths(ssh_project_id, ssh_project_paths) + .await + }) + } else { + Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) + }; + task.detach(); + this._schedule_serialize_ssh_paths.take(); + }) + .log_err(); + })); + } + } + fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context) { match member { Member::Axis(PaneAxis { members, .. }) => { @@ -5116,17 +5160,18 @@ impl Workspace { } fn serialize_workspace(&mut self, window: &mut Window, cx: &mut Context) { - if self._schedule_serialize.is_none() { - self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; - this.update_in(cx, |this, window, cx| { - this.serialize_workspace_internal(window, cx).detach(); - this._schedule_serialize.take(); - }) - .log_err(); - })); + if self._schedule_serialize_workspace.is_none() { + self._schedule_serialize_workspace = + Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(SERIALIZATION_THROTTLE_TIME) + .await; + this.update_in(cx, |this, window, cx| { + this.serialize_workspace_internal(window, cx).detach(); + this._schedule_serialize_workspace.take(); + }) + .log_err(); + })); } } @@ -5689,7 +5734,6 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); - let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); @@ -6894,10 +6938,13 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating + | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut => return Err(ErrorCode::SignedOut.into()), + Status::SignedOut | Status::AuthenticationError => { + return Err(ErrorCode::SignedOut.into()); + } Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index a864ece683..536af7b7b9 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.198.0" +version = "0.199.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] @@ -106,6 +106,7 @@ outline_panel.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true +settings_profile_selector.workspace = true profiling.workspace = true project.workspace = true project_panel.workspace = true diff --git a/crates/zed/resources/windows/zed.iss b/crates/zed/resources/windows/zed.iss index 51c1dd096e..2e76f35a0b 100644 --- a/crates/zed/resources/windows/zed.iss +++ b/crates/zed/resources/windows/zed.iss @@ -62,6 +62,7 @@ Source: "{#ResourcesDir}\Zed.exe"; DestDir: "{code:GetInstallDir}"; Flags: ignor Source: "{#ResourcesDir}\bin\*"; DestDir: "{code:GetInstallDir}\bin"; Flags: ignoreversion Source: "{#ResourcesDir}\tools\*"; DestDir: "{app}\tools"; Flags: ignoreversion Source: "{#ResourcesDir}\appx\*"; DestDir: "{app}\appx"; BeforeInstall: RemoveAppxPackage; AfterInstall: AddAppxPackage; Flags: ignoreversion; Check: IsWindows11OrLater +Source: "{#ResourcesDir}\amd_ags_x64.dll"; DestDir: "{app}"; Flags: ignoreversion [Icons] Name: "{group}\{#AppName}"; Filename: "{app}\{#AppExeName}.exe"; AppUserModelID: "{#AppUserId}" diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index d0b9c53397..c264135e5c 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -42,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; +use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -613,6 +613,7 @@ pub fn main() { language_selector::init(cx); toolchain_selector::init(cx); theme_selector::init(cx); + settings_profile_selector::init(cx); language_tools::init(cx); call::init(app_state.client.clone(), app_state.user_store.clone(), cx); notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); @@ -681,17 +682,9 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| match authenticate(client, &cx).await { - ConnectionResult::Timeout => log::error!("Timeout during initial auth"), - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during initial auth") - } - ConnectionResult::Result(r) => { - r.log_err(); - } - } + async move |cx| authenticate(client, &cx).await }) - .detach(); + .detach_and_log_err(cx); let urls: Vec<_> = args .paths_or_urls @@ -841,15 +834,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - match authenticate(client, &cx).await { - ConnectionResult::Timeout => { - log::error!("Timeout during open request handling") - } - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during open request handling") - } - ConnectionResult::Result(r) => r?, - }; + authenticate(client, &cx).await.log_err(); if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -899,18 +884,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut } } -async fn authenticate(client: Arc, cx: &AsyncApp) -> ConnectionResult<()> { +async fn authenticate(client: Arc, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - return client.authenticate_and_connect(false, cx).await; + client.sign_in_with_optional_connect(false, cx).await?; } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } - ConnectionResult::Result(Ok(())) + Ok(()) } async fn system_id() -> Result { diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index c72fe39d2d..af317edeee 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4354,6 +4354,7 @@ mod tests { "menu", "notebook", "notification_panel", + "onboarding", "outline", "outline_panel", "pane", @@ -4366,6 +4367,7 @@ mod tests { "repl", "rules_library", "search", + "settings_profile_selector", "snippets", "supermaven", "svg", diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index 78532b10b4..15d5659f03 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -24,6 +24,10 @@ pub fn app_menus() -> Vec { zed_actions::OpenDefaultKeymap, ), MenuItem::action("Open Project Settings", super::OpenProjectSettings), + MenuItem::action( + "Select Settings Profile...", + zed_actions::settings_profile_selector::Toggle, + ), MenuItem::action( "Select Theme...", zed_actions::theme_selector::Toggle::default(), diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 2e57152c62..480505338b 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -139,8 +139,7 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = - load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); + let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 825744572d..de98106fae 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,21 +12,19 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( - workspace: WeakEntity, project: Entity, cx: &mut AsyncApp, ) -> Task>> { - workspace - .update(cx, |_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + cx.update(|cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index f2e9d21b96..55dbea4fe1 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -1,10 +1,10 @@ -use client::{Client, UserStore}; +use client::{Client, DisableAiSettings, UserStore}; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; -use settings::SettingsStore; +use settings::{Settings as _, SettingsStore}; use smol::stream::StreamExt; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; @@ -90,10 +90,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); telemetry::event!( "Edit Prediction Provider Changed", @@ -195,16 +192,18 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context| { - editor.accept_partial_inline_completion(&Default::default(), window, cx); - }, - )) - .detach(); + if !DisableAiSettings::get_global(cx).disable_ai { + editor + .register_action(cx.listener( + |editor, + _: &editor::actions::AcceptPartialCopilotSuggestion, + window: &mut Window, + cx: &mut Context| { + editor.accept_partial_inline_completion(&Default::default(), window, cx); + }, + )) + .detach(); + } } fn assign_edit_prediction_provider( @@ -242,7 +241,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if user_store.read(cx).current_user().is_some() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 4b4bf016c4..64891b6973 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -260,14 +260,25 @@ pub mod icon_theme_selector { } } +pub mod settings_profile_selector { + use gpui::Action; + use schemars::JsonSchema; + use serde::Deserialize; + + #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] + #[action(namespace = settings_profile_selector)] + pub struct Toggle; +} + pub mod agent { use gpui::actions; actions!( agent, [ - /// Opens the agent configuration panel. - OpenConfiguration, + /// Opens the agent settings panel. + #[action(deprecated_aliases = ["agent::OpenConfiguration"])] + OpenSettings, /// Opens the agent onboarding modal. OpenOnboardingModal, /// Resets the agent onboarding state. diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 294d95aefd..26eeda3f22 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -40,7 +40,6 @@ log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -59,9 +58,11 @@ worktree.workspace = true zed_actions.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index d5c6be278b..f130c3a965 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -121,9 +121,10 @@ impl Dismissable for ZedPredictUpsell { } pub fn should_show_upsell_modal(user_store: &Entity, cx: &App) -> bool { - match user_store.read(cx).current_user_has_accepted_terms() { - Some(true) => !ZedPredictUpsell::dismissed(), - Some(false) | None => true, + if user_store.read(cx).has_accepted_terms_of_service() { + !ZedPredictUpsell::dismissed() + } else { + true } } @@ -145,14 +146,14 @@ pub struct InlineCompletion { input_events: Arc, input_excerpt: Arc, output_excerpt: Arc, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, response_received_at: Instant, } impl InlineCompletion { fn latency(&self) -> Duration { self.response_received_at - .duration_since(self.request_sent_at) + .duration_since(self.buffer_snapshotted_at) } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { @@ -226,12 +227,9 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, user_store: Entity, - _user_store_subscription: Subscription, license_detection_watchers: HashMap>, } @@ -306,22 +304,7 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), user_store, } @@ -408,104 +391,48 @@ impl Zeta { + Send + 'static, { + let buffer = buffer.clone(); + let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, cx); - let diagnostic_groups = snapshot.diagnostic_groups(None); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let events = self.events.clone(); - let path: Arc = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let zeta = cx.entity(); + let events = self.events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let buffer = buffer.clone(); - - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); - let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { - Some( - diagnostic_groups - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - - Some(( - language_server.name(), - diagnostic_group.resolve::(&snapshot), - )) - }) - .collect::>(), - ) - } else { - None - }; + let full_path: Arc = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let full_path_str = full_path.to_string_lossy().to_string(); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); + let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let gather_task = gather_context( + project, + full_path_str, + &snapshot, + cursor_point, + make_events_prompt, + can_collect_data, + cx, + ); cx.spawn(async move |this, cx| { - let request_sent_at = Instant::now(); - - struct BackgroundValues { - input_events: String, - input_excerpt: String, - speculated_output: String, - editable_range: Range, - input_outline: String, - } - - let values = cx - .background_spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &path, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS); - let input_outline = prompt_for_outline(&snapshot); - - anyhow::Ok(BackgroundValues { - input_events, - input_excerpt: input_excerpt.prompt, - speculated_output: input_excerpt.speculated_output, - editable_range: input_excerpt.editable_range.to_offset(&snapshot), - input_outline, - }) - } - }) - .await?; + let GatherContextOutput { + body, + editable_range, + } = gather_task.await?; log::debug!( "Events:\n{}\nExcerpt:\n{:?}", - values.input_events, - values.input_excerpt + body.input_events, + body.input_excerpt ); - let body = PredictEditsBody { - input_events: values.input_events.clone(), - input_excerpt: values.input_excerpt.clone(), - speculated_output: Some(values.speculated_output), - outline: Some(values.input_outline.clone()), - can_collect_data, - diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { - diagnostic_groups - .into_iter() - .map(|(name, diagnostic_group)| { - Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) - }) - .collect::>>() - .log_err() - }), - }; + let input_outline = body.outline.clone().unwrap_or_default(); + let input_events = body.input_events.clone(); + let input_excerpt = body.input_excerpt.clone(); let response = perform_predict_edits(PerformPredictEditsParams { client, @@ -563,13 +490,13 @@ impl Zeta { response, buffer, &snapshot, - values.editable_range, + editable_range, cursor_offset, - path, - values.input_outline, - values.input_events, - values.input_excerpt, - request_sent_at, + full_path, + input_outline, + input_events, + input_excerpt, + buffer_snapshotted_at, &cx, ) .await @@ -768,7 +695,7 @@ and then another ) } - fn perform_predict_edits( + pub fn perform_predict_edits( params: PerformPredictEditsParams, ) -> impl Future)>> { async move { @@ -923,7 +850,7 @@ and then another input_outline: String, input_events: String, input_excerpt: String, - request_sent_at: Instant, + buffer_snapshotted_at: Instant, cx: &AsyncApp, ) -> Task>> { let snapshot = snapshot.clone(); @@ -969,7 +896,7 @@ and then another input_events: input_events.into(), input_excerpt: input_excerpt.into(), output_excerpt, - request_sent_at, + buffer_snapshotted_at, response_received_at: Instant::now(), })) }) @@ -1153,7 +1080,7 @@ and then another } } -struct PerformPredictEditsParams { +pub struct PerformPredictEditsParams { pub client: Arc, pub llm_token: LlmApiToken, pub app_version: SemanticVersion, @@ -1228,6 +1155,77 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } +pub struct GatherContextOutput { + pub body: PredictEditsBody, + pub editable_range: Range, +} + +pub fn gather_context( + project: Option<&Entity>, + full_path_str: String, + snapshot: &BufferSnapshot, + cursor_point: language::Point, + make_events_prompt: impl FnOnce() -> String + Send + 'static, + can_collect_data: bool, + cx: &App, +) -> Task> { + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups: Vec<(String, serde_json::Value)> = + if let Some(local_lsp_store) = local_lsp_store { + snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + let diagnostic_group = diagnostic_group.resolve::(&snapshot); + let language_server_name = language_server.name().to_string(); + let serialized = serde_json::to_value(diagnostic_group).unwrap(); + Some((language_server_name, serialized)) + }) + .collect::>() + } else { + Vec::new() + }; + + cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let diagnostic_groups = if diagnostic_groups.is_empty() { + None + } else { + Some(diagnostic_groups) + }; + + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let input_events = make_events_prompt(); + let input_outline = prompt_for_outline(&snapshot); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + speculated_output: Some(input_excerpt.speculated_output), + outline: Some(input_outline), + can_collect_data, + diagnostic_groups, + }; + + Ok(GatherContextOutput { + body, + editable_range, + }) + } + }) +} + fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { let mut input_outline = String::new(); @@ -1278,7 +1276,7 @@ struct RegisteredBuffer { } #[derive(Clone)] -enum Event { +pub enum Event { BufferChange { old_snapshot: BufferSnapshot, new_snapshot: BufferSnapshot, @@ -1573,7 +1571,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .user_store + .read(cx) + .has_accepted_terms_of_service() } fn is_refreshing(&self) -> bool { @@ -1588,7 +1591,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1600,7 +1603,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider .zeta .read(cx) .user_store - .read_with(cx, |user_store, _| { + .read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { @@ -1817,13 +1820,14 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { + use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; @@ -1856,7 +1860,7 @@ mod tests { input_events: "".into(), input_excerpt: "".into(), output_excerpt: "".into(), - request_sent_at: Instant::now(), + buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), }; @@ -2027,28 +2031,45 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2056,13 +2077,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2079,20 +2093,36 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2100,9 +2130,10 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2111,13 +2142,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml new file mode 100644 index 0000000000..e77351c219 --- /dev/null +++ b/crates/zeta_cli/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "zeta_cli" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "zeta" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +clap.workspace = true +client.workspace = true +debug_adapter_extension.workspace = true +extension.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +language.workspace = true +language_extension.workspace = true +language_model.workspace = true +language_models.workspace = true +languages = { workspace = true, features = ["load-grammars"] } +node_runtime.workspace = true +paths.workspace = true +project.workspace = true +prompt_store.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +shellexpand.workspace = true +terminal_view.workspace = true +util.workspace = true +watch.workspace = true +workspace-hack.workspace = true +zeta.workspace = true +smol.workspace = true diff --git a/crates/zeta_cli/LICENSE-GPL b/crates/zeta_cli/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/zeta_cli/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta_cli/build.rs b/crates/zeta_cli/build.rs new file mode 100644 index 0000000000..ccbb54c5b4 --- /dev/null +++ b/crates/zeta_cli/build.rs @@ -0,0 +1,14 @@ +fn main() { + let cargo_toml = + std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml"); + let version = cargo_toml + .lines() + .find(|line| line.starts_with("version = ")) + .expect("Version not found in crates/zed/Cargo.toml") + .split('=') + .nth(1) + .expect("Invalid version format") + .trim() + .trim_matches('"'); + println!("cargo:rustc-env=ZED_PKG_VERSION={}", version); +} diff --git a/crates/zeta_cli/src/headless.rs b/crates/zeta_cli/src/headless.rs new file mode 100644 index 0000000000..959bb91a8f --- /dev/null +++ b/crates/zeta_cli/src/headless.rs @@ -0,0 +1,128 @@ +use client::{Client, ProxySettings, UserStore}; +use extension::ExtensionHostProxy; +use fs::RealFs; +use gpui::http_client::read_proxy_from_env; +use gpui::{App, AppContext, Entity}; +use gpui_tokio::Tokio; +use language::LanguageRegistry; +use language_extension::LspAccess; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; +use project::Project; +use project::project_settings::ProjectSettings; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use settings::{Settings, SettingsStore}; +use std::path::PathBuf; +use std::sync::Arc; +use util::ResultExt as _; + +/// Headless subset of `workspace::AppState`. +pub struct ZetaCliAppState { + pub languages: Arc, + pub client: Arc, + pub user_store: Entity, + pub fs: Arc, + pub node_runtime: NodeRuntime, +} + +// TODO: dedupe with crates/eval/src/eval.rs +pub fn init(cx: &mut App) -> ZetaCliAppState { + let app_version = AppVersion::load(env!("ZED_PKG_VERSION")); + release_channel::init(app_version, cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + app_version, + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new( + git_binary_path, + cx.background_executor().clone(), + )); + + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + extension::init(cx); + + let (mut tx, rx) = watch::channel(None); + cx.observe_global::(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version, + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client(), None, rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + + language::init(cx); + debug_adapter_extension::init(extension_host_proxy.clone(), cx); + language_extension::init( + LspAccess::Noop, + extension_host_proxy.clone(), + languages.clone(), + ); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); + prompt_store::init(cx); + terminal_view::init(cx); + + ZetaCliAppState { + languages, + client, + user_store, + fs, + node_runtime, + } +} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs new file mode 100644 index 0000000000..c5374b56c9 --- /dev/null +++ b/crates/zeta_cli/src/main.rs @@ -0,0 +1,376 @@ +mod headless; + +use anyhow::{Result, anyhow}; +use clap::{Args, Parser, Subcommand}; +use futures::channel::mpsc; +use futures::{FutureExt as _, StreamExt as _}; +use gpui::{AppContext, Application, AsyncApp}; +use gpui::{Entity, Task}; +use language::Bias; +use language::Buffer; +use language::Point; +use language_model::LlmApiToken; +use project::{Project, ProjectPath}; +use release_channel::AppVersion; +use reqwest_client::ReqwestClient; +use std::path::{Path, PathBuf}; +use std::process::exit; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; + +use crate::headless::ZetaCliAppState; + +#[derive(Parser, Debug)] +#[command(name = "zeta")] +struct ZetaCliArgs { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Context(ContextArgs), + Predict { + #[arg(long)] + predict_edits_body: Option, + #[clap(flatten)] + context_args: Option, + }, +} + +#[derive(Debug, Args)] +#[group(requires = "worktree")] +struct ContextArgs { + #[arg(long)] + worktree: PathBuf, + #[arg(long)] + cursor: CursorPosition, + #[arg(long)] + use_language_server: bool, + #[arg(long)] + events: Option, +} + +#[derive(Debug, Clone)] +enum FileOrStdin { + File(PathBuf), + Stdin, +} + +impl FileOrStdin { + async fn read_to_string(&self) -> Result { + match self { + FileOrStdin::File(path) => smol::fs::read_to_string(path).await, + FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, + } + } +} + +impl FromStr for FileOrStdin { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + match s { + "-" => Ok(Self::Stdin), + _ => Ok(Self::File(PathBuf::from_str(s)?)), + } + } +} + +#[derive(Debug, Clone)] +struct CursorPosition { + path: PathBuf, + point: Point, +} + +impl FromStr for CursorPosition { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != 3 { + return Err(anyhow!( + "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", + s + )); + } + + let path = PathBuf::from(parts[0]); + let line: u32 = parts[1] + .parse() + .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; + let column: u32 = parts[2] + .parse() + .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; + + // Convert from 1-based to 0-based indexing + let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); + + Ok(CursorPosition { path, point }) + } +} + +async fn get_context( + args: ContextArgs, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result { + let ContextArgs { + worktree: worktree_path, + cursor, + use_language_server, + events, + } = args; + + let worktree_path = worktree_path.canonicalize()?; + if cursor.path.is_absolute() { + return Err(anyhow!("Absolute paths are not supported in --cursor")); + } + + let (project, _lsp_open_handle, buffer) = if use_language_server { + let (project, lsp_open_handle, buffer) = + open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?; + (Some(project), Some(lsp_open_handle), buffer) + } else { + let abs_path = worktree_path.join(&cursor.path); + let content = smol::fs::read_to_string(&abs_path).await?; + let buffer = cx.new(|cx| Buffer::local(content, cx))?; + (None, None, buffer) + }; + + let worktree_name = worktree_path + .file_name() + .ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?; + let full_path_str = PathBuf::from(worktree_name) + .join(&cursor.path) + .to_string_lossy() + .to_string(); + + let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; + let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); + if clipped_cursor != cursor.point { + let max_row = snapshot.max_point().row; + if cursor.point.row < max_row { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (line length is {})", + cursor.point, + snapshot.line_len(cursor.point.row) + )); + } else { + return Err(anyhow!( + "Cursor position {:?} is out of bounds (max row is {})", + cursor.point, + max_row + )); + } + } + + let events = match events { + Some(events) => events.read_to_string().await?, + None => String::new(), + }; + let can_collect_data = false; + cx.update(|cx| { + gather_context( + project.as_ref(), + full_path_str, + &snapshot, + clipped_cursor, + move || events, + can_collect_data, + cx, + ) + })? + .await +} + +pub async fn open_buffer_with_language_server( + worktree_path: &Path, + path: &Path, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result<(Entity, Entity>, Entity)> { + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(worktree_path, true, cx) + })? + .await?; + + let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { + worktree_id: worktree.id(), + path: path.to_path_buf().into(), + })?; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let lsp_open_handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + })?; + + let log_prefix = path.to_string_lossy().to_string(); + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; + + Ok((project, lsp_open_handle, buffer)) +} + +// TODO: Dedupe with similar function in crates/eval/src/instance.rs +pub fn wait_for_lang_server( + project: &Entity, + buffer: &Entity, + log_prefix: String, + cx: &mut AsyncApp, +) -> Task> { + println!("{}⏵ Waiting for language server", log_prefix); + + let (mut tx, mut rx) = mpsc::channel(1); + + let lsp_store = project + .read_with(cx, |project, _| project.lsp_store()) + .unwrap(); + + let has_lang_server = buffer + .update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .is_some() + }) + }) + .unwrap_or(false); + + if has_lang_server { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let subscriptions = [ + cx.subscribe(&lsp_store, { + let log_prefix = log_prefix.clone(); + move |_, event, _| match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + client::proto::LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{}⟲ {message}", log_prefix), + _ => {} + } + }), + cx.subscribe(&project, { + let buffer = buffer.clone(); + move |project, event, cx| match event { + project::Event::LanguageServerAdded(_, _, _) => { + let buffer = buffer.clone(); + project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + .detach(); + } + project::Event::DiskBasedDiagnosticsFinished { .. } => { + tx.try_send(()).ok(); + } + _ => {} + } + }), + ]; + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}⚑ Language server idle", log_prefix); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + anyhow::bail!("LSP wait timed out after 5 minutes"); + } + }; + drop(subscriptions); + result + }) +} + +fn main() { + let args = ZetaCliArgs::parse(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client); + + app.run(move |cx| { + let app_state = Arc::new(headless::init(cx)); + cx.spawn(async move |cx| { + let result = match args.command { + Commands::Context(context_args) => get_context(context_args, &app_state, cx) + .await + .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), + Commands::Predict { + predict_edits_body, + context_args, + } => { + cx.spawn(async move |cx| { + let app_version = cx.update(|cx| AppVersion::global(cx))?; + app_state.client.sign_in(true, cx).await?; + let llm_token = LlmApiToken::default(); + llm_token.refresh(&app_state.client).await?; + + let predict_edits_body = + if let Some(predict_edits_body) = predict_edits_body { + serde_json::from_str(&predict_edits_body.read_to_string().await?)? + } else if let Some(context_args) = context_args { + get_context(context_args, &app_state, cx).await?.body + } else { + return Err(anyhow!( + "Expected either --predict-edits-body-file \ + or the required args of the `context` command." + )); + }; + + let (response, _usage) = + Zeta::perform_predict_edits(PerformPredictEditsParams { + client: app_state.client.clone(), + llm_token, + app_version, + body: predict_edits_body, + }) + .await?; + + Ok(response.output_excerpt) + }) + .await + } + }; + match result { + Ok(output) => { + println!("{}", output); + let _ = cx.update(|cx| cx.quit()); + } + Err(e) => { + eprintln!("Failed: {:?}", e); + exit(1); + } + } + }) + .detach(); + }); +} diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md index 315ae21929..ff97bcb8ee 100644 --- a/docs/src/ai/agent-settings.md +++ b/docs/src/ai/agent-settings.md @@ -108,7 +108,7 @@ Specify a custom temperature for a provider and/or model: ## Agent Panel Settings {#agent-panel-settings} -Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open configuration` action or by the dropdown menu on the top-right corner of the panel. +Note that some of these settings are also surfaced in the Agent Panel's settings UI, which you can access either via the `agent: open settings` action or by the dropdown menu on the top-right corner of the panel. ### Default View diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index cb55c1c94e..bd208e94ac 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -86,7 +86,7 @@ To do this: 1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). 2. Create security credentials for that User, save them and keep them secure. -3. Open the Agent Configuration with (`agent: open configuration`) and go to the Amazon Bedrock section +3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section 4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields. #### Cross-Region Inference @@ -113,7 +113,7 @@ You can use Anthropic models by choosing them via the model dropdown in the Agen 1. Sign up for Anthropic and [create an API key](https://console.anthropic.com/settings/keys) 2. Make sure that your Anthropic account has credits -3. Open the settings view (`agent: open configuration`) and go to the Anthropic section +3. Open the settings view (`agent: open settings`) and go to the Anthropic section 4. Enter your Anthropic API key Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API. @@ -168,7 +168,7 @@ You can configure a model to use [extended thinking](https://docs.anthropic.com/ > ✅ Supports tool use 1. Visit the DeepSeek platform and [create an API key](https://platform.deepseek.com/api_keys) -2. Open the settings view (`agent: open configuration`) and go to the DeepSeek section +2. Open the settings view (`agent: open settings`) and go to the DeepSeek section 3. Enter your DeepSeek API key The DeepSeek API key will be saved in your keychain. @@ -213,7 +213,7 @@ You can also modify the `api_url` to use a custom endpoint if needed. You can use GitHub Copilot Chat with the Zed agent by choosing it via the model dropdown in the Agent Panel. -1. Open the settings view (`agent: open configuration`) and go to the GitHub Copilot Chat section +1. Open the settings view (`agent: open settings`) and go to the GitHub Copilot Chat section 2. Click on `Sign in to use GitHub Copilot`, follow the steps shown in the modal. Alternatively, you can provide an OAuth token via the `GH_COPILOT_TOKEN` environment variable. @@ -229,7 +229,7 @@ To use Copilot Enterprise with Zed (for both agent and inline completions), you You can use Gemini models with the Zed agent by choosing it via the model dropdown in the Agent Panel. 1. Go to the Google AI Studio site and [create an API key](https://aistudio.google.com/app/apikey). -2. Open the settings view (`agent: open configuration`) and go to the Google AI section +2. Open the settings view (`agent: open settings`) and go to the Google AI section 3. Enter your Google AI API key and press enter. The Google AI API key will be saved in your keychain. @@ -288,7 +288,7 @@ Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless# > ✅ Supports tool use 1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) -2. Open the configuration view (`agent: open configuration`) and navigate to the Mistral section +2. Open the configuration view (`agent: open settings`) and navigate to the Mistral section 3. Enter your Mistral API key The Mistral API key will be saved in your keychain. @@ -399,7 +399,7 @@ If the model is tagged with `vision` in the Ollama catalog, set this option and 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) 2. Make sure that your OpenAI account has credits -3. Open the settings view (`agent: open configuration`) and go to the OpenAI section +3. Open the settings view (`agent: open settings`) and go to the OpenAI section 4. Enter your OpenAI API key The OpenAI API key will be saved in your keychain. @@ -480,7 +480,7 @@ OpenRouter provides access to multiple AI models through a single API. It suppor 1. Visit [OpenRouter](https://openrouter.ai) and create an account 2. Generate an API key from your [OpenRouter keys page](https://openrouter.ai/keys) -3. Open the settings view (`agent: open configuration`) and go to the OpenRouter section +3. Open the settings view (`agent: open settings`) and go to the OpenRouter section 4. Enter your OpenRouter API key The OpenRouter API key will be saved in your keychain. @@ -551,7 +551,7 @@ You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. 1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) -2. Open the settings view (`agent: open configuration`) and go to the **xAI** section +2. Open the settings view (`agent: open settings`) and go to the **xAI** section 3. Enter your xAI API key The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. diff --git a/docs/src/ai/mcp.md b/docs/src/ai/mcp.md index 5aef3d3d72..dfe3e4bdb9 100644 --- a/docs/src/ai/mcp.md +++ b/docs/src/ai/mcp.md @@ -50,7 +50,7 @@ You can connect them by adding their commands directly to your `settings.json`, } ``` -Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open configuration` action). +Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action). From there, you can add it through the modal that appears when you click the "Add Custom Server" button. ## Using MCP Servers diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 556bad22b4..5fd27abad6 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2588,6 +2588,7 @@ List of `integer` column numbers "font_features": null, "font_size": null, "line_height": "comfortable", + "minimum_contrast": 45, "option_as_meta": false, "button": true, "shell": "system", @@ -2883,6 +2884,30 @@ See Buffer Font Features } ``` +### Terminal: Minimum Contrast + +- Description: Controls the minimum contrast between foreground and background colors in the terminal. Uses the APCA (Accessible Perceptual Contrast Algorithm) for color adjustments. Set this to 0 to disable this feature. +- Setting: `minimum_contrast` +- Default: `45` + +**Options** + +`integer` values from 0 to 106. Common recommended values: + +- `0`: No contrast adjustment +- `45`: Minimum for large fluent text (default) +- `60`: Minimum for other content text +- `75`: Minimum for body text +- `90`: Preferred for body text + +```json +{ + "terminal": { + "minimum_contrast": 45 + } +} +``` + ### Terminal: Option As Meta - Description: Re-interprets the option keys to act like a 'meta' key, like in Emacs. diff --git a/script/bundle-windows.ps1 b/script/bundle-windows.ps1 index 2f751f1d10..8ae0212491 100644 --- a/script/bundle-windows.ps1 +++ b/script/bundle-windows.ps1 @@ -136,11 +136,22 @@ function SignZedAndItsFriends { & "$innoDir\sign.ps1" $files } +function DownloadAMDGpuServices { + # If you update the AGS SDK version, please also update the version in `crates/gpui/src/platform/windows/directx_renderer.rs` + $url = "https://codeload.github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/zip/refs/tags/v6.3.0" + $zipPath = ".\AGS_SDK_v6.3.0.zip" + # Download the AGS SDK zip file + Invoke-WebRequest -Uri $url -OutFile $zipPath + # Extract the AGS SDK zip file + Expand-Archive -Path $zipPath -DestinationPath "." -Force +} + function CollectFiles { Move-Item -Path "$innoDir\zed_explorer_command_injector.appx" -Destination "$innoDir\appx\zed_explorer_command_injector.appx" -Force Move-Item -Path "$innoDir\zed_explorer_command_injector.dll" -Destination "$innoDir\appx\zed_explorer_command_injector.dll" -Force Move-Item -Path "$innoDir\cli.exe" -Destination "$innoDir\bin\zed.exe" -Force Move-Item -Path "$innoDir\auto_update_helper.exe" -Destination "$innoDir\tools\auto_update_helper.exe" -Force + Move-Item -Path ".\AGS_SDK-6.3.0\ags_lib\lib\amd_ags_x64.dll" -Destination "$innoDir\amd_ags_x64.dll" -Force } function BuildInstaller { @@ -211,7 +222,6 @@ function BuildInstaller { # Windows runner 2022 default has iscc in PATH, https://github.com/actions/runner-images/blob/main/images/windows/Windows2022-Readme.md # Currently, we are using Windows 2022 runner. # Windows runner 2025 doesn't have iscc in PATH for now, https://github.com/actions/runner-images/issues/11228 - # $innoSetupPath = "iscc.exe" $innoSetupPath = "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" $definitions = @{ @@ -268,6 +278,7 @@ BuildZedAndItsFriends MakeAppx SignZedAndItsFriends ZipZedAndItsFriendsDebug +DownloadAMDGpuServices CollectFiles BuildInstaller diff --git a/script/zed-local b/script/zed-local index 2568931246..99d9308232 100755 --- a/script/zed-local +++ b/script/zed-local @@ -213,7 +213,7 @@ setTimeout(() => { platform === "win32" ? "http://127.0.0.1:8080/rpc" : "http://localhost:8080/rpc", - ZED_ADMIN_API_TOKEN: "secret", + ZED_ADMIN_API_TOKEN: "internal-api-key-secret", ZED_WINDOW_SIZE: size, ZED_CLIENT_CHECKSUM_SEED: "development-checksum-seed", RUST_LOG: process.env.RUST_LOG || "info", diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index e5123d5ab3..4196696f47 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -558,7 +558,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } scopeguard = { version = "1" } @@ -572,7 +571,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-pc-windows-msvc.build-dependencies] codespan-reporting = { version = "0.12" } @@ -582,7 +581,6 @@ getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-f getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -naga = { version = "25", features = ["spv-out", "wgsl-in"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event"] } @@ -597,7 +595,7 @@ windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } -windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } +windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } [target.x86_64-unknown-linux-musl.dependencies] aes = { version = "0.8", default-features = false, features = ["zeroize"] } diff --git a/typos.toml b/typos.toml index 7f1c6e04f1..336a829a44 100644 --- a/typos.toml +++ b/typos.toml @@ -71,6 +71,10 @@ extend-ignore-re = [ # Not an actual typo but an intentionally invalid color, in `color_extractor` "#fof", # Stripped version of reserved keyword `type` - "typ" + "typ", + # AMD GPU Services + "ags", + # AMD GPU Services + "AGS" ] check-filename = true