diff --git a/.github/workflows/release_actions.yml b/.github/workflows/release_actions.yml index c1df24a8e5..550eda882b 100644 --- a/.github/workflows/release_actions.yml +++ b/.github/workflows/release_actions.yml @@ -20,9 +20,7 @@ jobs: id: get-content with: stringToTruncate: | - 📣 Zed ${{ github.event.release.tag_name }} was just released! - - Restart your Zed or head to ${{ steps.get-release-url.outputs.URL }} to grab it. + 📣 Zed [${{ github.event.release.tag_name }}](${{ steps.get-release-url.outputs.URL }}) was just released! ${{ github.event.release.body }} maxLength: 2000 diff --git a/Cargo.lock b/Cargo.lock index 3dd85395b1..3aca27106c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", @@ -103,7 +104,34 @@ dependencies = [ "rusqlite", "serde", "serde_json", - "tiktoken-rs 0.5.4", + "tiktoken-rs", + "util", +] + +[[package]] +name = "ai2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "futures 0.3.28", + "gpui2", + "isahc", + "language2", + "lazy_static", + "log", + "matrixmultiply", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "parse_duration", + "postage", + "rand 0.8.5", + "regex", + "rusqlite", + "serde", + "serde_json", + "tiktoken-rs", "util", ] @@ -309,6 +337,7 @@ dependencies = [ "language", "log", "menu", + "multi_buffer", "ordered-float 2.10.0", "parking_lot 0.11.2", "project", @@ -316,12 +345,13 @@ dependencies = [ "regex", "schemars", "search", + "semantic_index", "serde", "serde_json", "settings", "smol", "theme", - "tiktoken-rs 0.4.5", + "tiktoken-rs", "util", "uuid 1.4.1", "workspace", @@ -1573,7 +1603,7 @@ dependencies = [ [[package]] name = "collab" -version = "0.24.0" +version = "0.27.0" dependencies = [ "anyhow", "async-trait", @@ -1609,6 +1639,7 @@ dependencies = [ "lsp", "nanoid", "node_runtime", + "notifications", "parking_lot 0.11.2", "pretty_assertions", "project", @@ -1664,20 +1695,26 @@ dependencies = [ "fuzzy", "gpui", "language", + "lazy_static", "log", "menu", + "notifications", "picker", "postage", + "pretty_assertions", "project", "recent_projects", "rich_text", + "rpc", "schemars", "serde", "serde_derive", "settings", + "smallvec", "theme", "theme_selector", "time", + "tree-sitter-markdown", "util", "vcs_menu", "workspace", @@ -1731,6 +1768,7 @@ dependencies = [ "theme", "util", "workspace", + "zed-actions", ] [[package]] @@ -1810,6 +1848,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "parking_lot 0.11.2", "rpc", "serde", "serde_derive", @@ -2556,11 +2595,11 @@ dependencies = [ "lazy_static", "log", "lsp", + "multi_buffer", "ordered-float 2.10.0", "parking_lot 0.11.2", "postage", "project", - "pulldown-cmark", "rand 0.8.5", "rich_text", "rpc", @@ -4159,6 +4198,24 @@ dependencies = [ "workspace", ] +[[package]] +name = "journal2" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "dirs 4.0.0", + "editor", + "gpui2", + "log", + "schemars", + "serde", + "settings2", + "shellexpand", + "util", + "workspace", +] + [[package]] name = "jpeg-decoder" version = "0.1.22" @@ -4244,6 +4301,7 @@ dependencies = [ "lsp", "parking_lot 0.11.2", "postage", + "pulldown-cmark", "rand 0.8.5", "regex", "rpc", @@ -4764,6 +4822,13 @@ dependencies = [ "gpui", ] +[[package]] +name = "menu2" +version = "0.1.0" +dependencies = [ + "gpui2", +] + [[package]] name = "metal" version = "0.21.0" @@ -4921,6 +4986,55 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7843ec2de400bcbc6a6328c958dc38e5359da6e93e72e37bc5246bf1ae776389" +[[package]] +name = "multi_buffer" +version = "0.1.0" +dependencies = [ + "aho-corasick", + "anyhow", + "client", + "clock", + "collections", + "context_menu", + "convert_case 0.6.0", + "copilot", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "git", + "gpui", + "indoc", + "itertools 0.10.5", + "language", + "lazy_static", + "log", + "lsp", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "postage", + "project", + "pulldown-cmark", + "rand 0.8.5", + "rich_text", + "schemars", + "serde", + "serde_derive", + "settings", + "smallvec", + "smol", + "snippet", + "sum_tree", + "text", + "theme", + "tree-sitter", + "tree-sitter-html", + "tree-sitter-rust", + "tree-sitter-typescript", + "unindent", + "util", + "workspace", +] + [[package]] name = "multimap" version = "0.8.3" @@ -5070,6 +5184,26 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "notifications" +version = "0.1.0" +dependencies = [ + "anyhow", + "channel", + "client", + "clock", + "collections", + "db", + "feature_flags", + "gpui", + "rpc", + "settings", + "sum_tree", + "text", + "time", + "util", +] + [[package]] name = "ntapi" version = "0.3.7" @@ -5886,6 +6020,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "parking_lot 0.11.2", "serde", "serde_derive", "serde_json", @@ -6831,8 +6966,10 @@ dependencies = [ "rsa 0.4.0", "serde", "serde_derive", + "serde_json", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", @@ -7407,7 +7544,7 @@ dependencies = [ "smol", "tempdir", "theme", - "tiktoken-rs 0.5.4", + "tiktoken-rs", "tree-sitter", "tree-sitter-cpp", "tree-sitter-elixir", @@ -7421,7 +7558,6 @@ dependencies = [ "unindent", "util", "workspace", - "zed", ] [[package]] @@ -8638,6 +8774,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap 4.4.4", + "convert_case 0.6.0", "gpui2", "log", "rust-embed", @@ -8713,21 +8850,6 @@ dependencies = [ "weezl", ] -[[package]] -name = "tiktoken-rs" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614" -dependencies = [ - "anyhow", - "base64 0.21.4", - "bstr", - "fancy-regex", - "lazy_static", - "parking_lot 0.12.1", - "rustc-hash", -] - [[package]] name = "tiktoken-rs" version = "0.5.4" @@ -9148,8 +9270,8 @@ dependencies = [ [[package]] name = "tree-sitter-bash" -version = "0.19.0" -source = "git+https://github.com/tree-sitter/tree-sitter-bash?rev=1b0321ee85701d5036c334a6f04761cdc672e64c#1b0321ee85701d5036c334a6f04761cdc672e64c" +version = "0.20.4" +source = "git+https://github.com/tree-sitter/tree-sitter-bash?rev=7331995b19b8f8aba2d5e26deb51d2195c18bc94#7331995b19b8f8aba2d5e26deb51d2195c18bc94" dependencies = [ "cc", "tree-sitter", @@ -9388,6 +9510,15 @@ dependencies = [ "tree-sitter", ] +[[package]] +name = "tree-sitter-vue" +version = "0.0.1" +source = "git+https://github.com/zed-industries/tree-sitter-vue?rev=95b2890#95b28908d90e928c308866f7631e73ef6e1d4b5f" +dependencies = [ + "cc", + "tree-sitter", +] + [[package]] name = "tree-sitter-yaml" version = "0.0.1" @@ -9469,10 +9600,8 @@ dependencies = [ "itertools 0.11.0", "rand 0.8.5", "serde", - "settings", "smallvec", "strum", - "theme", "theme2", ] @@ -9714,6 +9843,7 @@ name = "vcs_menu" version = "0.1.0" dependencies = [ "anyhow", + "fs", "fuzzy", "gpui", "picker", @@ -10658,9 +10788,10 @@ dependencies = [ [[package]] name = "zed" -version = "0.109.0" +version = "0.111.0" dependencies = [ "activity_indicator", + "ai", "anyhow", "assistant", "async-compression", @@ -10712,6 +10843,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "notifications", "num_cpus", "outline", "parking_lot 0.11.2", @@ -10773,6 +10905,7 @@ dependencies = [ "tree-sitter-svelte", "tree-sitter-toml", "tree-sitter-typescript", + "tree-sitter-vue", "tree-sitter-yaml", "unindent", "url", @@ -10790,12 +10923,14 @@ name = "zed-actions" version = "0.1.0" dependencies = [ "gpui", + "serde", ] [[package]] name = "zed2" version = "0.109.0" dependencies = [ + "ai2", "anyhow", "async-compression", "async-recursion 0.3.2", @@ -10811,7 +10946,7 @@ dependencies = [ "ctor", "db2", "env_logger 0.9.3", - "feature_flags", + "feature_flags2", "fs2", "fsevent", "futures 0.3.28", @@ -10822,12 +10957,13 @@ dependencies = [ "indexmap 1.9.3", "install_cli", "isahc", + "journal2", "language2", "language_tools", "lazy_static", "libc", "log", - "lsp", + "lsp2", "node_runtime", "num_cpus", "parking_lot 0.11.2", @@ -10880,6 +11016,7 @@ dependencies = [ "tree-sitter-svelte", "tree-sitter-toml", "tree-sitter-typescript", + "tree-sitter-vue", "tree-sitter-yaml", "unindent", "url", diff --git a/Cargo.toml b/Cargo.toml index 82af9265dd..ac490ce935 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ members = [ "crates/install_cli", "crates/install_cli2", "crates/journal", + "crates/journal2", "crates/language", "crates/language2", "crates/language_selector", @@ -58,7 +59,10 @@ members = [ "crates/lsp2", "crates/media", "crates/menu", + "crates/menu2", + "crates/multi_buffer", "crates/node_runtime", + "crates/notifications", "crates/outline", "crates/picker", "crates/plugin", @@ -133,6 +137,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } smallvec = { version = "1.6", features = ["union"] } smol = { version = "1.2" } +strum = { version = "0.25.0", features = ["derive"] } sysinfo = "0.29.10" tempdir = { version = "0.3.7" } thiserror = { version = "1.0.29" } @@ -144,7 +149,7 @@ pretty_assertions = "1.3.0" git2 = { version = "0.15", default-features = false} uuid = { version = "1.1.2", features = ["v4"] } -tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "1b0321ee85701d5036c334a6f04761cdc672e64c" } +tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "7331995b19b8f8aba2d5e26deb51d2195c18bc94" } tree-sitter-c = "0.20.1" tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev="f44509141e7e483323d2ec178f2d2e6c0fc041c1" } tree-sitter-css = { git = "https://github.com/tree-sitter/tree-sitter-css", rev = "769203d0f9abe1a9a691ac2b9fe4bb4397a73c51" } @@ -170,7 +175,7 @@ tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", tree-sitter-lua = "0.0.14" tree-sitter-nix = { git = "https://github.com/nix-community/tree-sitter-nix", rev = "66e3e9ce9180ae08fc57372061006ef83f0abde7" } tree-sitter-nu = { git = "https://github.com/nushell/tree-sitter-nu", rev = "786689b0562b9799ce53e824cb45a1a2a04dc673"} - +tree-sitter-vue = {git = "https://github.com/zed-industries/tree-sitter-vue", rev = "95b2890"} [patch.crates-io] tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "35a6052fbcafc5e5fc0f9415b8652be7dcaf7222" } async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" } diff --git a/Procfile b/Procfile index 2eb7de20fb..3f42c3a967 100644 --- a/Procfile +++ b/Procfile @@ -1,4 +1,4 @@ web: cd ../zed.dev && PORT=3000 npm run dev -collab: cd crates/collab && RUST_LOG=${RUST_LOG:-collab=info} cargo run serve +collab: cd crates/collab && RUST_LOG=${RUST_LOG:-warn,collab=info} cargo run serve livekit: livekit-server --dev postgrest: postgrest crates/collab/admin_api.conf diff --git a/assets/icons/bell.svg b/assets/icons/bell.svg new file mode 100644 index 0000000000..ea1c6dd42e --- /dev/null +++ b/assets/icons/bell.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/icons/link.svg b/assets/icons/link.svg new file mode 100644 index 0000000000..4925bd8e00 --- /dev/null +++ b/assets/icons/link.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/public.svg b/assets/icons/public.svg new file mode 100644 index 0000000000..38278cdaba --- /dev/null +++ b/assets/icons/public.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/update.svg b/assets/icons/update.svg new file mode 100644 index 0000000000..b529b2b08b --- /dev/null +++ b/assets/icons/update.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 8422d53abc..ef6a655bdc 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -370,42 +370,15 @@ { "context": "Pane", "bindings": { - "ctrl-1": [ - "pane::ActivateItem", - 0 - ], - "ctrl-2": [ - "pane::ActivateItem", - 1 - ], - "ctrl-3": [ - "pane::ActivateItem", - 2 - ], - "ctrl-4": [ - "pane::ActivateItem", - 3 - ], - "ctrl-5": [ - "pane::ActivateItem", - 4 - ], - "ctrl-6": [ - "pane::ActivateItem", - 5 - ], - "ctrl-7": [ - "pane::ActivateItem", - 6 - ], - "ctrl-8": [ - "pane::ActivateItem", - 7 - ], - "ctrl-9": [ - "pane::ActivateItem", - 8 - ], + "ctrl-1": ["pane::ActivateItem", 0], + "ctrl-2": ["pane::ActivateItem", 1], + "ctrl-3": ["pane::ActivateItem", 2], + "ctrl-4": ["pane::ActivateItem", 3], + "ctrl-5": ["pane::ActivateItem", 4], + "ctrl-6": ["pane::ActivateItem", 5], + "ctrl-7": ["pane::ActivateItem", 6], + "ctrl-8": ["pane::ActivateItem", 7], + "ctrl-9": ["pane::ActivateItem", 8], "ctrl-0": "pane::ActivateLastItem", "ctrl--": "pane::GoBack", "ctrl-_": "pane::GoForward", @@ -416,42 +389,15 @@ { "context": "Workspace", "bindings": { - "cmd-1": [ - "workspace::ActivatePane", - 0 - ], - "cmd-2": [ - "workspace::ActivatePane", - 1 - ], - "cmd-3": [ - "workspace::ActivatePane", - 2 - ], - "cmd-4": [ - "workspace::ActivatePane", - 3 - ], - "cmd-5": [ - "workspace::ActivatePane", - 4 - ], - "cmd-6": [ - "workspace::ActivatePane", - 5 - ], - "cmd-7": [ - "workspace::ActivatePane", - 6 - ], - "cmd-8": [ - "workspace::ActivatePane", - 7 - ], - "cmd-9": [ - "workspace::ActivatePane", - 8 - ], + "cmd-1": ["workspace::ActivatePane", 0], + "cmd-2": ["workspace::ActivatePane", 1], + "cmd-3": ["workspace::ActivatePane", 2], + "cmd-4": ["workspace::ActivatePane", 3], + "cmd-5": ["workspace::ActivatePane", 4], + "cmd-6": ["workspace::ActivatePane", 5], + "cmd-7": ["workspace::ActivatePane", 6], + "cmd-8": ["workspace::ActivatePane", 7], + "cmd-9": ["workspace::ActivatePane", 8], "cmd-b": "workspace::ToggleLeftDock", "cmd-r": "workspace::ToggleRightDock", "cmd-j": "workspace::ToggleBottomDock", @@ -494,38 +440,14 @@ }, { "bindings": { - "cmd-k cmd-left": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "cmd-k cmd-right": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "cmd-k cmd-up": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "cmd-k cmd-down": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "cmd-k shift-left": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "cmd-k shift-right": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "cmd-k shift-up": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "cmd-k shift-down": [ - "workspace::SwapPaneInDirection", - "Down" - ] + "cmd-k cmd-left": ["workspace::ActivatePaneInDirection", "Left"], + "cmd-k cmd-right": ["workspace::ActivatePaneInDirection", "Right"], + "cmd-k cmd-up": ["workspace::ActivatePaneInDirection", "Up"], + "cmd-k cmd-down": ["workspace::ActivatePaneInDirection", "Down"], + "cmd-k shift-left": ["workspace::SwapPaneInDirection", "Left"], + "cmd-k shift-right": ["workspace::SwapPaneInDirection", "Right"], + "cmd-k shift-up": ["workspace::SwapPaneInDirection", "Up"], + "cmd-k shift-down": ["workspace::SwapPaneInDirection", "Down"] } }, // Bindings from Atom @@ -627,14 +549,6 @@ "space": "collab_panel::InsertSpace" } }, - { - "context": "(CollabPanel && not_editing) > Editor", - "bindings": { - "cmd-c": "collab_panel::StartLinkChannel", - "cmd-x": "collab_panel::StartMoveChannel", - "cmd-v": "collab_panel::MoveOrLinkToSelected" - } - }, { "context": "ChannelModal", "bindings": { @@ -655,57 +569,21 @@ "cmd-v": "terminal::Paste", "cmd-k": "terminal::Clear", // Some nice conveniences - "cmd-backspace": [ - "terminal::SendText", - "\u0015" - ], - "cmd-right": [ - "terminal::SendText", - "\u0005" - ], - "cmd-left": [ - "terminal::SendText", - "\u0001" - ], + "cmd-backspace": ["terminal::SendText", "\u0015"], + "cmd-right": ["terminal::SendText", "\u0005"], + "cmd-left": ["terminal::SendText", "\u0001"], // Terminal.app compatibility - "alt-left": [ - "terminal::SendText", - "\u001bb" - ], - "alt-right": [ - "terminal::SendText", - "\u001bf" - ], + "alt-left": ["terminal::SendText", "\u001bb"], + "alt-right": ["terminal::SendText", "\u001bf"], // There are conflicting bindings for these keys in the global context. // these bindings override them, remove at your own risk: - "up": [ - "terminal::SendKeystroke", - "up" - ], - "pageup": [ - "terminal::SendKeystroke", - "pageup" - ], - "down": [ - "terminal::SendKeystroke", - "down" - ], - "pagedown": [ - "terminal::SendKeystroke", - "pagedown" - ], - "escape": [ - "terminal::SendKeystroke", - "escape" - ], - "enter": [ - "terminal::SendKeystroke", - "enter" - ], - "ctrl-c": [ - "terminal::SendKeystroke", - "ctrl-c" - ] + "up": ["terminal::SendKeystroke", "up"], + "pageup": ["terminal::SendKeystroke", "pageup"], + "down": ["terminal::SendKeystroke", "down"], + "pagedown": ["terminal::SendKeystroke", "pagedown"], + "escape": ["terminal::SendKeystroke", "escape"], + "enter": ["terminal::SendKeystroke", "enter"], + "ctrl-c": ["terminal::SendKeystroke", "ctrl-c"] } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index ea025747d8..81235bb72a 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -39,6 +39,7 @@ "w": "vim::NextWordStart", "{": "vim::StartOfParagraph", "}": "vim::EndOfParagraph", + "|": "vim::GoToColumn", "shift-w": [ "vim::NextWordStart", { @@ -97,14 +98,8 @@ "ctrl-o": "pane::GoBack", "ctrl-i": "pane::GoForward", "ctrl-]": "editor::GoToDefinition", - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ], + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"], "v": "vim::ToggleVisual", "shift-v": "vim::ToggleVisualLine", "ctrl-v": "vim::ToggleVisualBlock", @@ -233,123 +228,36 @@ } ], // Count support - "1": [ - "vim::Number", - 1 - ], - "2": [ - "vim::Number", - 2 - ], - "3": [ - "vim::Number", - 3 - ], - "4": [ - "vim::Number", - 4 - ], - "5": [ - "vim::Number", - 5 - ], - "6": [ - "vim::Number", - 6 - ], - "7": [ - "vim::Number", - 7 - ], - "8": [ - "vim::Number", - 8 - ], - "9": [ - "vim::Number", - 9 - ], + "1": ["vim::Number", 1], + "2": ["vim::Number", 2], + "3": ["vim::Number", 3], + "4": ["vim::Number", 4], + "5": ["vim::Number", 5], + "6": ["vim::Number", 6], + "7": ["vim::Number", 7], + "8": ["vim::Number", 8], + "9": ["vim::Number", 9], // window related commands (ctrl-w X) - "ctrl-w left": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w right": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w up": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w down": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w h": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w l": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w k": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w j": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w ctrl-h": [ - "workspace::ActivatePaneInDirection", - "Left" - ], - "ctrl-w ctrl-l": [ - "workspace::ActivatePaneInDirection", - "Right" - ], - "ctrl-w ctrl-k": [ - "workspace::ActivatePaneInDirection", - "Up" - ], - "ctrl-w ctrl-j": [ - "workspace::ActivatePaneInDirection", - "Down" - ], - "ctrl-w shift-left": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "ctrl-w shift-right": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "ctrl-w shift-up": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "ctrl-w shift-down": [ - "workspace::SwapPaneInDirection", - "Down" - ], - "ctrl-w shift-h": [ - "workspace::SwapPaneInDirection", - "Left" - ], - "ctrl-w shift-l": [ - "workspace::SwapPaneInDirection", - "Right" - ], - "ctrl-w shift-k": [ - "workspace::SwapPaneInDirection", - "Up" - ], - "ctrl-w shift-j": [ - "workspace::SwapPaneInDirection", - "Down" - ], + "ctrl-w left": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w right": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w up": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w down": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w h": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w l": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w k": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w j": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w ctrl-h": ["workspace::ActivatePaneInDirection", "Left"], + "ctrl-w ctrl-l": ["workspace::ActivatePaneInDirection", "Right"], + "ctrl-w ctrl-k": ["workspace::ActivatePaneInDirection", "Up"], + "ctrl-w ctrl-j": ["workspace::ActivatePaneInDirection", "Down"], + "ctrl-w shift-left": ["workspace::SwapPaneInDirection", "Left"], + "ctrl-w shift-right": ["workspace::SwapPaneInDirection", "Right"], + "ctrl-w shift-up": ["workspace::SwapPaneInDirection", "Up"], + "ctrl-w shift-down": ["workspace::SwapPaneInDirection", "Down"], + "ctrl-w shift-h": ["workspace::SwapPaneInDirection", "Left"], + "ctrl-w shift-l": ["workspace::SwapPaneInDirection", "Right"], + "ctrl-w shift-k": ["workspace::SwapPaneInDirection", "Up"], + "ctrl-w shift-j": ["workspace::SwapPaneInDirection", "Down"], "ctrl-w g t": "pane::ActivateNextItem", "ctrl-w ctrl-g t": "pane::ActivateNextItem", "ctrl-w g shift-t": "pane::ActivatePrevItem", @@ -371,14 +279,8 @@ "ctrl-w ctrl-q": "pane::CloseAllItems", "ctrl-w o": "workspace::CloseInactiveTabsAndPanes", "ctrl-w ctrl-o": "workspace::CloseInactiveTabsAndPanes", - "ctrl-w n": [ - "workspace::NewFileInDirection", - "Up" - ], - "ctrl-w ctrl-n": [ - "workspace::NewFileInDirection", - "Up" - ] + "ctrl-w n": ["workspace::NewFileInDirection", "Up"], + "ctrl-w ctrl-n": ["workspace::NewFileInDirection", "Up"] } }, { @@ -393,21 +295,12 @@ "context": "Editor && vim_mode == normal && vim_operator == none && !VimWaiting", "bindings": { ".": "vim::Repeat", - "c": [ - "vim::PushOperator", - "Change" - ], + "c": ["vim::PushOperator", "Change"], "shift-c": "vim::ChangeToEndOfLine", - "d": [ - "vim::PushOperator", - "Delete" - ], + "d": ["vim::PushOperator", "Delete"], "shift-d": "vim::DeleteToEndOfLine", "shift-j": "vim::JoinLines", - "y": [ - "vim::PushOperator", - "Yank" - ], + "y": ["vim::PushOperator", "Yank"], "shift-y": "vim::YankLine", "i": "vim::InsertBefore", "shift-i": "vim::InsertFirstNonWhitespace", @@ -443,10 +336,7 @@ "backwards": true } ], - "r": [ - "vim::PushOperator", - "Replace" - ], + "r": ["vim::PushOperator", "Replace"], "s": "vim::Substitute", "shift-s": "vim::SubstituteLine", "> >": "editor::Indent", @@ -458,10 +348,7 @@ { "context": "Editor && VimCount", "bindings": { - "0": [ - "vim::Number", - 0 - ] + "0": ["vim::Number", 0] } }, { @@ -497,12 +384,15 @@ "'": "vim::Quotes", "`": "vim::BackQuotes", "\"": "vim::DoubleQuotes", + "|": "vim::VerticalBars", "(": "vim::Parentheses", ")": "vim::Parentheses", + "b": "vim::Parentheses", "[": "vim::SquareBrackets", "]": "vim::SquareBrackets", "{": "vim::CurlyBrackets", "}": "vim::CurlyBrackets", + "shift-b": "vim::CurlyBrackets", "<": "vim::AngleBrackets", ">": "vim::AngleBrackets" } @@ -548,22 +438,10 @@ "shift-i": "vim::InsertBefore", "shift-a": "vim::InsertAfter", "shift-j": "vim::JoinLines", - "r": [ - "vim::PushOperator", - "Replace" - ], - "ctrl-c": [ - "vim::SwitchMode", - "Normal" - ], - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ], + "r": ["vim::PushOperator", "Replace"], + "ctrl-c": ["vim::SwitchMode", "Normal"], + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"], ">": "editor::Indent", "<": "editor::Outdent", "i": [ @@ -602,14 +480,8 @@ "bindings": { "tab": "vim::Tab", "enter": "vim::Enter", - "escape": [ - "vim::SwitchMode", - "Normal" - ], - "ctrl+[": [ - "vim::SwitchMode", - "Normal" - ] + "escape": ["vim::SwitchMode", "Normal"], + "ctrl+[": ["vim::SwitchMode", "Normal"] } }, { diff --git a/assets/settings/default.json b/assets/settings/default.json index 1611d80e2f..19c73ca021 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -50,6 +50,9 @@ // Whether to pop the completions menu while typing in an editor without // explicitly requesting it. "show_completions_on_input": true, + // Whether to display inline and alongside documentation for items in the + // completions menu + "show_completion_documentation": true, // Whether to show wrap guides in the editor. Setting this to true will // show a guide at the 'preferred_line_length' value if softwrap is set to // 'preferred_line_length', and will show any additional guides as specified @@ -139,6 +142,14 @@ // Default width of the channels panel. "default_width": 240 }, + "notification_panel": { + // Whether to show the collaboration panel button in the status bar. + "button": true, + // Where to dock channels panel. Can be 'left' or 'right'. + "dock": "right", + // Default width of the channels panel. + "default_width": 380 + }, "assistant": { // Whether to show the assistant panel button in the status bar. "button": true, diff --git a/crates/Cargo.toml b/crates/Cargo.toml new file mode 100644 index 0000000000..fb49a4b515 --- /dev/null +++ b/crates/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui = { path = "../gpui" } +util = { path = "../util" } +language = { path = "../language" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422f..fb49a4b515 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -8,9 +8,13 @@ publish = false path = "src/ai.rs" doctest = false +[features] +test-support = [] + [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 5256a6a643..dda22d2a1d 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,2 +1,8 @@ +pub mod auth; pub mod completion; pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs new file mode 100644 index 0000000000..c6256df216 --- /dev/null +++ b/crates/ai/src/auth.rs @@ -0,0 +1,15 @@ +use gpui::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); + fn delete_credentials(&self, cx: &AppContext); +} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 170b2268f9..30a60fcf1d 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,212 +1,23 @@ -use anyhow::{anyhow, Result}; -use futures::{ - future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, - Stream, StreamExt, -}; -use gpui::executor::Background; -use isahc::{http::StatusCode, Request, RequestExt}; -use serde::{Deserialize, Serialize}; -use std::{ - fmt::{self, Display}, - io, - sync::Arc, -}; +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; +use crate::{auth::CredentialProvider, models::LanguageModel}; -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; } -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { - pub model: String, - pub messages: Vec, - pub stream: bool, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} - -pub trait CompletionProvider { +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; } -pub struct OpenAICompletionProvider { - api_key: String, - executor: Arc, -} - -impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } - } -} - -impl CompletionProvider for OpenAICompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() } } diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 332470aa54..6768b7ce7b 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -1,30 +1,13 @@ -use anyhow::{anyhow, Result}; +use std::time::Instant; + +use anyhow::Result; use async_trait::async_trait; -use futures::AsyncReadExt; -use gpui::executor::Background; -use gpui::serde_json; -use isahc::http::StatusCode; -use isahc::prelude::Configurable; -use isahc::{AsyncBody, Response}; -use lazy_static::lazy_static; use ordered_float::OrderedFloat; -use parking_lot::Mutex; -use parse_duration::parse; -use postage::watch; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use serde::{Deserialize, Serialize}; -use std::env; -use std::ops::Add; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tiktoken_rs::{cl100k_base, CoreBPE}; -use util::http::{HttpClient, Request}; -lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); - static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); -} +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -85,295 +68,14 @@ impl Embedding { } } -// impl FromSql for Embedding { -// fn column_result(value: ValueRef) -> FromSqlResult { -// let bytes = value.as_blob()?; -// let embedding: Result, Box> = bincode::deserialize(bytes); -// if embedding.is_err() { -// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); -// } -// Ok(Embedding(embedding.unwrap())) -// } -// } - -// impl ToSql for Embedding { -// fn to_sql(&self) -> rusqlite::Result { -// let bytes = bincode::serialize(&self.0) -// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; -// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) -// } -// } - -#[derive(Clone)] -pub struct OpenAIEmbeddings { - pub client: Arc, - pub executor: Arc, - rate_limit_count_rx: watch::Receiver>, - rate_limit_count_tx: Arc>>>, -} - -#[derive(Serialize)] -struct OpenAIEmbeddingRequest<'a> { - model: &'static str, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingResponse { - data: Vec, - usage: OpenAIEmbeddingUsage, -} - -#[derive(Debug, Deserialize)] -struct OpenAIEmbedding { - embedding: Vec, - index: usize, - object: String, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingUsage { - prompt_tokens: usize, - total_tokens: usize, -} - #[async_trait] -pub trait EmbeddingProvider: Sync + Send { - fn is_authenticated(&self) -> bool; +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; - fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; } -pub struct DummyEmbeddings {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddings { - fn is_authenticated(&self) -> bool { - true - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch(&self, spans: Vec) -> Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - new_input.ok().unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } -} - -const OPENAI_INPUT_LIMIT: usize = 8190; - -impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { - let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); - let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - - OpenAIEmbeddings { - client, - executor, - rate_limit_count_rx, - rate_limit_count_tx, - } - } - - fn resolve_rate_limit(&self) { - let reset_time = *self.rate_limit_count_tx.lock().borrow(); - - if let Some(reset_time) = reset_time { - if Instant::now() >= reset_time { - *self.rate_limit_count_tx.lock().borrow_mut() = None - } - } - - log::trace!( - "resolving reset time: {:?}", - *self.rate_limit_count_tx.lock().borrow() - ); - } - - fn update_reset_time(&self, reset_time: Instant) { - let original_time = *self.rate_limit_count_tx.lock().borrow(); - - let updated_time = if let Some(original_time) = original_time { - if reset_time < original_time { - Some(reset_time) - } else { - Some(original_time) - } - } else { - Some(reset_time) - }; - - log::trace!("updating rate limit time: {:?}", updated_time); - - *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; - } - async fn send_request( - &self, - api_key: &str, - spans: Vec<&str>, - request_timeout: u64, - ) -> Result> { - let request = Request::post("https://api.openai.com/v1/embeddings") - .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(request_timeout)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body( - serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans.clone(), - model: "text-embedding-ada-002", - }) - .unwrap() - .into(), - )?; - - Ok(self.client.send(request).await?) - } -} - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() - } - fn max_tokens_per_batch(&self) -> usize { - 50000 - } - - fn rate_limit_expiration(&self) -> Option { - *self.rate_limit_count_rx.borrow() - } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } - - async fn embed_batch(&self, spans: Vec) -> Result> { - const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; - const MAX_RETRIES: usize = 4; - - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; - - let mut request_number = 0; - let mut rate_limiting = false; - let mut request_timeout: u64 = 15; - let mut response: Response; - while request_number < MAX_RETRIES { - response = self - .send_request( - api_key, - spans.iter().map(|x| &**x).collect(), - request_timeout, - ) - .await?; - request_number += 1; - - match response.status() { - StatusCode::REQUEST_TIMEOUT => { - request_timeout += 5; - } - StatusCode::OK => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::trace!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - // If we complete a request successfully that was previously rate_limited - // resolve the rate limit - if rate_limiting { - self.resolve_rate_limit() - } - - return Ok(response - .data - .into_iter() - .map(|embedding| Embedding::from(embedding.embedding)) - .collect()); - } - StatusCode::TOO_MANY_REQUESTS => { - rate_limiting = true; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let delay_duration = { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - if let Some(time_to_reset) = - response.headers().get("x-ratelimit-reset-tokens") - { - if let Ok(time_str) = time_to_reset.to_str() { - parse(time_str).unwrap_or(delay) - } else { - delay - } - } else { - delay - } - }; - - // If we've previously rate limited, increment the duration but not the count - let reset_time = Instant::now().add(delay_duration); - self.update_reset_time(reset_time); - - log::trace!( - "openai rate limiting: waiting {:?} until lifted", - &delay_duration - ); - - self.executor.timer(delay_duration).await; - } - _ => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } - } - } - Err(anyhow!("openai max retries")) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000..1db3d58c6f --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs new file mode 100644 index 0000000000..75bad00154 --- /dev/null +++ b/crates/ai/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai/src/prompts/file_context.rs b/crates/ai/src/prompts/file_context.rs new file mode 100644 index 0000000000..f108a62f6f --- /dev/null +++ b/crates/ai/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language::BufferSnapshot; +use language::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai/src/prompts/generate.rs b/crates/ai/src/prompts/generate.rs new file mode 100644 index 0000000000..c7be620107 --- /dev/null +++ b/crates/ai/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/prompts/mod.rs b/crates/ai/src/prompts/mod.rs new file mode 100644 index 0000000000..0025269a44 --- /dev/null +++ b/crates/ai/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/prompts/preamble.rs b/crates/ai/src/prompts/preamble.rs new file mode 100644 index 0000000000..92e0edeb78 --- /dev/null +++ b/crates/ai/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai/src/prompts/repository_context.rs b/crates/ai/src/prompts/repository_context.rs new file mode 100644 index 0000000000..c21b0f995c --- /dev/null +++ b/crates/ai/src/prompts/repository_context.rs @@ -0,0 +1,94 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs new file mode 100644 index 0000000000..acd0f9d910 --- /dev/null +++ b/crates/ai/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000..94685fd233 --- /dev/null +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -0,0 +1,298 @@ +use anyhow::{anyhow, Result}; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui::{executor::Background, AppContext}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000..fbfd0028f9 --- /dev/null +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::executor::Background; +use gpui::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000..7d2f86045d --- /dev/null +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs new file mode 100644 index 0000000000..6e306c80b9 --- /dev/null +++ b/crates/ai/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs new file mode 100644 index 0000000000..c7d67f2ba1 --- /dev/null +++ b/crates/ai/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs new file mode 100644 index 0000000000..d4165f3cca --- /dev/null +++ b/crates/ai/src/test.rs @@ -0,0 +1,191 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/Cargo.toml b/crates/ai2/Cargo.toml new file mode 100644 index 0000000000..4f06840e8e --- /dev/null +++ b/crates/ai2/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +language2 = { path = "../language2" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } diff --git a/crates/ai2/src/ai2.rs b/crates/ai2/src/ai2.rs new file mode 100644 index 0000000000..dda22d2a1d --- /dev/null +++ b/crates/ai2/src/ai2.rs @@ -0,0 +1,8 @@ +pub mod auth; +pub mod completion; +pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs new file mode 100644 index 0000000000..e4670bb449 --- /dev/null +++ b/crates/ai2/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use gpui2::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +#[async_trait] +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); + async fn delete_credentials(&self, cx: &mut AppContext); +} diff --git a/crates/ai2/src/completion.rs b/crates/ai2/src/completion.rs new file mode 100644 index 0000000000..30a60fcf1d --- /dev/null +++ b/crates/ai2/src/completion.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; + +use crate::{auth::CredentialProvider, models::LanguageModel}; + +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} + +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} diff --git a/crates/ai2/src/embedding.rs b/crates/ai2/src/embedding.rs new file mode 100644 index 0000000000..7ea4786178 --- /dev/null +++ b/crates/ai2/src/embedding.rs @@ -0,0 +1,123 @@ +use std::time::Instant; + +use anyhow::Result; +use async_trait::async_trait; +use ordered_float::OrderedFloat; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; + +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; + +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(pub Vec); + +// This is needed for semantic index functionality +// Unfortunately it has to live wherever the "Embedding" struct is created. +// Keeping this in here though, introduces a 'rusqlite' dependency into AI +// which is less than ideal +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> OrderedFloat { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + OrderedFloat(result) + } +} + +#[async_trait] +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn rate_limit_expiration(&self) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui2::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) + } + } +} diff --git a/crates/ai2/src/models.rs b/crates/ai2/src/models.rs new file mode 100644 index 0000000000..1db3d58c6f --- /dev/null +++ b/crates/ai2/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/prompts/base.rs b/crates/ai2/src/prompts/base.rs new file mode 100644 index 0000000000..29091d0f5b --- /dev/null +++ b/crates/ai2/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language2::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai2/src/prompts/file_context.rs b/crates/ai2/src/prompts/file_context.rs new file mode 100644 index 0000000000..4a741beb24 --- /dev/null +++ b/crates/ai2/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language2::BufferSnapshot; +use language2::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai2/src/prompts/generate.rs b/crates/ai2/src/prompts/generate.rs new file mode 100644 index 0000000000..c7be620107 --- /dev/null +++ b/crates/ai2/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai2/src/prompts/mod.rs b/crates/ai2/src/prompts/mod.rs new file mode 100644 index 0000000000..0025269a44 --- /dev/null +++ b/crates/ai2/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai2/src/prompts/preamble.rs b/crates/ai2/src/prompts/preamble.rs new file mode 100644 index 0000000000..92e0edeb78 --- /dev/null +++ b/crates/ai2/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs new file mode 100644 index 0000000000..1bb75de7d2 --- /dev/null +++ b/crates/ai2/src/prompts/repository_context.rs @@ -0,0 +1,98 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui2::{AsyncAppContext, Model}; +use language2::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new( + buffer: Model, + range: Range, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + })?; + + anyhow::Ok(PromptCodeSnippet { + path: file_path, + language_name, + content, + }) + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai2/src/providers/mod.rs b/crates/ai2/src/providers/mod.rs new file mode 100644 index 0000000000..acd0f9d910 --- /dev/null +++ b/crates/ai2/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai2/src/providers/open_ai/completion.rs b/crates/ai2/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000..eca5611027 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/completion.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui2::{AppContext, Executor}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +#[async_trait] +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000..fc49c15134 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -0,0 +1,313 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui2::Executor; +use gpui2::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai2/src/providers/open_ai/mod.rs b/crates/ai2/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000..7d2f86045d --- /dev/null +++ b/crates/ai2/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai2/src/providers/open_ai/model.rs b/crates/ai2/src/providers/open_ai/model.rs new file mode 100644 index 0000000000..6e306c80b9 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai2/src/providers/open_ai/new.rs b/crates/ai2/src/providers/open_ai/new.rs new file mode 100644 index 0000000000..c7d67f2ba1 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/test.rs b/crates/ai2/src/test.rs new file mode 100644 index 0000000000..ee88529aec --- /dev/null +++ b/crates/ai2/src/test.rs @@ -0,0 +1,193 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui2::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +#[async_trait] +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index f1daf47bab..fc885f6b36 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -17,13 +17,17 @@ fs = { path = "../fs" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +multi_buffer = { path = "../multi_buffer" } search = { path = "../search" } settings = { path = "../settings" } theme = { path = "../theme" } util = { path = "../util" } workspace = { path = "../workspace" } -uuid.workspace = true +semantic_index = { path = "../semantic_index" } +project = { path = "../project" } +uuid.workspace = true +log.workspace = true anyhow.workspace = true chrono = { version = "0.4", features = ["serde"] } futures.workspace = true @@ -36,11 +40,12 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true -tiktoken-rs = "0.4" +tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +ai = { path = "../ai", features = ["test-support"]} ctor.workspace = true env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 6c9b14333e..91d61a19f9 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -4,7 +4,7 @@ mod codegen; mod prompts; mod streaming_diff; -use ai::completion::Role; +use ai::providers::open_ai::Role; use anyhow::Result; pub use assistant_panel::AssistantPanel; use assistant_settings::OpenAIModel; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index b1c6038602..03eb3c238f 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,9 +5,14 @@ use crate::{ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + +use ai::{ + auth::ProviderCredential, + completion::{CompletionProvider, CompletionRequest}, + providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, }; + +use ai::prompts::repository_context::PromptCodeSnippet; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; @@ -29,24 +34,26 @@ use gpui::{ }, fonts::HighlightStyle, geometry::vector::{vec2f, Vector2F}, - platform::{CursorStyle, MouseButton}, + platform::{CursorStyle, MouseButton, PromptLevel}, Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext, - ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, - WindowContext, + ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, + WeakModelHandle, WeakViewHandle, WindowContext, }; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; +use project::Project; use search::BufferSearchBar; +use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ - cell::{Cell, RefCell}, - cmp, env, + cell::Cell, + cmp, fmt::Write, iter, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::Duration, + time::{Duration, Instant}, }; use theme::{ components::{action_button::Button, ComponentExt}, @@ -72,6 +79,7 @@ actions!( ResetKey, InlineAssist, ToggleIncludeConversation, + ToggleRetrieveContext, ] ); @@ -91,8 +99,8 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(ConversationEditor::copy); cx.add_action(ConversationEditor::split); cx.capture_action(ConversationEditor::cycle_message_role); - cx.add_action(AssistantPanel::save_api_key); - cx.add_action(AssistantPanel::reset_api_key); + cx.add_action(AssistantPanel::save_credentials); + cx.add_action(AssistantPanel::reset_credentials); cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::select_next_match); @@ -108,6 +116,7 @@ pub fn init(cx: &mut AppContext) { cx.add_action(InlineAssistant::confirm); cx.add_action(InlineAssistant::cancel); cx.add_action(InlineAssistant::toggle_include_conversation); + cx.add_action(InlineAssistant::toggle_retrieve_context); cx.add_action(InlineAssistant::move_up); cx.add_action(InlineAssistant::move_down); } @@ -133,9 +142,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - api_key: Rc>>, + completion_provider: Box, api_key_editor: Option>, - has_read_credentials: bool, languages: Arc, fs: Arc, subscriptions: Vec, @@ -145,6 +153,8 @@ pub struct AssistantPanel { include_conversation_in_next_inline_assist: bool, inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, + semantic_index: Option>, + retrieve_context_in_next_inline_assist: bool, } impl AssistantPanel { @@ -191,6 +201,14 @@ impl AssistantPanel { toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx); toolbar }); + + let semantic_index = SemanticIndex::global(cx); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = Box::new(OpenAICompletionProvider::new( + "gpt-4", + cx.background().clone(), + )); + let mut this = Self { workspace: workspace_handle, active_editor_index: Default::default(), @@ -201,9 +219,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - api_key: Rc::new(RefCell::new(None)), + completion_provider, api_key_editor: None, - has_read_credentials: false, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, @@ -215,6 +232,8 @@ impl AssistantPanel { include_conversation_in_next_inline_assist: false, inline_prompt_history: Default::default(), _watch_saved_conversations, + semantic_index, + retrieve_context_in_next_inline_assist: false, }; let mut old_dock_position = this.position(cx); @@ -240,10 +259,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this - .update(cx, |assistant, cx| assistant.load_api_key(cx)) - .is_some() - { + if this.update(cx, |assistant, _| assistant.has_credentials()) { this } else { workspace.focus_panel::(cx); @@ -262,20 +278,21 @@ impl AssistantPanel { return; }; + let project = workspace.project(); + this.update(cx, |assistant, cx| { - assistant.new_inline_assist(&active_editor, cx) + assistant.new_inline_assist(&active_editor, cx, project) }); } - fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - + fn new_inline_assist( + &mut self, + editor: &ViewHandle, + cx: &mut ViewContext, + project: &ModelHandle, + ) { let selection = editor.read(cx).selections.newest_anchor().clone(); - if selection.start.excerpt_id() != selection.end.excerpt_id() { + if selection.start.excerpt_id != selection.end.excerpt_id { return; } let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); @@ -304,14 +321,38 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( - api_key, + "gpt-4", cx.background().clone(), )); + // Retrieve Credentials Authenticates the Provider + // provider.retrieve_credentials(cx); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); + if let Some(semantic_index) = self.semantic_index.clone() { + let project = project.clone(); + cx.spawn(|_, mut cx| async move { + let previously_indexed = semantic_index + .update(&mut cx, |index, cx| { + index.project_previously_indexed(&project, cx) + }) + .await + .unwrap_or(false); + if previously_indexed { + let _ = semantic_index + .update(&mut cx, |index, cx| { + index.index_project(project.clone(), cx) + }) + .await; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + let measurements = Rc::new(Cell::new(BlockMeasurements::default())); let inline_assistant = cx.add_view(|cx| { let assistant = InlineAssistant::new( @@ -322,6 +363,9 @@ impl AssistantPanel { codegen.clone(), self.workspace.clone(), cx, + self.retrieve_context_in_next_inline_assist, + self.semantic_index.clone(), + project.clone(), ); cx.focus_self(); assistant @@ -362,6 +406,7 @@ impl AssistantPanel { editor: editor.downgrade(), inline_assistant: Some((block_id, inline_assistant.clone())), codegen: codegen.clone(), + project: project.downgrade(), _subscriptions: vec![ cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), cx.subscribe(editor, { @@ -440,8 +485,15 @@ impl AssistantPanel { InlineAssistantEvent::Confirmed { prompt, include_conversation, + retrieve_context, } => { - self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); + self.confirm_inline_assist( + assist_id, + prompt, + *include_conversation, + cx, + *retrieve_context, + ); } InlineAssistantEvent::Canceled => { self.finish_inline_assist(assist_id, true, cx); @@ -454,6 +506,9 @@ impl AssistantPanel { } => { self.include_conversation_in_next_inline_assist = *include_conversation; } + InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => { + self.retrieve_context_in_next_inline_assist = *retrieve_context + } } } @@ -532,6 +587,7 @@ impl AssistantPanel { user_prompt: &str, include_conversation: bool, cx: &mut ViewContext, + retrieve_context: bool, ) { let conversation = if include_conversation { self.active_editor() @@ -553,6 +609,20 @@ impl AssistantPanel { return; }; + let project = pending_assist.project.clone(); + + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -590,13 +660,70 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.to_string() != "Markdown".to_string() { + 0.5 + } else { + 1.0 + } + } else { + 1.0 + }; + let user_prompt = user_prompt.to_string(); - let mut messages = Vec::new(); + let snippets = if retrieve_context { + let Some(project) = project.upgrade(cx) else { + return; + }; + + let search_results = if let Some(semantic_index) = self.semantic_index.clone() { + let search_results = semantic_index.update(cx, |this, cx| { + this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) + }); + + cx.background() + .spawn(async move { search_results.await.unwrap_or_default() }) + } else { + Task::ready(Vec::new()) + }; + + let snippets = cx.spawn(|_, cx| async move { + let mut snippets = Vec::new(); + for result in search_results.await { + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); + } + snippets + }); + snippets + } else { + Task::ready(Vec::new()) + }; + let mut model = settings::get::(cx) .default_open_ai_model .clone(); + let model_name = model.full_name(); + + let prompt = cx.background().spawn(async move { + let snippets = snippets.await; + + let language_name = language_name.as_deref(); + generate_content_prompt( + user_prompt, + language_name, + buffer, + range, + snippets, + model_name, + project_name, + ) + }); + + let mut messages = Vec::new(); if let Some(conversation) = conversation { let conversation = conversation.read(cx); let buffer = conversation.buffer.read(cx); @@ -608,24 +735,25 @@ impl AssistantPanel { model = conversation.model.clone(); } - let prompt = cx.background().spawn(async move { - let language_name = language_name.as_deref(); - generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind) - }); - cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, content: prompt, }); - let request = OpenAIRequest { + + let request = Box::new(OpenAIRequest { model: model.full_name().into(), messages, stream: true, - }; + stop: vec!["|END|>".to_string()], + temperature, + }); + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } @@ -683,7 +811,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.api_key.clone(), + self.completion_provider.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -742,17 +870,19 @@ impl AssistantPanel { } } - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { if let Some(api_key) = self .api_key_editor .as_ref() .map(|editor| editor.read(cx).text(cx)) { if !api_key.is_empty() { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - *self.api_key.borrow_mut() = Some(api_key); + let credential = ProviderCredential::Credentials { + api_key: api_key.clone(), + }; + + self.completion_provider.save_credentials(cx, credential); + self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -762,9 +892,8 @@ impl AssistantPanel { } } - fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - self.api_key.take(); + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + self.completion_provider.delete_credentials(cx); self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1023,13 +1152,12 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let api_key = self.api_key.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + Conversation::deserialize(saved_conversation, path.clone(), languages, cx) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1053,30 +1181,12 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - } + fn has_credentials(&mut self) -> bool { + self.completion_provider.has_credentials() + } - self.api_key.borrow().clone() + fn load_credentials(&mut self, cx: &mut ViewContext) { + self.completion_provider.retrieve_credentials(cx); } } @@ -1261,7 +1371,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - self.load_api_key(cx); + self.load_credentials(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1326,10 +1436,10 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, pending_save: Task>, path: Option, _subscriptions: Vec, + completion_provider: Box, } impl Entity for Conversation { @@ -1338,9 +1448,9 @@ impl Entity for Conversation { impl Conversation { fn new( - api_key: Rc>>, language_registry: Arc, cx: &mut ModelContext, + completion_provider: Box, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1379,8 +1489,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - api_key, buffer, + completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1426,7 +1536,6 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - api_key: Rc>>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1435,6 +1544,10 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let completion_provider: Box = Box::new( + OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), + ); + completion_provider.retrieve_credentials(cx); let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1481,8 +1594,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - api_key, buffer, + completion_provider, }; this.count_remaining_tokens(cx); this @@ -1514,12 +1627,14 @@ impl Conversation { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: self - .buffer - .read(cx) - .text_for_range(message.offset_range) - .collect(), + content: Some( + self.buffer + .read(cx) + .text_for_range(message.offset_range) + .collect(), + ), name: None, + function_call: None, }) }) .collect::>(); @@ -1601,11 +1716,11 @@ impl Conversation { } if should_assist { - let Some(api_key) = self.api_key.borrow().clone() else { + if !self.completion_provider.has_credentials() { return Default::default(); - }; + } - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1613,9 +1728,11 @@ impl Conversation { .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .collect(), stream: true, - }; + stop: vec![], + temperature: 1.0, + }); - let stream = stream_completion(api_key, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1633,33 +1750,28 @@ impl Conversation { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = - this.message_anchors.iter().position(|message| { - message.id == assistant_message_id - })?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); + let text = message?; - Some(()) + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); }); - } + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); smol::future::yield_now().await; } @@ -1881,55 +1993,54 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let api_key = self.api_key.borrow().clone(); - if let Some(api_key) = api_key { - let messages = self - .messages(cx) - .take(2) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })); - let request = OpenAIRequest { - model: self.model.full_name().to_string(), - messages: messages.collect(), - stream: true, - }; - - let stream = stream_completion(api_key, cx.background().clone(), request); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } - } - - this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ConversationEvent::SummaryChanged); - } - }); - - anyhow::Ok(()) - } - .log_err() - }); + if !self.completion_provider.has_credentials() { + return; } + + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + })); + let request: Box = Box::new(OpenAIRequest { + model: self.model.full_name().to_string(), + messages: messages.collect(), + stream: true, + stop: vec![], + temperature: 1.0, + }); + + let stream = self.completion_provider.complete(request); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let text = message?; + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); + } + + this.update(&mut cx, |this, cx| { + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + cx.emit(ConversationEvent::SummaryChanged); + } + }); + + anyhow::Ok(()) + } + .log_err() + }); } } @@ -2090,13 +2201,14 @@ struct ConversationEditor { impl ConversationEditor { fn new( - api_key: Rc>>, + completion_provider: Box, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + let conversation = + cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -2638,12 +2750,16 @@ enum InlineAssistantEvent { Confirmed { prompt: String, include_conversation: bool, + retrieve_context: bool, }, Canceled, Dismissed, IncludeConversationToggled { include_conversation: bool, }, + RetrieveContextToggled { + retrieve_context: bool, + }, } struct InlineAssistant { @@ -2659,6 +2775,11 @@ struct InlineAssistant { pending_prompt: String, codegen: ModelHandle, _subscriptions: Vec, + retrieve_context: bool, + semantic_index: Option>, + semantic_permissioned: Option, + project: WeakModelHandle, + maintain_rate_limit: Option>, } impl Entity for InlineAssistant { @@ -2675,51 +2796,65 @@ impl View for InlineAssistant { let theme = theme::current(cx); Flex::row() - .with_child( - Flex::row() - .with_child( - Button::action(ToggleIncludeConversation) - .with_tooltip("Include Conversation", theme.tooltip.clone()) + .with_children([Flex::row() + .with_child( + Button::action(ToggleIncludeConversation) + .with_tooltip("Include Conversation", theme.tooltip.clone()) + .with_id(self.id) + .with_contents(theme::components::svg::Svg::new("icons/ai.svg")) + .toggleable(self.include_conversation) + .with_style(theme.assistant.inline.include_conversation.clone()) + .element() + .aligned(), + ) + .with_children(if SemanticIndex::enabled(cx) { + Some( + Button::action(ToggleRetrieveContext) + .with_tooltip("Retrieve Context", theme.tooltip.clone()) .with_id(self.id) - .with_contents(theme::components::svg::Svg::new("icons/ai.svg")) - .toggleable(self.include_conversation) - .with_style(theme.assistant.inline.include_conversation.clone()) + .with_contents(theme::components::svg::Svg::new( + "icons/magnifying_glass.svg", + )) + .toggleable(self.retrieve_context) + .with_style(theme.assistant.inline.retrieve_context.clone()) .element() .aligned(), ) - .with_children(if let Some(error) = self.codegen.read(cx).error() { - Some( - Svg::new("icons/error.svg") - .with_color(theme.assistant.error_icon.color) - .constrained() - .with_width(theme.assistant.error_icon.width) - .contained() - .with_style(theme.assistant.error_icon.container) - .with_tooltip::( - self.id, - error.to_string(), - None, - theme.tooltip.clone(), - cx, - ) - .aligned(), - ) - } else { - None - }) - .aligned() - .constrained() - .dynamically({ - let measurements = self.measurements.clone(); - move |constraint, _, _| { - let measurements = measurements.get(); - SizeConstraint { - min: vec2f(measurements.gutter_width, constraint.min.y()), - max: vec2f(measurements.gutter_width, constraint.max.y()), - } + } else { + None + }) + .with_children(if let Some(error) = self.codegen.read(cx).error() { + Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.error_icon.color) + .constrained() + .with_width(theme.assistant.error_icon.width) + .contained() + .with_style(theme.assistant.error_icon.container) + .with_tooltip::( + self.id, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), + ) + } else { + None + }) + .aligned() + .constrained() + .dynamically({ + let measurements = self.measurements.clone(); + move |constraint, _, _| { + let measurements = measurements.get(); + SizeConstraint { + min: vec2f(measurements.gutter_width, constraint.min.y()), + max: vec2f(measurements.gutter_width, constraint.max.y()), } - }), - ) + } + })]) .with_child(Empty::new().constrained().dynamically({ let measurements = self.measurements.clone(); move |constraint, _, _| { @@ -2742,6 +2877,16 @@ impl View for InlineAssistant { .left() .flex(1., true), ) + .with_children(if self.retrieve_context { + Some( + Flex::row() + .with_children(self.retrieve_context_status(cx)) + .flex(1., true) + .aligned(), + ) + } else { + None + }) .contained() .with_style(theme.assistant.inline.container) .into_any() @@ -2767,6 +2912,9 @@ impl InlineAssistant { codegen: ModelHandle, workspace: WeakViewHandle, cx: &mut ViewContext, + retrieve_context: bool, + semantic_index: Option>, + project: ModelHandle, ) -> Self { let prompt_editor = cx.add_view(|cx| { let mut editor = Editor::single_line( @@ -2780,11 +2928,16 @@ impl InlineAssistant { editor.set_placeholder_text(placeholder, cx); editor }); - let subscriptions = vec![ + let mut subscriptions = vec![ cx.observe(&codegen, Self::handle_codegen_changed), cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), ]; - Self { + + if let Some(semantic_index) = semantic_index.clone() { + subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed)); + } + + let assistant = Self { id, prompt_editor, workspace, @@ -2797,7 +2950,33 @@ impl InlineAssistant { pending_prompt: String::new(), codegen, _subscriptions: subscriptions, + retrieve_context, + semantic_permissioned: None, + semantic_index, + project: project.downgrade(), + maintain_rate_limit: None, + }; + + assistant.index_project(cx).log_err(); + + assistant + } + + fn semantic_permissioned(&self, cx: &mut ViewContext) -> Task> { + if let Some(value) = self.semantic_permissioned { + return Task::ready(Ok(value)); } + + let Some(project) = self.project.upgrade(cx) else { + return Task::ready(Err(anyhow!("project was dropped"))); + }; + + self.semantic_index + .as_ref() + .map(|semantic| { + semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx)) + }) + .unwrap_or(Task::ready(Ok(false))) } fn handle_prompt_editor_events( @@ -2812,6 +2991,37 @@ impl InlineAssistant { } } + fn semantic_index_changed( + &mut self, + semantic_index: ModelHandle, + cx: &mut ViewContext, + ) { + let Some(project) = self.project.upgrade(cx) else { + return; + }; + + let status = semantic_index.read(cx).status(&project); + match status { + SemanticIndexStatus::Indexing { + rate_limit_expiry: Some(_), + .. + } => { + if self.maintain_rate_limit.is_none() { + self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move { + loop { + cx.background().timer(Duration::from_secs(1)).await; + this.update(&mut cx, |_, cx| cx.notify()).log_err(); + } + })); + } + return; + } + _ => { + self.maintain_rate_limit = None; + } + } + } + fn handle_codegen_changed(&mut self, _: ModelHandle, cx: &mut ViewContext) { let is_read_only = !self.codegen.read(cx).idle(); self.prompt_editor.update(cx, |editor, cx| { @@ -2861,12 +3071,241 @@ impl InlineAssistant { cx.emit(InlineAssistantEvent::Confirmed { prompt, include_conversation: self.include_conversation, + retrieve_context: self.retrieve_context, }); self.confirmed = true; cx.notify(); } } + fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext) { + let semantic_permissioned = self.semantic_permissioned(cx); + + let Some(project) = self.project.upgrade(cx) else { + return; + }; + + let project_name = project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"); + let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0; + let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name, + if is_plural { + "s" + } else {""}); + + cx.spawn(|this, mut cx| async move { + // If Necessary prompt user + if !semantic_permissioned.await.unwrap_or(false) { + let mut answer = this.update(&mut cx, |_, cx| { + cx.prompt( + PromptLevel::Info, + prompt_text.as_str(), + &["Continue", "Cancel"], + ) + })?; + + if answer.next().await == Some(0) { + this.update(&mut cx, |this, _| { + this.semantic_permissioned = Some(true); + })?; + } else { + return anyhow::Ok(()); + } + } + + // If permissioned, update context appropriately + this.update(&mut cx, |this, cx| { + this.retrieve_context = !this.retrieve_context; + + cx.emit(InlineAssistantEvent::RetrieveContextToggled { + retrieve_context: this.retrieve_context, + }); + + if this.retrieve_context { + this.index_project(cx).log_err(); + } + + cx.notify(); + })?; + + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + + fn index_project(&self, cx: &mut ViewContext) -> anyhow::Result<()> { + let Some(project) = self.project.upgrade(cx) else { + return Err(anyhow!("project was dropped!")); + }; + + let semantic_permissioned = self.semantic_permissioned(cx); + if let Some(semantic_index) = SemanticIndex::global(cx) { + cx.spawn(|_, mut cx| async move { + // This has to be updated to accomodate for semantic_permissions + if semantic_permissioned.await.unwrap_or(false) { + semantic_index + .update(&mut cx, |index, cx| index.index_project(project, cx)) + .await + } else { + Err(anyhow!("project is not permissioned for semantic indexing")) + } + }) + .detach_and_log_err(cx); + } + + anyhow::Ok(()) + } + + fn retrieve_context_status( + &self, + cx: &mut ViewContext, + ) -> Option> { + enum ContextStatusIcon {} + + let Some(project) = self.project.upgrade(cx) else { + return None; + }; + + if let Some(semantic_index) = SemanticIndex::global(cx) { + let status = semantic_index.update(cx, |index, _| index.status(&project)); + let theme = theme::current(cx); + match status { + SemanticIndexStatus::NotAuthenticated {} => Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.error_icon.color) + .constrained() + .with_width(theme.assistant.error_icon.width) + .contained() + .with_style(theme.assistant.error_icon.container) + .with_tooltip::( + self.id, + "Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + SemanticIndexStatus::NotIndexed {} => Some( + Svg::new("icons/error.svg") + .with_color(theme.assistant.inline.context_status.error_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.error_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.error_icon.container) + .with_tooltip::( + self.id, + "Not Indexed", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + SemanticIndexStatus::Indexing { + remaining_files, + rate_limit_expiry, + } => { + + let mut status_text = if remaining_files == 0 { + "Indexing...".to_string() + } else { + format!("Remaining files to index: {remaining_files}") + }; + + if let Some(rate_limit_expiry) = rate_limit_expiry { + let remaining_seconds = rate_limit_expiry.duration_since(Instant::now()); + if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 { + write!( + status_text, + " (rate limit expires in {}s)", + remaining_seconds.as_secs() + ) + .unwrap(); + } + } + Some( + Svg::new("icons/update.svg") + .with_color(theme.assistant.inline.context_status.in_progress_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.in_progress_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.in_progress_icon.container) + .with_tooltip::( + self.id, + status_text, + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ) + } + SemanticIndexStatus::Indexed {} => Some( + Svg::new("icons/check.svg") + .with_color(theme.assistant.inline.context_status.complete_icon.color) + .constrained() + .with_width(theme.assistant.inline.context_status.complete_icon.width) + .contained() + .with_style(theme.assistant.inline.context_status.complete_icon.container) + .with_tooltip::( + self.id, + "Index up to date", + None, + theme.tooltip.clone(), + cx, + ) + .aligned() + .into_any(), + ), + } + } else { + None + } + } + + // fn retrieve_context_status(&self, cx: &mut ViewContext) -> String { + // let project = self.project.clone(); + // if let Some(semantic_index) = self.semantic_index.clone() { + // let status = semantic_index.update(cx, |index, cx| index.status(&project)); + // return match status { + // // This theoretically shouldnt be a valid code path + // // As the inline assistant cant be launched without an API key + // // We keep it here for safety + // semantic_index::SemanticIndexStatus::NotAuthenticated => { + // "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string() + // } + // semantic_index::SemanticIndexStatus::Indexed => { + // "Indexing Complete!".to_string() + // } + // semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => { + + // let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}"); + + // if let Some(rate_limit_expiry) = rate_limit_expiry { + // let remaining_seconds = + // rate_limit_expiry.duration_since(Instant::now()); + // if remaining_seconds > Duration::from_secs(0) { + // write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap(); + // } + // } + // status + // } + // semantic_index::SemanticIndexStatus::NotIndexed => { + // "Not Indexed for Context Retrieval".to_string() + // } + // }; + // } + + // "".to_string() + // } + fn toggle_include_conversation( &mut self, _: &ToggleIncludeConversation, @@ -2929,6 +3368,7 @@ struct PendingInlineAssist { inline_assistant: Option<(BlockId, ViewHandle)>, codegen: ModelHandle, _subscriptions: Vec, + project: WeakModelHandle, } fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { @@ -2957,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use crate::MessageId; + use ai::test::FakeCompletionProvider; use gpui::AppContext; #[gpui::test] @@ -2964,7 +3405,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3092,7 +3535,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let completion_provider = Box::new(FakeCompletionProvider::new()); + + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3188,7 +3633,8 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3270,8 +3716,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); + let completion_provider = Box::new(FakeCompletionProvider::new()); let conversation = - cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3308,7 +3755,6 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Default::default(), registry.clone(), cx, ) diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index b6ef6b5cfa..0466259b24 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,10 +1,11 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, OpenAIRequest}; +use ai::completion::{CompletionProvider, CompletionRequest}; use anyhow::Result; -use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; +use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; +use multi_buffer; use std::{cmp, future, ops::Range, sync::Arc}; pub enum Event { @@ -95,7 +96,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -335,17 +336,25 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; + use ai::test::FakeCompletionProvider; + use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; use rand::prelude::*; + use serde::Serialize; use settings::SettingsStore; - use smol::future::FutureExt; + + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } + } #[gpui::test(iterations = 10)] async fn test_transform_autoindent( @@ -371,7 +380,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -380,7 +389,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( " let mut x = 0;\n", @@ -433,7 +446,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -442,7 +455,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "t mut x = 0;\n", @@ -495,7 +512,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -504,7 +521,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "let mut x = 0;\n", @@ -592,38 +613,6 @@ mod tests { } } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } - - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } - - impl CompletionProvider for TestCompletionProvider { - fn complete( - &self, - _prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } - fn rust_lang() -> Language { Language::new( LanguageConfig { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index d326a7f445..25af023c40 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,8 +1,14 @@ -use crate::codegen::CodegenKind; +use ai::models::LanguageModel; +use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::prompts::file_context::FileContext; +use ai::prompts::generate::GenerateInlineContent; +use ai::prompts::preamble::EngineerPreamble; +use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::providers::open_ai::OpenAILanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; +use std::sync::Arc; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -118,86 +124,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, -) -> String { - let range = range.to_offset(buffer); - let mut prompt = String::new(); - - // General Preamble - if let Some(language_name) = language_name { - writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); + buffer: BufferSnapshot, + range: Range, + search_results: Vec, + model: &str, + project_name: Option, +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) } else { - writeln!(prompt, "You're an expert engineer.\n").unwrap(); - } + None + }; - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name, + snippets: search_results.clone(), + reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), + }; - writeln!( - prompt, - "The file you are currently working on has the following content:" - ) - .unwrap(); - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - writeln!(prompt, "```{language_name}\n{content}\n```").unwrap(); - } else { - writeln!(prompt, "```\n{content}\n```").unwrap(); - } + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), + ]; + let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; - match kind { - CodegenKind::Generate { position: _ } => { - writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap(); - writeln!( - prompt, - "Assume the cursor is located where the `<|START|` marker is." - ) - .unwrap(); - writeln!( - prompt, - "Text can't be replaced, so assume your answer will be inserted at the cursor." - ) - .unwrap(); - writeln!( - prompt, - "Generate text based on the users prompt: {user_prompt}" - ) - .unwrap(); - } - CodegenKind::Transform { range: _ } => { - writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); - writeln!( - prompt, - "Modify the users code selected text based upon the users prompt: {user_prompt}" - ) - .unwrap(); - writeln!( - prompt, - "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file." - ) - .unwrap(); - } - } - - if let Some(language_name) = language_name { - writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); - } - writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); - writeln!(prompt, "Never make remarks about the output.").unwrap(); - - prompt + anyhow::Ok(prompt) } #[cfg(test)] diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 0846341325..ca1a60bd63 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -10,7 +10,7 @@ use client::{ ZED_ALWAYS_ACTIVE, }; use collections::HashSet; -use futures::{future::Shared, FutureExt}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt}; use gpui::{ AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task, WeakModelHandle, @@ -37,10 +37,42 @@ pub struct IncomingCall { pub initial_project: Option, } +pub struct OneAtATime { + cancel: Option>, +} + +impl OneAtATime { + /// spawn a task in the given context. + /// if another task is spawned before that resolves, or if the OneAtATime itself is dropped, the first task will be cancelled and return Ok(None) + /// otherwise you'll see the result of the task. + fn spawn(&mut self, cx: &mut AppContext, f: F) -> Task>> + where + F: 'static + FnOnce(AsyncAppContext) -> Fut, + Fut: Future>, + R: 'static, + { + let (tx, rx) = oneshot::channel(); + self.cancel.replace(tx); + cx.spawn(|cx| async move { + futures::select_biased! { + _ = rx.fuse() => Ok(None), + result = f(cx).fuse() => result.map(Some), + } + }) + } + + fn running(&self) -> bool { + self.cancel + .as_ref() + .is_some_and(|cancel| !cancel.is_canceled()) + } +} + /// Singleton global maintaining the user's participation in a room across workspaces. pub struct ActiveCall { room: Option<(ModelHandle, Vec)>, pending_room_creation: Option, Arc>>>>, + _join_debouncer: OneAtATime, location: Option>, pending_invites: HashSet, incoming_call: ( @@ -69,6 +101,7 @@ impl ActiveCall { pending_invites: Default::default(), incoming_call: watch::channel(), + _join_debouncer: OneAtATime { cancel: None }, _subscriptions: vec![ client.add_request_handler(cx.handle(), Self::handle_incoming_call), client.add_message_handler(cx.handle(), Self::handle_call_canceled), @@ -143,6 +176,10 @@ impl ActiveCall { } cx.notify(); + if self._join_debouncer.running() { + return Task::ready(Ok(())); + } + let room = if let Some(room) = self.room().cloned() { Some(Task::ready(Ok(room)).shared()) } else { @@ -259,11 +296,20 @@ impl ActiveCall { return Task::ready(Err(anyhow!("no incoming call"))); }; - let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(())); + } + + let room_id = call.room_id.clone(); + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self + ._join_debouncer + .spawn(cx, move |cx| Room::join(room_id, client, user_store, cx)); cx.spawn(|this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("accept incoming", cx) @@ -290,20 +336,28 @@ impl ActiveCall { &mut self, channel_id: u64, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>>> { if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { - return Task::ready(Ok(room)); + return Task::ready(Ok(Some(room))); } else { room.update(cx, |room, cx| room.clear_state(cx)); } } - let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); + if self.pending_room_creation.is_some() { + return Task::ready(Ok(None)); + } - cx.spawn(|this, mut cx| async move { + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let join = self._join_debouncer.spawn(cx, move |cx| async move { + Room::join_channel(channel_id, client, user_store, cx).await + }); + + cx.spawn(move |this, mut cx| async move { let room = join.await?; - this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx)) + this.update(&mut cx, |this, cx| this.set_room(room.clone(), cx)) .await?; this.update(&mut cx, |this, cx| { this.report_call_event("join channel", cx) @@ -457,3 +511,40 @@ pub fn report_call_event_for_channel( }; telemetry.report_clickhouse_event(event, telemetry_settings); } + +#[cfg(test)] +mod test { + use gpui::TestAppContext; + + use crate::OneAtATime; + + #[gpui::test] + async fn test_one_at_a_time(cx: &mut TestAppContext) { + let mut one_at_a_time = OneAtATime { cancel: None }; + + assert_eq!( + cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(1) })) + .await + .unwrap(), + Some(1) + ); + + let (a, b) = cx.update(|cx| { + ( + one_at_a_time.spawn(cx, |_| async { + assert!(false); + Ok(2) + }), + one_at_a_time.spawn(cx, |_| async { Ok(3) }), + ) + }); + + assert_eq!(a.await.unwrap(), None); + assert_eq!(b.await.unwrap(), Some(3)); + + let promise = cx.update(|cx| one_at_a_time.spawn(cx, |_| async { Ok(4) })); + drop(one_at_a_time); + + assert_eq!(promise.await.unwrap(), None); + } +} diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 4e52f57f60..8d37194f3a 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -1,7 +1,6 @@ use crate::{ call_settings::CallSettings, participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack}, - IncomingCall, }; use anyhow::{anyhow, Result}; use audio::{Audio, Sound}; @@ -55,7 +54,7 @@ pub enum Event { pub struct Room { id: u64, - channel_id: Option, + pub channel_id: Option, live_kit: Option, status: RoomStatus, shared_projects: HashSet>, @@ -122,6 +121,10 @@ impl Room { } } + pub fn can_publish(&self) -> bool { + self.live_kit.as_ref().is_some_and(|room| room.can_publish) + } + fn new( id: u64, channel_id: Option, @@ -181,20 +184,23 @@ impl Room { }); let connect = room.connect(&connection_info.server_url, &connection_info.token); - cx.spawn(|this, mut cx| async move { - connect.await?; + if connection_info.can_publish { + cx.spawn(|this, mut cx| async move { + connect.await?; - if !cx.read(Self::mute_on_join) { - this.update(&mut cx, |this, cx| this.share_microphone(cx)) - .await?; - } + if !cx.read(Self::mute_on_join) { + this.update(&mut cx, |this, cx| this.share_microphone(cx)) + .await?; + } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } Some(LiveKitRoom { room, + can_publish: connection_info.can_publish, screen_track: LocalTrack::None, microphone_track: LocalTrack::None, next_publish_id: 0, @@ -284,37 +290,32 @@ impl Room { }) } - pub(crate) fn join_channel( + pub(crate) async fn join_channel( channel_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinChannel { channel_id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinChannel { channel_id }).await?, + client, + user_store, + cx, + ) } - pub(crate) fn join( - call: &IncomingCall, + pub(crate) async fn join( + room_id: u64, client: Arc, user_store: ModelHandle, - cx: &mut AppContext, - ) -> Task>> { - let id = call.room_id; - cx.spawn(|cx| async move { - Self::from_join_response( - client.request(proto::JoinRoom { id }).await?, - client, - user_store, - cx, - ) - }) + cx: AsyncAppContext, + ) -> Result> { + Self::from_join_response( + client.request(proto::JoinRoom { id: room_id }).await?, + client, + user_store, + cx, + ) } pub fn mute_on_join(cx: &AppContext) -> bool { @@ -1251,7 +1252,7 @@ impl Room { .read_with(&cx, |this, _| { this.live_kit .as_ref() - .map(|live_kit| live_kit.room.publish_audio_track(&track)) + .map(|live_kit| live_kit.room.publish_audio_track(track)) }) .ok_or_else(|| anyhow!("live-kit was not initialized"))? .await @@ -1337,7 +1338,7 @@ impl Room { .read_with(&cx, |this, _| { this.live_kit .as_ref() - .map(|live_kit| live_kit.room.publish_video_track(&track)) + .map(|live_kit| live_kit.room.publish_video_track(track)) }) .ok_or_else(|| anyhow!("live-kit was not initialized"))? .await @@ -1498,6 +1499,7 @@ struct LiveKitRoom { deafened: bool, speaking: bool, next_publish_id: usize, + can_publish: bool, _maintain_room: Task<()>, _maintain_tracks: [Task<()>; 2], } diff --git a/crates/call2/src/call2.rs b/crates/call2/src/call2.rs index 1a514164ba..fd09dc3180 100644 --- a/crates/call2/src/call2.rs +++ b/crates/call2/src/call2.rs @@ -12,8 +12,8 @@ use client2::{ use collections::HashSet; use futures::{future::Shared, FutureExt}; use gpui2::{ - AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Subscription, Task, - WeakHandle, + AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task, + WeakModel, }; use postage::watch; use project2::Project; @@ -23,10 +23,10 @@ use std::sync::Arc; pub use participant::ParticipantLocation; pub use room::Room; -pub fn init(client: Arc, user_store: Handle, cx: &mut AppContext) { +pub fn init(client: Arc, user_store: Model, cx: &mut AppContext) { CallSettings::register(cx); - let active_call = cx.entity(|cx| ActiveCall::new(client, user_store, cx)); + let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx)); cx.set_global(active_call); } @@ -40,16 +40,16 @@ pub struct IncomingCall { /// Singleton global maintaining the user's participation in a room across workspaces. pub struct ActiveCall { - room: Option<(Handle, Vec)>, - pending_room_creation: Option, Arc>>>>, - location: Option>, + room: Option<(Model, Vec)>, + pending_room_creation: Option, Arc>>>>, + location: Option>, pending_invites: HashSet, incoming_call: ( watch::Sender>, watch::Receiver>, ), client: Arc, - user_store: Handle, + user_store: Model, _subscriptions: Vec, } @@ -58,11 +58,7 @@ impl EventEmitter for ActiveCall { } impl ActiveCall { - fn new( - client: Arc, - user_store: Handle, - cx: &mut ModelContext, - ) -> Self { + fn new(client: Arc, user_store: Model, cx: &mut ModelContext) -> Self { Self { room: None, pending_room_creation: None, @@ -71,8 +67,8 @@ impl ActiveCall { incoming_call: watch::channel(), _subscriptions: vec![ - client.add_request_handler(cx.weak_handle(), Self::handle_incoming_call), - client.add_message_handler(cx.weak_handle(), Self::handle_call_canceled), + client.add_request_handler(cx.weak_model(), Self::handle_incoming_call), + client.add_message_handler(cx.weak_model(), Self::handle_call_canceled), ], client, user_store, @@ -84,7 +80,7 @@ impl ActiveCall { } async fn handle_incoming_call( - this: Handle, + this: Model, envelope: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -112,7 +108,7 @@ impl ActiveCall { } async fn handle_call_canceled( - this: Handle, + this: Model, envelope: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -129,14 +125,14 @@ impl ActiveCall { Ok(()) } - pub fn global(cx: &AppContext) -> Handle { - cx.global::>().clone() + pub fn global(cx: &AppContext) -> Model { + cx.global::>().clone() } pub fn invite( &mut self, called_user_id: u64, - initial_project: Option>, + initial_project: Option>, cx: &mut ModelContext, ) -> Task> { if !self.pending_invites.insert(called_user_id) { @@ -291,7 +287,7 @@ impl ActiveCall { &mut self, channel_id: u64, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>> { if let Some(room) = self.room().cloned() { if room.read(cx).channel_id() == Some(channel_id) { return Task::ready(Ok(room)); @@ -327,7 +323,7 @@ impl ActiveCall { pub fn share_project( &mut self, - project: Handle, + project: Model, cx: &mut ModelContext, ) -> Task> { if let Some((room, _)) = self.room.as_ref() { @@ -340,7 +336,7 @@ impl ActiveCall { pub fn unshare_project( &mut self, - project: Handle, + project: Model, cx: &mut ModelContext, ) -> Result<()> { if let Some((room, _)) = self.room.as_ref() { @@ -351,13 +347,13 @@ impl ActiveCall { } } - pub fn location(&self) -> Option<&WeakHandle> { + pub fn location(&self) -> Option<&WeakModel> { self.location.as_ref() } pub fn set_location( &mut self, - project: Option<&Handle>, + project: Option<&Model>, cx: &mut ModelContext, ) -> Task> { if project.is_some() || !*ZED_ALWAYS_ACTIVE { @@ -371,7 +367,7 @@ impl ActiveCall { fn set_room( &mut self, - room: Option>, + room: Option>, cx: &mut ModelContext, ) -> Task> { if room.as_ref() != self.room.as_ref().map(|room| &room.0) { @@ -407,7 +403,7 @@ impl ActiveCall { } } - pub fn room(&self) -> Option<&Handle> { + pub fn room(&self) -> Option<&Model> { self.room.as_ref().map(|(room, _)| room) } diff --git a/crates/call2/src/participant.rs b/crates/call2/src/participant.rs index 3f594ac944..7f3e91dbba 100644 --- a/crates/call2/src/participant.rs +++ b/crates/call2/src/participant.rs @@ -1,10 +1,8 @@ use anyhow::{anyhow, Result}; use client2::ParticipantIndex; use client2::{proto, User}; -use collections::HashMap; -use gpui2::WeakHandle; +use gpui2::WeakModel; pub use live_kit_client::Frame; -use live_kit_client::RemoteAudioTrack; use project2::Project; use std::{fmt, sync::Arc}; @@ -35,7 +33,7 @@ impl ParticipantLocation { #[derive(Clone, Default)] pub struct LocalParticipant { pub projects: Vec, - pub active_project: Option>, + pub active_project: Option>, } #[derive(Clone, Debug)] @@ -47,8 +45,8 @@ pub struct RemoteParticipant { pub participant_index: ParticipantIndex, pub muted: bool, pub speaking: bool, - pub video_tracks: HashMap>, - pub audio_tracks: HashMap>, + // pub video_tracks: HashMap>, + // pub audio_tracks: HashMap>, } #[derive(Clone)] diff --git a/crates/call2/src/room.rs b/crates/call2/src/room.rs index 27ffe68c50..b7bac52a8b 100644 --- a/crates/call2/src/room.rs +++ b/crates/call2/src/room.rs @@ -1,3 +1,6 @@ +#![allow(dead_code, unused)] +// todo!() + use crate::{ call_settings::CallSettings, participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack}, @@ -13,18 +16,15 @@ use collections::{BTreeMap, HashMap, HashSet}; use fs2::Fs; use futures::{FutureExt, StreamExt}; use gpui2::{ - AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle, + AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel, }; use language2::LanguageRegistry; -use live_kit_client::{ - LocalAudioTrack, LocalTrackPublication, LocalVideoTrack, RemoteAudioTrackUpdate, - RemoteVideoTrackUpdate, -}; +use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate}; use postage::{sink::Sink, stream::Stream, watch}; use project2::Project; use settings2::Settings; -use std::{future::Future, mem, sync::Arc, time::Duration}; -use util::{post_inc, ResultExt, TryFutureExt}; +use std::{future::Future, sync::Arc, time::Duration}; +use util::{ResultExt, TryFutureExt}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -59,10 +59,10 @@ pub enum Event { pub struct Room { id: u64, channel_id: Option, - live_kit: Option, + // live_kit: Option, status: RoomStatus, - shared_projects: HashSet>, - joined_projects: HashSet>, + shared_projects: HashSet>, + joined_projects: HashSet>, local_participant: LocalParticipant, remote_participants: BTreeMap, pending_participants: Vec>, @@ -70,7 +70,7 @@ pub struct Room { pending_call_count: usize, leave_when_empty: bool, client: Arc, - user_store: Handle, + user_store: Model, follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec>, client_subscriptions: Vec, _subscriptions: Vec, @@ -95,14 +95,15 @@ impl Room { #[cfg(any(test, feature = "test-support"))] pub fn is_connected(&self) -> bool { - if let Some(live_kit) = self.live_kit.as_ref() { - matches!( - *live_kit.room.status().borrow(), - live_kit_client::ConnectionState::Connected { .. } - ) - } else { - false - } + false + // if let Some(live_kit) = self.live_kit.as_ref() { + // matches!( + // *live_kit.room.status().borrow(), + // live_kit_client::ConnectionState::Connected { .. } + // ) + // } else { + // false + // } } fn new( @@ -110,140 +111,141 @@ impl Room { channel_id: Option, live_kit_connection_info: Option, client: Arc, - user_store: Handle, + user_store: Model, cx: &mut ModelContext, ) -> Self { - let live_kit_room = if let Some(connection_info) = live_kit_connection_info { - let room = live_kit_client::Room::new(); - let mut status = room.status(); - // Consume the initial status of the room. - let _ = status.try_recv(); - let _maintain_room = cx.spawn(|this, mut cx| async move { - while let Some(status) = status.next().await { - let this = if let Some(this) = this.upgrade() { - this - } else { - break; - }; + todo!() + // let _live_kit_room = if let Some(connection_info) = live_kit_connection_info { + // let room = live_kit_client::Room::new(); + // let mut status = room.status(); + // // Consume the initial status of the room. + // let _ = status.try_recv(); + // let _maintain_room = cx.spawn(|this, mut cx| async move { + // while let Some(status) = status.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; - if status == live_kit_client::ConnectionState::Disconnected { - this.update(&mut cx, |this, cx| this.leave(cx).log_err()) - .ok(); - break; - } - } - }); + // if status == live_kit_client::ConnectionState::Disconnected { + // this.update(&mut cx, |this, cx| this.leave(cx).log_err()) + // .ok(); + // break; + // } + // } + // }); - let mut track_video_changes = room.remote_video_track_updates(); - let _maintain_video_tracks = cx.spawn(|this, mut cx| async move { - while let Some(track_change) = track_video_changes.next().await { - let this = if let Some(this) = this.upgrade() { - this - } else { - break; - }; + // let mut track_video_changes = room.remote_video_track_updates(); + // let _maintain_video_tracks = cx.spawn(|this, mut cx| async move { + // while let Some(track_change) = track_video_changes.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; - this.update(&mut cx, |this, cx| { - this.remote_video_track_updated(track_change, cx).log_err() - }) - .ok(); - } - }); + // this.update(&mut cx, |this, cx| { + // this.remote_video_track_updated(track_change, cx).log_err() + // }) + // .ok(); + // } + // }); - let mut track_audio_changes = room.remote_audio_track_updates(); - let _maintain_audio_tracks = cx.spawn(|this, mut cx| async move { - while let Some(track_change) = track_audio_changes.next().await { - let this = if let Some(this) = this.upgrade() { - this - } else { - break; - }; + // let mut track_audio_changes = room.remote_audio_track_updates(); + // let _maintain_audio_tracks = cx.spawn(|this, mut cx| async move { + // while let Some(track_change) = track_audio_changes.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; - this.update(&mut cx, |this, cx| { - this.remote_audio_track_updated(track_change, cx).log_err() - }) - .ok(); - } - }); + // this.update(&mut cx, |this, cx| { + // this.remote_audio_track_updated(track_change, cx).log_err() + // }) + // .ok(); + // } + // }); - let connect = room.connect(&connection_info.server_url, &connection_info.token); - cx.spawn(|this, mut cx| async move { - connect.await?; + // let connect = room.connect(&connection_info.server_url, &connection_info.token); + // cx.spawn(|this, mut cx| async move { + // connect.await?; - if !cx.update(|cx| Self::mute_on_join(cx))? { - this.update(&mut cx, |this, cx| this.share_microphone(cx))? - .await?; - } + // if !cx.update(|cx| Self::mute_on_join(cx))? { + // this.update(&mut cx, |this, cx| this.share_microphone(cx))? + // .await?; + // } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); + // anyhow::Ok(()) + // }) + // .detach_and_log_err(cx); - Some(LiveKitRoom { - room, - screen_track: LocalTrack::None, - microphone_track: LocalTrack::None, - next_publish_id: 0, - muted_by_user: false, - deafened: false, - speaking: false, - _maintain_room, - _maintain_tracks: [_maintain_video_tracks, _maintain_audio_tracks], - }) - } else { - None - }; + // Some(LiveKitRoom { + // room, + // screen_track: LocalTrack::None, + // microphone_track: LocalTrack::None, + // next_publish_id: 0, + // muted_by_user: false, + // deafened: false, + // speaking: false, + // _maintain_room, + // _maintain_tracks: [_maintain_video_tracks, _maintain_audio_tracks], + // }) + // } else { + // None + // }; - let maintain_connection = cx.spawn({ - let client = client.clone(); - move |this, cx| Self::maintain_connection(this, client.clone(), cx).log_err() - }); + // let maintain_connection = cx.spawn({ + // let client = client.clone(); + // move |this, cx| Self::maintain_connection(this, client.clone(), cx).log_err() + // }); - Audio::play_sound(Sound::Joined, cx); + // Audio::play_sound(Sound::Joined, cx); - let (room_update_completed_tx, room_update_completed_rx) = watch::channel(); + // let (room_update_completed_tx, room_update_completed_rx) = watch::channel(); - Self { - id, - channel_id, - live_kit: live_kit_room, - status: RoomStatus::Online, - shared_projects: Default::default(), - joined_projects: Default::default(), - participant_user_ids: Default::default(), - local_participant: Default::default(), - remote_participants: Default::default(), - pending_participants: Default::default(), - pending_call_count: 0, - client_subscriptions: vec![ - client.add_message_handler(cx.weak_handle(), Self::handle_room_updated) - ], - _subscriptions: vec![ - cx.on_release(Self::released), - cx.on_app_quit(Self::app_will_quit), - ], - leave_when_empty: false, - pending_room_update: None, - client, - user_store, - follows_by_leader_id_project_id: Default::default(), - maintain_connection: Some(maintain_connection), - room_update_completed_tx, - room_update_completed_rx, - } + // Self { + // id, + // channel_id, + // // live_kit: live_kit_room, + // status: RoomStatus::Online, + // shared_projects: Default::default(), + // joined_projects: Default::default(), + // participant_user_ids: Default::default(), + // local_participant: Default::default(), + // remote_participants: Default::default(), + // pending_participants: Default::default(), + // pending_call_count: 0, + // client_subscriptions: vec![ + // client.add_message_handler(cx.weak_handle(), Self::handle_room_updated) + // ], + // _subscriptions: vec![ + // cx.on_release(Self::released), + // cx.on_app_quit(Self::app_will_quit), + // ], + // leave_when_empty: false, + // pending_room_update: None, + // client, + // user_store, + // follows_by_leader_id_project_id: Default::default(), + // maintain_connection: Some(maintain_connection), + // room_update_completed_tx, + // room_update_completed_rx, + // } } pub(crate) fn create( called_user_id: u64, - initial_project: Option>, + initial_project: Option>, client: Arc, - user_store: Handle, + user_store: Model, cx: &mut AppContext, - ) -> Task>> { + ) -> Task>> { cx.spawn(move |mut cx| async move { let response = client.request(proto::CreateRoom {}).await?; let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; - let room = cx.entity(|cx| { + let room = cx.build_model(|cx| { Self::new( room_proto.id, None, @@ -281,9 +283,9 @@ impl Room { pub(crate) fn join_channel( channel_id: u64, client: Arc, - user_store: Handle, + user_store: Model, cx: &mut AppContext, - ) -> Task>> { + ) -> Task>> { cx.spawn(move |cx| async move { Self::from_join_response( client.request(proto::JoinChannel { channel_id }).await?, @@ -297,9 +299,9 @@ impl Room { pub(crate) fn join( call: &IncomingCall, client: Arc, - user_store: Handle, + user_store: Model, cx: &mut AppContext, - ) -> Task>> { + ) -> Task>> { let id = call.room_id; cx.spawn(move |cx| async move { Self::from_join_response( @@ -341,11 +343,11 @@ impl Room { fn from_join_response( response: proto::JoinRoomResponse, client: Arc, - user_store: Handle, + user_store: Model, mut cx: AsyncAppContext, - ) -> Result> { + ) -> Result> { let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; - let room = cx.entity(|cx| { + let room = cx.build_model(|cx| { Self::new( room_proto.id, response.channel_id, @@ -416,13 +418,13 @@ impl Room { self.pending_participants.clear(); self.participant_user_ids.clear(); self.client_subscriptions.clear(); - self.live_kit.take(); + // self.live_kit.take(); self.pending_room_update.take(); self.maintain_connection.take(); } async fn maintain_connection( - this: WeakHandle, + this: WeakModel, client: Arc, mut cx: AsyncAppContext, ) -> Result<()> { @@ -659,7 +661,7 @@ impl Room { } async fn handle_room_updated( - this: Handle, + this: Model, envelope: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -792,43 +794,43 @@ impl Room { location, muted: true, speaking: false, - video_tracks: Default::default(), - audio_tracks: Default::default(), + // video_tracks: Default::default(), + // audio_tracks: Default::default(), }, ); Audio::play_sound(Sound::Joined, cx); - if let Some(live_kit) = this.live_kit.as_ref() { - let video_tracks = - live_kit.room.remote_video_tracks(&user.id.to_string()); - let audio_tracks = - live_kit.room.remote_audio_tracks(&user.id.to_string()); - let publications = live_kit - .room - .remote_audio_track_publications(&user.id.to_string()); + // if let Some(live_kit) = this.live_kit.as_ref() { + // let video_tracks = + // live_kit.room.remote_video_tracks(&user.id.to_string()); + // let audio_tracks = + // live_kit.room.remote_audio_tracks(&user.id.to_string()); + // let publications = live_kit + // .room + // .remote_audio_track_publications(&user.id.to_string()); - for track in video_tracks { - this.remote_video_track_updated( - RemoteVideoTrackUpdate::Subscribed(track), - cx, - ) - .log_err(); - } + // for track in video_tracks { + // this.remote_video_track_updated( + // RemoteVideoTrackUpdate::Subscribed(track), + // cx, + // ) + // .log_err(); + // } - for (track, publication) in - audio_tracks.iter().zip(publications.iter()) - { - this.remote_audio_track_updated( - RemoteAudioTrackUpdate::Subscribed( - track.clone(), - publication.clone(), - ), - cx, - ) - .log_err(); - } - } + // for (track, publication) in + // audio_tracks.iter().zip(publications.iter()) + // { + // this.remote_audio_track_updated( + // RemoteAudioTrackUpdate::Subscribed( + // track.clone(), + // publication.clone(), + // ), + // cx, + // ) + // .log_err(); + // } + // } } } @@ -916,6 +918,7 @@ impl Room { change: RemoteVideoTrackUpdate, cx: &mut ModelContext, ) -> Result<()> { + todo!(); match change { RemoteVideoTrackUpdate::Subscribed(track) => { let user_id = track.publisher_id().parse()?; @@ -924,12 +927,12 @@ impl Room { .remote_participants .get_mut(&user_id) .ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?; - participant.video_tracks.insert( - track_id.clone(), - Arc::new(RemoteVideoTrack { - live_kit_track: track, - }), - ); + // participant.video_tracks.insert( + // track_id.clone(), + // Arc::new(RemoteVideoTrack { + // live_kit_track: track, + // }), + // ); cx.emit(Event::RemoteVideoTracksChanged { participant_id: participant.peer_id, }); @@ -943,7 +946,7 @@ impl Room { .remote_participants .get_mut(&user_id) .ok_or_else(|| anyhow!("unsubscribed from track by unknown participant"))?; - participant.video_tracks.remove(&track_id); + // participant.video_tracks.remove(&track_id); cx.emit(Event::RemoteVideoTracksChanged { participant_id: participant.peer_id, }); @@ -973,62 +976,65 @@ impl Room { participant.speaking = false; } } - if let Some(id) = self.client.user_id() { - if let Some(room) = &mut self.live_kit { - if let Ok(_) = speaker_ids.binary_search(&id) { - room.speaking = true; - } else { - room.speaking = false; - } - } - } + // todo!() + // if let Some(id) = self.client.user_id() { + // if let Some(room) = &mut self.live_kit { + // if let Ok(_) = speaker_ids.binary_search(&id) { + // room.speaking = true; + // } else { + // room.speaking = false; + // } + // } + // } cx.notify(); } RemoteAudioTrackUpdate::MuteChanged { track_id, muted } => { - let mut found = false; - for participant in &mut self.remote_participants.values_mut() { - for track in participant.audio_tracks.values() { - if track.sid() == track_id { - found = true; - break; - } - } - if found { - participant.muted = muted; - break; - } - } + // todo!() + // let mut found = false; + // for participant in &mut self.remote_participants.values_mut() { + // for track in participant.audio_tracks.values() { + // if track.sid() == track_id { + // found = true; + // break; + // } + // } + // if found { + // participant.muted = muted; + // break; + // } + // } cx.notify(); } RemoteAudioTrackUpdate::Subscribed(track, publication) => { - let user_id = track.publisher_id().parse()?; - let track_id = track.sid().to_string(); - let participant = self - .remote_participants - .get_mut(&user_id) - .ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?; + // todo!() + // let user_id = track.publisher_id().parse()?; + // let track_id = track.sid().to_string(); + // let participant = self + // .remote_participants + // .get_mut(&user_id) + // .ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?; + // // participant.audio_tracks.insert(track_id.clone(), track); + // participant.muted = publication.is_muted(); - participant.audio_tracks.insert(track_id.clone(), track); - participant.muted = publication.is_muted(); - - cx.emit(Event::RemoteAudioTracksChanged { - participant_id: participant.peer_id, - }); + // cx.emit(Event::RemoteAudioTracksChanged { + // participant_id: participant.peer_id, + // }); } RemoteAudioTrackUpdate::Unsubscribed { publisher_id, track_id, } => { - let user_id = publisher_id.parse()?; - let participant = self - .remote_participants - .get_mut(&user_id) - .ok_or_else(|| anyhow!("unsubscribed from track by unknown participant"))?; - participant.audio_tracks.remove(&track_id); - cx.emit(Event::RemoteAudioTracksChanged { - participant_id: participant.peer_id, - }); + // todo!() + // let user_id = publisher_id.parse()?; + // let participant = self + // .remote_participants + // .get_mut(&user_id) + // .ok_or_else(|| anyhow!("unsubscribed from track by unknown participant"))?; + // participant.audio_tracks.remove(&track_id); + // cx.emit(Event::RemoteAudioTracksChanged { + // participant_id: participant.peer_id, + // }); } } @@ -1095,7 +1101,7 @@ impl Room { language_registry: Arc, fs: Arc, cx: &mut ModelContext, - ) -> Task>> { + ) -> Task>> { let client = self.client.clone(); let user_store = self.user_store.clone(); cx.emit(Event::RemoteProjectJoined { project_id: id }); @@ -1119,7 +1125,7 @@ impl Room { pub(crate) fn share_project( &mut self, - project: Handle, + project: Model, cx: &mut ModelContext, ) -> Task> { if let Some(project_id) = project.read(cx).remote_id() { @@ -1155,7 +1161,7 @@ impl Room { pub(crate) fn unshare_project( &mut self, - project: Handle, + project: Model, cx: &mut ModelContext, ) -> Result<()> { let project_id = match project.read(cx).remote_id() { @@ -1169,7 +1175,7 @@ impl Room { pub(crate) fn set_location( &mut self, - project: Option<&Handle>, + project: Option<&Model>, cx: &mut ModelContext, ) -> Task> { if self.status.is_offline() { @@ -1209,269 +1215,278 @@ impl Room { } pub fn is_screen_sharing(&self) -> bool { - self.live_kit.as_ref().map_or(false, |live_kit| { - !matches!(live_kit.screen_track, LocalTrack::None) - }) + todo!() + // self.live_kit.as_ref().map_or(false, |live_kit| { + // !matches!(live_kit.screen_track, LocalTrack::None) + // }) } pub fn is_sharing_mic(&self) -> bool { - self.live_kit.as_ref().map_or(false, |live_kit| { - !matches!(live_kit.microphone_track, LocalTrack::None) - }) + todo!() + // self.live_kit.as_ref().map_or(false, |live_kit| { + // !matches!(live_kit.microphone_track, LocalTrack::None) + // }) } pub fn is_muted(&self, cx: &AppContext) -> bool { - self.live_kit - .as_ref() - .and_then(|live_kit| match &live_kit.microphone_track { - LocalTrack::None => Some(Self::mute_on_join(cx)), - LocalTrack::Pending { muted, .. } => Some(*muted), - LocalTrack::Published { muted, .. } => Some(*muted), - }) - .unwrap_or(false) + todo!() + // self.live_kit + // .as_ref() + // .and_then(|live_kit| match &live_kit.microphone_track { + // LocalTrack::None => Some(Self::mute_on_join(cx)), + // LocalTrack::Pending { muted, .. } => Some(*muted), + // LocalTrack::Published { muted, .. } => Some(*muted), + // }) + // .unwrap_or(false) } pub fn is_speaking(&self) -> bool { - self.live_kit - .as_ref() - .map_or(false, |live_kit| live_kit.speaking) + todo!() + // self.live_kit + // .as_ref() + // .map_or(false, |live_kit| live_kit.speaking) } pub fn is_deafened(&self) -> Option { - self.live_kit.as_ref().map(|live_kit| live_kit.deafened) + // self.live_kit.as_ref().map(|live_kit| live_kit.deafened) + todo!() } #[track_caller] pub fn share_microphone(&mut self, cx: &mut ModelContext) -> Task> { - if self.status.is_offline() { - return Task::ready(Err(anyhow!("room is offline"))); - } else if self.is_sharing_mic() { - return Task::ready(Err(anyhow!("microphone was already shared"))); - } + todo!() + // if self.status.is_offline() { + // return Task::ready(Err(anyhow!("room is offline"))); + // } else if self.is_sharing_mic() { + // return Task::ready(Err(anyhow!("microphone was already shared"))); + // } - let publish_id = if let Some(live_kit) = self.live_kit.as_mut() { - let publish_id = post_inc(&mut live_kit.next_publish_id); - live_kit.microphone_track = LocalTrack::Pending { - publish_id, - muted: false, - }; - cx.notify(); - publish_id - } else { - return Task::ready(Err(anyhow!("live-kit was not initialized"))); - }; + // let publish_id = if let Some(live_kit) = self.live_kit.as_mut() { + // let publish_id = post_inc(&mut live_kit.next_publish_id); + // live_kit.microphone_track = LocalTrack::Pending { + // publish_id, + // muted: false, + // }; + // cx.notify(); + // publish_id + // } else { + // return Task::ready(Err(anyhow!("live-kit was not initialized"))); + // }; - cx.spawn(move |this, mut cx| async move { - let publish_track = async { - let track = LocalAudioTrack::create(); - this.upgrade() - .ok_or_else(|| anyhow!("room was dropped"))? - .update(&mut cx, |this, _| { - this.live_kit - .as_ref() - .map(|live_kit| live_kit.room.publish_audio_track(&track)) - })? - .ok_or_else(|| anyhow!("live-kit was not initialized"))? - .await - }; + // cx.spawn(move |this, mut cx| async move { + // let publish_track = async { + // let track = LocalAudioTrack::create(); + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, _| { + // this.live_kit + // .as_ref() + // .map(|live_kit| live_kit.room.publish_audio_track(track)) + // })? + // .ok_or_else(|| anyhow!("live-kit was not initialized"))? + // .await + // }; - let publication = publish_track.await; - this.upgrade() - .ok_or_else(|| anyhow!("room was dropped"))? - .update(&mut cx, |this, cx| { - let live_kit = this - .live_kit - .as_mut() - .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + // let publication = publish_track.await; + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, cx| { + // let live_kit = this + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; - let (canceled, muted) = if let LocalTrack::Pending { - publish_id: cur_publish_id, - muted, - } = &live_kit.microphone_track - { - (*cur_publish_id != publish_id, *muted) - } else { - (true, false) - }; + // let (canceled, muted) = if let LocalTrack::Pending { + // publish_id: cur_publish_id, + // muted, + // } = &live_kit.microphone_track + // { + // (*cur_publish_id != publish_id, *muted) + // } else { + // (true, false) + // }; - match publication { - Ok(publication) => { - if canceled { - live_kit.room.unpublish_track(publication); - } else { - if muted { - cx.executor().spawn(publication.set_mute(muted)).detach(); - } - live_kit.microphone_track = LocalTrack::Published { - track_publication: publication, - muted, - }; - cx.notify(); - } - Ok(()) - } - Err(error) => { - if canceled { - Ok(()) - } else { - live_kit.microphone_track = LocalTrack::None; - cx.notify(); - Err(error) - } - } - } - })? - }) + // match publication { + // Ok(publication) => { + // if canceled { + // live_kit.room.unpublish_track(publication); + // } else { + // if muted { + // cx.executor().spawn(publication.set_mute(muted)).detach(); + // } + // live_kit.microphone_track = LocalTrack::Published { + // track_publication: publication, + // muted, + // }; + // cx.notify(); + // } + // Ok(()) + // } + // Err(error) => { + // if canceled { + // Ok(()) + // } else { + // live_kit.microphone_track = LocalTrack::None; + // cx.notify(); + // Err(error) + // } + // } + // } + // })? + // }) } pub fn share_screen(&mut self, cx: &mut ModelContext) -> Task> { - if self.status.is_offline() { - return Task::ready(Err(anyhow!("room is offline"))); - } else if self.is_screen_sharing() { - return Task::ready(Err(anyhow!("screen was already shared"))); - } + todo!() + // if self.status.is_offline() { + // return Task::ready(Err(anyhow!("room is offline"))); + // } else if self.is_screen_sharing() { + // return Task::ready(Err(anyhow!("screen was already shared"))); + // } - let (displays, publish_id) = if let Some(live_kit) = self.live_kit.as_mut() { - let publish_id = post_inc(&mut live_kit.next_publish_id); - live_kit.screen_track = LocalTrack::Pending { - publish_id, - muted: false, - }; - cx.notify(); - (live_kit.room.display_sources(), publish_id) - } else { - return Task::ready(Err(anyhow!("live-kit was not initialized"))); - }; + // let (displays, publish_id) = if let Some(live_kit) = self.live_kit.as_mut() { + // let publish_id = post_inc(&mut live_kit.next_publish_id); + // live_kit.screen_track = LocalTrack::Pending { + // publish_id, + // muted: false, + // }; + // cx.notify(); + // (live_kit.room.display_sources(), publish_id) + // } else { + // return Task::ready(Err(anyhow!("live-kit was not initialized"))); + // }; - cx.spawn(move |this, mut cx| async move { - let publish_track = async { - let displays = displays.await?; - let display = displays - .first() - .ok_or_else(|| anyhow!("no display found"))?; - let track = LocalVideoTrack::screen_share_for_display(&display); - this.upgrade() - .ok_or_else(|| anyhow!("room was dropped"))? - .update(&mut cx, |this, _| { - this.live_kit - .as_ref() - .map(|live_kit| live_kit.room.publish_video_track(&track)) - })? - .ok_or_else(|| anyhow!("live-kit was not initialized"))? - .await - }; + // cx.spawn(move |this, mut cx| async move { + // let publish_track = async { + // let displays = displays.await?; + // let display = displays + // .first() + // .ok_or_else(|| anyhow!("no display found"))?; + // let track = LocalVideoTrack::screen_share_for_display(&display); + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, _| { + // this.live_kit + // .as_ref() + // .map(|live_kit| live_kit.room.publish_video_track(track)) + // })? + // .ok_or_else(|| anyhow!("live-kit was not initialized"))? + // .await + // }; - let publication = publish_track.await; - this.upgrade() - .ok_or_else(|| anyhow!("room was dropped"))? - .update(&mut cx, |this, cx| { - let live_kit = this - .live_kit - .as_mut() - .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + // let publication = publish_track.await; + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, cx| { + // let live_kit = this + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; - let (canceled, muted) = if let LocalTrack::Pending { - publish_id: cur_publish_id, - muted, - } = &live_kit.screen_track - { - (*cur_publish_id != publish_id, *muted) - } else { - (true, false) - }; + // let (canceled, muted) = if let LocalTrack::Pending { + // publish_id: cur_publish_id, + // muted, + // } = &live_kit.screen_track + // { + // (*cur_publish_id != publish_id, *muted) + // } else { + // (true, false) + // }; - match publication { - Ok(publication) => { - if canceled { - live_kit.room.unpublish_track(publication); - } else { - if muted { - cx.executor().spawn(publication.set_mute(muted)).detach(); - } - live_kit.screen_track = LocalTrack::Published { - track_publication: publication, - muted, - }; - cx.notify(); - } + // match publication { + // Ok(publication) => { + // if canceled { + // live_kit.room.unpublish_track(publication); + // } else { + // if muted { + // cx.executor().spawn(publication.set_mute(muted)).detach(); + // } + // live_kit.screen_track = LocalTrack::Published { + // track_publication: publication, + // muted, + // }; + // cx.notify(); + // } - Audio::play_sound(Sound::StartScreenshare, cx); + // Audio::play_sound(Sound::StartScreenshare, cx); - Ok(()) - } - Err(error) => { - if canceled { - Ok(()) - } else { - live_kit.screen_track = LocalTrack::None; - cx.notify(); - Err(error) - } - } - } - })? - }) + // Ok(()) + // } + // Err(error) => { + // if canceled { + // Ok(()) + // } else { + // live_kit.screen_track = LocalTrack::None; + // cx.notify(); + // Err(error) + // } + // } + // } + // })? + // }) } pub fn toggle_mute(&mut self, cx: &mut ModelContext) -> Result>> { - let should_mute = !self.is_muted(cx); - if let Some(live_kit) = self.live_kit.as_mut() { - if matches!(live_kit.microphone_track, LocalTrack::None) { - return Ok(self.share_microphone(cx)); - } + todo!() + // let should_mute = !self.is_muted(cx); + // if let Some(live_kit) = self.live_kit.as_mut() { + // if matches!(live_kit.microphone_track, LocalTrack::None) { + // return Ok(self.share_microphone(cx)); + // } - let (ret_task, old_muted) = live_kit.set_mute(should_mute, cx)?; - live_kit.muted_by_user = should_mute; + // let (ret_task, old_muted) = live_kit.set_mute(should_mute, cx)?; + // live_kit.muted_by_user = should_mute; - if old_muted == true && live_kit.deafened == true { - if let Some(task) = self.toggle_deafen(cx).ok() { - task.detach(); - } - } + // if old_muted == true && live_kit.deafened == true { + // if let Some(task) = self.toggle_deafen(cx).ok() { + // task.detach(); + // } + // } - Ok(ret_task) - } else { - Err(anyhow!("LiveKit not started")) - } + // Ok(ret_task) + // } else { + // Err(anyhow!("LiveKit not started")) + // } } pub fn toggle_deafen(&mut self, cx: &mut ModelContext) -> Result>> { - if let Some(live_kit) = self.live_kit.as_mut() { - (*live_kit).deafened = !live_kit.deafened; + todo!() + // if let Some(live_kit) = self.live_kit.as_mut() { + // (*live_kit).deafened = !live_kit.deafened; - let mut tasks = Vec::with_capacity(self.remote_participants.len()); - // Context notification is sent within set_mute itself. - let mut mute_task = None; - // When deafening, mute user's mic as well. - // When undeafening, unmute user's mic unless it was manually muted prior to deafening. - if live_kit.deafened || !live_kit.muted_by_user { - mute_task = Some(live_kit.set_mute(live_kit.deafened, cx)?.0); - }; - for participant in self.remote_participants.values() { - for track in live_kit - .room - .remote_audio_track_publications(&participant.user.id.to_string()) - { - let deafened = live_kit.deafened; - tasks.push( - cx.executor() - .spawn_on_main(move || track.set_enabled(!deafened)), - ); - } - } + // let mut tasks = Vec::with_capacity(self.remote_participants.len()); + // // Context notification is sent within set_mute itself. + // let mut mute_task = None; + // // When deafening, mute user's mic as well. + // // When undeafening, unmute user's mic unless it was manually muted prior to deafening. + // if live_kit.deafened || !live_kit.muted_by_user { + // mute_task = Some(live_kit.set_mute(live_kit.deafened, cx)?.0); + // }; + // for participant in self.remote_participants.values() { + // for track in live_kit + // .room + // .remote_audio_track_publications(&participant.user.id.to_string()) + // { + // let deafened = live_kit.deafened; + // tasks.push( + // cx.executor() + // .spawn_on_main(move || track.set_enabled(!deafened)), + // ); + // } + // } - Ok(cx.executor().spawn_on_main(|| async { - if let Some(mute_task) = mute_task { - mute_task.await?; - } - for task in tasks { - task.await?; - } - Ok(()) - })) - } else { - Err(anyhow!("LiveKit not started")) - } + // Ok(cx.executor().spawn_on_main(|| async { + // if let Some(mute_task) = mute_task { + // mute_task.await?; + // } + // for task in tasks { + // task.await?; + // } + // Ok(()) + // })) + // } else { + // Err(anyhow!("LiveKit not started")) + // } } pub fn unshare_screen(&mut self, cx: &mut ModelContext) -> Result<()> { @@ -1479,35 +1494,37 @@ impl Room { return Err(anyhow!("room is offline")); } - let live_kit = self - .live_kit - .as_mut() - .ok_or_else(|| anyhow!("live-kit was not initialized"))?; - match mem::take(&mut live_kit.screen_track) { - LocalTrack::None => Err(anyhow!("screen was not shared")), - LocalTrack::Pending { .. } => { - cx.notify(); - Ok(()) - } - LocalTrack::Published { - track_publication, .. - } => { - live_kit.room.unpublish_track(track_publication); - cx.notify(); + todo!() + // let live_kit = self + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + // match mem::take(&mut live_kit.screen_track) { + // LocalTrack::None => Err(anyhow!("screen was not shared")), + // LocalTrack::Pending { .. } => { + // cx.notify(); + // Ok(()) + // } + // LocalTrack::Published { + // track_publication, .. + // } => { + // live_kit.room.unpublish_track(track_publication); + // cx.notify(); - Audio::play_sound(Sound::StopScreenshare, cx); - Ok(()) - } - } + // Audio::play_sound(Sound::StopScreenshare, cx); + // Ok(()) + // } + // } } #[cfg(any(test, feature = "test-support"))] pub fn set_display_sources(&self, sources: Vec) { - self.live_kit - .as_ref() - .unwrap() - .room - .set_display_sources(sources); + todo!() + // self.live_kit + // .as_ref() + // .unwrap() + // .room + // .set_display_sources(sources); } } diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index d31d4b3c8c..d0a32e16ff 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -7,10 +7,11 @@ use gpui::{AppContext, ModelHandle}; use std::sync::Arc; pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; -pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId}; -pub use channel_store::{ - Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore, +pub use channel_chat::{ + mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, + MessageParams, }; +pub use channel_store::{Channel, ChannelEvent, ChannelId, ChannelMembership, ChannelStore}; #[cfg(test)] mod channel_store_tests; diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index ab7ea78ac1..9089973d32 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -1,4 +1,4 @@ -use crate::Channel; +use crate::{Channel, ChannelId, ChannelStore}; use anyhow::Result; use client::{Client, Collaborator, UserStore}; use collections::HashMap; @@ -19,10 +19,11 @@ pub(crate) fn init(client: &Arc) { } pub struct ChannelBuffer { - pub(crate) channel: Arc, + pub channel_id: ChannelId, connected: bool, collaborators: HashMap, user_store: ModelHandle, + channel_store: ModelHandle, buffer: ModelHandle, buffer_epoch: u64, client: Arc, @@ -34,6 +35,7 @@ pub enum ChannelBufferEvent { CollaboratorsChanged, Disconnected, BufferEdited, + ChannelChanged, } impl Entity for ChannelBuffer { @@ -46,7 +48,7 @@ impl Entity for ChannelBuffer { } self.client .send(proto::LeaveChannelBuffer { - channel_id: self.channel.id, + channel_id: self.channel_id, }) .log_err(); } @@ -58,6 +60,7 @@ impl ChannelBuffer { channel: Arc, client: Arc, user_store: ModelHandle, + channel_store: ModelHandle, mut cx: AsyncAppContext, ) -> Result> { let response = client @@ -90,9 +93,10 @@ impl ChannelBuffer { connected: true, collaborators: Default::default(), acknowledge_task: None, - channel, + channel_id: channel.id, subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())), user_store, + channel_store, }; this.replace_collaborators(response.collaborators, cx); this @@ -179,7 +183,7 @@ impl ChannelBuffer { let operation = language::proto::serialize_operation(operation); self.client .send(proto::UpdateChannelBuffer { - channel_id: self.channel.id, + channel_id: self.channel_id, operations: vec![operation], }) .log_err(); @@ -223,12 +227,15 @@ impl ChannelBuffer { &self.collaborators } - pub fn channel(&self) -> Arc { - self.channel.clone() + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_store + .read(cx) + .channel_for_id(self.channel_id) + .cloned() } pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { - log::info!("channel buffer {} disconnected", self.channel.id); + log::info!("channel buffer {} disconnected", self.channel_id); if self.connected { self.connected = false; self.subscription.take(); @@ -237,6 +244,11 @@ impl ChannelBuffer { } } + pub(crate) fn channel_changed(&mut self, cx: &mut ModelContext) { + cx.emit(ChannelBufferEvent::ChannelChanged); + cx.notify() + } + pub fn is_connected(&self) -> bool { self.connected } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 734182886b..ef11d96424 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -3,19 +3,25 @@ use anyhow::{anyhow, Result}; use client::{ proto, user::{User, UserStore}, - Client, Subscription, TypedEnvelope, + Client, Subscription, TypedEnvelope, UserId, }; use futures::lock::Mutex; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use rand::prelude::*; -use std::{collections::HashSet, mem, ops::Range, sync::Arc}; +use std::{ + collections::HashSet, + mem, + ops::{ControlFlow, Range}, + sync::Arc, +}; use sum_tree::{Bias, SumTree}; use time::OffsetDateTime; use util::{post_inc, ResultExt as _, TryFutureExt}; pub struct ChannelChat { - channel: Arc, + pub channel_id: ChannelId, messages: SumTree, + acknowledged_message_ids: HashSet, channel_store: ModelHandle, loaded_all_messages: bool, last_acknowledged_id: Option, @@ -27,6 +33,12 @@ pub struct ChannelChat { _subscription: Subscription, } +#[derive(Debug, PartialEq, Eq)] +pub struct MessageParams { + pub text: String, + pub mentions: Vec<(Range, UserId)>, +} + #[derive(Clone, Debug)] pub struct ChannelMessage { pub id: ChannelMessageId, @@ -34,6 +46,7 @@ pub struct ChannelMessage { pub timestamp: OffsetDateTime, pub sender: Arc, pub nonce: u128, + pub mentions: Vec<(Range, UserId)>, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -74,7 +87,7 @@ impl Entity for ChannelChat { fn release(&mut self, _: &mut AppContext) { self.rpc .send(proto::LeaveChannelChat { - channel_id: self.channel.id, + channel_id: self.channel_id, }) .log_err(); } @@ -99,12 +112,13 @@ impl ChannelChat { Ok(cx.add_model(|cx| { let mut this = Self { - channel, + channel_id: channel.id, user_store, channel_store, rpc: client, outgoing_messages_lock: Default::default(), messages: Default::default(), + acknowledged_message_ids: Default::default(), loaded_all_messages, next_pending_message_id: 0, last_acknowledged_id: None, @@ -116,16 +130,23 @@ impl ChannelChat { })) } - pub fn channel(&self) -> &Arc { - &self.channel + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_store + .read(cx) + .channel_for_id(self.channel_id) + .cloned() + } + + pub fn client(&self) -> &Arc { + &self.rpc } pub fn send_message( &mut self, - body: String, + message: MessageParams, cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { + ) -> Result>> { + if message.text.is_empty() { Err(anyhow!("message body can't be empty"))?; } @@ -135,16 +156,17 @@ impl ChannelChat { .current_user() .ok_or_else(|| anyhow!("current_user is not present"))?; - let channel_id = self.channel.id; + let channel_id = self.channel_id; let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); let nonce = self.rng.gen(); self.insert_messages( SumTree::from_item( ChannelMessage { id: pending_id, - body: body.clone(), + body: message.text.clone(), sender: current_user, timestamp: OffsetDateTime::now_utc(), + mentions: message.mentions.clone(), nonce, }, &(), @@ -158,27 +180,25 @@ impl ChannelChat { let outgoing_message_guard = outgoing_messages_lock.lock().await; let request = rpc.request(proto::SendChannelMessage { channel_id, - body, + body: message.text, nonce: Some(nonce.into()), + mentions: mentions_to_proto(&message.mentions), }); let response = request.await?; drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; + let response = response.message.ok_or_else(|| anyhow!("invalid message"))?; + let id = response.id; + let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) + Ok(id) }) })) } pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext) -> Task> { let response = self.rpc.request(proto::RemoveChannelMessage { - channel_id: self.channel.id, + channel_id: self.channel_id, message_id: id, }); cx.spawn(|this, mut cx| async move { @@ -191,41 +211,76 @@ impl ChannelChat { }) } - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.channel.id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - anyhow::Ok(()) + pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> Option>> { + if self.loaded_all_messages { + return None; + } + + let rpc = self.rpc.clone(); + let user_store = self.user_store.clone(); + let channel_id = self.channel_id; + let before_message_id = self.first_loaded_message_id()?; + Some(cx.spawn(|this, mut cx| { + async move { + let response = rpc + .request(proto::GetChannelMessages { + channel_id, + before_message_id, + }) + .await?; + let loaded_all_messages = response.done; + let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; + this.update(&mut cx, |this, cx| { + this.loaded_all_messages = loaded_all_messages; + this.insert_messages(messages, cx); + }); + anyhow::Ok(()) + } + .log_err() + })) + } + + pub fn first_loaded_message_id(&mut self) -> Option { + self.messages.first().and_then(|message| match message.id { + ChannelMessageId::Saved(id) => Some(id), + ChannelMessageId::Pending(_) => None, + }) + } + + /// Load all of the chat messages since a certain message id. + /// + /// For now, we always maintain a suffix of the channel's messages. + pub async fn load_history_since_message( + chat: ModelHandle, + message_id: u64, + mut cx: AsyncAppContext, + ) -> Option { + loop { + let step = chat.update(&mut cx, |chat, cx| { + if let Some(first_id) = chat.first_loaded_message_id() { + if first_id <= message_id { + let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(); + let message_id = ChannelMessageId::Saved(message_id); + cursor.seek(&message_id, Bias::Left, &()); + return ControlFlow::Break( + if cursor + .item() + .map_or(false, |message| message.id == message_id) + { + Some(cursor.start().1 .0) + } else { + None + }, + ); } - .log_err() - }) - .detach(); - return true; + } + ControlFlow::Continue(chat.load_more_messages(cx)) + }); + match step { + ControlFlow::Break(ix) => return ix, + ControlFlow::Continue(task) => task?.await?, } } - false } pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext) { @@ -236,13 +291,13 @@ impl ChannelChat { { self.rpc .send(proto::AckChannelMessage { - channel_id: self.channel.id, + channel_id: self.channel_id, message_id: latest_message_id, }) .ok(); self.last_acknowledged_id = Some(latest_message_id); self.channel_store.update(cx, |store, cx| { - store.acknowledge_message_id(self.channel.id, latest_message_id, cx); + store.acknowledge_message_id(self.channel_id, latest_message_id, cx); }); } } @@ -251,7 +306,7 @@ impl ChannelChat { pub fn rejoin(&mut self, cx: &mut ModelContext) { let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); - let channel_id = self.channel.id; + let channel_id = self.channel_id; cx.spawn(|this, mut cx| { async move { let response = rpc.request(proto::JoinChannelChat { channel_id }).await?; @@ -284,6 +339,7 @@ impl ChannelChat { let request = rpc.request(proto::SendChannelMessage { channel_id, body: pending_message.body, + mentions: mentions_to_proto(&pending_message.mentions), nonce: Some(pending_message.nonce.into()), }); let response = request.await?; @@ -319,6 +375,17 @@ impl ChannelChat { cursor.item().unwrap() } + pub fn acknowledge_message(&mut self, id: u64) { + if self.acknowledged_message_ids.insert(id) { + self.rpc + .send(proto::AckChannelMessage { + channel_id: self.channel_id, + message_id: id, + }) + .ok(); + } + } + pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(); cursor.seek(&Count(range.start), Bias::Right, &()); @@ -348,7 +415,7 @@ impl ChannelChat { this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); cx.emit(ChannelChatEvent::NewMessage { - channel_id: this.channel.id, + channel_id: this.channel_id, message_id, }) }); @@ -451,22 +518,7 @@ async fn messages_from_proto( user_store: &ModelHandle, cx: &mut AsyncAppContext, ) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } + let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?; let mut result = SumTree::new(); result.extend(messages, &()); Ok(result) @@ -486,6 +538,14 @@ impl ChannelMessage { Ok(ChannelMessage { id: ChannelMessageId::Saved(message.id), body: message.body, + mentions: message + .mentions + .into_iter() + .filter_map(|mention| { + let range = mention.range?; + Some((range.start as usize..range.end as usize, mention.user_id)) + }) + .collect(), timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, nonce: message @@ -498,6 +558,43 @@ impl ChannelMessage { pub fn is_pending(&self) -> bool { matches!(self.id, ChannelMessageId::Pending(_)) } + + pub async fn from_proto_vec( + proto_messages: Vec, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, + ) -> Result> { + let unique_user_ids = proto_messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store + .update(cx, |user_store, cx| { + user_store.get_users(unique_user_ids, cx) + }) + .await?; + + let mut messages = Vec::with_capacity(proto_messages.len()); + for message in proto_messages { + messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); + } + Ok(messages) + } +} + +pub fn mentions_to_proto(mentions: &[(Range, UserId)]) -> Vec { + mentions + .iter() + .map(|(range, user_id)| proto::ChatMention { + range: Some(proto::Range { + start: range.start as u64, + end: range.end as u64, + }), + user_id: *user_id as u64, + }) + .collect() } impl sum_tree::Item for ChannelMessage { @@ -538,3 +635,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { self.0 += summary.count; } } + +impl<'a> From<&'a str> for MessageParams { + fn from(value: &'a str) -> Self { + Self { + text: value.into(), + mentions: Vec::new(), + } + } +} diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index bceb2c094d..efa05d51a9 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,6 +1,6 @@ mod channel_index; -use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; +use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage}; use anyhow::{anyhow, Result}; use channel_index::ChannelIndex; use client::{Client, Subscription, User, UserId, UserStore}; @@ -9,11 +9,10 @@ use db::RELEASE_CHANNEL; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{ - proto::{self, ChannelEdge, ChannelPermission}, + proto::{self, ChannelVisibility}, TypedEnvelope, }; -use serde_derive::{Deserialize, Serialize}; -use std::{borrow::Cow, hash::Hash, mem, ops::Deref, sync::Arc, time::Duration}; +use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; pub fn init(client: &Arc, user_store: ModelHandle, cx: &mut AppContext) { @@ -27,10 +26,9 @@ pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub type ChannelId = u64; pub struct ChannelStore { - channel_index: ChannelIndex, + pub channel_index: ChannelIndex, channel_invitations: Vec>, channel_participants: HashMap>>, - channels_with_admin_privileges: HashSet, outgoing_invites: HashSet<(ChannelId, UserId)>, update_channels_tx: mpsc::UnboundedSender, opened_buffers: HashMap>, @@ -43,14 +41,15 @@ pub struct ChannelStore { _update_channels: Task<()>, } -pub type ChannelData = (Channel, ChannelPath); - #[derive(Clone, Debug, PartialEq)] pub struct Channel { pub id: ChannelId, pub name: String, + pub visibility: proto::ChannelVisibility, + pub role: proto::ChannelRole, pub unseen_note_version: Option<(u64, clock::Global)>, pub unseen_message_id: Option, + pub parent_path: Vec, } impl Channel { @@ -71,15 +70,41 @@ impl Channel { slug.trim_matches(|c| c == '-').to_string() } -} -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct ChannelPath(Arc<[ChannelId]>); + pub fn can_edit_notes(&self) -> bool { + self.role == proto::ChannelRole::Member || self.role == proto::ChannelRole::Admin + } +} pub struct ChannelMembership { pub user: Arc, pub kind: proto::channel_member::Kind, - pub admin: bool, + pub role: proto::ChannelRole, +} +impl ChannelMembership { + pub fn sort_key(&self) -> MembershipSortKey { + MembershipSortKey { + role_order: match self.role { + proto::ChannelRole::Admin => 0, + proto::ChannelRole::Member => 1, + proto::ChannelRole::Banned => 2, + proto::ChannelRole::Guest => 3, + }, + kind_order: match self.kind { + proto::channel_member::Kind::Member => 0, + proto::channel_member::Kind::AncestorMember => 1, + proto::channel_member::Kind::Invitee => 2, + }, + username_order: self.user.github_login.as_str(), + } + } +} + +#[derive(PartialOrd, Ord, PartialEq, Eq)] +pub struct MembershipSortKey<'a> { + role_order: u8, + kind_order: u8, + username_order: &'a str, } pub enum ChannelEvent { @@ -127,9 +152,6 @@ impl ChannelStore { this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx)); } } - if status.is_connected() { - } else { - } } Some(()) }); @@ -138,7 +160,6 @@ impl ChannelStore { channel_invitations: Vec::default(), channel_index: ChannelIndex::default(), channel_participants: Default::default(), - channels_with_admin_privileges: Default::default(), outgoing_invites: Default::default(), opened_buffers: Default::default(), opened_chats: Default::default(), @@ -167,16 +188,6 @@ impl ChannelStore { self.client.clone() } - pub fn has_children(&self, channel_id: ChannelId) -> bool { - self.channel_index.iter().any(|path| { - if let Some(ix) = path.iter().position(|id| *id == channel_id) { - path.len() > ix + 1 - } else { - false - } - }) - } - /// Returns the number of unique channels in the store pub fn channel_count(&self) -> usize { self.channel_index.by_id().len() @@ -196,26 +207,31 @@ impl ChannelStore { } /// Iterate over all entries in the channel DAG - pub fn channel_dag_entries(&self) -> impl '_ + Iterator)> { - self.channel_index.iter().map(move |path| { - let id = path.last().unwrap(); - let channel = self.channel_for_id(*id).unwrap(); - (path.len() - 1, channel) - }) + pub fn ordered_channels(&self) -> impl '_ + Iterator)> { + self.channel_index + .ordered_channels() + .iter() + .filter_map(move |id| { + let channel = self.channel_index.by_id().get(id)?; + Some((channel.parent_path.len(), channel)) + }) } - pub fn channel_dag_entry_at(&self, ix: usize) -> Option<(&Arc, &ChannelPath)> { - let path = self.channel_index.get(ix)?; - let id = path.last().unwrap(); - let channel = self.channel_for_id(*id).unwrap(); - - Some((channel, path)) + pub fn channel_at_index(&self, ix: usize) -> Option<&Arc> { + let channel_id = self.channel_index.ordered_channels().get(ix)?; + self.channel_index.by_id().get(channel_id) } pub fn channel_at(&self, ix: usize) -> Option<&Arc> { self.channel_index.by_id().values().nth(ix) } + pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool { + self.channel_invitations + .iter() + .any(|channel| channel.id == channel_id) + } + pub fn channel_invitations(&self) -> &[Arc] { &self.channel_invitations } @@ -240,14 +256,42 @@ impl ChannelStore { ) -> Task>> { let client = self.client.clone(); let user_store = self.user_store.clone(); + let channel_store = cx.handle(); self.open_channel_resource( channel_id, |this| &mut this.opened_buffers, - |channel, cx| ChannelBuffer::new(channel, client, user_store, cx), + |channel, cx| ChannelBuffer::new(channel, client, user_store, channel_store, cx), cx, ) } + pub fn fetch_channel_messages( + &self, + message_ids: Vec, + cx: &mut ModelContext, + ) -> Task>> { + let request = if message_ids.is_empty() { + None + } else { + Some( + self.client + .request(proto::GetChannelMessagesById { message_ids }), + ) + }; + cx.spawn_weak(|this, mut cx| async move { + if let Some(request) = request { + let response = request.await?; + let this = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("channel store dropped"))?; + let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); + ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await + } else { + Ok(Vec::new()) + } + }) + } + pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option { self.channel_index .by_id() @@ -393,16 +437,11 @@ impl ChannelStore { .spawn(async move { task.await.map_err(|error| anyhow!("{}", error)) }) } - pub fn is_user_admin(&self, channel_id: ChannelId) -> bool { - self.channel_index.iter().any(|path| { - if let Some(ix) = path.iter().position(|id| *id == channel_id) { - path[..=ix] - .iter() - .any(|id| self.channels_with_admin_privileges.contains(id)) - } else { - false - } - }) + pub fn is_channel_admin(&self, channel_id: ChannelId) -> bool { + let Some(channel) = self.channel_for_id(channel_id) else { + return false; + }; + channel.role == proto::ChannelRole::Admin } pub fn channel_participants(&self, channel_id: ChannelId) -> &[Arc] { @@ -429,24 +468,19 @@ impl ChannelStore { .ok_or_else(|| anyhow!("missing channel in response"))?; let channel_id = channel.id; - let parent_edge = if let Some(parent_id) = parent_id { - vec![ChannelEdge { - channel_id: channel.id, - parent_id, - }] - } else { - vec![] - }; + // let parent_edge = if let Some(parent_id) = parent_id { + // vec![ChannelEdge { + // channel_id: channel.id, + // parent_id, + // }] + // } else { + // vec![] + // }; this.update(&mut cx, |this, cx| { let task = this.update_channels( proto::UpdateChannels { channels: vec![channel], - insert_edge: parent_edge, - channel_permissions: vec![ChannelPermission { - channel_id, - is_admin: true, - }], ..Default::default() }, cx, @@ -464,52 +498,34 @@ impl ChannelStore { }) } - pub fn link_channel( - &mut self, - channel_id: ChannelId, - to: ChannelId, - cx: &mut ModelContext, - ) -> Task> { - let client = self.client.clone(); - cx.spawn(|_, _| async move { - let _ = client - .request(proto::LinkChannel { channel_id, to }) - .await?; - - Ok(()) - }) - } - - pub fn unlink_channel( - &mut self, - channel_id: ChannelId, - from: ChannelId, - cx: &mut ModelContext, - ) -> Task> { - let client = self.client.clone(); - cx.spawn(|_, _| async move { - let _ = client - .request(proto::UnlinkChannel { channel_id, from }) - .await?; - - Ok(()) - }) - } - pub fn move_channel( &mut self, channel_id: ChannelId, - from: ChannelId, - to: ChannelId, + to: Option, cx: &mut ModelContext, ) -> Task> { let client = self.client.clone(); cx.spawn(|_, _| async move { let _ = client - .request(proto::MoveChannel { + .request(proto::MoveChannel { channel_id, to }) + .await?; + + Ok(()) + }) + } + + pub fn set_channel_visibility( + &mut self, + channel_id: ChannelId, + visibility: ChannelVisibility, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.clone(); + cx.spawn(|_, _| async move { + let _ = client + .request(proto::SetChannelVisibility { channel_id, - from, - to, + visibility: visibility.into(), }) .await?; @@ -521,7 +537,7 @@ impl ChannelStore { &mut self, channel_id: ChannelId, user_id: UserId, - admin: bool, + role: proto::ChannelRole, cx: &mut ModelContext, ) -> Task> { if !self.outgoing_invites.insert((channel_id, user_id)) { @@ -535,7 +551,7 @@ impl ChannelStore { .request(proto::InviteChannelMember { channel_id, user_id, - admin, + role: role.into(), }) .await; @@ -579,11 +595,11 @@ impl ChannelStore { }) } - pub fn set_member_admin( + pub fn set_member_role( &mut self, channel_id: ChannelId, user_id: UserId, - admin: bool, + role: proto::ChannelRole, cx: &mut ModelContext, ) -> Task> { if !self.outgoing_invites.insert((channel_id, user_id)) { @@ -594,10 +610,10 @@ impl ChannelStore { let client = self.client.clone(); cx.spawn(|this, mut cx| async move { let result = client - .request(proto::SetChannelMemberAdmin { + .request(proto::SetChannelMemberRole { channel_id, user_id, - admin, + role: role.into(), }) .await; @@ -649,14 +665,15 @@ impl ChannelStore { &mut self, channel_id: ChannelId, accept: bool, - ) -> impl Future> { + cx: &mut ModelContext, + ) -> Task> { let client = self.client.clone(); - async move { + cx.background().spawn(async move { client .request(proto::RespondToChannelInvite { channel_id, accept }) .await?; Ok(()) - } + }) } pub fn get_channel_member_details( @@ -685,8 +702,8 @@ impl ChannelStore { .filter_map(|(user, member)| { Some(ChannelMembership { user, - admin: member.admin, - kind: proto::channel_member::Kind::from_i32(member.kind)?, + role: member.role(), + kind: member.kind(), }) }) .collect()) @@ -724,6 +741,11 @@ impl ChannelStore { } fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.channel_index.clear(); + self.channel_invitations.clear(); + self.channel_participants.clear(); + self.channel_index.clear(); + self.outgoing_invites.clear(); self.disconnect_channel_buffers_task.take(); for chat in self.opened_chats.values() { @@ -743,7 +765,7 @@ impl ChannelStore { let channel_buffer = buffer.read(cx); let buffer = channel_buffer.buffer().read(cx); buffer_versions.push(proto::ChannelBufferVersion { - channel_id: channel_buffer.channel().id, + channel_id: channel_buffer.channel_id, epoch: channel_buffer.epoch(), version: language::proto::serialize_version(&buffer.version()), }); @@ -770,13 +792,13 @@ impl ChannelStore { }; channel_buffer.update(cx, |channel_buffer, cx| { - let channel_id = channel_buffer.channel().id; + let channel_id = channel_buffer.channel_id; if let Some(remote_buffer) = response .buffers .iter_mut() .find(|buffer| buffer.channel_id == channel_id) { - let channel_id = channel_buffer.channel().id; + let channel_id = channel_buffer.channel_id; let remote_version = language::proto::deserialize_version(&remote_buffer.version); @@ -833,12 +855,6 @@ impl ChannelStore { } fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext) { - self.channel_index.clear(); - self.channel_invitations.clear(); - self.channel_participants.clear(); - self.channels_with_admin_privileges.clear(); - self.channel_index.clear(); - self.outgoing_invites.clear(); cx.notify(); self.disconnect_channel_buffers_task.get_or_insert_with(|| { @@ -881,9 +897,12 @@ impl ChannelStore { ix, Arc::new(Channel { id: channel.id, + visibility: channel.visibility(), + role: channel.role(), name: channel.name, unseen_note_version: None, unseen_message_id: None, + parent_path: channel.parent_path, }), ), } @@ -891,8 +910,6 @@ impl ChannelStore { let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty() - || !payload.insert_edge.is_empty() - || !payload.delete_edge.is_empty() || !payload.unseen_channel_messages.is_empty() || !payload.unseen_channel_buffer_changes.is_empty(); @@ -900,12 +917,17 @@ impl ChannelStore { if !payload.delete_channels.is_empty() { self.channel_index.delete_channels(&payload.delete_channels); self.channel_participants - .retain(|channel_id, _| !payload.delete_channels.contains(channel_id)); - self.channels_with_admin_privileges - .retain(|channel_id| !payload.delete_channels.contains(channel_id)); + .retain(|channel_id, _| !&payload.delete_channels.contains(channel_id)); for channel_id in &payload.delete_channels { let channel_id = *channel_id; + if payload + .channels + .iter() + .any(|channel| channel.id == channel_id) + { + continue; + } if let Some(OpenedModelHandle::Open(buffer)) = self.opened_buffers.remove(&channel_id) { @@ -918,7 +940,16 @@ impl ChannelStore { let mut index = self.channel_index.bulk_insert(); for channel in payload.channels { - index.insert(channel) + let id = channel.id; + let channel_changed = index.insert(channel); + + if channel_changed { + if let Some(OpenedModelHandle::Open(buffer)) = self.opened_buffers.get(&id) { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, ChannelBuffer::channel_changed); + } + } + } } for unseen_buffer_change in payload.unseen_channel_buffer_changes { @@ -936,24 +967,6 @@ impl ChannelStore { unseen_channel_message.message_id, ); } - - for edge in payload.insert_edge { - index.insert_edge(edge.channel_id, edge.parent_id); - } - - for edge in payload.delete_edge { - index.delete_edge(edge.parent_id, edge.channel_id); - } - } - - for permission in payload.channel_permissions { - if permission.is_admin { - self.channels_with_admin_privileges - .insert(permission.channel_id); - } else { - self.channels_with_admin_privileges - .remove(&permission.channel_id); - } } cx.notify(); @@ -1002,44 +1015,3 @@ impl ChannelStore { })) } } - -impl Deref for ChannelPath { - type Target = [ChannelId]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl ChannelPath { - pub fn new(path: Arc<[ChannelId]>) -> Self { - debug_assert!(path.len() >= 1); - Self(path) - } - - pub fn parent_id(&self) -> Option { - self.0.len().checked_sub(2).map(|i| self.0[i]) - } - - pub fn channel_id(&self) -> ChannelId { - self.0[self.0.len() - 1] - } -} - -impl From for Cow<'static, ChannelPath> { - fn from(value: ChannelPath) -> Self { - Cow::Owned(value) - } -} - -impl<'a> From<&'a ChannelPath> for Cow<'a, ChannelPath> { - fn from(value: &'a ChannelPath) -> Self { - Cow::Borrowed(value) - } -} - -impl Default for ChannelPath { - fn default() -> Self { - ChannelPath(Arc::from([])) - } -} diff --git a/crates/channel/src/channel_store/channel_index.rs b/crates/channel/src/channel_store/channel_index.rs index bf0de1b644..97b2ab6318 100644 --- a/crates/channel/src/channel_store/channel_index.rs +++ b/crates/channel/src/channel_store/channel_index.rs @@ -1,14 +1,11 @@ -use std::{ops::Deref, sync::Arc}; - use crate::{Channel, ChannelId}; use collections::BTreeMap; use rpc::proto; - -use super::ChannelPath; +use std::sync::Arc; #[derive(Default, Debug)] pub struct ChannelIndex { - paths: Vec, + channels_ordered: Vec, channels_by_id: BTreeMap>, } @@ -17,8 +14,12 @@ impl ChannelIndex { &self.channels_by_id } + pub fn ordered_channels(&self) -> &[ChannelId] { + &self.channels_ordered + } + pub fn clear(&mut self) { - self.paths.clear(); + self.channels_ordered.clear(); self.channels_by_id.clear(); } @@ -26,15 +27,13 @@ impl ChannelIndex { pub fn delete_channels(&mut self, channels: &[ChannelId]) { self.channels_by_id .retain(|channel_id, _| !channels.contains(channel_id)); - self.paths.retain(|path| { - path.iter() - .all(|channel_id| self.channels_by_id.contains_key(channel_id)) - }); + self.channels_ordered + .retain(|channel_id| !channels.contains(channel_id)); } pub fn bulk_insert(&mut self) -> ChannelPathsInsertGuard { ChannelPathsInsertGuard { - paths: &mut self.paths, + channels_ordered: &mut self.channels_ordered, channels_by_id: &mut self.channels_by_id, } } @@ -77,42 +76,15 @@ impl ChannelIndex { } } -impl Deref for ChannelIndex { - type Target = [ChannelPath]; - - fn deref(&self) -> &Self::Target { - &self.paths - } -} - /// A guard for ensuring that the paths index maintains its sort and uniqueness /// invariants after a series of insertions #[derive(Debug)] pub struct ChannelPathsInsertGuard<'a> { - paths: &'a mut Vec, + channels_ordered: &'a mut Vec, channels_by_id: &'a mut BTreeMap>, } impl<'a> ChannelPathsInsertGuard<'a> { - /// Remove the given edge from this index. This will not remove the channel. - /// If this operation would result in a dangling edge, re-insert it. - pub fn delete_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) { - self.paths.retain(|path| { - !path - .windows(2) - .any(|window| window == [parent_id, channel_id]) - }); - - // Ensure that there is at least one channel path in the index - if !self - .paths - .iter() - .any(|path| path.iter().any(|id| id == &channel_id)) - { - self.insert_root(channel_id); - } - } - pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) { insert_note_changed(&mut self.channels_by_id, channel_id, epoch, &version); } @@ -121,91 +93,65 @@ impl<'a> ChannelPathsInsertGuard<'a> { insert_new_message(&mut self.channels_by_id, channel_id, message_id) } - pub fn insert(&mut self, channel_proto: proto::Channel) { + pub fn insert(&mut self, channel_proto: proto::Channel) -> bool { + let mut ret = false; if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) { - Arc::make_mut(existing_channel).name = channel_proto.name; + let existing_channel = Arc::make_mut(existing_channel); + + ret = existing_channel.visibility != channel_proto.visibility() + || existing_channel.role != channel_proto.role() + || existing_channel.name != channel_proto.name; + + existing_channel.visibility = channel_proto.visibility(); + existing_channel.role = channel_proto.role(); + existing_channel.name = channel_proto.name; } else { self.channels_by_id.insert( channel_proto.id, Arc::new(Channel { id: channel_proto.id, + visibility: channel_proto.visibility(), + role: channel_proto.role(), name: channel_proto.name, unseen_note_version: None, unseen_message_id: None, + parent_path: channel_proto.parent_path, }), ); self.insert_root(channel_proto.id); } - } - - pub fn insert_edge(&mut self, channel_id: ChannelId, parent_id: ChannelId) { - let mut parents = Vec::new(); - let mut descendants = Vec::new(); - let mut ixs_to_remove = Vec::new(); - - for (ix, path) in self.paths.iter().enumerate() { - if path - .windows(2) - .any(|window| window[0] == parent_id && window[1] == channel_id) - { - // We already have this edge in the index - return; - } - if path.ends_with(&[parent_id]) { - parents.push(path); - } else if let Some(position) = path.iter().position(|id| id == &channel_id) { - if position == 0 { - ixs_to_remove.push(ix); - } - descendants.push(path.split_at(position).1); - } - } - - let mut new_paths = Vec::new(); - for parent in parents.iter() { - if descendants.is_empty() { - let mut new_path = Vec::with_capacity(parent.len() + 1); - new_path.extend_from_slice(parent); - new_path.push(channel_id); - new_paths.push(ChannelPath::new(new_path.into())); - } else { - for descendant in descendants.iter() { - let mut new_path = Vec::with_capacity(parent.len() + descendant.len()); - new_path.extend_from_slice(parent); - new_path.extend_from_slice(descendant); - new_paths.push(ChannelPath::new(new_path.into())); - } - } - } - - for ix in ixs_to_remove.into_iter().rev() { - self.paths.swap_remove(ix); - } - self.paths.extend(new_paths) + ret } fn insert_root(&mut self, channel_id: ChannelId) { - self.paths.push(ChannelPath::new(Arc::from([channel_id]))); + self.channels_ordered.push(channel_id); } } impl<'a> Drop for ChannelPathsInsertGuard<'a> { fn drop(&mut self) { - self.paths.sort_by(|a, b| { - let a = channel_path_sorting_key(a, &self.channels_by_id); - let b = channel_path_sorting_key(b, &self.channels_by_id); + self.channels_ordered.sort_by(|a, b| { + let a = channel_path_sorting_key(*a, &self.channels_by_id); + let b = channel_path_sorting_key(*b, &self.channels_by_id); a.cmp(b) }); - self.paths.dedup(); + self.channels_ordered.dedup(); } } fn channel_path_sorting_key<'a>( - path: &'a [ChannelId], + id: ChannelId, channels_by_id: &'a BTreeMap>, -) -> impl 'a + Iterator> { - path.iter() - .map(|id| Some(channels_by_id.get(id)?.name.as_str())) +) -> impl Iterator { + let (parent_path, name) = channels_by_id + .get(&id) + .map_or((&[] as &[_], None), |channel| { + (channel.parent_path.as_slice(), Some(channel.name.as_str())) + }); + parent_path + .iter() + .filter_map(|id| Some(channels_by_id.get(id)?.name.as_str())) + .chain(name) } fn insert_note_changed( diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index 9303a52092..ff8761ee91 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -3,7 +3,7 @@ use crate::channel_chat::ChannelChatEvent; use super::*; use client::{test::FakeServer, Client, UserStore}; use gpui::{AppContext, ModelHandle, TestAppContext}; -use rpc::proto; +use rpc::proto::{self}; use settings::SettingsStore; use util::http::FakeHttpClient; @@ -18,16 +18,18 @@ fn test_update_channels(cx: &mut AppContext) { proto::Channel { id: 1, name: "b".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: Vec::new(), }, proto::Channel { id: 2, name: "a".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: Vec::new(), }, ], - channel_permissions: vec![proto::ChannelPermission { - channel_id: 1, - is_admin: true, - }], ..Default::default() }, cx, @@ -36,8 +38,8 @@ fn test_update_channels(cx: &mut AppContext) { &channel_store, &[ // - (0, "a".to_string(), false), - (0, "b".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Member), + (0, "b".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -49,20 +51,16 @@ fn test_update_channels(cx: &mut AppContext) { proto::Channel { id: 3, name: "x".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![1], }, proto::Channel { id: 4, name: "y".to_string(), - }, - ], - insert_edge: vec![ - proto::ChannelEdge { - parent_id: 1, - channel_id: 3, - }, - proto::ChannelEdge { - parent_id: 2, - channel_id: 4, + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: vec![2], }, ], ..Default::default() @@ -72,10 +70,10 @@ fn test_update_channels(cx: &mut AppContext) { assert_channels( &channel_store, &[ - (0, "a".to_string(), false), - (1, "y".to_string(), false), - (0, "b".to_string(), true), - (1, "x".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Member), + (1, "y".to_string(), proto::ChannelRole::Member), + (0, "b".to_string(), proto::ChannelRole::Admin), + (1, "x".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -92,30 +90,25 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { proto::Channel { id: 0, name: "a".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![], }, proto::Channel { id: 1, name: "b".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![0], }, proto::Channel { id: 2, name: "c".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Admin.into(), + parent_path: vec![0, 1], }, ], - insert_edge: vec![ - proto::ChannelEdge { - parent_id: 0, - channel_id: 1, - }, - proto::ChannelEdge { - parent_id: 1, - channel_id: 2, - }, - ], - channel_permissions: vec![proto::ChannelPermission { - channel_id: 0, - is_admin: true, - }], ..Default::default() }, cx, @@ -125,9 +118,9 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { &channel_store, &[ // - (0, "a".to_string(), true), - (1, "b".to_string(), true), - (2, "c".to_string(), true), + (0, "a".to_string(), proto::ChannelRole::Admin), + (1, "b".to_string(), proto::ChannelRole::Admin), + (2, "c".to_string(), proto::ChannelRole::Admin), ], cx, ); @@ -142,7 +135,11 @@ fn test_dangling_channel_paths(cx: &mut AppContext) { ); // Make sure that the 1/2/3 path is gone - assert_channels(&channel_store, &[(0, "a".to_string(), true)], cx); + assert_channels( + &channel_store, + &[(0, "a".to_string(), proto::ChannelRole::Admin)], + cx, + ); } #[gpui::test] @@ -158,12 +155,19 @@ async fn test_channel_messages(cx: &mut TestAppContext) { channels: vec![proto::Channel { id: channel_id, name: "the-channel".to_string(), + visibility: proto::ChannelVisibility::Members as i32, + role: proto::ChannelRole::Member.into(), + parent_path: vec![], }], ..Default::default() }); cx.foreground().run_until_parked(); cx.read(|cx| { - assert_channels(&channel_store, &[(0, "the-channel".to_string(), false)], cx); + assert_channels( + &channel_store, + &[(0, "the-channel".to_string(), proto::ChannelRole::Member)], + cx, + ); }); let get_users = server.receive::().await.unwrap(); @@ -181,7 +185,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { - let channel_id = store.channel_dag_entries().next().unwrap().1.id; + let channel_id = store.ordered_channels().next().unwrap().1.id; store.open_channel_chat(channel_id, cx) }); let join_channel = server.receive::().await.unwrap(); @@ -194,6 +198,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "a".into(), timestamp: 1000, sender_id: 5, + mentions: vec![], nonce: Some(1.into()), }, proto::ChannelMessage { @@ -201,6 +206,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "b".into(), timestamp: 1001, sender_id: 6, + mentions: vec![], nonce: Some(2.into()), }, ], @@ -247,6 +253,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "c".into(), timestamp: 1002, sender_id: 7, + mentions: vec![], nonce: Some(3.into()), }), }); @@ -284,7 +291,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Scroll up to view older messages. channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); + channel.load_more_messages(cx).unwrap().detach(); }); let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); @@ -300,6 +307,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 998, sender_id: 5, nonce: Some(4.into()), + mentions: vec![], }, proto::ChannelMessage { id: 9, @@ -307,6 +315,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 999, sender_id: 6, nonce: Some(5.into()), + mentions: vec![], }, ], }, @@ -358,19 +367,13 @@ fn update_channels( #[track_caller] fn assert_channels( channel_store: &ModelHandle, - expected_channels: &[(usize, String, bool)], + expected_channels: &[(usize, String, proto::ChannelRole)], cx: &AppContext, ) { let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() - .map(|(depth, channel)| { - ( - depth, - channel.name.to_string(), - store.is_user_admin(channel.id), - ) - }) + .ordered_channels() + .map(|(depth, channel)| (depth, channel.name.to_string(), channel.role)) .collect::>() }); assert_eq!(actual, expected_channels); diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 70878bf2e4..fd93aaeec8 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -4,7 +4,9 @@ use lazy_static::lazy_static; use parking_lot::Mutex; use serde::Serialize; use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration}; -use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; +use sysinfo::{ + CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt, +}; use tempfile::NamedTempFile; use util::http::HttpClient; use util::{channel::ReleaseChannel, TryFutureExt}; @@ -166,8 +168,16 @@ impl Telemetry { let this = self.clone(); cx.spawn(|mut cx| async move { - let mut system = System::new_all(); - system.refresh_all(); + // Avoiding calling `System::new_all()`, as there have been crashes related to it + let refresh_kind = RefreshKind::new() + .with_memory() // For memory usage + .with_processes(ProcessRefreshKind::everything()) // For process usage + .with_cpu(CpuRefreshKind::everything()); // For core count + + let mut system = System::new_with_specifics(refresh_kind); + + // Avoiding calling `refresh_all()`, just update what we need + system.refresh_specifics(refresh_kind); loop { // Waiting some amount of time before the first query is important to get a reasonable value @@ -175,8 +185,7 @@ impl Telemetry { const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60); smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await; - system.refresh_memory(); - system.refresh_processes(); + system.refresh_specifics(refresh_kind); let current_process = Pid::from_u32(std::process::id()); let Some(process) = system.processes().get(¤t_process) else { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 6aa41708e3..8299b7c6e4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -293,21 +293,19 @@ impl UserStore { // No need to paralellize here let mut updated_contacts = Vec::new(); for contact in message.contacts { - let should_notify = contact.should_notify; - updated_contacts.push(( - Arc::new(Contact::from_proto(contact, &this, &mut cx).await?), - should_notify, + updated_contacts.push(Arc::new( + Contact::from_proto(contact, &this, &mut cx).await?, )); } let mut incoming_requests = Vec::new(); for request in message.incoming_requests { - incoming_requests.push({ - let user = this - .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx)) - .await?; - (user, request.should_notify) - }); + incoming_requests.push( + this.update(&mut cx, |this, cx| { + this.get_user(request.requester_id, cx) + }) + .await?, + ); } let mut outgoing_requests = Vec::new(); @@ -330,13 +328,7 @@ impl UserStore { this.contacts .retain(|contact| !removed_contacts.contains(&contact.user.id)); // Update existing contacts and insert new ones - for (updated_contact, should_notify) in updated_contacts { - if should_notify { - cx.emit(Event::Contact { - user: updated_contact.user.clone(), - kind: ContactEventKind::Accepted, - }); - } + for updated_contact in updated_contacts { match this.contacts.binary_search_by_key( &&updated_contact.user.github_login, |contact| &contact.user.github_login, @@ -359,14 +351,7 @@ impl UserStore { } }); // Update existing incoming requests and insert new ones - for (user, should_notify) in incoming_requests { - if should_notify { - cx.emit(Event::Contact { - user: user.clone(), - kind: ContactEventKind::Requested, - }); - } - + for user in incoming_requests { match this .incoming_contact_requests .binary_search_by_key(&&user.github_login, |contact| { @@ -415,6 +400,12 @@ impl UserStore { &self.incoming_contact_requests } + pub fn has_incoming_contact_request(&self, user_id: u64) -> bool { + self.incoming_contact_requests + .iter() + .any(|user| user.id == user_id) + } + pub fn outgoing_contact_requests(&self) -> &[Arc] { &self.outgoing_contact_requests } diff --git a/crates/client2/src/client2.rs b/crates/client2/src/client2.rs index 79b0205c91..19e8685c28 100644 --- a/crates/client2/src/client2.rs +++ b/crates/client2/src/client2.rs @@ -14,8 +14,8 @@ use futures::{ future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt, }; use gpui2::{ - serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Handle, SemanticVersion, - Task, WeakHandle, + serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task, + WeakModel, }; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -227,7 +227,7 @@ struct ClientState { _reconnect_task: Option>, reconnect_interval: Duration, entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, - models_by_message_type: HashMap, + models_by_message_type: HashMap, entity_types_by_message_type: HashMap, #[allow(clippy::type_complexity)] message_handlers: HashMap< @@ -236,7 +236,7 @@ struct ClientState { dyn Send + Sync + Fn( - AnyHandle, + AnyModel, Box, &Arc, AsyncAppContext, @@ -246,7 +246,7 @@ struct ClientState { } enum WeakSubscriber { - Entity { handle: AnyWeakHandle }, + Entity { handle: AnyWeakModel }, Pending(Vec>), } @@ -312,9 +312,9 @@ pub struct PendingEntitySubscription { impl PendingEntitySubscription where - T: 'static + Send + Sync, + T: 'static + Send, { - pub fn set_model(mut self, model: &Handle, cx: &mut AsyncAppContext) -> Subscription { + pub fn set_model(mut self, model: &Model, cx: &mut AsyncAppContext) -> Subscription { self.consumed = true; let mut state = self.client.state.write(); let id = (TypeId::of::(), self.remote_id); @@ -529,7 +529,7 @@ impl Client { remote_id: u64, ) -> Result> where - T: 'static + Send + Sync, + T: 'static + Send, { let id = (TypeId::of::(), remote_id); @@ -552,13 +552,13 @@ impl Client { #[track_caller] pub fn add_message_handler( self: &Arc, - entity: WeakHandle, + entity: WeakModel, handler: H, ) -> Subscription where M: EnvelopedMessage, - E: 'static + Send + Sync, - H: 'static + Send + Sync + Fn(Handle, TypedEnvelope, Arc, AsyncAppContext) -> F, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { let message_type_id = TypeId::of::(); @@ -594,13 +594,13 @@ impl Client { pub fn add_request_handler( self: &Arc, - model: WeakHandle, + model: WeakModel, handler: H, ) -> Subscription where M: RequestMessage, - E: 'static + Send + Sync, - H: 'static + Send + Sync + Fn(Handle, TypedEnvelope, Arc, AsyncAppContext) -> F, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { self.add_message_handler(model, move |handle, envelope, this, cx| { @@ -615,8 +615,8 @@ impl Client { pub fn add_model_message_handler(self: &Arc, handler: H) where M: EntityMessage, - E: 'static + Send + Sync, - H: 'static + Send + Sync + Fn(Handle, TypedEnvelope, Arc, AsyncAppContext) -> F, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { self.add_entity_message_handler::(move |subscriber, message, client, cx| { @@ -627,8 +627,8 @@ impl Client { fn add_entity_message_handler(self: &Arc, handler: H) where M: EntityMessage, - E: 'static + Send + Sync, - H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + E: 'static + Send, + H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { let model_type_id = TypeId::of::(); @@ -666,8 +666,8 @@ impl Client { pub fn add_model_request_handler(self: &Arc, handler: H) where M: EntityMessage + RequestMessage, - E: 'static + Send + Sync, - H: 'static + Send + Sync + Fn(Handle, TypedEnvelope, Arc, AsyncAppContext) -> F, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { self.add_model_message_handler(move |entity, envelope, client, cx| { @@ -1546,7 +1546,7 @@ mod tests { let (done_tx1, mut done_rx1) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded(); client.add_model_message_handler( - move |model: Handle, _: TypedEnvelope, _, mut cx| { + move |model: Model, _: TypedEnvelope, _, mut cx| { match model.update(&mut cx, |model, _| model.id).unwrap() { 1 => done_tx1.try_send(()).unwrap(), 2 => done_tx2.try_send(()).unwrap(), @@ -1555,15 +1555,15 @@ mod tests { async { Ok(()) } }, ); - let model1 = cx.entity(|_| Model { + let model1 = cx.build_model(|_| TestModel { id: 1, subscription: None, }); - let model2 = cx.entity(|_| Model { + let model2 = cx.build_model(|_| TestModel { id: 2, subscription: None, }); - let model3 = cx.entity(|_| Model { + let model3 = cx.build_model(|_| TestModel { id: 3, subscription: None, }); @@ -1596,7 +1596,7 @@ mod tests { let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); let server = FakeServer::for_client(user_id, &client, cx).await; - let model = cx.entity(|_| Model::default()); + let model = cx.build_model(|_| TestModel::default()); let (done_tx1, _done_rx1) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded(); let subscription1 = client.add_message_handler( @@ -1624,11 +1624,11 @@ mod tests { let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); let server = FakeServer::for_client(user_id, &client, cx).await; - let model = cx.entity(|_| Model::default()); + let model = cx.build_model(|_| TestModel::default()); let (done_tx, mut done_rx) = smol::channel::unbounded(); let subscription = client.add_message_handler( model.clone().downgrade(), - move |model: Handle, _: TypedEnvelope, _, mut cx| { + move |model: Model, _: TypedEnvelope, _, mut cx| { model .update(&mut cx, |model, _| model.subscription.take()) .unwrap(); @@ -1644,7 +1644,7 @@ mod tests { } #[derive(Default)] - struct Model { + struct TestModel { id: usize, subscription: Option, } diff --git a/crates/client2/src/telemetry.rs b/crates/client2/src/telemetry.rs index 1b64e94107..47d1c143e1 100644 --- a/crates/client2/src/telemetry.rs +++ b/crates/client2/src/telemetry.rs @@ -5,7 +5,9 @@ use parking_lot::Mutex; use serde::Serialize; use settings2::Settings; use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration}; -use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; +use sysinfo::{ + CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt, +}; use tempfile::NamedTempFile; use util::http::HttpClient; use util::{channel::ReleaseChannel, TryFutureExt}; @@ -161,8 +163,16 @@ impl Telemetry { let this = self.clone(); cx.spawn(|cx| async move { - let mut system = System::new_all(); - system.refresh_all(); + // Avoiding calling `System::new_all()`, as there have been crashes related to it + let refresh_kind = RefreshKind::new() + .with_memory() // For memory usage + .with_processes(ProcessRefreshKind::everything()) // For process usage + .with_cpu(CpuRefreshKind::everything()); // For core count + + let mut system = System::new_with_specifics(refresh_kind); + + // Avoiding calling `refresh_all()`, just update what we need + system.refresh_specifics(refresh_kind); loop { // Waiting some amount of time before the first query is important to get a reasonable value @@ -170,8 +180,7 @@ impl Telemetry { const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60); smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await; - system.refresh_memory(); - system.refresh_processes(); + system.refresh_specifics(refresh_kind); let current_process = Pid::from_u32(std::process::id()); let Some(process) = system.processes().get(¤t_process) else { diff --git a/crates/client2/src/test.rs b/crates/client2/src/test.rs index 1b32d35092..f30547dcfc 100644 --- a/crates/client2/src/test.rs +++ b/crates/client2/src/test.rs @@ -1,7 +1,7 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{anyhow, Result}; use futures::{stream::BoxStream, StreamExt}; -use gpui2::{Context, Executor, Handle, TestAppContext}; +use gpui2::{Context, Executor, Model, TestAppContext}; use parking_lot::Mutex; use rpc2::{ proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, @@ -194,9 +194,9 @@ impl FakeServer { &self, client: Arc, cx: &mut TestAppContext, - ) -> Handle { + ) -> Model { let http_client = FakeHttpClient::with_404_response(); - let user_store = cx.entity(|cx| UserStore::new(client, http_client, cx)); + let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx)); assert_eq!( self.receive::() .await diff --git a/crates/client2/src/user.rs b/crates/client2/src/user.rs index 41cf46ea8f..2a8cf34af4 100644 --- a/crates/client2/src/user.rs +++ b/crates/client2/src/user.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result}; use collections::{hash_map::Entry, HashMap, HashSet}; use feature_flags2::FeatureFlagAppExt; use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt}; -use gpui2::{AsyncAppContext, EventEmitter, Handle, ImageData, ModelContext, Task}; +use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task}; use postage::{sink::Sink, watch}; use rpc2::proto::{RequestMessage, UsersResponse}; use std::sync::{Arc, Weak}; @@ -122,9 +122,9 @@ impl UserStore { let (mut current_user_tx, current_user_rx) = watch::channel(); let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscriptions = vec![ - client.add_message_handler(cx.weak_handle(), Self::handle_update_contacts), - client.add_message_handler(cx.weak_handle(), Self::handle_update_invite_info), - client.add_message_handler(cx.weak_handle(), Self::handle_show_contacts), + client.add_message_handler(cx.weak_model(), Self::handle_update_contacts), + client.add_message_handler(cx.weak_model(), Self::handle_update_invite_info), + client.add_message_handler(cx.weak_model(), Self::handle_show_contacts), ]; Self { users: Default::default(), @@ -213,7 +213,7 @@ impl UserStore { } async fn handle_update_invite_info( - this: Handle, + this: Model, message: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -229,7 +229,7 @@ impl UserStore { } async fn handle_show_contacts( - this: Handle, + this: Model, _: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -243,7 +243,7 @@ impl UserStore { } async fn handle_update_contacts( - this: Handle, + this: Model, message: TypedEnvelope, _: Arc, mut cx: AsyncAppContext, @@ -690,7 +690,7 @@ impl User { impl Contact { async fn from_proto( contact: proto::Contact, - user_store: &Handle, + user_store: &Model, cx: &mut AsyncAppContext, ) -> Result { let user = user_store diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b91f0e1a5f..987c295407 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] default-run = "collab" edition = "2021" name = "collab" -version = "0.24.0" +version = "0.27.0" publish = false [[bin]] @@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } node_runtime = { path = "../node_runtime" } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5a84bfd796..775a4c1bbe 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -44,7 +44,7 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "room_id" INTEGER REFERENCES rooms (id) NOT NULL, + "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE NOT NULL, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, "host_connection_id" INTEGER, "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, @@ -192,9 +192,13 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" VARCHAR NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT now + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "visibility" VARCHAR NOT NULL, + "parent_path" TEXT ); +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); + CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "user_id" INTEGER NOT NULL REFERENCES users (id), @@ -213,19 +217,22 @@ CREATE TABLE IF NOT EXISTS "channel_messages" ( "nonce" BLOB NOT NULL ); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); -CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); -CREATE TABLE "channel_paths" ( - "id_path" TEXT NOT NULL PRIMARY KEY, - "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) ); -CREATE INDEX "index_channel_paths_on_channel_id" ON "channel_paths" ("channel_id"); CREATE TABLE "channel_members" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, "admin" BOOLEAN NOT NULL DEFAULT false, + "role" VARCHAR, "accepted" BOOLEAN NOT NULL DEFAULT false, "updated_at" TIMESTAMP NOT NULL DEFAULT now ); @@ -312,3 +319,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( ); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231004130100_create_notifications.sql b/crates/collab/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000..93c282c631 --- /dev/null +++ b/crates/collab/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,22 @@ +CREATE TABLE "notification_kinds" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231011214412_add_guest_role.sql b/crates/collab/migrations/20231011214412_add_guest_role.sql new file mode 100644 index 0000000000..1713547158 --- /dev/null +++ b/crates/collab/migrations/20231011214412_add_guest_role.sql @@ -0,0 +1,4 @@ +ALTER TABLE channel_members ADD COLUMN role TEXT; +UPDATE channel_members SET role = CASE WHEN admin THEN 'admin' ELSE 'member' END; + +ALTER TABLE channels ADD COLUMN visibility TEXT NOT NULL DEFAULT 'members'; diff --git a/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql new file mode 100644 index 0000000000..be535ff7fa --- /dev/null +++ b/crates/collab/migrations/20231017185833_projects_room_id_fkey_on_delete_cascade.sql @@ -0,0 +1,8 @@ +-- Add migration script here + +ALTER TABLE projects + DROP CONSTRAINT projects_room_id_fkey, + ADD CONSTRAINT projects_room_id_fkey + FOREIGN KEY (room_id) + REFERENCES rooms (id) + ON DELETE CASCADE; diff --git a/crates/collab/migrations/20231018102700_create_mentions.sql b/crates/collab/migrations/20231018102700_create_mentions.sql new file mode 100644 index 0000000000..221a1748cf --- /dev/null +++ b/crates/collab/migrations/20231018102700_create_mentions.sql @@ -0,0 +1,11 @@ +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +-- We use 'on conflict update' with this index, so it should be per-user. +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); +DROP INDEX "index_channel_messages_on_nonce"; diff --git a/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql b/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql new file mode 100644 index 0000000000..d9fc6c8722 --- /dev/null +++ b/crates/collab/migrations/20231024085546_move_channel_paths_to_channels_table.sql @@ -0,0 +1,12 @@ +ALTER TABLE channels ADD COLUMN parent_path TEXT; + +UPDATE channels +SET parent_path = substr( + channel_paths.id_path, + 2, + length(channel_paths.id_path) - length('/' || channel_paths.channel_id::text || '/') +) +FROM channel_paths +WHERE channel_paths.channel_id = channels.id; + +CREATE INDEX "index_channels_on_parent_path" ON "channels" ("parent_path"); diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index cb1594e941..88fe0a647b 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -71,7 +71,6 @@ async fn main() { db::NewUserParams { github_login: github_user.login, github_user_id: github_user.id, - invite_count: 5, }, ) .await diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e60b7cc33d..df33416a46 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -20,7 +20,7 @@ use rpc::{ }; use sea_orm::{ entity::prelude::*, - sea_query::{Alias, Expr, OnConflict, Query}, + sea_query::{Alias, Expr, OnConflict}, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, @@ -47,14 +47,14 @@ pub use ids::*; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; -use self::queries::channels::ChannelGraph; - pub struct Database { options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, rng: Mutex, executor: Executor, + notification_kinds_by_id: HashMap, + notification_kinds_by_name: HashMap, #[cfg(test)] runtime: Option, } @@ -69,6 +69,8 @@ impl Database { pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), rng: Mutex::new(StdRng::seed_from_u64(0)), + notification_kinds_by_id: HashMap::default(), + notification_kinds_by_name: HashMap::default(), executor, #[cfg(test)] runtime: None, @@ -121,6 +123,11 @@ impl Database { Ok(new_migrations) } + pub async fn initialize_static_data(&mut self) -> Result<()> { + self.initialize_notification_kinds().await?; + Ok(()) + } + pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, @@ -361,18 +368,9 @@ impl RoomGuard { #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, + Accepted { user_id: UserId, busy: bool }, + Outgoing { user_id: UserId }, + Incoming { user_id: UserId }, } impl Contact { @@ -385,6 +383,15 @@ impl Contact { } } +pub type NotificationBatch = Vec<(UserId, proto::Notification)>; + +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, @@ -417,7 +424,6 @@ pub struct WaitlistSummary { pub struct NewUserParams { pub github_login: String, pub github_user_id: i32, - pub invite_count: i32, } #[derive(Debug)] @@ -428,17 +434,115 @@ pub struct NewUserResult { pub signup_device_id: Option, } -#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)] +#[derive(Debug)] +pub struct MoveChannelResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub moved_channels: HashSet, +} + +#[derive(Debug)] +pub struct RenameChannelResult { + pub channel: Channel, + pub participants_to_update: HashMap, +} + +#[derive(Debug)] +pub struct CreateChannelResult { + pub channel: Channel, + pub participants_to_update: Vec<(UserId, ChannelsForUser)>, +} + +#[derive(Debug)] +pub struct SetChannelVisibilityResult { + pub participants_to_update: HashMap, + pub participants_to_remove: HashSet, + pub channels_to_remove: Vec, +} + +#[derive(Debug)] +pub struct MembershipUpdated { + pub channel_id: ChannelId, + pub new_channels: ChannelsForUser, + pub removed_channels: Vec, +} + +#[derive(Debug)] +pub enum SetMemberRoleResult { + InviteUpdated(Channel), + MembershipUpdated(MembershipUpdated), +} + +#[derive(Debug)] +pub struct InviteMemberResult { + pub channel: Channel, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RespondToChannelInvite { + pub membership_update: Option, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RemoveChannelMemberResult { + pub membership_update: MembershipUpdated, + pub notification_id: Option, +} + +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Channel { pub id: ChannelId, pub name: String, + pub visibility: ChannelVisibility, + pub role: ChannelRole, + pub parent_path: Vec, +} + +impl Channel { + fn from_model(value: channel::Model, role: ChannelRole) -> Self { + Channel { + id: value.id, + visibility: value.visibility, + name: value.clone().name, + role, + parent_path: value.ancestors().collect(), + } + } + + pub fn to_proto(&self) -> proto::Channel { + proto::Channel { + id: self.id.to_proto(), + name: self.name.clone(), + visibility: self.visibility.into(), + role: self.role.into(), + parent_path: self.parent_path.iter().map(|c| c.to_proto()).collect(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ChannelMember { + pub role: ChannelRole, + pub user_id: UserId, + pub kind: proto::channel_member::Kind, +} + +impl ChannelMember { + pub fn to_proto(&self) -> proto::ChannelMember { + proto::ChannelMember { + role: self.role.into(), + user_id: self.user_id.to_proto(), + kind: self.kind.into(), + } + } } #[derive(Debug, PartialEq)] pub struct ChannelsForUser { - pub channels: ChannelGraph, + pub channels: Vec, pub channel_participants: HashMap>, - pub channels_with_admin_privileges: HashSet, pub unseen_buffer_changes: Vec, pub channel_messages: Vec, } diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 23bb9e53bf..5f0df90811 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -1,4 +1,5 @@ use crate::Result; +use rpc::proto; use sea_orm::{entity::prelude::*, DbErr}; use serde::{Deserialize, Serialize}; @@ -80,3 +81,119 @@ id_type!(SignupId); id_type!(UserId); id_type!(ChannelBufferCollaboratorId); id_type!(FlagId); +id_type!(NotificationId); +id_type!(NotificationKindId); + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelRole { + #[sea_orm(string_value = "admin")] + Admin, + #[sea_orm(string_value = "member")] + #[default] + Member, + #[sea_orm(string_value = "guest")] + Guest, + #[sea_orm(string_value = "banned")] + Banned, +} + +impl ChannelRole { + pub fn should_override(&self, other: Self) -> bool { + use ChannelRole::*; + match self { + Admin => matches!(other, Member | Banned | Guest), + Member => matches!(other, Banned | Guest), + Banned => matches!(other, Guest), + Guest => false, + } + } + + pub fn max(&self, other: Self) -> Self { + if self.should_override(other) { + *self + } else { + other + } + } + + pub fn can_see_all_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Admin | Member => true, + Guest | Banned => false, + } + } + + pub fn can_only_see_public_descendants(&self) -> bool { + use ChannelRole::*; + match self { + Guest => true, + Admin | Member | Banned => false, + } + } +} + +impl From for ChannelRole { + fn from(value: proto::ChannelRole) -> Self { + match value { + proto::ChannelRole::Admin => ChannelRole::Admin, + proto::ChannelRole::Member => ChannelRole::Member, + proto::ChannelRole::Guest => ChannelRole::Guest, + proto::ChannelRole::Banned => ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> proto::ChannelRole { + match self { + ChannelRole::Admin => proto::ChannelRole::Admin, + ChannelRole::Member => proto::ChannelRole::Member, + ChannelRole::Guest => proto::ChannelRole::Guest, + ChannelRole::Banned => proto::ChannelRole::Banned, + } + } +} + +impl Into for ChannelRole { + fn into(self) -> i32 { + let proto: proto::ChannelRole = self.into(); + proto.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum ChannelVisibility { + #[sea_orm(string_value = "public")] + Public, + #[sea_orm(string_value = "members")] + #[default] + Members, +} + +impl From for ChannelVisibility { + fn from(value: proto::ChannelVisibility) -> Self { + match value { + proto::ChannelVisibility::Public => ChannelVisibility::Public, + proto::ChannelVisibility::Members => ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> proto::ChannelVisibility { + match self { + ChannelVisibility::Public => proto::ChannelVisibility::Public, + ChannelVisibility::Members => proto::ChannelVisibility::Members, + } + } +} + +impl Into for ChannelVisibility { + fn into(self) -> i32 { + let proto: proto::ChannelVisibility = self.into(); + proto.into() + } +} diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 80bd8704b2..629e26f1a9 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod messages; +pub mod notifications; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/access_tokens.rs b/crates/collab/src/db/queries/access_tokens.rs index def9428a2b..589b6483df 100644 --- a/crates/collab/src/db/queries/access_tokens.rs +++ b/crates/collab/src/db/queries/access_tokens.rs @@ -1,4 +1,5 @@ use super::*; +use sea_orm::sea_query::Query; impl Database { pub async fn create_access_token( diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index c85432f2bb..9eddb1f618 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -16,7 +16,8 @@ impl Database { connection: ConnectionId, ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &tx) .await?; let buffer = channel::Model { @@ -129,9 +130,11 @@ impl Database { self.transaction(|tx| async move { let mut results = Vec::new(); for client_buffer in buffers { - let channel_id = ChannelId::from_proto(client_buffer.channel_id); + let channel = self + .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx) + .await?; if self - .check_user_is_channel_member(channel_id, user_id, &*tx) + .check_user_is_channel_participant(&channel, user_id, &*tx) .await .is_err() { @@ -139,9 +142,9 @@ impl Database { continue; } - let buffer = self.get_channel_buffer(channel_id, &*tx).await?; + let buffer = self.get_channel_buffer(channel.id, &*tx).await?; let mut collaborators = channel_buffer_collaborator::Entity::find() - .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id)) .all(&*tx) .await?; @@ -439,7 +442,8 @@ impl Database { Vec, )> { self.transaction(move |tx| async move { - self.check_user_is_channel_member(channel_id, user, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_member(&channel, user, &*tx) .await?; let buffer = buffer::Entity::find() @@ -482,7 +486,7 @@ impl Database { ) .await?; - channel_members = self.get_channel_members_internal(channel_id, &*tx).await?; + channel_members = self.get_channel_participants(&channel, &*tx).await?; let collaborators = self .get_channel_buffer_collaborators_internal(channel_id, &*tx) .await?; diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index c576d2406b..68b06e435d 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -1,8 +1,6 @@ use super::*; -use rpc::proto::ChannelEdge; -use smallvec::SmallVec; - -type ChannelDescendants = HashMap>; +use rpc::proto::channel_member::Kind; +use sea_orm::TryGetableMany; impl Database { #[cfg(test)] @@ -19,71 +17,242 @@ impl Database { .await } + #[cfg(test)] pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result { - self.create_channel(name, None, creator_id).await + Ok(self + .create_channel(name, None, creator_id) + .await? + .channel + .id) + } + + #[cfg(test)] + pub async fn create_sub_channel( + &self, + name: &str, + parent: ChannelId, + creator_id: UserId, + ) -> Result { + Ok(self + .create_channel(name, Some(parent), creator_id) + .await? + .channel + .id) } pub async fn create_channel( &self, name: &str, - parent: Option, - creator_id: UserId, - ) -> Result { + parent_channel_id: Option, + admin_id: UserId, + ) -> Result { let name = Self::sanitize_channel_name(name)?; self.transaction(move |tx| async move { - if let Some(parent) = parent { - self.check_user_is_channel_admin(parent, creator_id, &*tx) + let mut parent = None; + + if let Some(parent_channel_id) = parent_channel_id { + let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?; + self.check_user_is_channel_admin(&parent_channel, admin_id, &*tx) .await?; + parent = Some(parent_channel); } let channel = channel::ActiveModel { + id: ActiveValue::NotSet, name: ActiveValue::Set(name.to_string()), - ..Default::default() + visibility: ActiveValue::Set(ChannelVisibility::Members), + parent_path: ActiveValue::Set( + parent + .as_ref() + .map_or(String::new(), |parent| parent.path()), + ), } .insert(&*tx) .await?; - if let Some(parent) = parent { - let sql = r#" - INSERT INTO channel_paths - (id_path, channel_id) - SELECT - id_path || $1 || '/', $2 - FROM - channel_paths - WHERE - channel_id = $3 - "#; - let channel_paths_stmt = Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [ - channel.id.to_proto().into(), - channel.id.to_proto().into(), - parent.to_proto().into(), - ], - ); - tx.execute(channel_paths_stmt).await?; + let participants_to_update; + if let Some(parent) = &parent { + participants_to_update = self + .participants_to_notify_for_channel_change(parent, &*tx) + .await?; } else { - channel_path::Entity::insert(channel_path::ActiveModel { + participants_to_update = vec![]; + + channel_member::ActiveModel { + id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel.id), - id_path: ActiveValue::Set(format!("/{}/", channel.id)), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + }; + + Ok(CreateChannelResult { + channel: Channel::from_model(channel, ChannelRole::Admin), + participants_to_update, + }) + }) + .await + } + + pub async fn join_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + environment: &str, + ) -> Result<(JoinRoom, Option, ChannelRole)> { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let mut role = self.channel_role_for_user(&channel, user_id, &*tx).await?; + + let mut accept_invite_result = None; + + if role.is_none() { + if let Some(invitation) = self + .pending_invite_for_channel(&channel, user_id, &*tx) + .await? + { + // note, this may be a parent channel + role = Some(invitation.role); + channel_member::Entity::update(channel_member::ActiveModel { + accepted: ActiveValue::Set(true), + ..invitation.into_active_model() + }) + .exec(&*tx) + .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ); + + debug_assert!( + self.channel_role_for_user(&channel, user_id, &*tx).await? == role + ); + } + } + + if channel.visibility == ChannelVisibility::Public { + role = Some(ChannelRole::Guest); + let channel_to_join = self + .public_ancestors_including_self(&channel, &*tx) + .await? + .first() + .cloned() + .unwrap_or(channel.clone()); + + channel_member::Entity::insert(channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_to_join.id), + user_id: ActiveValue::Set(user_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Guest), }) .exec(&*tx) .await?; + + accept_invite_result = Some( + self.calculate_membership_updated(&channel_to_join, user_id, &*tx) + .await?, + ); + + debug_assert!(self.channel_role_for_user(&channel, user_id, &*tx).await? == role); } - channel_member::ActiveModel { - channel_id: ActiveValue::Set(channel.id), - user_id: ActiveValue::Set(creator_id), - accepted: ActiveValue::Set(true), - admin: ActiveValue::Set(true), - ..Default::default() + if role.is_none() || role == Some(ChannelRole::Banned) { + Err(anyhow!("not allowed"))? } - .insert(&*tx) - .await?; - Ok(channel.id) + let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + let room_id = self + .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) + .await?; + + self.join_channel_room_internal(room_id, user_id, connection, &*tx) + .await + .map(|jr| (jr, accept_invite_result, role.unwrap())) + }) + .await + } + + pub async fn set_channel_visibility( + &self, + channel_id: ChannelId, + visibility: ChannelVisibility, + admin_id: UserId, + ) -> Result { + self.transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + self.check_user_is_channel_admin(&channel, admin_id, &*tx) + .await?; + + let previous_members = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + let mut model = channel.into_active_model(); + model.visibility = ActiveValue::Set(visibility); + let channel = model.update(&*tx).await?; + + let mut participants_to_update: HashMap = self + .participants_to_notify_for_channel_change(&channel, &*tx) + .await? + .into_iter() + .collect(); + + let mut channels_to_remove: Vec = vec![]; + let mut participants_to_remove: HashSet = HashSet::default(); + match visibility { + ChannelVisibility::Members => { + let all_descendents: Vec = self + .get_channel_descendants_including_self(vec![channel_id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove = channel::Entity::find() + .filter( + channel::Column::Id + .is_in(all_descendents) + .and(channel::Column::Visibility.eq(ChannelVisibility::Public)), + ) + .all(&*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect(); + + channels_to_remove.push(channel_id); + + for member in previous_members { + if member.role.can_only_see_public_descendants() { + participants_to_remove.insert(member.user_id); + } + } + } + ChannelVisibility::Public => { + if let Some(public_parent) = self.public_parent_channel(&channel, &*tx).await? { + let parent_updates = self + .participants_to_notify_for_channel_change(&public_parent, &*tx) + .await?; + + for (user_id, channels) in parent_updates { + participants_to_update.insert(user_id, channels); + } + } + } + } + + Ok(SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + }) }) .await } @@ -94,37 +263,12 @@ impl Database { user_id: UserId, ) -> Result<(Vec, Vec)> { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, user_id, &*tx) .await?; - // Don't remove descendant channels that have additional parents. - let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?; - { - let mut channels_to_keep = channel_path::Entity::find() - .filter( - channel_path::Column::ChannelId - .is_in( - channels_to_remove - .keys() - .copied() - .filter(|&id| id != channel_id), - ) - .and( - channel_path::Column::IdPath - .not_like(&format!("%/{}/%", channel_id)), - ), - ) - .stream(&*tx) - .await?; - while let Some(row) = channels_to_keep.next().await { - let row = row?; - channels_to_remove.remove(&row.channel_id); - } - } - - let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?; let members_to_notify: Vec = channel_member::Entity::find() - .filter(channel_member::Column::ChannelId.is_in(channel_ancestors)) + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) .select_only() .column(channel_member::Column::UserId) .distinct() @@ -132,25 +276,19 @@ impl Database { .all(&*tx) .await?; + let channels_to_remove = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .map(|channel| channel.id) + .collect::>(); + channel::Entity::delete_many() - .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied())) + .filter(channel::Column::Id.is_in(channels_to_remove.iter().copied())) .exec(&*tx) .await?; - // Delete any other paths that include this channel - let sql = r#" - DELETE FROM channel_paths - WHERE - id_path LIKE '%' || $1 || '%' - "#; - let channel_paths_stmt = Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [channel_id.to_proto().into()], - ); - tx.execute(channel_paths_stmt).await?; - - Ok((channels_to_remove.into_keys().collect(), members_to_notify)) + Ok((channels_to_remove, members_to_notify)) }) .await } @@ -160,23 +298,44 @@ impl Database { channel_id: ChannelId, invitee_id: UserId, inviter_id: UserId, - is_admin: bool, - ) -> Result<()> { + role: ChannelRole, + ) -> Result { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, inviter_id, &*tx) .await?; channel_member::ActiveModel { + id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel_id), user_id: ActiveValue::Set(invitee_id), accepted: ActiveValue::Set(false), - admin: ActiveValue::Set(is_admin), - ..Default::default() + role: ActiveValue::Set(role), } .insert(&*tx) .await?; - Ok(()) + let channel = Channel::from_model(channel, role); + + let notifications = self + .create_notification( + invitee_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: channel.name.clone(), + inviter_id: inviter_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect(); + + Ok(InviteMemberResult { + channel, + notifications, + }) }) .await } @@ -192,24 +351,37 @@ impl Database { pub async fn rename_channel( &self, channel_id: ChannelId, - user_id: UserId, + admin_id: UserId, new_name: &str, - ) -> Result { + ) -> Result { self.transaction(move |tx| async move { let new_name = Self::sanitize_channel_name(new_name)?.to_string(); - self.check_user_is_channel_admin(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - channel::ActiveModel { - id: ActiveValue::Unchanged(channel_id), - name: ActiveValue::Set(new_name.clone()), - ..Default::default() - } - .update(&*tx) - .await?; + let mut model = channel.into_active_model(); + model.name = ActiveValue::Set(new_name.clone()); + let channel = model.update(&*tx).await?; - Ok(new_name) + let participants = self + .get_channel_participant_details_internal(&channel, &*tx) + .await?; + + Ok(RenameChannelResult { + channel: Channel::from_model(channel.clone(), role), + participants_to_update: participants + .iter() + .map(|participant| { + ( + participant.user_id, + Channel::from_model(channel.clone(), participant.role), + ) + }) + .collect(), + }) }) .await } @@ -219,10 +391,12 @@ impl Database { channel_id: ChannelId, user_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(move |tx| async move { - let rows_affected = if accept { - channel_member::Entity::update_many() + let channel = self.get_channel_internal(channel_id, &*tx).await?; + + let membership_update = if accept { + let rows_affected = channel_member::Entity::update_many() .set(channel_member::ActiveModel { accepted: ActiveValue::Set(accept), ..Default::default() @@ -235,35 +409,91 @@ impl Database { ) .exec(&*tx) .await? - .rows_affected - } else { - channel_member::ActiveModel { - channel_id: ActiveValue::Unchanged(channel_id), - user_id: ActiveValue::Unchanged(user_id), - ..Default::default() + .rows_affected; + + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; } - .delete(&*tx) - .await? - .rows_affected + + Some( + self.calculate_membership_updated(&channel, user_id, &*tx) + .await?, + ) + } else { + let rows_affected = channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected; + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; + } + + None }; - if rows_affected == 0 { - Err(anyhow!("no such invitation"))?; - } - - Ok(()) + Ok(RespondToChannelInvite { + membership_update, + notifications: self + .mark_notification_as_read_with_response( + user_id, + &rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + accept, + &*tx, + ) + .await? + .into_iter() + .collect(), + }) }) .await } + async fn calculate_membership_updated( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let new_channels = self.get_user_channels(user_id, Some(channel), &*tx).await?; + let removed_channels = self + .get_channel_descendants_including_self(vec![channel.id], &*tx) + .await? + .into_iter() + .filter_map(|channel| { + if !new_channels.channels.iter().any(|c| c.id == channel.id) { + Some(channel.id) + } else { + None + } + }) + .collect::>(); + + Ok(MembershipUpdated { + channel_id: channel.id, + new_channels, + removed_channels, + }) + } + pub async fn remove_channel_member( &self, channel_id: ChannelId, member_id: UserId, - remover_id: UserId, - ) -> Result<()> { + admin_id: UserId, + ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, remover_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; let result = channel_member::Entity::delete_many() @@ -279,13 +509,30 @@ impl Database { Err(anyhow!("no such member"))?; } - Ok(()) + Ok(RemoveChannelMemberResult { + membership_update: self + .calculate_membership_updated(&channel, member_id, &*tx) + .await?, + notification_id: self + .remove_notification( + member_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + &*tx, + ) + .await?, + }) }) .await } pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result> { self.transaction(|tx| async move { + let mut role_for_channel: HashMap = HashMap::default(); + let channel_invites = channel_member::Entity::find() .filter( channel_member::Column::UserId @@ -295,22 +542,20 @@ impl Database { .all(&*tx) .await?; + for invite in channel_invites { + role_for_channel.insert(invite.channel_id, invite.role); + } + let channels = channel::Entity::find() - .filter( - channel::Column::Id.is_in( - channel_invites - .into_iter() - .map(|channel_member| channel_member.channel_id), - ), - ) + .filter(channel::Column::Id.is_in(role_for_channel.keys().copied())) .all(&*tx) .await?; let channels = channels .into_iter() - .map(|channel| Channel { - id: channel.id, - name: channel.name, + .filter_map(|channel| { + let role = *role_for_channel.get(&channel.id)?; + Some(Channel::from_model(channel, role)) }) .collect(); @@ -319,88 +564,11 @@ impl Database { .await } - async fn get_channel_graph( - &self, - parents_by_child_id: ChannelDescendants, - trim_dangling_parents: bool, - tx: &DatabaseTransaction, - ) -> Result { - let mut channels = Vec::with_capacity(parents_by_child_id.len()); - { - let mut rows = channel::Entity::find() - .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied())) - .stream(&*tx) - .await?; - while let Some(row) = rows.next().await { - let row = row?; - channels.push(Channel { - id: row.id, - name: row.name, - }) - } - } - - let mut edges = Vec::with_capacity(parents_by_child_id.len()); - for (channel, parents) in parents_by_child_id.iter() { - for parent in parents.into_iter() { - if trim_dangling_parents { - if parents_by_child_id.contains_key(parent) { - edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }); - } - } else { - edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }); - } - } - } - - Ok(ChannelGraph { channels, edges }) - } - pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { self.transaction(|tx| async move { let tx = tx; - let channel_memberships = channel_member::Entity::find() - .filter( - channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::Accepted.eq(true)), - ) - .all(&*tx) - .await?; - - self.get_user_channels(user_id, channel_memberships, &tx) - .await - }) - .await - } - - pub async fn get_channel_for_user( - &self, - channel_id: ChannelId, - user_id: UserId, - ) -> Result { - self.transaction(|tx| async move { - let tx = tx; - - let channel_membership = channel_member::Entity::find() - .filter( - channel_member::Column::UserId - .eq(user_id) - .and(channel_member::Column::ChannelId.eq(channel_id)) - .and(channel_member::Column::Accepted.eq(true)), - ) - .all(&*tx) - .await?; - - self.get_user_channels(user_id, channel_membership, &tx) - .await + self.get_user_channels(user_id, None, &tx).await }) .await } @@ -408,22 +576,78 @@ impl Database { pub async fn get_user_channels( &self, user_id: UserId, - channel_memberships: Vec, + ancestor_channel: Option<&channel::Model>, tx: &DatabaseTransaction, ) -> Result { - let parents_by_child_id = self - .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx) + let channel_memberships = channel_member::Entity::find() + .filter( + channel_member::Column::UserId + .eq(user_id) + .and(channel_member::Column::Accepted.eq(true)), + ) + .all(&*tx) .await?; - let channels_with_admin_privileges = channel_memberships - .iter() - .filter_map(|membership| membership.admin.then_some(membership.channel_id)) + let descendants = self + .get_channel_descendants_including_self( + channel_memberships.iter().map(|m| m.channel_id), + &*tx, + ) + .await?; + + let mut roles_by_channel_id: HashMap = HashMap::default(); + for membership in channel_memberships.iter() { + roles_by_channel_id.insert(membership.channel_id, membership.role); + } + + let mut visible_channel_ids: HashSet = HashSet::default(); + + let channels: Vec = descendants + .into_iter() + .filter_map(|channel| { + let parent_role = channel + .parent_id() + .and_then(|parent_id| roles_by_channel_id.get(&parent_id)); + + let role = if let Some(parent_role) = parent_role { + let role = if let Some(existing_role) = roles_by_channel_id.get(&channel.id) { + existing_role.max(*parent_role) + } else { + *parent_role + }; + roles_by_channel_id.insert(channel.id, role); + role + } else { + *roles_by_channel_id.get(&channel.id)? + }; + + let can_see_parent_paths = role.can_see_all_descendants() + || role.can_only_see_public_descendants() + && channel.visibility == ChannelVisibility::Public; + if !can_see_parent_paths { + return None; + } + + visible_channel_ids.insert(channel.id); + + if let Some(ancestor) = ancestor_channel { + if !channel + .ancestors_including_self() + .any(|id| id == ancestor.id) + { + return None; + } + } + + let mut channel = Channel::from_model(channel, role); + channel + .parent_path + .retain(|id| visible_channel_ids.contains(&id)); + + Some(channel) + }) .collect(); - let graph = self - .get_channel_graph(parents_by_child_id, true, &tx) - .await?; - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIdsAndChannelIds { ChannelId, @@ -434,7 +658,7 @@ impl Database { { let mut rows = room_participant::Entity::find() .inner_join(room::Entity) - .filter(room::Column::ChannelId.is_in(graph.channels.iter().map(|c| c.id))) + .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id))) .select_only() .column(room::Column::ChannelId) .column(room_participant::Column::UserId) @@ -447,7 +671,7 @@ impl Database { } } - let channel_ids = graph.channels.iter().map(|c| c.id).collect::>(); + let channel_ids = channels.iter().map(|c| c.id).collect::>(); let channel_buffer_changes = self .unseen_channel_buffer_changes(user_id, &channel_ids, &*tx) .await?; @@ -457,228 +681,428 @@ impl Database { .await?; Ok(ChannelsForUser { - channels: graph, + channels, channel_participants, - channels_with_admin_privileges, unseen_buffer_changes: channel_buffer_changes, channel_messages: unseen_messages, }) } - pub async fn get_channel_members(&self, id: ChannelId) -> Result> { - self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await }) - .await + async fn participants_to_notify_for_channel_change( + &self, + new_parent: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut results: Vec<(UserId, ChannelsForUser)> = Vec::new(); + + let members = self + .get_channel_participant_details_internal(new_parent, &*tx) + .await?; + + for member in members.iter() { + if !member.role.can_see_all_descendants() { + continue; + } + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(new_parent), &*tx) + .await?, + )) + } + + let public_parents = self + .public_ancestors_including_self(new_parent, &*tx) + .await?; + let public_parent = public_parents.last(); + + let Some(public_parent) = public_parent else { + return Ok(results); + }; + + // could save some time in the common case by skipping this if the + // new channel is not public and has no public descendants. + let public_members = if public_parent == new_parent { + members + } else { + self.get_channel_participant_details_internal(public_parent, &*tx) + .await? + }; + + for member in public_members { + if !member.role.can_only_see_public_descendants() { + continue; + }; + results.push(( + member.user_id, + self.get_user_channels(member.user_id, Some(public_parent), &*tx) + .await?, + )) + } + + Ok(results) } - pub async fn set_channel_member_admin( + pub async fn set_channel_member_role( &self, channel_id: ChannelId, - from: UserId, + admin_id: UserId, for_user: UserId, - admin: bool, - ) -> Result<()> { + role: ChannelRole, + ) -> Result { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, from, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - let result = channel_member::Entity::update_many() + let membership = channel_member::Entity::find() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(for_user)), ) - .set(channel_member::ActiveModel { - admin: ActiveValue::set(admin), - ..Default::default() - }) - .exec(&*tx) + .one(&*tx) .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such member"))?; - } + let Some(membership) = membership else { + Err(anyhow!("no such member"))? + }; - Ok(()) + let mut update = membership.into_active_model(); + update.role = ActiveValue::Set(role); + let updated = channel_member::Entity::update(update).exec(&*tx).await?; + + if updated.accepted { + Ok(SetMemberRoleResult::MembershipUpdated( + self.calculate_membership_updated(&channel, for_user, &*tx) + .await?, + )) + } else { + Ok(SetMemberRoleResult::InviteUpdated(Channel::from_model( + channel, role, + ))) + } }) .await } - pub async fn get_channel_member_details( + pub async fn get_channel_participant_details( &self, channel_id: ChannelId, user_id: UserId, ) -> Result> { - self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, user_id, &*tx) - .await?; + let (role, members) = self + .transaction(move |tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + Ok(( + role, + self.get_channel_participant_details_internal(&channel, &*tx) + .await?, + )) + }) + .await?; - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryMemberDetails { - UserId, - Admin, - IsDirectMember, - Accepted, - } - - let tx = tx; - let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?; - let mut stream = channel_member::Entity::find() - .distinct() - .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied())) - .select_only() - .column(channel_member::Column::UserId) - .column(channel_member::Column::Admin) - .column_as( - channel_member::Column::ChannelId.eq(channel_id), - QueryMemberDetails::IsDirectMember, - ) - .column(channel_member::Column::Accepted) - .order_by_asc(channel_member::Column::UserId) - .into_values::<_, QueryMemberDetails>() - .stream(&*tx) - .await?; - - let mut rows = Vec::::new(); - while let Some(row) = stream.next().await { - let (user_id, is_admin, is_direct_member, is_invite_accepted): ( - UserId, - bool, - bool, - bool, - ) = row?; - let kind = match (is_direct_member, is_invite_accepted) { - (true, true) => proto::channel_member::Kind::Member, - (true, false) => proto::channel_member::Kind::Invitee, - (false, true) => proto::channel_member::Kind::AncestorMember, - (false, false) => continue, - }; - let user_id = user_id.to_proto(); - let kind = kind.into(); - if let Some(last_row) = rows.last_mut() { - if last_row.user_id == user_id { - if is_direct_member { - last_row.kind = kind; - last_row.admin = is_admin; - } - continue; + if role == ChannelRole::Admin { + Ok(members + .into_iter() + .map(|channel_member| channel_member.to_proto()) + .collect()) + } else { + return Ok(members + .into_iter() + .filter_map(|member| { + if member.kind == proto::channel_member::Kind::Invitee { + return None; } - } - rows.push(proto::ChannelMember { - user_id, - kind, - admin: is_admin, - }); - } - - Ok(rows) - }) - .await + Some(ChannelMember { + role: member.role, + user_id: member.user_id, + kind: proto::channel_member::Kind::Member, + }) + }) + .map(|channel_member| channel_member.to_proto()) + .collect()); + } } - pub async fn get_channel_members_internal( + async fn get_channel_participant_details_internal( &self, - id: ChannelId, + channel: &channel::Model, tx: &DatabaseTransaction, - ) -> Result> { - let ancestor_ids = self.get_channel_ancestors(id, tx).await?; - let user_ids = channel_member::Entity::find() - .distinct() - .filter( - channel_member::Column::ChannelId - .is_in(ancestor_ids.iter().copied()) - .and(channel_member::Column::Accepted.eq(true)), - ) + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryMemberDetails { + UserId, + Role, + IsDirectMember, + Accepted, + Visibility, + } + + let mut stream = channel_member::Entity::find() + .left_join(channel::Entity) + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) .select_only() .column(channel_member::Column::UserId) - .into_values::<_, QueryUserIds>() - .all(&*tx) + .column(channel_member::Column::Role) + .column_as( + channel_member::Column::ChannelId.eq(channel.id), + QueryMemberDetails::IsDirectMember, + ) + .column(channel_member::Column::Accepted) + .column(channel::Column::Visibility) + .into_values::<_, QueryMemberDetails>() + .stream(&*tx) .await?; - Ok(user_ids) + + let mut user_details: HashMap = HashMap::default(); + + while let Some(user_membership) = stream.next().await { + let (user_id, channel_role, is_direct_member, is_invite_accepted, visibility): ( + UserId, + ChannelRole, + bool, + bool, + ChannelVisibility, + ) = user_membership?; + let kind = match (is_direct_member, is_invite_accepted) { + (true, true) => proto::channel_member::Kind::Member, + (true, false) => proto::channel_member::Kind::Invitee, + (false, true) => proto::channel_member::Kind::AncestorMember, + (false, false) => continue, + }; + + if channel_role == ChannelRole::Guest + && visibility != ChannelVisibility::Public + && channel.visibility != ChannelVisibility::Public + { + continue; + } + + if let Some(details_mut) = user_details.get_mut(&user_id) { + if channel_role.should_override(details_mut.role) { + details_mut.role = channel_role; + } + if kind == Kind::Member { + details_mut.kind = kind; + // the UI is going to be a bit confusing if you already have permissions + // that are greater than or equal to the ones you're being invited to. + } else if kind == Kind::Invitee && details_mut.kind == Kind::AncestorMember { + details_mut.kind = kind; + } + } else { + user_details.insert( + user_id, + ChannelMember { + user_id, + kind, + role: channel_role, + }, + ); + } + } + + Ok(user_details + .into_iter() + .map(|(_, details)| details) + .collect()) } - pub async fn check_user_is_channel_member( + pub async fn get_channel_participants( &self, - channel_id: ChannelId, - user_id: UserId, + channel: &channel::Model, tx: &DatabaseTransaction, - ) -> Result<()> { - let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .is_in(channel_ids) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?; - Ok(()) + ) -> Result> { + let participants = self + .get_channel_participant_details_internal(channel, &*tx) + .await?; + Ok(participants + .into_iter() + .map(|member| member.user_id) + .collect()) } pub async fn check_user_is_channel_admin( &self, - channel_id: ChannelId, + channel: &channel::Model, user_id: UserId, tx: &DatabaseTransaction, - ) -> Result<()> { - let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) => Ok(role.unwrap()), + Some(ChannelRole::Member) + | Some(ChannelRole::Banned) + | Some(ChannelRole::Guest) + | None => Err(anyhow!( + "user is not a channel admin or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_member( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let channel_role = self.channel_role_for_user(channel, user_id, tx).await?; + match channel_role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()), + Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( + "user is not a channel member or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_participant( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result { + let role = self.channel_role_for_user(channel, user_id, tx).await?; + match role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => { + Ok(role.unwrap()) + } + Some(ChannelRole::Banned) | None => Err(anyhow!( + "user is not a channel participant or channel does not exist" + ))?, + } + } + + pub async fn pending_invite_for_channel( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + let row = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.is_in(channel.ancestors_including_self())) + .filter(channel_member::Column::UserId.eq(user_id)) + .filter(channel_member::Column::Accepted.eq(false)) + .one(&*tx) + .await?; + + Ok(row) + } + + pub async fn public_parent_channel( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let mut path = self.public_ancestors_including_self(channel, &*tx).await?; + if path.last().unwrap().id == channel.id { + path.pop(); + } + Ok(path.pop()) + } + + pub async fn public_ancestors_including_self( + &self, + channel: &channel::Model, + tx: &DatabaseTransaction, + ) -> Result> { + let visible_channels = channel::Entity::find() + .filter(channel::Column::Id.is_in(channel.ancestors_including_self())) + .filter(channel::Column::Visibility.eq(ChannelVisibility::Public)) + .order_by_asc(channel::Column::ParentPath) + .all(&*tx) + .await?; + + Ok(visible_channels) + } + + pub async fn channel_role_for_user( + &self, + channel: &channel::Model, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelMembership { + ChannelId, + Role, + Visibility, + } + + let mut rows = channel_member::Entity::find() + .left_join(channel::Entity) .filter( channel_member::Column::ChannelId - .is_in(channel_ids) + .is_in(channel.ancestors_including_self()) .and(channel_member::Column::UserId.eq(user_id)) - .and(channel_member::Column::Admin.eq(true)), + .and(channel_member::Column::Accepted.eq(true)), ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?; - Ok(()) - } - - /// Returns the channel ancestors, deepest first - pub async fn get_channel_ancestors( - &self, - channel_id: ChannelId, - tx: &DatabaseTransaction, - ) -> Result> { - let paths = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(channel_id)) - .order_by(channel_path::Column::IdPath, sea_orm::Order::Desc) - .all(tx) + .select_only() + .column(channel_member::Column::ChannelId) + .column(channel_member::Column::Role) + .column(channel::Column::Visibility) + .into_values::<_, QueryChannelMembership>() + .stream(&*tx) .await?; - let mut channel_ids = Vec::new(); - for path in paths { - for id in path.id_path.trim_matches('/').split('/') { - if let Ok(id) = id.parse() { - let id = ChannelId::from_proto(id); - if let Err(ix) = channel_ids.binary_search(&id) { - channel_ids.insert(ix, id); + + let mut user_role: Option = None; + + let mut is_participant = false; + let mut current_channel_visibility = None; + + // note these channels are not iterated in any particular order, + // our current logic takes the highest permission available. + while let Some(row) = rows.next().await { + let (membership_channel, role, visibility): ( + ChannelId, + ChannelRole, + ChannelVisibility, + ) = row?; + + match role { + ChannelRole::Admin | ChannelRole::Member | ChannelRole::Banned => { + if let Some(users_role) = user_role { + user_role = Some(users_role.max(role)); + } else { + user_role = Some(role) } } + ChannelRole::Guest if visibility == ChannelVisibility::Public => { + is_participant = true + } + ChannelRole::Guest => {} + } + if channel.id == membership_channel { + current_channel_visibility = Some(visibility); } } - Ok(channel_ids) + // free up database connection + drop(rows); + + if is_participant && user_role.is_none() { + if current_channel_visibility.is_none() { + current_channel_visibility = channel::Entity::find() + .filter(channel::Column::Id.eq(channel.id)) + .one(&*tx) + .await? + .map(|channel| channel.visibility); + } + if current_channel_visibility == Some(ChannelVisibility::Public) { + user_role = Some(ChannelRole::Guest); + } + } + + Ok(user_role) } - /// Returns the channel descendants, - /// Structured as a map from child ids to their parent ids - /// For example, the descendants of 'a' in this DAG: - /// - /// /- b -\ - /// a -- c -- d - /// - /// would be: - /// { - /// a: [], - /// b: [a], - /// c: [a], - /// d: [a, c], - /// } - async fn get_channel_descendants( + // Get the descendants of the given set if channels, ordered by their + // path. + async fn get_channel_descendants_including_self( &self, channel_ids: impl IntoIterator, tx: &DatabaseTransaction, - ) -> Result { + ) -> Result> { let mut values = String::new(); for id in channel_ids { if !values.is_empty() { @@ -688,403 +1112,201 @@ impl Database { } if values.is_empty() { - return Ok(HashMap::default()); + return Ok(vec![]); } let sql = format!( r#" - SELECT - descendant_paths.* + SELECT DISTINCT + descendant_channels.*, + descendant_channels.parent_path || descendant_channels.id as full_path FROM - channel_paths parent_paths, channel_paths descendant_paths + channels parent_channels, channels descendant_channels WHERE - parent_paths.channel_id IN ({values}) AND - descendant_paths.id_path LIKE (parent_paths.id_path || '%') - "# + descendant_channels.id IN ({values}) OR + ( + parent_channels.id IN ({values}) AND + descendant_channels.parent_path LIKE (parent_channels.parent_path || parent_channels.id || '/%') + ) + ORDER BY + full_path ASC + "# ); - let stmt = Statement::from_string(self.pool.get_database_backend(), sql); - - let mut parents_by_child_id: ChannelDescendants = HashMap::default(); - let mut paths = channel_path::Entity::find() - .from_raw_sql(stmt) - .stream(tx) - .await?; - - while let Some(path) = paths.next().await { - let path = path?; - let ids = path.id_path.trim_matches('/').split('/'); - let mut parent_id = None; - for id in ids { - if let Ok(id) = id.parse() { - let id = ChannelId::from_proto(id); - if id == path.channel_id { - break; - } - parent_id = Some(id); - } - } - let entry = parents_by_child_id.entry(path.channel_id).or_default(); - if let Some(parent_id) = parent_id { - entry.insert(parent_id); - } - } - - Ok(parents_by_child_id) + Ok(channel::Entity::find() + .from_raw_sql(Statement::from_string( + self.pool.get_database_backend(), + sql, + )) + .all(tx) + .await?) } - /// Returns the channel with the given ID and: - /// - true if the user is a member - /// - false if the user hasn't accepted the invitation yet - pub async fn get_channel( - &self, - channel_id: ChannelId, - user_id: UserId, - ) -> Result> { + /// Returns the channel with the given ID + pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result { self.transaction(|tx| async move { - let tx = tx; + let channel = self.get_channel_internal(channel_id, &*tx).await?; + let role = self + .check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; - let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?; - - if let Some(channel) = channel { - if self - .check_user_is_channel_member(channel_id, user_id, &*tx) - .await - .is_err() - { - return Ok(None); - } - - let channel_membership = channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .eq(channel_id) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await?; - - let is_accepted = channel_membership - .map(|membership| membership.accepted) - .unwrap_or(false); - - Ok(Some(( - Channel { - id: channel.id, - name: channel.name, - }, - is_accepted, - ))) - } else { - Ok(None) - } + Ok(Channel::from_model(channel, role)) }) .await } - pub async fn get_or_create_channel_room( + pub async fn get_channel_internal( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Entity::find_by_id(channel_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such channel"))?) + } + + pub(crate) async fn get_or_create_channel_room( &self, channel_id: ChannelId, live_kit_room: &str, - enviroment: &str, - ) -> Result { - self.transaction(|tx| async move { - let tx = tx; - - let room = room::Entity::find() - .filter(room::Column::ChannelId.eq(channel_id)) - .one(&*tx) - .await?; - - let room_id = if let Some(room) = room { - room.id - } else { - let result = room::Entity::insert(room::ActiveModel { - channel_id: ActiveValue::Set(Some(channel_id)), - live_kit_room: ActiveValue::Set(live_kit_room.to_string()), - enviroment: ActiveValue::Set(Some(enviroment.to_string())), - ..Default::default() - }) - .exec(&*tx) - .await?; - - result.last_insert_id - }; - - Ok(room_id) - }) - .await - } - - // Insert an edge from the given channel to the given other channel. - pub async fn link_channel( - &self, - user: UserId, - channel: ChannelId, - to: ChannelId, - ) -> Result { - self.transaction(|tx| async move { - // Note that even with these maxed permissions, this linking operation - // is still insecure because you can't remove someone's permissions to a - // channel if they've linked the channel to one where they're an admin. - self.check_user_is_channel_admin(channel, user, &*tx) - .await?; - - self.link_channel_internal(user, channel, to, &*tx).await - }) - .await - } - - pub async fn link_channel_internal( - &self, - user: UserId, - channel: ChannelId, - to: ChannelId, + environment: &str, tx: &DatabaseTransaction, - ) -> Result { - self.check_user_is_channel_admin(to, user, &*tx).await?; - - let paths = channel_path::Entity::find() - .filter(channel_path::Column::IdPath.like(&format!("%/{}/%", channel))) - .all(tx) + ) -> Result { + let room = room::Entity::find() + .filter(room::Column::ChannelId.eq(channel_id)) + .one(&*tx) .await?; - let mut new_path_suffixes = HashSet::default(); - for path in paths { - if let Some(start_offset) = path.id_path.find(&format!("/{}/", channel)) { - new_path_suffixes.insert(( - path.channel_id, - path.id_path[(start_offset + 1)..].to_string(), - )); - } - } - - let paths_to_new_parent = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(to)) - .all(tx) - .await?; - - let mut new_paths = Vec::new(); - for path in paths_to_new_parent { - if path.id_path.contains(&format!("/{}/", channel)) { - Err(anyhow!("cycle"))?; - } - - new_paths.extend(new_path_suffixes.iter().map(|(channel_id, path_suffix)| { - channel_path::ActiveModel { - channel_id: ActiveValue::Set(*channel_id), - id_path: ActiveValue::Set(format!("{}{}", &path.id_path, path_suffix)), + let room_id = if let Some(room) = room { + if let Some(env) = room.enviroment { + if &env != environment { + Err(anyhow!("must join using the {} release", env))?; } - })); - } - - channel_path::Entity::insert_many(new_paths) + } + room.id + } else { + let result = room::Entity::insert(room::ActiveModel { + channel_id: ActiveValue::Set(Some(channel_id)), + live_kit_room: ActiveValue::Set(live_kit_room.to_string()), + enviroment: ActiveValue::Set(Some(environment.to_string())), + ..Default::default() + }) .exec(&*tx) .await?; - // remove any root edges for the channel we just linked - { - channel_path::Entity::delete_many() - .filter(channel_path::Column::IdPath.like(&format!("/{}/%", channel))) - .exec(&*tx) - .await?; - } + result.last_insert_id + }; - let mut channel_descendants = self.get_channel_descendants([channel], &*tx).await?; - if let Some(channel) = channel_descendants.get_mut(&channel) { - // Remove the other parents - channel.clear(); - channel.insert(to); - } - - let channels = self - .get_channel_graph(channel_descendants, false, &*tx) - .await?; - - Ok(channels) + Ok(room_id) } - /// Unlink a channel from a given parent. This will add in a root edge if - /// the channel has no other parents after this operation. - pub async fn unlink_channel( - &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - ) -> Result<()> { - self.transaction(|tx| async move { - // Note that even with these maxed permissions, this linking operation - // is still insecure because you can't remove someone's permissions to a - // channel if they've linked the channel to one where they're an admin. - self.check_user_is_channel_admin(channel, user, &*tx) - .await?; - - self.unlink_channel_internal(user, channel, from, &*tx) - .await?; - - Ok(()) - }) - .await - } - - pub async fn unlink_channel_internal( - &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - tx: &DatabaseTransaction, - ) -> Result<()> { - self.check_user_is_channel_admin(from, user, &*tx).await?; - - let sql = r#" - DELETE FROM channel_paths - WHERE - id_path LIKE '%/' || $1 || '/' || $2 || '/%' - RETURNING id_path, channel_id - "#; - - let paths = channel_path::Entity::find() - .from_raw_sql(Statement::from_sql_and_values( - self.pool.get_database_backend(), - sql, - [from.to_proto().into(), channel.to_proto().into()], - )) - .all(&*tx) - .await?; - - let is_stranded = channel_path::Entity::find() - .filter(channel_path::Column::ChannelId.eq(channel)) - .count(&*tx) - .await? - == 0; - - // Make sure that there is always at least one path to the channel - if is_stranded { - let root_paths: Vec<_> = paths - .iter() - .map(|path| { - let start_offset = path.id_path.find(&format!("/{}/", channel)).unwrap(); - channel_path::ActiveModel { - channel_id: ActiveValue::Set(path.channel_id), - id_path: ActiveValue::Set(path.id_path[start_offset..].to_string()), - } - }) - .collect(); - channel_path::Entity::insert_many(root_paths) - .exec(&*tx) - .await?; - } - - Ok(()) - } - - /// Move a channel from one parent to another, returns the - /// Channels that were moved for notifying clients + /// Move a channel from one parent to another pub async fn move_channel( &self, - user: UserId, - channel: ChannelId, - from: ChannelId, - to: ChannelId, - ) -> Result { - if from == to { - return Ok(ChannelGraph { - channels: vec![], - edges: vec![], - }); - } - + channel_id: ChannelId, + new_parent_id: Option, + admin_id: UserId, + ) -> Result> { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel, user, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_admin(&channel, admin_id, &*tx) .await?; - let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?; + let new_parent_path; + let new_parent_channel; + if let Some(new_parent_id) = new_parent_id { + let new_parent = self.get_channel_internal(new_parent_id, &*tx).await?; + self.check_user_is_channel_admin(&new_parent, admin_id, &*tx) + .await?; - self.unlink_channel_internal(user, channel, from, &*tx) + new_parent_path = new_parent.path(); + new_parent_channel = Some(new_parent); + } else { + new_parent_path = String::new(); + new_parent_channel = None; + }; + + let previous_participants = self + .get_channel_participant_details_internal(&channel, &*tx) .await?; - Ok(moved_channels) + let old_path = format!("{}{}/", channel.parent_path, channel.id); + let new_path = format!("{}{}/", new_parent_path, channel.id); + + if old_path == new_path { + return Ok(None); + } + + let mut model = channel.into_active_model(); + model.parent_path = ActiveValue::Set(new_parent_path); + let channel = model.update(&*tx).await?; + + if new_parent_channel.is_none() { + channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(admin_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Admin), + } + .insert(&*tx) + .await?; + } + + let descendent_ids = + ChannelId::find_by_statement::(Statement::from_sql_and_values( + self.pool.get_database_backend(), + " + UPDATE channels SET parent_path = REPLACE(parent_path, $1, $2) + WHERE parent_path LIKE $3 || '%' + RETURNING id + ", + [old_path.clone().into(), new_path.into(), old_path.into()], + )) + .all(&*tx) + .await?; + + let participants_to_update: HashMap<_, _> = self + .participants_to_notify_for_channel_change( + new_parent_channel.as_ref().unwrap_or(&channel), + &*tx, + ) + .await? + .into_iter() + .collect(); + + let mut moved_channels: HashSet = HashSet::default(); + for id in descendent_ids { + moved_channels.insert(id); + } + moved_channels.insert(channel_id); + + let mut participants_to_remove: HashSet = HashSet::default(); + for participant in previous_participants { + if participant.kind == proto::channel_member::Kind::AncestorMember { + if !participants_to_update.contains_key(&participant.user_id) { + participants_to_remove.insert(participant.user_id); + } + } + } + + Ok(Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + })) }) .await } } +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +enum QueryIds { + Id, +} + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIds { UserId, } - -#[derive(Debug)] -pub struct ChannelGraph { - pub channels: Vec, - pub edges: Vec, -} - -impl ChannelGraph { - pub fn is_empty(&self) -> bool { - self.channels.is_empty() && self.edges.is_empty() - } -} - -#[cfg(test)] -impl PartialEq for ChannelGraph { - fn eq(&self, other: &Self) -> bool { - // Order independent comparison for tests - let channels_set = self.channels.iter().collect::>(); - let other_channels_set = other.channels.iter().collect::>(); - let edges_set = self - .edges - .iter() - .map(|edge| (edge.channel_id, edge.parent_id)) - .collect::>(); - let other_edges_set = other - .edges - .iter() - .map(|edge| (edge.channel_id, edge.parent_id)) - .collect::>(); - - channels_set == other_channels_set && edges_set == other_edges_set - } -} - -#[cfg(not(test))] -impl PartialEq for ChannelGraph { - fn eq(&self, other: &Self) -> bool { - self.channels == other.channels && self.edges == other.edges - } -} - -struct SmallSet(SmallVec<[T; 1]>); - -impl Deref for SmallSet { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -impl Default for SmallSet { - fn default() -> Self { - Self(SmallVec::new()) - } -} - -impl SmallSet { - fn insert(&mut self, value: T) -> bool - where - T: Ord, - { - match self.binary_search(&value) { - Ok(_) => false, - Err(ix) => { - self.0.insert(ix, value); - true - } - } - } - - fn clear(&mut self) { - self.0.clear(); - } -} diff --git a/crates/collab/src/db/queries/contacts.rs b/crates/collab/src/db/queries/contacts.rs index 2171f1a6bf..f31f1addbd 100644 --- a/crates/collab/src/db/queries/contacts.rs +++ b/crates/collab/src/db/queries/contacts.rs @@ -8,7 +8,6 @@ impl Database { user_id_b: UserId, a_to_b: bool, accepted: bool, - should_notify: bool, user_a_busy: bool, user_b_busy: bool, } @@ -53,7 +52,6 @@ impl Database { if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, busy: db_contact.user_b_busy, }); } else if db_contact.a_to_b { @@ -63,19 +61,16 @@ impl Database { } else { contacts.push(Contact::Incoming { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, }); } } else if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, busy: db_contact.user_a_busy, }); } else if db_contact.a_to_b { contacts.push(Contact::Incoming { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, }); } else { contacts.push(Contact::Outgoing { @@ -124,7 +119,11 @@ impl Database { .await } - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + pub async fn send_contact_request( + &self, + sender_id: UserId, + receiver_id: UserId, + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) @@ -161,11 +160,22 @@ impl Database { .exec_without_returning(&*tx) .await?; - if rows_affected == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? + if rows_affected == 0 { + Err(anyhow!("contact already requested"))?; } + + Ok(self + .create_notification( + receiver_id, + rpc::Notification::ContactRequest { + sender_id: sender_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -179,7 +189,11 @@ impl Database { /// /// * `requester_id` - The user that initiates this request /// * `responder_id` - The user that will be removed - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result { + pub async fn remove_contact( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<(bool, Option)> { self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) @@ -198,7 +212,21 @@ impl Database { .ok_or_else(|| anyhow!("no such contact"))?; contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; - Ok(contact.accepted) + + let mut deleted_notification_id = None; + if !contact.accepted { + deleted_notification_id = self + .remove_notification( + responder_id, + rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + &*tx, + ) + .await?; + } + + Ok((contact.accepted, deleted_notification_id)) }) .await } @@ -249,7 +277,7 @@ impl Database { responder_id: UserId, requester_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) @@ -287,11 +315,38 @@ impl Database { result.rows_affected }; - if rows_affected == 1 { - Ok(()) - } else { + if rows_affected == 0 { Err(anyhow!("no such contact request"))? } + + let mut notifications = Vec::new(); + notifications.extend( + self.mark_notification_as_read_with_response( + responder_id, + &rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + accept, + &*tx, + ) + .await?, + ); + + if accept { + notifications.extend( + self.create_notification( + requester_id, + rpc::Notification::ContactRequestAccepted { + responder_id: responder_id.to_proto(), + }, + true, + &*tx, + ) + .await?, + ); + } + + Ok(notifications) }) .await } diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index a48d425d90..47bb27df39 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,4 +1,6 @@ use super::*; +use rpc::Notification; +use sea_orm::TryInsertResult; use time::OffsetDateTime; impl Database { @@ -9,7 +11,8 @@ impl Database { user_id: UserId, ) -> Result<()> { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) .await?; channel_chat_participant::ActiveModel { id: ActiveValue::NotSet, @@ -77,7 +80,8 @@ impl Database { before_message_id: Option, ) -> Result> { self.transaction(|tx| async move { - self.check_user_is_channel_member(channel_id, user_id, &*tx) + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) .await?; let mut condition = @@ -87,33 +91,103 @@ impl Database { condition = condition.add(channel_message::Column::Id.lt(before_message_id)); } - let mut rows = channel_message::Entity::find() + let rows = channel_message::Entity::find() .filter(condition) .order_by_desc(channel_message::Column::Id) .limit(count as u64) - .stream(&*tx) + .all(&*tx) .await?; - let mut messages = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; + self.load_channel_messages(rows, &*tx).await + }) + .await + } + + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) + .all(&*tx) + .await?; + + let mut channels = HashMap::::default(); + for row in &rows { + channels.insert( + row.channel_id, + self.get_channel_internal(row.channel_id, &*tx).await?, + ); + } + + for (_, channel) in channels { + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + } + + let messages = self.load_channel_messages(rows, &*tx).await?; + Ok(messages) + }) + .await + } + + async fn load_channel_messages( + &self, + rows: Vec, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = rows + .into_iter() + .map(|row| { let nonce = row.nonce.as_u64_pair(); - messages.push(proto::ChannelMessage { + proto::ChannelMessage { id: row.id.to_proto(), sender_id: row.sender_id.to_proto(), body: row.body, timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], nonce: Some(proto::Nonce { upper_half: nonce.0, lower_half: nonce.1, }), - }); + } + }) + .collect::>(); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } } - drop(rows); - messages.reverse(); - Ok(messages) - }) - .await + } + + Ok(messages) } pub async fn create_channel_message( @@ -121,10 +195,15 @@ impl Database { channel_id: ChannelId, user_id: UserId, body: &str, + mentions: &[proto::ChatMention], timestamp: OffsetDateTime, nonce: u128, - ) -> Result<(MessageId, Vec, Vec)> { + ) -> Result { self.transaction(|tx| async move { + let channel = self.get_channel_internal(channel_id, &*tx).await?; + self.check_user_is_channel_participant(&channel, user_id, &*tx) + .await?; + let mut rows = channel_chat_participant::Entity::find() .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) .stream(&*tx) @@ -150,7 +229,7 @@ impl Database { let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); - let message = channel_message::Entity::insert(channel_message::ActiveModel { + let result = channel_message::Entity::insert(channel_message::ActiveModel { channel_id: ActiveValue::Set(channel_id), sender_id: ActiveValue::Set(user_id), body: ActiveValue::Set(body.to_string()), @@ -159,35 +238,85 @@ impl Database { id: ActiveValue::NotSet, }) .on_conflict( - OnConflict::column(channel_message::Column::Nonce) - .update_column(channel_message::Column::Nonce) - .to_owned(), + OnConflict::columns([ + channel_message::Column::SenderId, + channel_message::Column::Nonce, + ]) + .do_nothing() + .to_owned(), ) + .do_nothing() .exec(&*tx) .await?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryConnectionId { - ConnectionId, + let message_id; + let mut notifications = Vec::new(); + match result { + TryInsertResult::Inserted(result) => { + message_id = result.last_insert_id; + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions + .iter() + .filter_map(|mention| { + let range = mention.range.as_ref()?; + if !body.is_char_boundary(range.start as usize) + || !body.is_char_boundary(range.end as usize) + { + return None; + } + Some(channel_message_mention::ActiveModel { + message_id: ActiveValue::Set(message_id), + start_offset: ActiveValue::Set(range.start as i32), + end_offset: ActiveValue::Set(range.end as i32), + user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), + }) + }) + .collect::>(); + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) + .exec(&*tx) + .await?; + } + + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + } + _ => { + message_id = channel_message::Entity::find() + .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert message"))? + .id; + } } - // Observe this message for the sender - self.observe_channel_message_internal( - channel_id, - user_id, - message.last_insert_id, - &*tx, - ) - .await?; - - let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?; + let mut channel_members = self.get_channel_participants(&channel, &*tx).await?; channel_members.retain(|member| !participant_user_ids.contains(member)); - Ok(( - message.last_insert_id, + Ok(CreatedChannelMessage { + message_id, participant_connection_ids, channel_members, - )) + notifications, + }) }) .await } @@ -197,11 +326,24 @@ impl Database { channel_id: ChannelId, user_id: UserId, message_id: MessageId, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) .await?; - Ok(()) + let mut batch = NotificationBatch::default(); + batch.extend( + self.mark_notification_as_read( + user_id, + &Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: Default::default(), + channel_id: Default::default(), + }, + &*tx, + ) + .await?, + ); + Ok(batch) }) .await } @@ -337,8 +479,23 @@ impl Database { .filter(channel_message::Column::SenderId.eq(user_id)) .exec(&*tx) .await?; + if result.rows_affected == 0 { - Err(anyhow!("no such message"))?; + let channel = self.get_channel_internal(channel_id, &*tx).await?; + if self + .check_user_is_channel_admin(&channel, user_id, &*tx) + .await + .is_ok() + { + let result = channel_message::Entity::delete_by_id(message_id) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such message"))?; + } + } else { + Err(anyhow!("operation could not be completed"))?; + } } Ok(participant_connection_ids) diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs new file mode 100644 index 0000000000..6f2511c23e --- /dev/null +++ b/crates/collab/src/db/queries/notifications.rs @@ -0,0 +1,262 @@ +use super::*; +use rpc::Notification; + +impl Database { + pub async fn initialize_notification_kinds(&mut self) -> Result<()> { + notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map( + |kind| notification_kind::ActiveModel { + name: ActiveValue::Set(kind.to_string()), + ..Default::default() + }, + )) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&self.pool) + .await?; + + let mut rows = notification_kind::Entity::find().stream(&self.pool).await?; + while let Some(row) = rows.next().await { + let row = row?; + self.notification_kinds_by_name.insert(row.name, row.id); + } + + for name in Notification::all_variant_names() { + if let Some(id) = self.notification_kinds_by_name.get(*name).copied() { + self.notification_kinds_by_id.insert(id, name); + } + } + + Ok(()) + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + before_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut condition = + Condition::all().add(notification::Column::RecipientId.eq(recipient_id)); + + if let Some(before_id) = before_id { + condition = condition.add(notification::Column::Id.lt(before_id)); + } + + let mut rows = notification::Entity::find() + .filter(condition) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let kind = row.kind; + if let Some(proto) = model_to_proto(self, row) { + result.push(proto); + } else { + log::warn!("unknown notification kind {:?}", kind); + } + } + result.reverse(); + Ok(result) + }) + .await + } + + /// Create a notification. If `avoid_duplicates` is set to true, then avoid + /// creating a new notification if the given recipient already has an + /// unread notification with the given kind and entity id. + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + avoid_duplicates: bool, + tx: &DatabaseTransaction, + ) -> Result> { + if avoid_duplicates { + if self + .find_notification(recipient_id, ¬ification, tx) + .await? + .is_some() + { + return Ok(None); + } + } + + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + let model = notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind), + entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)), + content: ActiveValue::Set(proto.content.clone()), + ..Default::default() + } + .save(&*tx) + .await?; + + Ok(Some(( + recipient_id, + proto::Notification { + id: model.id.as_ref().to_proto(), + kind: proto.kind, + timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, + is_read: false, + response: None, + content: proto.content, + entity_id: proto.entity_id, + }, + ))) + } + + /// Remove an unread notification with the given recipient, kind and + /// entity id. + pub async fn remove_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let id = self + .find_notification(recipient_id, ¬ification, tx) + .await?; + if let Some(id) = id { + notification::Entity::delete_by_id(id).exec(tx).await?; + } + Ok(id) + } + + /// Populate the response for the notification with the given kind and + /// entity id. + pub async fn mark_notification_as_read_with_response( + &self, + recipient_id: UserId, + notification: &Notification, + response: bool, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx) + .await + } + + pub async fn mark_notification_as_read( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, None, tx) + .await + } + + pub async fn mark_notification_as_read_by_id( + &self, + recipient_id: UserId, + notification_id: NotificationId, + ) -> Result { + self.transaction(|tx| async move { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(notification_id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .into_iter() + .collect()) + }) + .await + } + + async fn mark_notification_as_read_internal( + &self, + recipient_id: UserId, + notification: &Notification, + response: Option, + tx: &DatabaseTransaction, + ) -> Result> { + if let Some(id) = self + .find_notification(recipient_id, notification, &*tx) + .await? + { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + response: if let Some(response) = response { + ActiveValue::Set(Some(response)) + } else { + ActiveValue::NotSet + }, + ..Default::default() + }) + .exec(tx) + .await?; + Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + } else { + Ok(None) + } + } + + /// Find an unread notification by its recipient, kind and entity id. + async fn find_notification( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryIds { + Id, + } + + Ok(notification::Entity::find() + .select_only() + .column(notification::Column::Id) + .filter( + Condition::all() + .add(notification::Column::RecipientId.eq(recipient_id)) + .add(notification::Column::IsRead.eq(false)) + .add(notification::Column::Kind.eq(kind)) + .add(if proto.entity_id.is_some() { + notification::Column::EntityId.eq(proto.entity_id) + } else { + notification::Column::EntityId.is_null() + }), + ) + .into_values::<_, QueryIds>() + .one(&*tx) + .await?) + } +} + +fn model_to_proto(this: &Database, row: notification::Model) -> Option { + let kind = this.notification_kinds_by_id.get(&row.kind)?; + Some(proto::Notification { + id: row.id.to_proto(), + kind: kind.to_string(), + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + response: row.response, + content: row.content, + entity_id: row.entity_id.map(|id| id as u64), + }) +} + +fn notification_kind_from_proto( + this: &Database, + proto: &proto::Notification, +) -> Result { + Ok(this + .notification_kinds_by_name + .get(&proto.kind) + .copied() + .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?) +} diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index a38c77dc0f..40fdf5d58f 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -50,10 +50,10 @@ impl Database { .map(|participant| participant.user_id), ); - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; + let (channel, room) = self.get_channel_room(room_id, &tx).await?; let channel_members; - if let Some(channel_id) = channel_id { - channel_members = self.get_channel_members_internal(channel_id, &tx).await?; + if let Some(channel) = &channel { + channel_members = self.get_channel_participants(channel, &tx).await?; } else { channel_members = Vec::new(); @@ -69,7 +69,7 @@ impl Database { Ok(RefreshedRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, stale_participant_user_ids, canceled_calls_to_user_ids, @@ -298,98 +298,137 @@ impl Database { } } - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryParticipantIndices { - ParticipantIndex, + if channel_id.is_some() { + Err(anyhow!("tried to join channel call directly"))? } - let existing_participant_indices: Vec = room_participant::Entity::find() - .filter( - room_participant::Column::RoomId - .eq(room_id) - .and(room_participant::Column::ParticipantIndex.is_not_null()), - ) - .select_only() - .column(room_participant::Column::ParticipantIndex) - .into_values::<_, QueryParticipantIndices>() - .all(&*tx) + + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) .await?; - let mut participant_index = 0; - while existing_participant_indices.contains(&participant_index) { - participant_index += 1; - } - - if let Some(channel_id) = channel_id { - self.check_user_is_channel_member(channel_id, user_id, &*tx) - .await?; - - room_participant::Entity::insert_many([room_participant::ActiveModel { - room_id: ActiveValue::set(room_id), - user_id: ActiveValue::set(user_id), + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + participant_index: ActiveValue::Set(Some(participant_index)), answering_connection_id: ActiveValue::set(Some(connection.id as i32)), answering_connection_server_id: ActiveValue::set(Some(ServerId( connection.owner_id as i32, ))), answering_connection_lost: ActiveValue::set(false), - calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection.id as i32), - calling_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - participant_index: ActiveValue::Set(Some(participant_index)), ..Default::default() - }]) - .on_conflict( - OnConflict::columns([room_participant::Column::UserId]) - .update_columns([ - room_participant::Column::AnsweringConnectionId, - room_participant::Column::AnsweringConnectionServerId, - room_participant::Column::AnsweringConnectionLost, - room_participant::Column::ParticipantIndex, - ]) - .to_owned(), - ) + }) .exec(&*tx) .await?; - } else { - let result = room_participant::Entity::update_many() - .filter( - Condition::all() - .add(room_participant::Column::RoomId.eq(room_id)) - .add(room_participant::Column::UserId.eq(user_id)) - .add(room_participant::Column::AnsweringConnectionId.is_null()), - ) - .set(room_participant::ActiveModel { - participant_index: ActiveValue::Set(Some(participant_index)), - answering_connection_id: ActiveValue::set(Some(connection.id as i32)), - answering_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - answering_connection_lost: ActiveValue::set(false), - ..Default::default() - }) - .exec(&*tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("room does not exist or was already joined"))?; - } + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))?; } let room = self.get_room(room_id, &tx).await?; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? - } else { - Vec::new() - }; Ok(JoinRoom { room, - channel_id, - channel_members, + channel_id: None, + channel_members: vec![], }) }) .await } + async fn get_next_participant_index_internal( + &self, + room_id: RoomId, + tx: &DatabaseTransaction, + ) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryParticipantIndices { + ParticipantIndex, + } + let existing_participant_indices: Vec = room_participant::Entity::find() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::ParticipantIndex.is_not_null()), + ) + .select_only() + .column(room_participant::Column::ParticipantIndex) + .into_values::<_, QueryParticipantIndices>() + .all(&*tx) + .await?; + + let mut participant_index = 0; + while existing_participant_indices.contains(&participant_index) { + participant_index += 1; + } + + Ok(participant_index) + } + + pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result> { + self.transaction(|tx| async move { + let room: Option = room::Entity::find() + .filter(room::Column::Id.eq(room_id)) + .one(&*tx) + .await?; + + Ok(room.and_then(|room| room.channel_id)) + }) + .await + } + + pub(crate) async fn join_channel_room_internal( + &self, + room_id: RoomId, + user_id: UserId, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result { + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) + .await?; + + room_participant::Entity::insert_many([room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::Set(Some(participant_index)), + ..Default::default() + }]) + .on_conflict( + OnConflict::columns([room_participant::Column::UserId]) + .update_columns([ + room_participant::Column::AnsweringConnectionId, + room_participant::Column::AnsweringConnectionServerId, + room_participant::Column::AnsweringConnectionLost, + room_participant::Column::ParticipantIndex, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; + let channel_members = self.get_channel_participants(&channel, &*tx).await?; + Ok(JoinRoom { + room, + channel_id: Some(channel.id), + channel_members, + }) + } + pub async fn rejoin_room( &self, rejoin_room: proto::RejoinRoom, @@ -679,16 +718,16 @@ impl Database { }); } - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? + let (channel, room) = self.get_channel_room(room_id, &tx).await?; + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(&channel, &tx).await? } else { Vec::new() }; Ok(RejoinedRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, rejoined_projects, reshared_projects, @@ -830,7 +869,7 @@ impl Database { .exec(&*tx) .await?; - let (channel_id, room) = self.get_channel_room(room_id, &tx).await?; + let (channel, room) = self.get_channel_room(room_id, &tx).await?; let deleted = if room.participants.is_empty() { let result = room::Entity::delete_by_id(room_id).exec(&*tx).await?; result.rows_affected > 0 @@ -838,14 +877,14 @@ impl Database { false }; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_members_internal(channel_id, &tx).await? + let channel_members = if let Some(channel) = &channel { + self.get_channel_participants(channel, &tx).await? } else { Vec::new() }; let left_room = LeftRoom { room, - channel_id, + channel_id: channel.map(|channel| channel.id), channel_members, left_projects, canceled_calls_to_user_ids, @@ -1033,7 +1072,7 @@ impl Database { &self, room_id: RoomId, tx: &DatabaseTransaction, - ) -> Result<(Option, proto::Room)> { + ) -> Result<(Option, proto::Room)> { let db_room = room::Entity::find_by_id(room_id) .one(tx) .await? @@ -1142,9 +1181,16 @@ impl Database { project_id: db_follower.project_id.to_proto(), }); } + drop(db_followers); + + let channel = if let Some(channel_id) = db_room.channel_id { + Some(self.get_channel_internal(channel_id, &*tx).await?) + } else { + None + }; Ok(( - db_room.channel_id, + channel, proto::Room { id: db_room.id.to_proto(), live_kit_room: db_room.live_kit_room, diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index e19391da7d..4f28ce4fbd 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -7,11 +7,13 @@ pub mod channel_buffer_collaborator; pub mod channel_chat_participant; pub mod channel_member; pub mod channel_message; -pub mod channel_path; +pub mod channel_message_mention; pub mod contact; pub mod feature_flag; pub mod follower; pub mod language_server; +pub mod notification; +pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; diff --git a/crates/collab/src/db/tables/channel.rs b/crates/collab/src/db/tables/channel.rs index 54f12defc1..e30ec9af61 100644 --- a/crates/collab/src/db/tables/channel.rs +++ b/crates/collab/src/db/tables/channel.rs @@ -1,4 +1,4 @@ -use crate::db::ChannelId; +use crate::db::{ChannelId, ChannelVisibility}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] @@ -7,6 +7,29 @@ pub struct Model { #[sea_orm(primary_key)] pub id: ChannelId, pub name: String, + pub visibility: ChannelVisibility, + pub parent_path: String, +} + +impl Model { + pub fn parent_id(&self) -> Option { + self.ancestors().last() + } + + pub fn ancestors(&self) -> impl Iterator + '_ { + self.parent_path + .trim_end_matches('/') + .split('/') + .filter_map(|id| Some(ChannelId::from_proto(id.parse().ok()?))) + } + + pub fn ancestors_including_self(&self) -> impl Iterator + '_ { + self.ancestors().chain(Some(self.id)) + } + + pub fn path(&self) -> String { + format!("{}{}/", self.parent_path, self.id) + } } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_member.rs b/crates/collab/src/db/tables/channel_member.rs index ba3db5a155..5498a00856 100644 --- a/crates/collab/src/db/tables/channel_member.rs +++ b/crates/collab/src/db/tables/channel_member.rs @@ -1,7 +1,7 @@ -use crate::db::{channel_member, ChannelId, ChannelMemberId, UserId}; +use crate::db::{channel_member, ChannelId, ChannelMemberId, ChannelRole, UserId}; use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "channel_members")] pub struct Model { #[sea_orm(primary_key)] @@ -9,7 +9,7 @@ pub struct Model { pub channel_id: ChannelId, pub user_id: UserId, pub accepted: bool, - pub admin: bool, + pub role: ChannelRole, } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_message_mention.rs b/crates/collab/src/db/tables/channel_message_mention.rs new file mode 100644 index 0000000000..6155b057f0 --- /dev/null +++ b/crates/collab/src/db/tables/channel_message_mention.rs @@ -0,0 +1,43 @@ +use crate::db::{MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_message_mentions")] +pub struct Model { + #[sea_orm(primary_key)] + pub message_id: MessageId, + #[sea_orm(primary_key)] + pub start_offset: i32, + pub end_offset: i32, + pub user_id: UserId, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel_message::Entity", + from = "Column::MessageId", + to = "super::channel_message::Column::Id" + )] + Message, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + MentionedUser, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Message.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::MentionedUser.def() + } +} diff --git a/crates/collab/src/db/tables/notification.rs b/crates/collab/src/db/tables/notification.rs new file mode 100644 index 0000000000..3105198fa2 --- /dev/null +++ b/crates/collab/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, NotificationKindId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub created_at: PrimitiveDateTime, + pub recipient_id: UserId, + pub kind: NotificationKindId, + pub entity_id: Option, + pub content: String, + pub is_read: bool, + pub response: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_path.rs b/crates/collab/src/db/tables/notification_kind.rs similarity index 51% rename from crates/collab/src/db/tables/channel_path.rs rename to crates/collab/src/db/tables/notification_kind.rs index 323f116dae..865b5da04b 100644 --- a/crates/collab/src/db/tables/channel_path.rs +++ b/crates/collab/src/db/tables/notification_kind.rs @@ -1,15 +1,15 @@ -use crate::db::ChannelId; +use crate::db::NotificationKindId; use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "channel_paths")] +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] pub struct Model { #[sea_orm(primary_key)] - pub id_path: String, - pub channel_id: ChannelId, + pub id: NotificationKindId, + pub name: String, } -impl ActiveModelBehavior for ActiveModel {} - #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 6a91fd6ffe..b6a89ff6f8 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -7,10 +7,12 @@ mod message_tests; use super::*; use gpui::executor::Background; use parking_lot::Mutex; -use rpc::proto::ChannelEdge; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + Arc, +}; const TEST_RELEASE_CHANNEL: &'static str = "test"; @@ -31,7 +33,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let sql = include_str!(concat!( @@ -45,6 +47,7 @@ impl TestDb { )) .await .unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -79,11 +82,12 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -148,26 +152,39 @@ impl Drop for TestDb { } } -/// The second tuples are (channel_id, parent) -fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph { - let mut graph = ChannelGraph { - channels: vec![], - edges: vec![], - }; - - for (id, name) in channels { - graph.channels.push(Channel { +fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str, ChannelRole)]) -> Vec { + channels + .iter() + .map(|(id, parent_path, name, role)| Channel { id: *id, name: name.to_string(), + visibility: ChannelVisibility::Members, + role: *role, + parent_path: parent_path.to_vec(), }) - } - - for (channel, parent) in edges { - graph.edges.push(ChannelEdge { - channel_id: channel.to_proto(), - parent_id: parent.to_proto(), - }) - } - - graph + .collect() +} + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst), + }, + ) + .await + .unwrap() + .user_id +} + +static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); +fn new_test_connection(server: ServerId) -> ConnectionId { + ConnectionId { + id: TEST_CONNECTION_ID.fetch_add(1, SeqCst), + owner_id: server.0 as u32, + } } diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 0ac41a8b0b..222514da0b 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_c".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -56,7 +53,7 @@ async fn test_channel_buffers(db: &Arc) { let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); - db.invite_channel_member(zed_id, b_id, a_id, false) + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) .await .unwrap(); @@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -211,7 +206,7 @@ async fn test_channel_buffers_last_operations(db: &Database) { .await .unwrap(); - db.invite_channel_member(channel, observer_id, user_id, false) + db.invite_channel_member(channel, observer_id, user_id, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(channel, observer_id, true) diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 7d2bc04a35..43526c7f24 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -1,56 +1,28 @@ -use collections::{HashMap, HashSet}; +use crate::{ + db::{ + tests::{channel_tree, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL}, + Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, + }, + test_both_dbs, +}; use rpc::{ proto::{self}, ConnectionId, }; - -use crate::{ - db::{ - queries::channels::ChannelGraph, - tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, Database, NewUserParams, - }, - test_both_dbs, -}; use std::sync::Arc; test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); async fn test_channels(db: &Arc) { - let a_id = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let b_id = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let a_id = new_test_user(db, "user1@example.com").await; + let b_id = new_test_user(db, "user2@example.com").await; let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); // Make sure that people cannot read channels they haven't been invited to - assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none()); + assert!(db.get_channel(zed_id, b_id).await.is_err()); - db.invite_channel_member(zed_id, b_id, a_id, false) + db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member) .await .unwrap(); @@ -58,99 +30,103 @@ async fn test_channels(db: &Arc) { .await .unwrap(); - let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap(); + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(zed_id), a_id) + .create_sub_channel("livestreaming", zed_id, a_id) .await .unwrap(); let replace_id = db - .create_channel("replace", Some(zed_id), a_id) + .create_sub_channel("replace", zed_id, a_id) .await .unwrap(); - let mut members = db.get_channel_members(replace_id).await.unwrap(); + let mut members = db + .transaction(|tx| async move { + let channel = db.get_channel_internal(replace_id, &*tx).await?; + Ok(db.get_channel_participants(&channel, &*tx).await?) + }) + .await + .unwrap(); members.sort(); assert_eq!(members, &[a_id, b_id]); let rust_id = db.create_root_channel("rust", a_id).await.unwrap(); - let cargo_id = db - .create_channel("cargo", Some(rust_id), a_id) - .await - .unwrap(); + let cargo_id = db.create_sub_channel("cargo", rust_id, a_id).await.unwrap(); let cargo_ra_id = db - .create_channel("cargo-ra", Some(cargo_id), a_id) + .create_sub_channel("cargo-ra", cargo_id, a_id) .await .unwrap(); let result = db.get_channels_for_user(a_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace"), - (rust_id, "rust"), - (cargo_id, "cargo"), - (cargo_ra_id, "cargo-ra") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id), - (cargo_id, rust_id), - (cargo_ra_id, cargo_id), - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin), + (rust_id, &[], "rust", ChannelRole::Admin), + (cargo_id, &[rust_id], "cargo", ChannelRole::Admin), + ( + cargo_ra_id, + &[rust_id, cargo_id], + "cargo-ra", + ChannelRole::Admin + ) + ],) ); let result = db.get_channels_for_user(b_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id) - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Member), + (crdb_id, &[zed_id], "crdb", ChannelRole::Member), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Member + ), + (replace_id, &[zed_id], "replace", ChannelRole::Member) + ],) ); // Update member permissions - let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await; + let set_subchannel_admin = db + .set_channel_member_role(crdb_id, a_id, b_id, ChannelRole::Admin) + .await; assert!(set_subchannel_admin.is_err()); - let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await; + let set_channel_admin = db + .set_channel_member_role(zed_id, a_id, b_id, ChannelRole::Admin) + .await; assert!(set_channel_admin.is_ok()); let result = db.get_channels_for_user(b_id).await.unwrap(); assert_eq!( result.channels, - graph( - &[ - (zed_id, "zed"), - (crdb_id, "crdb"), - (livestreaming_id, "livestreaming"), - (replace_id, "replace") - ], - &[ - (crdb_id, zed_id), - (livestreaming_id, zed_id), - (replace_id, zed_id) - ] - ) + channel_tree(&[ + (zed_id, &[], "zed", ChannelRole::Admin), + (crdb_id, &[zed_id], "crdb", ChannelRole::Admin), + ( + livestreaming_id, + &[zed_id], + "livestreaming", + ChannelRole::Admin + ), + (replace_id, &[zed_id], "replace", ChannelRole::Admin) + ],) ); // Remove a single channel db.delete_channel(crdb_id, a_id).await.unwrap(); - assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(crdb_id, a_id).await.is_err()); // Remove a channel tree let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap(); @@ -158,9 +134,9 @@ async fn test_channels(db: &Arc) { assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]); assert_eq!(user_ids, &[a_id]); - assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none()); - assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none()); + assert!(db.get_channel(rust_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_id, a_id).await.is_err()); + assert!(db.get_channel(cargo_ra_id, a_id).await.is_err()); } test_both_dbs!( @@ -172,43 +148,15 @@ test_both_dbs!( async fn test_joining_channels(db: &Arc) { let owner_id = db.create_server("test").await.unwrap().0 as u32; - let user_1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); - let room_1 = db - .get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL) - .await - .unwrap(); // can join a room with membership to its channel - let joined_room = db - .join_room( - room_1, + let (joined_room, _, _) = db + .join_channel( + channel_1, user_1, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL, @@ -217,11 +165,12 @@ async fn test_joining_channels(db: &Arc) { .unwrap(); assert_eq!(joined_room.room.participants.len(), 1); + let room_id = RoomId::from_proto(joined_room.room.id); drop(joined_room); // cannot join a room without membership to its channel assert!(db .join_room( - room_1, + room_id, user_2, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL @@ -239,58 +188,21 @@ test_both_dbs!( async fn test_channel_invites(db: &Arc) { db.create_server("test").await.unwrap(); - let user_1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 5, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "user2".into(), - github_user_id: 6, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user_3 = db - .create_user( - "user3@example.com", - false, - NewUserParams { - github_login: "user3".into(), - github_user_id: 7, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let user_1 = new_test_user(db, "user1@example.com").await; + let user_2 = new_test_user(db, "user2@example.com").await; + let user_3 = new_test_user(db, "user3@example.com").await; let channel_1_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); let channel_1_2 = db.create_root_channel("channel_2", user_1).await.unwrap(); - db.invite_channel_member(channel_1_1, user_2, user_1, false) + db.invite_channel_member(channel_1_1, user_2, user_1, ChannelRole::Member) .await .unwrap(); - db.invite_channel_member(channel_1_2, user_2, user_1, false) + db.invite_channel_member(channel_1_2, user_2, user_1, ChannelRole::Member) .await .unwrap(); - db.invite_channel_member(channel_1_1, user_3, user_1, true) + db.invite_channel_member(channel_1_1, user_3, user_1, ChannelRole::Admin) .await .unwrap(); @@ -314,27 +226,29 @@ async fn test_channel_invites(db: &Arc) { assert_eq!(user_3_invites, &[channel_1_1]); - let members = db - .get_channel_member_details(channel_1_1, user_1) + let mut members = db + .get_channel_participant_details(channel_1_1, user_1) .await .unwrap(); + + members.sort_by_key(|member| member.user_id); assert_eq!( members, &[ proto::ChannelMember { user_id: user_1.to_proto(), kind: proto::channel_member::Kind::Member.into(), - admin: true, + role: proto::ChannelRole::Admin.into(), }, proto::ChannelMember { user_id: user_2.to_proto(), kind: proto::channel_member::Kind::Invitee.into(), - admin: false, + role: proto::ChannelRole::Member.into(), }, proto::ChannelMember { user_id: user_3.to_proto(), kind: proto::channel_member::Kind::Invitee.into(), - admin: true, + role: proto::ChannelRole::Admin.into(), }, ] ); @@ -344,12 +258,12 @@ async fn test_channel_invites(db: &Arc) { .unwrap(); let channel_1_3 = db - .create_channel("channel_3", Some(channel_1_1), user_1) + .create_sub_channel("channel_3", channel_1_1, user_1) .await .unwrap(); let members = db - .get_channel_member_details(channel_1_3, user_1) + .get_channel_participant_details(channel_1_3, user_1) .await .unwrap(); assert_eq!( @@ -357,13 +271,13 @@ async fn test_channel_invites(db: &Arc) { &[ proto::ChannelMember { user_id: user_1.to_proto(), - kind: proto::channel_member::Kind::Member.into(), - admin: true, + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), }, proto::ChannelMember { user_id: user_2.to_proto(), kind: proto::channel_member::Kind::AncestorMember.into(), - admin: false, + role: proto::ChannelRole::Member.into(), }, ] ); @@ -385,7 +299,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -399,7 +312,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -412,18 +324,10 @@ async fn test_channel_renames(db: &Arc) { .await .unwrap(); - let zed_archive_id = zed_id; - - let (channel, _) = db - .get_channel(zed_archive_id, user_1) - .await - .unwrap() - .unwrap(); + let channel = db.get_channel(zed_id, user_1).await.unwrap(); assert_eq!(channel.name, "zed-archive"); - let non_permissioned_rename = db - .rename_channel(zed_archive_id, user_2, "hacked-lol") - .await; + let non_permissioned_rename = db.rename_channel(zed_id, user_2, "hacked-lol").await; assert!(non_permissioned_rename.is_err()); let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await; @@ -444,7 +348,6 @@ async fn test_db_channel_moving(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -453,20 +356,17 @@ async fn test_db_channel_moving(db: &Arc) { let zed_id = db.create_root_channel("zed", a_id).await.unwrap(); - let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap(); + let crdb_id = db.create_sub_channel("crdb", zed_id, a_id).await.unwrap(); - let gpui2_id = db - .create_channel("gpui2", Some(zed_id), a_id) - .await - .unwrap(); + let gpui2_id = db.create_sub_channel("gpui2", zed_id, a_id).await.unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(crdb_id), a_id) + .create_sub_channel("livestreaming", crdb_id, a_id) .await .unwrap(); let livestreaming_dag_id = db - .create_channel("livestreaming_dag", Some(livestreaming_id), a_id) + .create_sub_channel("livestreaming_dag", livestreaming_id, a_id) .await .unwrap(); @@ -476,316 +376,16 @@ async fn test_db_channel_moving(db: &Arc) { // /- gpui2 // zed -- crdb - livestreaming - livestreaming_dag let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( + assert_channel_tree( result.channels, &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), + (zed_id, &[]), + (crdb_id, &[zed_id]), + (livestreaming_id, &[zed_id, crdb_id]), + (livestreaming_dag_id, &[zed_id, crdb_id, livestreaming_id]), + (gpui2_id, &[zed_id]), ], ); - - // Attempt to make a cycle - assert!(db - .link_channel(a_id, zed_id, livestreaming_id) - .await - .is_err()); - - // ======================================================================== - // Make a link - db.link_channel(a_id, livestreaming_id, zed_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed -- crdb - livestreaming - livestreaming_dag - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - ], - ); - - // ======================================================================== - // Create a new channel below a channel with multiple parents - let livestreaming_dag_sub_id = db - .create_channel("livestreaming_dag_sub", Some(livestreaming_dag_id), a_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test a complex DAG by making another link - let returned_channels = db - .link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 /---------------------\ - // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \--------/ - - // make sure we're getting just the new link - // Not using the assert_dag helper because we want to make sure we're returning the full data - pretty_assertions::assert_eq!( - returned_channels, - graph( - &[(livestreaming_dag_sub_id, "livestreaming_dag_sub")], - &[(livestreaming_dag_sub_id, livestreaming_id)] - ) - ); - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test a complex DAG by making another link - let returned_channels = db - .link_channel(a_id, livestreaming_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -\ /---------------------\ - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id - // \---------/ - - // Make sure that we're correctly getting the full sub-dag - pretty_assertions::assert_eq!( - returned_channels, - graph( - &[ - (livestreaming_id, "livestreaming"), - (livestreaming_dag_id, "livestreaming_dag"), - (livestreaming_dag_sub_id, "livestreaming_dag_sub"), - ], - &[ - (livestreaming_id, gpui2_id), - (livestreaming_dag_id, livestreaming_id), - (livestreaming_dag_sub_id, livestreaming_id), - (livestreaming_dag_sub_id, livestreaming_dag_id), - ] - ) - ); - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test unlinking in a complex DAG by removing the inner link - db.unlink_channel(a_id, livestreaming_dag_sub_id, livestreaming_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -\ - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // \---------/ - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test unlinking in a complex DAG by removing the inner link - db.unlink_channel(a_id, livestreaming_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 - // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Test moving DAG nodes by moving livestreaming to be below gpui2 - db.move_channel(a_id, livestreaming_id, crdb_id, gpui2_id) - .await - .unwrap(); - - // DAG is now: - // /- gpui2 -- livestreaming - livestreaming_dag - livestreaming_dag_sub - // zed - crdb / - // \---------/ - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (gpui2_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(gpui2_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Deleting a channel should not delete children that still have other parents - db.delete_channel(gpui2_id, a_id).await.unwrap(); - - // DAG is now: - // zed - crdb - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Unlinking a channel from it's parent should automatically promote it to a root channel - db.unlink_channel(a_id, crdb_id, zed_id).await.unwrap(); - - // DAG is now: - // crdb - // zed - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, None), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // You should be able to move a root channel into a non-root channel - db.link_channel(a_id, crdb_id, zed_id).await.unwrap(); - - // DAG is now: - // zed - crdb - // \- livestreaming - livestreaming_dag - livestreaming_dag_sub - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // ======================================================================== - // Prep for DAG deletion test - db.link_channel(a_id, livestreaming_id, crdb_id) - .await - .unwrap(); - - // DAG is now: - // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub - // \--------/ - - let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_dag( - result.channels, - &[ - (zed_id, None), - (crdb_id, Some(zed_id)), - (livestreaming_id, Some(zed_id)), - (livestreaming_id, Some(crdb_id)), - (livestreaming_dag_id, Some(livestreaming_id)), - (livestreaming_dag_sub_id, Some(livestreaming_dag_id)), - ], - ); - - // Deleting the parent of a DAG should delete the whole DAG: - db.delete_channel(zed_id, a_id).await.unwrap(); - let result = db.get_channels_for_user(a_id).await.unwrap(); - - assert!(result.channels.is_empty()) } test_both_dbs!( @@ -802,7 +402,6 @@ async fn test_db_channel_moving_bugs(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -812,12 +411,12 @@ async fn test_db_channel_moving_bugs(db: &Arc) { let zed_id = db.create_root_channel("zed", user_id).await.unwrap(); let projects_id = db - .create_channel("projects", Some(zed_id), user_id) + .create_sub_channel("projects", zed_id, user_id) .await .unwrap(); let livestreaming_id = db - .create_channel("livestreaming", Some(projects_id), user_id) + .create_sub_channel("livestreaming", projects_id, user_id) .await .unwrap(); @@ -825,48 +424,396 @@ async fn test_db_channel_moving_bugs(db: &Arc) { // Move to same parent should be a no-op assert!(db - .move_channel(user_id, projects_id, zed_id, zed_id) + .move_channel(projects_id, Some(zed_id), user_id) .await .unwrap() - .is_empty()); - - // Stranding a channel should retain it's sub channels - db.unlink_channel(user_id, projects_id, zed_id) - .await - .unwrap(); + .is_none()); let result = db.get_channels_for_user(user_id).await.unwrap(); - assert_dag( + assert_channel_tree( result.channels, &[ - (zed_id, None), - (projects_id, None), - (livestreaming_id, Some(projects_id)), + (zed_id, &[]), + (projects_id, &[zed_id]), + (livestreaming_id, &[zed_id, projects_id]), + ], + ); + + // Move the project channel to the root + db.move_channel(projects_id, None, user_id).await.unwrap(); + let result = db.get_channels_for_user(user_id).await.unwrap(); + assert_channel_tree( + result.channels, + &[ + (zed_id, &[]), + (projects_id, &[]), + (livestreaming_id, &[projects_id]), ], ); } -#[track_caller] -fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) { - let mut actual_map: HashMap> = HashMap::default(); - for channel in actual.channels { - actual_map.insert(channel.id, HashSet::default()); - } - for edge in actual.edges { - actual_map - .get_mut(&ChannelId::from_proto(edge.channel_id)) - .unwrap() - .insert(ChannelId::from_proto(edge.parent_id)); - } +test_both_dbs!( + test_user_is_channel_participant, + test_user_is_channel_participant_postgres, + test_user_is_channel_participant_sqlite +); - let mut expected_map: HashMap> = HashMap::default(); +async fn test_user_is_channel_participant(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + let member = new_test_user(db, "member@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; - for (child, parent) in expected { - let entry = expected_map.entry(*child).or_default(); - if let Some(parent) = parent { - entry.insert(*parent); - } - } + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + let active_channel_id = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + let vim_channel_id = db + .create_sub_channel("vim", active_channel_id, admin) + .await + .unwrap(); - pretty_assertions::assert_eq!(actual_map, expected_map) + db.set_channel_visibility(vim_channel_id, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + db.invite_channel_member(active_channel_id, member, admin, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(vim_channel_id, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + db.respond_to_channel_invite(active_channel_id, member, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + admin, + &*tx, + ) + .await + }) + .await + .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + member, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Invitee.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + db.respond_to_channel_invite(vim_channel_id, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await?, + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree(channels, &[(vim_channel_id, &[])]); + let channels = db.get_channels_for_user(member).await.unwrap().channels; + assert_channel_tree( + channels, + &[ + (active_channel_id, &[]), + (vim_channel_id, &[active_channel_id]), + ], + ); + + db.set_channel_member_role(vim_channel_id, admin, guest, ChannelRole::Banned) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err()); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::Member.into(), + role: proto::ChannelRole::Banned.into(), + }, + ] + ); + + db.remove_channel_member(vim_channel_id, guest, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.invite_channel_member(zed_channel, guest, admin, ChannelRole::Guest) + .await + .unwrap(); + + // currently people invited to parent channels are not shown here + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + ] + ); + + db.respond_to_channel_invite(zed_channel, guest, true) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(zed_channel, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(active_channel_id, &*tx) + .await + .unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .is_err(),); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant( + &db.get_channel_internal(vim_channel_id, &*tx).await.unwrap(), + guest, + &*tx, + ) + .await + }) + .await + .unwrap(); + + let mut members = db + .get_channel_participant_details(vim_channel_id, admin) + .await + .unwrap(); + + members.sort_by_key(|member| member.user_id); + + assert_eq!( + members, + &[ + proto::ChannelMember { + user_id: admin.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Admin.into(), + }, + proto::ChannelMember { + user_id: member.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Member.into(), + }, + proto::ChannelMember { + user_id: guest.to_proto(), + kind: proto::channel_member::Kind::AncestorMember.into(), + role: proto::ChannelRole::Guest.into(), + }, + ] + ); + + let channels = db.get_channels_for_user(guest).await.unwrap().channels; + assert_channel_tree( + channels, + &[(zed_channel, &[]), (vim_channel_id, &[zed_channel])], + ) +} + +test_both_dbs!( + test_user_joins_correct_channel, + test_user_joins_correct_channel_postgres, + test_user_joins_correct_channel_sqlite +); + +async fn test_user_joins_correct_channel(db: &Arc) { + let admin = new_test_user(db, "admin@example.com").await; + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + + let active_channel = db + .create_sub_channel("active", zed_channel, admin) + .await + .unwrap(); + + let vim_channel = db + .create_sub_channel("vim", active_channel, admin) + .await + .unwrap(); + + let vim2_channel = db + .create_sub_channel("vim2", vim_channel, admin) + .await + .unwrap(); + + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + db.set_channel_visibility(vim2_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + let most_public = db + .transaction(|tx| async move { + Ok(db + .public_ancestors_including_self( + &db.get_channel_internal(vim_channel, &*tx).await.unwrap(), + &tx, + ) + .await? + .first() + .cloned()) + }) + .await + .unwrap() + .unwrap() + .id; + + assert_eq!(most_public, zed_channel) +} + +test_both_dbs!( + test_guest_access, + test_guest_access_postgres, + test_guest_access_sqlite +); + +async fn test_guest_access(db: &Arc) { + let server = db.create_server("test").await.unwrap(); + + let admin = new_test_user(db, "admin@example.com").await; + let guest = new_test_user(db, "guest@example.com").await; + let guest_connection = new_test_connection(server); + + let zed_channel = db.create_root_channel("zed", admin).await.unwrap(); + db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_err()); + + db.join_channel(zed_channel, guest, guest_connection, TEST_RELEASE_CHANNEL) + .await + .unwrap(); + + assert!(db + .join_channel_chat(zed_channel, guest_connection, guest) + .await + .is_ok()) +} + +#[track_caller] +fn assert_channel_tree(actual: Vec, expected: &[(ChannelId, &[ChannelId])]) { + let actual = actual + .iter() + .map(|channel| (channel.id, channel.parent_path.as_slice())) + .collect::>(); + pretty_assertions::assert_eq!( + actual, + expected.to_vec(), + "wrong channel ids and parent paths" + ); } diff --git a/crates/collab/src/db/tests/db_tests.rs b/crates/collab/src/db/tests/db_tests.rs index 1520e081c0..c4b82f8cec 100644 --- a/crates/collab/src/db/tests/db_tests.rs +++ b/crates/collab/src/db/tests/db_tests.rs @@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login1".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login2".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc) { NewUserParams { github_login: "u1".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc) { ); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] + &[Contact::Incoming { user_id: user_1 }] ); // User 2 dismisses the contact request notification without accepting or rejecting. @@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc) { .unwrap(); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] + &[Contact::Incoming { user_id: user_1 }] ); // User can't accept their own contact request @@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }], ); @@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }] ); @@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }] ); @@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc) { &[ Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false, busy: false, } ] @@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person1".into(), github_user_id: 101, - invite_count: 5, }, ) .await @@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person2".into(), github_user_id: 102, - invite_count: 5, }, ) .await @@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() { NewUserParams { github_login: github_login.into(), github_user_id: i as i32, - invite_count: 0, }, ) .await @@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs index 9d5f039747..0286a6308e 100644 --- a/crates/collab/src/db/tests/feature_flag_tests.rs +++ b/crates/collab/src/db/tests/feature_flag_tests.rs @@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user1"), github_user_id: 1, - invite_count: 0, }, ) .await @@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user2"), github_user_id: 2, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index e758fcfb5d..10d9778612 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -1,7 +1,9 @@ +use super::new_test_user; use crate::{ - db::{Database, MessageId, NewUserParams}, + db::{ChannelRole, Database, MessageId}, test_both_dbs, }; +use channel::mentions_to_proto; use std::sync::Arc; use time::OffsetDateTime; @@ -12,39 +14,38 @@ test_both_dbs!( ); async fn test_channel_message_retrieval(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let channel = db.create_channel("channel", None, user).await.unwrap(); + let user = new_test_user(db, "user@example.com").await; + let result = db.create_channel("channel", None, user).await.unwrap(); let owner_id = db.create_server("test").await.unwrap().0 as u32; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + db.join_channel_chat( + result.channel.id, + rpc::ConnectionId { owner_id, id: 0 }, + user, + ) + .await + .unwrap(); let mut all_messages = Vec::new(); for i in 0..10 { all_messages.push( - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap() - .0 - .to_proto(), + db.create_channel_message( + result.channel.id, + user, + &i.to_string(), + &[], + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap() + .message_id + .to_proto(), ); } let messages = db - .get_channel_messages(channel, user, 3, None) + .get_channel_messages(result.channel.id, user, 3, None) .await .unwrap() .into_iter() @@ -54,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc) { let messages = db .get_channel_messages( - channel, + result.channel.id, user, 4, Some(MessageId::from_proto(all_messages[6])), @@ -74,99 +75,154 @@ test_both_dbs!( ); async fn test_channel_message_nonces(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + let channel = db.create_root_channel("channel", user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_c, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) + .await + .unwrap(); + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) + .await + .unwrap(); + + // As user A, create messages that re-use the same nonces. The requests + // succeed, but return the same ids. + let id1 = db + .create_channel_message( + channel, + user_a, + "hi @user_b", + &mentions_to_proto(&[(3..10, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 100, ) .await .unwrap() - .user_id; - let channel = db.create_channel("channel", None, user).await.unwrap(); + .message_id; + let id2 = db + .create_channel_message( + channel, + user_a, + "hello, fellow users", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + let id3 = db + .create_channel_message( + channel, + user_a, + "bye @user_c (same nonce as first message)", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id4 = db + .create_channel_message( + channel, + user_a, + "omg (same nonce as second message)", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; - let owner_id = db.create_server("test").await.unwrap().0 as u32; + // As a different user, reuse one of the same nonces. This request succeeds + // and returns a different id. + let id5 = db + .create_channel_message( + channel, + user_b, + "omg @user_a (same nonce as user_a's first message)", + &mentions_to_proto(&[(4..11, user_a.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + assert_ne!(id1, id2); + assert_eq!(id1, id3); + assert_eq!(id2, id4); + assert_ne!(id5, id1); - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + let messages = db + .get_channel_messages(channel, user_a, 5, None) .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); + .unwrap() + .into_iter() + .map(|m| (m.id, m.body, m.mentions)) + .collect::>(); + assert_eq!( + messages, + &[ + ( + id1.to_proto(), + "hi @user_b".into(), + mentions_to_proto(&[(3..10, user_b.to_proto())]), + ), + ( + id2.to_proto(), + "hello, fellow users".into(), + mentions_to_proto(&[]) + ), + ( + id5.to_proto(), + "omg @user_a (same nonce as user_a's first message)".into(), + mentions_to_proto(&[(4..11, user_a.to_proto())]), + ), + ] + ); } test_both_dbs!( - test_channel_message_new_notification, - test_channel_message_new_notification_postgres, - test_channel_message_new_notification_sqlite + test_unseen_channel_messages, + test_unseen_channel_messages_postgres, + test_unseen_channel_messages_sqlite ); -async fn test_channel_message_new_notification(db: &Arc) { - let user = db - .create_user( - "user_a@example.com", - false, - NewUserParams { - github_login: "user_a".into(), - github_user_id: 1, - invite_count: 0, - }, - ) +async fn test_unseen_channel_messages(db: &Arc) { + let user = new_test_user(db, "user_a@example.com").await; + let observer = new_test_user(db, "user_b@example.com").await; + + let channel_1 = db.create_root_channel("channel", user).await.unwrap(); + let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); + + db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) .await - .unwrap() - .user_id; - let observer = db - .create_user( - "user_b@example.com", - false, - NewUserParams { - github_login: "user_b".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let channel_1 = db.create_channel("channel", None, user).await.unwrap(); - - let channel_2 = db.create_channel("channel-2", None, user).await.unwrap(); - - db.invite_channel_member(channel_1, observer, user, false) + .unwrap(); + db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(channel_1, observer, true) .await .unwrap(); - - db.invite_channel_member(channel_2, observer, user, false) - .await - .unwrap(); - db.respond_to_channel_invite(channel_2, observer, true) .await .unwrap(); @@ -179,28 +235,31 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); let _ = db - .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1) + .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1) .await .unwrap(); - let (second_message, _, _) = db - .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2) + let second_message = db + .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) .await - .unwrap(); + .unwrap() + .message_id; - let (third_message, _, _) = db - .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3) + let third_message = db + .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) .await - .unwrap(); + .unwrap() + .message_id; db.join_channel_chat(channel_2, user_connection_id, user) .await .unwrap(); - let (fourth_message, _, _) = db - .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4) + let fourth_message = db + .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) .await - .unwrap(); + .unwrap() + .message_id; // Check that observer has new messages let unseen_messages = db @@ -295,3 +354,101 @@ async fn test_channel_message_new_notification(db: &Arc) { }] ); } + +test_both_dbs!( + test_channel_message_mentions, + test_channel_message_mentions_postgres, + test_channel_message_mentions_sqlite +); + +async fn test_channel_message_mentions(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + + let channel = db + .create_channel("channel", None, user_a) + .await + .unwrap() + .channel + .id; + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let connection_id = rpc::ConnectionId { owner_id, id: 0 }; + db.join_channel_chat(channel, connection_id, user_a) + .await + .unwrap(); + + db.create_channel_message( + channel, + user_a, + "hi @user_b and @user_c", + &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 1, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "bye @user_c", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 2, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "umm", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 3, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "@user_b, stop.", + &mentions_to_proto(&[(0..7, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 4, + ) + .await + .unwrap(); + + let messages = db + .get_channel_messages(channel, user_b, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.body, m.mentions)) + .collect::>(); + assert_eq!( + &messages, + &[ + ( + "hi @user_b and @user_c".into(), + mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + ), + ( + "bye @user_c".into(), + mentions_to_proto(&[(4..11, user_c.to_proto())]), + ), + ("umm".into(), mentions_to_proto(&[]),), + ( + "@user_b, stop.".into(), + mentions_to_proto(&[(0..7, user_b.to_proto())]), + ), + ] + ); +} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 13fb8ed0eb..85216525b0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -119,7 +119,9 @@ impl AppState { pub async fn new(config: Config) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); - let db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options, Executor::Production).await?; + db.initialize_notification_kinds().await?; + let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e5c6d94ce0..7e847e8bff 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,8 +3,11 @@ mod connection_pool; use crate::{ auth, db::{ - self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, - ServerId, User, UserId, + self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, + RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, + User, UserId, }, executor::Executor, AppState, Result, @@ -38,8 +41,8 @@ use lazy_static::lazy_static; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, - LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, Peer, Receipt, TypedEnvelope, }; @@ -70,6 +73,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; +const NOTIFICATION_COUNT_PER_PAGE: usize = 50; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = @@ -225,6 +229,7 @@ impl Server { .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) @@ -254,7 +259,8 @@ impl Server { .add_request_handler(delete_channel) .add_request_handler(invite_channel_member) .add_request_handler(remove_channel_member) - .add_request_handler(set_channel_member_admin) + .add_request_handler(set_channel_member_role) + .add_request_handler(set_channel_visibility) .add_request_handler(rename_channel) .add_request_handler(join_channel_buffer) .add_request_handler(leave_channel_buffer) @@ -268,8 +274,9 @@ impl Server { .add_request_handler(send_channel_message) .add_request_handler(remove_channel_message) .add_request_handler(get_channel_messages) - .add_request_handler(link_channel) - .add_request_handler(unlink_channel) + .add_request_handler(get_channel_messages_by_id) + .add_request_handler(get_notifications) + .add_request_handler(mark_notification_as_read) .add_request_handler(move_channel) .add_request_handler(follow) .add_message_handler(unfollow) @@ -387,7 +394,7 @@ impl Server { let contacts = app_state.db.get_contacts(user_id).await.trace_err(); if let Some((busy, contacts)) = busy.zip(contacts) { let pool = pool.lock(); - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, @@ -581,14 +588,14 @@ impl Server { let (contacts, channels_for_user, channel_invites) = future::try_join3( this.app_state.db.get_contacts(user_id), this.app_state.db.get_channels_for_user(user_id), - this.app_state.db.get_channel_invites_for_user(user_id) + this.app_state.db.get_channel_invites_for_user(user_id), ).await?; { let mut pool = this.connection_pool.lock(); pool.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; - this.peer.send(connection_id, build_initial_channels_update( + this.peer.send(connection_id, build_channels_update( channels_for_user, channel_invites ))?; @@ -687,7 +694,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { let pool = self.connection_pool.lock(); - let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + let invitee_contact = contact_for_user(invitee_id, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, @@ -935,7 +942,7 @@ async fn create_room( let live_kit_room = live_kit_room.clone(); let live_kit = session.live_kit_client.as_ref(); - util::async_iife!({ + util::async_maybe!({ let live_kit = live_kit?; let token = live_kit @@ -945,6 +952,7 @@ async fn create_room( Some(proto::LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish: true, }) }) } @@ -976,6 +984,13 @@ async fn join_room( session: Session, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); + + let channel_id = session.db().await.channel_id_for_room(room_id).await?; + + if let Some(channel_id) = channel_id { + return join_channel_internal(channel_id, Box::new(response), session).await; + } + let joined_room = { let room = session .db() @@ -991,16 +1006,6 @@ async fn join_room( room.into_inner() }; - if let Some(channel_id) = joined_room.channel_id { - channel_updated( - channel_id, - &joined_room.room, - &joined_room.channel_members, - &session.peer, - &*session.connection_pool().await, - ) - } - for connection_id in session .connection_pool() .await @@ -1028,6 +1033,7 @@ async fn join_room( Some(proto::LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish: true, }) } else { None @@ -1038,7 +1044,7 @@ async fn join_room( response.send(proto::JoinRoomResponse { room: Some(joined_room.room), - channel_id: joined_room.channel_id.map(|id| id.to_proto()), + channel_id: None, live_kit_connection_info, })?; @@ -2064,7 +2070,7 @@ async fn request_contact( return Err(anyhow!("cannot add yourself as a contact"))?; } - session + let notifications = session .db() .await .send_contact_request(requester_id, responder_id) @@ -2087,16 +2093,14 @@ async fn request_contact( .incoming_requests .push(proto::IncomingContactRequest { requester_id: requester_id.to_proto(), - should_notify: true, }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(responder_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2115,7 +2119,8 @@ async fn respond_to_contact_request( } else { let accept = request.response == proto::ContactRequestResponse::Accept as i32; - db.respond_to_contact_request(responder_id, requester_id, accept) + let notifications = db + .respond_to_contact_request(responder_id, requester_id, accept) .await?; let requester_busy = db.is_user_busy(requester_id).await?; let responder_busy = db.is_user_busy(responder_id).await?; @@ -2126,7 +2131,7 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(requester_id, false, requester_busy, &pool)); + .push(contact_for_user(requester_id, requester_busy, &pool)); } update .remove_incoming_requests @@ -2140,14 +2145,17 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(responder_id, true, responder_busy, &pool)); + .push(contact_for_user(responder_id, responder_busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { session.peer.send(connection_id, update.clone())?; } + + send_notifications(&*pool, &session.peer, notifications); } response.send(proto::Ack {})?; @@ -2162,7 +2170,8 @@ async fn remove_contact( let requester_id = session.user_id; let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; - let contact_accepted = db.remove_contact(requester_id, responder_id).await?; + let (contact_accepted, deleted_notification_id) = + db.remove_contact(requester_id, responder_id).await?; let pool = session.connection_pool().await; // Update outgoing contact requests of requester @@ -2189,6 +2198,14 @@ async fn remove_contact( } for connection_id in pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; + if let Some(notification_id) = deleted_notification_id { + session.peer.send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + )?; + } } response.send(proto::Ack {})?; @@ -2203,37 +2220,21 @@ async fn create_channel( let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); - let id = db + let CreateChannelResult { + channel, + participants_to_update, + } = db .create_channel(&request.name, parent_id, session.user_id) .await?; - let channel = proto::Channel { - id: id.to_proto(), - name: request.name, - }; - response.send(proto::CreateChannelResponse { - channel: Some(channel.clone()), + channel: Some(channel.to_proto()), parent_id: request.parent_id, })?; - let Some(parent_id) = parent_id else { - return Ok(()); - }; - - let update = proto::UpdateChannels { - channels: vec![channel], - insert_edge: vec![ChannelEdge { - parent_id: parent_id.to_proto(), - channel_id: id.to_proto(), - }], - ..Default::default() - }; - - let user_ids_to_notify = db.get_channel_members(parent_id).await?; - let connection_pool = session.connection_pool().await; - for user_id in user_ids_to_notify { + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); for connection_id in connection_pool.user_connection_ids(user_id) { if user_id == session.user_id { continue; @@ -2282,27 +2283,30 @@ async fn invite_channel_member( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let invitee_id = UserId::from_proto(request.user_id); - db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin) + let InviteMemberResult { + channel, + notifications, + } = db + .invite_channel_member( + channel_id, + invitee_id, + session.user_id, + request.role().into(), + ) .await?; - let (channel, _) = db - .get_channel(channel_id, session.user_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; - let mut update = proto::UpdateChannels::default(); - update.channel_invitations.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(invitee_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(invitee_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2316,54 +2320,117 @@ async fn remove_channel_member( let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.remove_channel_member(channel_id, member_id, session.user_id) + let RemoveChannelMemberResult { + membership_update, + notification_id, + } = db + .remove_channel_member(channel_id, member_id, session.user_id) .await?; - let mut update = proto::UpdateChannels::default(); - update.delete_channels.push(channel_id.to_proto()); - - for connection_id in session - .connection_pool() - .await - .user_connection_ids(member_id) - { - session.peer.send(connection_id, update.clone())?; + let connection_pool = &session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ); + for connection_id in connection_pool.user_connection_ids(member_id) { + if let Some(notification_id) = notification_id { + session + .peer + .send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + ) + .trace_err(); + } } response.send(proto::Ack {})?; Ok(()) } -async fn set_channel_member_admin( - request: proto::SetChannelMemberAdmin, - response: Response, +async fn set_channel_visibility( + request: proto::SetChannelVisibility, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let channel_id = ChannelId::from_proto(request.channel_id); + let visibility = request.visibility().into(); + + let SetChannelVisibilityResult { + participants_to_update, + participants_to_remove, + channels_to_remove, + } = db + .set_channel_visibility(channel_id, visibility, session.user_id) + .await?; + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let update = build_channels_update(channels, vec![]); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::Ack {})?; + Ok(()) +} + +async fn set_channel_member_role( + request: proto::SetChannelMemberRole, + response: Response, session: Session, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.set_channel_member_admin(channel_id, session.user_id, member_id, request.admin) + let result = db + .set_channel_member_role( + channel_id, + session.user_id, + member_id, + request.role().into(), + ) .await?; - let (channel, has_accepted) = db - .get_channel(channel_id, member_id) - .await? - .ok_or_else(|| anyhow!("channel not found"))?; + match result { + db::SetMemberRoleResult::MembershipUpdated(membership_update) => { + let connection_pool = session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ) + } + db::SetMemberRoleResult::InviteUpdated(channel) => { + let update = proto::UpdateChannels { + channel_invitations: vec![channel.to_proto()], + ..Default::default() + }; - let mut update = proto::UpdateChannels::default(); - if has_accepted { - update.channel_permissions.push(proto::ChannelPermission { - channel_id: channel.id.to_proto(), - is_admin: request.admin, - }); - } - - for connection_id in session - .connection_pool() - .await - .user_connection_ids(member_id) - { - session.peer.send(connection_id, update.clone())?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(member_id) + { + session.peer.send(connection_id, update.clone())?; + } + } } response.send(proto::Ack {})?; @@ -2377,25 +2444,25 @@ async fn rename_channel( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let new_name = db + let RenameChannelResult { + channel, + participants_to_update, + } = db .rename_channel(channel_id, session.user_id, &request.name) .await?; - let channel = proto::Channel { - id: request.channel_id, - name: new_name, - }; response.send(proto::RenameChannelResponse { - channel: Some(channel.clone()), + channel: Some(channel.to_proto()), })?; - let mut update = proto::UpdateChannels::default(); - update.channels.push(channel); - - let member_ids = db.get_channel_members(channel_id).await?; let connection_pool = session.connection_pool().await; - for member_id in member_ids { - for connection_id in connection_pool.user_connection_ids(member_id) { + for (user_id, channel) in participants_to_update { + for connection_id in connection_pool.user_connection_ids(user_id) { + let update = proto::UpdateChannels { + channels: vec![channel.to_proto()], + ..Default::default() + }; + session.peer.send(connection_id, update.clone())?; } } @@ -2403,129 +2470,55 @@ async fn rename_channel( Ok(()) } -async fn link_channel( - request: proto::LinkChannel, - response: Response, - session: Session, -) -> Result<()> { - let db = session.db().await; - let channel_id = ChannelId::from_proto(request.channel_id); - let to = ChannelId::from_proto(request.to); - let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?; - - let members = db.get_channel_members(to).await?; - let connection_pool = session.connection_pool().await; - let update = proto::UpdateChannels { - channels: channels_to_send - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }) - .collect(), - insert_edge: channels_to_send.edges, - ..Default::default() - }; - for member_id in members { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - response.send(Ack {})?; - - Ok(()) -} - -async fn unlink_channel( - request: proto::UnlinkChannel, - response: Response, - session: Session, -) -> Result<()> { - let db = session.db().await; - let channel_id = ChannelId::from_proto(request.channel_id); - let from = ChannelId::from_proto(request.from); - - db.unlink_channel(session.user_id, channel_id, from).await?; - - let members = db.get_channel_members(from).await?; - - let update = proto::UpdateChannels { - delete_edge: vec![proto::ChannelEdge { - channel_id: channel_id.to_proto(), - parent_id: from.to_proto(), - }], - ..Default::default() - }; - let connection_pool = session.connection_pool().await; - for member_id in members { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - response.send(Ack {})?; - - Ok(()) -} - async fn move_channel( request: proto::MoveChannel, response: Response, session: Session, ) -> Result<()> { - let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let from_parent = ChannelId::from_proto(request.from); - let to = ChannelId::from_proto(request.to); + let to = request.to.map(ChannelId::from_proto); - let channels_to_send = db - .move_channel(session.user_id, channel_id, from_parent, to) + let result = session + .db() + .await + .move_channel(channel_id, to, session.user_id) .await?; - if channels_to_send.is_empty() { - response.send(Ack {})?; - return Ok(()); - } - - let members_from = db.get_channel_members(from_parent).await?; - let members_to = db.get_channel_members(to).await?; - - let update = proto::UpdateChannels { - delete_edge: vec![proto::ChannelEdge { - channel_id: channel_id.to_proto(), - parent_id: from_parent.to_proto(), - }], - ..Default::default() - }; - let connection_pool = session.connection_pool().await; - for member_id in members_from { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } - - let update = proto::UpdateChannels { - channels: channels_to_send - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }) - .collect(), - insert_edge: channels_to_send.edges, - ..Default::default() - }; - for member_id in members_to { - for connection_id in connection_pool.user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; - } - } + notify_channel_moved(result, session).await?; response.send(Ack {})?; + Ok(()) +} +async fn notify_channel_moved(result: Option, session: Session) -> Result<()> { + let Some(MoveChannelResult { + participants_to_remove, + participants_to_update, + moved_channels, + }) = result + else { + return Ok(()); + }; + let moved_channels: Vec = moved_channels.iter().map(|id| id.to_proto()).collect(); + + let connection_pool = session.connection_pool().await; + for (user_id, channels) in participants_to_update { + let mut update = build_channels_update(channels, vec![]); + update.delete_channels = moved_channels.clone(); + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + for user_id in participants_to_remove { + let update = proto::UpdateChannels { + delete_channels: moved_channels.clone(), + ..Default::default() + }; + for connection_id in connection_pool.user_connection_ids(user_id) { + session.peer.send(connection_id, update.clone())?; + } + } Ok(()) } @@ -2537,7 +2530,7 @@ async fn get_channel_members( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let members = db - .get_channel_member_details(channel_id, session.user_id) + .get_channel_participant_details(channel_id, session.user_id) .await?; response.send(proto::GetChannelMembersResponse { members })?; Ok(()) @@ -2550,54 +2543,34 @@ async fn respond_to_channel_invite( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - db.respond_to_channel_invite(channel_id, session.user_id, request.accept) + let RespondToChannelInvite { + membership_update, + notifications, + } = db + .respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; - let mut update = proto::UpdateChannels::default(); - update - .remove_channel_invitations - .push(channel_id.to_proto()); - if request.accept { - let result = db.get_channel_for_user(channel_id, session.user_id).await?; - update - .channels - .extend( - result - .channels - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }), - ); - update.unseen_channel_messages = result.channel_messages; - update.unseen_channel_buffer_changes = result.unseen_buffer_changes; - update.insert_edge = result.channels.edges; - update - .channel_participants - .extend( - result - .channel_participants - .into_iter() - .map(|(channel_id, user_ids)| proto::ChannelParticipants { - channel_id: channel_id.to_proto(), - participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(), - }), - ); - update - .channel_permissions - .extend( - result - .channels_with_admin_privileges - .into_iter() - .map(|channel_id| proto::ChannelPermission { - channel_id: channel_id.to_proto(), - is_admin: true, - }), - ); - } - session.peer.send(session.connection_id, update)?; + let connection_pool = session.connection_pool().await; + if let Some(membership_update) = membership_update { + notify_membership_updated( + &connection_pool, + membership_update, + session.user_id, + &session.peer, + ); + } else { + let update = proto::UpdateChannels { + remove_channel_invitations: vec![channel_id.to_proto()], + ..Default::default() + }; + + for connection_id in connection_pool.user_connection_ids(session.user_id) { + session.peer.send(connection_id, update.clone())?; + } + }; + + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) @@ -2609,19 +2582,35 @@ async fn join_channel( session: Session, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); - let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + join_channel_internal(channel_id, Box::new(response), session).await +} +trait JoinChannelInternalResponse { + fn send(self, result: proto::JoinRoomResponse) -> Result<()>; +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} + +async fn join_channel_internal( + channel_id: ChannelId, + response: Box, + session: Session, +) -> Result<()> { let joined_room = { leave_room_for_session(&session).await?; let db = session.db().await; - let room_id = db - .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME) - .await?; - - let joined_room = db - .join_room( - room_id, + let (joined_room, membership_updated, role) = db + .join_channel( + channel_id, session.user_id, session.connection_id, RELEASE_CHANNEL_NAME.as_str(), @@ -2629,16 +2618,32 @@ async fn join_channel( .await?; let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { - let token = live_kit - .room_token( - &joined_room.room.live_kit_room, - &session.user_id.to_string(), + let (can_publish, token) = if role == ChannelRole::Guest { + ( + false, + live_kit + .guest_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, ) - .trace_err()?; + } else { + ( + true, + live_kit + .room_token( + &joined_room.room.live_kit_room, + &session.user_id.to_string(), + ) + .trace_err()?, + ) + }; Some(LiveKitConnectionInfo { server_url: live_kit.url().into(), token, + can_publish, }) }); @@ -2648,9 +2653,19 @@ async fn join_channel( live_kit_connection_info, })?; + let connection_pool = session.connection_pool().await; + if let Some(membership_updated) = membership_updated { + notify_membership_updated( + &connection_pool, + membership_updated, + session.user_id, + &session.peer, + ); + } + room_updated(&joined_room.room, &session.peer); - joined_room.into_inner() + joined_room }; channel_updated( @@ -2662,7 +2677,6 @@ async fn join_channel( ); update_user_contacts(session.user_id, &session).await?; - Ok(()) } @@ -2815,6 +2829,29 @@ fn channel_buffer_updated( }); } +fn send_notifications( + connection_pool: &ConnectionPool, + peer: &Peer, + notifications: db::NotificationBatch, +) { + for (user_id, notification) in notifications { + for connection_id in connection_pool.user_connection_ids(user_id) { + if let Err(error) = peer.send( + connection_id, + proto::AddNotification { + notification: Some(notification.clone()), + }, + ) { + tracing::error!( + "failed to send notification to {:?} {}", + connection_id, + error + ); + } + } + } +} + async fn send_channel_message( request: proto::SendChannelMessage, response: Response, @@ -2829,19 +2866,27 @@ async fn send_channel_message( return Err(anyhow!("message can't be blank"))?; } + // TODO: adjust mentions if body is trimmed + let timestamp = OffsetDateTime::now_utc(); let nonce = request .nonce .ok_or_else(|| anyhow!("nonce can't be blank"))?; let channel_id = ChannelId::from_proto(request.channel_id); - let (message_id, connection_ids, non_participants) = session + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session .db() .await .create_channel_message( channel_id, session.user_id, &body, + &request.mentions, timestamp, nonce.clone().into(), ) @@ -2850,18 +2895,23 @@ async fn send_channel_message( sender_id: session.user_id.to_proto(), id: message_id.to_proto(), body, + mentions: request.mentions, timestamp: timestamp.unix_timestamp() as u64, nonce: Some(nonce), }; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); response.send(proto::SendChannelMessageResponse { message: Some(message), })?; @@ -2869,7 +2919,7 @@ async fn send_channel_message( let pool = &*session.connection_pool().await; broadcast( None, - non_participants + channel_members .iter() .flat_map(|user_id| pool.user_connection_ids(*user_id)), |peer_id| { @@ -2885,6 +2935,7 @@ async fn send_channel_message( ) }, ); + send_notifications(pool, &session.peer, notifications); Ok(()) } @@ -2914,11 +2965,16 @@ async fn acknowledge_channel_message( ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); - session + let notifications = session .db() .await .observe_channel_message(channel_id, session.user_id, message_id) .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); Ok(()) } @@ -2993,6 +3049,72 @@ async fn get_channel_messages( Ok(()) } +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_notifications( + request: proto::GetNotifications, + response: Response, + session: Session, +) -> Result<()> { + let notifications = session + .db() + .await + .get_notifications( + session.user_id, + NOTIFICATION_COUNT_PER_PAGE, + request + .before_id + .map(|id| db::NotificationId::from_proto(id)), + ) + .await?; + response.send(proto::GetNotificationsResponse { + done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE, + notifications, + })?; + Ok(()) +} + +async fn mark_notification_as_read( + request: proto::MarkNotificationRead, + response: Response, + session: Session, +) -> Result<()> { + let database = &session.db().await; + let notifications = database + .mark_notification_as_read_by_id( + session.user_id, + NotificationId::from_proto(request.notification_id), + ) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + response.send(proto::Ack {})?; + Ok(()) +} + async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = session @@ -3062,22 +3184,37 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } -fn build_initial_channels_update( +fn notify_membership_updated( + connection_pool: &ConnectionPool, + result: MembershipUpdated, + user_id: UserId, + peer: &Peer, +) { + let mut update = build_channels_update(result.new_channels, vec![]); + update.delete_channels = result + .removed_channels + .into_iter() + .map(|id| id.to_proto()) + .collect(); + update.remove_channel_invitations = vec![result.channel_id.to_proto()]; + + for connection_id in connection_pool.user_connection_ids(user_id) { + peer.send(connection_id, update.clone()).trace_err(); + } +} + +fn build_channels_update( channels: ChannelsForUser, channel_invites: Vec, ) -> proto::UpdateChannels { let mut update = proto::UpdateChannels::default(); - for channel in channels.channels.channels { - update.channels.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); + for channel in channels.channels { + update.channels.push(channel.to_proto()); } update.unseen_channel_buffer_changes = channels.unseen_buffer_changes; update.unseen_channel_messages = channels.channel_messages; - update.insert_edge = channels.channels.edges; for (channel_id, participants) in channels.channel_participants { update @@ -3088,23 +3225,8 @@ fn build_initial_channels_update( }); } - update - .channel_permissions - .extend( - channels - .channels_with_admin_privileges - .into_iter() - .map(|id| proto::ChannelPermission { - channel_id: id.to_proto(), - is_admin: true, - }), - ); - for channel in channel_invites { - update.channel_invitations.push(proto::Channel { - id: channel.id.to_proto(), - name: channel.name, - }); + update.channel_invitations.push(channel.to_proto()); } update @@ -3118,42 +3240,28 @@ fn build_initial_contacts_update( for contact in contacts { match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(contact_for_user(user_id, should_notify, busy, &pool)); + db::Contact::Accepted { user_id, busy } => { + update.contacts.push(contact_for_user(user_id, busy, &pool)); } db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), + db::Contact::Incoming { user_id } => { + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + }) + } } } update } -fn contact_for_user( - user_id: UserId, - should_notify: bool, - busy: bool, - pool: &ConnectionPool, -) -> proto::Contact { +fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact { proto::Contact { user_id: user_id.to_proto(), online: pool.is_user_online(user_id), busy, - should_notify, } } @@ -3214,7 +3322,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> let busy = db.is_user_busy(user_id).await?; let pool = session.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index e78bbe3466..e8da66a75a 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -6,6 +6,7 @@ mod channel_message_tests; mod channel_tests; mod following_tests; mod integration_tests; +mod notification_tests; mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; @@ -39,3 +40,7 @@ fn room_participants(room: &ModelHandle, cx: &mut TestAppContext) -> RoomP RoomParticipants { remote, pending } }) } + +fn channel_id(room: &ModelHandle, cx: &mut TestAppContext) -> Option { + cx.read(|cx| room.read(cx).channel_id()) +} diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index a0b9b52484..5ca40a3c2d 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -3,7 +3,7 @@ use crate::{ tests::TestServer, }; use call::ActiveCall; -use channel::{Channel, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; +use channel::ACKNOWLEDGE_DEBOUNCE_INTERVAL; use client::ParticipantIndex; use client::{Collaborator, UserId}; use collab_ui::channel_view::ChannelView; @@ -407,11 +407,8 @@ async fn test_channel_buffer_disconnect( server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - channel_buffer_a.update(cx_a, |buffer, _| { - assert_eq!( - buffer.channel().as_ref(), - &channel(channel_id, "the-channel") - ); + channel_buffer_a.update(cx_a, |buffer, cx| { + assert_eq!(buffer.channel(cx).unwrap().name, "the-channel"); assert!(!buffer.is_connected()); }); @@ -432,24 +429,12 @@ async fn test_channel_buffer_disconnect( deterministic.run_until_parked(); // Channel buffer observed the deletion - channel_buffer_b.update(cx_b, |buffer, _| { - assert_eq!( - buffer.channel().as_ref(), - &channel(channel_id, "the-channel") - ); + channel_buffer_b.update(cx_b, |buffer, cx| { + assert!(buffer.channel(cx).is_none()); assert!(!buffer.is_connected()); }); } -fn channel(id: u64, name: &'static str) -> Channel { - Channel { - id, - name: name.to_string(), - unseen_note_version: None, - unseen_message_id: None, - } -} - #[gpui::test] async fn test_rejoin_channel_buffer( deterministic: Arc, @@ -694,7 +679,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .await .unwrap(); channel_view_1_a.update(cx_a, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-1"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); notes.editor.update(cx, |editor, cx| { editor.insert("Hello from A.", cx); editor.change_selections(None, cx, |selections| { @@ -726,7 +711,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .expect("active item is not a channel view") }); channel_view_1_b.read_with(cx_b, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-1"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-1"); let editor = notes.editor.read(cx); assert_eq!(editor.text(cx), "Hello from A."); assert_eq!(editor.selections.ranges::(cx), &[3..4]); @@ -738,7 +723,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .await .unwrap(); channel_view_2_a.read_with(cx_a, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-2"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); }); // Client B is taken to the notes for channel 2. @@ -755,7 +740,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( .expect("active item is not a channel view") }); channel_view_2_b.read_with(cx_b, |notes, cx| { - assert_eq!(notes.channel(cx).name, "channel-2"); + assert_eq!(notes.channel(cx).unwrap().name, "channel-2"); }); } diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index 0fc3b085ed..918eb053d3 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -1,27 +1,30 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; -use channel::{ChannelChat, ChannelMessageId}; +use channel::{ChannelChat, ChannelMessageId, MessageParams}; use collab_ui::chat_panel::ChatPanel; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; +use rpc::Notification; use std::sync::Arc; use workspace::dock::Panel; #[gpui::test] async fn test_basic_channel_messages( deterministic: Arc, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, + mut cx_a: &mut TestAppContext, + mut cx_b: &mut TestAppContext, + mut cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; let channel_id = server .make_channel( "the-channel", None, (&client_a, cx_a), - &mut [(&client_b, cx_b)], + &mut [(&client_b, cx_b), (&client_c, cx_c)], ) .await; @@ -36,8 +39,17 @@ async fn test_basic_channel_messages( .await .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + let message_id = channel_chat_a + .update(cx_a, |c, cx| { + c.send_message( + MessageParams { + text: "hi @user_c!".into(), + mentions: vec![(3..10, client_c.id())], + }, + cx, + ) + .unwrap() + }) .await .unwrap(); channel_chat_a @@ -52,15 +64,55 @@ async fn test_basic_channel_messages( .unwrap(); deterministic.run_until_parked(); - channel_chat_a.update(cx_a, |c, _| { + + let channel_chat_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + for (chat, cx) in [ + (&channel_chat_a, &mut cx_a), + (&channel_chat_b, &mut cx_b), + (&channel_chat_c, &mut cx_c), + ] { + chat.update(*cx, |c, _| { + assert_eq!( + c.messages() + .iter() + .map(|m| (m.body.as_str(), m.mentions.as_slice())) + .collect::>(), + vec![ + ("hi @user_c!", [(3..10, client_c.id())].as_slice()), + ("two", &[]), + ("three", &[]) + ], + "results for user {}", + c.client().id(), + ); + }); + } + + client_c.notification_store().update(cx_c, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); assert_eq!( - c.messages() - .iter() - .map(|m| m.body.as_str()) - .collect::>(), - vec!["one", "two", "three"] + store.notification_at(0).unwrap().notification, + Notification::ChannelMessageMention { + message_id, + sender_id: client_a.id(), + channel_id, + } ); - }) + assert_eq!( + store.notification_at(1).unwrap().notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + }); } #[gpui::test] @@ -280,7 +332,7 @@ async fn test_channel_message_changes( chat_panel_b .update(cx_b, |chat_panel, cx| { chat_panel.set_active(true, cx); - chat_panel.select_channel(channel_id, cx) + chat_panel.select_channel(channel_id, None, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 7cfcce832b..a33ded6492 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -1,12 +1,17 @@ use crate::{ + db::{self, UserId}, rpc::RECONNECT_TIMEOUT, tests::{room_participants, RoomParticipants, TestServer}, }; use call::ActiveCall; use channel::{ChannelId, ChannelMembership, ChannelStore}; use client::User; +use futures::future::try_join_all; use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; -use rpc::{proto, RECEIVE_TIMEOUT}; +use rpc::{ + proto::{self, ChannelRole}, + RECEIVE_TIMEOUT, +}; use std::sync::Arc; #[gpui::test] @@ -44,22 +49,19 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), depth: 1, - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); client_b.channel_store().read_with(cx_b, |channels, _| { - assert!(channels - .channel_dag_entries() - .collect::>() - .is_empty()) + assert!(channels.ordered_channels().collect::>().is_empty()) }); // Invite client B to channel A as client A. @@ -68,7 +70,12 @@ async fn test_core_channels( .update(cx_a, |store, cx| { assert!(!store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); - let invite = store.invite_member(channel_a_id, client_b.user_id().unwrap(), false, cx); + let invite = store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ); // Make sure we're synchronously storing the pending invite assert!(store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap())); @@ -86,7 +93,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: false, + role: ChannelRole::Member, }], ); @@ -103,12 +110,12 @@ async fn test_core_channels( &[ ( client_a.user_id().unwrap(), - true, + proto::ChannelRole::Admin, proto::channel_member::Kind::Member, ), ( client_b.user_id().unwrap(), - false, + proto::ChannelRole::Member, proto::channel_member::Kind::Invitee, ), ], @@ -117,8 +124,8 @@ async fn test_core_channels( // Client B accepts the invitation. client_b .channel_store() - .update(cx_b, |channels, _| { - channels.respond_to_channel_invite(channel_a_id, true) + .update(cx_b, |channels, cx| { + channels.respond_to_channel_invite(channel_a_id, true, cx) }) .await .unwrap(); @@ -133,13 +140,13 @@ async fn test_core_channels( ExpectedChannel { id: channel_a_id, name: "channel-a".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 0, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 1, }, ], @@ -161,19 +168,19 @@ async fn test_core_channels( ExpectedChannel { id: channel_a_id, name: "channel-a".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 0, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 1, }, ExpectedChannel { id: channel_c_id, name: "channel-c".to_string(), - user_is_admin: false, + role: ChannelRole::Member, depth: 2, }, ], @@ -183,7 +190,12 @@ async fn test_core_channels( client_a .channel_store() .update(cx_a, |store, cx| { - store.set_member_admin(channel_a_id, client_b.user_id().unwrap(), true, cx) + store.set_member_role( + channel_a_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) }) .await .unwrap(); @@ -200,19 +212,19 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_b_id, name: "channel-b".to_string(), depth: 1, - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { id: channel_c_id, name: "channel-c".to_string(), depth: 2, - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); @@ -234,7 +246,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); assert_channels( @@ -244,7 +256,7 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); @@ -267,18 +279,27 @@ async fn test_core_channels( id: channel_a_id, name: "channel-a".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); // Client B no longer has access to the channel assert_channels(client_b.channel_store(), cx_b, &[]); - // When disconnected, client A sees no channels. server.forbid_connections(); server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - assert_channels(client_a.channel_store(), cx_a, &[]); + + server + .app_state + .db + .rename_channel( + db::ChannelId::from_proto(channel_a_id), + UserId::from_proto(client_a.id()), + "channel-a-renamed", + ) + .await + .unwrap(); server.allow_connections(); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); @@ -287,9 +308,9 @@ async fn test_core_channels( cx_a, &[ExpectedChannel { id: channel_a_id, - name: "channel-a".to_string(), + name: "channel-a-renamed".to_string(), depth: 0, - user_is_admin: true, + role: ChannelRole::Admin, }], ); } @@ -305,12 +326,12 @@ fn assert_participants_eq(participants: &[Arc], expected_partitipants: &[u #[track_caller] fn assert_members_eq( members: &[ChannelMembership], - expected_members: &[(u64, bool, proto::channel_member::Kind)], + expected_members: &[(u64, proto::ChannelRole, proto::channel_member::Kind)], ) { assert_eq!( members .iter() - .map(|member| (member.user.id, member.admin, member.kind)) + .map(|member| (member.user.id, member.role, member.kind)) .collect::>(), expected_members ); @@ -397,7 +418,7 @@ async fn test_channel_room( id: zed_id, name: "zed".to_string(), depth: 0, - user_is_admin: false, + role: ChannelRole::Member, }], ); client_b.channel_store().read_with(cx_b, |channels, _| { @@ -611,7 +632,12 @@ async fn test_permissions_update_while_invited( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.invite_member(rust_id, client_b.user_id().unwrap(), false, cx) + channel_store.invite_member( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) }) .await .unwrap(); @@ -625,7 +651,7 @@ async fn test_permissions_update_while_invited( depth: 0, id: rust_id, name: "rust".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); assert_channels(client_b.channel_store(), cx_b, &[]); @@ -634,7 +660,12 @@ async fn test_permissions_update_while_invited( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.set_member_admin(rust_id, client_b.user_id().unwrap(), true, cx) + channel_store.set_member_role( + rust_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Admin, + cx, + ) }) .await .unwrap(); @@ -648,7 +679,7 @@ async fn test_permissions_update_while_invited( depth: 0, id: rust_id, name: "rust".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); assert_channels(client_b.channel_store(), cx_b, &[]); @@ -688,7 +719,7 @@ async fn test_channel_rename( depth: 0, id: rust_id, name: "rust-archive".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }], ); @@ -700,7 +731,7 @@ async fn test_channel_rename( depth: 0, id: rust_id, name: "rust-archive".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); } @@ -803,7 +834,12 @@ async fn test_lost_channel_creation( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.invite_member(channel_id, client_b.user_id().unwrap(), false, cx) + channel_store.invite_member( + channel_id, + client_b.user_id().unwrap(), + proto::ChannelRole::Member, + cx, + ) }) .await .unwrap(); @@ -818,7 +854,7 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }], ); @@ -842,13 +878,13 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }, ExpectedChannel { depth: 1, id: subchannel_id, name: "subchannel".to_string(), - user_is_admin: true, + role: ChannelRole::Admin, }, ], ); @@ -856,8 +892,8 @@ async fn test_lost_channel_creation( // Client B accepts the invite client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(channel_id, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -873,31 +909,489 @@ async fn test_lost_channel_creation( depth: 0, id: channel_id, name: "x".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }, ExpectedChannel { depth: 1, id: subchannel_id, name: "subchannel".to_string(), - user_is_admin: false, + role: ChannelRole::Member, }, ], ); } #[gpui::test] -async fn test_channel_moving( +async fn test_channel_link_notifications( deterministic: Arc, cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; + let user_b = client_b.user_id().unwrap(); + let user_c = client_c.user_id().unwrap(); + + let channels = server + .make_channel_tree(&[("zed", None)], (&client_a, cx_a)) + .await; + let zed_channel = channels[0]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_c, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_c + .channel_store() + .update(cx_c, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // we have an admin (a), member (b) and guest (c) all part of the zed channel. + + // create a new private channel, make it public, and move it under the previous one, and verify it shows for b and not c + let active_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("active", Some(zed_channel), cx) + }) + .await + .unwrap(); + + // the new channel shows for b and not c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1)], + ); + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); + + let vim_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("vim", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(vim_channel, Some(active_channel), cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(zed_channel, 0), (active_channel, 1), (vim_channel, 2)], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1)], + ); + + let helix_channel = client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.create_channel("helix", None, cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.move_channel(helix_channel, Some(vim_channel), cx) + }) + .await + .unwrap(); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility( + helix_channel, + proto::ChannelVisibility::Public, + cx, + ) + }) + .await + .unwrap(); + + // the new channel shows for b and c + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + assert_channels_list_shape( + client_c.channel_store(), + cx_c, + &[(zed_channel, 0), (vim_channel, 1), (helix_channel, 2)], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + // the members-only channel is still shown for c, but hidden for b + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[ + (zed_channel, 0), + (active_channel, 1), + (vim_channel, 2), + (helix_channel, 3), + ], + ); + client_b + .channel_store() + .read_with(cx_b, |channel_store, _| { + assert_eq!( + channel_store + .channel_for_id(vim_channel) + .unwrap() + .visibility, + proto::ChannelVisibility::Members + ) + }); + + assert_channels_list_shape(client_c.channel_store(), cx_c, &[(zed_channel, 0)]); +} + +#[gpui::test] +async fn test_channel_membership_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_c").await; + + let user_b = client_b.user_id().unwrap(); + + let channels = server + .make_channel_tree( + &[ + ("zed", None), + ("active", Some("zed")), + ("vim", Some("active")), + ], + (&client_a, cx_a), + ) + .await; + let zed_channel = channels[0]; + let _active_channel = channels[1]; + let vim_channel = channels[2]; + + try_join_all(client_a.channel_store().update(cx_a, |channel_store, cx| { + [ + channel_store.set_channel_visibility(zed_channel, proto::ChannelVisibility::Public, cx), + channel_store.set_channel_visibility(vim_channel, proto::ChannelVisibility::Public, cx), + channel_store.invite_member(vim_channel, user_b, proto::ChannelRole::Member, cx), + channel_store.invite_member(zed_channel, user_b, proto::ChannelRole::Guest, cx), + ] + })) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) + }) + .await + .unwrap(); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(vim_channel, true, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + // we have an admin (a), and a guest (b) with access to all of zed, and membership in vim. + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Member, + }, + ], + ); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.remove_member(vim_channel, user_b, cx) + }) + .await + .unwrap(); + + deterministic.run_until_parked(); + + assert_channels( + client_b.channel_store(), + cx_b, + &[ + ExpectedChannel { + depth: 0, + id: zed_channel, + name: "zed".to_string(), + role: ChannelRole::Guest, + }, + ExpectedChannel { + depth: 1, + id: vim_channel, + name: "vim".to_string(), + role: ChannelRole::Guest, + }, + ], + ) +} + +#[gpui::test] +async fn test_guest_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a = channels[0]; + let channel_b = channels[1]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // Non-members should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .is_err()); + + // Make channels A and B public + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_b, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + // Client B joins channel A as a guest + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + assert_channels_list_shape( + client_a.channel_store(), + cx_a, + &[(channel_a, 0), (channel_b, 1)], + ); + assert_channels_list_shape( + client_b.channel_store(), + cx_b, + &[(channel_a, 0), (channel_b, 1)], + ); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_a); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a, proto::ChannelVisibility::Members, cx) + }) + .await + .unwrap(); + + assert_channels_list_shape(client_b.channel_store(), cx_b, &[]); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + assert_channels_list_shape(client_b.channel_store(), cx_b, &[(channel_b, 0)]); +} + +#[gpui::test] +async fn test_invite_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree( + &[("channel-a", None), ("channel-b", Some("channel-a"))], + (&client_a, cx_a), + ) + .await; + let channel_a_id = channels[0]; + let channel_b_id = channels[0]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .is_err()); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.invite_member( + channel_a_id, + client_b.user_id().unwrap(), + ChannelRole::Member, + cx, + ) + }) + .await + .unwrap(); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + + client_b.channel_store().update(cx_b, |channel_store, _| { + assert!(channel_store.channel_for_id(channel_b_id).is_some()); + assert!(channel_store.channel_for_id(channel_a_id).is_some()); + }); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_b_id); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }) +} + +#[gpui::test] +async fn test_channel_moving( + deterministic: Arc, + cx_a: &mut TestAppContext, + _cx_b: &mut TestAppContext, + _cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + // let client_b = server.create_client(cx_b, "user_b").await; + // let client_c = server.create_client(cx_c, "user_c").await; + let channels = server .make_channel_tree( &[ @@ -930,7 +1424,7 @@ async fn test_channel_moving( client_a .channel_store() .update(cx_a, |channel_store, cx| { - channel_store.move_channel(channel_d_id, channel_c_id, channel_b_id, cx) + channel_store.move_channel(channel_d_id, Some(channel_b_id), cx) }) .await .unwrap(); @@ -948,188 +1442,6 @@ async fn test_channel_moving( (channel_d_id, 2), ], ); - - client_a - .channel_store() - .update(cx_a, |channel_store, cx| { - channel_store.link_channel(channel_d_id, channel_c_id, cx) - }) - .await - .unwrap(); - - // Current shape for A: - // /------\ - // a - b -- c -- d - assert_channels_list_shape( - client_a.channel_store(), - cx_a, - &[ - (channel_a_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - ], - ); - - let b_channels = server - .make_channel_tree( - &[ - ("channel-mu", None), - ("channel-gamma", Some("channel-mu")), - ("channel-epsilon", Some("channel-mu")), - ], - (&client_b, cx_b), - ) - .await; - let channel_mu_id = b_channels[0]; - let channel_ga_id = b_channels[1]; - let channel_ep_id = b_channels[2]; - - // Current shape for B: - // /- ep - // mu -- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[(channel_mu_id, 0), (channel_ep_id, 1), (channel_ga_id, 1)], - ); - - client_a - .add_admin_to_channel((&client_b, cx_b), channel_b_id, cx_a) - .await; - - // Current shape for B: - // /- ep - // mu -- ga - // /---------\ - // b -- c -- d - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - // New channels from a - (channel_b_id, 0), - (channel_c_id, 1), - (channel_d_id, 2), - (channel_d_id, 1), - // B's old channels - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_ga_id, 1), - ], - ); - - client_b - .add_admin_to_channel((&client_c, cx_c), channel_ep_id, cx_b) - .await; - - // Current shape for C: - // - ep - assert_channels_list_shape(client_c.channel_store(), cx_c, &[(channel_ep_id, 0)]); - - client_b - .channel_store() - .update(cx_b, |channel_store, cx| { - channel_store.link_channel(channel_b_id, channel_ep_id, cx) - }) - .await - .unwrap(); - - // Current shape for B: - // /---------\ - // /- ep -- b -- c -- d - // mu -- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_b_id, 2), - (channel_c_id, 3), - (channel_d_id, 4), - (channel_d_id, 3), - (channel_ga_id, 1), - ], - ); - - // Current shape for C: - // /---------\ - // ep -- b -- c -- d - assert_channels_list_shape( - client_c.channel_store(), - cx_c, - &[ - (channel_ep_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - ], - ); - - client_b - .channel_store() - .update(cx_b, |channel_store, cx| { - channel_store.link_channel(channel_ga_id, channel_b_id, cx) - }) - .await - .unwrap(); - - // Current shape for B: - // /---------\ - // /- ep -- b -- c -- d - // / \ - // mu ---------- ga - assert_channels_list_shape( - client_b.channel_store(), - cx_b, - &[ - (channel_mu_id, 0), - (channel_ep_id, 1), - (channel_b_id, 2), - (channel_c_id, 3), - (channel_d_id, 4), - (channel_d_id, 3), - (channel_ga_id, 3), - (channel_ga_id, 1), - ], - ); - - // Current shape for A: - // /------\ - // a - b -- c -- d - // \-- ga - assert_channels_list_shape( - client_a.channel_store(), - cx_a, - &[ - (channel_a_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - (channel_ga_id, 2), - ], - ); - - // Current shape for C: - // /-------\ - // ep -- b -- c -- d - // \-- ga - assert_channels_list_shape( - client_c.channel_store(), - cx_c, - &[ - (channel_ep_id, 0), - (channel_b_id, 1), - (channel_c_id, 2), - (channel_d_id, 3), - (channel_d_id, 2), - (channel_ga_id, 2), - ], - ); } #[derive(Debug, PartialEq)] @@ -1137,7 +1449,7 @@ struct ExpectedChannel { depth: usize, id: ChannelId, name: String, - user_is_admin: bool, + role: ChannelRole, } #[track_caller] @@ -1154,7 +1466,7 @@ fn assert_channel_invitations( depth: 0, name: channel.name.clone(), id: channel.id, - user_is_admin: store.is_user_admin(channel.id), + role: channel.role, }) .collect::>() }); @@ -1169,12 +1481,12 @@ fn assert_channels( ) { let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() + .ordered_channels() .map(|(depth, channel)| ExpectedChannel { depth, name: channel.name.clone(), id: channel.id, - user_is_admin: store.is_user_admin(channel.id), + role: channel.role, }) .collect::>() }); @@ -1191,7 +1503,7 @@ fn assert_channels_list_shape( let actual = channel_store.read_with(cx, |store, _| { store - .channel_dag_entries() + .ordered_channels() .map(|(depth, channel)| (channel.id, depth)) .collect::>() }); diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index f3857e3db3..a28f2ae87f 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -1,6 +1,6 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; -use collab_ui::project_shared_notification::ProjectSharedNotification; +use collab_ui::notifications::project_shared_notification::ProjectSharedNotification; use editor::{Editor, ExcerptRange, MultiBuffer}; use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle}; use live_kit_client::MacOSDisplay; diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index d6d449fd47..550c3a2bd8 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -1,6 +1,6 @@ use crate::{ rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - tests::{room_participants, RoomParticipants, TestClient, TestServer}, + tests::{channel_id, room_participants, RoomParticipants, TestClient, TestServer}, }; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; @@ -15,8 +15,8 @@ use gpui::{executor::Deterministic, test::EmptyView, AppContext, ModelHandle, Te use indoc::indoc; use language::{ language_settings::{AllLanguageSettings, Formatter, InlayHintSettings}, - tree_sitter_rust, Anchor, BundledFormatter, Diagnostic, DiagnosticEntry, FakeLspAdapter, - Language, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, + tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, + LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, }; use live_kit_client::MacOSDisplay; use lsp::LanguageServerId; @@ -469,6 +469,119 @@ async fn test_calling_multiple_users_simultaneously( ); } +#[gpui::test(iterations = 10)] +async fn test_joining_channels_and_calling_multiple_users_simultaneously( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + server + .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)]) + .await; + + let channel_1 = server + .make_channel( + "channel1", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_2 = server + .make_channel( + "channel2", + None, + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let active_call_a = cx_a.read(ActiveCall::global); + + // Simultaneously join channel 1 and then channel 2 + active_call_a + .update(cx_a, |call, cx| call.join_channel(channel_1, cx)) + .detach(); + let join_channel_2 = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_2, cx)); + + join_channel_2.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); + + assert_eq!(channel_id(&room_a, cx_a), Some(channel_2)); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Initiating invites and then joining a channel should fail gracefully + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + b_invite.await.unwrap(); + c_invite.await.unwrap(); + join_channel.await.unwrap(); + + let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); + + assert_eq!( + room_participants(&room_a, cx_a), + RoomParticipants { + remote: Default::default(), + pending: vec!["user_b".to_string(), "user_c".to_string()] + } + ); + + assert_eq!(channel_id(&room_a, cx_a), None); + + // Leave the room + active_call_a + .update(cx_a, |call, cx| { + let hang_up = call.hang_up(cx); + hang_up + }) + .await + .unwrap(); + + // Simultaneously join channel 1 and call user B and user C from client A. + let join_channel = active_call_a.update(cx_a, |call, cx| call.join_channel(channel_1, cx)); + + let b_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_b.user_id().unwrap(), None, cx) + }); + let c_invite = active_call_a.update(cx_a, |call, cx| { + call.invite(client_c.user_id().unwrap(), None, cx) + }); + + join_channel.await.unwrap(); + b_invite.await.unwrap(); + c_invite.await.unwrap(); + + active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone()); + deterministic.run_until_parked(); +} + #[gpui::test(iterations = 10)] async fn test_room_uniqueness( deterministic: Arc, @@ -4530,6 +4643,7 @@ async fn test_prettier_formatting_buffer( LanguageConfig { name: "Rust".into(), path_suffixes: vec!["rs".to_string()], + prettier_parser_name: Some("test_parser".to_string()), ..Default::default() }, Some(tree_sitter_rust::language()), @@ -4537,10 +4651,7 @@ async fn test_prettier_formatting_buffer( let test_plugin = "test_plugin"; let mut fake_language_servers = language .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { - enabled_formatters: vec![BundledFormatter::Prettier { - parser_name: Some("test_parser"), - plugin_names: vec![test_plugin], - }], + prettier_plugins: vec![test_plugin], ..Default::default() })) .await; @@ -4557,11 +4668,7 @@ async fn test_prettier_formatting_buffer( .insert_tree(&directory, json!({ "a.rs": buffer_text })) .await; let (project_a, worktree_id) = client_a.build_local_project(&directory, cx_a).await; - let prettier_format_suffix = project_a.update(cx_a, |project, _| { - let suffix = project.enable_test_prettier(&[test_plugin]); - project.languages().add(language); - suffix - }); + let prettier_format_suffix = project::TEST_PRETTIER_FORMAT_SUFFIX; let buffer_a = cx_a .background() .spawn(project_a.update(cx_a, |p, cx| p.open_buffer((worktree_id, "a.rs"), cx))) diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs new file mode 100644 index 0000000000..1114470449 --- /dev/null +++ b/crates/collab/src/tests/notification_tests.rs @@ -0,0 +1,159 @@ +use crate::tests::TestServer; +use gpui::{executor::Deterministic, TestAppContext}; +use notifications::NotificationEvent; +use parking_lot::Mutex; +use rpc::{proto, Notification}; +use std::sync::Arc; + +#[gpui::test] +async fn test_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let notification_events_a = Arc::new(Mutex::new(Vec::new())); + let notification_events_b = Arc::new(Mutex::new(Vec::new())); + client_a.notification_store().update(cx_a, |_, cx| { + let events = notification_events_a.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + client_b.notification_store().update(cx_b, |_, cx| { + let events = notification_events_b.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + + // Client A sends a contact request to client B. + client_a + .user_store() + .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx)) + .await + .unwrap(); + + // Client B receives a contact request notification and responds to the + // request, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequest { + sender_id: client_a.id() + } + ); + assert!(!entry.is_read); + assert_eq!( + ¬ification_events_b.lock()[0..], + &[ + NotificationEvent::NewNotification { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..0, + new_count: 1 + } + ] + ); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + assert_eq!( + ¬ification_events_b.lock()[2..], + &[ + NotificationEvent::NotificationRead { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..1, + new_count: 1 + } + ] + ); + }); + + // Client A receives a notification that client B accepted their request. + client_a.notification_store().read_with(cx_a, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequestAccepted { + responder_id: client_b.id() + } + ); + assert!(!entry.is_read); + }); + + // Client A creates a channel and invites client B to be a member. + let channel_id = client_a + .channel_store() + .update(cx_a, |store, cx| { + store.create_channel("the-channel", None, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |store, cx| { + store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx) + }) + .await + .unwrap(); + + // Client B receives a channel invitation notification and responds to the + // invitation, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + assert!(!entry.is_read); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + }); +} diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs index 6e0bef225c..38bc3f7c12 100644 --- a/crates/collab/src/tests/random_channel_buffer_tests.rs +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -1,3 +1,5 @@ +use crate::db::ChannelRole; + use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; use anyhow::Result; use async_trait::async_trait; @@ -46,11 +48,11 @@ impl RandomizedTest for RandomChannelBufferTest { let db = &server.app_state.db; for ix in 0..CHANNEL_COUNT { let id = db - .create_channel(&format!("channel-{ix}"), None, users[0].user_id) + .create_root_channel(&format!("channel-{ix}"), users[0].user_id) .await .unwrap(); for user in &users[1..] { - db.invite_channel_member(id, user.user_id, users[0].user_id, false) + db.invite_channel_member(id, user.user_id, users[0].user_id, ChannelRole::Member) .await .unwrap(); db.respond_to_channel_invite(id, user.user_id, true) @@ -81,7 +83,7 @@ impl RandomizedTest for RandomChannelBufferTest { match rng.gen_range(0..100_u32) { 0..=29 => { let channel_name = client.channel_store().read_with(cx, |store, cx| { - store.channel_dag_entries().find_map(|(_, channel)| { + store.ordered_channels().find_map(|(_, channel)| { if store.has_open_channel_buffer(channel.id, cx) { None } else { @@ -96,15 +98,16 @@ impl RandomizedTest for RandomChannelBufferTest { 30..=40 => { if let Some(buffer) = channel_buffers.iter().choose(rng) { - let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone()); + let channel_name = + buffer.read_with(cx, |b, cx| b.channel(cx).unwrap().name.clone()); break ChannelBufferOperation::LeaveChannelNotes { channel_name }; } } _ => { if let Some(buffer) = channel_buffers.iter().choose(rng) { - break buffer.read_with(cx, |b, _| { - let channel_name = b.channel().name.clone(); + break buffer.read_with(cx, |b, cx| { + let channel_name = b.channel(cx).unwrap().name.clone(); let edits = b .buffer() .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); @@ -128,7 +131,7 @@ impl RandomizedTest for RandomChannelBufferTest { ChannelBufferOperation::JoinChannelNotes { channel_name } => { let buffer = client.channel_store().update(cx, |store, cx| { let channel_id = store - .channel_dag_entries() + .ordered_channels() .find(|(_, c)| c.name == channel_name) .unwrap() .1 @@ -151,7 +154,7 @@ impl RandomizedTest for RandomChannelBufferTest { let buffer = cx.update(|cx| { let mut left_buffer = Err(TestError::Inapplicable); client.channel_buffers().retain(|buffer| { - if buffer.read(cx).channel().name == channel_name { + if buffer.read(cx).channel(cx).unwrap().name == channel_name { left_buffer = Ok(buffer.clone()); false } else { @@ -177,7 +180,9 @@ impl RandomizedTest for RandomChannelBufferTest { client .channel_buffers() .iter() - .find(|buffer| buffer.read(cx).channel().name == channel_name) + .find(|buffer| { + buffer.read(cx).channel(cx).unwrap().name == channel_name + }) .cloned() }) .ok_or_else(|| TestError::Inapplicable)?; @@ -248,7 +253,7 @@ impl RandomizedTest for RandomChannelBufferTest { if let Some(channel_buffer) = client .channel_buffers() .iter() - .find(|b| b.read(cx).channel().id == channel_id.to_proto()) + .find(|b| b.read(cx).channel_id == channel_id.to_proto()) { let channel_buffer = channel_buffer.read(cx); diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index 39598bdaf9..1cec945282 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -208,8 +208,7 @@ impl TestPlan { false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, + github_user_id: ix as i32, }, ) .await diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 7397489b34..d6ebe1e84e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -16,9 +16,10 @@ use futures::{channel::oneshot, StreamExt as _}; use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; use language::LanguageRegistry; use node_runtime::FakeNodeRuntime; +use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; -use rpc::RECEIVE_TIMEOUT; +use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; use settings::SettingsStore; use std::{ cell::{Ref, RefCell, RefMut}, @@ -46,6 +47,7 @@ pub struct TestClient { pub username: String, pub app_state: Arc, channel_store: ModelHandle, + notification_store: ModelHandle, state: RefCell, } @@ -138,7 +140,6 @@ impl TestServer { NewUserParams { github_login: name.into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -231,7 +232,8 @@ impl TestServer { workspace::init(app_state.clone(), cx); audio::init((), cx); call::init(client.clone(), user_store.clone(), cx); - channel::init(&client, user_store, cx); + channel::init(&client, user_store.clone(), cx); + notifications::init(client.clone(), user_store, cx); }); client @@ -243,6 +245,7 @@ impl TestServer { app_state, username: name.to_string(), channel_store: cx.read(ChannelStore::global).clone(), + notification_store: cx.read(NotificationStore::global).clone(), state: Default::default(), }; client.wait_for_current_user(cx).await; @@ -327,7 +330,7 @@ impl TestServer { channel_store.invite_member( channel_id, member_client.user_id().unwrap(), - false, + ChannelRole::Member, cx, ) }) @@ -338,8 +341,8 @@ impl TestServer { member_cx .read(ChannelStore::global) - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) + .update(*member_cx, |channels, cx| { + channels.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -448,6 +451,10 @@ impl TestClient { &self.channel_store } + pub fn notification_store(&self) -> &ModelHandle { + &self.notification_store + } + pub fn user_store(&self) -> &ModelHandle { &self.app_state.user_store } @@ -604,33 +611,6 @@ impl TestClient { ) -> WindowHandle { cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) } - - pub async fn add_admin_to_channel( - &self, - user: (&TestClient, &mut TestAppContext), - channel: u64, - cx_self: &mut TestAppContext, - ) { - let (other_client, other_cx) = user; - - cx_self - .read(ChannelStore::global) - .update(cx_self, |channel_store, cx| { - channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx) - }) - .await - .unwrap(); - - cx_self.foreground().run_until_parked(); - - other_cx - .read(ChannelStore::global) - .update(other_cx, |channel_store, _| { - channel_store.respond_to_channel_invite(channel, true) - }) - .await - .unwrap(); - } } impl Drop for TestClient { diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index 98790778c9..791c6b2fa7 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +notifications = { path = "../notifications" } rich_text = { path = "../rich_text" } picker = { path = "../picker" } project = { path = "../project" } -recent_projects = {path = "../recent_projects"} +recent_projects = { path = "../recent_projects" } +rpc = { path = "../rpc" } settings = { path = "../settings" } feature_flags = {path = "../feature_flags"} theme = { path = "../theme" } @@ -52,12 +54,14 @@ zed-actions = {path = "../zed-actions"} anyhow.workspace = true futures.workspace = true +lazy_static.workspace = true log.workspace = true schemars.workspace = true postage.workspace = true serde.workspace = true serde_derive.workspace = true time.workspace = true +smallvec.workspace = true [dev-dependencies] call = { path = "../call", features = ["test-support"] } @@ -65,7 +69,12 @@ client = { path = "../client", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } + +pretty_assertions.workspace = true +tree-sitter-markdown.workspace = true diff --git a/crates/collab_ui/src/channel_view.rs b/crates/collab_ui/src/channel_view.rs index e62ee8ef4b..1bdcebd018 100644 --- a/crates/collab_ui/src/channel_view.rs +++ b/crates/collab_ui/src/channel_view.rs @@ -15,13 +15,14 @@ use gpui::{ ViewContext, ViewHandle, }; use project::Project; +use smallvec::SmallVec; use std::{ any::{Any, TypeId}, sync::Arc, }; use util::ResultExt; use workspace::{ - item::{FollowableItem, Item, ItemHandle}, + item::{FollowableItem, Item, ItemEvent, ItemHandle}, register_followable_item, searchable::SearchableItemHandle, ItemNavHistory, Pane, SaveIntent, ViewId, Workspace, WorkspaceId, @@ -140,6 +141,12 @@ impl ChannelView { editor.set_collaboration_hub(Box::new(ChannelBufferCollaborationHub( channel_buffer.clone(), ))); + editor.set_read_only( + !channel_buffer + .read(cx) + .channel(cx) + .is_some_and(|c| c.can_edit_notes()), + ); editor }); let _editor_event_subscription = cx.subscribe(&editor, |_, _, e, cx| cx.emit(e.clone())); @@ -157,8 +164,8 @@ impl ChannelView { } } - pub fn channel(&self, cx: &AppContext) -> Arc { - self.channel_buffer.read(cx).channel() + pub fn channel(&self, cx: &AppContext) -> Option> { + self.channel_buffer.read(cx).channel(cx) } fn handle_channel_buffer_event( @@ -172,6 +179,13 @@ impl ChannelView { editor.set_read_only(true); cx.notify(); }), + ChannelBufferEvent::ChannelChanged => { + self.editor.update(cx, |editor, cx| { + editor.set_read_only(!self.channel(cx).is_some_and(|c| c.can_edit_notes())); + cx.emit(editor::Event::TitleChanged); + cx.notify() + }); + } ChannelBufferEvent::BufferEdited => { if cx.is_self_focused() || self.editor.is_focused(cx) { self.acknowledge_buffer_version(cx); @@ -179,7 +193,7 @@ impl ChannelView { self.channel_store.update(cx, |store, cx| { let channel_buffer = self.channel_buffer.read(cx); store.notes_changed( - channel_buffer.channel().id, + channel_buffer.channel_id, channel_buffer.epoch(), &channel_buffer.buffer().read(cx).version(), cx, @@ -187,7 +201,7 @@ impl ChannelView { }); } } - _ => {} + ChannelBufferEvent::CollaboratorsChanged => {} } } @@ -195,7 +209,7 @@ impl ChannelView { self.channel_store.update(cx, |store, cx| { let channel_buffer = self.channel_buffer.read(cx); store.acknowledge_notes_version( - channel_buffer.channel().id, + channel_buffer.channel_id, channel_buffer.epoch(), &channel_buffer.buffer().read(cx).version(), cx, @@ -250,11 +264,17 @@ impl Item for ChannelView { style: &theme::Tab, cx: &gpui::AppContext, ) -> AnyElement { - let channel_name = &self.channel_buffer.read(cx).channel().name; - let label = if self.channel_buffer.read(cx).is_connected() { - format!("#{}", channel_name) + let label = if let Some(channel) = self.channel(cx) { + match ( + channel.can_edit_notes(), + self.channel_buffer.read(cx).is_connected(), + ) { + (true, true) => format!("#{}", channel.name), + (false, true) => format!("#{} (read-only)", channel.name), + (_, false) => format!("#{} (disconnected)", channel.name), + } } else { - format!("#{} (disconnected)", channel_name) + format!("channel notes (disconnected)") }; Label::new(label, style.label.to_owned()).into_any() } @@ -298,6 +318,10 @@ impl Item for ChannelView { fn pixel_position_of_cursor(&self, cx: &AppContext) -> Option { self.editor.read(cx).pixel_position_of_cursor(cx) } + + fn to_item_events(event: &Self::Event) -> SmallVec<[ItemEvent; 2]> { + editor::Editor::to_item_events(event) + } } impl FollowableItem for ChannelView { @@ -313,7 +337,7 @@ impl FollowableItem for ChannelView { Some(proto::view::Variant::ChannelView( proto::view::ChannelView { - channel_id: channel_buffer.channel().id, + channel_id: channel_buffer.channel_id, editor: if let Some(proto::view::Variant::Editor(proto)) = self.editor.read(cx).to_state_proto(cx) { diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 1a17b48f19..5a4dafb6d4 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1,4 +1,6 @@ -use crate::{channel_view::ChannelView, ChatPanelSettings}; +use crate::{ + channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings, +}; use anyhow::Result; use call::ActiveCall; use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore}; @@ -6,18 +8,18 @@ use client::Client; use collections::HashMap; use db::kvp::KEY_VALUE_STORE; use editor::Editor; -use feature_flags::{ChannelsAlpha, FeatureFlagAppExt}; use gpui::{ actions, elements::*, platform::{CursorStyle, MouseButton}, serde_json, views::{ItemType, Select, SelectStyle}, - AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task, - View, ViewContext, ViewHandle, WeakViewHandle, + AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View, + ViewContext, ViewHandle, WeakViewHandle, }; -use language::{language_settings::SoftWrap, LanguageRegistry}; +use language::LanguageRegistry; use menu::Confirm; +use message_editor::MessageEditor; use project::Fs; use rich_text::RichText; use serde::{Deserialize, Serialize}; @@ -31,6 +33,8 @@ use workspace::{ Workspace, }; +mod message_editor; + const MESSAGE_LOADING_THRESHOLD: usize = 50; const CHAT_PANEL_KEY: &'static str = "ChatPanel"; @@ -40,7 +44,7 @@ pub struct ChatPanel { languages: Arc, active_chat: Option<(ModelHandle, Subscription)>, message_list: ListState, - input_editor: ViewHandle, + input_editor: ViewHandle, channel_select: ViewHandle