PR to get zed2 into main.

Because we have taken the approach of porting crates by renaming them to
`-2` we will need to manually reapply any changes that were made to
ported crates since the `zed2` branch diverged from master.

I think this is the list of PRs that may need changes ported manually.
Any changes to the following crates may need to be moved from crate `x`
to `x2` for each of the following crates: `audio call client copilot db
feature_flags fs fuzzy gpui install_cli language lsp prettier project
rpc settings storybook terminal theme ui zed`.

- [x] f75eb3f62 Conrad Irwin (origin/main, origin/HEAD, main) Merge
branch 'more-signing' (17 hours ago)
- [x] 832026a0a Julia Limit language server reinstallation attempts
(#3177) (18 hours ago)
- [x] 4539cef6d Julia Capture language server stderr during startup/init
and log if failure (#3175) (21 hours ago)
- [x] e6f2288a0 Conrad Irwin Don't use function_name in vim tests
(#3171) (2 days ago)
- [x] f67f42779 Mikayla Maki Rename IIFE to maybe (#3165) (2 days ago)
- [ ] 90f65ec9f Max Brunsfeld Remove logic for multiple channel parents
(#3162) (2 days ago)
- [ ] 4f859e025 Conrad Irwin link to channel notes (#3167) (2 days ago)
- [ ] b8bd070a8 Conrad Irwin Fix panic by disallowing multiple room
joins (#3149) (3 days ago)
- [ ] cc9e92857 Max Brunsfeld Guest roles (#3140) (3 days ago)
- [x] b090cefdd Kirill Bulatov Rework prettier tests (#3160) (3 days
ago)
- [ ] ff497810d Kyle Caverly move keychain access into semantic index as
opposed to on init (#3158) (3 days ago)
- [x] 2b95db087 Conrad Irwin Fix infinite loop in select all (#3154) (3
days ago)
- [ ] a5836b033 Max Brunsfeld Add chat mentions and a notifications
panel (#3121) (4 days ago)
- [ ] ef1a69156 Kyle Caverly update semantic search to use keychain as
fallback (#3151) (6 days ago)
- [x] 26638748b Kirill Bulatov Move prettier parsers data into languages
from LSP adapters (#3150) (6 days ago)
- [ ] 0dae0f602 Conrad Irwin pixel columns (#3052) (7 days ago)
- [x] cc7df91cc Julia Whoops (#3146) (7 days ago)
- [x] 808976ee2 Julia Magic incantations for Tailwind autocomplete in
more languages (#3141) (7 days ago)
- [ ] cc390ba86 Conrad Irwin Start writing role to database (#3120) (10
days ago)
- [ ] 2795091f0 Kyle Caverly Introduce Context Retrieval in Inline
Assistant (#3097) (10 days ago)
- [x] b168bded1 Conrad Irwin New entitlements: (#3118) (10 days ago)
- [x] 247cdb1e1 Joseph T. Lyons Fix telemetry-related crash on start up
(#3131) (11 days ago)
- [ ] 2323fd17b Julia Autocomplete docs (#3126) (2 weeks ago)
- [x] 16d9d77d8 Kirill Bulatov Update diagnostics indicator when
diagnostics are udpated (#3128) (2 weeks ago)
- [ ] 634202340 Kirill Bulatov Remove zed -> ... -> semantic_index ->
zed Cargo dependency cycle (#3127) (2 weeks ago)

Note: this list does not include any PRs that did not change crates that
have been converted; it also does not include any commits that were
pushed directly to master.

### To figure out what needs migrating, run:

```
git diff COMMIT^..COMMIT -- crates/audio crates/call crates/client crates/copilot crates/db crates/feature_flags crates/fs crates/fuzzy crates/gpui crates/install_cli crates/language crates/lsp crates/prettier crates/project crates/rpc crates/settings crates/storybook crates/terminal crates/theme crates/ui crates/zed
```
This commit is contained in:
Max Brunsfeld 2023-10-31 10:12:17 -07:00 committed by GitHub
commit ed5f1d3bdd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
526 changed files with 124188 additions and 2315 deletions

750
Cargo.lock generated
View file

@ -108,6 +108,33 @@ dependencies = [
"util",
]
[[package]]
name = "ai2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui2",
"isahc",
"language2",
"lazy_static",
"log",
"matrixmultiply",
"ordered-float 2.10.0",
"parking_lot 0.11.2",
"parse_duration",
"postage",
"rand 0.8.5",
"regex",
"rusqlite",
"serde",
"serde_json",
"tiktoken-rs",
"util",
]
[[package]]
name = "alacritty_config"
version = "0.1.2-dev"
@ -659,6 +686,20 @@ dependencies = [
"util",
]
[[package]]
name = "audio2"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"futures 0.3.28",
"gpui2",
"log",
"parking_lot 0.11.2",
"rodio",
"util",
]
[[package]]
name = "auto_update"
version = "0.1.0"
@ -776,6 +817,17 @@ dependencies = [
"rustc-demangle",
]
[[package]]
name = "backtrace-on-stack-overflow"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fd2d70527f3737a1ad17355e260706c1badebabd1fa06a7a053407380df841b"
dependencies = [
"backtrace",
"libc",
"nix 0.23.2",
]
[[package]]
name = "base64"
version = "0.13.1"
@ -1104,6 +1156,32 @@ dependencies = [
"util",
]
[[package]]
name = "call2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-broadcast",
"audio2",
"client2",
"collections",
"fs2",
"futures 0.3.28",
"gpui2",
"language2",
"live_kit_client",
"log",
"media",
"postage",
"project2",
"schemars",
"serde",
"serde_derive",
"serde_json",
"settings2",
"util",
]
[[package]]
name = "cap-fs-ext"
version = "0.24.4"
@ -1176,6 +1254,25 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6"
[[package]]
name = "cbindgen"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da6bc11b07529f16944307272d5bd9b22530bc7d05751717c9d416586cedab49"
dependencies = [
"clap 3.2.25",
"heck 0.4.1",
"indexmap 1.9.3",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 1.0.109",
"tempfile",
"toml 0.5.11",
]
[[package]]
name = "cc"
version = "1.0.83"
@ -1422,6 +1519,43 @@ dependencies = [
"uuid 1.4.1",
]
[[package]]
name = "client2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-recursion 0.3.2",
"async-tungstenite",
"collections",
"db2",
"feature_flags2",
"futures 0.3.28",
"gpui2",
"image",
"lazy_static",
"log",
"parking_lot 0.11.2",
"postage",
"rand 0.8.5",
"rpc2",
"schemars",
"serde",
"serde_derive",
"settings",
"settings2",
"smol",
"sum_tree",
"sysinfo",
"tempfile",
"text",
"thiserror",
"time",
"tiny_http",
"url",
"util",
"uuid 1.4.1",
]
[[package]]
name = "clock"
version = "0.1.0"
@ -1724,6 +1858,33 @@ dependencies = [
"util",
]
[[package]]
name = "copilot2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"async-tar",
"clock",
"collections",
"context_menu",
"fs",
"futures 0.3.28",
"gpui2",
"language2",
"log",
"lsp2",
"node_runtime",
"parking_lot 0.11.2",
"rpc",
"serde",
"serde_derive",
"settings2",
"smol",
"theme",
"util",
]
[[package]]
name = "copilot_button"
version = "0.1.0"
@ -2154,6 +2315,28 @@ dependencies = [
"util",
]
[[package]]
name = "db2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"collections",
"env_logger 0.9.3",
"gpui2",
"indoc",
"lazy_static",
"log",
"parking_lot 0.11.2",
"serde",
"serde_derive",
"smol",
"sqlez",
"sqlez_macros",
"tempdir",
"util",
]
[[package]]
name = "deflate"
version = "0.8.6"
@ -2621,6 +2804,14 @@ dependencies = [
"gpui",
]
[[package]]
name = "feature_flags2"
version = "0.1.0"
dependencies = [
"anyhow",
"gpui2",
]
[[package]]
name = "feedback"
version = "0.1.0"
@ -2866,6 +3057,34 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "fs2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"collections",
"fsevent",
"futures 0.3.28",
"git2",
"gpui2",
"lazy_static",
"libc",
"log",
"parking_lot 0.11.2",
"regex",
"rope",
"serde",
"serde_derive",
"serde_json",
"smol",
"sum_tree",
"tempfile",
"text",
"time",
"util",
]
[[package]]
name = "fsevent"
version = "2.0.2"
@ -3044,6 +3263,14 @@ dependencies = [
"util",
]
[[package]]
name = "fuzzy2"
version = "0.1.0"
dependencies = [
"gpui2",
"util",
]
[[package]]
name = "fxhash"
version = "0.2.1"
@ -3258,20 +3485,64 @@ name = "gpui2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-task",
"backtrace",
"bindgen 0.65.1",
"bitflags 2.4.0",
"block",
"cbindgen",
"cocoa",
"collections",
"core-foundation",
"core-graphics",
"core-text",
"ctor",
"derive_more",
"dhat",
"env_logger 0.9.3",
"etagere",
"font-kit",
"foreign-types",
"futures 0.3.28",
"gpui",
"gpui2_macros",
"gpui_macros",
"image",
"itertools 0.10.5",
"lazy_static",
"log",
"media",
"metal",
"num_cpus",
"objc",
"ordered-float 2.10.0",
"parking",
"parking_lot 0.11.2",
"pathfinder_geometry",
"plane-split",
"png",
"postage",
"rand 0.8.5",
"refineable",
"rust-embed",
"resvg",
"schemars",
"seahash",
"serde",
"settings",
"serde_derive",
"serde_json",
"simplelog",
"slotmap",
"smallvec",
"theme",
"smol",
"sqlez",
"sum_tree",
"taffy",
"thiserror",
"time",
"tiny-skia",
"usvg",
"util",
"uuid 1.4.1",
"waker-fn",
]
[[package]]
@ -3698,6 +3969,17 @@ dependencies = [
"util",
]
[[package]]
name = "install_cli2"
version = "0.1.0"
dependencies = [
"anyhow",
"gpui2",
"log",
"smol",
"util",
]
[[package]]
name = "instant"
version = "0.1.12"
@ -3916,6 +4198,24 @@ dependencies = [
"workspace",
]
[[package]]
name = "journal2"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"dirs 4.0.0",
"editor",
"gpui2",
"log",
"schemars",
"serde",
"settings2",
"shellexpand",
"util",
"workspace",
]
[[package]]
name = "jpeg-decoder"
version = "0.1.22"
@ -4032,6 +4332,59 @@ dependencies = [
"util",
]
[[package]]
name = "language2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-broadcast",
"async-trait",
"client2",
"clock",
"collections",
"ctor",
"env_logger 0.9.3",
"futures 0.3.28",
"fuzzy2",
"git",
"globset",
"gpui2",
"indoc",
"lazy_static",
"log",
"lsp2",
"parking_lot 0.11.2",
"postage",
"rand 0.8.5",
"regex",
"rpc2",
"schemars",
"serde",
"serde_derive",
"serde_json",
"settings2",
"similar",
"smallvec",
"smol",
"sum_tree",
"text",
"theme2",
"tree-sitter",
"tree-sitter-elixir",
"tree-sitter-embedded-template",
"tree-sitter-heex",
"tree-sitter-html",
"tree-sitter-json 0.20.0",
"tree-sitter-markdown",
"tree-sitter-python",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-typescript",
"unicase",
"unindent",
"util",
]
[[package]]
name = "language_selector"
version = "0.1.0"
@ -4310,6 +4663,29 @@ dependencies = [
"url",
]
[[package]]
name = "lsp2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-pipe",
"collections",
"ctor",
"env_logger 0.9.3",
"futures 0.3.28",
"gpui2",
"log",
"lsp-types",
"parking_lot 0.11.2",
"postage",
"serde",
"serde_derive",
"serde_json",
"smol",
"unindent",
"util",
]
[[package]]
name = "mach"
version = "0.3.2"
@ -4446,6 +4822,13 @@ dependencies = [
"gpui",
]
[[package]]
name = "menu2"
version = "0.1.0"
dependencies = [
"gpui2",
]
[[package]]
name = "metal"
version = "0.21.0"
@ -4738,6 +5121,19 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "nix"
version = "0.23.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f3790c00a0150112de0f4cd161e3d7fc4b2d8a5542ffc35f099a2562aecb35c"
dependencies = [
"bitflags 1.3.2",
"cc",
"cfg-if 1.0.0",
"libc",
"memoffset 0.6.5",
]
[[package]]
name = "nix"
version = "0.24.3"
@ -4769,7 +5165,6 @@ dependencies = [
"async-tar",
"async-trait",
"futures 0.3.28",
"gpui",
"log",
"parking_lot 0.11.2",
"serde",
@ -5491,6 +5886,17 @@ version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
[[package]]
name = "plane-split"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c1f7d82649829ecdef8e258790b0587acf0a8403f0ce963473d8e918acc1643"
dependencies = [
"euclid",
"log",
"smallvec",
]
[[package]]
name = "plist"
version = "1.5.0"
@ -5621,6 +6027,27 @@ dependencies = [
"util",
]
[[package]]
name = "prettier2"
version = "0.1.0"
dependencies = [
"anyhow",
"client2",
"collections",
"fs2",
"futures 0.3.28",
"gpui2",
"language2",
"log",
"lsp2",
"node_runtime",
"parking_lot 0.11.2",
"serde",
"serde_derive",
"serde_json",
"util",
]
[[package]]
name = "pretty_assertions"
version = "1.4.0"
@ -5759,6 +6186,61 @@ dependencies = [
"util",
]
[[package]]
name = "project2"
version = "0.1.0"
dependencies = [
"aho-corasick",
"anyhow",
"async-trait",
"backtrace",
"client2",
"clock",
"collections",
"copilot2",
"ctor",
"db2",
"env_logger 0.9.3",
"fs2",
"fsevent",
"futures 0.3.28",
"fuzzy2",
"git",
"git2",
"globset",
"gpui2",
"ignore",
"itertools 0.10.5",
"language2",
"lazy_static",
"log",
"lsp2",
"node_runtime",
"parking_lot 0.11.2",
"postage",
"prettier2",
"pretty_assertions",
"rand 0.8.5",
"regex",
"rpc2",
"schemars",
"serde",
"serde_derive",
"serde_json",
"settings2",
"sha2 0.10.7",
"similar",
"smol",
"sum_tree",
"tempdir",
"terminal2",
"text",
"thiserror",
"toml 0.5.11",
"unindent",
"util",
]
[[package]]
name = "project_panel"
version = "0.1.0"
@ -6494,6 +6976,35 @@ dependencies = [
"zstd",
]
[[package]]
name = "rpc2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-lock",
"async-tungstenite",
"base64 0.13.1",
"clock",
"collections",
"ctor",
"env_logger 0.9.3",
"futures 0.3.28",
"gpui2",
"parking_lot 0.11.2",
"prost 0.8.0",
"prost-build",
"rand 0.8.5",
"rsa 0.4.0",
"serde",
"serde_derive",
"smol",
"smol-timeout",
"tempdir",
"tracing",
"util",
"zstd",
]
[[package]]
name = "rsa"
version = "0.4.0"
@ -7198,6 +7709,36 @@ dependencies = [
"util",
]
[[package]]
name = "settings2"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"feature_flags2",
"fs",
"fs2",
"futures 0.3.28",
"gpui2",
"indoc",
"lazy_static",
"postage",
"pretty_assertions",
"rust-embed",
"schemars",
"serde",
"serde_derive",
"serde_json",
"serde_json_lenient",
"smallvec",
"sqlez",
"toml 0.5.11",
"tree-sitter",
"tree-sitter-json 0.19.0",
"unindent",
"util",
]
[[package]]
name = "sha-1"
version = "0.9.8"
@ -7766,6 +8307,29 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "storybook2"
version = "0.1.0"
dependencies = [
"anyhow",
"backtrace-on-stack-overflow",
"chrono",
"clap 4.4.4",
"gpui2",
"itertools 0.11.0",
"log",
"rust-embed",
"serde",
"settings2",
"simplelog",
"smallvec",
"strum",
"theme",
"theme2",
"ui2",
"util",
]
[[package]]
name = "stringprep"
version = "0.1.4"
@ -8075,6 +8639,35 @@ dependencies = [
"util",
]
[[package]]
name = "terminal2"
version = "0.1.0"
dependencies = [
"alacritty_terminal",
"anyhow",
"db2",
"dirs 4.0.0",
"futures 0.3.28",
"gpui2",
"itertools 0.10.5",
"lazy_static",
"libc",
"mio-extras",
"ordered-float 2.10.0",
"procinfo",
"rand 0.8.5",
"schemars",
"serde",
"serde_derive",
"settings2",
"shellexpand",
"smallvec",
"smol",
"theme2",
"thiserror",
"util",
]
[[package]]
name = "terminal_view"
version = "0.1.0"
@ -8157,6 +8750,39 @@ dependencies = [
"util",
]
[[package]]
name = "theme2"
version = "0.1.0"
dependencies = [
"anyhow",
"fs",
"gpui2",
"indexmap 1.9.3",
"parking_lot 0.11.2",
"schemars",
"serde",
"serde_derive",
"serde_json",
"settings2",
"toml 0.5.11",
"util",
]
[[package]]
name = "theme_converter"
version = "0.1.0"
dependencies = [
"anyhow",
"clap 4.4.4",
"convert_case 0.6.0",
"gpui2",
"log",
"rust-embed",
"serde",
"simplelog",
"theme2",
]
[[package]]
name = "theme_selector"
version = "0.1.0"
@ -8964,6 +9590,21 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "ui2"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"gpui2",
"itertools 0.11.0",
"rand 0.8.5",
"serde",
"smallvec",
"strum",
"theme2",
]
[[package]]
name = "unicase"
version = "2.7.0"
@ -10285,6 +10926,105 @@ dependencies = [
"serde",
]
[[package]]
name = "zed2"
version = "0.109.0"
dependencies = [
"ai2",
"anyhow",
"async-compression",
"async-recursion 0.3.2",
"async-tar",
"async-trait",
"backtrace",
"call2",
"chrono",
"cli",
"client2",
"collections",
"copilot2",
"ctor",
"db2",
"env_logger 0.9.3",
"feature_flags2",
"fs2",
"fsevent",
"futures 0.3.28",
"fuzzy",
"gpui2",
"ignore",
"image",
"indexmap 1.9.3",
"install_cli",
"isahc",
"journal2",
"language2",
"language_tools",
"lazy_static",
"libc",
"log",
"lsp2",
"node_runtime",
"num_cpus",
"parking_lot 0.11.2",
"postage",
"project2",
"rand 0.8.5",
"regex",
"rpc2",
"rsa 0.4.0",
"rust-embed",
"schemars",
"serde",
"serde_derive",
"serde_json",
"settings2",
"shellexpand",
"simplelog",
"smallvec",
"smol",
"sum_tree",
"tempdir",
"text",
"theme2",
"thiserror",
"tiny_http",
"toml 0.5.11",
"tree-sitter",
"tree-sitter-bash",
"tree-sitter-c",
"tree-sitter-cpp",
"tree-sitter-css",
"tree-sitter-elixir",
"tree-sitter-elm",
"tree-sitter-embedded-template",
"tree-sitter-glsl",
"tree-sitter-go",
"tree-sitter-heex",
"tree-sitter-html",
"tree-sitter-json 0.20.0",
"tree-sitter-lua",
"tree-sitter-markdown",
"tree-sitter-nix",
"tree-sitter-nu",
"tree-sitter-php",
"tree-sitter-python",
"tree-sitter-racket",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-scheme",
"tree-sitter-svelte",
"tree-sitter-toml",
"tree-sitter-typescript",
"tree-sitter-vue",
"tree-sitter-yaml",
"unindent",
"url",
"urlencoding",
"util",
"uuid 1.4.1",
]
[[package]]
name = "zeroize"
version = "1.6.0"

View file

@ -4,12 +4,15 @@ members = [
"crates/ai",
"crates/assistant",
"crates/audio",
"crates/audio2",
"crates/auto_update",
"crates/breadcrumbs",
"crates/call",
"crates/call2",
"crates/channel",
"crates/cli",
"crates/client",
"crates/client2",
"crates/clock",
"crates/collab",
"crates/collab_ui",
@ -18,18 +21,24 @@ members = [
"crates/component_test",
"crates/context_menu",
"crates/copilot",
"crates/copilot2",
"crates/copilot_button",
"crates/db",
"crates/db2",
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/diagnostics",
"crates/drag_and_drop",
"crates/editor",
"crates/feature_flags",
"crates/feature_flags2",
"crates/feedback",
"crates/file_finder",
"crates/fs",
"crates/fs2",
"crates/fsevent",
"crates/fuzzy",
"crates/fuzzy2",
"crates/git",
"crates/go_to_line",
"crates/gpui",
@ -37,15 +46,20 @@ members = [
"crates/gpui2",
"crates/gpui2_macros",
"crates/install_cli",
"crates/install_cli2",
"crates/journal",
"crates/journal2",
"crates/language",
"crates/language2",
"crates/language_selector",
"crates/language_tools",
"crates/live_kit_client",
"crates/live_kit_server",
"crates/lsp",
"crates/lsp2",
"crates/media",
"crates/menu",
"crates/menu2",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
@ -55,24 +69,32 @@ members = [
"crates/plugin_macros",
"crates/plugin_runtime",
"crates/prettier",
"crates/prettier2",
"crates/project",
"crates/project2",
"crates/project_panel",
"crates/project_symbols",
"crates/recent_projects",
"crates/rope",
"crates/rpc",
"crates/rpc2",
"crates/search",
"crates/settings",
"crates/settings2",
"crates/snippet",
"crates/sqlez",
"crates/sqlez_macros",
"crates/feature_flags",
"crates/rich_text",
"crates/storybook2",
"crates/sum_tree",
"crates/terminal",
"crates/terminal2",
"crates/text",
"crates/theme",
"crates/theme2",
"crates/theme_converter",
"crates/theme_selector",
"crates/ui2",
"crates/util",
"crates/semantic_index",
"crates/vim",
@ -81,6 +103,7 @@ members = [
"crates/welcome",
"crates/xtask",
"crates/zed",
"crates/zed2",
"crates/zed-actions"
]
default-members = ["crates/zed"]

38
crates/Cargo.toml Normal file
View file

@ -0,0 +1,38 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

38
crates/ai2/Cargo.toml Normal file
View file

@ -0,0 +1,38 @@
[package]
name = "ai2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai2.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui2 = { path = "../gpui2" }
util = { path = "../util" }
language2 = { path = "../language2" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui2 = { path = "../gpui2", features = ["test-support"] }

8
crates/ai2/src/ai2.rs Normal file
View file

@ -0,0 +1,8 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

17
crates/ai2/src/auth.rs Normal file
View file

@ -0,0 +1,17 @@
use async_trait::async_trait;
use gpui2::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
#[async_trait]
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
async fn delete_credentials(&self, cx: &mut AppContext);
}

View file

@ -0,0 +1,23 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

123
crates/ai2/src/embedding.rs Normal file
View file

@ -0,0 +1,123 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui2::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

16
crates/ai2/src/models.rs Normal file
View file

@ -0,0 +1,16 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View file

@ -0,0 +1,330 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language2::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq, Ord)]
pub enum PromptPriority {
Mandatory, // Ignores truncation
Ordered { order: usize }, // Truncates based on priority
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let seperator = "\n";
let seperator_tokens = self.args.model.count_tokens(seperator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
// If Truncate
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + seperator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(seperator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(seperator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View file

@ -0,0 +1,164 @@
use anyhow::anyhow;
use language2::BufferSnapshot;
use language2::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View file

@ -0,0 +1,99 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View file

@ -0,0 +1,5 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View file

@ -0,0 +1,52 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View file

@ -0,0 +1,98 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui2::{AsyncAppContext, Model};
use language2::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length.clone();
let seperator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
{
Some(tokens_left - token_count - seperator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View file

@ -0,0 +1 @@
pub mod open_ai;

View file

@ -0,0 +1,306 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui2::{AppContext, Executor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Executor>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Executor>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
#[async_trait]
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View file

@ -0,0 +1,313 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui2::Executor;
use gpui2::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Executor>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View file

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View file

@ -0,0 +1,57 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View file

@ -0,0 +1,11 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

193
crates/ai2/src/test.rs Normal file
View file

@ -0,0 +1,193 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui2::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
#[async_trait]
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

24
crates/audio2/Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "audio2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/audio2.rs"
doctest = false
[dependencies]
gpui2 = { path = "../gpui2" }
collections = { path = "../collections" }
util = { path = "../util" }
rodio ={version = "0.17.1", default-features=false, features = ["wav"]}
log.workspace = true
futures.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
[dev-dependencies]

View file

@ -0,0 +1,23 @@
[package]
name = "audio"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/audio.rs"
doctest = false
[dependencies]
gpui = { path = "../gpui" }
collections = { path = "../collections" }
util = { path = "../util" }
rodio ={version = "0.17.1", default-features=false, features = ["wav"]}
log.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
[dev-dependencies]

View file

@ -0,0 +1,44 @@
use std::{io::Cursor, sync::Arc};
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AssetSource};
use rodio::{
source::{Buffered, SamplesConverter},
Decoder, Source,
};
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
assets: Box<dyn AssetSource>,
}
impl SoundRegistry {
pub fn new(source: impl AssetSource) -> Arc<Self> {
Arc::new(Self {
cache: Default::default(),
assets: Box::new(source),
})
}
pub fn global(cx: &AppContext) -> Arc<Self> {
cx.global::<Arc<Self>>().clone()
}
pub fn get(&self, name: &str) -> Result<impl Source<Item = f32>> {
if let Some(wav) = self.cache.lock().get(name) {
return Ok(wav.clone());
}
let path = format!("sounds/{}.wav", name);
let bytes = self.assets.load(&path)?.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());
Ok(source)
}
}

View file

@ -0,0 +1,81 @@
use assets::SoundRegistry;
use gpui::{AppContext, AssetSource};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
pub fn init(source: impl AssetSource, cx: &mut AppContext) {
cx.set_global(SoundRegistry::new(source));
cx.set_global(Audio::new());
}
pub enum Sound {
Joined,
Leave,
Mute,
Unmute,
StartScreenshare,
StopScreenshare,
}
impl Sound {
fn file(&self) -> &'static str {
match self {
Self::Joined => "joined_call",
Self::Leave => "leave_call",
Self::Mute => "mute",
Self::Unmute => "unmute",
Self::StartScreenshare => "start_screenshare",
Self::StopScreenshare => "stop_screenshare",
}
}
}
pub struct Audio {
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
impl Audio {
pub fn new() -> Self {
Self {
_output_stream: None,
output_handle: None,
}
}
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
}
pub fn play_sound(sound: Sound, cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.play_raw(source).log_err()?;
Some(())
});
}
pub fn end_call(cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
cx.update_global::<Self, _, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});
}
}

View file

@ -0,0 +1,44 @@
use std::{io::Cursor, sync::Arc};
use anyhow::Result;
use collections::HashMap;
use gpui2::{AppContext, AssetSource};
use rodio::{
source::{Buffered, SamplesConverter},
Decoder, Source,
};
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
assets: Box<dyn AssetSource>,
}
impl SoundRegistry {
pub fn new(source: impl AssetSource) -> Arc<Self> {
Arc::new(Self {
cache: Default::default(),
assets: Box::new(source),
})
}
pub fn global(cx: &AppContext) -> Arc<Self> {
cx.global::<Arc<Self>>().clone()
}
pub fn get(&self, name: &str) -> Result<impl Source<Item = f32>> {
if let Some(wav) = self.cache.lock().get(name) {
return Ok(wav.clone());
}
let path = format!("sounds/{}.wav", name);
let bytes = self.assets.load(&path)?.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());
Ok(source)
}
}

111
crates/audio2/src/audio2.rs Normal file
View file

@ -0,0 +1,111 @@
use assets::SoundRegistry;
use futures::{channel::mpsc, StreamExt};
use gpui2::{AppContext, AssetSource, Executor};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
pub fn init(source: impl AssetSource, cx: &mut AppContext) {
cx.set_global(Audio::new(cx.executor()));
cx.set_global(SoundRegistry::new(source));
}
pub enum Sound {
Joined,
Leave,
Mute,
Unmute,
StartScreenshare,
StopScreenshare,
}
impl Sound {
fn file(&self) -> &'static str {
match self {
Self::Joined => "joined_call",
Self::Leave => "leave_call",
Self::Mute => "mute",
Self::Unmute => "unmute",
Self::StartScreenshare => "start_screenshare",
Self::StopScreenshare => "stop_screenshare",
}
}
}
pub struct Audio {
tx: mpsc::UnboundedSender<Box<dyn FnOnce(&mut AudioState) + Send>>,
}
struct AudioState {
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
impl AudioState {
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
}
fn take(&mut self) {
self._output_stream.take();
self.output_handle.take();
}
}
impl Audio {
pub fn new(executor: &Executor) -> Self {
let (tx, mut rx) = mpsc::unbounded::<Box<dyn FnOnce(&mut AudioState) + Send>>();
executor
.spawn_on_main(|| async move {
let mut audio = AudioState {
_output_stream: None,
output_handle: None,
};
while let Some(f) = rx.next().await {
(f)(&mut audio);
}
})
.detach();
Self { tx }
}
pub fn play_sound(sound: Sound, cx: &mut AppContext) {
if !cx.has_global::<Self>() {
return;
}
let Some(source) = SoundRegistry::global(cx).get(sound.file()).log_err() else {
return;
};
let this = cx.global::<Self>();
this.tx
.unbounded_send(Box::new(move |state| {
if let Some(output_handle) = state.ensure_output_exists() {
output_handle.play_raw(source).log_err();
}
}))
.ok();
}
pub fn end_call(cx: &AppContext) {
if !cx.has_global::<Self>() {
return;
}
let this = cx.global::<Self>();
this.tx
.unbounded_send(Box::new(move |state| state.take()))
.ok();
}
}

View file

@ -1252,7 +1252,7 @@ impl Room {
.read_with(&cx, |this, _| {
this.live_kit
.as_ref()
.map(|live_kit| live_kit.room.publish_audio_track(&track))
.map(|live_kit| live_kit.room.publish_audio_track(track))
})
.ok_or_else(|| anyhow!("live-kit was not initialized"))?
.await
@ -1338,7 +1338,7 @@ impl Room {
.read_with(&cx, |this, _| {
this.live_kit
.as_ref()
.map(|live_kit| live_kit.room.publish_video_track(&track))
.map(|live_kit| live_kit.room.publish_video_track(track))
})
.ok_or_else(|| anyhow!("live-kit was not initialized"))?
.await

52
crates/call2/Cargo.toml Normal file
View file

@ -0,0 +1,52 @@
[package]
name = "call2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/call2.rs"
doctest = false
[features]
test-support = [
"client2/test-support",
"collections/test-support",
"gpui2/test-support",
"live_kit_client/test-support",
"project2/test-support",
"util/test-support"
]
[dependencies]
audio2 = { path = "../audio2" }
client2 = { path = "../client2" }
collections = { path = "../collections" }
gpui2 = { path = "../gpui2" }
log.workspace = true
live_kit_client = { path = "../live_kit_client" }
fs2 = { path = "../fs2" }
language2 = { path = "../language2" }
media = { path = "../media" }
project2 = { path = "../project2" }
settings2 = { path = "../settings2" }
util = { path = "../util" }
anyhow.workspace = true
async-broadcast = "0.4"
futures.workspace = true
postage.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_derive.workspace = true
[dev-dependencies]
client2 = { path = "../client2", features = ["test-support"] }
fs2 = { path = "../fs2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }
live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
project2 = { path = "../project2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

461
crates/call2/src/call2.rs Normal file
View file

@ -0,0 +1,461 @@
pub mod call_settings;
pub mod participant;
pub mod room;
use anyhow::{anyhow, Result};
use audio2::Audio;
use call_settings::CallSettings;
use client2::{
proto, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore,
ZED_ALWAYS_ACTIVE,
};
use collections::HashSet;
use futures::{future::Shared, FutureExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
WeakModel,
};
use postage::watch;
use project2::Project;
use settings2::Settings;
use std::sync::Arc;
pub use participant::ParticipantLocation;
pub use room::Room;
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
CallSettings::register(cx);
let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx));
cx.set_global(active_call);
}
#[derive(Clone)]
pub struct IncomingCall {
pub room_id: u64,
pub calling_user: Arc<User>,
pub participants: Vec<Arc<User>>,
pub initial_project: Option<proto::ParticipantProject>,
}
/// Singleton global maintaining the user's participation in a room across workspaces.
pub struct ActiveCall {
room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakModel<Project>>,
pending_invites: HashSet<u64>,
incoming_call: (
watch::Sender<Option<IncomingCall>>,
watch::Receiver<Option<IncomingCall>>,
),
client: Arc<Client>,
user_store: Model<UserStore>,
_subscriptions: Vec<client2::Subscription>,
}
impl EventEmitter for ActiveCall {
type Event = room::Event;
}
impl ActiveCall {
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
Self {
room: None,
pending_room_creation: None,
location: None,
pending_invites: Default::default(),
incoming_call: watch::channel(),
_subscriptions: vec![
client.add_request_handler(cx.weak_model(), Self::handle_incoming_call),
client.add_message_handler(cx.weak_model(), Self::handle_call_canceled),
],
client,
user_store,
}
}
pub fn channel_id(&self, cx: &AppContext) -> Option<u64> {
self.room()?.read(cx).channel_id()
}
async fn handle_incoming_call(
this: Model<Self>,
envelope: TypedEnvelope<proto::IncomingCall>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<proto::Ack> {
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let call = IncomingCall {
room_id: envelope.payload.room_id,
participants: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_users(envelope.payload.participant_user_ids, cx)
})?
.await?,
calling_user: user_store
.update(&mut cx, |user_store, cx| {
user_store.get_user(envelope.payload.calling_user_id, cx)
})?
.await?,
initial_project: envelope.payload.initial_project,
};
this.update(&mut cx, |this, _| {
*this.incoming_call.0.borrow_mut() = Some(call);
})?;
Ok(proto::Ack {})
}
async fn handle_call_canceled(
this: Model<Self>,
envelope: TypedEnvelope<proto::CallCanceled>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, _| {
let mut incoming_call = this.incoming_call.0.borrow_mut();
if incoming_call
.as_ref()
.map_or(false, |call| call.room_id == envelope.payload.room_id)
{
incoming_call.take();
}
})?;
Ok(())
}
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Model<Self>>().clone()
}
pub fn invite(
&mut self,
called_user_id: u64,
initial_project: Option<Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.pending_invites.insert(called_user_id) {
return Task::ready(Err(anyhow!("user was already invited")));
}
cx.notify();
let room = if let Some(room) = self.room().cloned() {
Some(Task::ready(Ok(room)).shared())
} else {
self.pending_room_creation.clone()
};
let invite = if let Some(room) = room {
cx.spawn(move |_, mut cx| async move {
let room = room.await.map_err(|err| anyhow!("{:?}", err))?;
let initial_project_id = if let Some(initial_project) = initial_project {
Some(
room.update(&mut cx, |room, cx| room.share_project(initial_project, cx))?
.await?,
)
} else {
None
};
room.update(&mut cx, move |room, cx| {
room.call(called_user_id, initial_project_id, cx)
})?
.await?;
anyhow::Ok(())
})
} else {
let client = self.client.clone();
let user_store = self.user_store.clone();
let room = cx
.spawn(move |this, mut cx| async move {
let create_room = async {
let room = cx
.update(|cx| {
Room::create(
called_user_id,
initial_project,
client,
user_store,
cx,
)
})?
.await?;
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))?
.await?;
anyhow::Ok(room)
};
let room = create_room.await;
this.update(&mut cx, |this, _| this.pending_room_creation = None)?;
room.map_err(Arc::new)
})
.shared();
self.pending_room_creation = Some(room.clone());
cx.executor().spawn(async move {
room.await.map_err(|err| anyhow!("{:?}", err))?;
anyhow::Ok(())
})
};
cx.spawn(move |this, mut cx| async move {
let result = invite.await;
if result.is_ok() {
this.update(&mut cx, |this, cx| this.report_call_event("invite", cx))?;
} else {
// TODO: Resport collaboration error
}
this.update(&mut cx, |this, cx| {
this.pending_invites.remove(&called_user_id);
cx.notify();
})?;
result
})
}
pub fn cancel_invite(
&mut self,
called_user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let room_id = if let Some(room) = self.room() {
room.read(cx).id()
} else {
return Task::ready(Err(anyhow!("no active call")));
};
let client = self.client.clone();
cx.executor().spawn(async move {
client
.request(proto::CancelCall {
room_id,
called_user_id,
})
.await?;
anyhow::Ok(())
})
}
pub fn incoming(&self) -> watch::Receiver<Option<IncomingCall>> {
self.incoming_call.1.clone()
}
pub fn accept_incoming(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if self.room.is_some() {
return Task::ready(Err(anyhow!("cannot join while on another call")));
}
let call = if let Some(call) = self.incoming_call.1.borrow().clone() {
call
} else {
return Task::ready(Err(anyhow!("no incoming call")));
};
let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx);
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("accept incoming", cx)
})?;
Ok(())
})
}
pub fn decline_incoming(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
let call = self
.incoming_call
.0
.borrow_mut()
.take()
.ok_or_else(|| anyhow!("no incoming call"))?;
report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx);
self.client.send(proto::DeclineCall {
room_id: call.room_id,
})?;
Ok(())
}
pub fn join_channel(
&mut self,
channel_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Model<Room>>> {
if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(room));
} else {
room.update(cx, |room, cx| room.clear_state(cx));
}
}
let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx);
cx.spawn(|this, mut cx| async move {
let room = join.await?;
this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.report_call_event("join channel", cx)
})?;
Ok(room)
})
}
pub fn hang_up(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.notify();
self.report_call_event("hang up", cx);
Audio::end_call(cx);
if let Some((room, _)) = self.room.take() {
room.update(cx, |room, cx| room.leave(cx))
} else {
Task::ready(Ok(()))
}
}
pub fn share_project(
&mut self,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some((room, _)) = self.room.as_ref() {
self.report_call_event("share project", cx);
room.update(cx, |room, cx| room.share_project(project, cx))
} else {
Task::ready(Err(anyhow!("no active call")))
}
}
pub fn unshare_project(
&mut self,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let Some((room, _)) = self.room.as_ref() {
self.report_call_event("unshare project", cx);
room.update(cx, |room, cx| room.unshare_project(project, cx))
} else {
Err(anyhow!("no active call"))
}
}
pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref()
}
pub fn set_location(
&mut self,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
self.location = project.map(|project| project.downgrade());
if let Some((room, _)) = self.room.as_ref() {
return room.update(cx, |room, cx| room.set_location(project, cx));
}
}
Task::ready(Ok(()))
}
fn set_room(
&mut self,
room: Option<Model<Room>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
cx.notify();
if let Some(room) = room {
if room.read(cx).status().is_offline() {
self.room = None;
Task::ready(Ok(()))
} else {
let subscriptions = vec![
cx.observe(&room, |this, room, cx| {
if room.read(cx).status().is_offline() {
this.set_room(None, cx).detach_and_log_err(cx);
}
cx.notify();
}),
cx.subscribe(&room, |_, _, event, cx| cx.emit(event.clone())),
];
self.room = Some((room.clone(), subscriptions));
let location = self
.location
.as_ref()
.and_then(|location| location.upgrade());
room.update(cx, |room, cx| room.set_location(location.as_ref(), cx))
}
} else {
self.room = None;
Task::ready(Ok(()))
}
} else {
Task::ready(Ok(()))
}
}
pub fn room(&self) -> Option<&Model<Room>> {
self.room.as_ref().map(|(room, _)| room)
}
pub fn client(&self) -> Arc<Client> {
self.client.clone()
}
pub fn pending_invites(&self) -> &HashSet<u64> {
&self.pending_invites
}
pub fn report_call_event(&self, operation: &'static str, cx: &AppContext) {
if let Some(room) = self.room() {
let room = room.read(cx);
report_call_event_for_room(operation, room.id(), room.channel_id(), &self.client, cx);
}
}
}
pub fn report_call_event_for_room(
operation: &'static str,
room_id: u64,
channel_id: Option<u64>,
client: &Arc<Client>,
cx: &AppContext,
) {
let telemetry = client.telemetry();
let telemetry_settings = *TelemetrySettings::get_global(cx);
let event = ClickhouseEvent::Call {
operation,
room_id: Some(room_id),
channel_id,
};
telemetry.report_clickhouse_event(event, telemetry_settings);
}
pub fn report_call_event_for_channel(
operation: &'static str,
channel_id: u64,
client: &Arc<Client>,
cx: &AppContext,
) {
let room = ActiveCall::global(cx).read(cx).room();
let telemetry = client.telemetry();
let telemetry_settings = *TelemetrySettings::get_global(cx);
let event = ClickhouseEvent::Call {
operation,
room_id: room.map(|r| r.read(cx).id()),
channel_id: Some(channel_id),
};
telemetry.report_clickhouse_event(event, telemetry_settings);
}

View file

@ -0,0 +1,32 @@
use anyhow::Result;
use gpui2::AppContext;
use schemars::JsonSchema;
use serde_derive::{Deserialize, Serialize};
use settings2::Settings;
#[derive(Deserialize, Debug)]
pub struct CallSettings {
pub mute_on_join: bool,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct CallSettingsContent {
pub mute_on_join: Option<bool>,
}
impl Settings for CallSettings {
const KEY: Option<&'static str> = Some("calls");
type FileContent = CallSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_cx: &mut AppContext,
) -> Result<Self>
where
Self: Sized,
{
Self::load_via_json_merge(default_value, user_values)
}
}

View file

@ -0,0 +1,71 @@
use anyhow::{anyhow, Result};
use client2::ParticipantIndex;
use client2::{proto, User};
use gpui2::WeakModel;
pub use live_kit_client::Frame;
use project2::Project;
use std::{fmt, sync::Arc};
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ParticipantLocation {
SharedProject { project_id: u64 },
UnsharedProject,
External,
}
impl ParticipantLocation {
pub fn from_proto(location: Option<proto::ParticipantLocation>) -> Result<Self> {
match location.and_then(|l| l.variant) {
Some(proto::participant_location::Variant::SharedProject(project)) => {
Ok(Self::SharedProject {
project_id: project.id,
})
}
Some(proto::participant_location::Variant::UnsharedProject(_)) => {
Ok(Self::UnsharedProject)
}
Some(proto::participant_location::Variant::External(_)) => Ok(Self::External),
None => Err(anyhow!("participant location was not provided")),
}
}
}
#[derive(Clone, Default)]
pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>,
pub active_project: Option<WeakModel<Project>>,
}
#[derive(Clone, Debug)]
pub struct RemoteParticipant {
pub user: Arc<User>,
pub peer_id: proto::PeerId,
pub projects: Vec<proto::ParticipantProject>,
pub location: ParticipantLocation,
pub participant_index: ParticipantIndex,
pub muted: bool,
pub speaking: bool,
// pub video_tracks: HashMap<live_kit_client::Sid, Arc<RemoteVideoTrack>>,
// pub audio_tracks: HashMap<live_kit_client::Sid, Arc<RemoteAudioTrack>>,
}
#[derive(Clone)]
pub struct RemoteVideoTrack {
pub(crate) live_kit_track: Arc<live_kit_client::RemoteVideoTrack>,
}
unsafe impl Send for RemoteVideoTrack {}
// todo!("remove this sync because it's not legit")
unsafe impl Sync for RemoteVideoTrack {}
impl fmt::Debug for RemoteVideoTrack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteVideoTrack").finish()
}
}
impl RemoteVideoTrack {
pub fn frames(&self) -> async_broadcast::Receiver<Frame> {
self.live_kit_track.frames()
}
}

1622
crates/call2/src/room.rs Normal file

File diff suppressed because it is too large Load diff

52
crates/client2/Cargo.toml Normal file
View file

@ -0,0 +1,52 @@
[package]
name = "client2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/client2.rs"
doctest = false
[features]
test-support = ["collections/test-support", "gpui2/test-support", "rpc2/test-support"]
[dependencies]
collections = { path = "../collections" }
db2 = { path = "../db2" }
gpui2 = { path = "../gpui2" }
util = { path = "../util" }
rpc2 = { path = "../rpc2" }
text = { path = "../text" }
settings2 = { path = "../settings2" }
feature_flags2 = { path = "../feature_flags2" }
sum_tree = { path = "../sum_tree" }
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { version = "0.16", features = ["async-tls"] }
futures.workspace = true
image = "0.23"
lazy_static.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
schemars.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
sysinfo.workspace = true
tempfile = "3"
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
uuid.workspace = true
url = "2.2"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }
rpc2 = { path = "../rpc2", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,333 @@
use crate::{TelemetrySettings, ZED_SECRET_CLIENT_TOKEN, ZED_SERVER_URL};
use gpui2::{serde_json, AppContext, AppMetadata, Executor, Task};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use serde::Serialize;
use settings2::Settings;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
};
use tempfile::NamedTempFile;
use util::http::HttpClient;
use util::{channel::ReleaseChannel, TryFutureExt};
pub struct Telemetry {
http_client: Arc<dyn HttpClient>,
executor: Executor,
state: Mutex<TelemetryState>,
}
struct TelemetryState {
metrics_id: Option<Arc<str>>, // Per logged-in user
installation_id: Option<Arc<str>>, // Per app installation (different for dev, preview, and stable)
session_id: Option<Arc<str>>, // Per app launch
release_channel: Option<&'static str>,
app_metadata: AppMetadata,
architecture: &'static str,
clickhouse_events_queue: Vec<ClickhouseEventWrapper>,
flush_clickhouse_events_task: Option<Task<()>>,
log_file: Option<NamedTempFile>,
is_staff: Option<bool>,
}
const CLICKHOUSE_EVENTS_URL_PATH: &'static str = "/api/events";
lazy_static! {
static ref CLICKHOUSE_EVENTS_URL: String =
format!("{}{}", *ZED_SERVER_URL, CLICKHOUSE_EVENTS_URL_PATH);
}
#[derive(Serialize, Debug)]
struct ClickhouseEventRequestBody {
token: &'static str,
installation_id: Option<Arc<str>>,
session_id: Option<Arc<str>>,
is_staff: Option<bool>,
app_version: Option<String>,
os_name: &'static str,
os_version: Option<String>,
architecture: &'static str,
release_channel: Option<&'static str>,
events: Vec<ClickhouseEventWrapper>,
}
#[derive(Serialize, Debug)]
struct ClickhouseEventWrapper {
signed_in: bool,
#[serde(flatten)]
event: ClickhouseEvent,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum AssistantKind {
Panel,
Inline,
}
#[derive(Serialize, Debug)]
#[serde(tag = "type")]
pub enum ClickhouseEvent {
Editor {
operation: &'static str,
file_extension: Option<String>,
vim_mode: bool,
copilot_enabled: bool,
copilot_enabled_for_language: bool,
},
Copilot {
suggestion_id: Option<String>,
suggestion_accepted: bool,
file_extension: Option<String>,
},
Call {
operation: &'static str,
room_id: Option<u64>,
channel_id: Option<u64>,
},
Assistant {
conversation_id: Option<String>,
kind: AssistantKind,
model: &'static str,
},
Cpu {
usage_as_percentage: f32,
core_count: u32,
},
Memory {
memory_in_bytes: u64,
virtual_memory_in_bytes: u64,
},
}
#[cfg(debug_assertions)]
const MAX_QUEUE_LEN: usize = 1;
#[cfg(not(debug_assertions))]
const MAX_QUEUE_LEN: usize = 10;
#[cfg(debug_assertions)]
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(1);
#[cfg(not(debug_assertions))]
const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(30);
impl Telemetry {
pub fn new(client: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
let release_channel = if cx.has_global::<ReleaseChannel>() {
Some(cx.global::<ReleaseChannel>().display_name())
} else {
None
};
// TODO: Replace all hardware stuff with nested SystemSpecs json
let this = Arc::new(Self {
http_client: client,
executor: cx.executor().clone(),
state: Mutex::new(TelemetryState {
app_metadata: cx.app_metadata(),
architecture: env::consts::ARCH,
release_channel,
installation_id: None,
metrics_id: None,
session_id: None,
clickhouse_events_queue: Default::default(),
flush_clickhouse_events_task: Default::default(),
log_file: None,
is_staff: None,
}),
});
this
}
pub fn log_file_path(&self) -> Option<PathBuf> {
Some(self.state.lock().log_file.as_ref()?.path().to_path_buf())
}
pub fn start(
self: &Arc<Self>,
installation_id: Option<String>,
session_id: String,
cx: &mut AppContext,
) {
let mut state = self.state.lock();
state.installation_id = installation_id.map(|id| id.into());
state.session_id = Some(session_id.into());
let has_clickhouse_events = !state.clickhouse_events_queue.is_empty();
drop(state);
if has_clickhouse_events {
self.flush_clickhouse_events();
}
let this = self.clone();
cx.spawn(|cx| async move {
// Avoiding calling `System::new_all()`, as there have been crashes related to it
let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
.with_processes(ProcessRefreshKind::everything()) // For process usage
.with_cpu(CpuRefreshKind::everything()); // For core count
let mut system = System::new_with_specifics(refresh_kind);
// Avoiding calling `refresh_all()`, just update what we need
system.refresh_specifics(refresh_kind);
loop {
// Waiting some amount of time before the first query is important to get a reasonable value
// https://docs.rs/sysinfo/0.29.10/sysinfo/trait.ProcessExt.html#tymethod.cpu_usage
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
system.refresh_specifics(refresh_kind);
let current_process = Pid::from_u32(std::process::id());
let Some(process) = system.processes().get(&current_process) else {
let process = current_process;
log::error!("Failed to find own process {process:?} in system process table");
// TODO: Fire an error telemetry event
return;
};
let memory_event = ClickhouseEvent::Memory {
memory_in_bytes: process.memory(),
virtual_memory_in_bytes: process.virtual_memory(),
};
let cpu_event = ClickhouseEvent::Cpu {
usage_as_percentage: process.cpu_usage(),
core_count: system.cpus().len() as u32,
};
let telemetry_settings = if let Ok(telemetry_settings) =
cx.update(|cx| *TelemetrySettings::get_global(cx))
{
telemetry_settings
} else {
break;
};
this.report_clickhouse_event(memory_event, telemetry_settings);
this.report_clickhouse_event(cpu_event, telemetry_settings);
}
})
.detach();
}
pub fn set_authenticated_user_info(
self: &Arc<Self>,
metrics_id: Option<String>,
is_staff: bool,
cx: &AppContext,
) {
if !TelemetrySettings::get_global(cx).metrics {
return;
}
let mut state = self.state.lock();
let metrics_id: Option<Arc<str>> = metrics_id.map(|id| id.into());
state.metrics_id = metrics_id.clone();
state.is_staff = Some(is_staff);
drop(state);
}
pub fn report_clickhouse_event(
self: &Arc<Self>,
event: ClickhouseEvent,
telemetry_settings: TelemetrySettings,
) {
if !telemetry_settings.metrics {
return;
}
let mut state = self.state.lock();
let signed_in = state.metrics_id.is_some();
state
.clickhouse_events_queue
.push(ClickhouseEventWrapper { signed_in, event });
if state.installation_id.is_some() {
if state.clickhouse_events_queue.len() >= MAX_QUEUE_LEN {
drop(state);
self.flush_clickhouse_events();
} else {
let this = self.clone();
let executor = self.executor.clone();
state.flush_clickhouse_events_task = Some(self.executor.spawn(async move {
executor.timer(DEBOUNCE_INTERVAL).await;
this.flush_clickhouse_events();
}));
}
}
}
pub fn metrics_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().metrics_id.clone()
}
pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().installation_id.clone()
}
pub fn is_staff(self: &Arc<Self>) -> Option<bool> {
self.state.lock().is_staff
}
fn flush_clickhouse_events(self: &Arc<Self>) {
let mut state = self.state.lock();
let mut events = mem::take(&mut state.clickhouse_events_queue);
state.flush_clickhouse_events_task.take();
drop(state);
let this = self.clone();
self.executor
.spawn(
async move {
let mut json_bytes = Vec::new();
if let Some(file) = &mut this.state.lock().log_file {
let file = file.as_file_mut();
for event in &mut events {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event)?;
file.write_all(&json_bytes)?;
file.write(b"\n")?;
}
}
{
let state = this.state.lock();
let request_body = ClickhouseEventRequestBody {
token: ZED_SECRET_CLIENT_TOKEN,
installation_id: state.installation_id.clone(),
session_id: state.session_id.clone(),
is_staff: state.is_staff.clone(),
app_version: state
.app_metadata
.app_version
.map(|version| version.to_string()),
os_name: state.app_metadata.os_name,
os_version: state
.app_metadata
.os_version
.map(|version| version.to_string()),
architecture: state.architecture,
release_channel: state.release_channel,
events,
};
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, &request_body)?;
}
this.http_client
.post_json(CLICKHOUSE_EVENTS_URL.as_str(), json_bytes.into())
.await?;
anyhow::Ok(())
}
.log_err(),
)
.detach();
}
}

216
crates/client2/src/test.rs Normal file
View file

@ -0,0 +1,216 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use futures::{stream::BoxStream, StreamExt};
use gpui2::{Context, Executor, Model, TestAppContext};
use parking_lot::Mutex;
use rpc2::{
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
ConnectionId, Peer, Receipt, TypedEnvelope,
};
use std::sync::Arc;
use util::http::FakeHttpClient;
pub struct FakeServer {
peer: Arc<Peer>,
state: Arc<Mutex<FakeServerState>>,
user_id: u64,
executor: Executor,
}
#[derive(Default)]
struct FakeServerState {
incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
connection_id: Option<ConnectionId>,
forbid_connections: bool,
auth_count: usize,
access_token: usize,
}
impl FakeServer {
pub async fn for_client(
client_user_id: u64,
client: &Arc<Client>,
cx: &TestAppContext,
) -> Self {
let server = Self {
peer: Peer::new(0),
state: Default::default(),
user_id: client_user_id,
executor: cx.executor().clone(),
};
client
.override_authenticate({
let state = Arc::downgrade(&server.state);
move |cx| {
let state = state.clone();
cx.spawn(move |_| async move {
let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
let mut state = state.lock();
state.auth_count += 1;
let access_token = state.access_token.to_string();
Ok(Credentials {
user_id: client_user_id,
access_token,
})
})
}
})
.override_establish_connection({
let peer = Arc::downgrade(&server.peer);
let state = Arc::downgrade(&server.state);
move |credentials, cx| {
let peer = peer.clone();
let state = state.clone();
let credentials = credentials.clone();
cx.spawn(move |cx| async move {
let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
if state.lock().forbid_connections {
Err(EstablishConnectionError::Other(anyhow!(
"server is forbidding connections"
)))?
}
assert_eq!(credentials.user_id, client_user_id);
if credentials.access_token != state.lock().access_token.to_string() {
Err(EstablishConnectionError::Unauthorized)?
}
let (client_conn, server_conn, _) =
Connection::in_memory(cx.executor().clone());
let (connection_id, io, incoming) =
peer.add_test_connection(server_conn, cx.executor().clone());
cx.executor().spawn(io).detach();
{
let mut state = state.lock();
state.connection_id = Some(connection_id);
state.incoming = Some(incoming);
}
peer.send(
connection_id,
proto::Hello {
peer_id: Some(connection_id.into()),
},
)
.unwrap();
Ok(client_conn)
})
}
});
client
.authenticate_and_connect(false, &cx.to_async())
.await
.unwrap();
server
}
pub fn disconnect(&self) {
if self.state.lock().connection_id.is_some() {
self.peer.disconnect(self.connection_id());
let mut state = self.state.lock();
state.connection_id.take();
state.incoming.take();
}
}
pub fn auth_count(&self) -> usize {
self.state.lock().auth_count
}
pub fn roll_access_token(&self) {
self.state.lock().access_token += 1;
}
pub fn forbid_connections(&self) {
self.state.lock().forbid_connections = true;
}
pub fn allow_connections(&self) {
self.state.lock().forbid_connections = false;
}
pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id(), message).unwrap();
}
#[allow(clippy::await_holding_lock)]
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
self.executor.start_waiting();
loop {
let message = self
.state
.lock()
.incoming
.as_mut()
.expect("not connected")
.next()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
self.executor.finish_waiting();
let type_name = message.payload_type_name();
let message = message.into_any();
if message.is::<TypedEnvelope<M>>() {
return Ok(*message.downcast().unwrap());
}
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
.unwrap()
.receipt(),
GetPrivateUserInfoResponse {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
},
);
continue;
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}
}
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
self.peer.respond(receipt, response).unwrap()
}
fn connection_id(&self) -> ConnectionId {
self.state.lock().connection_id.expect("not connected")
}
pub async fn build_user_store(
&self,
client: Arc<Client>,
cx: &mut TestAppContext,
) -> Model<UserStore> {
let http_client = FakeHttpClient::with_404_response();
let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
assert_eq!(
self.receive::<proto::GetUsers>()
.await
.unwrap()
.payload
.user_ids,
&[self.user_id]
);
user_store
}
}
impl Drop for FakeServer {
fn drop(&mut self) {
self.disconnect();
}
}

739
crates/client2/src/user.rs Normal file
View file

@ -0,0 +1,739 @@
use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags2::FeatureFlagAppExt;
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task};
use postage::{sink::Sink, watch};
use rpc2::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak};
use text::ReplicaId;
use util::http::HttpClient;
use util::TryFutureExt as _;
pub type UserId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParticipantIndex(pub u32);
#[derive(Default, Debug)]
pub struct User {
pub id: UserId,
pub github_login: String,
pub avatar: Option<Arc<ImageData>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Collaborator {
pub peer_id: proto::PeerId,
pub replica_id: ReplicaId,
pub user_id: UserId,
}
impl PartialOrd for User {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for User {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.github_login.cmp(&other.github_login)
}
}
impl PartialEq for User {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.github_login == other.github_login
}
}
impl Eq for User {}
#[derive(Debug, PartialEq)]
pub struct Contact {
pub user: Arc<User>,
pub online: bool,
pub busy: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContactRequestStatus {
None,
RequestSent,
RequestReceived,
RequestAccepted,
}
pub struct UserStore {
users: HashMap<u64, Arc<User>>,
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_user: watch::Receiver<Option<Arc<User>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
pending_contact_requests: HashMap<u64, usize>,
invite_info: Option<InviteInfo>,
client: Weak<Client>,
http: Arc<dyn HttpClient>,
_maintain_contacts: Task<()>,
_maintain_current_user: Task<Result<()>>,
}
#[derive(Clone)]
pub struct InviteInfo {
pub count: u32,
pub url: Arc<str>,
}
pub enum Event {
Contact {
user: Arc<User>,
kind: ContactEventKind,
},
ShowContacts,
ParticipantIndicesChanged,
}
#[derive(Clone, Copy)]
pub enum ContactEventKind {
Requested,
Accepted,
Cancelled,
}
impl EventEmitter for UserStore {
type Event = Event;
}
enum UpdateContacts {
Update(proto::UpdateContacts),
Wait(postage::barrier::Sender),
Clear(postage::barrier::Sender),
}
impl UserStore {
pub fn new(
client: Arc<Client>,
http: Arc<dyn HttpClient>,
cx: &mut ModelContext<Self>,
) -> Self {
let (mut current_user_tx, current_user_rx) = watch::channel();
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
let rpc_subscriptions = vec![
client.add_message_handler(cx.weak_model(), Self::handle_update_contacts),
client.add_message_handler(cx.weak_model(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_model(), Self::handle_show_contacts),
];
Self {
users: Default::default(),
current_user: current_user_rx,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
outgoing_contact_requests: Default::default(),
invite_info: None,
client: Arc::downgrade(&client),
update_contacts_tx,
http,
_maintain_contacts: cx.spawn(|this, mut cx| async move {
let _subscriptions = rpc_subscriptions;
while let Some(message) = update_contacts_rx.next().await {
if let Ok(task) =
this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
{
task.log_err().await;
} else {
break;
}
}
}),
_maintain_current_user: cx.spawn(|this, mut cx| async move {
let mut status = client.status();
while let Some(status) = status.next().await {
match status {
Status::Connected { .. } => {
if let Some(user_id) = client.user_id() {
let fetch_user = if let Ok(fetch_user) = this
.update(&mut cx, |this, cx| {
this.get_user(user_id, cx).log_err()
}) {
fetch_user
} else {
break;
};
let fetch_metrics_id =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
cx.update(|cx| {
if let Some(info) = info {
cx.update_flags(info.staff, info.flags);
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
info.staff,
cx,
)
}
})?;
current_user_tx.send(user).await.ok();
this.update(&mut cx, |_, cx| cx.notify())?;
}
}
Status::SignedOut => {
current_user_tx.send(None).await.ok();
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
Status::ConnectionLost => {
this.update(&mut cx, |this, cx| {
cx.notify();
this.clear_contacts()
})?
.await;
}
_ => {}
}
}
Ok(())
}),
pending_contact_requests: Default::default(),
}
}
#[cfg(feature = "test-support")]
pub fn clear_cache(&mut self) {
self.users.clear();
}
async fn handle_update_invite_info(
this: Model<Self>,
message: TypedEnvelope<proto::UpdateInviteInfo>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.invite_info = Some(InviteInfo {
url: Arc::from(message.payload.url),
count: message.payload.count,
});
cx.notify();
})?;
Ok(())
}
async fn handle_show_contacts(
this: Model<Self>,
_: TypedEnvelope<proto::ShowContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts))?;
Ok(())
}
pub fn invite_info(&self) -> Option<&InviteInfo> {
self.invite_info.as_ref()
}
async fn handle_update_contacts(
this: Model<Self>,
message: TypedEnvelope<proto::UpdateContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, _| {
this.update_contacts_tx
.unbounded_send(UpdateContacts::Update(message.payload))
.unwrap();
})?;
Ok(())
}
fn update_contacts(
&mut self,
message: UpdateContacts,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
drop(barrier);
Task::ready(Ok(()))
}
UpdateContacts::Clear(barrier) => {
self.contacts.clear();
self.incoming_contact_requests.clear();
self.outgoing_contact_requests.clear();
drop(barrier);
Task::ready(Ok(()))
}
UpdateContacts::Update(message) => {
let mut user_ids = HashSet::default();
for contact in &message.contacts {
user_ids.insert(contact.user_id);
}
user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
user_ids.extend(message.outgoing_requests.iter());
let load_users = self.get_users(user_ids.into_iter().collect(), cx);
cx.spawn(|this, mut cx| async move {
load_users.await?;
// Users are fetched in parallel above and cached in call to get_users
// No need to paralellize here
let mut updated_contacts = Vec::new();
let this = this
.upgrade()
.ok_or_else(|| anyhow!("can't upgrade user store handle"))?;
for contact in message.contacts {
let should_notify = contact.should_notify;
updated_contacts.push((
Arc::new(Contact::from_proto(contact, &this, &mut cx).await?),
should_notify,
));
}
let mut incoming_requests = Vec::new();
for request in message.incoming_requests {
incoming_requests.push({
let user = this
.update(&mut cx, |this, cx| {
this.get_user(request.requester_id, cx)
})?
.await?;
(user, request.should_notify)
});
}
let mut outgoing_requests = Vec::new();
for requested_user_id in message.outgoing_requests {
outgoing_requests.push(
this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))?
.await?,
);
}
let removed_contacts =
HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
let removed_incoming_requests =
HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
let removed_outgoing_requests =
HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
this.update(&mut cx, |this, cx| {
// Remove contacts
this.contacts
.retain(|contact| !removed_contacts.contains(&contact.user.id));
// Update existing contacts and insert new ones
for (updated_contact, should_notify) in updated_contacts {
if should_notify {
cx.emit(Event::Contact {
user: updated_contact.user.clone(),
kind: ContactEventKind::Accepted,
});
}
match this.contacts.binary_search_by_key(
&&updated_contact.user.github_login,
|contact| &contact.user.github_login,
) {
Ok(ix) => this.contacts[ix] = updated_contact,
Err(ix) => this.contacts.insert(ix, updated_contact),
}
}
// Remove incoming contact requests
this.incoming_contact_requests.retain(|user| {
if removed_incoming_requests.contains(&user.id) {
cx.emit(Event::Contact {
user: user.clone(),
kind: ContactEventKind::Cancelled,
});
false
} else {
true
}
});
// Update existing incoming requests and insert new ones
for (user, should_notify) in incoming_requests {
if should_notify {
cx.emit(Event::Contact {
user: user.clone(),
kind: ContactEventKind::Requested,
});
}
match this
.incoming_contact_requests
.binary_search_by_key(&&user.github_login, |contact| {
&contact.github_login
}) {
Ok(ix) => this.incoming_contact_requests[ix] = user,
Err(ix) => this.incoming_contact_requests.insert(ix, user),
}
}
// Remove outgoing contact requests
this.outgoing_contact_requests
.retain(|user| !removed_outgoing_requests.contains(&user.id));
// Update existing incoming requests and insert new ones
for request in outgoing_requests {
match this
.outgoing_contact_requests
.binary_search_by_key(&&request.github_login, |contact| {
&contact.github_login
}) {
Ok(ix) => this.outgoing_contact_requests[ix] = request,
Err(ix) => this.outgoing_contact_requests.insert(ix, request),
}
}
cx.notify();
})?;
Ok(())
})
}
}
}
pub fn contacts(&self) -> &[Arc<Contact>] {
&self.contacts
}
pub fn has_contact(&self, user: &Arc<User>) -> bool {
self.contacts
.binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
.is_ok()
}
pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
&self.incoming_contact_requests
}
pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
&self.outgoing_contact_requests
}
pub fn is_contact_request_pending(&self, user: &User) -> bool {
self.pending_contact_requests.contains_key(&user.id)
}
pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
if self
.contacts
.binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
.is_ok()
{
ContactRequestStatus::RequestAccepted
} else if self
.outgoing_contact_requests
.binary_search_by_key(&&user.github_login, |user| &user.github_login)
.is_ok()
{
ContactRequestStatus::RequestSent
} else if self
.incoming_contact_requests
.binary_search_by_key(&&user.github_login, |user| &user.github_login)
.is_ok()
{
ContactRequestStatus::RequestReceived
} else {
ContactRequestStatus::None
}
}
pub fn request_contact(
&mut self,
responder_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
}
pub fn remove_contact(
&mut self,
user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
}
pub fn respond_to_contact_request(
&mut self,
requester_id: u64,
accept: bool,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
self.perform_contact_request(
requester_id,
proto::RespondToContactRequest {
requester_id,
response: if accept {
proto::ContactRequestResponse::Accept
} else {
proto::ContactRequestResponse::Decline
} as i32,
},
cx,
)
}
pub fn dismiss_contact_request(
&mut self,
requester_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.upgrade();
cx.spawn(move |_, _| async move {
client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(proto::RespondToContactRequest {
requester_id,
response: proto::ContactRequestResponse::Dismiss as i32,
})
.await?;
Ok(())
})
}
fn perform_contact_request<T: RequestMessage>(
&mut self,
user_id: u64,
request: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.upgrade();
*self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
cx.notify();
cx.spawn(move |this, mut cx| async move {
let response = client
.ok_or_else(|| anyhow!("can't upgrade client reference"))?
.request(request)
.await;
this.update(&mut cx, |this, cx| {
if let Entry::Occupied(mut request_count) =
this.pending_contact_requests.entry(user_id)
{
*request_count.get_mut() -= 1;
if *request_count.get() == 0 {
request_count.remove();
}
}
cx.notify();
})?;
response?;
Ok(())
})
}
pub fn clear_contacts(&mut self) -> impl Future<Output = ()> {
let (tx, mut rx) = postage::barrier::channel();
self.update_contacts_tx
.unbounded_send(UpdateContacts::Clear(tx))
.unwrap();
async move {
rx.next().await;
}
}
pub fn contact_updates_done(&mut self) -> impl Future<Output = ()> {
let (tx, mut rx) = postage::barrier::channel();
self.update_contacts_tx
.unbounded_send(UpdateContacts::Wait(tx))
.unwrap();
async move {
rx.next().await;
}
}
pub fn get_users(
&mut self,
user_ids: Vec<u64>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
let mut user_ids_to_fetch = user_ids.clone();
user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
cx.spawn(|this, mut cx| async move {
if !user_ids_to_fetch.is_empty() {
this.update(&mut cx, |this, cx| {
this.load_users(
proto::GetUsers {
user_ids: user_ids_to_fetch,
},
cx,
)
})?
.await?;
}
this.update(&mut cx, |this, _| {
user_ids
.iter()
.map(|user_id| {
this.users
.get(user_id)
.cloned()
.ok_or_else(|| anyhow!("user {} not found", user_id))
})
.collect()
})?
})
}
pub fn fuzzy_search_users(
&mut self,
query: String,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
self.load_users(proto::FuzzySearchUsers { query }, cx)
}
pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
self.users.get(&user_id).cloned()
}
pub fn get_user(
&mut self,
user_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Arc<User>>> {
if let Some(user) = self.users.get(&user_id).cloned() {
return Task::ready(Ok(user));
}
let load_users = self.get_users(vec![user_id], cx);
cx.spawn(move |this, mut cx| async move {
load_users.await?;
this.update(&mut cx, |this, _| {
this.users
.get(&user_id)
.cloned()
.ok_or_else(|| anyhow!("server responded with no users"))
})?
})
}
pub fn current_user(&self) -> Option<Arc<User>> {
self.current_user.borrow().clone()
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}
fn load_users(
&mut self,
request: impl RequestMessage<Response = UsersResponse>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Arc<User>>>> {
let client = self.client.clone();
let http = self.http.clone();
cx.spawn(|this, mut cx| async move {
if let Some(rpc) = client.upgrade() {
let response = rpc.request(request).await.context("error loading users")?;
let users = future::join_all(
response
.users
.into_iter()
.map(|user| User::new(user, http.as_ref())),
)
.await;
this.update(&mut cx, |this, _| {
for user in &users {
this.users.insert(user.id, user.clone());
}
})
.ok();
Ok(users)
} else {
Ok(Vec::new())
}
})
}
pub fn set_participant_indices(
&mut self,
participant_indices: HashMap<u64, ParticipantIndex>,
cx: &mut ModelContext<Self>,
) {
if participant_indices != self.participant_indices {
self.participant_indices = participant_indices;
cx.emit(Event::ParticipantIndicesChanged);
}
}
pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
&self.participant_indices
}
}
impl User {
async fn new(message: proto::User, http: &dyn HttpClient) -> Arc<Self> {
Arc::new(User {
id: message.id,
github_login: message.github_login,
avatar: fetch_avatar(http, &message.avatar_url).warn_on_err().await,
})
}
}
impl Contact {
async fn from_proto(
contact: proto::Contact,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let user = user_store
.update(cx, |user_store, cx| {
user_store.get_user(contact.user_id, cx)
})?
.await?;
Ok(Self {
user,
online: contact.online,
busy: contact.busy,
})
}
}
impl Collaborator {
pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
Ok(Self {
peer_id: message.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?,
replica_id: message.replica_id as ReplicaId,
user_id: message.user_id as UserId,
})
}
}
// todo!("we probably don't need this now that we fetch")
async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result<Arc<ImageData>> {
let mut response = http
.get(url, Default::default(), true)
.await
.map_err(|e| anyhow!("failed to send user avatar request: {}", e))?;
if !response.status().is_success() {
return Err(anyhow!("avatar request failed {:?}", response.status()));
}
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.map_err(|e| anyhow!("failed to read user avatar response body: {}", e))?;
let format = image::guess_format(&body)?;
let image = image::load_from_memory_with_format(&body, format)?.into_bgra8();
Ok(Arc::new(ImageData::new(image)))
}

View file

@ -0,0 +1,50 @@
[package]
name = "copilot2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/copilot2.rs"
doctest = false
[features]
test-support = [
"collections/test-support",
"gpui2/test-support",
"language2/test-support",
"lsp2/test-support",
"settings2/test-support",
"util/test-support",
]
[dependencies]
collections = { path = "../collections" }
context_menu = { path = "../context_menu" }
gpui2 = { path = "../gpui2" }
language2 = { path = "../language2" }
settings2 = { path = "../settings2" }
theme = { path = "../theme" }
lsp2 = { path = "../lsp2" }
node_runtime = { path = "../node_runtime"}
util = { path = "../util" }
async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
async-tar = "0.4.2"
anyhow.workspace = true
log.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
futures.workspace = true
parking_lot.workspace = true
[dev-dependencies]
clock = { path = "../clock" }
collections = { path = "../collections", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] }
lsp2 = { path = "../lsp2", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings2 = { path = "../settings2", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,225 @@
use serde::{Deserialize, Serialize};
pub enum CheckStatus {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CheckStatusParams {
pub local_checks_only: bool,
}
impl lsp2::request::Request for CheckStatus {
type Params = CheckStatusParams;
type Result = SignInStatus;
const METHOD: &'static str = "checkStatus";
}
pub enum SignInInitiate {}
#[derive(Debug, Serialize, Deserialize)]
pub struct SignInInitiateParams {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum SignInInitiateResult {
AlreadySignedIn { user: String },
PromptUserDeviceFlow(PromptUserDeviceFlow),
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptUserDeviceFlow {
pub user_code: String,
pub verification_uri: String,
}
impl lsp2::request::Request for SignInInitiate {
type Params = SignInInitiateParams;
type Result = SignInInitiateResult;
const METHOD: &'static str = "signInInitiate";
}
pub enum SignInConfirm {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SignInConfirmParams {
pub user_code: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum SignInStatus {
#[serde(rename = "OK")]
Ok {
user: String,
},
MaybeOk {
user: String,
},
AlreadySignedIn {
user: String,
},
NotAuthorized {
user: String,
},
NotSignedIn,
}
impl lsp2::request::Request for SignInConfirm {
type Params = SignInConfirmParams;
type Result = SignInStatus;
const METHOD: &'static str = "signInConfirm";
}
pub enum SignOut {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SignOutParams {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SignOutResult {}
impl lsp2::request::Request for SignOut {
type Params = SignOutParams;
type Result = SignOutResult;
const METHOD: &'static str = "signOut";
}
pub enum GetCompletions {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetCompletionsParams {
pub doc: GetCompletionsDocument,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetCompletionsDocument {
pub tab_size: u32,
pub indent_size: u32,
pub insert_spaces: bool,
pub uri: lsp2::Url,
pub relative_path: String,
pub position: lsp2::Position,
pub version: usize,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetCompletionsResult {
pub completions: Vec<Completion>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Completion {
pub text: String,
pub position: lsp2::Position,
pub uuid: String,
pub range: lsp2::Range,
pub display_text: String,
}
impl lsp2::request::Request for GetCompletions {
type Params = GetCompletionsParams;
type Result = GetCompletionsResult;
const METHOD: &'static str = "getCompletions";
}
pub enum GetCompletionsCycling {}
impl lsp2::request::Request for GetCompletionsCycling {
type Params = GetCompletionsParams;
type Result = GetCompletionsResult;
const METHOD: &'static str = "getCompletionsCycling";
}
pub enum LogMessage {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LogMessageParams {
pub level: u8,
pub message: String,
pub metadata_str: String,
pub extra: Vec<String>,
}
impl lsp2::notification::Notification for LogMessage {
type Params = LogMessageParams;
const METHOD: &'static str = "LogMessage";
}
pub enum StatusNotification {}
#[derive(Debug, Serialize, Deserialize)]
pub struct StatusNotificationParams {
pub message: String,
pub status: String, // One of Normal/InProgress
}
impl lsp2::notification::Notification for StatusNotification {
type Params = StatusNotificationParams;
const METHOD: &'static str = "statusNotification";
}
pub enum SetEditorInfo {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SetEditorInfoParams {
pub editor_info: EditorInfo,
pub editor_plugin_info: EditorPluginInfo,
}
impl lsp2::request::Request for SetEditorInfo {
type Params = SetEditorInfoParams;
type Result = String;
const METHOD: &'static str = "setEditorInfo";
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EditorInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EditorPluginInfo {
pub name: String,
pub version: String,
}
pub enum NotifyAccepted {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotifyAcceptedParams {
pub uuid: String,
}
impl lsp2::request::Request for NotifyAccepted {
type Params = NotifyAcceptedParams;
type Result = String;
const METHOD: &'static str = "notifyAccepted";
}
pub enum NotifyRejected {}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NotifyRejectedParams {
pub uuids: Vec<String>,
}
impl lsp2::request::Request for NotifyRejected {
type Params = NotifyRejectedParams;
type Result = String;
const METHOD: &'static str = "notifyRejected";
}

View file

@ -0,0 +1,376 @@
// TODO add logging in
// use crate::{request::PromptUserDeviceFlow, Copilot, Status};
// use gpui::{
// elements::*,
// geometry::rect::RectF,
// platform::{WindowBounds, WindowKind, WindowOptions},
// AnyElement, AnyViewHandle, AppContext, ClipboardItem, Element, Entity, View, ViewContext,
// WindowHandle,
// };
// use theme::ui::modal;
// #[derive(PartialEq, Eq, Debug, Clone)]
// struct CopyUserCode;
// #[derive(PartialEq, Eq, Debug, Clone)]
// struct OpenGithub;
// const COPILOT_SIGN_UP_URL: &'static str = "https://github.com/features/copilot";
// pub fn init(cx: &mut AppContext) {
// if let Some(copilot) = Copilot::global(cx) {
// let mut verification_window: Option<WindowHandle<CopilotCodeVerification>> = None;
// cx.observe(&copilot, move |copilot, cx| {
// let status = copilot.read(cx).status();
// match &status {
// crate::Status::SigningIn { prompt } => {
// if let Some(window) = verification_window.as_mut() {
// let updated = window
// .root(cx)
// .map(|root| {
// root.update(cx, |verification, cx| {
// verification.set_status(status.clone(), cx);
// cx.activate_window();
// })
// })
// .is_some();
// if !updated {
// verification_window = Some(create_copilot_auth_window(cx, &status));
// }
// } else if let Some(_prompt) = prompt {
// verification_window = Some(create_copilot_auth_window(cx, &status));
// }
// }
// Status::Authorized | Status::Unauthorized => {
// if let Some(window) = verification_window.as_ref() {
// if let Some(verification) = window.root(cx) {
// verification.update(cx, |verification, cx| {
// verification.set_status(status, cx);
// cx.platform().activate(true);
// cx.activate_window();
// });
// }
// }
// }
// _ => {
// if let Some(code_verification) = verification_window.take() {
// code_verification.update(cx, |cx| cx.remove_window());
// }
// }
// }
// })
// .detach();
// }
// }
// fn create_copilot_auth_window(
// cx: &mut AppContext,
// status: &Status,
// ) -> WindowHandle<CopilotCodeVerification> {
// let window_size = theme::current(cx).copilot.modal.dimensions();
// let window_options = WindowOptions {
// bounds: WindowBounds::Fixed(RectF::new(Default::default(), window_size)),
// titlebar: None,
// center: true,
// focus: true,
// show: true,
// kind: WindowKind::Normal,
// is_movable: true,
// screen: None,
// };
// cx.add_window(window_options, |_cx| {
// CopilotCodeVerification::new(status.clone())
// })
// }
// pub struct CopilotCodeVerification {
// status: Status,
// connect_clicked: bool,
// }
// impl CopilotCodeVerification {
// pub fn new(status: Status) -> Self {
// Self {
// status,
// connect_clicked: false,
// }
// }
// pub fn set_status(&mut self, status: Status, cx: &mut ViewContext<Self>) {
// self.status = status;
// cx.notify();
// }
// fn render_device_code(
// data: &PromptUserDeviceFlow,
// style: &theme::Copilot,
// cx: &mut ViewContext<Self>,
// ) -> impl IntoAnyElement<Self> {
// let copied = cx
// .read_from_clipboard()
// .map(|item| item.text() == &data.user_code)
// .unwrap_or(false);
// let device_code_style = &style.auth.prompting.device_code;
// MouseEventHandler::new::<Self, _>(0, cx, |state, _cx| {
// Flex::row()
// .with_child(
// Label::new(data.user_code.clone(), device_code_style.text.clone())
// .aligned()
// .contained()
// .with_style(device_code_style.left_container)
// .constrained()
// .with_width(device_code_style.left),
// )
// .with_child(
// Label::new(
// if copied { "Copied!" } else { "Copy" },
// device_code_style.cta.style_for(state).text.clone(),
// )
// .aligned()
// .contained()
// .with_style(*device_code_style.right_container.style_for(state))
// .constrained()
// .with_width(device_code_style.right),
// )
// .contained()
// .with_style(device_code_style.cta.style_for(state).container)
// })
// .on_click(gpui::platform::MouseButton::Left, {
// let user_code = data.user_code.clone();
// move |_, _, cx| {
// cx.platform()
// .write_to_clipboard(ClipboardItem::new(user_code.clone()));
// cx.notify();
// }
// })
// .with_cursor_style(gpui::platform::CursorStyle::PointingHand)
// }
// fn render_prompting_modal(
// connect_clicked: bool,
// data: &PromptUserDeviceFlow,
// style: &theme::Copilot,
// cx: &mut ViewContext<Self>,
// ) -> AnyElement<Self> {
// enum ConnectButton {}
// Flex::column()
// .with_child(
// Flex::column()
// .with_children([
// Label::new(
// "Enable Copilot by connecting",
// style.auth.prompting.subheading.text.clone(),
// )
// .aligned(),
// Label::new(
// "your existing license.",
// style.auth.prompting.subheading.text.clone(),
// )
// .aligned(),
// ])
// .align_children_center()
// .contained()
// .with_style(style.auth.prompting.subheading.container),
// )
// .with_child(Self::render_device_code(data, &style, cx))
// .with_child(
// Flex::column()
// .with_children([
// Label::new(
// "Paste this code into GitHub after",
// style.auth.prompting.hint.text.clone(),
// )
// .aligned(),
// Label::new(
// "clicking the button below.",
// style.auth.prompting.hint.text.clone(),
// )
// .aligned(),
// ])
// .align_children_center()
// .contained()
// .with_style(style.auth.prompting.hint.container.clone()),
// )
// .with_child(theme::ui::cta_button::<ConnectButton, _, _, _>(
// if connect_clicked {
// "Waiting for connection..."
// } else {
// "Connect to GitHub"
// },
// style.auth.content_width,
// &style.auth.cta_button,
// cx,
// {
// let verification_uri = data.verification_uri.clone();
// move |_, verification, cx| {
// cx.platform().open_url(&verification_uri);
// verification.connect_clicked = true;
// }
// },
// ))
// .align_children_center()
// .into_any()
// }
// fn render_enabled_modal(
// style: &theme::Copilot,
// cx: &mut ViewContext<Self>,
// ) -> AnyElement<Self> {
// enum DoneButton {}
// let enabled_style = &style.auth.authorized;
// Flex::column()
// .with_child(
// Label::new("Copilot Enabled!", enabled_style.subheading.text.clone())
// .contained()
// .with_style(enabled_style.subheading.container)
// .aligned(),
// )
// .with_child(
// Flex::column()
// .with_children([
// Label::new(
// "You can update your settings or",
// enabled_style.hint.text.clone(),
// )
// .aligned(),
// Label::new(
// "sign out from the Copilot menu in",
// enabled_style.hint.text.clone(),
// )
// .aligned(),
// Label::new("the status bar.", enabled_style.hint.text.clone()).aligned(),
// ])
// .align_children_center()
// .contained()
// .with_style(enabled_style.hint.container),
// )
// .with_child(theme::ui::cta_button::<DoneButton, _, _, _>(
// "Done",
// style.auth.content_width,
// &style.auth.cta_button,
// cx,
// |_, _, cx| cx.remove_window(),
// ))
// .align_children_center()
// .into_any()
// }
// fn render_unauthorized_modal(
// style: &theme::Copilot,
// cx: &mut ViewContext<Self>,
// ) -> AnyElement<Self> {
// let unauthorized_style = &style.auth.not_authorized;
// Flex::column()
// .with_child(
// Flex::column()
// .with_children([
// Label::new(
// "Enable Copilot by connecting",
// unauthorized_style.subheading.text.clone(),
// )
// .aligned(),
// Label::new(
// "your existing license.",
// unauthorized_style.subheading.text.clone(),
// )
// .aligned(),
// ])
// .align_children_center()
// .contained()
// .with_style(unauthorized_style.subheading.container),
// )
// .with_child(
// Flex::column()
// .with_children([
// Label::new(
// "You must have an active copilot",
// unauthorized_style.warning.text.clone(),
// )
// .aligned(),
// Label::new(
// "license to use it in Zed.",
// unauthorized_style.warning.text.clone(),
// )
// .aligned(),
// ])
// .align_children_center()
// .contained()
// .with_style(unauthorized_style.warning.container),
// )
// .with_child(theme::ui::cta_button::<Self, _, _, _>(
// "Subscribe on GitHub",
// style.auth.content_width,
// &style.auth.cta_button,
// cx,
// |_, _, cx| {
// cx.remove_window();
// cx.platform().open_url(COPILOT_SIGN_UP_URL)
// },
// ))
// .align_children_center()
// .into_any()
// }
// }
// impl Entity for CopilotCodeVerification {
// type Event = ();
// }
// impl View for CopilotCodeVerification {
// fn ui_name() -> &'static str {
// "CopilotCodeVerification"
// }
// fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
// cx.notify()
// }
// fn focus_out(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
// cx.notify()
// }
// fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
// enum ConnectModal {}
// let style = theme::current(cx).clone();
// modal::<ConnectModal, _, _, _, _>(
// "Connect Copilot to Zed",
// &style.copilot.modal,
// cx,
// |cx| {
// Flex::column()
// .with_children([
// theme::ui::icon(&style.copilot.auth.header).into_any(),
// match &self.status {
// Status::SigningIn {
// prompt: Some(prompt),
// } => Self::render_prompting_modal(
// self.connect_clicked,
// &prompt,
// &style.copilot,
// cx,
// ),
// Status::Unauthorized => {
// self.connect_clicked = false;
// Self::render_unauthorized_modal(&style.copilot, cx)
// }
// Status::Authorized => {
// self.connect_clicked = false;
// Self::render_enabled_modal(&style.copilot, cx)
// }
// _ => Empty::new().into_any(),
// },
// ])
// .align_children_center()
// },
// )
// .into_any()
// }
// }

33
crates/db2/Cargo.toml Normal file
View file

@ -0,0 +1,33 @@
[package]
name = "db2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/db2.rs"
doctest = false
[features]
test-support = []
[dependencies]
collections = { path = "../collections" }
gpui2 = { path = "../gpui2" }
sqlez = { path = "../sqlez" }
sqlez_macros = { path = "../sqlez_macros" }
util = { path = "../util" }
anyhow.workspace = true
indoc.workspace = true
async-trait.workspace = true
lazy_static.workspace = true
log.workspace = true
parking_lot.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
[dev-dependencies]
gpui2 = { path = "../gpui2", features = ["test-support"] }
env_logger.workspace = true
tempdir.workspace = true

5
crates/db2/README.md Normal file
View file

@ -0,0 +1,5 @@
# Building Queries
First, craft your test data. The examples folder shows a template for building a test-db, and can be ran with `cargo run --example [your-example]`.
To actually use and test your queries, import the generated DB file into https://sqliteonline.com/

327
crates/db2/src/db2.rs Normal file
View file

@ -0,0 +1,327 @@
pub mod kvp;
pub mod query;
// Re-export
pub use anyhow;
use anyhow::Context;
use gpui2::AppContext;
pub use indoc::indoc;
pub use lazy_static;
pub use smol;
pub use sqlez;
pub use sqlez_macros;
pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
pub use util::paths::DB_DIR;
use sqlez::domain::Migrator;
use sqlez::thread_safe_connection::ThreadSafeConnection;
use sqlez_macros::sql;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use util::channel::ReleaseChannel;
use util::{async_maybe, ResultExt};
const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA foreign_keys=TRUE;
);
const DB_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA journal_mode=WAL;
PRAGMA busy_timeout=1;
PRAGMA case_sensitive_like=TRUE;
PRAGMA synchronous=NORMAL;
);
const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
const DB_FILE_NAME: &'static str = "db.sqlite";
lazy_static::lazy_static! {
pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
}
/// Open or create a database at the given directory path.
/// This will retry a couple times if there are failures. If opening fails once, the db directory
/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
/// In either case, static variables are set so that the user can be notified.
pub async fn open_db<M: Migrator + 'static>(
db_dir: &Path,
release_channel: &ReleaseChannel,
) -> ThreadSafeConnection<M> {
if *ZED_STATELESS {
return open_fallback_db().await;
}
let release_channel_name = release_channel.dev_name();
let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
let connection = async_maybe!({
smol::fs::create_dir_all(&main_db_dir)
.await
.context("Could not create db directory")
.log_err()?;
let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
open_main_db(&db_path).await
})
.await;
if let Some(connection) = connection {
return connection;
}
// Set another static ref so that we can escalate the notification
ALL_FILE_DB_FAILED.store(true, Ordering::Release);
// If still failed, create an in memory db with a known name
open_fallback_db().await
}
async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
log::info!("Opening main db");
ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
.log_err()
}
async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
log::info!("Opening fallback db");
ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
.expect(
"Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
)
}
#[cfg(any(test, feature = "test-support"))]
pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
use sqlez::thread_safe_connection::locking_queue;
ThreadSafeConnection::<M>::builder(db_name, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
// Serialize queued writes via a mutex and run them synchronously
.with_write_queue_constructor(locking_queue())
.build()
.await
.unwrap()
}
/// Implements a basic DB wrapper for a given domain
#[macro_export]
macro_rules! define_connection {
(pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
}
#[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
}
};
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
}
#[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
}
};
}
pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
where
F: Future<Output = anyhow::Result<()>> + Send,
{
cx.executor()
.spawn(async move { db_write().await.log_err() })
.detach()
}
// #[cfg(test)]
// mod tests {
// use std::thread;
// use sqlez::domain::Domain;
// use sqlez_macros::sql;
// use tempdir::TempDir;
// use crate::open_db;
// // Test bad migration panics
// #[gpui::test]
// #[should_panic]
// async fn test_bad_migration_panics() {
// enum BadDB {}
// impl Domain for BadDB {
// fn name() -> &'static str {
// "db_tests"
// }
// fn migrations() -> &'static [&'static str] {
// &[
// sql!(CREATE TABLE test(value);),
// // failure because test already exists
// sql!(CREATE TABLE test(value);),
// ]
// }
// }
// let tempdir = TempDir::new("DbTests").unwrap();
// let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
// }
// /// Test that DB exists but corrupted (causing recreate)
// #[gpui::test]
// async fn test_db_corruption() {
// enum CorruptedDB {}
// impl Domain for CorruptedDB {
// fn name() -> &'static str {
// "db_tests"
// }
// fn migrations() -> &'static [&'static str] {
// &[sql!(CREATE TABLE test(value);)]
// }
// }
// enum GoodDB {}
// impl Domain for GoodDB {
// fn name() -> &'static str {
// "db_tests" //Notice same name
// }
// fn migrations() -> &'static [&'static str] {
// &[sql!(CREATE TABLE test2(value);)] //But different migration
// }
// }
// let tempdir = TempDir::new("DbTests").unwrap();
// {
// let corrupt_db =
// open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
// assert!(corrupt_db.persistent());
// }
// let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
// assert!(
// good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
// .unwrap()
// .is_none()
// );
// }
// /// Test that DB exists but corrupted (causing recreate)
// #[gpui::test(iterations = 30)]
// async fn test_simultaneous_db_corruption() {
// enum CorruptedDB {}
// impl Domain for CorruptedDB {
// fn name() -> &'static str {
// "db_tests"
// }
// fn migrations() -> &'static [&'static str] {
// &[sql!(CREATE TABLE test(value);)]
// }
// }
// enum GoodDB {}
// impl Domain for GoodDB {
// fn name() -> &'static str {
// "db_tests" //Notice same name
// }
// fn migrations() -> &'static [&'static str] {
// &[sql!(CREATE TABLE test2(value);)] //But different migration
// }
// }
// let tempdir = TempDir::new("DbTests").unwrap();
// {
// // Setup the bad database
// let corrupt_db =
// open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
// assert!(corrupt_db.persistent());
// }
// // Try to connect to it a bunch of times at once
// let mut guards = vec![];
// for _ in 0..10 {
// let tmp_path = tempdir.path().to_path_buf();
// let guard = thread::spawn(move || {
// let good_db = smol::block_on(open_db::<GoodDB>(
// tmp_path.as_path(),
// &util::channel::ReleaseChannel::Dev,
// ));
// assert!(
// good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
// .unwrap()
// .is_none()
// );
// });
// guards.push(guard);
// }
// for guard in guards.into_iter() {
// assert!(guard.join().is_ok());
// }
// }
// }

62
crates/db2/src/kvp.rs Normal file
View file

@ -0,0 +1,62 @@
use sqlez_macros::sql;
use crate::{define_connection, query};
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
&[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)];
);
impl KeyValueStore {
query! {
pub fn read_kvp(key: &str) -> Result<Option<String>> {
SELECT value FROM kv_store WHERE key = (?)
}
}
query! {
pub async fn write_kvp(key: String, value: String) -> Result<()> {
INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))
}
}
query! {
pub async fn delete_kvp(key: String) -> Result<()> {
DELETE FROM kv_store WHERE key = (?)
}
}
}
// #[cfg(test)]
// mod tests {
// use crate::kvp::KeyValueStore;
// #[gpui::test]
// async fn test_kvp() {
// let db = KeyValueStore(crate::open_test_db("test_kvp").await);
// assert_eq!(db.read_kvp("key-1").unwrap(), None);
// db.write_kvp("key-1".to_string(), "one".to_string())
// .await
// .unwrap();
// assert_eq!(db.read_kvp("key-1").unwrap(), Some("one".to_string()));
// db.write_kvp("key-1".to_string(), "one-2".to_string())
// .await
// .unwrap();
// assert_eq!(db.read_kvp("key-1").unwrap(), Some("one-2".to_string()));
// db.write_kvp("key-2".to_string(), "two".to_string())
// .await
// .unwrap();
// assert_eq!(db.read_kvp("key-2").unwrap(), Some("two".to_string()));
// db.delete_kvp("key-1".to_string()).await.unwrap();
// assert_eq!(db.read_kvp("key-1").unwrap(), None);
// }
// }

314
crates/db2/src/query.rs Normal file
View file

@ -0,0 +1,314 @@
#[macro_export]
macro_rules! query {
($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt,
))
}
};
($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<$arg_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
pub async fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(indoc! { $sql })?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
}

View file

@ -0,0 +1,12 @@
[package]
name = "feature_flags2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/feature_flags2.rs"
[dependencies]
gpui2 = { path = "../gpui2" }
anyhow.workspace = true

View file

@ -0,0 +1,80 @@
use gpui2::{AppContext, Subscription, ViewContext};
#[derive(Default)]
struct FeatureFlags {
flags: Vec<String>,
staff: bool,
}
impl FeatureFlags {
fn has_flag(&self, flag: &str) -> bool {
self.staff || self.flags.iter().find(|f| f.as_str() == flag).is_some()
}
}
pub trait FeatureFlag {
const NAME: &'static str;
}
pub enum ChannelsAlpha {}
impl FeatureFlag for ChannelsAlpha {
const NAME: &'static str = "channels_alpha";
}
pub trait FeatureFlagViewExt<V: 'static> {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
F: Fn(bool, &mut V, &mut ViewContext<V>) + Send + Sync + 'static;
}
impl<V> FeatureFlagViewExt<V> for ViewContext<'_, '_, V>
where
V: 'static + Send + Sync,
{
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
F: Fn(bool, &mut V, &mut ViewContext<V>) + Send + Sync + 'static,
{
self.observe_global::<FeatureFlags>(move |v, cx| {
let feature_flags = cx.global::<FeatureFlags>();
callback(feature_flags.has_flag(<T as FeatureFlag>::NAME), v, cx);
})
}
}
pub trait FeatureFlagAppExt {
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
fn set_staff(&mut self, staff: bool);
fn has_flag<T: FeatureFlag>(&self) -> bool;
fn is_staff(&self) -> bool;
}
impl FeatureFlagAppExt for AppContext {
fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
let feature_flags = self.default_global::<FeatureFlags>();
feature_flags.staff = staff;
feature_flags.flags = flags;
}
fn set_staff(&mut self, staff: bool) {
let feature_flags = self.default_global::<FeatureFlags>();
feature_flags.staff = staff;
}
fn has_flag<T: FeatureFlag>(&self) -> bool {
if self.has_global::<FeatureFlags>() {
self.global::<FeatureFlags>().has_flag(T::NAME)
} else {
false
}
}
fn is_staff(&self) -> bool {
if self.has_global::<FeatureFlags>() {
return self.global::<FeatureFlags>().staff;
} else {
false
}
}
}

40
crates/fs2/Cargo.toml Normal file
View file

@ -0,0 +1,40 @@
[package]
name = "fs2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/fs2.rs"
[dependencies]
collections = { path = "../collections" }
rope = { path = "../rope" }
text = { path = "../text" }
util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
anyhow.workspace = true
async-trait.workspace = true
futures.workspace = true
tempfile = "3"
fsevent = { path = "../fsevent" }
lazy_static.workspace = true
parking_lot.workspace = true
smol.workspace = true
regex.workspace = true
git2.workspace = true
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
log.workspace = true
libc = "0.2"
time.workspace = true
gpui2 = { path = "../gpui2", optional = true}
[dev-dependencies]
gpui2 = { path = "../gpui2", features = ["test-support"] }
[features]
test-support = ["gpui2/test-support"]

1278
crates/fs2/src/fs2.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,417 @@
use anyhow::Result;
use collections::HashMap;
use git2::{BranchType, StatusShow};
use parking_lot::Mutex;
use serde_derive::{Deserialize, Serialize};
use std::{
cmp::Ordering,
ffi::OsStr,
os::unix::prelude::OsStrExt,
path::{Component, Path, PathBuf},
sync::Arc,
time::SystemTime,
};
use sum_tree::{MapSeekTarget, TreeMap};
use util::ResultExt;
pub use git2::Repository as LibGitRepository;
#[derive(Clone, Debug, Hash, PartialEq)]
pub struct Branch {
pub name: Box<str>,
/// Timestamp of most recent commit, normalized to Unix Epoch format.
pub unix_timestamp: Option<i64>,
}
#[async_trait::async_trait]
pub trait GitRepository: Send {
fn reload_index(&self);
fn load_index_text(&self, relative_file_path: &Path) -> Option<String>;
fn branch_name(&self) -> Option<String>;
/// Get the statuses of all of the files in the index that start with the given
/// path and have changes with resepect to the HEAD commit. This is fast because
/// the index stores hashes of trees, so that unchanged directories can be skipped.
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus>;
/// Get the status of a given file in the working directory with respect to
/// the index. In the common case, when there are no changes, this only requires
/// an index lookup. The index stores the mtime of each file when it was added,
/// so there's no work to do if the mtime matches.
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
/// Get the status of a given file in the working directory with respect to
/// the HEAD commit. In the common case, when there are no changes, this only
/// requires an index lookup and blob comparison between the index and the HEAD
/// commit. The index stores the mtime of each file when it was added, so there's
/// no need to consider the working directory file if the mtime matches.
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus>;
fn branches(&self) -> Result<Vec<Branch>>;
fn change_branch(&self, _: &str) -> Result<()>;
fn create_branch(&self, _: &str) -> Result<()>;
}
impl std::fmt::Debug for dyn GitRepository {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("dyn GitRepository<...>").finish()
}
}
impl GitRepository for LibGitRepository {
fn reload_index(&self) {
if let Ok(mut index) = self.index() {
_ = index.read(false);
}
}
fn load_index_text(&self, relative_file_path: &Path) -> Option<String> {
fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result<Option<String>> {
const STAGE_NORMAL: i32 = 0;
let index = repo.index()?;
// This check is required because index.get_path() unwraps internally :(
check_path_to_repo_path_errors(relative_file_path)?;
let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) {
Some(entry) => entry.id,
None => return Ok(None),
};
let content = repo.find_blob(oid)?.content().to_owned();
Ok(Some(String::from_utf8(content)?))
}
match logic(&self, relative_file_path) {
Ok(value) => return value,
Err(err) => log::error!("Error loading head text: {:?}", err),
}
None
}
fn branch_name(&self) -> Option<String> {
let head = self.head().log_err()?;
let branch = String::from_utf8_lossy(head.shorthand_bytes());
Some(branch.to_string())
}
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
let mut map = TreeMap::default();
let mut options = git2::StatusOptions::new();
options.pathspec(path_prefix);
options.show(StatusShow::Index);
if let Some(statuses) = self.statuses(Some(&mut options)).log_err() {
for status in statuses.iter() {
let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes())));
let status = status.status();
if !status.contains(git2::Status::IGNORED) {
if let Some(status) = read_status(status) {
map.insert(path, status)
}
}
}
}
map
}
fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
// If the file has not changed since it was added to the index, then
// there can't be any changes.
if matches_index(self, path, mtime) {
return None;
}
let mut options = git2::StatusOptions::new();
options.pathspec(&path.0);
options.disable_pathspec_match(true);
options.include_untracked(true);
options.recurse_untracked_dirs(true);
options.include_unmodified(true);
options.show(StatusShow::Workdir);
let statuses = self.statuses(Some(&mut options)).log_err()?;
let status = statuses.get(0).and_then(|s| read_status(s.status()));
status
}
fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option<GitFileStatus> {
let mut options = git2::StatusOptions::new();
options.pathspec(&path.0);
options.disable_pathspec_match(true);
options.include_untracked(true);
options.recurse_untracked_dirs(true);
options.include_unmodified(true);
// If the file has not changed since it was added to the index, then
// there's no need to examine the working directory file: just compare
// the blob in the index to the one in the HEAD commit.
if matches_index(self, path, mtime) {
options.show(StatusShow::Index);
}
let statuses = self.statuses(Some(&mut options)).log_err()?;
let status = statuses.get(0).and_then(|s| read_status(s.status()));
status
}
fn branches(&self) -> Result<Vec<Branch>> {
let local_branches = self.branches(Some(BranchType::Local))?;
let valid_branches = local_branches
.filter_map(|branch| {
branch.ok().and_then(|(branch, _)| {
let name = branch.name().ok().flatten().map(Box::from)?;
let timestamp = branch.get().peel_to_commit().ok()?.time();
let unix_timestamp = timestamp.seconds();
let timezone_offset = timestamp.offset_minutes();
let utc_offset =
time::UtcOffset::from_whole_seconds(timezone_offset * 60).ok()?;
let unix_timestamp =
time::OffsetDateTime::from_unix_timestamp(unix_timestamp).ok()?;
Some(Branch {
name,
unix_timestamp: Some(unix_timestamp.to_offset(utc_offset).unix_timestamp()),
})
})
})
.collect();
Ok(valid_branches)
}
fn change_branch(&self, name: &str) -> Result<()> {
let revision = self.find_branch(name, BranchType::Local)?;
let revision = revision.get();
let as_tree = revision.peel_to_tree()?;
self.checkout_tree(as_tree.as_object(), None)?;
self.set_head(
revision
.name()
.ok_or_else(|| anyhow::anyhow!("Branch name could not be retrieved"))?,
)?;
Ok(())
}
fn create_branch(&self, name: &str) -> Result<()> {
let current_commit = self.head()?.peel_to_commit()?;
self.branch(name, &current_commit, false)?;
Ok(())
}
}
fn matches_index(repo: &LibGitRepository, path: &RepoPath, mtime: SystemTime) -> bool {
if let Some(index) = repo.index().log_err() {
if let Some(entry) = index.get_path(&path, 0) {
if let Some(mtime) = mtime.duration_since(SystemTime::UNIX_EPOCH).log_err() {
if entry.mtime.seconds() == mtime.as_secs() as i32
&& entry.mtime.nanoseconds() == mtime.subsec_nanos()
{
return true;
}
}
}
}
false
}
fn read_status(status: git2::Status) -> Option<GitFileStatus> {
if status.contains(git2::Status::CONFLICTED) {
Some(GitFileStatus::Conflict)
} else if status.intersects(
git2::Status::WT_MODIFIED
| git2::Status::WT_RENAMED
| git2::Status::INDEX_MODIFIED
| git2::Status::INDEX_RENAMED,
) {
Some(GitFileStatus::Modified)
} else if status.intersects(git2::Status::WT_NEW | git2::Status::INDEX_NEW) {
Some(GitFileStatus::Added)
} else {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct FakeGitRepository {
state: Arc<Mutex<FakeGitRepositoryState>>,
}
#[derive(Debug, Clone, Default)]
pub struct FakeGitRepositoryState {
pub index_contents: HashMap<PathBuf, String>,
pub worktree_statuses: HashMap<RepoPath, GitFileStatus>,
pub branch_name: Option<String>,
}
impl FakeGitRepository {
pub fn open(state: Arc<Mutex<FakeGitRepositoryState>>) -> Arc<Mutex<dyn GitRepository>> {
Arc::new(Mutex::new(FakeGitRepository { state }))
}
}
#[async_trait::async_trait]
impl GitRepository for FakeGitRepository {
fn reload_index(&self) {}
fn load_index_text(&self, path: &Path) -> Option<String> {
let state = self.state.lock();
state.index_contents.get(path).cloned()
}
fn branch_name(&self) -> Option<String> {
let state = self.state.lock();
state.branch_name.clone()
}
fn staged_statuses(&self, path_prefix: &Path) -> TreeMap<RepoPath, GitFileStatus> {
let mut map = TreeMap::default();
let state = self.state.lock();
for (repo_path, status) in state.worktree_statuses.iter() {
if repo_path.0.starts_with(path_prefix) {
map.insert(repo_path.to_owned(), status.to_owned());
}
}
map
}
fn unstaged_status(&self, _path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
None
}
fn status(&self, path: &RepoPath, _mtime: SystemTime) -> Option<GitFileStatus> {
let state = self.state.lock();
state.worktree_statuses.get(path).cloned()
}
fn branches(&self) -> Result<Vec<Branch>> {
Ok(vec![])
}
fn change_branch(&self, name: &str) -> Result<()> {
let mut state = self.state.lock();
state.branch_name = Some(name.to_owned());
Ok(())
}
fn create_branch(&self, name: &str) -> Result<()> {
let mut state = self.state.lock();
state.branch_name = Some(name.to_owned());
Ok(())
}
}
fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
match relative_file_path.components().next() {
None => anyhow::bail!("repo path should not be empty"),
Some(Component::Prefix(_)) => anyhow::bail!(
"repo path `{}` should be relative, not a windows prefix",
relative_file_path.to_string_lossy()
),
Some(Component::RootDir) => {
anyhow::bail!(
"repo path `{}` should be relative",
relative_file_path.to_string_lossy()
)
}
Some(Component::CurDir) => {
anyhow::bail!(
"repo path `{}` should not start with `.`",
relative_file_path.to_string_lossy()
)
}
Some(Component::ParentDir) => {
anyhow::bail!(
"repo path `{}` should not start with `..`",
relative_file_path.to_string_lossy()
)
}
_ => Ok(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GitFileStatus {
Added,
Modified,
Conflict,
}
impl GitFileStatus {
pub fn merge(
this: Option<GitFileStatus>,
other: Option<GitFileStatus>,
prefer_other: bool,
) -> Option<GitFileStatus> {
if prefer_other {
return other;
} else {
match (this, other) {
(Some(GitFileStatus::Conflict), _) | (_, Some(GitFileStatus::Conflict)) => {
Some(GitFileStatus::Conflict)
}
(Some(GitFileStatus::Modified), _) | (_, Some(GitFileStatus::Modified)) => {
Some(GitFileStatus::Modified)
}
(Some(GitFileStatus::Added), _) | (_, Some(GitFileStatus::Added)) => {
Some(GitFileStatus::Added)
}
_ => None,
}
}
}
}
#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]
pub struct RepoPath(pub PathBuf);
impl RepoPath {
pub fn new(path: PathBuf) -> Self {
debug_assert!(path.is_relative(), "Repo paths must be relative");
RepoPath(path)
}
}
impl From<&Path> for RepoPath {
fn from(value: &Path) -> Self {
RepoPath::new(value.to_path_buf())
}
}
impl From<PathBuf> for RepoPath {
fn from(value: PathBuf) -> Self {
RepoPath::new(value)
}
}
impl Default for RepoPath {
fn default() -> Self {
RepoPath(PathBuf::new())
}
}
impl AsRef<Path> for RepoPath {
fn as_ref(&self) -> &Path {
self.0.as_ref()
}
}
impl std::ops::Deref for RepoPath {
type Target = PathBuf;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
pub struct RepoPathDescendants<'a>(pub &'a Path);
impl<'a> MapSeekTarget<RepoPath> for RepoPathDescendants<'a> {
fn cmp_cursor(&self, key: &RepoPath) -> Ordering {
if key.starts_with(&self.0) {
Ordering::Greater
} else {
self.0.cmp(key)
}
}
}

13
crates/fuzzy2/Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "fuzzy2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/fuzzy2.rs"
doctest = false
[dependencies]
gpui2 = { path = "../gpui2" }
util = { path = "../util" }

View file

@ -0,0 +1,63 @@
use std::iter::FromIterator;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct CharBag(u64);
impl CharBag {
pub fn is_superset(self, other: CharBag) -> bool {
self.0 & other.0 == other.0
}
fn insert(&mut self, c: char) {
let c = c.to_ascii_lowercase();
if ('a'..='z').contains(&c) {
let mut count = self.0;
let idx = c as u8 - b'a';
count >>= idx * 2;
count = ((count << 1) | 1) & 3;
count <<= idx * 2;
self.0 |= count;
} else if ('0'..='9').contains(&c) {
let idx = c as u8 - b'0';
self.0 |= 1 << (idx + 52);
} else if c == '-' {
self.0 |= 1 << 62;
}
}
}
impl Extend<char> for CharBag {
fn extend<T: IntoIterator<Item = char>>(&mut self, iter: T) {
for c in iter {
self.insert(c);
}
}
}
impl FromIterator<char> for CharBag {
fn from_iter<T: IntoIterator<Item = char>>(iter: T) -> Self {
let mut result = Self::default();
result.extend(iter);
result
}
}
impl From<&str> for CharBag {
fn from(s: &str) -> Self {
let mut bag = Self(0);
for c in s.chars() {
bag.insert(c);
}
bag
}
}
impl From<&[char]> for CharBag {
fn from(chars: &[char]) -> Self {
let mut bag = Self(0);
for c in chars {
bag.insert(*c);
}
bag
}
}

View file

@ -0,0 +1,10 @@
mod char_bag;
mod matcher;
mod paths;
mod strings;
pub use char_bag::CharBag;
pub use paths::{
match_fixed_path_set, match_path_sets, PathMatch, PathMatchCandidate, PathMatchCandidateSet,
};
pub use strings::{match_strings, StringMatch, StringMatchCandidate};

View file

@ -0,0 +1,464 @@
use std::{
borrow::Cow,
sync::atomic::{self, AtomicBool},
};
use crate::CharBag;
const BASE_DISTANCE_PENALTY: f64 = 0.6;
const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
const MIN_DISTANCE_PENALTY: f64 = 0.2;
pub struct Matcher<'a> {
query: &'a [char],
lowercase_query: &'a [char],
query_char_bag: CharBag,
smart_case: bool,
max_results: usize,
min_score: f64,
match_positions: Vec<usize>,
last_positions: Vec<usize>,
score_matrix: Vec<Option<f64>>,
best_position_matrix: Vec<usize>,
}
pub trait Match: Ord {
fn score(&self) -> f64;
fn set_positions(&mut self, positions: Vec<usize>);
}
pub trait MatchCandidate {
fn has_chars(&self, bag: CharBag) -> bool;
fn to_string(&self) -> Cow<'_, str>;
}
impl<'a> Matcher<'a> {
pub fn new(
query: &'a [char],
lowercase_query: &'a [char],
query_char_bag: CharBag,
smart_case: bool,
max_results: usize,
) -> Self {
Self {
query,
lowercase_query,
query_char_bag,
min_score: 0.0,
last_positions: vec![0; query.len()],
match_positions: vec![0; query.len()],
score_matrix: Vec::new(),
best_position_matrix: Vec::new(),
smart_case,
max_results,
}
}
pub fn match_candidates<C: MatchCandidate, R, F>(
&mut self,
prefix: &[char],
lowercase_prefix: &[char],
candidates: impl Iterator<Item = C>,
results: &mut Vec<R>,
cancel_flag: &AtomicBool,
build_match: F,
) where
R: Match,
F: Fn(&C, f64) -> R,
{
let mut candidate_chars = Vec::new();
let mut lowercase_candidate_chars = Vec::new();
for candidate in candidates {
if !candidate.has_chars(self.query_char_bag) {
continue;
}
if cancel_flag.load(atomic::Ordering::Relaxed) {
break;
}
candidate_chars.clear();
lowercase_candidate_chars.clear();
for c in candidate.to_string().chars() {
candidate_chars.push(c);
lowercase_candidate_chars.push(c.to_ascii_lowercase());
}
if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
continue;
}
let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
self.score_matrix.clear();
self.score_matrix.resize(matrix_len, None);
self.best_position_matrix.clear();
self.best_position_matrix.resize(matrix_len, 0);
let score = self.score_match(
&candidate_chars,
&lowercase_candidate_chars,
prefix,
lowercase_prefix,
);
if score > 0.0 {
let mut mat = build_match(&candidate, score);
if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
if results.len() < self.max_results {
mat.set_positions(self.match_positions.clone());
results.insert(i, mat);
} else if i < results.len() {
results.pop();
mat.set_positions(self.match_positions.clone());
results.insert(i, mat);
}
if results.len() == self.max_results {
self.min_score = results.last().unwrap().score();
}
}
}
}
}
fn find_last_positions(
&mut self,
lowercase_prefix: &[char],
lowercase_candidate: &[char],
) -> bool {
let mut lowercase_prefix = lowercase_prefix.iter();
let mut lowercase_candidate = lowercase_candidate.iter();
for (i, char) in self.lowercase_query.iter().enumerate().rev() {
if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
self.last_positions[i] = j + lowercase_prefix.len();
} else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
self.last_positions[i] = j;
} else {
return false;
}
}
true
}
fn score_match(
&mut self,
path: &[char],
path_cased: &[char],
prefix: &[char],
lowercase_prefix: &[char],
) -> f64 {
let score = self.recursive_score_match(
path,
path_cased,
prefix,
lowercase_prefix,
0,
0,
self.query.len() as f64,
) * self.query.len() as f64;
if score <= 0.0 {
return 0.0;
}
let path_len = prefix.len() + path.len();
let mut cur_start = 0;
let mut byte_ix = 0;
let mut char_ix = 0;
for i in 0..self.query.len() {
let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
while char_ix < match_char_ix {
let ch = prefix
.get(char_ix)
.or_else(|| path.get(char_ix - prefix.len()))
.unwrap();
byte_ix += ch.len_utf8();
char_ix += 1;
}
cur_start = match_char_ix + 1;
self.match_positions[i] = byte_ix;
}
score
}
#[allow(clippy::too_many_arguments)]
fn recursive_score_match(
&mut self,
path: &[char],
path_cased: &[char],
prefix: &[char],
lowercase_prefix: &[char],
query_idx: usize,
path_idx: usize,
cur_score: f64,
) -> f64 {
if query_idx == self.query.len() {
return 1.0;
}
let path_len = prefix.len() + path.len();
if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
return memoized;
}
let mut score = 0.0;
let mut best_position = 0;
let query_char = self.lowercase_query[query_idx];
let limit = self.last_positions[query_idx];
let mut last_slash = 0;
for j in path_idx..=limit {
let path_char = if j < prefix.len() {
lowercase_prefix[j]
} else {
path_cased[j - prefix.len()]
};
let is_path_sep = path_char == '/' || path_char == '\\';
if query_idx == 0 && is_path_sep {
last_slash = j;
}
if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
let curr = if j < prefix.len() {
prefix[j]
} else {
path[j - prefix.len()]
};
let mut char_score = 1.0;
if j > path_idx {
let last = if j - 1 < prefix.len() {
prefix[j - 1]
} else {
path[j - 1 - prefix.len()]
};
if last == '/' {
char_score = 0.9;
} else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
|| (last.is_lowercase() && curr.is_uppercase())
{
char_score = 0.8;
} else if last == '.' {
char_score = 0.7;
} else if query_idx == 0 {
char_score = BASE_DISTANCE_PENALTY;
} else {
char_score = MIN_DISTANCE_PENALTY.max(
BASE_DISTANCE_PENALTY
- (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
);
}
}
// Apply a severe penalty if the case doesn't match.
// This will make the exact matches have higher score than the case-insensitive and the
// path insensitive matches.
if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
char_score *= 0.001;
}
let mut multiplier = char_score;
// Scale the score based on how deep within the path we found the match.
if query_idx == 0 {
multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
}
let mut next_score = 1.0;
if self.min_score > 0.0 {
next_score = cur_score * multiplier;
// Scores only decrease. If we can't pass the previous best, bail
if next_score < self.min_score {
// Ensure that score is non-zero so we use it in the memo table.
if score == 0.0 {
score = 1e-18;
}
continue;
}
}
let new_score = self.recursive_score_match(
path,
path_cased,
prefix,
lowercase_prefix,
query_idx + 1,
j + 1,
next_score,
) * multiplier;
if new_score > score {
score = new_score;
best_position = j;
// Optimization: can't score better than 1.
if new_score == 1.0 {
break;
}
}
}
}
if best_position != 0 {
self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
}
self.score_matrix[query_idx * path_len + path_idx] = Some(score);
score
}
}
#[cfg(test)]
mod tests {
use crate::{PathMatch, PathMatchCandidate};
use super::*;
use std::{
path::{Path, PathBuf},
sync::Arc,
};
#[test]
fn test_get_last_positions() {
let mut query: &[char] = &['d', 'c'];
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
assert!(!result);
query = &['c', 'd'];
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
assert!(result);
assert_eq!(matcher.last_positions, vec![2, 4]);
query = &['z', '/', 'z', 'f'];
let mut matcher = Matcher::new(query, query, query.into(), false, 10);
let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
assert!(result);
assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
}
#[test]
fn test_match_path_entries() {
let paths = vec![
"",
"a",
"ab",
"abC",
"abcd",
"alphabravocharlie",
"AlphaBravoCharlie",
"thisisatestdir",
"/////ThisIsATestDir",
"/this/is/a/test/dir",
"/test/tiatd",
];
assert_eq!(
match_single_path_query("abc", false, &paths),
vec![
("abC", vec![0, 1, 2]),
("abcd", vec![0, 1, 2]),
("AlphaBravoCharlie", vec![0, 5, 10]),
("alphabravocharlie", vec![4, 5, 10]),
]
);
assert_eq!(
match_single_path_query("t/i/a/t/d", false, &paths),
vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
);
assert_eq!(
match_single_path_query("tiatd", false, &paths),
vec![
("/test/tiatd", vec![6, 7, 8, 9, 10]),
("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
("thisisatestdir", vec![0, 2, 6, 7, 11]),
]
);
}
#[test]
fn test_match_multibyte_path_entries() {
let paths = vec!["aαbβ/cγ", "αβγδ/bcde", "c1⃣2⃣3⃣/d4⃣5⃣6⃣/e7⃣8⃣9⃣/f", "/d/🆒/h"];
assert_eq!("1".len(), 7);
assert_eq!(
match_single_path_query("bcd", false, &paths),
vec![
("αβγδ/bcde", vec![9, 10, 11]),
("aαbβ/cγ", vec![3, 7, 10]),
]
);
assert_eq!(
match_single_path_query("cde", false, &paths),
vec![
("αβγδ/bcde", vec![10, 11, 12]),
("c1⃣2⃣3⃣/d4⃣5⃣6⃣/e7⃣8⃣9⃣/f", vec![0, 23, 46]),
]
);
}
fn match_single_path_query<'a>(
query: &str,
smart_case: bool,
paths: &[&'a str],
) -> Vec<(&'a str, Vec<usize>)> {
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
let query = query.chars().collect::<Vec<_>>();
let query_chars = CharBag::from(&lowercase_query[..]);
let path_arcs: Vec<Arc<Path>> = paths
.iter()
.map(|path| Arc::from(PathBuf::from(path)))
.collect::<Vec<_>>();
let mut path_entries = Vec::new();
for (i, path) in paths.iter().enumerate() {
let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
let char_bag = CharBag::from(lowercase_path.as_slice());
path_entries.push(PathMatchCandidate {
char_bag,
path: &path_arcs[i],
});
}
let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
let cancel_flag = AtomicBool::new(false);
let mut results = Vec::new();
matcher.match_candidates(
&[],
&[],
path_entries.into_iter(),
&mut results,
&cancel_flag,
|candidate, score| PathMatch {
score,
worktree_id: 0,
positions: Vec::new(),
path: Arc::from(candidate.path),
path_prefix: "".into(),
distance_to_relative_ancestor: usize::MAX,
},
);
results
.into_iter()
.map(|result| {
(
paths
.iter()
.copied()
.find(|p| result.path.as_ref() == Path::new(p))
.unwrap(),
result.positions,
)
})
.collect()
}
}

257
crates/fuzzy2/src/paths.rs Normal file
View file

@ -0,0 +1,257 @@
use gpui2::Executor;
use std::{
borrow::Cow,
cmp::{self, Ordering},
path::Path,
sync::{atomic::AtomicBool, Arc},
};
use crate::{
matcher::{Match, MatchCandidate, Matcher},
CharBag,
};
#[derive(Clone, Debug)]
pub struct PathMatchCandidate<'a> {
pub path: &'a Path,
pub char_bag: CharBag,
}
#[derive(Clone, Debug)]
pub struct PathMatch {
pub score: f64,
pub positions: Vec<usize>,
pub worktree_id: usize,
pub path: Arc<Path>,
pub path_prefix: Arc<str>,
/// Number of steps removed from a shared parent with the relative path
/// Used to order closer paths first in the search list
pub distance_to_relative_ancestor: usize,
}
pub trait PathMatchCandidateSet<'a>: Send + Sync {
type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
fn id(&self) -> usize;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn prefix(&self) -> Arc<str>;
fn candidates(&'a self, start: usize) -> Self::Candidates;
}
impl Match for PathMatch {
fn score(&self) -> f64 {
self.score
}
fn set_positions(&mut self, positions: Vec<usize>) {
self.positions = positions;
}
}
impl<'a> MatchCandidate for PathMatchCandidate<'a> {
fn has_chars(&self, bag: CharBag) -> bool {
self.char_bag.is_superset(bag)
}
fn to_string(&self) -> Cow<'a, str> {
self.path.to_string_lossy()
}
}
impl PartialEq for PathMatch {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for PathMatch {}
impl PartialOrd for PathMatch {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PathMatch {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.then_with(|| self.worktree_id.cmp(&other.worktree_id))
.then_with(|| {
other
.distance_to_relative_ancestor
.cmp(&self.distance_to_relative_ancestor)
})
.then_with(|| self.path.cmp(&other.path))
}
}
pub fn match_fixed_path_set(
candidates: Vec<PathMatchCandidate>,
worktree_id: usize,
query: &str,
smart_case: bool,
max_results: usize,
) -> Vec<PathMatch> {
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
let query = query.chars().collect::<Vec<_>>();
let query_char_bag = CharBag::from(&lowercase_query[..]);
let mut matcher = Matcher::new(
&query,
&lowercase_query,
query_char_bag,
smart_case,
max_results,
);
let mut results = Vec::new();
matcher.match_candidates(
&[],
&[],
candidates.into_iter(),
&mut results,
&AtomicBool::new(false),
|candidate, score| PathMatch {
score,
worktree_id,
positions: Vec::new(),
path: Arc::from(candidate.path),
path_prefix: Arc::from(""),
distance_to_relative_ancestor: usize::MAX,
},
);
results
}
pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>(
candidate_sets: &'a [Set],
query: &str,
relative_to: Option<Arc<Path>>,
smart_case: bool,
max_results: usize,
cancel_flag: &AtomicBool,
executor: Executor,
) -> Vec<PathMatch> {
let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
if path_count == 0 {
return Vec::new();
}
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
let query = query.chars().collect::<Vec<_>>();
let lowercase_query = &lowercase_query;
let query = &query;
let query_char_bag = CharBag::from(&lowercase_query[..]);
let num_cpus = executor.num_cpus().min(path_count);
let segment_size = (path_count + num_cpus - 1) / num_cpus;
let mut segment_results = (0..num_cpus)
.map(|_| Vec::with_capacity(max_results))
.collect::<Vec<_>>();
executor
.scoped(|scope| {
for (segment_idx, results) in segment_results.iter_mut().enumerate() {
let relative_to = relative_to.clone();
scope.spawn(async move {
let segment_start = segment_idx * segment_size;
let segment_end = segment_start + segment_size;
let mut matcher = Matcher::new(
query,
lowercase_query,
query_char_bag,
smart_case,
max_results,
);
let mut tree_start = 0;
for candidate_set in candidate_sets {
let tree_end = tree_start + candidate_set.len();
if tree_start < segment_end && segment_start < tree_end {
let start = cmp::max(tree_start, segment_start) - tree_start;
let end = cmp::min(tree_end, segment_end) - tree_start;
let candidates = candidate_set.candidates(start).take(end - start);
let worktree_id = candidate_set.id();
let prefix = candidate_set.prefix().chars().collect::<Vec<_>>();
let lowercase_prefix = prefix
.iter()
.map(|c| c.to_ascii_lowercase())
.collect::<Vec<_>>();
matcher.match_candidates(
&prefix,
&lowercase_prefix,
candidates,
results,
cancel_flag,
|candidate, score| PathMatch {
score,
worktree_id,
positions: Vec::new(),
path: Arc::from(candidate.path),
path_prefix: candidate_set.prefix(),
distance_to_relative_ancestor: relative_to.as_ref().map_or(
usize::MAX,
|relative_to| {
distance_between_paths(
candidate.path.as_ref(),
relative_to.as_ref(),
)
},
),
},
);
}
if tree_end >= segment_end {
break;
}
tree_start = tree_end;
}
})
}
})
.await;
let mut results = Vec::new();
for segment_result in segment_results {
if results.is_empty() {
results = segment_result;
} else {
util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
}
}
results
}
/// Compute the distance from a given path to some other path
/// If there is no shared path, returns usize::MAX
fn distance_between_paths(path: &Path, relative_to: &Path) -> usize {
let mut path_components = path.components();
let mut relative_components = relative_to.components();
while path_components
.next()
.zip(relative_components.next())
.map(|(path_component, relative_component)| path_component == relative_component)
.unwrap_or_default()
{}
path_components.count() + relative_components.count() + 1
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::distance_between_paths;
#[test]
fn test_distance_between_paths_empty() {
distance_between_paths(Path::new(""), Path::new(""));
}
}

View file

@ -0,0 +1,159 @@
use crate::{
matcher::{Match, MatchCandidate, Matcher},
CharBag,
};
use gpui2::Executor;
use std::{
borrow::Cow,
cmp::{self, Ordering},
sync::atomic::AtomicBool,
};
#[derive(Clone, Debug)]
pub struct StringMatchCandidate {
pub id: usize,
pub string: String,
pub char_bag: CharBag,
}
impl Match for StringMatch {
fn score(&self) -> f64 {
self.score
}
fn set_positions(&mut self, positions: Vec<usize>) {
self.positions = positions;
}
}
impl StringMatchCandidate {
pub fn new(id: usize, string: String) -> Self {
Self {
id,
char_bag: CharBag::from(string.as_str()),
string,
}
}
}
impl<'a> MatchCandidate for &'a StringMatchCandidate {
fn has_chars(&self, bag: CharBag) -> bool {
self.char_bag.is_superset(bag)
}
fn to_string(&self) -> Cow<'a, str> {
self.string.as_str().into()
}
}
#[derive(Clone, Debug)]
pub struct StringMatch {
pub candidate_id: usize,
pub score: f64,
pub positions: Vec<usize>,
pub string: String,
}
impl PartialEq for StringMatch {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl Eq for StringMatch {}
impl PartialOrd for StringMatch {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for StringMatch {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.then_with(|| self.candidate_id.cmp(&other.candidate_id))
}
}
pub async fn match_strings(
candidates: &[StringMatchCandidate],
query: &str,
smart_case: bool,
max_results: usize,
cancel_flag: &AtomicBool,
executor: Executor,
) -> Vec<StringMatch> {
if candidates.is_empty() || max_results == 0 {
return Default::default();
}
if query.is_empty() {
return candidates
.iter()
.map(|candidate| StringMatch {
candidate_id: candidate.id,
score: 0.,
positions: Default::default(),
string: candidate.string.clone(),
})
.collect();
}
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
let query = query.chars().collect::<Vec<_>>();
let lowercase_query = &lowercase_query;
let query = &query;
let query_char_bag = CharBag::from(&lowercase_query[..]);
let num_cpus = executor.num_cpus().min(candidates.len());
let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
let mut segment_results = (0..num_cpus)
.map(|_| Vec::with_capacity(max_results.min(candidates.len())))
.collect::<Vec<_>>();
executor
.scoped(|scope| {
for (segment_idx, results) in segment_results.iter_mut().enumerate() {
let cancel_flag = &cancel_flag;
scope.spawn(async move {
let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
let segment_end = cmp::min(segment_start + segment_size, candidates.len());
let mut matcher = Matcher::new(
query,
lowercase_query,
query_char_bag,
smart_case,
max_results,
);
matcher.match_candidates(
&[],
&[],
candidates[segment_start..segment_end].iter(),
results,
cancel_flag,
|candidate, score| StringMatch {
candidate_id: candidate.id,
score,
positions: Vec::new(),
string: candidate.string.to_string(),
},
);
});
}
})
.await;
let mut results = Vec::new();
for segment_result in segment_results {
if results.is_empty() {
results = segment_result;
} else {
util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
}
}
results
}

View file

@ -3607,7 +3607,7 @@ impl<V> BorrowWindowContext for EventContext<'_, '_, '_, V> {
}
}
pub(crate) enum Reference<'a, T> {
pub enum Reference<'a, T> {
Immutable(&'a T),
Mutable(&'a mut T),
}

View file

@ -154,6 +154,11 @@ impl Refineable for TextStyleRefinement {
self.underline = refinement.underline;
}
}
fn refined(mut self, refinement: Self::Refinement) -> Self {
self.refine(&refinement);
self
}
}
#[derive(JsonSchema)]

View file

@ -84,7 +84,6 @@ impl ImageCache {
let format = image::guess_format(&body)?;
let image =
image::load_from_memory_with_format(&body, format)?.into_bgra8();
Ok(ImageData::new(image))
}
}

View file

@ -109,6 +109,7 @@ impl AtlasAllocator {
};
descriptor.set_width(size.x() as u64);
descriptor.set_height(size.y() as u64);
self.device.new_texture(&descriptor)
} else {
self.device.new_texture(&self.texture_descriptor)

View file

@ -632,6 +632,7 @@ impl Renderer {
) {
// Snap sprite to pixel grid.
let origin = (glyph.origin * scale_factor).floor() + sprite.offset.to_f32();
sprites_by_atlas
.entry(sprite.atlas_id)
.or_insert_with(Vec::new)

View file

@ -82,6 +82,7 @@ const NSWindowAnimationBehaviorUtilityWindow: NSInteger = 4;
#[ctor]
unsafe fn build_classes() {
::util::gpui1_loaded();
WINDOW_CLASS = build_window_class("GPUIWindow", class!(NSWindow));
PANEL_CLASS = build_window_class("GPUIPanel", class!(NSPanel));
VIEW_CLASS = {

View file

@ -22,8 +22,8 @@ use std::{
};
pub struct TextLayoutCache {
prev_frame: Mutex<HashMap<CacheKeyValue, Arc<LineLayout>>>,
curr_frame: RwLock<HashMap<CacheKeyValue, Arc<LineLayout>>>,
prev_frame: Mutex<HashMap<OwnedCacheKey, Arc<LineLayout>>>,
curr_frame: RwLock<HashMap<OwnedCacheKey, Arc<LineLayout>>>,
fonts: Arc<dyn platform::FontSystem>,
}
@ -56,7 +56,7 @@ impl TextLayoutCache {
font_size: f32,
runs: &'a [(usize, RunStyle)],
) -> Line {
let key = &CacheKeyRef {
let key = &BorrowedCacheKey {
text,
font_size: OrderedFloat(font_size),
runs,
@ -72,7 +72,7 @@ impl TextLayoutCache {
Line::new(layout, runs)
} else {
let layout = Arc::new(self.fonts.layout_line(text, font_size, runs));
let key = CacheKeyValue {
let key = OwnedCacheKey {
text: text.into(),
font_size: OrderedFloat(font_size),
runs: SmallVec::from(runs),
@ -84,7 +84,7 @@ impl TextLayoutCache {
}
trait CacheKey {
fn key(&self) -> CacheKeyRef;
fn key(&self) -> BorrowedCacheKey;
}
impl<'a> PartialEq for (dyn CacheKey + 'a) {
@ -102,15 +102,15 @@ impl<'a> Hash for (dyn CacheKey + 'a) {
}
#[derive(Eq)]
struct CacheKeyValue {
struct OwnedCacheKey {
text: String,
font_size: OrderedFloat<f32>,
runs: SmallVec<[(usize, RunStyle); 1]>,
}
impl CacheKey for CacheKeyValue {
fn key(&self) -> CacheKeyRef {
CacheKeyRef {
impl CacheKey for OwnedCacheKey {
fn key(&self) -> BorrowedCacheKey {
BorrowedCacheKey {
text: self.text.as_str(),
font_size: self.font_size,
runs: self.runs.as_slice(),
@ -118,38 +118,38 @@ impl CacheKey for CacheKeyValue {
}
}
impl PartialEq for CacheKeyValue {
impl PartialEq for OwnedCacheKey {
fn eq(&self, other: &Self) -> bool {
self.key().eq(&other.key())
}
}
impl Hash for CacheKeyValue {
impl Hash for OwnedCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.key().hash(state);
}
}
impl<'a> Borrow<dyn CacheKey + 'a> for CacheKeyValue {
impl<'a> Borrow<dyn CacheKey + 'a> for OwnedCacheKey {
fn borrow(&self) -> &(dyn CacheKey + 'a) {
self as &dyn CacheKey
}
}
#[derive(Copy, Clone)]
struct CacheKeyRef<'a> {
struct BorrowedCacheKey<'a> {
text: &'a str,
font_size: OrderedFloat<f32>,
runs: &'a [(usize, RunStyle)],
}
impl<'a> CacheKey for CacheKeyRef<'a> {
fn key(&self) -> CacheKeyRef {
impl<'a> CacheKey for BorrowedCacheKey<'a> {
fn key(&self) -> BorrowedCacheKey {
*self
}
}
impl<'a> PartialEq for CacheKeyRef<'a> {
impl<'a> PartialEq for BorrowedCacheKey<'a> {
fn eq(&self, other: &Self) -> bool {
self.text == other.text
&& self.font_size == other.font_size
@ -162,7 +162,7 @@ impl<'a> PartialEq for CacheKeyRef<'a> {
}
}
impl<'a> Hash for CacheKeyRef<'a> {
impl<'a> Hash for BorrowedCacheKey<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.text.hash(state);
self.font_size.hash(state);

View file

@ -2,31 +2,86 @@
name = "gpui2"
version = "0.1.0"
edition = "2021"
authors = ["Nathan Sobo <nathan@zed.dev>"]
description = "The next version of Zed's GPU-accelerated UI framework"
publish = false
[lib]
name = "gpui2"
path = "src/gpui2.rs"
[features]
test-support = ["gpui/test-support"]
test-support = ["backtrace", "dhat", "env_logger", "collections/test-support", "util/test-support"]
[lib]
path = "src/gpui2.rs"
doctest = false
[dependencies]
anyhow.workspace = true
derive_more.workspace = true
gpui = { path = "../gpui" }
log.workspace = true
futures.workspace = true
collections = { path = "../collections" }
gpui_macros = { path = "../gpui_macros" }
gpui2_macros = { path = "../gpui2_macros" }
parking_lot.workspace = true
refineable.workspace = true
rust-embed.workspace = true
serde.workspace = true
settings = { path = "../settings" }
simplelog = "0.9"
smallvec.workspace = true
theme = { path = "../theme" }
util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
sqlez = { path = "../sqlez" }
async-task = "4.0.3"
backtrace = { version = "0.3", optional = true }
ctor.workspace = true
derive_more.workspace = true
dhat = { version = "0.3", optional = true }
env_logger = { version = "0.9", optional = true }
etagere = "0.2"
futures.workspace = true
image = "0.23"
itertools = "0.10"
lazy_static.workspace = true
log.workspace = true
num_cpus = "1.13"
ordered-float.workspace = true
parking = "2.0.0"
parking_lot.workspace = true
pathfinder_geometry = "0.5"
postage.workspace = true
rand.workspace = true
refineable.workspace = true
resvg = "0.14"
seahash = "4.1"
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
smallvec.workspace = true
smol.workspace = true
taffy = { git = "https://github.com/DioxusLabs/taffy", rev = "4fb530bdd71609bb1d3f76c6a8bde1ba82805d5e" }
thiserror.workspace = true
time.workspace = true
tiny-skia = "0.5"
usvg = { version = "0.14", features = [] }
uuid = { version = "1.1.2", features = ["v4"] }
waker-fn = "1.1.0"
slotmap = "1.0.6"
schemars.workspace = true
plane-split = "0.18.0"
bitflags = "2.4.0"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
backtrace = "0.3"
collections = { path = "../collections", features = ["test-support"] }
dhat = "0.3"
env_logger.workspace = true
png = "0.16"
simplelog = "0.9"
util = { path = "../util", features = ["test-support"] }
[build-dependencies]
bindgen = "0.65.1"
cbindgen = "0.26.0"
[target.'cfg(target_os = "macos")'.dependencies]
media = { path = "../media" }
anyhow.workspace = true
block = "0.1"
cocoa = "0.24"
core-foundation = { version = "0.9.3", features = ["with-uuid"] }
core-graphics = "0.22.3"
core-text = "19.2"
font-kit = { git = "https://github.com/zed-industries/font-kit", rev = "b2f77d56f450338aa4f7dd2f0197d8c9acb0cf18" }
foreign-types = "0.3"
log.workspace = true
metal = "0.21.0"
objc = "0.2"

134
crates/gpui2/build.rs Normal file
View file

@ -0,0 +1,134 @@
use std::{
env,
path::{Path, PathBuf},
process::{self, Command},
};
use cbindgen::Config;
fn main() {
generate_dispatch_bindings();
let header_path = generate_shader_bindings();
compile_metal_shaders(&header_path);
}
fn generate_dispatch_bindings() {
println!("cargo:rustc-link-lib=framework=System");
println!("cargo:rerun-if-changed=src/platform/mac/dispatch.h");
let bindings = bindgen::Builder::default()
.header("src/platform/mac/dispatch.h")
.allowlist_var("_dispatch_main_q")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
.allowlist_function("dispatch_get_global_queue")
.allowlist_function("dispatch_async_f")
.allowlist_function("dispatch_after_f")
.allowlist_function("dispatch_time")
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.layout_tests(false)
.generate()
.expect("unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("dispatch_sys.rs"))
.expect("couldn't write dispatch bindings");
}
fn generate_shader_bindings() -> PathBuf {
let output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("scene.h");
let crate_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let mut config = Config::default();
config.include_guard = Some("SCENE_H".into());
config.language = cbindgen::Language::C;
config.export.include.extend([
"Bounds".into(),
"Corners".into(),
"Edges".into(),
"Size".into(),
"Pixels".into(),
"PointF".into(),
"Hsla".into(),
"ContentMask".into(),
"Uniforms".into(),
"AtlasTile".into(),
"PathRasterizationInputIndex".into(),
"PathVertex_ScaledPixels".into(),
"ShadowInputIndex".into(),
"Shadow".into(),
"QuadInputIndex".into(),
"Underline".into(),
"UnderlineInputIndex".into(),
"Quad".into(),
"SpriteInputIndex".into(),
"MonochromeSprite".into(),
"PolychromeSprite".into(),
"PathSprite".into(),
]);
config.no_includes = true;
config.enumeration.prefix_with_name = true;
cbindgen::Builder::new()
.with_src(crate_dir.join("src/scene.rs"))
.with_src(crate_dir.join("src/geometry.rs"))
.with_src(crate_dir.join("src/color.rs"))
.with_src(crate_dir.join("src/window.rs"))
.with_src(crate_dir.join("src/platform.rs"))
.with_src(crate_dir.join("src/platform/mac/metal_renderer.rs"))
.with_config(config)
.generate()
.expect("Unable to generate bindings")
.write_to_file(&output_path);
output_path
}
fn compile_metal_shaders(header_path: &Path) {
let shader_path = "./src/platform/mac/shaders.metal";
let air_output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("shaders.air");
let metallib_output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("shaders.metallib");
println!("cargo:rerun-if-changed={}", header_path.display());
println!("cargo:rerun-if-changed={}", shader_path);
let output = Command::new("xcrun")
.args([
"-sdk",
"macosx",
"metal",
"-gline-tables-only",
"-mmacosx-version-min=10.15.7",
"-MO",
"-c",
shader_path,
"-include",
&header_path.to_str().unwrap(),
"-o",
])
.arg(&air_output_path)
.output()
.unwrap();
if !output.status.success() {
eprintln!(
"metal shader compilation failed:\n{}",
String::from_utf8_lossy(&output.stderr)
);
process::exit(1);
}
let output = Command::new("xcrun")
.args(["-sdk", "macosx", "metallib"])
.arg(air_output_path)
.arg("-o")
.arg(metallib_output_path)
.output()
.unwrap();
if !output.status.success() {
eprintln!(
"metallib compilation failed:\n{}",
String::from_utf8_lossy(&output.stderr)
);
process::exit(1);
}
}

432
crates/gpui2/src/action.rs Normal file
View file

@ -0,0 +1,432 @@
use crate::SharedString;
use anyhow::{anyhow, Context, Result};
use collections::{HashMap, HashSet};
use serde::Deserialize;
use std::any::{type_name, Any};
pub trait Action: Any + Send {
fn qualified_name() -> SharedString
where
Self: Sized;
fn build(value: Option<serde_json::Value>) -> Result<Box<dyn Action>>
where
Self: Sized;
fn partial_eq(&self, action: &dyn Action) -> bool;
fn boxed_clone(&self) -> Box<dyn Action>;
fn as_any(&self) -> &dyn Any;
}
impl<A> Action for A
where
A: for<'a> Deserialize<'a> + PartialEq + Any + Send + Clone + Default,
{
fn qualified_name() -> SharedString {
type_name::<A>().into()
}
fn build(params: Option<serde_json::Value>) -> Result<Box<dyn Action>>
where
Self: Sized,
{
let action = if let Some(params) = params {
serde_json::from_value(params).context("failed to deserialize action")?
} else {
Self::default()
};
Ok(Box::new(action))
}
fn partial_eq(&self, action: &dyn Action) -> bool {
action
.as_any()
.downcast_ref::<Self>()
.map_or(false, |a| self == a)
}
fn boxed_clone(&self) -> Box<dyn Action> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct DispatchContext {
set: HashSet<SharedString>,
map: HashMap<SharedString, SharedString>,
}
impl<'a> TryFrom<&'a str> for DispatchContext {
type Error = anyhow::Error;
fn try_from(value: &'a str) -> Result<Self> {
Self::parse(value)
}
}
impl DispatchContext {
pub fn parse(source: &str) -> Result<Self> {
let mut context = Self::default();
let source = skip_whitespace(source);
Self::parse_expr(&source, &mut context)?;
Ok(context)
}
fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
if source.is_empty() {
return Ok(());
}
let key = source
.chars()
.take_while(|c| is_identifier_char(*c))
.collect::<String>();
source = skip_whitespace(&source[key.len()..]);
if let Some(suffix) = source.strip_prefix('=') {
source = skip_whitespace(suffix);
let value = source
.chars()
.take_while(|c| is_identifier_char(*c))
.collect::<String>();
source = skip_whitespace(&source[value.len()..]);
context.set(key, value);
} else {
context.insert(key);
}
Self::parse_expr(source, context)
}
pub fn is_empty(&self) -> bool {
self.set.is_empty() && self.map.is_empty()
}
pub fn clear(&mut self) {
self.set.clear();
self.map.clear();
}
pub fn extend(&mut self, other: &Self) {
for v in &other.set {
self.set.insert(v.clone());
}
for (k, v) in &other.map {
self.map.insert(k.clone(), v.clone());
}
}
pub fn insert<I: Into<SharedString>>(&mut self, identifier: I) {
self.set.insert(identifier.into());
}
pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
self.map.insert(key.into(), value.into());
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum DispatchContextPredicate {
Identifier(SharedString),
Equal(SharedString, SharedString),
NotEqual(SharedString, SharedString),
Child(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
Not(Box<DispatchContextPredicate>),
And(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
Or(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
}
impl DispatchContextPredicate {
pub fn parse(source: &str) -> Result<Self> {
let source = skip_whitespace(source);
let (predicate, rest) = Self::parse_expr(source, 0)?;
if let Some(next) = rest.chars().next() {
Err(anyhow!("unexpected character {next:?}"))
} else {
Ok(predicate)
}
}
pub fn eval(&self, contexts: &[&DispatchContext]) -> bool {
let Some(context) = contexts.last() else {
return false;
};
match self {
Self::Identifier(name) => context.set.contains(name),
Self::Equal(left, right) => context
.map
.get(left)
.map(|value| value == right)
.unwrap_or(false),
Self::NotEqual(left, right) => context
.map
.get(left)
.map(|value| value != right)
.unwrap_or(true),
Self::Not(pred) => !pred.eval(contexts),
Self::Child(parent, child) => {
parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
}
Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
}
}
fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
type Op = fn(
DispatchContextPredicate,
DispatchContextPredicate,
) -> Result<DispatchContextPredicate>;
let (mut predicate, rest) = Self::parse_primary(source)?;
source = rest;
'parse: loop {
for (operator, precedence, constructor) in [
(">", PRECEDENCE_CHILD, Self::new_child as Op),
("&&", PRECEDENCE_AND, Self::new_and as Op),
("||", PRECEDENCE_OR, Self::new_or as Op),
("==", PRECEDENCE_EQ, Self::new_eq as Op),
("!=", PRECEDENCE_EQ, Self::new_neq as Op),
] {
if source.starts_with(operator) && precedence >= min_precedence {
source = skip_whitespace(&source[operator.len()..]);
let (right, rest) = Self::parse_expr(source, precedence + 1)?;
predicate = constructor(predicate, right)?;
source = rest;
continue 'parse;
}
}
break;
}
Ok((predicate, source))
}
fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
let next = source
.chars()
.next()
.ok_or_else(|| anyhow!("unexpected eof"))?;
match next {
'(' => {
source = skip_whitespace(&source[1..]);
let (predicate, rest) = Self::parse_expr(source, 0)?;
if rest.starts_with(')') {
source = skip_whitespace(&rest[1..]);
Ok((predicate, source))
} else {
Err(anyhow!("expected a ')'"))
}
}
'!' => {
let source = skip_whitespace(&source[1..]);
let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?;
Ok((DispatchContextPredicate::Not(Box::new(predicate)), source))
}
_ if is_identifier_char(next) => {
let len = source
.find(|c: char| !is_identifier_char(c))
.unwrap_or(source.len());
let (identifier, rest) = source.split_at(len);
source = skip_whitespace(rest);
Ok((
DispatchContextPredicate::Identifier(identifier.to_string().into()),
source,
))
}
_ => Err(anyhow!("unexpected character {next:?}")),
}
}
fn new_or(self, other: Self) -> Result<Self> {
Ok(Self::Or(Box::new(self), Box::new(other)))
}
fn new_and(self, other: Self) -> Result<Self> {
Ok(Self::And(Box::new(self), Box::new(other)))
}
fn new_child(self, other: Self) -> Result<Self> {
Ok(Self::Child(Box::new(self), Box::new(other)))
}
fn new_eq(self, other: Self) -> Result<Self> {
if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
Ok(Self::Equal(left, right))
} else {
Err(anyhow!("operands must be identifiers"))
}
}
fn new_neq(self, other: Self) -> Result<Self> {
if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
Ok(Self::NotEqual(left, right))
} else {
Err(anyhow!("operands must be identifiers"))
}
}
}
const PRECEDENCE_CHILD: u32 = 1;
const PRECEDENCE_OR: u32 = 2;
const PRECEDENCE_AND: u32 = 3;
const PRECEDENCE_EQ: u32 = 4;
const PRECEDENCE_NOT: u32 = 5;
fn is_identifier_char(c: char) -> bool {
c.is_alphanumeric() || c == '_' || c == '-'
}
fn skip_whitespace(source: &str) -> &str {
let len = source
.find(|c: char| !c.is_whitespace())
.unwrap_or(source.len());
&source[len..]
}
#[cfg(test)]
mod tests {
use super::*;
use DispatchContextPredicate::*;
#[test]
fn test_parse_context() {
let mut expected = DispatchContext::default();
expected.set("foo", "bar");
expected.insert("baz");
assert_eq!(DispatchContext::parse("baz foo=bar").unwrap(), expected);
assert_eq!(DispatchContext::parse("foo = bar baz").unwrap(), expected);
assert_eq!(
DispatchContext::parse(" baz foo = bar baz").unwrap(),
expected
);
assert_eq!(DispatchContext::parse(" foo = bar baz").unwrap(), expected);
}
#[test]
fn test_parse_identifiers() {
// Identifiers
assert_eq!(
DispatchContextPredicate::parse("abc12").unwrap(),
Identifier("abc12".into())
);
assert_eq!(
DispatchContextPredicate::parse("_1a").unwrap(),
Identifier("_1a".into())
);
}
#[test]
fn test_parse_negations() {
assert_eq!(
DispatchContextPredicate::parse("!abc").unwrap(),
Not(Box::new(Identifier("abc".into())))
);
assert_eq!(
DispatchContextPredicate::parse(" ! ! abc").unwrap(),
Not(Box::new(Not(Box::new(Identifier("abc".into())))))
);
}
#[test]
fn test_parse_equality_operators() {
assert_eq!(
DispatchContextPredicate::parse("a == b").unwrap(),
Equal("a".into(), "b".into())
);
assert_eq!(
DispatchContextPredicate::parse("c!=d").unwrap(),
NotEqual("c".into(), "d".into())
);
assert_eq!(
DispatchContextPredicate::parse("c == !d")
.unwrap_err()
.to_string(),
"operands must be identifiers"
);
}
#[test]
fn test_parse_boolean_operators() {
assert_eq!(
DispatchContextPredicate::parse("a || b").unwrap(),
Or(
Box::new(Identifier("a".into())),
Box::new(Identifier("b".into()))
)
);
assert_eq!(
DispatchContextPredicate::parse("a || !b && c").unwrap(),
Or(
Box::new(Identifier("a".into())),
Box::new(And(
Box::new(Not(Box::new(Identifier("b".into())))),
Box::new(Identifier("c".into()))
))
)
);
assert_eq!(
DispatchContextPredicate::parse("a && b || c&&d").unwrap(),
Or(
Box::new(And(
Box::new(Identifier("a".into())),
Box::new(Identifier("b".into()))
)),
Box::new(And(
Box::new(Identifier("c".into())),
Box::new(Identifier("d".into()))
))
)
);
assert_eq!(
DispatchContextPredicate::parse("a == b && c || d == e && f").unwrap(),
Or(
Box::new(And(
Box::new(Equal("a".into(), "b".into())),
Box::new(Identifier("c".into()))
)),
Box::new(And(
Box::new(Equal("d".into(), "e".into())),
Box::new(Identifier("f".into()))
))
)
);
assert_eq!(
DispatchContextPredicate::parse("a && b && c && d").unwrap(),
And(
Box::new(And(
Box::new(And(
Box::new(Identifier("a".into())),
Box::new(Identifier("b".into()))
)),
Box::new(Identifier("c".into())),
)),
Box::new(Identifier("d".into()))
),
);
}
#[test]
fn test_parse_parenthesized_expressions() {
assert_eq!(
DispatchContextPredicate::parse("a && (b == c || d != e)").unwrap(),
And(
Box::new(Identifier("a".into())),
Box::new(Or(
Box::new(Equal("b".into(), "c".into())),
Box::new(NotEqual("d".into(), "e".into())),
)),
),
);
assert_eq!(
DispatchContextPredicate::parse(" ( a || b ) ").unwrap(),
Or(
Box::new(Identifier("a".into())),
Box::new(Identifier("b".into())),
)
);
}
}

View file

@ -1,76 +0,0 @@
use crate::ViewContext;
use gpui::{geometry::rect::RectF, LayoutEngine, LayoutId};
use util::ResultExt;
/// Makes a new, gpui2-style element into a legacy element.
pub struct AdapterElement<V>(pub(crate) crate::element::AnyElement<V>);
impl<V: 'static> gpui::Element<V> for AdapterElement<V> {
type LayoutState = Option<(LayoutEngine, LayoutId)>;
type PaintState = ();
fn layout(
&mut self,
constraint: gpui::SizeConstraint,
view: &mut V,
cx: &mut gpui::ViewContext<V>,
) -> (gpui::geometry::vector::Vector2F, Self::LayoutState) {
cx.push_layout_engine(LayoutEngine::new());
let mut cx = ViewContext::new(cx);
let layout_id = self.0.layout(view, &mut cx).log_err();
if let Some(layout_id) = layout_id {
cx.layout_engine()
.unwrap()
.compute_layout(layout_id, constraint.max)
.log_err();
}
let layout_engine = cx.pop_layout_engine();
debug_assert!(layout_engine.is_some(),
"unexpected layout stack state. is there an unmatched pop_layout_engine in the called code?"
);
(constraint.max, layout_engine.zip(layout_id))
}
fn paint(
&mut self,
bounds: RectF,
_visible_bounds: RectF,
layout_data: &mut Option<(LayoutEngine, LayoutId)>,
view: &mut V,
cx: &mut gpui::ViewContext<V>,
) -> Self::PaintState {
let (layout_engine, layout_id) = layout_data.take().unwrap();
cx.push_layout_engine(layout_engine);
self.0
.paint(view, bounds.origin(), &mut ViewContext::new(cx));
*layout_data = cx.pop_layout_engine().zip(Some(layout_id));
debug_assert!(layout_data.is_some());
}
fn rect_for_text_range(
&self,
_range_utf16: std::ops::Range<usize>,
_bounds: RectF,
_visible_bounds: RectF,
_layout: &Self::LayoutState,
_paint: &Self::PaintState,
_view: &V,
_cx: &gpui::ViewContext<V>,
) -> Option<RectF> {
todo!("implement before merging to main")
}
fn debug(
&self,
_bounds: RectF,
_layout: &Self::LayoutState,
_paint: &Self::PaintState,
_view: &V,
_cx: &gpui::ViewContext<V>,
) -> gpui::serde_json::Value {
todo!("implement before merging to main")
}
}

919
crates/gpui2/src/app.rs Normal file
View file

@ -0,0 +1,919 @@
mod async_context;
mod entity_map;
mod model_context;
#[cfg(any(test, feature = "test-support"))]
mod test_context;
pub use async_context::*;
pub use entity_map::*;
pub use model_context::*;
use refineable::Refineable;
use smallvec::SmallVec;
#[cfg(any(test, feature = "test-support"))]
pub use test_context::*;
use crate::{
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AppMetadata, AssetSource,
ClipboardItem, Context, DispatchPhase, DisplayId, Executor, FocusEvent, FocusHandle, FocusId,
KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly, Pixels, Platform, Point, Render,
SharedString, SubscriberSet, Subscription, SvgRenderer, Task, TextStyle, TextStyleRefinement,
TextSystem, View, Window, WindowContext, WindowHandle, WindowId,
};
use anyhow::{anyhow, Result};
use collections::{HashMap, HashSet, VecDeque};
use futures::{future::BoxFuture, Future};
use parking_lot::Mutex;
use slotmap::SlotMap;
use std::{
any::{type_name, Any, TypeId},
borrow::Borrow,
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
path::PathBuf,
sync::{atomic::Ordering::SeqCst, Arc, Weak},
time::Duration,
};
use util::http::{self, HttpClient};
pub struct App(Arc<Mutex<AppContext>>);
/// Represents an application before it is fully launched. Once your app is
/// configured, you'll start the app with `App::run`.
impl App {
/// Builds an app with the given asset source.
pub fn production(asset_source: Arc<dyn AssetSource>) -> Self {
Self(AppContext::new(
current_platform(),
asset_source,
http::client(),
))
}
/// Start the application. The provided callback will be called once the
/// app is fully launched.
pub fn run<F>(self, on_finish_launching: F)
where
F: 'static + FnOnce(&mut MainThread<AppContext>),
{
let this = self.0.clone();
let platform = self.0.lock().platform.clone();
platform.borrow_on_main_thread().run(Box::new(move || {
let cx = &mut *this.lock();
let cx = unsafe { mem::transmute::<&mut AppContext, &mut MainThread<AppContext>>(cx) };
on_finish_launching(cx);
}));
}
/// Register a handler to be invoked when the platform instructs the application
/// to open one or more URLs.
pub fn on_open_urls<F>(&self, mut callback: F) -> &Self
where
F: 'static + FnMut(Vec<String>, &mut AppContext),
{
let this = Arc::downgrade(&self.0);
self.0
.lock()
.platform
.borrow_on_main_thread()
.on_open_urls(Box::new(move |urls| {
if let Some(app) = this.upgrade() {
callback(urls, &mut app.lock());
}
}));
self
}
pub fn on_reopen<F>(&self, mut callback: F) -> &Self
where
F: 'static + FnMut(&mut AppContext),
{
let this = Arc::downgrade(&self.0);
self.0
.lock()
.platform
.borrow_on_main_thread()
.on_reopen(Box::new(move || {
if let Some(app) = this.upgrade() {
callback(&mut app.lock());
}
}));
self
}
pub fn metadata(&self) -> AppMetadata {
self.0.lock().app_metadata.clone()
}
pub fn executor(&self) -> Executor {
self.0.lock().executor.clone()
}
pub fn text_system(&self) -> Arc<TextSystem> {
self.0.lock().text_system.clone()
}
}
type ActionBuilder = fn(json: Option<serde_json::Value>) -> anyhow::Result<Box<dyn Action>>;
type FrameCallback = Box<dyn FnOnce(&mut WindowContext) + Send>;
type Handler = Box<dyn FnMut(&mut AppContext) -> bool + Send + 'static>;
type Listener = Box<dyn FnMut(&dyn Any, &mut AppContext) -> bool + Send + 'static>;
type QuitHandler = Box<dyn FnMut(&mut AppContext) -> BoxFuture<'static, ()> + Send + 'static>;
type ReleaseListener = Box<dyn FnMut(&mut dyn Any, &mut AppContext) + Send + 'static>;
pub struct AppContext {
this: Weak<Mutex<AppContext>>,
pub(crate) platform: MainThreadOnly<dyn Platform>,
app_metadata: AppMetadata,
text_system: Arc<TextSystem>,
flushing_effects: bool,
pending_updates: usize,
pub(crate) active_drag: Option<AnyDrag>,
pub(crate) next_frame_callbacks: HashMap<DisplayId, Vec<FrameCallback>>,
pub(crate) executor: Executor,
pub(crate) svg_renderer: SvgRenderer,
asset_source: Arc<dyn AssetSource>,
pub(crate) image_cache: ImageCache,
pub(crate) text_style_stack: Vec<TextStyleRefinement>,
pub(crate) globals_by_type: HashMap<TypeId, AnyBox>,
pub(crate) entities: EntityMap,
pub(crate) windows: SlotMap<WindowId, Option<Window>>,
pub(crate) keymap: Arc<Mutex<Keymap>>,
pub(crate) global_action_listeners:
HashMap<TypeId, Vec<Box<dyn Fn(&dyn Action, DispatchPhase, &mut Self) + Send>>>,
action_builders: HashMap<SharedString, ActionBuilder>,
pending_effects: VecDeque<Effect>,
pub(crate) pending_notifications: HashSet<EntityId>,
pub(crate) pending_global_notifications: HashSet<TypeId>,
pub(crate) observers: SubscriberSet<EntityId, Handler>,
pub(crate) event_listeners: SubscriberSet<EntityId, Listener>,
pub(crate) release_listeners: SubscriberSet<EntityId, ReleaseListener>,
pub(crate) global_observers: SubscriberSet<TypeId, Handler>,
pub(crate) quit_observers: SubscriberSet<(), QuitHandler>,
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
pub(crate) propagate_event: bool,
}
impl AppContext {
pub(crate) fn new(
platform: Arc<dyn Platform>,
asset_source: Arc<dyn AssetSource>,
http_client: Arc<dyn HttpClient>,
) -> Arc<Mutex<Self>> {
let executor = platform.executor();
assert!(
executor.is_main_thread(),
"must construct App on main thread"
);
let text_system = Arc::new(TextSystem::new(platform.text_system()));
let entities = EntityMap::new();
let app_metadata = AppMetadata {
os_name: platform.os_name(),
os_version: platform.os_version().ok(),
app_version: platform.app_version().ok(),
};
Arc::new_cyclic(|this| {
Mutex::new(AppContext {
this: this.clone(),
text_system,
platform: MainThreadOnly::new(platform, executor.clone()),
app_metadata,
flushing_effects: false,
pending_updates: 0,
next_frame_callbacks: Default::default(),
executor,
svg_renderer: SvgRenderer::new(asset_source.clone()),
asset_source,
image_cache: ImageCache::new(http_client),
text_style_stack: Vec::new(),
globals_by_type: HashMap::default(),
entities,
windows: SlotMap::with_key(),
keymap: Arc::new(Mutex::new(Keymap::default())),
global_action_listeners: HashMap::default(),
action_builders: HashMap::default(),
pending_effects: VecDeque::new(),
pending_notifications: HashSet::default(),
pending_global_notifications: HashSet::default(),
observers: SubscriberSet::new(),
event_listeners: SubscriberSet::new(),
release_listeners: SubscriberSet::new(),
global_observers: SubscriberSet::new(),
quit_observers: SubscriberSet::new(),
layout_id_buffer: Default::default(),
propagate_event: true,
active_drag: None,
})
})
}
/// Quit the application gracefully. Handlers registered with `ModelContext::on_app_quit`
/// will be given 100ms to complete before exiting.
pub fn quit(&mut self) {
let mut futures = Vec::new();
self.quit_observers.clone().retain(&(), |observer| {
futures.push(observer(self));
true
});
self.windows.clear();
self.flush_effects();
let futures = futures::future::join_all(futures);
if self
.executor
.block_with_timeout(Duration::from_millis(100), futures)
.is_err()
{
log::error!("timed out waiting on app_will_quit");
}
self.globals_by_type.clear();
}
pub fn app_metadata(&self) -> AppMetadata {
self.app_metadata.clone()
}
/// Schedules all windows in the application to be redrawn. This can be called
/// multiple times in an update cycle and still result in a single redraw.
pub fn refresh(&mut self) {
self.pending_effects.push_back(Effect::Refresh);
}
pub(crate) fn update<R>(&mut self, update: impl FnOnce(&mut Self) -> R) -> R {
self.pending_updates += 1;
let result = update(self);
if !self.flushing_effects && self.pending_updates == 1 {
self.flushing_effects = true;
self.flush_effects();
self.flushing_effects = false;
}
self.pending_updates -= 1;
result
}
pub(crate) fn read_window<R>(
&mut self,
id: WindowId,
read: impl FnOnce(&WindowContext) -> R,
) -> Result<R> {
let window = self
.windows
.get(id)
.ok_or_else(|| anyhow!("window not found"))?
.as_ref()
.unwrap();
Ok(read(&WindowContext::immutable(self, &window)))
}
pub(crate) fn update_window<R>(
&mut self,
id: WindowId,
update: impl FnOnce(&mut WindowContext) -> R,
) -> Result<R> {
self.update(|cx| {
let mut window = cx
.windows
.get_mut(id)
.ok_or_else(|| anyhow!("window not found"))?
.take()
.unwrap();
let result = update(&mut WindowContext::mutable(cx, &mut window));
cx.windows
.get_mut(id)
.ok_or_else(|| anyhow!("window not found"))?
.replace(window);
Ok(result)
})
}
pub(crate) fn push_effect(&mut self, effect: Effect) {
match &effect {
Effect::Notify { emitter } => {
if !self.pending_notifications.insert(*emitter) {
return;
}
}
Effect::NotifyGlobalObservers { global_type } => {
if !self.pending_global_notifications.insert(*global_type) {
return;
}
}
_ => {}
};
self.pending_effects.push_back(effect);
}
/// Called at the end of AppContext::update to complete any side effects
/// such as notifying observers, emitting events, etc. Effects can themselves
/// cause effects, so we continue looping until all effects are processed.
fn flush_effects(&mut self) {
loop {
self.release_dropped_entities();
self.release_dropped_focus_handles();
if let Some(effect) = self.pending_effects.pop_front() {
match effect {
Effect::Notify { emitter } => {
self.apply_notify_effect(emitter);
}
Effect::Emit { emitter, event } => self.apply_emit_effect(emitter, event),
Effect::FocusChanged { window_id, focused } => {
self.apply_focus_changed_effect(window_id, focused);
}
Effect::Refresh => {
self.apply_refresh_effect();
}
Effect::NotifyGlobalObservers { global_type } => {
self.apply_notify_global_observers_effect(global_type);
}
Effect::Defer { callback } => {
self.apply_defer_effect(callback);
}
}
} else {
break;
}
}
let dirty_window_ids = self
.windows
.iter()
.filter_map(|(window_id, window)| {
let window = window.as_ref().unwrap();
if window.dirty {
Some(window_id)
} else {
None
}
})
.collect::<SmallVec<[_; 8]>>();
for dirty_window_id in dirty_window_ids {
self.update_window(dirty_window_id, |cx| cx.draw()).unwrap();
}
}
/// Repeatedly called during `flush_effects` to release any entities whose
/// reference count has become zero. We invoke any release observers before dropping
/// each entity.
fn release_dropped_entities(&mut self) {
loop {
let dropped = self.entities.take_dropped();
if dropped.is_empty() {
break;
}
for (entity_id, mut entity) in dropped {
self.observers.remove(&entity_id);
self.event_listeners.remove(&entity_id);
for mut release_callback in self.release_listeners.remove(&entity_id) {
release_callback(&mut entity, self);
}
}
}
}
/// Repeatedly called during `flush_effects` to handle a focused handle being dropped.
/// For now, we simply blur the window if this happens, but we may want to support invoking
/// a window blur handler to restore focus to some logical element.
fn release_dropped_focus_handles(&mut self) {
let window_ids = self.windows.keys().collect::<SmallVec<[_; 8]>>();
for window_id in window_ids {
self.update_window(window_id, |cx| {
let mut blur_window = false;
let focus = cx.window.focus;
cx.window.focus_handles.write().retain(|handle_id, count| {
if count.load(SeqCst) == 0 {
if focus == Some(handle_id) {
blur_window = true;
}
false
} else {
true
}
});
if blur_window {
cx.blur();
}
})
.unwrap();
}
}
fn apply_notify_effect(&mut self, emitter: EntityId) {
self.pending_notifications.remove(&emitter);
self.observers
.clone()
.retain(&emitter, |handler| handler(self));
}
fn apply_emit_effect(&mut self, emitter: EntityId, event: Box<dyn Any>) {
self.event_listeners
.clone()
.retain(&emitter, |handler| handler(event.as_ref(), self));
}
fn apply_focus_changed_effect(&mut self, window_id: WindowId, focused: Option<FocusId>) {
self.update_window(window_id, |cx| {
if cx.window.focus == focused {
let mut listeners = mem::take(&mut cx.window.focus_listeners);
let focused =
focused.map(|id| FocusHandle::for_id(id, &cx.window.focus_handles).unwrap());
let blurred = cx
.window
.last_blur
.take()
.unwrap()
.and_then(|id| FocusHandle::for_id(id, &cx.window.focus_handles));
if focused.is_some() || blurred.is_some() {
let event = FocusEvent { focused, blurred };
for listener in &listeners {
listener(&event, cx);
}
}
listeners.extend(cx.window.focus_listeners.drain(..));
cx.window.focus_listeners = listeners;
}
})
.ok();
}
fn apply_refresh_effect(&mut self) {
for window in self.windows.values_mut() {
if let Some(window) = window.as_mut() {
window.dirty = true;
}
}
}
fn apply_notify_global_observers_effect(&mut self, type_id: TypeId) {
self.pending_global_notifications.remove(&type_id);
self.global_observers
.clone()
.retain(&type_id, |observer| observer(self));
}
fn apply_defer_effect(&mut self, callback: Box<dyn FnOnce(&mut Self) + Send + 'static>) {
callback(self);
}
/// Creates an `AsyncAppContext`, which can be cloned and has a static lifetime
/// so it can be held across `await` points.
pub fn to_async(&self) -> AsyncAppContext {
AsyncAppContext {
app: unsafe { mem::transmute(self.this.clone()) },
executor: self.executor.clone(),
}
}
/// Obtains a reference to the executor, which can be used to spawn futures.
pub fn executor(&self) -> &Executor {
&self.executor
}
/// Runs the given closure on the main thread, where interaction with the platform
/// is possible. The given closure will be invoked with a `MainThread<AppContext>`, which
/// has platform-specific methods that aren't present on `AppContext`.
pub fn run_on_main<R>(
&mut self,
f: impl FnOnce(&mut MainThread<AppContext>) -> R + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
if self.executor.is_main_thread() {
Task::ready(f(unsafe {
mem::transmute::<&mut AppContext, &mut MainThread<AppContext>>(self)
}))
} else {
let this = self.this.upgrade().unwrap();
self.executor.run_on_main(move || {
let cx = &mut *this.lock();
cx.update(|cx| f(unsafe { mem::transmute::<&mut Self, &mut MainThread<Self>>(cx) }))
})
}
}
/// Spawns the future returned by the given function on the main thread, where interaction with
/// the platform is possible. The given closure will be invoked with a `MainThread<AsyncAppContext>`,
/// which has platform-specific methods that aren't present on `AsyncAppContext`. The future will be
/// polled exclusively on the main thread.
// todo!("I think we need somehow to prevent the MainThread<AsyncAppContext> from implementing Send")
pub fn spawn_on_main<F, R>(
&self,
f: impl FnOnce(MainThread<AsyncAppContext>) -> F + Send + 'static,
) -> Task<R>
where
F: Future<Output = R> + 'static,
R: Send + 'static,
{
let cx = self.to_async();
self.executor.spawn_on_main(move || f(MainThread(cx)))
}
/// Spawns the future returned by the given function on the thread pool. The closure will be invoked
/// with AsyncAppContext, which allows the application state to be accessed across await points.
pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task<R>
where
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let cx = self.to_async();
self.executor.spawn(async move {
let future = f(cx);
future.await
})
}
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
/// that are currently on the stack to be returned to the app.
pub fn defer(&mut self, f: impl FnOnce(&mut AppContext) + 'static + Send) {
self.push_effect(Effect::Defer {
callback: Box::new(f),
});
}
/// Accessor for the application's asset source, which is provided when constructing the `App`.
pub fn asset_source(&self) -> &Arc<dyn AssetSource> {
&self.asset_source
}
/// Accessor for the text system.
pub fn text_system(&self) -> &Arc<TextSystem> {
&self.text_system
}
/// The current text style. Which is composed of all the style refinements provided to `with_text_style`.
pub fn text_style(&self) -> TextStyle {
let mut style = TextStyle::default();
for refinement in &self.text_style_stack {
style.refine(refinement);
}
style
}
/// Check whether a global of the given type has been assigned.
pub fn has_global<G: 'static>(&self) -> bool {
self.globals_by_type.contains_key(&TypeId::of::<G>())
}
/// Access the global of the given type. Panics if a global for that type has not been assigned.
pub fn global<G: 'static>(&self) -> &G {
self.globals_by_type
.get(&TypeId::of::<G>())
.map(|any_state| any_state.downcast_ref::<G>().unwrap())
.ok_or_else(|| anyhow!("no state of type {} exists", type_name::<G>()))
.unwrap()
}
/// Access the global of the given type if a value has been assigned.
pub fn try_global<G: 'static>(&self) -> Option<&G> {
self.globals_by_type
.get(&TypeId::of::<G>())
.map(|any_state| any_state.downcast_ref::<G>().unwrap())
}
/// Access the global of the given type mutably. Panics if a global for that type has not been assigned.
pub fn global_mut<G: 'static>(&mut self) -> &mut G {
let global_type = TypeId::of::<G>();
self.push_effect(Effect::NotifyGlobalObservers { global_type });
self.globals_by_type
.get_mut(&global_type)
.and_then(|any_state| any_state.downcast_mut::<G>())
.ok_or_else(|| anyhow!("no state of type {} exists", type_name::<G>()))
.unwrap()
}
/// Access the global of the given type mutably. A default value is assigned if a global of this type has not
/// yet been assigned.
pub fn default_global<G: 'static + Default + Send>(&mut self) -> &mut G {
let global_type = TypeId::of::<G>();
self.push_effect(Effect::NotifyGlobalObservers { global_type });
self.globals_by_type
.entry(global_type)
.or_insert_with(|| Box::new(G::default()))
.downcast_mut::<G>()
.unwrap()
}
/// Set the value of the global of the given type.
pub fn set_global<G: Any + Send>(&mut self, global: G) {
let global_type = TypeId::of::<G>();
self.push_effect(Effect::NotifyGlobalObservers { global_type });
self.globals_by_type.insert(global_type, Box::new(global));
}
/// Update the global of the given type with a closure. Unlike `global_mut`, this method provides
/// your closure with mutable access to the `AppContext` and the global simultaneously.
pub fn update_global<G: 'static, R>(&mut self, f: impl FnOnce(&mut G, &mut Self) -> R) -> R {
let mut global = self.lease_global::<G>();
let result = f(&mut global, self);
self.end_global_lease(global);
result
}
/// Register a callback to be invoked when a global of the given type is updated.
pub fn observe_global<G: 'static>(
&mut self,
mut f: impl FnMut(&mut Self) + Send + 'static,
) -> Subscription {
self.global_observers.insert(
TypeId::of::<G>(),
Box::new(move |cx| {
f(cx);
true
}),
)
}
pub fn all_action_names<'a>(&'a self) -> impl Iterator<Item = SharedString> + 'a {
self.action_builders.keys().cloned()
}
/// Move the global of the given type to the stack.
pub(crate) fn lease_global<G: 'static>(&mut self) -> GlobalLease<G> {
GlobalLease::new(
self.globals_by_type
.remove(&TypeId::of::<G>())
.ok_or_else(|| anyhow!("no global registered of type {}", type_name::<G>()))
.unwrap(),
)
}
/// Restore the global of the given type after it is moved to the stack.
pub(crate) fn end_global_lease<G: 'static>(&mut self, lease: GlobalLease<G>) {
let global_type = TypeId::of::<G>();
self.push_effect(Effect::NotifyGlobalObservers { global_type });
self.globals_by_type.insert(global_type, lease.global);
}
pub(crate) fn push_text_style(&mut self, text_style: TextStyleRefinement) {
self.text_style_stack.push(text_style);
}
pub(crate) fn pop_text_style(&mut self) {
self.text_style_stack.pop();
}
/// Register key bindings.
pub fn bind_keys(&mut self, bindings: impl IntoIterator<Item = KeyBinding>) {
self.keymap.lock().add_bindings(bindings);
self.pending_effects.push_back(Effect::Refresh);
}
/// Register a global listener for actions invoked via the keyboard.
pub fn on_action<A: Action>(&mut self, listener: impl Fn(&A, &mut Self) + Send + 'static) {
self.global_action_listeners
.entry(TypeId::of::<A>())
.or_default()
.push(Box::new(move |action, phase, cx| {
if phase == DispatchPhase::Bubble {
let action = action.as_any().downcast_ref().unwrap();
listener(action, cx)
}
}));
}
/// Register an action type to allow it to be referenced in keymaps.
pub fn register_action_type<A: Action>(&mut self) {
self.action_builders.insert(A::qualified_name(), A::build);
}
/// Construct an action based on its name and parameters.
pub fn build_action(
&mut self,
name: &str,
params: Option<serde_json::Value>,
) -> Result<Box<dyn Action>> {
let build = self
.action_builders
.get(name)
.ok_or_else(|| anyhow!("no action type registered for {}", name))?;
(build)(params)
}
/// Halt propagation of a mouse event, keyboard event, or action. This prevents listeners
/// that have not yet been invoked from receiving the event.
pub fn stop_propagation(&mut self) {
self.propagate_event = false;
}
}
impl Context for AppContext {
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T;
/// Build an entity that is owned by the application. The given function will be invoked with
/// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
/// which can be used to access the entity in a context.
fn build_model<T: 'static + Send>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Model<T> {
self.update(|cx| {
let slot = cx.entities.reserve();
let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade()));
cx.entities.insert(slot, entity)
})
}
/// Update the entity referenced by the given model. The function is passed a mutable reference to the
/// entity along with a `ModelContext` for the entity.
fn update_model<T: 'static, R>(
&mut self,
model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
self.update(|cx| {
let mut entity = cx.entities.lease(model);
let result = update(
&mut entity,
&mut ModelContext::mutable(cx, model.downgrade()),
);
cx.entities.end_lease(entity);
result
})
}
}
impl<C> MainThread<C>
where
C: Borrow<AppContext>,
{
pub(crate) fn platform(&self) -> &dyn Platform {
self.0.borrow().platform.borrow_on_main_thread()
}
/// Instructs the platform to activate the application by bringing it to the foreground.
pub fn activate(&self, ignoring_other_apps: bool) {
self.platform().activate(ignoring_other_apps);
}
/// Writes data to the platform clipboard.
pub fn write_to_clipboard(&self, item: ClipboardItem) {
self.platform().write_to_clipboard(item)
}
/// Reads data from the platform clipboard.
pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
self.platform().read_from_clipboard()
}
/// Writes credentials to the platform keychain.
pub fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()> {
self.platform().write_credentials(url, username, password)
}
/// Reads credentials from the platform keychain.
pub fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>> {
self.platform().read_credentials(url)
}
/// Deletes credentials from the platform keychain.
pub fn delete_credentials(&self, url: &str) -> Result<()> {
self.platform().delete_credentials(url)
}
/// Directs the platform's default browser to open the given URL.
pub fn open_url(&self, url: &str) {
self.platform().open_url(url);
}
pub fn path_for_auxiliary_executable(&self, name: &str) -> Result<PathBuf> {
self.platform().path_for_auxiliary_executable(name)
}
}
impl MainThread<AppContext> {
fn update<R>(&mut self, update: impl FnOnce(&mut Self) -> R) -> R {
self.0.update(|cx| {
update(unsafe {
std::mem::transmute::<&mut AppContext, &mut MainThread<AppContext>>(cx)
})
})
}
pub(crate) fn update_window<R>(
&mut self,
id: WindowId,
update: impl FnOnce(&mut MainThread<WindowContext>) -> R,
) -> Result<R> {
self.0.update_window(id, |cx| {
update(unsafe {
std::mem::transmute::<&mut WindowContext, &mut MainThread<WindowContext>>(cx)
})
})
}
/// Opens a new window with the given option and the root view returned by the given function.
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
/// functionality.
pub fn open_window<V: Render>(
&mut self,
options: crate::WindowOptions,
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
) -> WindowHandle<V> {
self.update(|cx| {
let id = cx.windows.insert(None);
let handle = WindowHandle::new(id);
let mut window = Window::new(handle.into(), options, cx);
let root_view = build_root_view(&mut WindowContext::mutable(cx, &mut window));
window.root_view.replace(root_view.into());
cx.windows.get_mut(id).unwrap().replace(window);
handle
})
}
/// Update the global of the given type with a closure. Unlike `global_mut`, this method provides
/// your closure with mutable access to the `MainThread<AppContext>` and the global simultaneously.
pub fn update_global<G: 'static + Send, R>(
&mut self,
update: impl FnOnce(&mut G, &mut MainThread<AppContext>) -> R,
) -> R {
self.0.update_global(|global, cx| {
let cx = unsafe { mem::transmute::<&mut AppContext, &mut MainThread<AppContext>>(cx) };
update(global, cx)
})
}
}
/// These effects are processed at the end of each application update cycle.
pub(crate) enum Effect {
Notify {
emitter: EntityId,
},
Emit {
emitter: EntityId,
event: Box<dyn Any + Send + 'static>,
},
FocusChanged {
window_id: WindowId,
focused: Option<FocusId>,
},
Refresh,
NotifyGlobalObservers {
global_type: TypeId,
},
Defer {
callback: Box<dyn FnOnce(&mut AppContext) + Send + 'static>,
},
}
/// Wraps a global variable value during `update_global` while the value has been moved to the stack.
pub(crate) struct GlobalLease<G: 'static> {
global: AnyBox,
global_type: PhantomData<G>,
}
impl<G: 'static> GlobalLease<G> {
fn new(global: AnyBox) -> Self {
GlobalLease {
global,
global_type: PhantomData,
}
}
}
impl<G: 'static> Deref for GlobalLease<G> {
type Target = G;
fn deref(&self) -> &Self::Target {
self.global.downcast_ref().unwrap()
}
}
impl<G: 'static> DerefMut for GlobalLease<G> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.global.downcast_mut().unwrap()
}
}
/// Contains state associated with an active drag operation, started by dragging an element
/// within the window or by dragging into the app from the underlying platform.
pub(crate) struct AnyDrag {
pub view: AnyView,
pub cursor_offset: Point<Pixels>,
}
#[cfg(test)]
mod tests {
use super::AppContext;
#[test]
fn test_app_context_send_sync() {
// This will not compile if `AppContext` does not implement `Send`
fn assert_send<T: Send>() {}
assert_send::<AppContext>();
}
}

View file

@ -0,0 +1,252 @@
use crate::{
AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task,
WindowContext,
};
use anyhow::anyhow;
use derive_more::{Deref, DerefMut};
use parking_lot::Mutex;
use std::{future::Future, sync::Weak};
#[derive(Clone)]
pub struct AsyncAppContext {
pub(crate) app: Weak<Mutex<AppContext>>,
pub(crate) executor: Executor,
}
impl Context for AsyncAppContext {
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>;
fn build_model<T: 'static>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.lock(); // Need this to compile
Ok(lock.build_model(build_model))
}
fn update_model<T: 'static, R>(
&mut self,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.lock(); // Need this to compile
Ok(lock.update_model(handle, update))
}
}
impl AsyncAppContext {
pub fn refresh(&mut self) -> Result<()> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.lock(); // Need this to compile
lock.refresh();
Ok(())
}
pub fn executor(&self) -> &Executor {
&self.executor
}
pub fn update<R>(&self, f: impl FnOnce(&mut AppContext) -> R) -> Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.lock();
Ok(f(&mut *lock))
}
pub fn read_window<R>(
&self,
handle: AnyWindowHandle,
update: impl FnOnce(&WindowContext) -> R,
) -> Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut app_context = app.lock();
app_context.read_window(handle.id, update)
}
pub fn update_window<R>(
&self,
handle: AnyWindowHandle,
update: impl FnOnce(&mut WindowContext) -> R,
) -> Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut app_context = app.lock();
app_context.update_window(handle.id, update)
}
pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task<R>
where
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let this = self.clone();
self.executor.spawn(async move { f(this).await })
}
pub fn spawn_on_main<Fut, R>(
&self,
f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R>
where
Fut: Future<Output = R> + 'static,
R: Send + 'static,
{
let this = self.clone();
self.executor.spawn_on_main(|| f(this))
}
pub fn run_on_main<R>(
&self,
f: impl FnOnce(&mut MainThread<AppContext>) -> R + Send + 'static,
) -> Result<Task<R>>
where
R: Send + 'static,
{
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut app_context = app.lock();
Ok(app_context.run_on_main(f))
}
pub fn has_global<G: 'static>(&self) -> Result<bool> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let lock = app.lock(); // Need this to compile
Ok(lock.has_global::<G>())
}
pub fn read_global<G: 'static, R>(&self, read: impl FnOnce(&G, &AppContext) -> R) -> Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let lock = app.lock(); // Need this to compile
Ok(read(lock.global(), &lock))
}
pub fn try_read_global<G: 'static, R>(
&self,
read: impl FnOnce(&G, &AppContext) -> R,
) -> Option<R> {
let app = self.app.upgrade()?;
let lock = app.lock(); // Need this to compile
Some(read(lock.try_global()?, &lock))
}
pub fn update_global<G: 'static, R>(
&mut self,
update: impl FnOnce(&mut G, &mut AppContext) -> R,
) -> Result<R> {
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.lock(); // Need this to compile
Ok(lock.update_global(update))
}
}
#[derive(Clone, Deref, DerefMut)]
pub struct AsyncWindowContext {
#[deref]
#[deref_mut]
app: AsyncAppContext,
window: AnyWindowHandle,
}
impl AsyncWindowContext {
pub(crate) fn new(app: AsyncAppContext, window: AnyWindowHandle) -> Self {
Self { app, window }
}
pub fn update<R>(&self, update: impl FnOnce(&mut WindowContext) -> R) -> Result<R> {
self.app.update_window(self.window, update)
}
pub fn on_next_frame(&mut self, f: impl FnOnce(&mut WindowContext) + Send + 'static) {
self.app
.update_window(self.window, |cx| cx.on_next_frame(f))
.ok();
}
pub fn read_global<G: 'static, R>(
&self,
read: impl FnOnce(&G, &WindowContext) -> R,
) -> Result<R> {
self.app
.read_window(self.window, |cx| read(cx.global(), cx))
}
pub fn update_global<G, R>(
&mut self,
update: impl FnOnce(&mut G, &mut WindowContext) -> R,
) -> Result<R>
where
G: 'static,
{
self.app
.update_window(self.window, |cx| cx.update_global(update))
}
}
impl Context for AsyncWindowContext {
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>;
fn build_model<T>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Result<Model<T>>
where
T: 'static + Send,
{
self.app
.update_window(self.window, |cx| cx.build_model(build_model))
}
fn update_model<T: 'static, R>(
&mut self,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Result<R> {
self.app
.update_window(self.window, |cx| cx.update_model(handle, update))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_app_context_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AsyncAppContext>();
}
}

View file

@ -0,0 +1,501 @@
use crate::{private::Sealed, AnyBox, AppContext, Context, Entity};
use anyhow::{anyhow, Result};
use derive_more::{Deref, DerefMut};
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
use slotmap::{SecondaryMap, SlotMap};
use std::{
any::{type_name, TypeId},
fmt::{self, Display},
hash::{Hash, Hasher},
marker::PhantomData,
mem,
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc, Weak,
},
};
slotmap::new_key_type! { pub struct EntityId; }
impl EntityId {
pub fn as_u64(self) -> u64 {
self.0.as_ffi()
}
}
impl Display for EntityId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self)
}
}
pub(crate) struct EntityMap {
entities: SecondaryMap<EntityId, AnyBox>,
ref_counts: Arc<RwLock<EntityRefCounts>>,
}
struct EntityRefCounts {
counts: SlotMap<EntityId, AtomicUsize>,
dropped_entity_ids: Vec<EntityId>,
}
impl EntityMap {
pub fn new() -> Self {
Self {
entities: SecondaryMap::new(),
ref_counts: Arc::new(RwLock::new(EntityRefCounts {
counts: SlotMap::with_key(),
dropped_entity_ids: Vec::new(),
})),
}
}
/// Reserve a slot for an entity, which you can subsequently use with `insert`.
pub fn reserve<T: 'static>(&self) -> Slot<T> {
let id = self.ref_counts.write().counts.insert(1.into());
Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
}
/// Insert an entity into a slot obtained by calling `reserve`.
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
where
T: 'static + Send,
{
let model = slot.0;
self.entities.insert(model.entity_id, Box::new(entity));
model
}
/// Move an entity to the stack.
pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
self.assert_valid_context(model);
let entity = Some(
self.entities
.remove(model.entity_id)
.expect("Circular entity lease. Is the entity already being updated?"),
);
Lease {
model,
entity,
entity_type: PhantomData,
}
}
/// Return an entity after moving it to the stack.
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
self.entities
.insert(lease.model.entity_id, lease.entity.take().unwrap());
}
pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
self.assert_valid_context(model);
self.entities[model.entity_id].downcast_ref().unwrap()
}
fn assert_valid_context(&self, model: &AnyModel) {
debug_assert!(
Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
"used a model with the wrong context"
);
}
pub fn take_dropped(&mut self) -> Vec<(EntityId, AnyBox)> {
let mut ref_counts = self.ref_counts.write();
let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
dropped_entity_ids
.into_iter()
.map(|entity_id| {
ref_counts.counts.remove(entity_id);
(entity_id, self.entities.remove(entity_id).unwrap())
})
.collect()
}
}
pub struct Lease<'a, T> {
entity: Option<AnyBox>,
pub model: &'a Model<T>,
entity_type: PhantomData<T>,
}
impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.entity.as_ref().unwrap().downcast_ref().unwrap()
}
}
impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.entity.as_mut().unwrap().downcast_mut().unwrap()
}
}
impl<'a, T> Drop for Lease<'a, T> {
fn drop(&mut self) {
if self.entity.is_some() {
// We don't panic here, because other panics can cause us to drop the lease without ending it cleanly.
log::error!("Leases must be ended with EntityMap::end_lease")
}
}
}
#[derive(Deref, DerefMut)]
pub struct Slot<T>(Model<T>);
pub struct AnyModel {
pub(crate) entity_id: EntityId,
pub(crate) entity_type: TypeId,
entity_map: Weak<RwLock<EntityRefCounts>>,
}
impl AnyModel {
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
Self {
entity_id: id,
entity_type,
entity_map,
}
}
pub fn entity_id(&self) -> EntityId {
self.entity_id
}
pub fn downgrade(&self) -> AnyWeakModel {
AnyWeakModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_ref_counts: self.entity_map.clone(),
}
}
pub fn downcast<T: 'static>(self) -> Result<Model<T>, AnyModel> {
if TypeId::of::<T>() == self.entity_type {
Ok(Model {
any_model: self,
entity_type: PhantomData,
})
} else {
Err(self)
}
}
}
impl Clone for AnyModel {
fn clone(&self) -> Self {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.read();
let count = entity_map
.counts
.get(self.entity_id)
.expect("detected over-release of a model");
let prev_count = count.fetch_add(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a model.");
}
Self {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_map: self.entity_map.clone(),
}
}
}
impl Drop for AnyModel {
fn drop(&mut self) {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.upgradable_read();
let count = entity_map
.counts
.get(self.entity_id)
.expect("Detected over-release of a model.");
let prev_count = count.fetch_sub(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a model.");
if prev_count == 1 {
// We were the last reference to this entity, so we can remove it.
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
entity_map.dropped_entity_ids.push(self.entity_id);
}
}
}
}
impl<T> From<Model<T>> for AnyModel {
fn from(model: Model<T>) -> Self {
model.any_model
}
}
impl Hash for AnyModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
impl PartialEq for AnyModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
impl Eq for AnyModel {}
impl std::fmt::Debug for AnyModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AnyModel")
.field("entity_id", &self.entity_id.as_u64())
.finish()
}
}
#[derive(Deref, DerefMut)]
pub struct Model<T> {
#[deref]
#[deref_mut]
pub(crate) any_model: AnyModel,
pub(crate) entity_type: PhantomData<T>,
}
unsafe impl<T> Send for Model<T> {}
unsafe impl<T> Sync for Model<T> {}
impl<T> Sealed for Model<T> {}
impl<T: 'static> Entity<T> for Model<T> {
type Weak = WeakModel<T>;
fn entity_id(&self) -> EntityId {
self.any_model.entity_id
}
fn downgrade(&self) -> Self::Weak {
WeakModel {
any_model: self.any_model.downgrade(),
entity_type: self.entity_type,
}
}
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
where
Self: Sized,
{
Some(Model {
any_model: weak.any_model.upgrade()?,
entity_type: weak.entity_type,
})
}
}
impl<T: 'static> Model<T> {
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
where
T: 'static,
{
Self {
any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
entity_type: PhantomData,
}
}
/// Downgrade the this to a weak model reference
pub fn downgrade(&self) -> WeakModel<T> {
// Delegate to the trait implementation to keep behavior in one place.
// This method was included to improve method resolution in the presence of
// the Model's deref
Entity::downgrade(self)
}
/// Convert this into a dynamically typed model.
pub fn into_any(self) -> AnyModel {
self.any_model
}
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
cx.entities.read(self)
}
/// Update the entity referenced by this model with the given function.
///
/// The update function receives a context appropriate for its environment.
/// When updating in an `AppContext`, it receives a `ModelContext`.
/// When updating an a `WindowContext`, it receives a `ViewContext`.
pub fn update<C, R>(
&self,
cx: &mut C,
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> C::Result<R>
where
C: Context,
{
cx.update_model(self, update)
}
}
impl<T> Clone for Model<T> {
fn clone(&self) -> Self {
Self {
any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
}
impl<T> std::fmt::Debug for Model<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Model {{ entity_id: {:?}, entity_type: {:?} }}",
self.any_model.entity_id,
type_name::<T>()
)
}
}
impl<T> Hash for Model<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.any_model.hash(state);
}
}
impl<T> PartialEq for Model<T> {
fn eq(&self, other: &Self) -> bool {
self.any_model == other.any_model
}
}
impl<T> Eq for Model<T> {}
impl<T> PartialEq<WeakModel<T>> for Model<T> {
fn eq(&self, other: &WeakModel<T>) -> bool {
self.any_model.entity_id() == other.entity_id()
}
}
#[derive(Clone)]
pub struct AnyWeakModel {
pub(crate) entity_id: EntityId,
entity_type: TypeId,
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
}
impl AnyWeakModel {
pub fn entity_id(&self) -> EntityId {
self.entity_id
}
pub fn is_upgradable(&self) -> bool {
let ref_count = self
.entity_ref_counts
.upgrade()
.and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
.unwrap_or(0);
ref_count > 0
}
pub fn upgrade(&self) -> Option<AnyModel> {
let entity_map = self.entity_ref_counts.upgrade()?;
entity_map
.read()
.counts
.get(self.entity_id)?
.fetch_add(1, SeqCst);
Some(AnyModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_map: self.entity_ref_counts.clone(),
})
}
}
impl<T> From<WeakModel<T>> for AnyWeakModel {
fn from(model: WeakModel<T>) -> Self {
model.any_model
}
}
impl Hash for AnyWeakModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
impl PartialEq for AnyWeakModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
impl Eq for AnyWeakModel {}
#[derive(Deref, DerefMut)]
pub struct WeakModel<T> {
#[deref]
#[deref_mut]
any_model: AnyWeakModel,
entity_type: PhantomData<T>,
}
unsafe impl<T> Send for WeakModel<T> {}
unsafe impl<T> Sync for WeakModel<T> {}
impl<T> Clone for WeakModel<T> {
fn clone(&self) -> Self {
Self {
any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
}
impl<T: 'static> WeakModel<T> {
/// Upgrade this weak model reference into a strong model reference
pub fn upgrade(&self) -> Option<Model<T>> {
// Delegate to the trait implementation to keep behavior in one place.
Model::upgrade_from(self)
}
/// Update the entity referenced by this model with the given function if
/// the referenced entity still exists. Returns an error if the entity has
/// been released.
///
/// The update function receives a context appropriate for its environment.
/// When updating in an `AppContext`, it receives a `ModelContext`.
/// When updating an a `WindowContext`, it receives a `ViewContext`.
pub fn update<C, R>(
&self,
cx: &mut C,
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> Result<R>
where
C: Context,
Result<C::Result<R>>: crate::Flatten<R>,
{
crate::Flatten::flatten(
self.upgrade()
.ok_or_else(|| anyhow!("entity release"))
.map(|this| cx.update_model(&this, update)),
)
}
}
impl<T> Hash for WeakModel<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.any_model.hash(state);
}
}
impl<T> PartialEq for WeakModel<T> {
fn eq(&self, other: &Self) -> bool {
self.any_model == other.any_model
}
}
impl<T> Eq for WeakModel<T> {}
impl<T> PartialEq<Model<T>> for WeakModel<T> {
fn eq(&self, other: &Model<T>) -> bool {
self.entity_id() == other.any_model.entity_id()
}
}

View file

@ -0,0 +1,266 @@
use crate::{
AppContext, AsyncAppContext, Context, Effect, Entity, EntityId, EventEmitter, MainThread,
Model, Reference, Subscription, Task, WeakModel,
};
use derive_more::{Deref, DerefMut};
use futures::FutureExt;
use std::{
any::{Any, TypeId},
borrow::{Borrow, BorrowMut},
future::Future,
};
#[derive(Deref, DerefMut)]
pub struct ModelContext<'a, T> {
#[deref]
#[deref_mut]
app: Reference<'a, AppContext>,
model_state: WeakModel<T>,
}
impl<'a, T: 'static> ModelContext<'a, T> {
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
Self {
app: Reference::Mutable(app),
model_state,
}
}
pub fn entity_id(&self) -> EntityId {
self.model_state.entity_id
}
pub fn handle(&self) -> Model<T> {
self.weak_model()
.upgrade()
.expect("The entity must be alive if we have a model context")
}
pub fn weak_model(&self) -> WeakModel<T> {
self.model_state.clone()
}
pub fn observe<T2, E>(
&mut self,
entity: &E,
mut on_notify: impl FnMut(&mut T, E, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: 'static + Send,
T2: 'static,
E: Entity<T2>,
{
let this = self.weak_model();
let entity_id = entity.entity_id();
let handle = entity.downgrade();
self.app.observers.insert(
entity_id,
Box::new(move |cx| {
if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&handle)) {
this.update(cx, |this, cx| on_notify(this, handle, cx));
true
} else {
false
}
}),
)
}
pub fn subscribe<T2, E>(
&mut self,
entity: &E,
mut on_event: impl FnMut(&mut T, E, &T2::Event, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: 'static + Send,
T2: 'static + EventEmitter,
E: Entity<T2>,
{
let this = self.weak_model();
let entity_id = entity.entity_id();
let entity = entity.downgrade();
self.app.event_listeners.insert(
entity_id,
Box::new(move |event, cx| {
let event: &T2::Event = event.downcast_ref().expect("invalid event type");
if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&entity)) {
this.update(cx, |this, cx| on_event(this, handle, event, cx));
true
} else {
false
}
}),
)
}
pub fn on_release(
&mut self,
mut on_release: impl FnMut(&mut T, &mut AppContext) + Send + 'static,
) -> Subscription
where
T: 'static,
{
self.app.release_listeners.insert(
self.model_state.entity_id,
Box::new(move |this, cx| {
let this = this.downcast_mut().expect("invalid entity type");
on_release(this, cx);
}),
)
}
pub fn observe_release<T2, E>(
&mut self,
entity: &E,
mut on_release: impl FnMut(&mut T, &mut T2, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: Any + Send,
T2: 'static,
E: Entity<T2>,
{
let entity_id = entity.entity_id();
let this = self.weak_model();
self.app.release_listeners.insert(
entity_id,
Box::new(move |entity, cx| {
let entity = entity.downcast_mut().expect("invalid entity type");
if let Some(this) = this.upgrade() {
this.update(cx, |this, cx| on_release(this, entity, cx));
}
}),
)
}
pub fn observe_global<G: 'static>(
&mut self,
mut f: impl FnMut(&mut T, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: 'static + Send,
{
let handle = self.weak_model();
self.global_observers.insert(
TypeId::of::<G>(),
Box::new(move |cx| handle.update(cx, |view, cx| f(view, cx)).is_ok()),
)
}
pub fn on_app_quit<Fut>(
&mut self,
mut on_quit: impl FnMut(&mut T, &mut ModelContext<T>) -> Fut + Send + 'static,
) -> Subscription
where
Fut: 'static + Future<Output = ()> + Send,
T: 'static + Send,
{
let handle = self.weak_model();
self.app.quit_observers.insert(
(),
Box::new(move |cx| {
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
async move {
if let Some(future) = future {
future.await;
}
}
.boxed()
}),
)
}
pub fn notify(&mut self) {
if self
.app
.pending_notifications
.insert(self.model_state.entity_id)
{
self.app.pending_effects.push_back(Effect::Notify {
emitter: self.model_state.entity_id,
});
}
}
pub fn update_global<G, R>(&mut self, f: impl FnOnce(&mut G, &mut Self) -> R) -> R
where
G: 'static + Send,
{
let mut global = self.app.lease_global::<G>();
let result = f(&mut global, self);
self.app.end_global_lease(global);
result
}
pub fn spawn<Fut, R>(
&self,
f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R>
where
T: 'static,
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let this = self.weak_model();
self.app.spawn(|cx| f(this, cx))
}
pub fn spawn_on_main<Fut, R>(
&self,
f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
) -> Task<R>
where
Fut: Future<Output = R> + 'static,
R: Send + 'static,
{
let this = self.weak_model();
self.app.spawn_on_main(|cx| f(this, cx))
}
}
impl<'a, T> ModelContext<'a, T>
where
T: EventEmitter,
T::Event: Send,
{
pub fn emit(&mut self, event: T::Event) {
self.app.pending_effects.push_back(Effect::Emit {
emitter: self.model_state.entity_id,
event: Box::new(event),
});
}
}
impl<'a, T> Context for ModelContext<'a, T> {
type ModelContext<'b, U> = ModelContext<'b, U>;
type Result<U> = U;
fn build_model<U>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U,
) -> Model<U>
where
U: 'static + Send,
{
self.app.build_model(build_model)
}
fn update_model<U: 'static, R>(
&mut self,
handle: &Model<U>,
update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R,
) -> R {
self.app.update_model(handle, update)
}
}
impl<T> Borrow<AppContext> for ModelContext<'_, T> {
fn borrow(&self) -> &AppContext {
&self.app
}
}
impl<T> BorrowMut<AppContext> for ModelContext<'_, T> {
fn borrow_mut(&mut self) -> &mut AppContext {
&mut self.app
}
}

View file

@ -0,0 +1,152 @@
use crate::{
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model,
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
};
use parking_lot::Mutex;
use std::{future::Future, sync::Arc};
#[derive(Clone)]
pub struct TestAppContext {
pub app: Arc<Mutex<AppContext>>,
pub executor: Executor,
}
impl Context for TestAppContext {
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T;
fn build_model<T: 'static>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
let mut lock = self.app.lock();
lock.build_model(build_model)
}
fn update_model<T: 'static, R>(
&mut self,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
let mut lock = self.app.lock();
lock.update_model(handle, update)
}
}
impl TestAppContext {
pub fn new(dispatcher: TestDispatcher) -> Self {
let executor = Executor::new(Arc::new(dispatcher));
let platform = Arc::new(TestPlatform::new(executor.clone()));
let asset_source = Arc::new(());
let http_client = util::http::FakeHttpClient::with_404_response();
Self {
app: AppContext::new(platform, asset_source, http_client),
executor,
}
}
pub fn quit(&self) {
self.app.lock().quit();
}
pub fn refresh(&mut self) -> Result<()> {
let mut lock = self.app.lock();
lock.refresh();
Ok(())
}
pub fn executor(&self) -> &Executor {
&self.executor
}
pub fn update<R>(&self, f: impl FnOnce(&mut AppContext) -> R) -> R {
let mut lock = self.app.lock();
f(&mut *lock)
}
pub fn read_window<R>(
&self,
handle: AnyWindowHandle,
read: impl FnOnce(&WindowContext) -> R,
) -> R {
let mut app_context = self.app.lock();
app_context.read_window(handle.id, read).unwrap()
}
pub fn update_window<R>(
&self,
handle: AnyWindowHandle,
update: impl FnOnce(&mut WindowContext) -> R,
) -> R {
let mut app = self.app.lock();
app.update_window(handle.id, update).unwrap()
}
pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task<R>
where
Fut: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let cx = self.to_async();
self.executor.spawn(async move { f(cx).await })
}
pub fn spawn_on_main<Fut, R>(
&self,
f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R>
where
Fut: Future<Output = R> + 'static,
R: Send + 'static,
{
let cx = self.to_async();
self.executor.spawn_on_main(|| f(cx))
}
pub fn run_on_main<R>(
&self,
f: impl FnOnce(&mut MainThread<AppContext>) -> R + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
let mut app_context = self.app.lock();
app_context.run_on_main(f)
}
pub fn has_global<G: 'static>(&self) -> bool {
let lock = self.app.lock();
lock.has_global::<G>()
}
pub fn read_global<G: 'static, R>(&self, read: impl FnOnce(&G, &AppContext) -> R) -> R {
let lock = self.app.lock();
read(lock.global(), &lock)
}
pub fn try_read_global<G: 'static, R>(
&self,
read: impl FnOnce(&G, &AppContext) -> R,
) -> Option<R> {
let lock = self.app.lock();
Some(read(lock.try_global()?, &lock))
}
pub fn update_global<G: 'static, R>(
&mut self,
update: impl FnOnce(&mut G, &mut AppContext) -> R,
) -> R {
let mut lock = self.app.lock();
lock.update_global(update)
}
pub fn to_async(&self) -> AsyncAppContext {
AsyncAppContext {
app: Arc::downgrade(&self.app),
executor: self.executor.clone(),
}
}
}

View file

@ -0,0 +1,64 @@
use crate::{size, DevicePixels, Result, SharedString, Size};
use anyhow::anyhow;
use image::{Bgra, ImageBuffer};
use std::{
borrow::Cow,
fmt,
hash::Hash,
sync::atomic::{AtomicUsize, Ordering::SeqCst},
};
pub trait AssetSource: 'static + Send + Sync {
fn load(&self, path: &str) -> Result<Cow<[u8]>>;
fn list(&self, path: &str) -> Result<Vec<SharedString>>;
}
impl AssetSource for () {
fn load(&self, path: &str) -> Result<Cow<[u8]>> {
Err(anyhow!(
"get called on empty asset provider with \"{}\"",
path
))
}
fn list(&self, _path: &str) -> Result<Vec<SharedString>> {
Ok(vec![])
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct ImageId(usize);
pub struct ImageData {
pub id: ImageId,
data: ImageBuffer<Bgra<u8>, Vec<u8>>,
}
impl ImageData {
pub fn new(data: ImageBuffer<Bgra<u8>, Vec<u8>>) -> Self {
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
Self {
id: ImageId(NEXT_ID.fetch_add(1, SeqCst)),
data,
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn size(&self) -> Size<DevicePixels> {
let (width, height) = self.data.dimensions();
size(width.into(), height.into())
}
}
impl fmt::Debug for ImageData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ImageData")
.field("id", &self.id)
.field("size", &self.data.dimensions())
.finish()
}
}

View file

@ -1,9 +1,8 @@
#![allow(dead_code)]
use serde::de::{self, Deserialize, Deserializer, Visitor};
use smallvec::SmallVec;
use std::fmt;
use std::{num::ParseIntError, ops::Range};
use std::num::ParseIntError;
pub fn rgb<C: From<Rgba>>(hex: u32) -> C {
let r = ((hex >> 16) & 0xFF) as f32 / 255.0;
@ -12,7 +11,15 @@ pub fn rgb<C: From<Rgba>>(hex: u32) -> C {
Rgba { r, g, b, a: 1.0 }.into()
}
#[derive(Clone, Copy, Default, Debug)]
pub fn rgba(hex: u32) -> Rgba {
let r = ((hex >> 24) & 0xFF) as f32 / 255.0;
let g = ((hex >> 16) & 0xFF) as f32 / 255.0;
let b = ((hex >> 8) & 0xFF) as f32 / 255.0;
let a = (hex & 0xFF) as f32 / 255.0;
Rgba { r, g, b, a }
}
#[derive(Clone, Copy, Default)]
pub struct Rgba {
pub r: f32,
pub g: f32,
@ -20,6 +27,39 @@ pub struct Rgba {
pub a: f32,
}
impl fmt::Debug for Rgba {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "rgba({:#010x})", u32::from(*self))
}
}
impl Rgba {
pub fn blend(&self, other: Rgba) -> Self {
if other.a >= 1.0 {
return other;
} else if other.a <= 0.0 {
return *self;
} else {
return Rgba {
r: (self.r * (1.0 - other.a)) + (other.r * other.a),
g: (self.g * (1.0 - other.a)) + (other.g * other.a),
b: (self.b * (1.0 - other.a)) + (other.b * other.a),
a: self.a,
};
}
}
}
impl From<Rgba> for u32 {
fn from(rgba: Rgba) -> Self {
let r = (rgba.r * 255.0) as u32;
let g = (rgba.g * 255.0) as u32;
let b = (rgba.b * 255.0) as u32;
let a = (rgba.a * 255.0) as u32;
(r << 24) | (g << 16) | (b << 8) | a
}
}
struct RgbaVisitor;
impl<'de> Visitor<'de> for RgbaVisitor {
@ -54,33 +94,6 @@ impl<'de> Deserialize<'de> for Rgba {
}
}
pub trait Lerp {
fn lerp(&self, level: f32) -> Hsla;
}
impl Lerp for Range<Hsla> {
fn lerp(&self, level: f32) -> Hsla {
let level = level.clamp(0., 1.);
Hsla {
h: self.start.h + (level * (self.end.h - self.start.h)),
s: self.start.s + (level * (self.end.s - self.start.s)),
l: self.start.l + (level * (self.end.l - self.start.l)),
a: self.start.a + (level * (self.end.a - self.start.a)),
}
}
}
impl From<gpui::color::Color> for Rgba {
fn from(value: gpui::color::Color) -> Self {
Self {
r: value.0.r as f32 / 255.0,
g: value.0.g as f32 / 255.0,
b: value.0.b as f32 / 255.0,
a: value.0.a as f32 / 255.0,
}
}
}
impl From<Hsla> for Rgba {
fn from(color: Hsla) -> Self {
let h = color.h;
@ -128,13 +141,8 @@ impl TryFrom<&'_ str> for Rgba {
}
}
impl Into<gpui::color::Color> for Rgba {
fn into(self) -> gpui::color::Color {
gpui::color::rgba(self.r, self.g, self.b, self.a)
}
}
#[derive(Default, Copy, Clone, Debug, PartialEq)]
#[repr(C)]
pub struct Hsla {
pub h: f32,
pub s: f32,
@ -142,6 +150,14 @@ pub struct Hsla {
pub a: f32,
}
impl Hsla {
pub fn to_rgb(self) -> Rgba {
self.into()
}
}
impl Eq for Hsla {}
pub fn hsla(h: f32, s: f32, l: f32, a: f32) -> Hsla {
Hsla {
h: h.clamp(0., 1.),
@ -160,6 +176,73 @@ pub fn black() -> Hsla {
}
}
pub fn white() -> Hsla {
Hsla {
h: 0.,
s: 0.,
l: 1.,
a: 1.,
}
}
pub fn red() -> Hsla {
Hsla {
h: 0.,
s: 1.,
l: 0.5,
a: 1.,
}
}
impl Hsla {
/// Returns true if the HSLA color is fully transparent, false otherwise.
pub fn is_transparent(&self) -> bool {
self.a == 0.0
}
/// Blends `other` on top of `self` based on `other`'s alpha value. The resulting color is a combination of `self`'s and `other`'s colors.
///
/// If `other`'s alpha value is 1.0 or greater, `other` color is fully opaque, thus `other` is returned as the output color.
/// If `other`'s alpha value is 0.0 or less, `other` color is fully transparent, thus `self` is returned as the output color.
/// Else, the output color is calculated as a blend of `self` and `other` based on their weighted alpha values.
///
/// Assumptions:
/// - Alpha values are contained in the range [0, 1], with 1 as fully opaque and 0 as fully transparent.
/// - The relative contributions of `self` and `other` is based on `self`'s alpha value (`self.a`) and `other`'s alpha value (`other.a`), `self` contributing `self.a * (1.0 - other.a)` and `other` contributing it's own alpha value.
/// - RGB color components are contained in the range [0, 1].
/// - If `self` and `other` colors are out of the valid range, the blend operation's output and behavior is undefined.
pub fn blend(self, other: Hsla) -> Hsla {
let alpha = other.a;
if alpha >= 1.0 {
return other;
} else if alpha <= 0.0 {
return self;
} else {
let converted_self = Rgba::from(self);
let converted_other = Rgba::from(other);
let blended_rgb = converted_self.blend(converted_other);
return Hsla::from(blended_rgb);
}
}
/// Fade out the color by a given factor. This factor should be between 0.0 and 1.0.
/// Where 0.0 will leave the color unchanged, and 1.0 will completely fade out the color.
pub fn fade_out(&mut self, factor: f32) {
self.a *= 1.0 - factor.clamp(0., 1.);
}
}
// impl From<Hsla> for Rgba {
// fn from(value: Hsla) -> Self {
// let h = value.h;
// let s = value.s;
// let l = value.l;
// let c = (1 - |2L - 1|) X s
// }
// }
impl From<Rgba> for Hsla {
fn from(color: Rgba) -> Self {
let r = color.r;
@ -198,62 +281,6 @@ impl From<Rgba> for Hsla {
}
}
impl Hsla {
/// Scales the saturation and lightness by the given values, clamping at 1.0.
pub fn scale_sl(mut self, s: f32, l: f32) -> Self {
self.s = (self.s * s).clamp(0., 1.);
self.l = (self.l * l).clamp(0., 1.);
self
}
/// Increases the saturation of the color by a certain amount, with a max
/// value of 1.0.
pub fn saturate(mut self, amount: f32) -> Self {
self.s += amount;
self.s = self.s.clamp(0.0, 1.0);
self
}
/// Decreases the saturation of the color by a certain amount, with a min
/// value of 0.0.
pub fn desaturate(mut self, amount: f32) -> Self {
self.s -= amount;
self.s = self.s.max(0.0);
if self.s < 0.0 {
self.s = 0.0;
}
self
}
/// Brightens the color by increasing the lightness by a certain amount,
/// with a max value of 1.0.
pub fn brighten(mut self, amount: f32) -> Self {
self.l += amount;
self.l = self.l.clamp(0.0, 1.0);
self
}
/// Darkens the color by decreasing the lightness by a certain amount,
/// with a max value of 0.0.
pub fn darken(mut self, amount: f32) -> Self {
self.l -= amount;
self.l = self.l.clamp(0.0, 1.0);
self
}
}
impl From<gpui::color::Color> for Hsla {
fn from(value: gpui::color::Color) -> Self {
Rgba::from(value).into()
}
}
impl Into<gpui::color::Color> for Hsla {
fn into(self) -> gpui::color::Color {
Rgba::from(self).into()
}
}
impl<'de> Deserialize<'de> for Hsla {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@ -266,59 +293,3 @@ impl<'de> Deserialize<'de> for Hsla {
Ok(Hsla::from(rgba))
}
}
pub struct ColorScale {
colors: SmallVec<[Hsla; 2]>,
positions: SmallVec<[f32; 2]>,
}
pub fn scale<I, C>(colors: I) -> ColorScale
where
I: IntoIterator<Item = C>,
C: Into<Hsla>,
{
let mut scale = ColorScale {
colors: colors.into_iter().map(Into::into).collect(),
positions: SmallVec::new(),
};
let num_colors: f32 = scale.colors.len() as f32 - 1.0;
scale.positions = (0..scale.colors.len())
.map(|i| i as f32 / num_colors)
.collect();
scale
}
impl ColorScale {
fn at(&self, t: f32) -> Hsla {
// Ensure that the input is within [0.0, 1.0]
debug_assert!(
0.0 <= t && t <= 1.0,
"t value {} is out of range. Expected value in range 0.0 to 1.0",
t
);
let position = match self
.positions
.binary_search_by(|a| a.partial_cmp(&t).unwrap())
{
Ok(index) | Err(index) => index,
};
let lower_bound = position.saturating_sub(1);
let upper_bound = position.min(self.colors.len() - 1);
let lower_color = &self.colors[lower_bound];
let upper_color = &self.colors[upper_bound];
match upper_bound.checked_sub(lower_bound) {
Some(0) | None => *lower_color,
Some(_) => {
let interval_t = (t - self.positions[lower_bound])
/ (self.positions[upper_bound] - self.positions[lower_bound]);
let h = lower_color.h + interval_t * (upper_color.h - lower_color.h);
let s = lower_color.s + interval_t * (upper_color.s - lower_color.s);
let l = lower_color.l + interval_t * (upper_color.l - lower_color.l);
let a = lower_color.a + interval_t * (upper_color.a - lower_color.a);
Hsla { h, s, l, a }
}
}
}
}

View file

@ -1,51 +1,205 @@
pub use crate::ViewContext;
use anyhow::Result;
use gpui::geometry::vector::Vector2F;
pub use gpui::{Layout, LayoutId};
use smallvec::SmallVec;
use crate::{BorrowWindow, Bounds, ElementId, LayoutId, Pixels, ViewContext};
use derive_more::{Deref, DerefMut};
pub(crate) use smallvec::SmallVec;
use std::{any::Any, mem};
pub trait Element<V: 'static>: 'static + IntoElement<V> {
type PaintState;
pub trait Element<V: 'static> {
type ElementState: 'static + Send;
fn id(&self) -> Option<ElementId>;
/// Called to initialize this element for the current frame. If this
/// element had state in a previous frame, it will be passed in for the 3rd argument.
fn initialize(
&mut self,
view_state: &mut V,
element_state: Option<Self::ElementState>,
cx: &mut ViewContext<V>,
) -> Self::ElementState;
fn layout(
&mut self,
view: &mut V,
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> Result<(LayoutId, Self::PaintState)>
where
Self: Sized;
) -> LayoutId;
fn paint(
&mut self,
view: &mut V,
parent_origin: Vector2F,
layout: &Layout,
state: &mut Self::PaintState,
bounds: Bounds<Pixels>,
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) where
Self: Sized;
);
}
fn into_any(self) -> AnyElement<V>
#[derive(Deref, DerefMut, Default, Clone, Debug, Eq, PartialEq, Hash)]
pub struct GlobalElementId(SmallVec<[ElementId; 32]>);
pub trait ParentElement<V: 'static> {
fn children_mut(&mut self) -> &mut SmallVec<[AnyElement<V>; 2]>;
fn child(mut self, child: impl Component<V>) -> Self
where
Self: 'static + Sized,
Self: Sized,
{
AnyElement(Box::new(StatefulElement {
element: self,
phase: ElementPhase::Init,
}))
self.children_mut().push(child.render());
self
}
/// Applies a given function `then` to the current element if `condition` is true.
/// This function is used to conditionally modify the element based on a given condition.
/// If `condition` is false, it just returns the current element as it is.
///
/// # Parameters
/// - `self`: The current element
/// - `condition`: The boolean condition based on which `then` is applied to the element.
/// - `then`: A function that takes in the current element and returns a possibly modified element.
///
/// # Return
/// It returns the potentially modified element.
fn children(mut self, iter: impl IntoIterator<Item = impl Component<V>>) -> Self
where
Self: Sized,
{
self.children_mut()
.extend(iter.into_iter().map(|item| item.render()));
self
}
}
trait ElementObject<V> {
fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext<V>);
fn layout(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) -> LayoutId;
fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext<V>);
}
struct RenderedElement<V: 'static, E: Element<V>> {
element: E,
phase: ElementRenderPhase<E::ElementState>,
}
#[derive(Default)]
enum ElementRenderPhase<V> {
#[default]
Start,
Initialized {
frame_state: Option<V>,
},
LayoutRequested {
layout_id: LayoutId,
frame_state: Option<V>,
},
Painted,
}
/// Internal struct that wraps an element to store Layout and ElementState after the element is rendered.
/// It's allocated as a trait object to erase the element type and wrapped in AnyElement<E::State> for
/// improved usability.
impl<V, E: Element<V>> RenderedElement<V, E> {
fn new(element: E) -> Self {
RenderedElement {
element,
phase: ElementRenderPhase::Start,
}
}
}
impl<V, E> ElementObject<V> for RenderedElement<V, E>
where
E: Element<V>,
E::ElementState: 'static + Send,
{
fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) {
let frame_state = if let Some(id) = self.element.id() {
cx.with_element_state(id, |element_state, cx| {
let element_state = self.element.initialize(view_state, element_state, cx);
((), element_state)
});
None
} else {
let frame_state = self.element.initialize(view_state, None, cx);
Some(frame_state)
};
self.phase = ElementRenderPhase::Initialized { frame_state };
}
fn layout(&mut self, state: &mut V, cx: &mut ViewContext<V>) -> LayoutId {
let layout_id;
let mut frame_state;
match mem::take(&mut self.phase) {
ElementRenderPhase::Initialized {
frame_state: initial_frame_state,
} => {
frame_state = initial_frame_state;
if let Some(id) = self.element.id() {
layout_id = cx.with_element_state(id, |element_state, cx| {
let mut element_state = element_state.unwrap();
let layout_id = self.element.layout(state, &mut element_state, cx);
(layout_id, element_state)
});
} else {
layout_id = self
.element
.layout(state, frame_state.as_mut().unwrap(), cx);
}
}
_ => panic!("must call initialize before layout"),
};
self.phase = ElementRenderPhase::LayoutRequested {
layout_id,
frame_state,
};
layout_id
}
fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) {
self.phase = match mem::take(&mut self.phase) {
ElementRenderPhase::LayoutRequested {
layout_id,
mut frame_state,
} => {
let bounds = cx.layout_bounds(layout_id);
if let Some(id) = self.element.id() {
cx.with_element_state(id, |element_state, cx| {
let mut element_state = element_state.unwrap();
self.element
.paint(bounds, view_state, &mut element_state, cx);
((), element_state)
});
} else {
self.element
.paint(bounds, view_state, frame_state.as_mut().unwrap(), cx);
}
ElementRenderPhase::Painted
}
_ => panic!("must call layout before paint"),
};
}
}
pub struct AnyElement<V>(Box<dyn ElementObject<V> + Send>);
unsafe impl<V> Send for AnyElement<V> {}
impl<V> AnyElement<V> {
pub fn new<E>(element: E) -> Self
where
V: 'static,
E: 'static + Element<V> + Send,
E::ElementState: Any + Send,
{
AnyElement(Box::new(RenderedElement::new(element)))
}
pub fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) {
self.0.initialize(view_state, cx);
}
pub fn layout(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) -> LayoutId {
self.0.layout(view_state, cx)
}
pub fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext<V>) {
self.0.paint(view_state, cx)
}
}
pub trait Component<V> {
fn render(self) -> AnyElement<V>;
fn when(mut self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self
where
Self: Sized,
@ -57,176 +211,74 @@ pub trait Element<V: 'static>: 'static + IntoElement<V> {
}
}
/// Used to make ElementState<V, E> into a trait object, so we can wrap it in AnyElement<V>.
trait AnyStatefulElement<V> {
fn layout(&mut self, view: &mut V, cx: &mut ViewContext<V>) -> Result<LayoutId>;
fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext<V>);
}
/// A wrapper around an element that stores its layout state.
struct StatefulElement<V: 'static, E: Element<V>> {
element: E,
phase: ElementPhase<V, E>,
}
enum ElementPhase<V: 'static, E: Element<V>> {
Init,
PostLayout {
layout_id: LayoutId,
paint_state: E::PaintState,
},
#[allow(dead_code)]
PostPaint {
layout: Layout,
paint_state: E::PaintState,
},
Error(String),
}
impl<V: 'static, E: Element<V>> std::fmt::Debug for ElementPhase<V, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ElementPhase::Init => write!(f, "Init"),
ElementPhase::PostLayout { layout_id, .. } => {
write!(f, "PostLayout with layout id: {:?}", layout_id)
}
ElementPhase::PostPaint { layout, .. } => {
write!(f, "PostPaint with layout: {:?}", layout)
}
ElementPhase::Error(err) => write!(f, "Error: {}", err),
}
}
}
impl<V: 'static, E: Element<V>> Default for ElementPhase<V, E> {
fn default() -> Self {
Self::Init
}
}
/// We blanket-implement the object-safe ElementStateObject interface to make ElementStates into trait objects
impl<V, E: Element<V>> AnyStatefulElement<V> for StatefulElement<V, E> {
fn layout(&mut self, view: &mut V, cx: &mut ViewContext<V>) -> Result<LayoutId> {
let result;
self.phase = match self.element.layout(view, cx) {
Ok((layout_id, paint_state)) => {
result = Ok(layout_id);
ElementPhase::PostLayout {
layout_id,
paint_state,
}
}
Err(error) => {
let message = error.to_string();
result = Err(error);
ElementPhase::Error(message)
}
};
result
}
fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext<V>) {
self.phase = match std::mem::take(&mut self.phase) {
ElementPhase::PostLayout {
layout_id,
mut paint_state,
} => match cx.computed_layout(layout_id) {
Ok(layout) => {
self.element
.paint(view, parent_origin, &layout, &mut paint_state, cx);
ElementPhase::PostPaint {
layout,
paint_state,
}
}
Err(error) => ElementPhase::Error(error.to_string()),
},
ElementPhase::PostPaint {
layout,
mut paint_state,
} => {
self.element
.paint(view, parent_origin, &layout, &mut paint_state, cx);
ElementPhase::PostPaint {
layout,
paint_state,
}
}
phase @ ElementPhase::Error(_) => phase,
phase @ _ => {
panic!("invalid element phase to call paint: {:?}", phase);
}
};
}
}
/// A dynamic element.
pub struct AnyElement<V>(Box<dyn AnyStatefulElement<V>>);
impl<V> AnyElement<V> {
pub fn layout(&mut self, view: &mut V, cx: &mut ViewContext<V>) -> Result<LayoutId> {
self.0.layout(view, cx)
}
pub fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext<V>) {
self.0.paint(view, parent_origin, cx)
}
}
pub trait ParentElement<V: 'static> {
fn children_mut(&mut self) -> &mut SmallVec<[AnyElement<V>; 2]>;
fn child(mut self, child: impl IntoElement<V>) -> Self
where
Self: Sized,
{
self.children_mut().push(child.into_element().into_any());
self
}
fn children<I, E>(mut self, children: I) -> Self
where
I: IntoIterator<Item = E>,
E: IntoElement<V>,
Self: Sized,
{
self.children_mut().extend(
children
.into_iter()
.map(|child| child.into_element().into_any()),
);
self
}
// HACK: This is a temporary hack to get children working for the purposes
// of building UI on top of the current version of gpui2.
//
// We'll (hopefully) be moving away from this in the future.
fn children_any<I>(mut self, children: I) -> Self
where
I: IntoIterator<Item = AnyElement<V>>,
Self: Sized,
{
self.children_mut().extend(children.into_iter());
self
}
// HACK: This is a temporary hack to get children working for the purposes
// of building UI on top of the current version of gpui2.
//
// We'll (hopefully) be moving away from this in the future.
fn child_any(mut self, children: AnyElement<V>) -> Self
where
Self: Sized,
{
self.children_mut().push(children);
impl<V> Component<V> for AnyElement<V> {
fn render(self) -> AnyElement<V> {
self
}
}
pub trait IntoElement<V: 'static> {
type Element: Element<V>;
impl<V, E, F> Element<V> for Option<F>
where
V: 'static,
E: 'static + Component<V> + Send,
F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
{
type ElementState = AnyElement<V>;
fn into_element(self) -> Self::Element;
fn id(&self) -> Option<ElementId> {
None
}
fn initialize(
&mut self,
view_state: &mut V,
_rendered_element: Option<Self::ElementState>,
cx: &mut ViewContext<V>,
) -> Self::ElementState {
let render = self.take().unwrap();
let mut rendered_element = (render)(view_state, cx).render();
rendered_element.initialize(view_state, cx);
rendered_element
}
fn layout(
&mut self,
view_state: &mut V,
rendered_element: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> LayoutId {
rendered_element.layout(view_state, cx)
}
fn paint(
&mut self,
_bounds: Bounds<Pixels>,
view_state: &mut V,
rendered_element: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) {
rendered_element.paint(view_state, cx)
}
}
impl<V, E, F> Component<V> for Option<F>
where
V: 'static,
E: 'static + Component<V> + Send,
F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
{
fn render(self) -> AnyElement<V> {
AnyElement::new(self)
}
}
impl<V, E, F> Component<V> for F
where
V: 'static,
E: 'static + Component<V> + Send,
F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
{
fn render(self) -> AnyElement<V> {
AnyElement::new(Some(self))
}
}

View file

@ -1,10 +1,9 @@
pub mod div;
pub mod hoverable;
mod div;
mod img;
pub mod pressable;
pub mod svg;
pub mod text;
mod svg;
mod text;
pub use div::div;
pub use img::img;
pub use svg::svg;
pub use div::*;
pub use img::*;
pub use svg::*;
pub use text::*;

View file

@ -1,320 +1,354 @@
use std::{cell::Cell, rc::Rc};
use crate::{
element::{AnyElement, Element, IntoElement, Layout, ParentElement},
hsla,
style::{CornerRadii, Overflow, Style, StyleHelpers, Styleable},
InteractionHandlers, Interactive, ViewContext,
point, AnyElement, BorrowWindow, Bounds, Component, Element, ElementFocus, ElementId,
ElementInteraction, FocusDisabled, FocusEnabled, FocusHandle, FocusListeners, Focusable,
GlobalElementId, GroupBounds, InteractiveElementState, LayoutId, Overflow, ParentElement,
Pixels, Point, SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction,
StatelessInteractive, Style, StyleRefinement, Styled, ViewContext,
};
use anyhow::Result;
use gpui::{
geometry::{rect::RectF, vector::Vector2F, Point},
platform::{MouseButton, MouseButtonEvent, MouseMovedEvent, ScrollWheelEvent},
scene::{self},
LayoutId,
};
use refineable::{Refineable, RefinementCascade};
use refineable::Refineable;
use smallvec::SmallVec;
use util::ResultExt;
pub struct Div<V: 'static> {
styles: RefinementCascade<Style>,
handlers: InteractionHandlers<V>,
pub struct Div<
V: 'static,
I: ElementInteraction<V> = StatelessInteraction<V>,
F: ElementFocus<V> = FocusDisabled,
> {
interaction: I,
focus: F,
children: SmallVec<[AnyElement<V>; 2]>,
scroll_state: Option<ScrollState>,
group: Option<SharedString>,
base_style: StyleRefinement,
}
pub fn div<V>() -> Div<V> {
pub fn div<V: 'static>() -> Div<V, StatelessInteraction<V>, FocusDisabled> {
Div {
styles: Default::default(),
handlers: Default::default(),
children: Default::default(),
scroll_state: None,
interaction: StatelessInteraction::default(),
focus: FocusDisabled,
children: SmallVec::new(),
group: None,
base_style: StyleRefinement::default(),
}
}
impl<V: 'static> Element<V> for Div<V> {
type PaintState = Vec<LayoutId>;
fn layout(
&mut self,
view: &mut V,
cx: &mut ViewContext<V>,
) -> Result<(LayoutId, Self::PaintState)>
where
Self: Sized,
{
let style = self.computed_style();
let pop_text_style = style.text_style(cx).map_or(false, |style| {
cx.push_text_style(&style).log_err().is_some()
});
let children = self
.children
.iter_mut()
.map(|child| child.layout(view, cx))
.collect::<Result<Vec<LayoutId>>>()?;
if pop_text_style {
cx.pop_text_style();
}
Ok((cx.add_layout_node(style, children.clone())?, children))
}
fn paint(
&mut self,
view: &mut V,
parent_origin: Vector2F,
layout: &Layout,
child_layouts: &mut Vec<LayoutId>,
cx: &mut ViewContext<V>,
) where
Self: Sized,
{
let order = layout.order;
let bounds = layout.bounds + parent_origin;
let style = self.computed_style();
let pop_text_style = style.text_style(cx).map_or(false, |style| {
cx.push_text_style(&style).log_err().is_some()
});
style.paint_background(bounds, cx);
self.interaction_handlers().paint(order, bounds, cx);
let scrolled_origin = bounds.origin() - self.scroll_offset(&style.overflow);
// TODO: Support only one dimension being hidden
let mut pop_layer = false;
if style.overflow.y != Overflow::Visible || style.overflow.x != Overflow::Visible {
cx.scene().push_layer(Some(bounds));
pop_layer = true;
}
for child in &mut self.children {
child.paint(view, scrolled_origin, cx);
}
if pop_layer {
cx.scene().pop_layer();
}
style.paint_foreground(bounds, cx);
if pop_text_style {
cx.pop_text_style();
}
self.handle_scroll(order, bounds, style.overflow.clone(), child_layouts, cx);
if cx.is_inspector_enabled() {
self.paint_inspector(parent_origin, layout, cx);
impl<V, F> Div<V, StatelessInteraction<V>, F>
where
V: 'static,
F: ElementFocus<V>,
{
pub fn id(self, id: impl Into<ElementId>) -> Div<V, StatefulInteraction<V>, F> {
Div {
interaction: id.into().into(),
focus: self.focus,
children: self.children,
group: self.group,
base_style: self.base_style,
}
}
}
impl<V: 'static> Div<V> {
impl<V, I, F> Div<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
pub fn group(mut self, group: impl Into<SharedString>) -> Self {
self.group = Some(group.into());
self
}
pub fn z_index(mut self, z_index: u32) -> Self {
self.base_style.z_index = Some(z_index);
self
}
pub fn overflow_hidden(mut self) -> Self {
self.declared_style().overflow.x = Some(Overflow::Hidden);
self.declared_style().overflow.y = Some(Overflow::Hidden);
self.base_style.overflow.x = Some(Overflow::Hidden);
self.base_style.overflow.y = Some(Overflow::Hidden);
self
}
pub fn overflow_hidden_x(mut self) -> Self {
self.declared_style().overflow.x = Some(Overflow::Hidden);
self.base_style.overflow.x = Some(Overflow::Hidden);
self
}
pub fn overflow_hidden_y(mut self) -> Self {
self.declared_style().overflow.y = Some(Overflow::Hidden);
self.base_style.overflow.y = Some(Overflow::Hidden);
self
}
pub fn overflow_scroll(mut self, scroll_state: ScrollState) -> Self {
self.scroll_state = Some(scroll_state);
self.declared_style().overflow.x = Some(Overflow::Scroll);
self.declared_style().overflow.y = Some(Overflow::Scroll);
self
}
pub fn overflow_x_scroll(mut self, scroll_state: ScrollState) -> Self {
self.scroll_state = Some(scroll_state);
self.declared_style().overflow.x = Some(Overflow::Scroll);
self
}
pub fn overflow_y_scroll(mut self, scroll_state: ScrollState) -> Self {
self.scroll_state = Some(scroll_state);
self.declared_style().overflow.y = Some(Overflow::Scroll);
self
}
fn scroll_offset(&self, overflow: &Point<Overflow>) -> Vector2F {
let mut offset = Vector2F::zero();
if overflow.y == Overflow::Scroll {
offset.set_y(self.scroll_state.as_ref().unwrap().y());
}
if overflow.x == Overflow::Scroll {
offset.set_x(self.scroll_state.as_ref().unwrap().x());
}
offset
}
fn handle_scroll(
fn with_element_id<R>(
&mut self,
order: u32,
bounds: RectF,
overflow: Point<Overflow>,
child_layout_ids: &[LayoutId],
cx: &mut ViewContext<V>,
f: impl FnOnce(&mut Self, Option<GlobalElementId>, &mut ViewContext<V>) -> R,
) -> R {
if let Some(id) = self.id() {
cx.with_element_id(id, |global_id, cx| f(self, Some(global_id), cx))
} else {
f(self, None, cx)
}
}
pub fn compute_style(
&self,
bounds: Bounds<Pixels>,
element_state: &DivState,
cx: &mut ViewContext<V>,
) -> Style {
let mut computed_style = Style::default();
computed_style.refine(&self.base_style);
self.focus.refine_style(&mut computed_style, cx);
self.interaction
.refine_style(&mut computed_style, bounds, &element_state.interactive, cx);
computed_style
}
}
impl<V: 'static> Div<V, StatefulInteraction<V>, FocusDisabled> {
pub fn focusable(self) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
Div {
interaction: self.interaction,
focus: FocusEnabled::new(),
children: self.children,
group: self.group,
base_style: self.base_style,
}
}
pub fn track_focus(
self,
handle: &FocusHandle,
) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
Div {
interaction: self.interaction,
focus: FocusEnabled::tracked(handle),
children: self.children,
group: self.group,
base_style: self.base_style,
}
}
pub fn overflow_scroll(mut self) -> Self {
self.base_style.overflow.x = Some(Overflow::Scroll);
self.base_style.overflow.y = Some(Overflow::Scroll);
self
}
pub fn overflow_x_scroll(mut self) -> Self {
self.base_style.overflow.x = Some(Overflow::Scroll);
self
}
pub fn overflow_y_scroll(mut self) -> Self {
self.base_style.overflow.y = Some(Overflow::Scroll);
self
}
}
impl<V: 'static> Div<V, StatelessInteraction<V>, FocusDisabled> {
pub fn track_focus(
self,
handle: &FocusHandle,
) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
Div {
interaction: self.interaction.into_stateful(handle),
focus: handle.clone().into(),
children: self.children,
group: self.group,
base_style: self.base_style,
}
}
}
impl<V, I> Focusable<V> for Div<V, I, FocusEnabled<V>>
where
V: 'static,
I: ElementInteraction<V>,
{
fn focus_listeners(&mut self) -> &mut FocusListeners<V> {
&mut self.focus.focus_listeners
}
fn set_focus_style(&mut self, style: StyleRefinement) {
self.focus.focus_style = style;
}
fn set_focus_in_style(&mut self, style: StyleRefinement) {
self.focus.focus_in_style = style;
}
fn set_in_focus_style(&mut self, style: StyleRefinement) {
self.focus.in_focus_style = style;
}
}
#[derive(Default)]
pub struct DivState {
interactive: InteractiveElementState,
focus_handle: Option<FocusHandle>,
child_layout_ids: SmallVec<[LayoutId; 4]>,
}
impl<V, I, F> Element<V> for Div<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
type ElementState = DivState;
fn id(&self) -> Option<ElementId> {
self.interaction
.as_stateful()
.map(|identified| identified.id.clone())
}
fn initialize(
&mut self,
view_state: &mut V,
element_state: Option<Self::ElementState>,
cx: &mut ViewContext<V>,
) -> Self::ElementState {
let mut element_state = element_state.unwrap_or_default();
self.focus
.initialize(element_state.focus_handle.take(), cx, |focus_handle, cx| {
element_state.focus_handle = focus_handle;
self.interaction.initialize(cx, |cx| {
for child in &mut self.children {
child.initialize(view_state, cx);
}
})
});
element_state
}
fn layout(
&mut self,
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> LayoutId {
let style = self.compute_style(Bounds::default(), element_state, cx);
style.apply_text_style(cx, |cx| {
self.with_element_id(cx, |this, _global_id, cx| {
let layout_ids = this
.children
.iter_mut()
.map(|child| child.layout(view_state, cx))
.collect::<SmallVec<_>>();
element_state.child_layout_ids = layout_ids.clone();
cx.request_layout(&style, layout_ids)
})
})
}
fn paint(
&mut self,
bounds: Bounds<Pixels>,
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) {
if overflow.y == Overflow::Scroll || overflow.x == Overflow::Scroll {
let mut scroll_max = Vector2F::zero();
for child_layout_id in child_layout_ids {
if let Some(child_layout) = cx
.layout_engine()
.unwrap()
.computed_layout(*child_layout_id)
.log_err()
{
scroll_max = scroll_max.max(child_layout.bounds.lower_right());
}
self.with_element_id(cx, |this, _global_id, cx| {
if let Some(group) = this.group.clone() {
GroupBounds::push(group, bounds, cx);
}
scroll_max -= bounds.size();
let scroll_state = self.scroll_state.as_ref().unwrap().clone();
cx.on_event(order, move |_, event: &ScrollWheelEvent, cx| {
if bounds.contains_point(event.position) {
let scroll_delta = match event.delta {
gpui::platform::ScrollDelta::Pixels(delta) => delta,
gpui::platform::ScrollDelta::Lines(delta) => {
delta * cx.text_style().font_size
}
};
if overflow.x == Overflow::Scroll {
scroll_state.set_x(
(scroll_state.x() - scroll_delta.x())
.max(0.)
.min(scroll_max.x()),
);
}
if overflow.y == Overflow::Scroll {
scroll_state.set_y(
(scroll_state.y() - scroll_delta.y())
.max(0.)
.min(scroll_max.y()),
);
}
cx.repaint();
} else {
cx.bubble_event();
let style = this.compute_style(bounds, element_state, cx);
let z_index = style.z_index.unwrap_or(0);
let mut child_min = point(Pixels::MAX, Pixels::MAX);
let mut child_max = Point::default();
let content_size = if element_state.child_layout_ids.is_empty() {
bounds.size
} else {
for child_layout_id in &element_state.child_layout_ids {
let child_bounds = cx.layout_bounds(*child_layout_id);
child_min = child_min.min(&child_bounds.origin);
child_max = child_max.max(&child_bounds.lower_right());
}
})
}
}
(child_max - child_min).into()
};
fn paint_inspector(&self, parent_origin: Vector2F, layout: &Layout, cx: &mut ViewContext<V>) {
let style = self.styles.merged();
let bounds = layout.bounds + parent_origin;
cx.stack(z_index, |cx| {
cx.stack(0, |cx| {
style.paint(bounds, cx);
this.focus.paint(bounds, cx);
this.interaction.paint(
bounds,
content_size,
style.overflow,
&mut element_state.interactive,
cx,
);
});
cx.stack(1, |cx| {
style.apply_text_style(cx, |cx| {
style.apply_overflow(bounds, cx, |cx| {
let scroll_offset = element_state.interactive.scroll_offset();
cx.with_element_offset(scroll_offset, |cx| {
for child in &mut this.children {
child.paint(view_state, cx);
}
});
})
})
});
});
let hovered = bounds.contains_point(cx.mouse_position());
if hovered {
let rem_size = cx.rem_size();
cx.scene().push_quad(scene::Quad {
bounds,
background: Some(hsla(0., 0., 1., 0.05).into()),
border: gpui::Border {
color: hsla(0., 0., 1., 0.2).into(),
top: 1.,
right: 1.,
bottom: 1.,
left: 1.,
},
corner_radii: CornerRadii::default()
.refined(&style.corner_radii)
.to_gpui(bounds.size(), rem_size),
})
}
let pressed = Cell::new(hovered && cx.is_mouse_down(MouseButton::Left));
cx.on_event(layout.order, move |_, event: &MouseButtonEvent, _| {
if bounds.contains_point(event.position) {
if event.is_down {
pressed.set(true);
} else if pressed.get() {
pressed.set(false);
eprintln!("clicked div {:?} {:#?}", bounds, style);
}
if let Some(group) = this.group.as_ref() {
GroupBounds::pop(group, cx);
}
});
let hovered = Cell::new(hovered);
cx.on_event(layout.order, move |_, event: &MouseMovedEvent, cx| {
cx.bubble_event();
let hovered_now = bounds.contains_point(event.position);
if hovered.get() != hovered_now {
hovered.set(hovered_now);
cx.repaint();
}
});
})
}
}
impl<V> Styleable for Div<V> {
type Style = Style;
fn style_cascade(&mut self) -> &mut RefinementCascade<Self::Style> {
&mut self.styles
}
fn declared_style(&mut self) -> &mut <Self::Style as Refineable>::Refinement {
self.styles.base()
impl<V, I, F> Component<V> for Div<V, I, F>
where
// V: Any + Send + Sync,
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn render(self) -> AnyElement<V> {
AnyElement::new(self)
}
}
impl<V> StyleHelpers for Div<V> {}
impl<V> Interactive<V> for Div<V> {
fn interaction_handlers(&mut self) -> &mut InteractionHandlers<V> {
&mut self.handlers
}
}
impl<V: 'static> ParentElement<V> for Div<V> {
impl<V, I, F> ParentElement<V> for Div<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn children_mut(&mut self) -> &mut SmallVec<[AnyElement<V>; 2]> {
&mut self.children
}
}
impl<V: 'static> IntoElement<V> for Div<V> {
type Element = Self;
fn into_element(self) -> Self::Element {
self
impl<V, I, F> Styled for Div<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn style(&mut self) -> &mut StyleRefinement {
&mut self.base_style
}
}
#[derive(Default, Clone)]
pub struct ScrollState(Rc<Cell<Vector2F>>);
impl ScrollState {
pub fn x(&self) -> f32 {
self.0.get().x()
}
pub fn set_x(&self, value: f32) {
let mut current_value = self.0.get();
current_value.set_x(value);
self.0.set(current_value);
}
pub fn y(&self) -> f32 {
self.0.get().y()
}
pub fn set_y(&self, value: f32) {
let mut current_value = self.0.get();
current_value.set_y(value);
self.0.set(current_value);
impl<V, I, F> StatelessInteractive<V> for Div<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn stateless_interaction(&mut self) -> &mut StatelessInteraction<V> {
self.interaction.as_stateless_mut()
}
}
impl<V, F> StatefulInteractive<V> for Div<V, StatefulInteraction<V>, F>
where
F: ElementFocus<V>,
{
fn stateful_interaction(&mut self) -> &mut StatefulInteraction<V> {
&mut self.interaction
}
}

View file

@ -1,105 +0,0 @@
use crate::{
element::{AnyElement, Element, IntoElement, Layout, ParentElement},
interactive::{InteractionHandlers, Interactive},
style::{Style, StyleHelpers, Styleable},
ViewContext,
};
use anyhow::Result;
use gpui::{geometry::vector::Vector2F, platform::MouseMovedEvent, LayoutId};
use refineable::{CascadeSlot, Refineable, RefinementCascade};
use smallvec::SmallVec;
use std::{cell::Cell, rc::Rc};
pub struct Hoverable<E: Styleable> {
hovered: Rc<Cell<bool>>,
cascade_slot: CascadeSlot,
hovered_style: <E::Style as Refineable>::Refinement,
child: E,
}
pub fn hoverable<E: Styleable>(mut child: E) -> Hoverable<E> {
Hoverable {
hovered: Rc::new(Cell::new(false)),
cascade_slot: child.style_cascade().reserve(),
hovered_style: Default::default(),
child,
}
}
impl<E: Styleable> Styleable for Hoverable<E> {
type Style = E::Style;
fn style_cascade(&mut self) -> &mut RefinementCascade<Self::Style> {
self.child.style_cascade()
}
fn declared_style(&mut self) -> &mut <Self::Style as Refineable>::Refinement {
&mut self.hovered_style
}
}
impl<V: 'static, E: Element<V> + Styleable> Element<V> for Hoverable<E> {
type PaintState = E::PaintState;
fn layout(
&mut self,
view: &mut V,
cx: &mut ViewContext<V>,
) -> Result<(LayoutId, Self::PaintState)>
where
Self: Sized,
{
Ok(self.child.layout(view, cx)?)
}
fn paint(
&mut self,
view: &mut V,
parent_origin: Vector2F,
layout: &Layout,
paint_state: &mut Self::PaintState,
cx: &mut ViewContext<V>,
) where
Self: Sized,
{
let bounds = layout.bounds + parent_origin;
self.hovered.set(bounds.contains_point(cx.mouse_position()));
let slot = self.cascade_slot;
let style = self.hovered.get().then_some(self.hovered_style.clone());
self.style_cascade().set(slot, style);
let hovered = self.hovered.clone();
cx.on_event(layout.order, move |_view, _: &MouseMovedEvent, cx| {
cx.bubble_event();
if bounds.contains_point(cx.mouse_position()) != hovered.get() {
cx.repaint();
}
});
self.child
.paint(view, parent_origin, layout, paint_state, cx);
}
}
impl<E: Styleable<Style = Style>> StyleHelpers for Hoverable<E> {}
impl<V: 'static, E: Interactive<V> + Styleable> Interactive<V> for Hoverable<E> {
fn interaction_handlers(&mut self) -> &mut InteractionHandlers<V> {
self.child.interaction_handlers()
}
}
impl<V: 'static, E: ParentElement<V> + Styleable> ParentElement<V> for Hoverable<E> {
fn children_mut(&mut self) -> &mut SmallVec<[AnyElement<V>; 2]> {
self.child.children_mut()
}
}
impl<V: 'static, E: Element<V> + Styleable> IntoElement<V> for Hoverable<E> {
type Element = Self;
fn into_element(self) -> Self::Element {
self
}
}

View file

@ -1,110 +1,184 @@
use crate as gpui2;
use crate::{
style::{Style, StyleHelpers, Styleable},
Element,
div, AnyElement, BorrowWindow, Bounds, Component, Div, DivState, Element, ElementFocus,
ElementId, ElementInteraction, FocusDisabled, FocusEnabled, FocusListeners, Focusable,
LayoutId, Pixels, SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction,
StatelessInteractive, StyleRefinement, Styled, ViewContext,
};
use futures::FutureExt;
use gpui::geometry::vector::Vector2F;
use gpui::scene;
use gpui2_macros::IntoElement;
use refineable::RefinementCascade;
use util::arc_cow::ArcCow;
use util::ResultExt;
#[derive(IntoElement)]
pub struct Img {
style: RefinementCascade<Style>,
uri: Option<ArcCow<'static, str>>,
pub struct Img<
V: 'static,
I: ElementInteraction<V> = StatelessInteraction<V>,
F: ElementFocus<V> = FocusDisabled,
> {
base: Div<V, I, F>,
uri: Option<SharedString>,
grayscale: bool,
}
pub fn img() -> Img {
pub fn img<V: 'static>() -> Img<V, StatelessInteraction<V>, FocusDisabled> {
Img {
style: RefinementCascade::default(),
base: div(),
uri: None,
grayscale: false,
}
}
impl Img {
pub fn uri(mut self, uri: impl Into<ArcCow<'static, str>>) -> Self {
impl<V, I, F> Img<V, I, F>
where
V: 'static,
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
pub fn uri(mut self, uri: impl Into<SharedString>) -> Self {
self.uri = Some(uri.into());
self
}
pub fn grayscale(mut self, grayscale: bool) -> Self {
self.grayscale = grayscale;
self
}
}
impl<V: 'static> Element<V> for Img {
type PaintState = ();
impl<V, F> Img<V, StatelessInteraction<V>, F>
where
F: ElementFocus<V>,
{
pub fn id(self, id: impl Into<ElementId>) -> Img<V, StatefulInteraction<V>, F> {
Img {
base: self.base.id(id),
uri: self.uri,
grayscale: self.grayscale,
}
}
}
impl<V, I, F> Component<V> for Img<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn render(self) -> AnyElement<V> {
AnyElement::new(self)
}
}
impl<V, I, F> Element<V> for Img<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
type ElementState = DivState;
fn id(&self) -> Option<crate::ElementId> {
self.base.id()
}
fn initialize(
&mut self,
view_state: &mut V,
element_state: Option<Self::ElementState>,
cx: &mut ViewContext<V>,
) -> Self::ElementState {
self.base.initialize(view_state, element_state, cx)
}
fn layout(
&mut self,
_: &mut V,
cx: &mut crate::ViewContext<V>,
) -> anyhow::Result<(gpui::LayoutId, Self::PaintState)>
where
Self: Sized,
{
let style = self.computed_style();
let layout_id = cx.add_layout_node(style, [])?;
Ok((layout_id, ()))
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> LayoutId {
self.base.layout(view_state, element_state, cx)
}
fn paint(
&mut self,
_: &mut V,
parent_origin: Vector2F,
layout: &gpui::Layout,
_: &mut Self::PaintState,
cx: &mut crate::ViewContext<V>,
) where
Self: Sized,
{
let style = self.computed_style();
let bounds = layout.bounds + parent_origin;
bounds: Bounds<Pixels>,
view: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) {
cx.stack(0, |cx| {
self.base.paint(bounds, view, element_state, cx);
});
style.paint_background(bounds, cx);
let style = self.base.compute_style(bounds, element_state, cx);
let corner_radii = style.corner_radii;
if let Some(uri) = &self.uri {
let image_future = cx.image_cache.get(uri.clone());
if let Some(uri) = self.uri.clone() {
let image_future = cx.image_cache.get(uri);
if let Some(data) = image_future
.clone()
.now_or_never()
.and_then(ResultExt::log_err)
{
let rem_size = cx.rem_size();
cx.scene().push_image(scene::Image {
bounds,
border: gpui::Border {
color: style.border_color.unwrap_or_default().into(),
top: style.border_widths.top.to_pixels(rem_size),
right: style.border_widths.right.to_pixels(rem_size),
bottom: style.border_widths.bottom.to_pixels(rem_size),
left: style.border_widths.left.to_pixels(rem_size),
},
corner_radii: style.corner_radii.to_gpui(bounds.size(), rem_size),
grayscale: false,
data,
})
let corner_radii = corner_radii.to_pixels(bounds.size, cx.rem_size());
cx.stack(1, |cx| {
cx.paint_image(bounds, corner_radii, data, self.grayscale)
.log_err()
});
} else {
cx.spawn(|this, mut cx| async move {
cx.spawn(|_, mut cx| async move {
if image_future.await.log_err().is_some() {
this.update(&mut cx, |_, cx| cx.notify()).ok();
cx.on_next_frame(|cx| cx.notify());
}
})
.detach();
.detach()
}
}
}
}
impl Styleable for Img {
type Style = Style;
fn style_cascade(&mut self) -> &mut RefinementCascade<Self::Style> {
&mut self.style
}
fn declared_style(&mut self) -> &mut <Self::Style as refineable::Refineable>::Refinement {
self.style.base()
impl<V, I, F> Styled for Img<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn style(&mut self) -> &mut StyleRefinement {
self.base.style()
}
}
impl StyleHelpers for Img {}
impl<V, I, F> StatelessInteractive<V> for Img<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn stateless_interaction(&mut self) -> &mut StatelessInteraction<V> {
self.base.stateless_interaction()
}
}
impl<V, F> StatefulInteractive<V> for Img<V, StatefulInteraction<V>, F>
where
F: ElementFocus<V>,
{
fn stateful_interaction(&mut self) -> &mut StatefulInteraction<V> {
self.base.stateful_interaction()
}
}
impl<V, I> Focusable<V> for Img<V, I, FocusEnabled<V>>
where
V: 'static,
I: ElementInteraction<V>,
{
fn focus_listeners(&mut self) -> &mut FocusListeners<V> {
self.base.focus_listeners()
}
fn set_focus_style(&mut self, style: StyleRefinement) {
self.base.set_focus_style(style)
}
fn set_focus_in_style(&mut self, style: StyleRefinement) {
self.base.set_focus_in_style(style)
}
fn set_in_focus_style(&mut self, style: StyleRefinement) {
self.base.set_in_focus_style(style)
}
}

View file

@ -1,113 +0,0 @@
use crate::{
element::{AnyElement, Element, IntoElement, Layout, ParentElement},
interactive::{InteractionHandlers, Interactive},
style::{Style, StyleHelpers, Styleable},
ViewContext,
};
use anyhow::Result;
use gpui::{geometry::vector::Vector2F, platform::MouseButtonEvent, LayoutId};
use refineable::{CascadeSlot, Refineable, RefinementCascade};
use smallvec::SmallVec;
use std::{cell::Cell, rc::Rc};
pub struct Pressable<E: Styleable> {
pressed: Rc<Cell<bool>>,
pressed_style: <E::Style as Refineable>::Refinement,
cascade_slot: CascadeSlot,
child: E,
}
pub fn pressable<E: Styleable>(mut child: E) -> Pressable<E> {
Pressable {
pressed: Rc::new(Cell::new(false)),
pressed_style: Default::default(),
cascade_slot: child.style_cascade().reserve(),
child,
}
}
impl<E: Styleable> Styleable for Pressable<E> {
type Style = E::Style;
fn declared_style(&mut self) -> &mut <Self::Style as Refineable>::Refinement {
&mut self.pressed_style
}
fn style_cascade(&mut self) -> &mut RefinementCascade<E::Style> {
self.child.style_cascade()
}
}
impl<V: 'static, E: Element<V> + Styleable> Element<V> for Pressable<E> {
type PaintState = E::PaintState;
fn layout(
&mut self,
view: &mut V,
cx: &mut ViewContext<V>,
) -> Result<(LayoutId, Self::PaintState)>
where
Self: Sized,
{
self.child.layout(view, cx)
}
fn paint(
&mut self,
view: &mut V,
parent_origin: Vector2F,
layout: &Layout,
paint_state: &mut Self::PaintState,
cx: &mut ViewContext<V>,
) where
Self: Sized,
{
let slot = self.cascade_slot;
let style = self.pressed.get().then_some(self.pressed_style.clone());
self.style_cascade().set(slot, style);
let pressed = self.pressed.clone();
let bounds = layout.bounds + parent_origin;
cx.on_event(layout.order, move |_view, event: &MouseButtonEvent, cx| {
if event.is_down {
if bounds.contains_point(event.position) {
pressed.set(true);
cx.repaint();
} else {
cx.bubble_event();
}
} else {
if pressed.get() {
pressed.set(false);
cx.repaint();
}
cx.bubble_event();
}
});
self.child
.paint(view, parent_origin, layout, paint_state, cx);
}
}
impl<E: Styleable<Style = Style>> StyleHelpers for Pressable<E> {}
impl<V: 'static, E: Interactive<V> + Styleable> Interactive<V> for Pressable<E> {
fn interaction_handlers(&mut self) -> &mut InteractionHandlers<V> {
self.child.interaction_handlers()
}
}
impl<V: 'static, E: ParentElement<V> + Styleable> ParentElement<V> for Pressable<E> {
fn children_mut(&mut self) -> &mut SmallVec<[AnyElement<V>; 2]> {
self.child.children_mut()
}
}
impl<V: 'static, E: Element<V> + Styleable> IntoElement<V> for Pressable<E> {
type Element = Self;
fn into_element(self) -> Self::Element {
self
}
}

View file

@ -1,84 +1,157 @@
use crate::{
self as gpui2, scene,
style::{Style, StyleHelpers, Styleable},
Element, IntoElement, Layout, LayoutId, Rgba,
div, AnyElement, Bounds, Component, Div, DivState, Element, ElementFocus, ElementId,
ElementInteraction, FocusDisabled, FocusEnabled, FocusListeners, Focusable, LayoutId, Pixels,
SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction,
StatelessInteractive, StyleRefinement, Styled, ViewContext,
};
use gpui::geometry::vector::Vector2F;
use refineable::RefinementCascade;
use std::borrow::Cow;
use util::ResultExt;
#[derive(IntoElement)]
pub struct Svg {
path: Option<Cow<'static, str>>,
style: RefinementCascade<Style>,
pub struct Svg<
V: 'static,
I: ElementInteraction<V> = StatelessInteraction<V>,
F: ElementFocus<V> = FocusDisabled,
> {
base: Div<V, I, F>,
path: Option<SharedString>,
}
pub fn svg() -> Svg {
pub fn svg<V: 'static>() -> Svg<V, StatelessInteraction<V>, FocusDisabled> {
Svg {
base: div(),
path: None,
style: RefinementCascade::<Style>::default(),
}
}
impl Svg {
pub fn path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
impl<V, I, F> Svg<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
pub fn path(mut self, path: impl Into<SharedString>) -> Self {
self.path = Some(path.into());
self
}
}
impl<V: 'static> Element<V> for Svg {
type PaintState = ();
fn layout(
&mut self,
_: &mut V,
cx: &mut crate::ViewContext<V>,
) -> anyhow::Result<(LayoutId, Self::PaintState)>
where
Self: Sized,
{
let style = self.computed_style();
Ok((cx.add_layout_node(style, [])?, ()))
}
fn paint(
&mut self,
_: &mut V,
parent_origin: Vector2F,
layout: &Layout,
_: &mut Self::PaintState,
cx: &mut crate::ViewContext<V>,
) where
Self: Sized,
{
let fill_color = self.computed_style().fill.and_then(|fill| fill.color());
if let Some((path, fill_color)) = self.path.as_ref().zip(fill_color) {
if let Some(svg_tree) = cx.asset_cache.svg(path).log_err() {
let icon = scene::Icon {
bounds: layout.bounds + parent_origin,
svg: svg_tree,
path: path.clone(),
color: Rgba::from(fill_color).into(),
};
cx.scene().push_icon(icon);
}
impl<V, F> Svg<V, StatelessInteraction<V>, F>
where
F: ElementFocus<V>,
{
pub fn id(self, id: impl Into<ElementId>) -> Svg<V, StatefulInteraction<V>, F> {
Svg {
base: self.base.id(id),
path: self.path,
}
}
}
impl Styleable for Svg {
type Style = Style;
fn style_cascade(&mut self) -> &mut refineable::RefinementCascade<Self::Style> {
&mut self.style
}
fn declared_style(&mut self) -> &mut <Self::Style as refineable::Refineable>::Refinement {
self.style.base()
impl<V, I, F> Component<V> for Svg<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn render(self) -> AnyElement<V> {
AnyElement::new(self)
}
}
impl StyleHelpers for Svg {}
impl<V, I, F> Element<V> for Svg<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
type ElementState = DivState;
fn id(&self) -> Option<crate::ElementId> {
self.base.id()
}
fn initialize(
&mut self,
view_state: &mut V,
element_state: Option<Self::ElementState>,
cx: &mut ViewContext<V>,
) -> Self::ElementState {
self.base.initialize(view_state, element_state, cx)
}
fn layout(
&mut self,
view_state: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> LayoutId {
self.base.layout(view_state, element_state, cx)
}
fn paint(
&mut self,
bounds: Bounds<Pixels>,
view: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) where
Self: Sized,
{
self.base.paint(bounds, view, element_state, cx);
let color = self
.base
.compute_style(bounds, element_state, cx)
.text
.color;
if let Some((path, color)) = self.path.as_ref().zip(color) {
cx.paint_svg(bounds, path.clone(), color).log_err();
}
}
}
impl<V, I, F> Styled for Svg<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn style(&mut self) -> &mut StyleRefinement {
self.base.style()
}
}
impl<V, I, F> StatelessInteractive<V> for Svg<V, I, F>
where
I: ElementInteraction<V>,
F: ElementFocus<V>,
{
fn stateless_interaction(&mut self) -> &mut StatelessInteraction<V> {
self.base.stateless_interaction()
}
}
impl<V, F> StatefulInteractive<V> for Svg<V, StatefulInteraction<V>, F>
where
V: 'static,
F: ElementFocus<V>,
{
fn stateful_interaction(&mut self) -> &mut StatefulInteraction<V> {
self.base.stateful_interaction()
}
}
impl<V: 'static, I> Focusable<V> for Svg<V, I, FocusEnabled<V>>
where
I: ElementInteraction<V>,
{
fn focus_listeners(&mut self) -> &mut FocusListeners<V> {
self.base.focus_listeners()
}
fn set_focus_style(&mut self, style: StyleRefinement) {
self.base.set_focus_style(style)
}
fn set_focus_in_style(&mut self, style: StyleRefinement) {
self.base.set_focus_in_style(style)
}
fn set_in_focus_style(&mut self, style: StyleRefinement) {
self.base.set_in_focus_style(style)
}
}

View file

@ -1,105 +1,145 @@
use crate::{
element::{Element, IntoElement, Layout},
ViewContext,
};
use anyhow::Result;
use gpui::{
geometry::{vector::Vector2F, Size},
text_layout::Line,
LayoutId,
AnyElement, BorrowWindow, Bounds, Component, Element, LayoutId, Line, Pixels, SharedString,
Size, ViewContext,
};
use parking_lot::Mutex;
use std::sync::Arc;
use util::arc_cow::ArcCow;
use smallvec::SmallVec;
use std::{marker::PhantomData, sync::Arc};
use util::ResultExt;
impl<V: 'static, S: Into<ArcCow<'static, str>>> IntoElement<V> for S {
type Element = Text;
fn into_element(self) -> Self::Element {
Text { text: self.into() }
impl<V: 'static> Component<V> for SharedString {
fn render(self) -> AnyElement<V> {
Text {
text: self,
state_type: PhantomData,
}
.render()
}
}
pub struct Text {
text: ArcCow<'static, str>,
impl<V: 'static> Component<V> for &'static str {
fn render(self) -> AnyElement<V> {
Text {
text: self.into(),
state_type: PhantomData,
}
.render()
}
}
impl<V: 'static> Element<V> for Text {
type PaintState = Arc<Mutex<Option<TextLayout>>>;
// TODO: Figure out how to pass `String` to `child` without this.
// This impl doesn't exist in the `gpui2` crate.
impl<V: 'static> Component<V> for String {
fn render(self) -> AnyElement<V> {
Text {
text: self.into(),
state_type: PhantomData,
}
.render()
}
}
pub struct Text<V> {
text: SharedString,
state_type: PhantomData<V>,
}
unsafe impl<V> Send for Text<V> {}
unsafe impl<V> Sync for Text<V> {}
impl<V: 'static> Component<V> for Text<V> {
fn render(self) -> AnyElement<V> {
AnyElement::new(self)
}
}
impl<V: 'static> Element<V> for Text<V> {
type ElementState = Arc<Mutex<Option<TextElementState>>>;
fn id(&self) -> Option<crate::ElementId> {
None
}
fn initialize(
&mut self,
_view_state: &mut V,
element_state: Option<Self::ElementState>,
_cx: &mut ViewContext<V>,
) -> Self::ElementState {
element_state.unwrap_or_default()
}
fn layout(
&mut self,
_view: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) -> Result<(LayoutId, Self::PaintState)> {
let layout_cache = cx.text_layout_cache().clone();
) -> LayoutId {
let text_system = cx.text_system().clone();
let text_style = cx.text_style();
let line_height = cx.font_cache().line_height(text_style.font_size);
let font_size = text_style.font_size * cx.rem_size();
let line_height = text_style
.line_height
.to_pixels(font_size.into(), cx.rem_size());
let text = self.text.clone();
let paint_state = Arc::new(Mutex::new(None));
let layout_id = cx.add_measured_layout_node(Default::default(), {
let paint_state = paint_state.clone();
move |_params| {
let line_layout = layout_cache.layout_str(
text.as_ref(),
text_style.font_size,
&[(text.len(), text_style.to_run())],
);
let size = Size {
width: line_layout.width(),
height: line_height,
let rem_size = cx.rem_size();
let layout_id = cx.request_measured_layout(Default::default(), rem_size, {
let element_state = element_state.clone();
move |known_dimensions, _| {
let Some(lines) = text_system
.layout_text(
&text,
font_size,
&[text_style.to_run(text.len())],
known_dimensions.width, // Wrap if we know the width.
)
.log_err()
else {
return Size::default();
};
paint_state.lock().replace(TextLayout {
line_layout: Arc::new(line_layout),
line_height,
});
let line_count = lines
.iter()
.map(|line| line.wrap_count() + 1)
.sum::<usize>();
let size = Size {
width: lines.iter().map(|line| line.layout.width).max().unwrap(),
height: line_height * line_count,
};
element_state
.lock()
.replace(TextElementState { lines, line_height });
size
}
});
Ok((layout_id?, paint_state))
layout_id
}
fn paint<'a>(
fn paint(
&mut self,
_view: &mut V,
parent_origin: Vector2F,
layout: &Layout,
paint_state: &mut Self::PaintState,
bounds: Bounds<Pixels>,
_: &mut V,
element_state: &mut Self::ElementState,
cx: &mut ViewContext<V>,
) {
let bounds = layout.bounds + parent_origin;
let line_layout;
let line_height;
{
let paint_state = paint_state.lock();
let paint_state = paint_state
.as_ref()
.expect("measurement has not been performed");
line_layout = paint_state.line_layout.clone();
line_height = paint_state.line_height;
let element_state = element_state.lock();
let element_state = element_state
.as_ref()
.expect("measurement has not been performed");
let line_height = element_state.line_height;
let mut line_origin = bounds.origin;
for line in &element_state.lines {
line.paint(line_origin, line_height, cx).log_err();
line_origin.y += line.size(line_height).height;
}
// TODO: We haven't added visible bounds to the new element system yet, so this is a placeholder.
let visible_bounds = bounds;
line_layout.paint(bounds.origin(), visible_bounds, line_height, cx.legacy_cx);
}
}
impl<V: 'static> IntoElement<V> for Text {
type Element = Self;
fn into_element(self) -> Self::Element {
self
}
}
pub struct TextLayout {
line_layout: Arc<Line>,
line_height: f32,
pub struct TextElementState {
lines: SmallVec<[Line; 1]>,
line_height: Pixels,
}

View file

@ -0,0 +1,290 @@
use crate::{AppContext, PlatformDispatcher};
use futures::{channel::mpsc, pin_mut, FutureExt};
use smol::prelude::*;
use std::{
fmt::Debug,
marker::PhantomData,
mem,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use util::TryFutureExt;
use waker_fn::waker_fn;
#[derive(Clone)]
pub struct Executor {
dispatcher: Arc<dyn PlatformDispatcher>,
}
#[must_use]
pub enum Task<T> {
Ready(Option<T>),
Spawned(async_task::Task<T>),
}
impl<T> Task<T> {
pub fn ready(val: T) -> Self {
Task::Ready(Some(val))
}
pub fn detach(self) {
match self {
Task::Ready(_) => {}
Task::Spawned(task) => task.detach(),
}
}
}
impl<E, T> Task<Result<T, E>>
where
T: 'static + Send,
E: 'static + Send + Debug,
{
pub fn detach_and_log_err(self, cx: &mut AppContext) {
cx.executor().spawn(self.log_err()).detach();
}
}
impl<T> Future for Task<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match unsafe { self.get_unchecked_mut() } {
Task::Ready(val) => Poll::Ready(val.take().unwrap()),
Task::Spawned(task) => task.poll(cx),
}
}
}
impl Executor {
pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
Self { dispatcher }
}
/// Enqueues the given closure to be run on any thread. The closure returns
/// a future which will be run to completion on any available thread.
pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
where
R: Send + 'static,
{
let dispatcher = self.dispatcher.clone();
let (runnable, task) =
async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
runnable.schedule();
Task::Spawned(task)
}
/// Enqueues the given closure to run on the application's event loop.
/// Returns the result asynchronously.
pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if self.dispatcher.is_main_thread() {
Task::ready(func())
} else {
self.spawn_on_main(move || async move { func() })
}
}
/// Enqueues the given closure to be run on the application's event loop. The
/// closure returns a future which will be run to completion on the main thread.
pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
where
F: Future<Output = R> + 'static,
R: Send + 'static,
{
let (runnable, task) = async_task::spawn(
{
let this = self.clone();
async move {
let task = this.spawn_on_main_local(func());
task.await
}
},
{
let dispatcher = self.dispatcher.clone();
move |runnable| dispatcher.dispatch_on_main_thread(runnable)
},
);
runnable.schedule();
Task::Spawned(task)
}
/// Enqueues the given closure to be run on the application's event loop. Must
/// be called on the main thread.
pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
where
R: 'static,
{
assert!(
self.dispatcher.is_main_thread(),
"must be called on main thread"
);
let dispatcher = self.dispatcher.clone();
let (runnable, task) = async_task::spawn_local(future, move |runnable| {
dispatcher.dispatch_on_main_thread(runnable)
});
runnable.schedule();
Task::Spawned(task)
}
pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
pin_mut!(future);
let (parker, unparker) = parking::pair();
let waker = waker_fn(move || {
unparker.unpark();
});
let mut cx = std::task::Context::from_waker(&waker);
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => return result,
Poll::Pending => {
if !self.dispatcher.poll() {
#[cfg(any(test, feature = "test-support"))]
if let Some(_) = self.dispatcher.as_test() {
panic!("blocked with nothing left to run")
}
parker.park();
}
}
}
}
}
pub fn block_with_timeout<R>(
&self,
duration: Duration,
future: impl Future<Output = R>,
) -> Result<R, impl Future<Output = R>> {
let mut future = Box::pin(future.fuse());
if duration.is_zero() {
return Err(future);
}
let mut timer = self.timer(duration).fuse();
let timeout = async {
futures::select_biased! {
value = future => Ok(value),
_ = timer => Err(()),
}
};
match self.block(timeout) {
Ok(value) => Ok(value),
Err(_) => Err(future),
}
}
pub async fn scoped<'scope, F>(&self, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn(f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
pub fn timer(&self, duration: Duration) -> Task<()> {
let (runnable, task) = async_task::spawn(async move {}, {
let dispatcher = self.dispatcher.clone();
move |runnable| dispatcher.dispatch_after(duration, runnable)
});
runnable.schedule();
Task::Spawned(task)
}
#[cfg(any(test, feature = "test-support"))]
pub fn start_waiting(&self) {
todo!("start_waiting")
}
#[cfg(any(test, feature = "test-support"))]
pub fn finish_waiting(&self) {
todo!("finish_waiting")
}
#[cfg(any(test, feature = "test-support"))]
pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
self.spawn(self.dispatcher.as_test().unwrap().simulate_random_delay())
}
#[cfg(any(test, feature = "test-support"))]
pub fn advance_clock(&self, duration: Duration) {
self.dispatcher.as_test().unwrap().advance_clock(duration)
}
#[cfg(any(test, feature = "test-support"))]
pub fn run_until_parked(&self) {
self.dispatcher.as_test().unwrap().run_until_parked()
}
pub fn num_cpus(&self) -> usize {
num_cpus::get()
}
pub fn is_main_thread(&self) -> bool {
self.dispatcher.is_main_thread()
}
}
pub struct Scope<'a> {
executor: Executor,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
lifetime: PhantomData<&'a ()>,
}
impl<'a> Scope<'a> {
fn new(executor: Executor) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
tx: Some(tx),
rx,
futures: Default::default(),
lifetime: PhantomData,
}
}
pub fn spawn<F>(&mut self, f: F)
where
F: Future<Output = ()> + Send + 'a,
{
let tx = self.tx.clone().unwrap();
// Safety: The 'a lifetime is guaranteed to outlive any of these futures because
// dropping this `Scope` blocks until all of the futures have resolved.
let f = unsafe {
mem::transmute::<
Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
>(Box::pin(async move {
f.await;
drop(tx);
}))
};
self.futures.push(f);
}
}
impl<'a> Drop for Scope<'a> {
fn drop(&mut self) {
self.tx.take().unwrap();
// Wait until the channel is closed, which means that all of the spawned
// futures have resolved.
self.executor.block(self.rx.next());
}
}

View file

@ -0,0 +1,252 @@
use crate::{
Bounds, DispatchPhase, Element, FocusEvent, FocusHandle, MouseDownEvent, Pixels, Style,
StyleRefinement, ViewContext, WindowContext,
};
use refineable::Refineable;
use smallvec::SmallVec;
pub type FocusListeners<V> = SmallVec<[FocusListener<V>; 2]>;
pub type FocusListener<V> =
Box<dyn Fn(&mut V, &FocusHandle, &FocusEvent, &mut ViewContext<V>) + Send + 'static>;
pub trait Focusable<V: 'static>: Element<V> {
fn focus_listeners(&mut self) -> &mut FocusListeners<V>;
fn set_focus_style(&mut self, style: StyleRefinement);
fn set_focus_in_style(&mut self, style: StyleRefinement);
fn set_in_focus_style(&mut self, style: StyleRefinement);
fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
where
Self: Sized,
{
self.set_focus_style(f(StyleRefinement::default()));
self
}
fn focus_in(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
where
Self: Sized,
{
self.set_focus_in_style(f(StyleRefinement::default()));
self
}
fn in_focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
where
Self: Sized,
{
self.set_in_focus_style(f(StyleRefinement::default()));
self
}
fn on_focus(
mut self,
listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + Send + 'static,
) -> Self
where
Self: Sized,
{
self.focus_listeners()
.push(Box::new(move |view, focus_handle, event, cx| {
if event.focused.as_ref() == Some(focus_handle) {
listener(view, event, cx)
}
}));
self
}
fn on_blur(
mut self,
listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + Send + 'static,
) -> Self
where
Self: Sized,
{
self.focus_listeners()
.push(Box::new(move |view, focus_handle, event, cx| {
if event.blurred.as_ref() == Some(focus_handle) {
listener(view, event, cx)
}
}));
self
}
fn on_focus_in(
mut self,
listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + Send + 'static,
) -> Self
where
Self: Sized,
{
self.focus_listeners()
.push(Box::new(move |view, focus_handle, event, cx| {
let descendant_blurred = event
.blurred
.as_ref()
.map_or(false, |blurred| focus_handle.contains(blurred, cx));
let descendant_focused = event
.focused
.as_ref()
.map_or(false, |focused| focus_handle.contains(focused, cx));
if !descendant_blurred && descendant_focused {
listener(view, event, cx)
}
}));
self
}
fn on_focus_out(
mut self,
listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + Send + 'static,
) -> Self
where
Self: Sized,
{
self.focus_listeners()
.push(Box::new(move |view, focus_handle, event, cx| {
let descendant_blurred = event
.blurred
.as_ref()
.map_or(false, |blurred| focus_handle.contains(blurred, cx));
let descendant_focused = event
.focused
.as_ref()
.map_or(false, |focused| focus_handle.contains(focused, cx));
if descendant_blurred && !descendant_focused {
listener(view, event, cx)
}
}));
self
}
}
pub trait ElementFocus<V: 'static>: 'static + Send {
fn as_focusable(&self) -> Option<&FocusEnabled<V>>;
fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>>;
fn initialize<R>(
&mut self,
focus_handle: Option<FocusHandle>,
cx: &mut ViewContext<V>,
f: impl FnOnce(Option<FocusHandle>, &mut ViewContext<V>) -> R,
) -> R {
if let Some(focusable) = self.as_focusable_mut() {
let focus_handle = focusable
.focus_handle
.get_or_insert_with(|| focus_handle.unwrap_or_else(|| cx.focus_handle()))
.clone();
for listener in focusable.focus_listeners.drain(..) {
let focus_handle = focus_handle.clone();
cx.on_focus_changed(move |view, event, cx| {
listener(view, &focus_handle, event, cx)
});
}
cx.with_focus(focus_handle.clone(), |cx| f(Some(focus_handle), cx))
} else {
f(None, cx)
}
}
fn refine_style(&self, style: &mut Style, cx: &WindowContext) {
if let Some(focusable) = self.as_focusable() {
let focus_handle = focusable
.focus_handle
.as_ref()
.expect("must call initialize before refine_style");
if focus_handle.contains_focused(cx) {
style.refine(&focusable.focus_in_style);
}
if focus_handle.within_focused(cx) {
style.refine(&focusable.in_focus_style);
}
if focus_handle.is_focused(cx) {
style.refine(&focusable.focus_style);
}
}
}
fn paint(&self, bounds: Bounds<Pixels>, cx: &mut WindowContext) {
if let Some(focusable) = self.as_focusable() {
let focus_handle = focusable
.focus_handle
.clone()
.expect("must call initialize before paint");
cx.on_mouse_event(move |event: &MouseDownEvent, phase, cx| {
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
if !cx.default_prevented() {
cx.focus(&focus_handle);
cx.prevent_default();
}
}
})
}
}
}
pub struct FocusEnabled<V> {
pub focus_handle: Option<FocusHandle>,
pub focus_listeners: FocusListeners<V>,
pub focus_style: StyleRefinement,
pub focus_in_style: StyleRefinement,
pub in_focus_style: StyleRefinement,
}
impl<V> FocusEnabled<V> {
pub fn new() -> Self {
Self {
focus_handle: None,
focus_listeners: FocusListeners::default(),
focus_style: StyleRefinement::default(),
focus_in_style: StyleRefinement::default(),
in_focus_style: StyleRefinement::default(),
}
}
pub fn tracked(handle: &FocusHandle) -> Self {
Self {
focus_handle: Some(handle.clone()),
focus_listeners: FocusListeners::default(),
focus_style: StyleRefinement::default(),
focus_in_style: StyleRefinement::default(),
in_focus_style: StyleRefinement::default(),
}
}
}
impl<V: 'static> ElementFocus<V> for FocusEnabled<V> {
fn as_focusable(&self) -> Option<&FocusEnabled<V>> {
Some(self)
}
fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>> {
Some(self)
}
}
impl<V> From<FocusHandle> for FocusEnabled<V> {
fn from(value: FocusHandle) -> Self {
Self {
focus_handle: Some(value),
focus_listeners: FocusListeners::default(),
focus_style: StyleRefinement::default(),
focus_in_style: StyleRefinement::default(),
in_focus_style: StyleRefinement::default(),
}
}
}
pub struct FocusDisabled;
impl<V: 'static> ElementFocus<V> for FocusDisabled {
fn as_focusable(&self) -> Option<&FocusEnabled<V>> {
None
}
fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>> {
None
}
}

1244
crates/gpui2/src/geometry.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,22 +1,364 @@
pub mod adapter;
pub mod color;
pub mod element;
pub mod elements;
pub mod interactive;
pub mod style;
pub mod view;
pub mod view_context;
mod action;
mod app;
mod assets;
mod color;
mod element;
mod elements;
mod executor;
mod focusable;
mod geometry;
mod image_cache;
mod interactive;
mod keymap;
mod platform;
mod scene;
mod style;
mod styled;
mod subscription;
mod svg_renderer;
mod taffy;
#[cfg(any(test, feature = "test-support"))]
mod test;
mod text_system;
mod util;
mod view;
mod window;
mod private {
/// A mechanism for restricting implementations of a trait to only those in GPUI.
/// See: https://predr.ag/blog/definitive-guide-to-sealed-traits-in-rust/
pub trait Sealed {}
}
pub use action::*;
pub use anyhow::Result;
pub use app::*;
pub use assets::*;
pub use color::*;
pub use element::{AnyElement, Element, IntoElement, Layout, ParentElement};
pub use geometry::{
rect::RectF,
vector::{vec2f, Vector2F},
};
pub use gpui::*;
pub use gpui2_macros::{Element, *};
pub use element::*;
pub use elements::*;
pub use executor::*;
pub use focusable::*;
pub use geometry::*;
pub use gpui2_macros::*;
pub use image_cache::*;
pub use interactive::*;
pub use platform::{Platform, WindowBounds, WindowOptions};
pub use keymap::*;
pub use platform::*;
use private::Sealed;
pub use refineable::*;
pub use scene::*;
pub use serde;
pub use serde_json;
pub use smallvec;
pub use smol::Timer;
pub use style::*;
pub use styled::*;
pub use subscription::*;
pub use svg_renderer::*;
pub use taffy::{AvailableSpace, LayoutId};
#[cfg(any(test, feature = "test-support"))]
pub use test::*;
pub use text_system::*;
pub use util::arc_cow::ArcCow;
pub use view::*;
pub use view_context::ViewContext;
pub use window::*;
use derive_more::{Deref, DerefMut};
use std::{
any::{Any, TypeId},
borrow::{Borrow, BorrowMut},
mem,
ops::{Deref, DerefMut},
sync::Arc,
};
use taffy::TaffyLayoutEngine;
type AnyBox = Box<dyn Any + Send>;
pub trait Context {
type ModelContext<'a, T>;
type Result<T>;
fn build_model<T>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send;
fn update_model<T: 'static, R>(
&mut self,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R>;
}
pub trait VisualContext: Context {
type ViewContext<'a, 'w, V>;
fn build_view<V>(
&mut self,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
) -> Self::Result<View<V>>
where
V: 'static + Send;
fn update_view<V: 'static, R>(
&mut self,
view: &View<V>,
update: impl FnOnce(&mut V, &mut Self::ViewContext<'_, '_, V>) -> R,
) -> Self::Result<R>;
}
pub trait Entity<T>: Sealed {
type Weak: 'static + Send;
fn entity_id(&self) -> EntityId;
fn downgrade(&self) -> Self::Weak;
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
where
Self: Sized;
}
pub enum GlobalKey {
Numeric(usize),
View(EntityId),
Type(TypeId),
}
#[repr(transparent)]
pub struct MainThread<T>(T);
impl<T> Deref for MainThread<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for MainThread<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<C: Context> Context for MainThread<C> {
type ModelContext<'a, T> = MainThread<C::ModelContext<'a, T>>;
type Result<T> = C::Result<T>;
fn build_model<T>(
&mut self,
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
self.0.build_model(|cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ModelContext<'_, T>,
&mut MainThread<C::ModelContext<'_, T>>,
>(cx)
};
build_model(cx)
})
}
fn update_model<T: 'static, R>(
&mut self,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
self.0.update_model(handle, |entity, cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ModelContext<'_, T>,
&mut MainThread<C::ModelContext<'_, T>>,
>(cx)
};
update(entity, cx)
})
}
}
impl<C: VisualContext> VisualContext for MainThread<C> {
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
fn build_view<V>(
&mut self,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
) -> Self::Result<View<V>>
where
V: 'static + Send,
{
self.0.build_view(|cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ViewContext<'_, '_, V>,
&mut MainThread<C::ViewContext<'_, '_, V>>,
>(cx)
};
build_view_state(cx)
})
}
fn update_view<V: 'static, R>(
&mut self,
view: &View<V>,
update: impl FnOnce(&mut V, &mut Self::ViewContext<'_, '_, V>) -> R,
) -> Self::Result<R> {
self.0.update_view(view, |view_state, cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ViewContext<'_, '_, V>,
&mut MainThread<C::ViewContext<'_, '_, V>>,
>(cx)
};
update(view_state, cx)
})
}
}
pub trait BorrowAppContext {
fn with_text_style<F, R>(&mut self, style: TextStyleRefinement, f: F) -> R
where
F: FnOnce(&mut Self) -> R;
fn set_global<T: Send + 'static>(&mut self, global: T);
}
impl<C> BorrowAppContext for C
where
C: BorrowMut<AppContext>,
{
fn with_text_style<F, R>(&mut self, style: TextStyleRefinement, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
self.borrow_mut().push_text_style(style);
let result = f(self);
self.borrow_mut().pop_text_style();
result
}
fn set_global<G: 'static + Send>(&mut self, global: G) {
self.borrow_mut().set_global(global)
}
}
pub trait EventEmitter: 'static {
type Event: Any;
}
pub trait Flatten<T> {
fn flatten(self) -> Result<T>;
}
impl<T> Flatten<T> for Result<Result<T>> {
fn flatten(self) -> Result<T> {
self?
}
}
impl<T> Flatten<T> for Result<T> {
fn flatten(self) -> Result<T> {
self
}
}
#[derive(Deref, DerefMut, Eq, PartialEq, Hash, Clone)]
pub struct SharedString(ArcCow<'static, str>);
impl Default for SharedString {
fn default() -> Self {
Self(ArcCow::Owned("".into()))
}
}
impl AsRef<str> for SharedString {
fn as_ref(&self) -> &str {
&self.0
}
}
impl Borrow<str> for SharedString {
fn borrow(&self) -> &str {
self.as_ref()
}
}
impl std::fmt::Debug for SharedString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for SharedString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_ref())
}
}
impl<T: Into<ArcCow<'static, str>>> From<T> for SharedString {
fn from(value: T) -> Self {
Self(value.into())
}
}
pub enum Reference<'a, T> {
Immutable(&'a T),
Mutable(&'a mut T),
}
impl<'a, T> Deref for Reference<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
Reference::Immutable(target) => target,
Reference::Mutable(target) => target,
}
}
}
impl<'a, T> DerefMut for Reference<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Reference::Immutable(_) => {
panic!("cannot mutably deref an immutable reference. this is a bug in GPUI.");
}
Reference::Mutable(target) => target,
}
}
}
pub(crate) struct MainThreadOnly<T: ?Sized> {
executor: Executor,
value: Arc<T>,
}
impl<T: ?Sized> Clone for MainThreadOnly<T> {
fn clone(&self) -> Self {
Self {
executor: self.executor.clone(),
value: self.value.clone(),
}
}
}
/// Allows a value to be accessed only on the main thread, allowing a non-`Send` type
/// to become `Send`.
impl<T: 'static + ?Sized> MainThreadOnly<T> {
pub(crate) fn new(value: Arc<T>, executor: Executor) -> Self {
Self { executor, value }
}
pub(crate) fn borrow_on_main_thread(&self) -> &T {
assert!(self.executor.is_main_thread());
&self.value
}
}
unsafe impl<T: ?Sized> Send for MainThreadOnly<T> {}

View file

@ -0,0 +1,99 @@
use crate::{ImageData, ImageId, SharedString};
use collections::HashMap;
use futures::{
future::{BoxFuture, Shared},
AsyncReadExt, FutureExt,
};
use image::ImageError;
use parking_lot::Mutex;
use std::sync::Arc;
use thiserror::Error;
use util::http::{self, HttpClient};
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct RenderImageParams {
pub(crate) image_id: ImageId,
}
#[derive(Debug, Error, Clone)]
pub enum Error {
#[error("http error: {0}")]
Client(#[from] http::Error),
#[error("IO error: {0}")]
Io(Arc<std::io::Error>),
#[error("unexpected http status: {status}, body: {body}")]
BadStatus {
status: http::StatusCode,
body: String,
},
#[error("image error: {0}")]
Image(Arc<ImageError>),
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Error::Io(Arc::new(error))
}
}
impl From<ImageError> for Error {
fn from(error: ImageError) -> Self {
Error::Image(Arc::new(error))
}
}
pub struct ImageCache {
client: Arc<dyn HttpClient>,
images: Arc<Mutex<HashMap<SharedString, FetchImageFuture>>>,
}
type FetchImageFuture = Shared<BoxFuture<'static, Result<Arc<ImageData>, Error>>>;
impl ImageCache {
pub fn new(client: Arc<dyn HttpClient>) -> Self {
ImageCache {
client,
images: Default::default(),
}
}
pub fn get(
&self,
uri: impl Into<SharedString>,
) -> Shared<BoxFuture<'static, Result<Arc<ImageData>, Error>>> {
let uri = uri.into();
let mut images = self.images.lock();
match images.get(&uri) {
Some(future) => future.clone(),
None => {
let client = self.client.clone();
let future = {
let uri = uri.clone();
async move {
let mut response = client.get(uri.as_ref(), ().into(), true).await?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
return Err(Error::BadStatus {
status: response.status(),
body: String::from_utf8_lossy(&body).into_owned(),
});
}
let format = image::guess_format(&body)?;
let image =
image::load_from_memory_with_format(&body, format)?.into_bgra8();
Ok(Arc::new(ImageData::new(image)))
}
}
.boxed()
.shared();
images.insert(uri, future.clone());
future
}
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
use crate::{Action, DispatchContext, DispatchContextPredicate, KeyMatch, Keystroke};
use anyhow::Result;
use smallvec::SmallVec;
pub struct KeyBinding {
action: Box<dyn Action>,
pub(super) keystrokes: SmallVec<[Keystroke; 2]>,
pub(super) context_predicate: Option<DispatchContextPredicate>,
}
impl KeyBinding {
pub fn new<A: Action>(keystrokes: &str, action: A, context_predicate: Option<&str>) -> Self {
Self::load(keystrokes, Box::new(action), context_predicate).unwrap()
}
pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
let context = if let Some(context) = context {
Some(DispatchContextPredicate::parse(context)?)
} else {
None
};
let keystrokes = keystrokes
.split_whitespace()
.map(Keystroke::parse)
.collect::<Result<_>>()?;
Ok(Self {
keystrokes,
action,
context_predicate: context,
})
}
pub fn matches_context(&self, contexts: &[&DispatchContext]) -> bool {
self.context_predicate
.as_ref()
.map(|predicate| predicate.eval(contexts))
.unwrap_or(true)
}
pub fn match_keystrokes(
&self,
pending_keystrokes: &[Keystroke],
contexts: &[&DispatchContext],
) -> KeyMatch {
if self.keystrokes.as_ref().starts_with(&pending_keystrokes)
&& self.matches_context(contexts)
{
// If the binding is completed, push it onto the matches list
if self.keystrokes.as_ref().len() == pending_keystrokes.len() {
KeyMatch::Some(self.action.boxed_clone())
} else {
KeyMatch::Pending
}
} else {
KeyMatch::None
}
}
pub fn keystrokes_for_action(
&self,
action: &dyn Action,
contexts: &[&DispatchContext],
) -> Option<SmallVec<[Keystroke; 2]>> {
if self.action.partial_eq(action) && self.matches_context(contexts) {
Some(self.keystrokes.clone())
} else {
None
}
}
pub fn keystrokes(&self) -> &[Keystroke] {
self.keystrokes.as_slice()
}
pub fn action(&self) -> &dyn Action {
self.action.as_ref()
}
}

View file

@ -0,0 +1,398 @@
use crate::{DispatchContextPredicate, KeyBinding, Keystroke};
use collections::HashSet;
use smallvec::SmallVec;
use std::{any::TypeId, collections::HashMap};
#[derive(Copy, Clone, Eq, PartialEq, Default)]
pub struct KeymapVersion(usize);
#[derive(Default)]
pub struct Keymap {
bindings: Vec<KeyBinding>,
binding_indices_by_action_id: HashMap<TypeId, SmallVec<[usize; 3]>>,
disabled_keystrokes:
HashMap<SmallVec<[Keystroke; 2]>, HashSet<Option<DispatchContextPredicate>>>,
version: KeymapVersion,
}
impl Keymap {
pub fn new(bindings: Vec<KeyBinding>) -> Self {
let mut this = Self::default();
this.add_bindings(bindings);
this
}
pub fn version(&self) -> KeymapVersion {
self.version
}
pub fn bindings_for_action(&self, action_id: TypeId) -> impl Iterator<Item = &'_ KeyBinding> {
self.binding_indices_by_action_id
.get(&action_id)
.map(SmallVec::as_slice)
.unwrap_or(&[])
.iter()
.map(|ix| &self.bindings[*ix])
.filter(|binding| !self.binding_disabled(binding))
}
pub fn add_bindings<T: IntoIterator<Item = KeyBinding>>(&mut self, bindings: T) {
// todo!("no action")
// let no_action_id = (NoAction {}).id();
let mut new_bindings = Vec::new();
let has_new_disabled_keystrokes = false;
for binding in bindings {
// if binding.action().id() == no_action_id {
// has_new_disabled_keystrokes |= self
// .disabled_keystrokes
// .entry(binding.keystrokes)
// .or_default()
// .insert(binding.context_predicate);
// } else {
new_bindings.push(binding);
// }
}
if has_new_disabled_keystrokes {
self.binding_indices_by_action_id.retain(|_, indices| {
indices.retain(|ix| {
let binding = &self.bindings[*ix];
match self.disabled_keystrokes.get(&binding.keystrokes) {
Some(disabled_predicates) => {
!disabled_predicates.contains(&binding.context_predicate)
}
None => true,
}
});
!indices.is_empty()
});
}
for new_binding in new_bindings {
if !self.binding_disabled(&new_binding) {
self.binding_indices_by_action_id
.entry(new_binding.action().as_any().type_id())
.or_default()
.push(self.bindings.len());
self.bindings.push(new_binding);
}
}
self.version.0 += 1;
}
pub fn clear(&mut self) {
self.bindings.clear();
self.binding_indices_by_action_id.clear();
self.disabled_keystrokes.clear();
self.version.0 += 1;
}
pub fn bindings(&self) -> Vec<&KeyBinding> {
self.bindings
.iter()
.filter(|binding| !self.binding_disabled(binding))
.collect()
}
fn binding_disabled(&self, binding: &KeyBinding) -> bool {
match self.disabled_keystrokes.get(&binding.keystrokes) {
Some(disabled_predicates) => disabled_predicates.contains(&binding.context_predicate),
None => false,
}
}
}
// #[cfg(test)]
// mod tests {
// use crate::actions;
// use super::*;
// actions!(
// keymap_test,
// [Present1, Present2, Present3, Duplicate, Missing]
// );
// #[test]
// fn regular_keymap() {
// let present_1 = Binding::new("ctrl-q", Present1 {}, None);
// let present_2 = Binding::new("ctrl-w", Present2 {}, Some("pane"));
// let present_3 = Binding::new("ctrl-e", Present3 {}, Some("editor"));
// let keystroke_duplicate_to_1 = Binding::new("ctrl-q", Duplicate {}, None);
// let full_duplicate_to_2 = Binding::new("ctrl-w", Present2 {}, Some("pane"));
// let missing = Binding::new("ctrl-r", Missing {}, None);
// let all_bindings = [
// &present_1,
// &present_2,
// &present_3,
// &keystroke_duplicate_to_1,
// &full_duplicate_to_2,
// &missing,
// ];
// let mut keymap = Keymap::default();
// assert_absent(&keymap, &all_bindings);
// assert!(keymap.bindings().is_empty());
// keymap.add_bindings([present_1.clone(), present_2.clone(), present_3.clone()]);
// assert_absent(&keymap, &[&keystroke_duplicate_to_1, &missing]);
// assert_present(
// &keymap,
// &[(&present_1, "q"), (&present_2, "w"), (&present_3, "e")],
// );
// keymap.add_bindings([
// keystroke_duplicate_to_1.clone(),
// full_duplicate_to_2.clone(),
// ]);
// assert_absent(&keymap, &[&missing]);
// assert!(
// !keymap.binding_disabled(&keystroke_duplicate_to_1),
// "Duplicate binding 1 was added and should not be disabled"
// );
// assert!(
// !keymap.binding_disabled(&full_duplicate_to_2),
// "Duplicate binding 2 was added and should not be disabled"
// );
// assert_eq!(
// keymap
// .bindings_for_action(keystroke_duplicate_to_1.action().id())
// .map(|binding| &binding.keystrokes)
// .flatten()
// .collect::<Vec<_>>(),
// vec![&Keystroke {
// ctrl: true,
// alt: false,
// shift: false,
// cmd: false,
// function: false,
// key: "q".to_string(),
// ime_key: None,
// }],
// "{keystroke_duplicate_to_1:?} should have the expected keystroke in the keymap"
// );
// assert_eq!(
// keymap
// .bindings_for_action(full_duplicate_to_2.action().id())
// .map(|binding| &binding.keystrokes)
// .flatten()
// .collect::<Vec<_>>(),
// vec![
// &Keystroke {
// ctrl: true,
// alt: false,
// shift: false,
// cmd: false,
// function: false,
// key: "w".to_string(),
// ime_key: None,
// },
// &Keystroke {
// ctrl: true,
// alt: false,
// shift: false,
// cmd: false,
// function: false,
// key: "w".to_string(),
// ime_key: None,
// }
// ],
// "{full_duplicate_to_2:?} should have a duplicated keystroke in the keymap"
// );
// let updated_bindings = keymap.bindings();
// let expected_updated_bindings = vec![
// &present_1,
// &present_2,
// &present_3,
// &keystroke_duplicate_to_1,
// &full_duplicate_to_2,
// ];
// assert_eq!(
// updated_bindings.len(),
// expected_updated_bindings.len(),
// "Unexpected updated keymap bindings {updated_bindings:?}"
// );
// for (i, expected) in expected_updated_bindings.iter().enumerate() {
// let keymap_binding = &updated_bindings[i];
// assert_eq!(
// keymap_binding.context_predicate, expected.context_predicate,
// "Unexpected context predicate for keymap {i} element: {keymap_binding:?}"
// );
// assert_eq!(
// keymap_binding.keystrokes, expected.keystrokes,
// "Unexpected keystrokes for keymap {i} element: {keymap_binding:?}"
// );
// }
// keymap.clear();
// assert_absent(&keymap, &all_bindings);
// assert!(keymap.bindings().is_empty());
// }
// #[test]
// fn keymap_with_ignored() {
// let present_1 = Binding::new("ctrl-q", Present1 {}, None);
// let present_2 = Binding::new("ctrl-w", Present2 {}, Some("pane"));
// let present_3 = Binding::new("ctrl-e", Present3 {}, Some("editor"));
// let keystroke_duplicate_to_1 = Binding::new("ctrl-q", Duplicate {}, None);
// let full_duplicate_to_2 = Binding::new("ctrl-w", Present2 {}, Some("pane"));
// let ignored_1 = Binding::new("ctrl-q", NoAction {}, None);
// let ignored_2 = Binding::new("ctrl-w", NoAction {}, Some("pane"));
// let ignored_3_with_other_context =
// Binding::new("ctrl-e", NoAction {}, Some("other_context"));
// let mut keymap = Keymap::default();
// keymap.add_bindings([
// ignored_1.clone(),
// ignored_2.clone(),
// ignored_3_with_other_context.clone(),
// ]);
// assert_absent(&keymap, &[&present_3]);
// assert_disabled(
// &keymap,
// &[
// &present_1,
// &present_2,
// &ignored_1,
// &ignored_2,
// &ignored_3_with_other_context,
// ],
// );
// assert!(keymap.bindings().is_empty());
// keymap.clear();
// keymap.add_bindings([
// present_1.clone(),
// present_2.clone(),
// present_3.clone(),
// ignored_1.clone(),
// ignored_2.clone(),
// ignored_3_with_other_context.clone(),
// ]);
// assert_present(&keymap, &[(&present_3, "e")]);
// assert_disabled(
// &keymap,
// &[
// &present_1,
// &present_2,
// &ignored_1,
// &ignored_2,
// &ignored_3_with_other_context,
// ],
// );
// keymap.clear();
// keymap.add_bindings([
// present_1.clone(),
// present_2.clone(),
// present_3.clone(),
// ignored_1.clone(),
// ]);
// assert_present(&keymap, &[(&present_2, "w"), (&present_3, "e")]);
// assert_disabled(&keymap, &[&present_1, &ignored_1]);
// assert_absent(&keymap, &[&ignored_2, &ignored_3_with_other_context]);
// keymap.clear();
// keymap.add_bindings([
// present_1.clone(),
// present_2.clone(),
// present_3.clone(),
// keystroke_duplicate_to_1.clone(),
// full_duplicate_to_2.clone(),
// ignored_1.clone(),
// ignored_2.clone(),
// ignored_3_with_other_context.clone(),
// ]);
// assert_present(&keymap, &[(&present_3, "e")]);
// assert_disabled(
// &keymap,
// &[
// &present_1,
// &present_2,
// &keystroke_duplicate_to_1,
// &full_duplicate_to_2,
// &ignored_1,
// &ignored_2,
// &ignored_3_with_other_context,
// ],
// );
// keymap.clear();
// }
// #[track_caller]
// fn assert_present(keymap: &Keymap, expected_bindings: &[(&Binding, &str)]) {
// let keymap_bindings = keymap.bindings();
// assert_eq!(
// expected_bindings.len(),
// keymap_bindings.len(),
// "Unexpected keymap bindings {keymap_bindings:?}"
// );
// for (i, (expected, expected_key)) in expected_bindings.iter().enumerate() {
// assert!(
// !keymap.binding_disabled(expected),
// "{expected:?} should not be disabled as it was added into keymap for element {i}"
// );
// assert_eq!(
// keymap
// .bindings_for_action(expected.action().id())
// .map(|binding| &binding.keystrokes)
// .flatten()
// .collect::<Vec<_>>(),
// vec![&Keystroke {
// ctrl: true,
// alt: false,
// shift: false,
// cmd: false,
// function: false,
// key: expected_key.to_string(),
// ime_key: None,
// }],
// "{expected:?} should have the expected keystroke with key '{expected_key}' in the keymap for element {i}"
// );
// let keymap_binding = &keymap_bindings[i];
// assert_eq!(
// keymap_binding.context_predicate, expected.context_predicate,
// "Unexpected context predicate for keymap {i} element: {keymap_binding:?}"
// );
// assert_eq!(
// keymap_binding.keystrokes, expected.keystrokes,
// "Unexpected keystrokes for keymap {i} element: {keymap_binding:?}"
// );
// }
// }
// #[track_caller]
// fn assert_absent(keymap: &Keymap, bindings: &[&Binding]) {
// for binding in bindings.iter() {
// assert!(
// !keymap.binding_disabled(binding),
// "{binding:?} should not be disabled in the keymap where was not added"
// );
// assert_eq!(
// keymap.bindings_for_action(binding.action().id()).count(),
// 0,
// "{binding:?} should have no actions in the keymap where was not added"
// );
// }
// }
// #[track_caller]
// fn assert_disabled(keymap: &Keymap, bindings: &[&Binding]) {
// for binding in bindings.iter() {
// assert!(
// keymap.binding_disabled(binding),
// "{binding:?} should be disabled in the keymap"
// );
// assert_eq!(
// keymap.bindings_for_action(binding.action().id()).count(),
// 0,
// "{binding:?} should have no actions in the keymap where it was disabled"
// );
// }
// }
// }

View file

@ -0,0 +1,473 @@
use crate::{Action, DispatchContext, Keymap, KeymapVersion, Keystroke};
use parking_lot::Mutex;
use smallvec::SmallVec;
use std::sync::Arc;
pub struct KeyMatcher {
pending_keystrokes: Vec<Keystroke>,
keymap: Arc<Mutex<Keymap>>,
keymap_version: KeymapVersion,
}
impl KeyMatcher {
pub fn new(keymap: Arc<Mutex<Keymap>>) -> Self {
let keymap_version = keymap.lock().version();
Self {
pending_keystrokes: Vec::new(),
keymap_version,
keymap,
}
}
// todo!("replace with a function that calls an FnMut for every binding matching the action")
// pub fn bindings_for_action(&self, action_id: TypeId) -> impl Iterator<Item = &Binding> {
// self.keymap.lock().bindings_for_action(action_id)
// }
pub fn clear_pending(&mut self) {
self.pending_keystrokes.clear();
}
pub fn has_pending_keystrokes(&self) -> bool {
!self.pending_keystrokes.is_empty()
}
/// Pushes a keystroke onto the matcher.
/// The result of the new keystroke is returned:
/// KeyMatch::None =>
/// No match is valid for this key given any pending keystrokes.
/// KeyMatch::Pending =>
/// There exist bindings which are still waiting for more keys.
/// KeyMatch::Complete(matches) =>
/// One or more bindings have received the necessary key presses.
/// Bindings added later will take precedence over earlier bindings.
pub fn match_keystroke(
&mut self,
keystroke: &Keystroke,
context_stack: &[&DispatchContext],
) -> KeyMatch {
let keymap = self.keymap.lock();
// Clear pending keystrokes if the keymap has changed since the last matched keystroke.
if keymap.version() != self.keymap_version {
self.keymap_version = keymap.version();
self.pending_keystrokes.clear();
}
let mut pending_key = None;
for binding in keymap.bindings().iter().rev() {
for candidate in keystroke.match_candidates() {
self.pending_keystrokes.push(candidate.clone());
match binding.match_keystrokes(&self.pending_keystrokes, context_stack) {
KeyMatch::Some(action) => {
self.pending_keystrokes.clear();
return KeyMatch::Some(action);
}
KeyMatch::Pending => {
pending_key.get_or_insert(candidate);
}
KeyMatch::None => {}
}
self.pending_keystrokes.pop();
}
}
if let Some(pending_key) = pending_key {
self.pending_keystrokes.push(pending_key);
}
if self.pending_keystrokes.is_empty() {
KeyMatch::None
} else {
KeyMatch::Pending
}
}
pub fn keystrokes_for_action(
&self,
action: &dyn Action,
contexts: &[&DispatchContext],
) -> Option<SmallVec<[Keystroke; 2]>> {
self.keymap
.lock()
.bindings()
.iter()
.rev()
.find_map(|binding| binding.keystrokes_for_action(action, contexts))
}
}
pub enum KeyMatch {
None,
Pending,
Some(Box<dyn Action>),
}
impl KeyMatch {
pub fn is_some(&self) -> bool {
matches!(self, KeyMatch::Some(_))
}
}
// #[cfg(test)]
// mod tests {
// use anyhow::Result;
// use serde::Deserialize;
// use crate::{actions, impl_actions, keymap_matcher::ActionContext};
// use super::*;
// #[test]
// fn test_keymap_and_view_ordering() -> Result<()> {
// actions!(test, [EditorAction, ProjectPanelAction]);
// let mut editor = ActionContext::default();
// editor.add_identifier("Editor");
// let mut project_panel = ActionContext::default();
// project_panel.add_identifier("ProjectPanel");
// // Editor 'deeper' in than project panel
// let dispatch_path = vec![(2, editor), (1, project_panel)];
// // But editor actions 'higher' up in keymap
// let keymap = Keymap::new(vec![
// Binding::new("left", EditorAction, Some("Editor")),
// Binding::new("left", ProjectPanelAction, Some("ProjectPanel")),
// ]);
// let mut matcher = KeymapMatcher::new(keymap);
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("left")?, dispatch_path.clone()),
// KeyMatch::Matches(vec![
// (2, Box::new(EditorAction)),
// (1, Box::new(ProjectPanelAction)),
// ]),
// );
// Ok(())
// }
// #[test]
// fn test_push_keystroke() -> Result<()> {
// actions!(test, [B, AB, C, D, DA, E, EF]);
// let mut context1 = ActionContext::default();
// context1.add_identifier("1");
// let mut context2 = ActionContext::default();
// context2.add_identifier("2");
// let dispatch_path = vec![(2, context2), (1, context1)];
// let keymap = Keymap::new(vec![
// Binding::new("a b", AB, Some("1")),
// Binding::new("b", B, Some("2")),
// Binding::new("c", C, Some("2")),
// Binding::new("d", D, Some("1")),
// Binding::new("d", D, Some("2")),
// Binding::new("d a", DA, Some("2")),
// ]);
// let mut matcher = KeymapMatcher::new(keymap);
// // Binding with pending prefix always takes precedence
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
// KeyMatch::Pending,
// );
// // B alone doesn't match because a was pending, so AB is returned instead
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
// KeyMatch::Matches(vec![(1, Box::new(AB))]),
// );
// assert!(!matcher.has_pending_keystrokes());
// // Without an a prefix, B is dispatched like expected
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
// KeyMatch::Matches(vec![(2, Box::new(B))]),
// );
// assert!(!matcher.has_pending_keystrokes());
// // If a is prefixed, C will not be dispatched because there
// // was a pending binding for it
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
// KeyMatch::Pending,
// );
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("c")?, dispatch_path.clone()),
// KeyMatch::None,
// );
// assert!(!matcher.has_pending_keystrokes());
// // If a single keystroke matches multiple bindings in the tree
// // all of them are returned so that we can fallback if the action
// // handler decides to propagate the action
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("d")?, dispatch_path.clone()),
// KeyMatch::Matches(vec![(2, Box::new(D)), (1, Box::new(D))]),
// );
// // If none of the d action handlers consume the binding, a pending
// // binding may then be used
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
// KeyMatch::Matches(vec![(2, Box::new(DA))]),
// );
// assert!(!matcher.has_pending_keystrokes());
// Ok(())
// }
// #[test]
// fn test_keystroke_parsing() -> Result<()> {
// assert_eq!(
// Keystroke::parse("ctrl-p")?,
// Keystroke {
// key: "p".into(),
// ctrl: true,
// alt: false,
// shift: false,
// cmd: false,
// function: false,
// ime_key: None,
// }
// );
// assert_eq!(
// Keystroke::parse("alt-shift-down")?,
// Keystroke {
// key: "down".into(),
// ctrl: false,
// alt: true,
// shift: true,
// cmd: false,
// function: false,
// ime_key: None,
// }
// );
// assert_eq!(
// Keystroke::parse("shift-cmd--")?,
// Keystroke {
// key: "-".into(),
// ctrl: false,
// alt: false,
// shift: true,
// cmd: true,
// function: false,
// ime_key: None,
// }
// );
// Ok(())
// }
// #[test]
// fn test_context_predicate_parsing() -> Result<()> {
// use KeymapContextPredicate::*;
// assert_eq!(
// KeymapContextPredicate::parse("a && (b == c || d != e)")?,
// And(
// Box::new(Identifier("a".into())),
// Box::new(Or(
// Box::new(Equal("b".into(), "c".into())),
// Box::new(NotEqual("d".into(), "e".into())),
// ))
// )
// );
// assert_eq!(
// KeymapContextPredicate::parse("!a")?,
// Not(Box::new(Identifier("a".into())),)
// );
// Ok(())
// }
// #[test]
// fn test_context_predicate_eval() {
// let predicate = KeymapContextPredicate::parse("a && b || c == d").unwrap();
// let mut context = ActionContext::default();
// context.add_identifier("a");
// assert!(!predicate.eval(&[context]));
// let mut context = ActionContext::default();
// context.add_identifier("a");
// context.add_identifier("b");
// assert!(predicate.eval(&[context]));
// let mut context = ActionContext::default();
// context.add_identifier("a");
// context.add_key("c", "x");
// assert!(!predicate.eval(&[context]));
// let mut context = ActionContext::default();
// context.add_identifier("a");
// context.add_key("c", "d");
// assert!(predicate.eval(&[context]));
// let predicate = KeymapContextPredicate::parse("!a").unwrap();
// assert!(predicate.eval(&[ActionContext::default()]));
// }
// #[test]
// fn test_context_child_predicate_eval() {
// let predicate = KeymapContextPredicate::parse("a && b > c").unwrap();
// let contexts = [
// context_set(&["e", "f"]),
// context_set(&["c", "d"]), // match this context
// context_set(&["a", "b"]),
// ];
// assert!(!predicate.eval(&contexts[0..]));
// assert!(predicate.eval(&contexts[1..]));
// assert!(!predicate.eval(&contexts[2..]));
// let predicate = KeymapContextPredicate::parse("a && b > c && !d > e").unwrap();
// let contexts = [
// context_set(&["f"]),
// context_set(&["e"]), // only match this context
// context_set(&["c"]),
// context_set(&["a", "b"]),
// context_set(&["e"]),
// context_set(&["c", "d"]),
// context_set(&["a", "b"]),
// ];
// assert!(!predicate.eval(&contexts[0..]));
// assert!(predicate.eval(&contexts[1..]));
// assert!(!predicate.eval(&contexts[2..]));
// assert!(!predicate.eval(&contexts[3..]));
// assert!(!predicate.eval(&contexts[4..]));
// assert!(!predicate.eval(&contexts[5..]));
// assert!(!predicate.eval(&contexts[6..]));
// fn context_set(names: &[&str]) -> ActionContext {
// let mut keymap = ActionContext::new();
// names
// .iter()
// .for_each(|name| keymap.add_identifier(name.to_string()));
// keymap
// }
// }
// #[test]
// fn test_matcher() -> Result<()> {
// #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
// pub struct A(pub String);
// impl_actions!(test, [A]);
// actions!(test, [B, Ab, Dollar, Quote, Ess, Backtick]);
// #[derive(Clone, Debug, Eq, PartialEq)]
// struct ActionArg {
// a: &'static str,
// }
// let keymap = Keymap::new(vec![
// Binding::new("a", A("x".to_string()), Some("a")),
// Binding::new("b", B, Some("a")),
// Binding::new("a b", Ab, Some("a || b")),
// Binding::new("$", Dollar, Some("a")),
// Binding::new("\"", Quote, Some("a")),
// Binding::new("alt-s", Ess, Some("a")),
// Binding::new("ctrl-`", Backtick, Some("a")),
// ]);
// let mut context_a = ActionContext::default();
// context_a.add_identifier("a");
// let mut context_b = ActionContext::default();
// context_b.add_identifier("b");
// let mut matcher = KeymapMatcher::new(keymap);
// // Basic match
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(A("x".to_string())))])
// );
// matcher.clear_pending();
// // Multi-keystroke match
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, vec![(1, context_b.clone())]),
// KeyMatch::Pending
// );
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("b")?, vec![(1, context_b.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Ab))])
// );
// matcher.clear_pending();
// // Failed matches don't interfere with matching subsequent keys
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("x")?, vec![(1, context_a.clone())]),
// KeyMatch::None
// );
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(A("x".to_string())))])
// );
// matcher.clear_pending();
// // Pending keystrokes are cleared when the context changes
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("a")?, vec![(1, context_b.clone())]),
// KeyMatch::Pending
// );
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("b")?, vec![(1, context_a.clone())]),
// KeyMatch::None
// );
// matcher.clear_pending();
// let mut context_c = ActionContext::default();
// context_c.add_identifier("c");
// // Pending keystrokes are maintained per-view
// assert_eq!(
// matcher.match_keystroke(
// Keystroke::parse("a")?,
// vec![(1, context_b.clone()), (2, context_c.clone())]
// ),
// KeyMatch::Pending
// );
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("b")?, vec![(1, context_b.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Ab))])
// );
// // handle Czech $ (option + 4 key)
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("alt-ç->$")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Dollar))])
// );
// // handle Brazillian quote (quote key then space key)
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("space->\"")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Quote))])
// );
// // handle ctrl+` on a brazillian keyboard
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("ctrl-->`")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Backtick))])
// );
// // handle alt-s on a US keyboard
// assert_eq!(
// matcher.match_keystroke(Keystroke::parse("alt-s->ß")?, vec![(1, context_a.clone())]),
// KeyMatch::Matches(vec![(1, Box::new(Ess))])
// );
// Ok(())
// }
// }

View file

@ -0,0 +1,7 @@
mod binding;
mod keymap;
mod matcher;
pub use binding::*;
pub use keymap::*;
pub use matcher::*;

View file

@ -0,0 +1,497 @@
mod keystroke;
#[cfg(target_os = "macos")]
mod mac;
#[cfg(any(test, feature = "test-support"))]
mod test;
use crate::{
AnyWindowHandle, Bounds, DevicePixels, Executor, Font, FontId, FontMetrics, FontRun,
GlobalPixels, GlyphId, InputEvent, LineLayout, Pixels, Point, RenderGlyphParams,
RenderImageParams, RenderSvgParams, Result, Scene, SharedString, Size,
};
use anyhow::anyhow;
use async_task::Runnable;
use futures::channel::oneshot;
use seahash::SeaHasher;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::time::Duration;
use std::{
any::Any,
fmt::{self, Debug, Display},
ops::Range,
path::{Path, PathBuf},
rc::Rc,
str::FromStr,
sync::Arc,
};
pub use keystroke::*;
#[cfg(target_os = "macos")]
pub use mac::*;
#[cfg(any(test, feature = "test-support"))]
pub use test::*;
pub use time::UtcOffset;
#[cfg(target_os = "macos")]
pub(crate) fn current_platform() -> Arc<dyn Platform> {
Arc::new(MacPlatform::new())
}
pub(crate) trait Platform: 'static {
fn executor(&self) -> Executor;
fn text_system(&self) -> Arc<dyn PlatformTextSystem>;
fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>);
fn quit(&self);
fn restart(&self);
fn activate(&self, ignoring_other_apps: bool);
fn hide(&self);
fn hide_other_apps(&self);
fn unhide_other_apps(&self);
fn displays(&self) -> Vec<Rc<dyn PlatformDisplay>>;
fn display(&self, id: DisplayId) -> Option<Rc<dyn PlatformDisplay>>;
fn main_window(&self) -> Option<AnyWindowHandle>;
fn open_window(
&self,
handle: AnyWindowHandle,
options: WindowOptions,
) -> Box<dyn PlatformWindow>;
fn set_display_link_output_callback(
&self,
display_id: DisplayId,
callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp)>,
);
fn start_display_link(&self, display_id: DisplayId);
fn stop_display_link(&self, display_id: DisplayId);
// fn add_status_item(&self, _handle: AnyWindowHandle) -> Box<dyn PlatformWindow>;
fn open_url(&self, url: &str);
fn on_open_urls(&self, callback: Box<dyn FnMut(Vec<String>)>);
fn prompt_for_paths(
&self,
options: PathPromptOptions,
) -> oneshot::Receiver<Option<Vec<PathBuf>>>;
fn prompt_for_new_path(&self, directory: &Path) -> oneshot::Receiver<Option<PathBuf>>;
fn reveal_path(&self, path: &Path);
fn on_become_active(&self, callback: Box<dyn FnMut()>);
fn on_resign_active(&self, callback: Box<dyn FnMut()>);
fn on_quit(&self, callback: Box<dyn FnMut()>);
fn on_reopen(&self, callback: Box<dyn FnMut()>);
fn on_event(&self, callback: Box<dyn FnMut(InputEvent) -> bool>);
fn os_name(&self) -> &'static str;
fn os_version(&self) -> Result<SemanticVersion>;
fn app_version(&self) -> Result<SemanticVersion>;
fn app_path(&self) -> Result<PathBuf>;
fn local_timezone(&self) -> UtcOffset;
fn path_for_auxiliary_executable(&self, name: &str) -> Result<PathBuf>;
fn set_cursor_style(&self, style: CursorStyle);
fn should_auto_hide_scrollbars(&self) -> bool;
fn write_to_clipboard(&self, item: ClipboardItem);
fn read_from_clipboard(&self) -> Option<ClipboardItem>;
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>;
fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>>;
fn delete_credentials(&self, url: &str) -> Result<()>;
}
pub trait PlatformDisplay: Send + Sync + Debug {
fn id(&self) -> DisplayId;
fn as_any(&self) -> &dyn Any;
fn bounds(&self) -> Bounds<GlobalPixels>;
}
#[derive(PartialEq, Eq, Hash, Copy, Clone)]
pub struct DisplayId(pub(crate) u32);
impl Debug for DisplayId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DisplayId({})", self.0)
}
}
unsafe impl Send for DisplayId {}
pub(crate) trait PlatformWindow {
fn bounds(&self) -> WindowBounds;
fn content_size(&self) -> Size<Pixels>;
fn scale_factor(&self) -> f32;
fn titlebar_height(&self) -> Pixels;
fn appearance(&self) -> WindowAppearance;
fn display(&self) -> Rc<dyn PlatformDisplay>;
fn mouse_position(&self) -> Point<Pixels>;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn set_input_handler(&mut self, input_handler: Box<dyn PlatformInputHandler>);
fn prompt(
&self,
level: WindowPromptLevel,
msg: &str,
answers: &[&str],
) -> oneshot::Receiver<usize>;
fn activate(&self);
fn set_title(&mut self, title: &str);
fn set_edited(&mut self, edited: bool);
fn show_character_palette(&self);
fn minimize(&self);
fn zoom(&self);
fn toggle_full_screen(&self);
fn on_input(&self, callback: Box<dyn FnMut(InputEvent) -> bool>);
fn on_active_status_change(&self, callback: Box<dyn FnMut(bool)>);
fn on_resize(&self, callback: Box<dyn FnMut(Size<Pixels>, f32)>);
fn on_fullscreen(&self, callback: Box<dyn FnMut(bool)>);
fn on_moved(&self, callback: Box<dyn FnMut()>);
fn on_should_close(&self, callback: Box<dyn FnMut() -> bool>);
fn on_close(&self, callback: Box<dyn FnOnce()>);
fn on_appearance_changed(&self, callback: Box<dyn FnMut()>);
fn is_topmost_for_position(&self, position: Point<Pixels>) -> bool;
fn draw(&self, scene: Scene);
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas>;
}
pub trait PlatformDispatcher: Send + Sync {
fn is_main_thread(&self) -> bool;
fn dispatch(&self, runnable: Runnable);
fn dispatch_on_main_thread(&self, runnable: Runnable);
fn dispatch_after(&self, duration: Duration, runnable: Runnable);
fn poll(&self) -> bool;
#[cfg(any(test, feature = "test-support"))]
fn as_test(&self) -> Option<&TestDispatcher> {
None
}
}
pub trait PlatformTextSystem: Send + Sync {
fn add_fonts(&self, fonts: &[Arc<Vec<u8>>]) -> Result<()>;
fn all_font_families(&self) -> Vec<String>;
fn font_id(&self, descriptor: &Font) -> Result<FontId>;
fn font_metrics(&self, font_id: FontId) -> FontMetrics;
fn typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>>;
fn advance(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Size<f32>>;
fn glyph_for_char(&self, font_id: FontId, ch: char) -> Option<GlyphId>;
fn glyph_raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>>;
fn rasterize_glyph(&self, params: &RenderGlyphParams) -> Result<(Size<DevicePixels>, Vec<u8>)>;
fn layout_line(&self, text: &str, font_size: Pixels, runs: &[FontRun]) -> LineLayout;
fn wrap_line(
&self,
text: &str,
font_id: FontId,
font_size: Pixels,
width: Pixels,
) -> Vec<usize>;
}
#[derive(Clone, Debug)]
pub struct AppMetadata {
pub os_name: &'static str,
pub os_version: Option<SemanticVersion>,
pub app_version: Option<SemanticVersion>,
}
#[derive(PartialEq, Eq, Hash, Clone)]
pub enum AtlasKey {
Glyph(RenderGlyphParams),
Svg(RenderSvgParams),
Image(RenderImageParams),
}
impl AtlasKey {
pub(crate) fn texture_kind(&self) -> AtlasTextureKind {
match self {
AtlasKey::Glyph(params) => {
if params.is_emoji {
AtlasTextureKind::Polychrome
} else {
AtlasTextureKind::Monochrome
}
}
AtlasKey::Svg(_) => AtlasTextureKind::Monochrome,
AtlasKey::Image(_) => AtlasTextureKind::Polychrome,
}
}
}
impl From<RenderGlyphParams> for AtlasKey {
fn from(params: RenderGlyphParams) -> Self {
Self::Glyph(params)
}
}
impl From<RenderSvgParams> for AtlasKey {
fn from(params: RenderSvgParams) -> Self {
Self::Svg(params)
}
}
impl From<RenderImageParams> for AtlasKey {
fn from(params: RenderImageParams) -> Self {
Self::Image(params)
}
}
pub trait PlatformAtlas: Send + Sync {
fn get_or_insert_with<'a>(
&self,
key: &AtlasKey,
build: &mut dyn FnMut() -> Result<(Size<DevicePixels>, Cow<'a, [u8]>)>,
) -> Result<AtlasTile>;
fn clear(&self);
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[repr(C)]
pub struct AtlasTile {
pub(crate) texture_id: AtlasTextureId,
pub(crate) tile_id: TileId,
pub(crate) bounds: Bounds<DevicePixels>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(C)]
pub(crate) struct AtlasTextureId {
// We use u32 instead of usize for Metal Shader Language compatibility
pub(crate) index: u32,
pub(crate) kind: AtlasTextureKind,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(C)]
pub(crate) enum AtlasTextureKind {
Monochrome = 0,
Polychrome = 1,
Path = 2,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub(crate) struct TileId(pub(crate) u32);
impl From<etagere::AllocId> for TileId {
fn from(id: etagere::AllocId) -> Self {
Self(id.serialize())
}
}
impl From<TileId> for etagere::AllocId {
fn from(id: TileId) -> Self {
Self::deserialize(id.0)
}
}
pub trait PlatformInputHandler {
fn selected_text_range(&self) -> Option<Range<usize>>;
fn marked_text_range(&self) -> Option<Range<usize>>;
fn text_for_range(&self, range_utf16: Range<usize>) -> Option<String>;
fn replace_text_in_range(&mut self, replacement_range: Option<Range<usize>>, text: &str);
fn replace_and_mark_text_in_range(
&mut self,
range_utf16: Option<Range<usize>>,
new_text: &str,
new_selected_range: Option<Range<usize>>,
);
fn unmark_text(&mut self);
fn bounds_for_range(&self, range_utf16: Range<usize>) -> Option<Bounds<f32>>;
}
#[derive(Debug)]
pub struct WindowOptions {
pub bounds: WindowBounds,
pub titlebar: Option<TitlebarOptions>,
pub center: bool,
pub focus: bool,
pub show: bool,
pub kind: WindowKind,
pub is_movable: bool,
pub display_id: Option<DisplayId>,
}
impl Default for WindowOptions {
fn default() -> Self {
Self {
bounds: WindowBounds::default(),
titlebar: Some(TitlebarOptions {
title: Default::default(),
appears_transparent: Default::default(),
traffic_light_position: Default::default(),
}),
center: false,
focus: true,
show: true,
kind: WindowKind::Normal,
is_movable: true,
display_id: None,
}
}
}
#[derive(Debug, Default)]
pub struct TitlebarOptions {
pub title: Option<SharedString>,
pub appears_transparent: bool,
pub traffic_light_position: Option<Point<Pixels>>,
}
#[derive(Copy, Clone, Debug)]
pub enum Appearance {
Light,
VibrantLight,
Dark,
VibrantDark,
}
impl Default for Appearance {
fn default() -> Self {
Self::Light
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum WindowKind {
Normal,
PopUp,
}
#[derive(Copy, Clone, Debug, PartialEq, Default)]
pub enum WindowBounds {
Fullscreen,
#[default]
Maximized,
Fixed(Bounds<GlobalPixels>),
}
#[derive(Copy, Clone, Debug)]
pub enum WindowAppearance {
Light,
VibrantLight,
Dark,
VibrantDark,
}
impl Default for WindowAppearance {
fn default() -> Self {
Self::Light
}
}
#[derive(Copy, Clone, Debug, PartialEq, Default)]
pub enum WindowPromptLevel {
#[default]
Info,
Warning,
Critical,
}
#[derive(Copy, Clone, Debug)]
pub struct PathPromptOptions {
pub files: bool,
pub directories: bool,
pub multiple: bool,
}
#[derive(Copy, Clone, Debug)]
pub enum PromptLevel {
Info,
Warning,
Critical,
}
#[derive(Copy, Clone, Debug)]
pub enum CursorStyle {
Arrow,
ResizeLeftRight,
ResizeUpDown,
PointingHand,
IBeam,
}
impl Default for CursorStyle {
fn default() -> Self {
Self::Arrow
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct SemanticVersion {
major: usize,
minor: usize,
patch: usize,
}
impl FromStr for SemanticVersion {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let mut components = s.trim().split('.');
let major = components
.next()
.ok_or_else(|| anyhow!("missing major version number"))?
.parse()?;
let minor = components
.next()
.ok_or_else(|| anyhow!("missing minor version number"))?
.parse()?;
let patch = components
.next()
.ok_or_else(|| anyhow!("missing patch version number"))?
.parse()?;
Ok(Self {
major,
minor,
patch,
})
}
}
impl Display for SemanticVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ClipboardItem {
pub(crate) text: String,
pub(crate) metadata: Option<String>,
}
impl ClipboardItem {
pub fn new(text: String) -> Self {
Self {
text,
metadata: None,
}
}
pub fn with_metadata<T: Serialize>(mut self, metadata: T) -> Self {
self.metadata = Some(serde_json::to_string(&metadata).unwrap());
self
}
pub fn text(&self) -> &String {
&self.text
}
pub fn metadata<T>(&self) -> Option<T>
where
T: for<'a> Deserialize<'a>,
{
self.metadata
.as_ref()
.and_then(|m| serde_json::from_str(m).ok())
}
pub(crate) fn text_hash(text: &str) -> u64 {
let mut hasher = SeaHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
}

View file

@ -0,0 +1,151 @@
use anyhow::anyhow;
use serde::Deserialize;
use smallvec::SmallVec;
use std::fmt::Write;
#[derive(Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)]
pub struct Keystroke {
pub modifiers: Modifiers,
/// key is the character printed on the key that was pressed
/// e.g. for option-s, key is "s"
pub key: String,
/// ime_key is the character inserted by the IME engine when that key was pressed.
/// e.g. for option-s, ime_key is "ß"
pub ime_key: Option<String>,
}
impl Keystroke {
// When matching a key we cannot know whether the user intended to type
// the ime_key or the key. On some non-US keyboards keys we use in our
// bindings are behind option (for example `$` is typed `alt-ç` on a Czech keyboard),
// and on some keyboards the IME handler converts a sequence of keys into a
// specific character (for example `"` is typed as `" space` on a brazillian keyboard).
pub fn match_candidates(&self) -> SmallVec<[Keystroke; 2]> {
let mut possibilities = SmallVec::new();
match self.ime_key.as_ref() {
None => possibilities.push(self.clone()),
Some(ime_key) => {
possibilities.push(Keystroke {
modifiers: Modifiers {
control: self.modifiers.control,
alt: false,
shift: false,
command: false,
function: false,
},
key: ime_key.to_string(),
ime_key: None,
});
possibilities.push(Keystroke {
ime_key: None,
..self.clone()
});
}
}
possibilities
}
/// key syntax is:
/// [ctrl-][alt-][shift-][cmd-][fn-]key[->ime_key]
/// ime_key is only used for generating test events,
/// when matching a key with an ime_key set will be matched without it.
pub fn parse(source: &str) -> anyhow::Result<Self> {
let mut control = false;
let mut alt = false;
let mut shift = false;
let mut command = false;
let mut function = false;
let mut key = None;
let mut ime_key = None;
let mut components = source.split('-').peekable();
while let Some(component) = components.next() {
match component {
"ctrl" => control = true,
"alt" => alt = true,
"shift" => shift = true,
"cmd" => command = true,
"fn" => function = true,
_ => {
if let Some(next) = components.peek() {
if next.is_empty() && source.ends_with('-') {
key = Some(String::from("-"));
break;
} else if next.len() > 1 && next.starts_with('>') {
key = Some(String::from(component));
ime_key = Some(String::from(&next[1..]));
components.next();
} else {
return Err(anyhow!("Invalid keystroke `{}`", source));
}
} else {
key = Some(String::from(component));
}
}
}
}
let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
Ok(Keystroke {
modifiers: Modifiers {
control,
alt,
shift,
command,
function,
},
key,
ime_key,
})
}
}
impl std::fmt::Display for Keystroke {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.modifiers.control {
f.write_char('^')?;
}
if self.modifiers.alt {
f.write_char('⌥')?;
}
if self.modifiers.command {
f.write_char('⌘')?;
}
if self.modifiers.shift {
f.write_char('⇧')?;
}
let key = match self.key.as_str() {
"backspace" => '⌫',
"up" => '↑',
"down" => '↓',
"left" => '←',
"right" => '→',
"tab" => '⇥',
"escape" => '⎋',
key => {
if key.len() == 1 {
key.chars().next().unwrap().to_ascii_uppercase()
} else {
return f.write_str(key);
}
}
};
f.write_char(key)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default, Deserialize, Hash)]
pub struct Modifiers {
pub control: bool,
pub alt: bool,
pub shift: bool,
pub command: bool,
pub function: bool,
}
impl Modifiers {
pub fn modified(&self) -> bool {
self.control || self.alt || self.shift || self.command || self.function
}
}

View file

@ -0,0 +1,165 @@
///! Macos screen have a y axis that goings up from the bottom of the screen and
///! an origin at the bottom left of the main display.
mod dispatcher;
mod display;
mod display_linker;
mod events;
mod metal_atlas;
mod metal_renderer;
mod open_type;
mod platform;
mod text_system;
mod window;
mod window_appearence;
use crate::{px, size, GlobalPixels, Pixels, Size};
use anyhow::anyhow;
use cocoa::{
base::{id, nil},
foundation::{NSAutoreleasePool, NSNotFound, NSRect, NSSize, NSString, NSUInteger, NSURL},
};
use metal_renderer::*;
use objc::{
msg_send,
runtime::{BOOL, NO, YES},
sel, sel_impl,
};
use std::{
ffi::{c_char, CStr, OsStr},
ops::Range,
os::unix::prelude::OsStrExt,
path::PathBuf,
};
pub use dispatcher::*;
pub use display::*;
pub use display_linker::*;
pub use metal_atlas::*;
pub use platform::*;
pub use text_system::*;
pub use window::*;
trait BoolExt {
fn to_objc(self) -> BOOL;
}
impl BoolExt for bool {
fn to_objc(self) -> BOOL {
if self {
YES
} else {
NO
}
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
struct NSRange {
pub location: NSUInteger,
pub length: NSUInteger,
}
impl NSRange {
fn invalid() -> Self {
Self {
location: NSNotFound as NSUInteger,
length: 0,
}
}
fn is_valid(&self) -> bool {
self.location != NSNotFound as NSUInteger
}
fn to_range(self) -> Option<Range<usize>> {
if self.is_valid() {
let start = self.location as usize;
let end = start + self.length as usize;
Some(start..end)
} else {
None
}
}
}
impl From<Range<usize>> for NSRange {
fn from(range: Range<usize>) -> Self {
NSRange {
location: range.start as NSUInteger,
length: range.len() as NSUInteger,
}
}
}
unsafe impl objc::Encode for NSRange {
fn encode() -> objc::Encoding {
let encoding = format!(
"{{NSRange={}{}}}",
NSUInteger::encode().as_str(),
NSUInteger::encode().as_str()
);
unsafe { objc::Encoding::from_str(&encoding) }
}
}
unsafe fn ns_string(string: &str) -> id {
NSString::alloc(nil).init_str(string).autorelease()
}
impl From<NSSize> for Size<Pixels> {
fn from(value: NSSize) -> Self {
Size {
width: px(value.width as f32),
height: px(value.height as f32),
}
}
}
pub trait NSRectExt {
fn size(&self) -> Size<Pixels>;
fn intersects(&self, other: Self) -> bool;
}
impl From<NSRect> for Size<Pixels> {
fn from(rect: NSRect) -> Self {
let NSSize { width, height } = rect.size;
size(width.into(), height.into())
}
}
impl From<NSRect> for Size<GlobalPixels> {
fn from(rect: NSRect) -> Self {
let NSSize { width, height } = rect.size;
size(width.into(), height.into())
}
}
// impl NSRectExt for NSRect {
// fn intersects(&self, other: Self) -> bool {
// self.size.width > 0.
// && self.size.height > 0.
// && other.size.width > 0.
// && other.size.height > 0.
// && self.origin.x <= other.origin.x + other.size.width
// && self.origin.x + self.size.width >= other.origin.x
// && self.origin.y <= other.origin.y + other.size.height
// && self.origin.y + self.size.height >= other.origin.y
// }
// }
// todo!
#[allow(unused)]
unsafe fn ns_url_to_path(url: id) -> crate::Result<PathBuf> {
let path: *mut c_char = msg_send![url, fileSystemRepresentation];
if path.is_null() {
Err(anyhow!(
"url is not a file path: {}",
CStr::from_ptr(url.absoluteString().UTF8String()).to_string_lossy()
))
} else {
Ok(PathBuf::from(OsStr::from_bytes(
CStr::from_ptr(path).to_bytes(),
)))
}
}

View file

@ -0,0 +1 @@
#include <dispatch/dispatch.h>

View file

@ -0,0 +1,100 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use crate::PlatformDispatcher;
use async_task::Runnable;
use objc::{
class, msg_send,
runtime::{BOOL, YES},
sel, sel_impl,
};
use std::{
ffi::c_void,
time::{Duration, SystemTime},
};
include!(concat!(env!("OUT_DIR"), "/dispatch_sys.rs"));
pub fn dispatch_get_main_queue() -> dispatch_queue_t {
unsafe { &_dispatch_main_q as *const _ as dispatch_queue_t }
}
pub struct MacDispatcher;
impl PlatformDispatcher for MacDispatcher {
fn is_main_thread(&self) -> bool {
let is_main_thread: BOOL = unsafe { msg_send![class!(NSThread), isMainThread] };
is_main_thread == YES
}
fn dispatch(&self, runnable: Runnable) {
unsafe {
dispatch_async_f(
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT.try_into().unwrap(), 0),
runnable.into_raw() as *mut c_void,
Some(trampoline),
);
}
}
fn dispatch_on_main_thread(&self, runnable: Runnable) {
unsafe {
dispatch_async_f(
dispatch_get_main_queue(),
runnable.into_raw() as *mut c_void,
Some(trampoline),
);
}
}
fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
let now = SystemTime::now();
let after_duration = now
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64
+ duration.as_nanos() as u64;
unsafe {
let queue =
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT.try_into().unwrap(), 0);
let when = dispatch_time(0, after_duration as i64);
dispatch_after_f(
when,
queue,
runnable.into_raw() as *mut c_void,
Some(trampoline),
);
}
}
fn poll(&self) -> bool {
false
}
}
extern "C" fn trampoline(runnable: *mut c_void) {
let task = unsafe { Runnable::from_raw(runnable as *mut ()) };
task.run();
}
// #include <dispatch/dispatch.h>
// int main(void) {
// dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
// // Do some lengthy background work here...
// printf("Background Work\n");
// dispatch_async(dispatch_get_main_queue(), ^{
// // Once done, update your UI on the main queue here.
// printf("UI Updated\n");
// });
// });
// sleep(3); // prevent the program from terminating immediately
// return 0;
// }
// ```

Some files were not shown because too many files have changed in this diff Show more