diff --git a/Cargo.lock b/Cargo.lock index d827b314c8..e2e5de95c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,7 @@ dependencies = [ "futures 0.3.28", "gpui", "isahc", + "language", "lazy_static", "log", "matrixmultiply", @@ -1467,7 +1468,7 @@ dependencies = [ [[package]] name = "collab" -version = "0.24.0" +version = "0.25.0" dependencies = [ "anyhow", "async-trait", @@ -1503,6 +1504,7 @@ dependencies = [ "lsp", "nanoid", "node_runtime", + "notifications", "parking_lot 0.11.2", "pretty_assertions", "project", @@ -1558,13 +1560,17 @@ dependencies = [ "fuzzy", "gpui", "language", + "lazy_static", "log", "menu", + "notifications", "picker", "postage", + "pretty_assertions", "project", "recent_projects", "rich_text", + "rpc", "schemars", "serde", "serde_derive", @@ -1573,6 +1579,7 @@ dependencies = [ "theme", "theme_selector", "time", + "tree-sitter-markdown", "util", "vcs_menu", "workspace", @@ -4730,6 +4737,26 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "notifications" +version = "0.1.0" +dependencies = [ + "anyhow", + "channel", + "client", + "clock", + "collections", + "db", + "feature_flags", + "gpui", + "rpc", + "settings", + "sum_tree", + "text", + "time", + "util", +] + [[package]] name = "ntapi" version = "0.3.7" @@ -6404,8 +6431,10 @@ dependencies = [ "rsa 0.4.0", "serde", "serde_derive", + "serde_json", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", @@ -6626,6 +6655,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "rustybuzz" version = "0.3.0" @@ -7700,6 +7735,22 @@ name = "strum" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.37", +] [[package]] name = "subtle" @@ -9098,6 +9149,7 @@ name = "vcs_menu" version = "0.1.0" dependencies = [ "anyhow", + "fs", "fuzzy", "gpui", "picker", @@ -10042,7 +10094,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.109.0" +version = "0.110.0" dependencies = [ "activity_indicator", "ai", @@ -10097,6 +10149,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "notifications", "num_cpus", "outline", "parking_lot 0.11.2", diff --git a/Cargo.toml b/Cargo.toml index 995cd15edd..cf977b8fe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ members = [ "crates/media", "crates/menu", "crates/node_runtime", + "crates/notifications", "crates/outline", "crates/picker", "crates/plugin", @@ -112,6 +113,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } smallvec = { version = "1.6", features = ["union"] } smol = { version = "1.2" } +strum = { version = "0.25.0", features = ["derive"] } sysinfo = "0.29.10" tempdir = { version = "0.3.7" } thiserror = { version = "1.0.29" } diff --git a/assets/icons/bell.svg b/assets/icons/bell.svg new file mode 100644 index 0000000000..ea1c6dd42e --- /dev/null +++ b/assets/icons/bell.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 4143e5dd41..e70b563359 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -142,6 +142,14 @@ // Default width of the channels panel. "default_width": 240 }, + "notification_panel": { + // Whether to show the collaboration panel button in the status bar. + "button": true, + // Where to dock channels panel. Can be 'left' or 'right'. + "dock": "right", + // Default width of the channels panel. + "default_width": 240 + }, "assistant": { // Whether to show the assistant panel button in the status bar. "button": true, diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 542d7f422f..b24c4e5ece 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -11,6 +11,7 @@ doctest = false [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } +language = { path = "../language" } async-trait.workspace = true anyhow.workspace = true futures.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 5256a6a643..f168c15793 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,2 +1,4 @@ pub mod completion; pub mod embedding; +pub mod models; +pub mod templates; diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 170b2268f9..de6ce9da71 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -53,6 +53,8 @@ pub struct OpenAIRequest { pub model: String, pub messages: Vec, pub stream: bool, + pub stop: Vec, + pub temperature: f32, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 4587ece0a2..4d5e40fad9 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, ViewContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -20,9 +20,11 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::completion::OPENAI_API_URL; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } @@ -87,6 +89,7 @@ impl Embedding { #[derive(Clone)] pub struct OpenAIEmbeddings { + pub api_key: Option, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings { const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { + pub fn authenticate(&mut self, cx: &mut ViewContext) { + if self.api_key.is_none() { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + + if let Some(api_key) = api_key { + self.api_key = Some(api_key); + } + } + } + pub fn new( + api_key: Option, + client: Arc, + executor: Arc, + ) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); OpenAIEmbeddings { + api_key, client, executor, rate_limit_count_rx, @@ -237,8 +265,9 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() + self.api_key.is_some() } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let Some(api_key) = self.api_key.clone() else { + return Err(anyhow!("no open ai key provided")); + }; let mut request_number = 0; let mut rate_limiting = false; @@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs new file mode 100644 index 0000000000..d0206cc41c --- /dev/null +++ b/crates/ai/src/models.rs @@ -0,0 +1,66 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[..length].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + bpe.decode(tokens[length..].to_vec()) + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs new file mode 100644 index 0000000000..bda1d6c30e --- /dev/null +++ b/crates/ai/src/templates/base.rs @@ -0,0 +1,350 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::templates::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate(&content, max_token_length)?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + #[derive(Clone)] + struct DummyLanguageModel { + capacity: usize, + } + + impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[..length] + .into_iter() + .collect::(), + ) + } + fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + anyhow::Ok( + content.chars().collect::>()[length..] + .into_iter() + .collect::(), + ) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } + } + + let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs new file mode 100644 index 0000000000..1afd61192e --- /dev/null +++ b/crates/ai/src/templates/file_context.rs @@ -0,0 +1,160 @@ +use anyhow::anyhow; +use language::BufferSnapshot; +use language::ToOffset; + +use crate::models::LanguageModel; +use crate::templates::base::PromptArguments; +use crate::templates::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate_start(&start_window, start_goal_tokens)?; + let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs new file mode 100644 index 0000000000..1eeb197f93 --- /dev/null +++ b/crates/ai/src/templates/generate.rs @@ -0,0 +1,95 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate(&prompt, max_tokens)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/templates/mod.rs new file mode 100644 index 0000000000..0025269a44 --- /dev/null +++ b/crates/ai/src/templates/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/templates/preamble.rs new file mode 100644 index 0000000000..9eabaaeb97 --- /dev/null +++ b/crates/ai/src/templates/preamble.rs @@ -0,0 +1,52 @@ +use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/templates/repository_context.rs new file mode 100644 index 0000000000..a8e7f4b5af --- /dev/null +++ b/crates/ai/src/templates/repository_context.rs @@ -0,0 +1,94 @@ +use crate::templates::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui::{AsyncAppContext, ModelHandle}; +use language::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new(buffer: ModelHandle, range: Range, cx: &AsyncAppContext) -> Self { + let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + }); + + PromptCodeSnippet { + path: file_path, + language_name, + content, + } + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 65edb1832f..ca8c54a285 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,12 +1,15 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, codegen::{self, Codegen, CodegenKind}, - prompts::{generate_content_prompt, PromptCodeSnippet}, + prompts::generate_content_prompt, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, + templates::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -609,6 +612,18 @@ impl AssistantPanel { let project = pending_assist.project.clone(); + let project_name = if let Some(project) = project.upgrade(cx) { + Some( + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/"), + ) + } else { + None + }; + self.inline_prompt_history .retain(|prompt| prompt != user_prompt); self.inline_prompt_history.push_back(user_prompt.into()); @@ -646,7 +661,19 @@ impl AssistantPanel { None }; - let codegen_kind = codegen.read(cx).kind().clone(); + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.to_string() != "Markdown".to_string() { + 0.5 + } else { + 1.0 + } + } else { + 1.0 + }; + let user_prompt = user_prompt.to_string(); let snippets = if retrieve_context { @@ -668,14 +695,7 @@ impl AssistantPanel { let snippets = cx.spawn(|_, cx| async move { let mut snippets = Vec::new(); for result in search_results.await { - snippets.push(PromptCodeSnippet::new(result, &cx)); - - // snippets.push(result.buffer.read_with(&cx, |buffer, _| { - // buffer - // .snapshot() - // .text_for_range(result.range) - // .collect::() - // })); + snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx)); } snippets }); @@ -696,11 +716,11 @@ impl AssistantPanel { generate_content_prompt( user_prompt, language_name, - &buffer, + buffer, range, - codegen_kind, snippets, model_name, + project_name, ) }); @@ -717,18 +737,23 @@ impl AssistantPanel { } cx.spawn(|_, mut cx| async move { - let prompt = prompt.await; + // I Don't know if we want to return a ? here. + let prompt = prompt.await?; messages.push(RequestMessage { role: Role::User, content: prompt, }); + let request = OpenAIRequest { model: model.full_name().into(), messages, stream: true, + stop: vec!["|END|>".to_string()], + temperature, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); + anyhow::Ok(()) }) .detach(); } @@ -1718,6 +1743,8 @@ impl Conversation { .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request); @@ -2002,6 +2029,8 @@ impl Conversation { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 7aafe75920..dffcbc2923 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,60 +1,13 @@ -use crate::codegen::CodegenKind; -use gpui::AsyncAppContext; +use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::templates::file_context::FileContext; +use ai::templates::generate::GenerateInlineContent; +use ai::templates::preamble::EngineerPreamble; +use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; -use semantic_index::SearchResult; use std::cmp::{self, Reverse}; -use std::fmt::Write; use std::ops::Range; -use std::path::PathBuf; -use tiktoken_rs::ChatCompletionRequestMessage; - -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self { - let (content, language_name, file_path) = - search_result.buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot - .text_for_range(search_result.range.clone()) - .collect::(); - - let language_name = buffer - .language() - .and_then(|language| Some(language.name().to_string())); - - let file_path = buffer - .file() - .and_then(|file| Some(file.path().to_path_buf())); - - (content, language_name, file_path) - }); - - PromptCodeSnippet { - path: file_path, - language_name, - content, - } - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .and_then(|path| Some(path.to_string_lossy().to_string())) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} +use std::sync::Arc; #[allow(dead_code)] fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { @@ -170,134 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, - buffer: &BufferSnapshot, - range: Range, - kind: CodegenKind, + buffer: BufferSnapshot, + range: Range, search_results: Vec, model: &str, -) -> String { - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; - - let mut prompts = Vec::new(); - let range = range.to_offset(buffer); - - // General Preamble - if let Some(language_name) = language_name { - prompts.push(format!("You're an expert {language_name} engineer.\n")); + project_name: Option, +) -> anyhow::Result { + // Using new Prompt Templates + let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let lang_name = if let Some(language_name) = language_name { + Some(language_name.to_string()) } else { - prompts.push("You're an expert engineer.\n".to_string()); - } - - // Snippets - let mut snippet_position = prompts.len() - 1; - - let mut content = String::new(); - content.extend(buffer.text_for_range(0..range.start)); - if range.start == range.end { - content.push_str("<|START|>"); - } else { - content.push_str("<|START|"); - } - content.extend(buffer.text_for_range(range.clone())); - if range.start != range.end { - content.push_str("|END|>"); - } - content.extend(buffer.text_for_range(range.end..buffer.len())); - - prompts.push("The file you are currently working on has the following content:\n".to_string()); - - if let Some(language_name) = language_name { - let language_name = language_name.to_lowercase(); - prompts.push(format!("```{language_name}\n{content}\n```")); - } else { - prompts.push(format!("```\n{content}\n```")); - } - - match kind { - CodegenKind::Generate { position: _ } => { - prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string()); - prompts - .push("Assume the cursor is located where the `<|START|` marker is.".to_string()); - prompts.push( - "Text can't be replaced, so assume your answer will be inserted at the cursor." - .to_string(), - ); - prompts.push(format!( - "Generate text based on the users prompt: {user_prompt}" - )); - } - CodegenKind::Transform { range: _ } => { - prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string()); - prompts.push(format!( - "Modify the users code selected text based upon the users prompt: '{user_prompt}'" - )); - prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string()); - } - } - - if let Some(language_name) = language_name { - prompts.push(format!( - "Your answer MUST always and only be valid {language_name}" - )); - } - prompts.push("Never make remarks about the output.".to_string()); - prompts.push("Do not return any text, except the generated code.".to_string()); - prompts.push("Do not wrap your text in a Markdown block".to_string()); - - let current_messages = [ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(prompts.join("\n")), - function_call: None, - name: None, - }]; - - let mut remaining_token_count = if let Ok(current_token_count) = - tiktoken_rs::num_tokens_from_messages(model, ¤t_messages) - { - let max_token_count = tiktoken_rs::model::get_context_size(model); - let intermediate_token_count = max_token_count - current_token_count; - - if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION { - 0 - } else { - intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION - } - } else { - // If tiktoken fails to count token count, assume we have no space remaining. - 0 + None }; - // TODO: - // - add repository name to snippet - // - add file path - // - add language - if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { - let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; + let args = PromptArguments { + model: openai_model, + language_name: lang_name.clone(), + project_name, + snippets: search_results.clone(), + reserved_tokens: 1000, + buffer: Some(buffer), + selected_range: Some(range), + user_prompt: Some(user_prompt.clone()), + }; - for search_result in search_results { - let mut snippet_prompt = template.to_string(); - let snippet = search_result.to_string(); - writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap(); + let templates: Vec<(PromptPriority, Box)> = vec![ + (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(RepositoryContext {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(FileContext {}), + ), + ( + PromptPriority::Mandatory, + Box::new(GenerateInlineContent {}), + ), + ]; + let chain = PromptChain::new(args, templates); + let (prompt, _) = chain.generate(true)?; - let token_count = encoding - .encode_with_special_tokens(snippet_prompt.as_str()) - .len(); - if token_count <= remaining_token_count { - if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT { - prompts.insert(snippet_position, snippet_prompt); - snippet_position += 1; - remaining_token_count -= token_count; - // If you have already added the template to the prompt, remove the template. - template = ""; - } - } else { - break; - } - } - } - - prompts.join("\n") + anyhow::Ok(prompt) } #[cfg(test)] diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index d31d4b3c8c..b6db304a70 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -7,7 +7,10 @@ use gpui::{AppContext, ModelHandle}; use std::sync::Arc; pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; -pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId}; +pub use channel_chat::{ + mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, + MessageParams, +}; pub use channel_store::{ Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore, }; diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 299c91759c..ef11d96424 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -3,12 +3,17 @@ use anyhow::{anyhow, Result}; use client::{ proto, user::{User, UserStore}, - Client, Subscription, TypedEnvelope, + Client, Subscription, TypedEnvelope, UserId, }; use futures::lock::Mutex; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use rand::prelude::*; -use std::{collections::HashSet, mem, ops::Range, sync::Arc}; +use std::{ + collections::HashSet, + mem, + ops::{ControlFlow, Range}, + sync::Arc, +}; use sum_tree::{Bias, SumTree}; use time::OffsetDateTime; use util::{post_inc, ResultExt as _, TryFutureExt}; @@ -16,6 +21,7 @@ use util::{post_inc, ResultExt as _, TryFutureExt}; pub struct ChannelChat { pub channel_id: ChannelId, messages: SumTree, + acknowledged_message_ids: HashSet, channel_store: ModelHandle, loaded_all_messages: bool, last_acknowledged_id: Option, @@ -27,6 +33,12 @@ pub struct ChannelChat { _subscription: Subscription, } +#[derive(Debug, PartialEq, Eq)] +pub struct MessageParams { + pub text: String, + pub mentions: Vec<(Range, UserId)>, +} + #[derive(Clone, Debug)] pub struct ChannelMessage { pub id: ChannelMessageId, @@ -34,6 +46,7 @@ pub struct ChannelMessage { pub timestamp: OffsetDateTime, pub sender: Arc, pub nonce: u128, + pub mentions: Vec<(Range, UserId)>, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -105,6 +118,7 @@ impl ChannelChat { rpc: client, outgoing_messages_lock: Default::default(), messages: Default::default(), + acknowledged_message_ids: Default::default(), loaded_all_messages, next_pending_message_id: 0, last_acknowledged_id: None, @@ -123,12 +137,16 @@ impl ChannelChat { .cloned() } + pub fn client(&self) -> &Arc { + &self.rpc + } + pub fn send_message( &mut self, - body: String, + message: MessageParams, cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { + ) -> Result>> { + if message.text.is_empty() { Err(anyhow!("message body can't be empty"))?; } @@ -145,9 +163,10 @@ impl ChannelChat { SumTree::from_item( ChannelMessage { id: pending_id, - body: body.clone(), + body: message.text.clone(), sender: current_user, timestamp: OffsetDateTime::now_utc(), + mentions: message.mentions.clone(), nonce, }, &(), @@ -161,20 +180,18 @@ impl ChannelChat { let outgoing_message_guard = outgoing_messages_lock.lock().await; let request = rpc.request(proto::SendChannelMessage { channel_id, - body, + body: message.text, nonce: Some(nonce.into()), + mentions: mentions_to_proto(&message.mentions), }); let response = request.await?; drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; + let response = response.message.ok_or_else(|| anyhow!("invalid message"))?; + let id = response.id; + let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) + Ok(id) }) })) } @@ -194,41 +211,76 @@ impl ChannelChat { }) } - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.channel_id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - anyhow::Ok(()) + pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> Option>> { + if self.loaded_all_messages { + return None; + } + + let rpc = self.rpc.clone(); + let user_store = self.user_store.clone(); + let channel_id = self.channel_id; + let before_message_id = self.first_loaded_message_id()?; + Some(cx.spawn(|this, mut cx| { + async move { + let response = rpc + .request(proto::GetChannelMessages { + channel_id, + before_message_id, + }) + .await?; + let loaded_all_messages = response.done; + let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; + this.update(&mut cx, |this, cx| { + this.loaded_all_messages = loaded_all_messages; + this.insert_messages(messages, cx); + }); + anyhow::Ok(()) + } + .log_err() + })) + } + + pub fn first_loaded_message_id(&mut self) -> Option { + self.messages.first().and_then(|message| match message.id { + ChannelMessageId::Saved(id) => Some(id), + ChannelMessageId::Pending(_) => None, + }) + } + + /// Load all of the chat messages since a certain message id. + /// + /// For now, we always maintain a suffix of the channel's messages. + pub async fn load_history_since_message( + chat: ModelHandle, + message_id: u64, + mut cx: AsyncAppContext, + ) -> Option { + loop { + let step = chat.update(&mut cx, |chat, cx| { + if let Some(first_id) = chat.first_loaded_message_id() { + if first_id <= message_id { + let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(); + let message_id = ChannelMessageId::Saved(message_id); + cursor.seek(&message_id, Bias::Left, &()); + return ControlFlow::Break( + if cursor + .item() + .map_or(false, |message| message.id == message_id) + { + Some(cursor.start().1 .0) + } else { + None + }, + ); } - .log_err() - }) - .detach(); - return true; + } + ControlFlow::Continue(chat.load_more_messages(cx)) + }); + match step { + ControlFlow::Break(ix) => return ix, + ControlFlow::Continue(task) => task?.await?, } } - false } pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext) { @@ -287,6 +339,7 @@ impl ChannelChat { let request = rpc.request(proto::SendChannelMessage { channel_id, body: pending_message.body, + mentions: mentions_to_proto(&pending_message.mentions), nonce: Some(pending_message.nonce.into()), }); let response = request.await?; @@ -322,6 +375,17 @@ impl ChannelChat { cursor.item().unwrap() } + pub fn acknowledge_message(&mut self, id: u64) { + if self.acknowledged_message_ids.insert(id) { + self.rpc + .send(proto::AckChannelMessage { + channel_id: self.channel_id, + message_id: id, + }) + .ok(); + } + } + pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(); cursor.seek(&Count(range.start), Bias::Right, &()); @@ -454,22 +518,7 @@ async fn messages_from_proto( user_store: &ModelHandle, cx: &mut AsyncAppContext, ) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } + let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?; let mut result = SumTree::new(); result.extend(messages, &()); Ok(result) @@ -489,6 +538,14 @@ impl ChannelMessage { Ok(ChannelMessage { id: ChannelMessageId::Saved(message.id), body: message.body, + mentions: message + .mentions + .into_iter() + .filter_map(|mention| { + let range = mention.range?; + Some((range.start as usize..range.end as usize, mention.user_id)) + }) + .collect(), timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, nonce: message @@ -501,6 +558,43 @@ impl ChannelMessage { pub fn is_pending(&self) -> bool { matches!(self.id, ChannelMessageId::Pending(_)) } + + pub async fn from_proto_vec( + proto_messages: Vec, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, + ) -> Result> { + let unique_user_ids = proto_messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store + .update(cx, |user_store, cx| { + user_store.get_users(unique_user_ids, cx) + }) + .await?; + + let mut messages = Vec::with_capacity(proto_messages.len()); + for message in proto_messages { + messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); + } + Ok(messages) + } +} + +pub fn mentions_to_proto(mentions: &[(Range, UserId)]) -> Vec { + mentions + .iter() + .map(|(range, user_id)| proto::ChatMention { + range: Some(proto::Range { + start: range.start as u64, + end: range.end as u64, + }), + user_id: *user_id as u64, + }) + .collect() } impl sum_tree::Item for ChannelMessage { @@ -541,3 +635,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { self.0 += summary.count; } } + +impl<'a> From<&'a str> for MessageParams { + fn from(value: &'a str) -> Self { + Self { + text: value.into(), + mentions: Vec::new(), + } + } +} diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 2e1e92431e..2665c2e1ec 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,6 +1,6 @@ mod channel_index; -use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; +use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage}; use anyhow::{anyhow, Result}; use channel_index::ChannelIndex; use client::{Client, Subscription, User, UserId, UserStore}; @@ -157,9 +157,6 @@ impl ChannelStore { this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx)); } } - if status.is_connected() { - } else { - } } Some(()) }); @@ -245,6 +242,12 @@ impl ChannelStore { self.channel_index.by_id().values().nth(ix) } + pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool { + self.channel_invitations + .iter() + .any(|channel| channel.id == channel_id) + } + pub fn channel_invitations(&self) -> &[Arc] { &self.channel_invitations } @@ -278,6 +281,33 @@ impl ChannelStore { ) } + pub fn fetch_channel_messages( + &self, + message_ids: Vec, + cx: &mut ModelContext, + ) -> Task>> { + let request = if message_ids.is_empty() { + None + } else { + Some( + self.client + .request(proto::GetChannelMessagesById { message_ids }), + ) + }; + cx.spawn_weak(|this, mut cx| async move { + if let Some(request) = request { + let response = request.await?; + let this = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("channel store dropped"))?; + let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); + ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await + } else { + Ok(Vec::new()) + } + }) + } + pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option { self.channel_index .by_id() @@ -689,14 +719,15 @@ impl ChannelStore { &mut self, channel_id: ChannelId, accept: bool, - ) -> impl Future> { + cx: &mut ModelContext, + ) -> Task> { let client = self.client.clone(); - async move { + cx.background().spawn(async move { client .request(proto::RespondToChannelInvite { channel_id, accept }) .await?; Ok(()) - } + }) } pub fn get_channel_member_details( @@ -764,6 +795,11 @@ impl ChannelStore { } fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.channel_index.clear(); + self.channel_invitations.clear(); + self.channel_participants.clear(); + self.channel_index.clear(); + self.outgoing_invites.clear(); self.disconnect_channel_buffers_task.take(); for chat in self.opened_chats.values() { @@ -873,11 +909,6 @@ impl ChannelStore { } fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext) { - self.channel_index.clear(); - self.channel_invitations.clear(); - self.channel_participants.clear(); - self.channel_index.clear(); - self.outgoing_invites.clear(); cx.notify(); self.disconnect_channel_buffers_task.get_or_insert_with(|| { diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index 1358378d16..dd4f24e2ca 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -210,6 +210,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "a".into(), timestamp: 1000, sender_id: 5, + mentions: vec![], nonce: Some(1.into()), }, proto::ChannelMessage { @@ -217,6 +218,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "b".into(), timestamp: 1001, sender_id: 6, + mentions: vec![], nonce: Some(2.into()), }, ], @@ -263,6 +265,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "c".into(), timestamp: 1002, sender_id: 7, + mentions: vec![], nonce: Some(3.into()), }), }); @@ -300,7 +303,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Scroll up to view older messages. channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); + channel.load_more_messages(cx).unwrap().detach(); }); let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); @@ -316,6 +319,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 998, sender_id: 5, nonce: Some(4.into()), + mentions: vec![], }, proto::ChannelMessage { id: 9, @@ -323,6 +327,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 999, sender_id: 6, nonce: Some(5.into()), + mentions: vec![], }, ], }, diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 6aa41708e3..8299b7c6e4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -293,21 +293,19 @@ impl UserStore { // No need to paralellize here let mut updated_contacts = Vec::new(); for contact in message.contacts { - let should_notify = contact.should_notify; - updated_contacts.push(( - Arc::new(Contact::from_proto(contact, &this, &mut cx).await?), - should_notify, + updated_contacts.push(Arc::new( + Contact::from_proto(contact, &this, &mut cx).await?, )); } let mut incoming_requests = Vec::new(); for request in message.incoming_requests { - incoming_requests.push({ - let user = this - .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx)) - .await?; - (user, request.should_notify) - }); + incoming_requests.push( + this.update(&mut cx, |this, cx| { + this.get_user(request.requester_id, cx) + }) + .await?, + ); } let mut outgoing_requests = Vec::new(); @@ -330,13 +328,7 @@ impl UserStore { this.contacts .retain(|contact| !removed_contacts.contains(&contact.user.id)); // Update existing contacts and insert new ones - for (updated_contact, should_notify) in updated_contacts { - if should_notify { - cx.emit(Event::Contact { - user: updated_contact.user.clone(), - kind: ContactEventKind::Accepted, - }); - } + for updated_contact in updated_contacts { match this.contacts.binary_search_by_key( &&updated_contact.user.github_login, |contact| &contact.user.github_login, @@ -359,14 +351,7 @@ impl UserStore { } }); // Update existing incoming requests and insert new ones - for (user, should_notify) in incoming_requests { - if should_notify { - cx.emit(Event::Contact { - user: user.clone(), - kind: ContactEventKind::Requested, - }); - } - + for user in incoming_requests { match this .incoming_contact_requests .binary_search_by_key(&&user.github_login, |contact| { @@ -415,6 +400,12 @@ impl UserStore { &self.incoming_contact_requests } + pub fn has_incoming_contact_request(&self, user_id: u64) -> bool { + self.incoming_contact_requests + .iter() + .any(|user| user.id == user_id) + } + pub fn outgoing_contact_requests(&self) -> &[Arc] { &self.outgoing_contact_requests } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b91f0e1a5f..64bc191b21 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] default-run = "collab" edition = "2021" name = "collab" -version = "0.24.0" +version = "0.25.0" publish = false [[bin]] @@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } node_runtime = { path = "../node_runtime" } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 8eb6b52fd8..7fa808b498 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -192,7 +192,7 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" VARCHAR NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT now, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "visibility" VARCHAR NOT NULL ); @@ -214,7 +214,15 @@ CREATE TABLE IF NOT EXISTS "channel_messages" ( "nonce" BLOB NOT NULL ); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); -CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); + +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); CREATE TABLE "channel_paths" ( "id_path" TEXT NOT NULL PRIMARY KEY, @@ -314,3 +322,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( ); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231004130100_create_notifications.sql b/crates/collab/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000..93c282c631 --- /dev/null +++ b/crates/collab/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,22 @@ +CREATE TABLE "notification_kinds" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231018102700_create_mentions.sql b/crates/collab/migrations/20231018102700_create_mentions.sql new file mode 100644 index 0000000000..221a1748cf --- /dev/null +++ b/crates/collab/migrations/20231018102700_create_mentions.sql @@ -0,0 +1,11 @@ +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +-- We use 'on conflict update' with this index, so it should be per-user. +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); +DROP INDEX "index_channel_messages_on_nonce"; diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index cb1594e941..88fe0a647b 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -71,7 +71,6 @@ async fn main() { db::NewUserParams { github_login: github_user.login, github_user_id: github_user.id, - invite_count: 5, }, ) .await diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 936b655200..8c6bca2113 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -13,6 +13,7 @@ use anyhow::anyhow; use collections::{BTreeMap, HashMap, HashSet}; use dashmap::DashMap; use futures::StreamExt; +use queries::channels::ChannelGraph; use rand::{prelude::StdRng, Rng, SeedableRng}; use rpc::{ proto::{self}, @@ -20,7 +21,7 @@ use rpc::{ }; use sea_orm::{ entity::prelude::*, - sea_query::{Alias, Expr, OnConflict, Query}, + sea_query::{Alias, Expr, OnConflict}, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, @@ -47,14 +48,14 @@ pub use ids::*; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; -use self::queries::channels::ChannelGraph; - pub struct Database { options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, rng: Mutex, executor: Executor, + notification_kinds_by_id: HashMap, + notification_kinds_by_name: HashMap, #[cfg(test)] runtime: Option, } @@ -69,6 +70,8 @@ impl Database { pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), rng: Mutex::new(StdRng::seed_from_u64(0)), + notification_kinds_by_id: HashMap::default(), + notification_kinds_by_name: HashMap::default(), executor, #[cfg(test)] runtime: None, @@ -121,6 +124,11 @@ impl Database { Ok(new_migrations) } + pub async fn initialize_static_data(&mut self) -> Result<()> { + self.initialize_notification_kinds().await?; + Ok(()) + } + pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, @@ -361,18 +369,9 @@ impl RoomGuard { #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, + Accepted { user_id: UserId, busy: bool }, + Outgoing { user_id: UserId }, + Incoming { user_id: UserId }, } impl Contact { @@ -385,6 +384,15 @@ impl Contact { } } +pub type NotificationBatch = Vec<(UserId, proto::Notification)>; + +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, @@ -417,7 +425,6 @@ pub struct WaitlistSummary { pub struct NewUserParams { pub github_login: String, pub github_user_id: i32, - pub invite_count: i32, } #[derive(Debug)] @@ -466,6 +473,24 @@ pub enum SetMemberRoleResult { MembershipUpdated(MembershipUpdated), } +#[derive(Debug)] +pub struct InviteMemberResult { + pub channel: Channel, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RespondToChannelInvite { + pub membership_update: Option, + pub notifications: NotificationBatch, +} + +#[derive(Debug)] +pub struct RemoveChannelMemberResult { + pub membership_update: MembershipUpdated, + pub notification_id: Option, +} + #[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)] pub struct Channel { pub id: ChannelId, diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 4c6e7116ee..5f0df90811 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -81,6 +81,8 @@ id_type!(SignupId); id_type!(UserId); id_type!(ChannelBufferCollaboratorId); id_type!(FlagId); +id_type!(NotificationId); +id_type!(NotificationKindId); #[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] #[sea_orm(rs_type = "String", db_type = "String(None)")] diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 80bd8704b2..629e26f1a9 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod messages; +pub mod notifications; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/access_tokens.rs b/crates/collab/src/db/queries/access_tokens.rs index def9428a2b..589b6483df 100644 --- a/crates/collab/src/db/queries/access_tokens.rs +++ b/crates/collab/src/db/queries/access_tokens.rs @@ -1,4 +1,5 @@ use super::*; +use sea_orm::sea_query::Query; impl Database { pub async fn create_access_token( diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index dbc701e4ae..0abd36049d 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -349,11 +349,11 @@ impl Database { &self, channel_id: ChannelId, invitee_id: UserId, - admin_id: UserId, + inviter_id: UserId, role: ChannelRole, - ) -> Result { + ) -> Result { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, admin_id, &*tx) + self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) .await?; channel_member::ActiveModel { @@ -371,11 +371,31 @@ impl Database { .await? .unwrap(); - Ok(Channel { + let channel = Channel { id: channel.id, visibility: channel.visibility, name: channel.name, role, + }; + + let notifications = self + .create_notification( + invitee_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: channel.name.clone(), + inviter_id: inviter_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect(); + + Ok(InviteMemberResult { + channel, + notifications, }) }) .await @@ -445,9 +465,9 @@ impl Database { channel_id: ChannelId, user_id: UserId, accept: bool, - ) -> Result> { + ) -> Result { self.transaction(move |tx| async move { - if accept { + let membership_update = if accept { let rows_affected = channel_member::Entity::update_many() .set(channel_member::ActiveModel { accepted: ActiveValue::Set(accept), @@ -467,26 +487,45 @@ impl Database { Err(anyhow!("no such invitation"))?; } - return Ok(Some( + Some( self.calculate_membership_updated(channel_id, user_id, &*tx) .await?, - )); - } + ) + } else { + let rows_affected = channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected; + if rows_affected == 0 { + Err(anyhow!("no such invitation"))?; + } - let rows_affected = channel_member::ActiveModel { - channel_id: ActiveValue::Unchanged(channel_id), - user_id: ActiveValue::Unchanged(user_id), - ..Default::default() - } - .delete(&*tx) - .await? - .rows_affected; + None + }; - if rows_affected == 0 { - Err(anyhow!("no such invitation"))?; - } - - Ok(None) + Ok(RespondToChannelInvite { + membership_update, + notifications: self + .mark_notification_as_read_with_response( + user_id, + &rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + accept, + &*tx, + ) + .await? + .into_iter() + .collect(), + }) }) .await } @@ -550,7 +589,7 @@ impl Database { channel_id: ChannelId, member_id: UserId, admin_id: UserId, - ) -> Result { + ) -> Result { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; @@ -568,9 +607,22 @@ impl Database { Err(anyhow!("no such member"))?; } - Ok(self - .calculate_membership_updated(channel_id, member_id, &*tx) - .await?) + Ok(RemoveChannelMemberResult { + membership_update: self + .calculate_membership_updated(channel_id, member_id, &*tx) + .await?, + notification_id: self + .remove_notification( + member_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + &*tx, + ) + .await?, + }) }) .await } @@ -911,6 +963,47 @@ impl Database { .await } + pub async fn get_channel_participant_details( + &self, + channel_id: ChannelId, + user_id: UserId, + ) -> Result> { + let (role, members) = self + .transaction(move |tx| async move { + let role = self + .check_user_is_channel_participant(channel_id, user_id, &*tx) + .await?; + Ok(( + role, + self.get_channel_participant_details_internal(channel_id, &*tx) + .await?, + )) + }) + .await?; + + if role == ChannelRole::Admin { + Ok(members + .into_iter() + .map(|channel_member| channel_member.to_proto()) + .collect()) + } else { + return Ok(members + .into_iter() + .filter_map(|member| { + if member.kind == proto::channel_member::Kind::Invitee { + return None; + } + Some(ChannelMember { + role: member.role, + user_id: member.user_id, + kind: proto::channel_member::Kind::Member, + }) + }) + .map(|channel_member| channel_member.to_proto()) + .collect()); + } + } + async fn get_channel_participant_details_internal( &self, channel_id: ChannelId, @@ -1003,28 +1096,6 @@ impl Database { .collect()) } - pub async fn get_channel_participant_details( - &self, - channel_id: ChannelId, - admin_id: UserId, - ) -> Result> { - let members = self - .transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, admin_id, &*tx) - .await?; - - Ok(self - .get_channel_participant_details_internal(channel_id, &*tx) - .await?) - }) - .await?; - - Ok(members - .into_iter() - .map(|channel_member| channel_member.to_proto()) - .collect()) - } - pub async fn get_channel_participants( &self, channel_id: ChannelId, @@ -1062,9 +1133,10 @@ impl Database { channel_id: ChannelId, user_id: UserId, tx: &DatabaseTransaction, - ) -> Result<()> { - match self.channel_role_for_user(channel_id, user_id, tx).await? { - Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()), + ) -> Result { + let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?; + match channel_role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()), Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( "user is not a channel member or channel does not exist" ))?, diff --git a/crates/collab/src/db/queries/contacts.rs b/crates/collab/src/db/queries/contacts.rs index 2171f1a6bf..f31f1addbd 100644 --- a/crates/collab/src/db/queries/contacts.rs +++ b/crates/collab/src/db/queries/contacts.rs @@ -8,7 +8,6 @@ impl Database { user_id_b: UserId, a_to_b: bool, accepted: bool, - should_notify: bool, user_a_busy: bool, user_b_busy: bool, } @@ -53,7 +52,6 @@ impl Database { if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, busy: db_contact.user_b_busy, }); } else if db_contact.a_to_b { @@ -63,19 +61,16 @@ impl Database { } else { contacts.push(Contact::Incoming { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, }); } } else if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, busy: db_contact.user_a_busy, }); } else if db_contact.a_to_b { contacts.push(Contact::Incoming { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, }); } else { contacts.push(Contact::Outgoing { @@ -124,7 +119,11 @@ impl Database { .await } - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + pub async fn send_contact_request( + &self, + sender_id: UserId, + receiver_id: UserId, + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) @@ -161,11 +160,22 @@ impl Database { .exec_without_returning(&*tx) .await?; - if rows_affected == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? + if rows_affected == 0 { + Err(anyhow!("contact already requested"))?; } + + Ok(self + .create_notification( + receiver_id, + rpc::Notification::ContactRequest { + sender_id: sender_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -179,7 +189,11 @@ impl Database { /// /// * `requester_id` - The user that initiates this request /// * `responder_id` - The user that will be removed - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result { + pub async fn remove_contact( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<(bool, Option)> { self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) @@ -198,7 +212,21 @@ impl Database { .ok_or_else(|| anyhow!("no such contact"))?; contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; - Ok(contact.accepted) + + let mut deleted_notification_id = None; + if !contact.accepted { + deleted_notification_id = self + .remove_notification( + responder_id, + rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + &*tx, + ) + .await?; + } + + Ok((contact.accepted, deleted_notification_id)) }) .await } @@ -249,7 +277,7 @@ impl Database { responder_id: UserId, requester_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) @@ -287,11 +315,38 @@ impl Database { result.rows_affected }; - if rows_affected == 1 { - Ok(()) - } else { + if rows_affected == 0 { Err(anyhow!("no such contact request"))? } + + let mut notifications = Vec::new(); + notifications.extend( + self.mark_notification_as_read_with_response( + responder_id, + &rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + accept, + &*tx, + ) + .await?, + ); + + if accept { + notifications.extend( + self.create_notification( + requester_id, + rpc::Notification::ContactRequestAccepted { + responder_id: responder_id.to_proto(), + }, + true, + &*tx, + ) + .await?, + ); + } + + Ok(notifications) }) .await } diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index 4bcac025c5..562c4e45c4 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,4 +1,7 @@ use super::*; +use futures::Stream; +use rpc::Notification; +use sea_orm::TryInsertResult; use time::OffsetDateTime; impl Database { @@ -87,43 +90,118 @@ impl Database { condition = condition.add(channel_message::Column::Id.lt(before_message_id)); } - let mut rows = channel_message::Entity::find() + let rows = channel_message::Entity::find() .filter(condition) .order_by_desc(channel_message::Column::Id) .limit(count as u64) .stream(&*tx) .await?; - let mut messages = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - let nonce = row.nonce.as_u64_pair(); - messages.push(proto::ChannelMessage { - id: row.id.to_proto(), - sender_id: row.sender_id.to_proto(), - body: row.body, - timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, - nonce: Some(proto::Nonce { - upper_half: nonce.0, - lower_half: nonce.1, + self.load_channel_messages(rows, &*tx).await + }) + .await + } + + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) + .stream(&*tx) + .await?; + + let mut channel_ids = HashSet::::default(); + let messages = self + .load_channel_messages( + rows.map(|row| { + row.map(|row| { + channel_ids.insert(row.channel_id); + row + }) }), - }); + &*tx, + ) + .await?; + + for channel_id in channel_ids { + self.check_user_is_channel_member(channel_id, user_id, &*tx) + .await?; } - drop(rows); - messages.reverse(); + Ok(messages) }) .await } + async fn load_channel_messages( + &self, + mut rows: impl Send + Unpin + Stream>, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let nonce = row.nonce.as_u64_pair(); + messages.push(proto::ChannelMessage { + id: row.id.to_proto(), + sender_id: row.sender_id.to_proto(), + body: row.body, + timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], + nonce: Some(proto::Nonce { + upper_half: nonce.0, + lower_half: nonce.1, + }), + }); + } + drop(rows); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } + } + } + + Ok(messages) + } + pub async fn create_channel_message( &self, channel_id: ChannelId, user_id: UserId, body: &str, + mentions: &[proto::ChatMention], timestamp: OffsetDateTime, nonce: u128, - ) -> Result<(MessageId, Vec, Vec)> { + ) -> Result { self.transaction(|tx| async move { self.check_user_is_channel_participant(channel_id, user_id, &*tx) .await?; @@ -153,7 +231,7 @@ impl Database { let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); - let message = channel_message::Entity::insert(channel_message::ActiveModel { + let result = channel_message::Entity::insert(channel_message::ActiveModel { channel_id: ActiveValue::Set(channel_id), sender_id: ActiveValue::Set(user_id), body: ActiveValue::Set(body.to_string()), @@ -162,35 +240,85 @@ impl Database { id: ActiveValue::NotSet, }) .on_conflict( - OnConflict::column(channel_message::Column::Nonce) - .update_column(channel_message::Column::Nonce) - .to_owned(), + OnConflict::columns([ + channel_message::Column::SenderId, + channel_message::Column::Nonce, + ]) + .do_nothing() + .to_owned(), ) + .do_nothing() .exec(&*tx) .await?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryConnectionId { - ConnectionId, - } + let message_id; + let mut notifications = Vec::new(); + match result { + TryInsertResult::Inserted(result) => { + message_id = result.last_insert_id; + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions + .iter() + .filter_map(|mention| { + let range = mention.range.as_ref()?; + if !body.is_char_boundary(range.start as usize) + || !body.is_char_boundary(range.end as usize) + { + return None; + } + Some(channel_message_mention::ActiveModel { + message_id: ActiveValue::Set(message_id), + start_offset: ActiveValue::Set(range.start as i32), + end_offset: ActiveValue::Set(range.end as i32), + user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), + }) + }) + .collect::>(); + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) + .exec(&*tx) + .await?; + } - // Observe this message for the sender - self.observe_channel_message_internal( - channel_id, - user_id, - message.last_insert_id, - &*tx, - ) - .await?; + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + } + _ => { + message_id = channel_message::Entity::find() + .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert message"))? + .id; + } + } let mut channel_members = self.get_channel_participants(channel_id, &*tx).await?; channel_members.retain(|member| !participant_user_ids.contains(member)); - Ok(( - message.last_insert_id, + Ok(CreatedChannelMessage { + message_id, participant_connection_ids, channel_members, - )) + notifications, + }) }) .await } @@ -200,11 +328,24 @@ impl Database { channel_id: ChannelId, user_id: UserId, message_id: MessageId, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) .await?; - Ok(()) + let mut batch = NotificationBatch::default(); + batch.extend( + self.mark_notification_as_read( + user_id, + &Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: Default::default(), + channel_id: Default::default(), + }, + &*tx, + ) + .await?, + ); + Ok(batch) }) .await } diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs new file mode 100644 index 0000000000..6f2511c23e --- /dev/null +++ b/crates/collab/src/db/queries/notifications.rs @@ -0,0 +1,262 @@ +use super::*; +use rpc::Notification; + +impl Database { + pub async fn initialize_notification_kinds(&mut self) -> Result<()> { + notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map( + |kind| notification_kind::ActiveModel { + name: ActiveValue::Set(kind.to_string()), + ..Default::default() + }, + )) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&self.pool) + .await?; + + let mut rows = notification_kind::Entity::find().stream(&self.pool).await?; + while let Some(row) = rows.next().await { + let row = row?; + self.notification_kinds_by_name.insert(row.name, row.id); + } + + for name in Notification::all_variant_names() { + if let Some(id) = self.notification_kinds_by_name.get(*name).copied() { + self.notification_kinds_by_id.insert(id, name); + } + } + + Ok(()) + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + before_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut condition = + Condition::all().add(notification::Column::RecipientId.eq(recipient_id)); + + if let Some(before_id) = before_id { + condition = condition.add(notification::Column::Id.lt(before_id)); + } + + let mut rows = notification::Entity::find() + .filter(condition) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let kind = row.kind; + if let Some(proto) = model_to_proto(self, row) { + result.push(proto); + } else { + log::warn!("unknown notification kind {:?}", kind); + } + } + result.reverse(); + Ok(result) + }) + .await + } + + /// Create a notification. If `avoid_duplicates` is set to true, then avoid + /// creating a new notification if the given recipient already has an + /// unread notification with the given kind and entity id. + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + avoid_duplicates: bool, + tx: &DatabaseTransaction, + ) -> Result> { + if avoid_duplicates { + if self + .find_notification(recipient_id, ¬ification, tx) + .await? + .is_some() + { + return Ok(None); + } + } + + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + let model = notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind), + entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)), + content: ActiveValue::Set(proto.content.clone()), + ..Default::default() + } + .save(&*tx) + .await?; + + Ok(Some(( + recipient_id, + proto::Notification { + id: model.id.as_ref().to_proto(), + kind: proto.kind, + timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, + is_read: false, + response: None, + content: proto.content, + entity_id: proto.entity_id, + }, + ))) + } + + /// Remove an unread notification with the given recipient, kind and + /// entity id. + pub async fn remove_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let id = self + .find_notification(recipient_id, ¬ification, tx) + .await?; + if let Some(id) = id { + notification::Entity::delete_by_id(id).exec(tx).await?; + } + Ok(id) + } + + /// Populate the response for the notification with the given kind and + /// entity id. + pub async fn mark_notification_as_read_with_response( + &self, + recipient_id: UserId, + notification: &Notification, + response: bool, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx) + .await + } + + pub async fn mark_notification_as_read( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, None, tx) + .await + } + + pub async fn mark_notification_as_read_by_id( + &self, + recipient_id: UserId, + notification_id: NotificationId, + ) -> Result { + self.transaction(|tx| async move { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(notification_id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .into_iter() + .collect()) + }) + .await + } + + async fn mark_notification_as_read_internal( + &self, + recipient_id: UserId, + notification: &Notification, + response: Option, + tx: &DatabaseTransaction, + ) -> Result> { + if let Some(id) = self + .find_notification(recipient_id, notification, &*tx) + .await? + { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + response: if let Some(response) = response { + ActiveValue::Set(Some(response)) + } else { + ActiveValue::NotSet + }, + ..Default::default() + }) + .exec(tx) + .await?; + Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + } else { + Ok(None) + } + } + + /// Find an unread notification by its recipient, kind and entity id. + async fn find_notification( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryIds { + Id, + } + + Ok(notification::Entity::find() + .select_only() + .column(notification::Column::Id) + .filter( + Condition::all() + .add(notification::Column::RecipientId.eq(recipient_id)) + .add(notification::Column::IsRead.eq(false)) + .add(notification::Column::Kind.eq(kind)) + .add(if proto.entity_id.is_some() { + notification::Column::EntityId.eq(proto.entity_id) + } else { + notification::Column::EntityId.is_null() + }), + ) + .into_values::<_, QueryIds>() + .one(&*tx) + .await?) + } +} + +fn model_to_proto(this: &Database, row: notification::Model) -> Option { + let kind = this.notification_kinds_by_id.get(&row.kind)?; + Some(proto::Notification { + id: row.id.to_proto(), + kind: kind.to_string(), + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + response: row.response, + content: row.content, + entity_id: row.entity_id.map(|id| id as u64), + }) +} + +fn notification_kind_from_proto( + this: &Database, + proto: &proto::Notification, +) -> Result { + Ok(this + .notification_kinds_by_name + .get(&proto.kind) + .copied() + .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?) +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index e19391da7d..0acb266d9d 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -7,11 +7,14 @@ pub mod channel_buffer_collaborator; pub mod channel_chat_participant; pub mod channel_member; pub mod channel_message; +pub mod channel_message_mention; pub mod channel_path; pub mod contact; pub mod feature_flag; pub mod follower; pub mod language_server; +pub mod notification; +pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; diff --git a/crates/collab/src/db/tables/channel_message_mention.rs b/crates/collab/src/db/tables/channel_message_mention.rs new file mode 100644 index 0000000000..6155b057f0 --- /dev/null +++ b/crates/collab/src/db/tables/channel_message_mention.rs @@ -0,0 +1,43 @@ +use crate::db::{MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_message_mentions")] +pub struct Model { + #[sea_orm(primary_key)] + pub message_id: MessageId, + #[sea_orm(primary_key)] + pub start_offset: i32, + pub end_offset: i32, + pub user_id: UserId, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel_message::Entity", + from = "Column::MessageId", + to = "super::channel_message::Column::Id" + )] + Message, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + MentionedUser, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Message.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::MentionedUser.def() + } +} diff --git a/crates/collab/src/db/tables/notification.rs b/crates/collab/src/db/tables/notification.rs new file mode 100644 index 0000000000..3105198fa2 --- /dev/null +++ b/crates/collab/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, NotificationKindId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub created_at: PrimitiveDateTime, + pub recipient_id: UserId, + pub kind: NotificationKindId, + pub entity_id: Option, + pub content: String, + pub is_read: bool, + pub response: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/notification_kind.rs b/crates/collab/src/db/tables/notification_kind.rs new file mode 100644 index 0000000000..865b5da04b --- /dev/null +++ b/crates/collab/src/db/tables/notification_kind.rs @@ -0,0 +1,15 @@ +use crate::db::NotificationKindId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationKindId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index a53509a59f..8f2939040d 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -10,7 +10,10 @@ use parking_lot::Mutex; use rpc::proto::ChannelEdge; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + Arc, +}; const TEST_RELEASE_CHANNEL: &'static str = "test"; @@ -31,7 +34,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let sql = include_str!(concat!( @@ -45,6 +48,7 @@ impl TestDb { )) .await .unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -79,11 +83,12 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -176,3 +181,27 @@ fn graph( graph } + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst), + }, + ) + .await + .unwrap() + .user_id +} + +static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); +fn new_test_connection(server: ServerId) -> ConnectionId { + ConnectionId { + id: TEST_CONNECTION_ID.fetch_add(1, SeqCst), + owner_id: server.0 as u32, + } +} diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 51ba9bf655..222514da0b 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_c".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 405839ba18..32cba2fdd9 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -1,20 +1,17 @@ -use collections::{HashMap, HashSet}; -use rpc::{ - proto::{self}, - ConnectionId, -}; +use std::sync::Arc; use crate::{ db::{ queries::channels::ChannelGraph, - tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, ChannelRole, Database, NewUserParams, RoomId, ServerId, UserId, + tests::{graph, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL}, + ChannelId, ChannelRole, Database, NewUserParams, RoomId, }, test_both_dbs, }; -use std::sync::{ - atomic::{AtomicI32, AtomicU32, Ordering}, - Arc, +use collections::{HashMap, HashSet}; +use rpc::{ + proto::{self}, + ConnectionId, }; test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); @@ -305,7 +302,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -319,7 +315,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -360,7 +355,6 @@ async fn test_db_channel_moving(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -727,7 +721,6 @@ async fn test_db_channel_moving_bugs(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -1122,28 +1115,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) pretty_assertions::assert_eq!(actual_map, expected_map) } - -static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); - -async fn new_test_user(db: &Arc, email: &str) -> UserId { - db.create_user( - email, - false, - NewUserParams { - github_login: email[0..email.find("@").unwrap()].to_string(), - github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst), - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id -} - -static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1); -fn new_test_connection(server: ServerId) -> ConnectionId { - ConnectionId { - id: TEST_CONNECTION_ID.fetch_add(1, Ordering::SeqCst), - owner_id: server.0 as u32, - } -} diff --git a/crates/collab/src/db/tests/db_tests.rs b/crates/collab/src/db/tests/db_tests.rs index 1520e081c0..c4b82f8cec 100644 --- a/crates/collab/src/db/tests/db_tests.rs +++ b/crates/collab/src/db/tests/db_tests.rs @@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login1".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login2".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc) { NewUserParams { github_login: "u1".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc) { ); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] + &[Contact::Incoming { user_id: user_1 }] ); // User 2 dismisses the contact request notification without accepting or rejecting. @@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc) { .unwrap(); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] + &[Contact::Incoming { user_id: user_1 }] ); // User can't accept their own contact request @@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }], ); @@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }] ); @@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }] ); @@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc) { &[ Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false, busy: false, } ] @@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person1".into(), github_user_id: 101, - invite_count: 5, }, ) .await @@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person2".into(), github_user_id: 102, - invite_count: 5, }, ) .await @@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() { NewUserParams { github_login: github_login.into(), github_user_id: i as i32, - invite_count: 0, }, ) .await @@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs index 9d5f039747..0286a6308e 100644 --- a/crates/collab/src/db/tests/feature_flag_tests.rs +++ b/crates/collab/src/db/tests/feature_flag_tests.rs @@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user1"), github_user_id: 1, - invite_count: 0, }, ) .await @@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user2"), github_user_id: 2, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index f9d97e2181..10d9778612 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -1,7 +1,9 @@ +use super::new_test_user; use crate::{ - db::{ChannelRole, Database, MessageId, NewUserParams}, + db::{ChannelRole, Database, MessageId}, test_both_dbs, }; +use channel::mentions_to_proto; use std::sync::Arc; use time::OffsetDateTime; @@ -12,39 +14,38 @@ test_both_dbs!( ); async fn test_channel_message_retrieval(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let channel = db.create_root_channel("channel", user).await.unwrap(); + let user = new_test_user(db, "user@example.com").await; + let result = db.create_channel("channel", None, user).await.unwrap(); let owner_id = db.create_server("test").await.unwrap().0 as u32; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + db.join_channel_chat( + result.channel.id, + rpc::ConnectionId { owner_id, id: 0 }, + user, + ) + .await + .unwrap(); let mut all_messages = Vec::new(); for i in 0..10 { all_messages.push( - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap() - .0 - .to_proto(), + db.create_channel_message( + result.channel.id, + user, + &i.to_string(), + &[], + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap() + .message_id + .to_proto(), ); } let messages = db - .get_channel_messages(channel, user, 3, None) + .get_channel_messages(result.channel.id, user, 3, None) .await .unwrap() .into_iter() @@ -54,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc) { let messages = db .get_channel_messages( - channel, + result.channel.id, user, 4, Some(MessageId::from_proto(all_messages[6])), @@ -74,99 +75,154 @@ test_both_dbs!( ); async fn test_channel_message_nonces(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + let channel = db.create_root_channel("channel", user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_c, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) + .await + .unwrap(); + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) + .await + .unwrap(); + + // As user A, create messages that re-use the same nonces. The requests + // succeed, but return the same ids. + let id1 = db + .create_channel_message( + channel, + user_a, + "hi @user_b", + &mentions_to_proto(&[(3..10, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 100, ) .await .unwrap() - .user_id; - let channel = db.create_root_channel("channel", user).await.unwrap(); + .message_id; + let id2 = db + .create_channel_message( + channel, + user_a, + "hello, fellow users", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + let id3 = db + .create_channel_message( + channel, + user_a, + "bye @user_c (same nonce as first message)", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id4 = db + .create_channel_message( + channel, + user_a, + "omg (same nonce as second message)", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; - let owner_id = db.create_server("test").await.unwrap().0 as u32; + // As a different user, reuse one of the same nonces. This request succeeds + // and returns a different id. + let id5 = db + .create_channel_message( + channel, + user_b, + "omg @user_a (same nonce as user_a's first message)", + &mentions_to_proto(&[(4..11, user_a.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) - .await - .unwrap(); + assert_ne!(id1, id2); + assert_eq!(id1, id3); + assert_eq!(id2, id4); + assert_ne!(id5, id1); - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + let messages = db + .get_channel_messages(channel, user_a, 5, None) .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); + .unwrap() + .into_iter() + .map(|m| (m.id, m.body, m.mentions)) + .collect::>(); + assert_eq!( + messages, + &[ + ( + id1.to_proto(), + "hi @user_b".into(), + mentions_to_proto(&[(3..10, user_b.to_proto())]), + ), + ( + id2.to_proto(), + "hello, fellow users".into(), + mentions_to_proto(&[]) + ), + ( + id5.to_proto(), + "omg @user_a (same nonce as user_a's first message)".into(), + mentions_to_proto(&[(4..11, user_a.to_proto())]), + ), + ] + ); } test_both_dbs!( - test_channel_message_new_notification, - test_channel_message_new_notification_postgres, - test_channel_message_new_notification_sqlite + test_unseen_channel_messages, + test_unseen_channel_messages_postgres, + test_unseen_channel_messages_sqlite ); -async fn test_channel_message_new_notification(db: &Arc) { - let user = db - .create_user( - "user_a@example.com", - false, - NewUserParams { - github_login: "user_a".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let observer = db - .create_user( - "user_b@example.com", - false, - NewUserParams { - github_login: "user_b".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; +async fn test_unseen_channel_messages(db: &Arc) { + let user = new_test_user(db, "user_a@example.com").await; + let observer = new_test_user(db, "user_b@example.com").await; let channel_1 = db.create_root_channel("channel", user).await.unwrap(); - let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) .await .unwrap(); + db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) + .await + .unwrap(); db.respond_to_channel_invite(channel_1, observer, true) .await .unwrap(); - - db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) - .await - .unwrap(); - db.respond_to_channel_invite(channel_2, observer, true) .await .unwrap(); @@ -179,28 +235,31 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); let _ = db - .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1) + .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1) .await .unwrap(); - let (second_message, _, _) = db - .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2) + let second_message = db + .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) .await - .unwrap(); + .unwrap() + .message_id; - let (third_message, _, _) = db - .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3) + let third_message = db + .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) .await - .unwrap(); + .unwrap() + .message_id; db.join_channel_chat(channel_2, user_connection_id, user) .await .unwrap(); - let (fourth_message, _, _) = db - .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4) + let fourth_message = db + .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) .await - .unwrap(); + .unwrap() + .message_id; // Check that observer has new messages let unseen_messages = db @@ -295,3 +354,101 @@ async fn test_channel_message_new_notification(db: &Arc) { }] ); } + +test_both_dbs!( + test_channel_message_mentions, + test_channel_message_mentions_postgres, + test_channel_message_mentions_sqlite +); + +async fn test_channel_message_mentions(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + + let channel = db + .create_channel("channel", None, user_a) + .await + .unwrap() + .channel + .id; + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let connection_id = rpc::ConnectionId { owner_id, id: 0 }; + db.join_channel_chat(channel, connection_id, user_a) + .await + .unwrap(); + + db.create_channel_message( + channel, + user_a, + "hi @user_b and @user_c", + &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 1, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "bye @user_c", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 2, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "umm", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 3, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "@user_b, stop.", + &mentions_to_proto(&[(0..7, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 4, + ) + .await + .unwrap(); + + let messages = db + .get_channel_messages(channel, user_b, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.body, m.mentions)) + .collect::>(); + assert_eq!( + &messages, + &[ + ( + "hi @user_b and @user_c".into(), + mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + ), + ( + "bye @user_c".into(), + mentions_to_proto(&[(4..11, user_c.to_proto())]), + ), + ("umm".into(), mentions_to_proto(&[]),), + ( + "@user_b, stop.".into(), + mentions_to_proto(&[(0..7, user_b.to_proto())]), + ), + ] + ); +} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 13fb8ed0eb..85216525b0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -119,7 +119,9 @@ impl AppState { pub async fn new(config: Config) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); - let db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options, Executor::Production).await?; + db.initialize_notification_kinds().await?; + let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9974ded821..2d89ec47a2 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,9 +3,11 @@ mod connection_pool; use crate::{ auth, db::{ - self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, Database, - MembershipUpdated, MessageId, MoveChannelResult, ProjectId, RenameChannelResult, RoomId, - ServerId, SetChannelVisibilityResult, User, UserId, + self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult, + RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult, + User, UserId, }, executor::Executor, AppState, Result, @@ -71,6 +73,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; +const NOTIFICATION_COUNT_PER_PAGE: usize = 50; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = @@ -271,6 +274,9 @@ impl Server { .add_request_handler(send_channel_message) .add_request_handler(remove_channel_message) .add_request_handler(get_channel_messages) + .add_request_handler(get_channel_messages_by_id) + .add_request_handler(get_notifications) + .add_request_handler(mark_notification_as_read) .add_request_handler(link_channel) .add_request_handler(unlink_channel) .add_request_handler(move_channel) @@ -390,7 +396,7 @@ impl Server { let contacts = app_state.db.get_contacts(user_id).await.trace_err(); if let Some((busy, contacts)) = busy.zip(contacts) { let pool = pool.lock(); - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, @@ -584,7 +590,7 @@ impl Server { let (contacts, channels_for_user, channel_invites) = future::try_join3( this.app_state.db.get_contacts(user_id), this.app_state.db.get_channels_for_user(user_id), - this.app_state.db.get_channel_invites_for_user(user_id) + this.app_state.db.get_channel_invites_for_user(user_id), ).await?; { @@ -690,7 +696,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { let pool = self.connection_pool.lock(); - let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + let invitee_contact = contact_for_user(invitee_id, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, @@ -2066,7 +2072,7 @@ async fn request_contact( return Err(anyhow!("cannot add yourself as a contact"))?; } - session + let notifications = session .db() .await .send_contact_request(requester_id, responder_id) @@ -2089,16 +2095,14 @@ async fn request_contact( .incoming_requests .push(proto::IncomingContactRequest { requester_id: requester_id.to_proto(), - should_notify: true, }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(responder_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2117,7 +2121,8 @@ async fn respond_to_contact_request( } else { let accept = request.response == proto::ContactRequestResponse::Accept as i32; - db.respond_to_contact_request(responder_id, requester_id, accept) + let notifications = db + .respond_to_contact_request(responder_id, requester_id, accept) .await?; let requester_busy = db.is_user_busy(requester_id).await?; let responder_busy = db.is_user_busy(responder_id).await?; @@ -2128,7 +2133,7 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(requester_id, false, requester_busy, &pool)); + .push(contact_for_user(requester_id, requester_busy, &pool)); } update .remove_incoming_requests @@ -2142,14 +2147,17 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(responder_id, true, responder_busy, &pool)); + .push(contact_for_user(responder_id, responder_busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { session.peer.send(connection_id, update.clone())?; } + + send_notifications(&*pool, &session.peer, notifications); } response.send(proto::Ack {})?; @@ -2164,7 +2172,8 @@ async fn remove_contact( let requester_id = session.user_id; let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; - let contact_accepted = db.remove_contact(requester_id, responder_id).await?; + let (contact_accepted, deleted_notification_id) = + db.remove_contact(requester_id, responder_id).await?; let pool = session.connection_pool().await; // Update outgoing contact requests of requester @@ -2191,6 +2200,14 @@ async fn remove_contact( } for connection_id in pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; + if let Some(notification_id) = deleted_notification_id { + session.peer.send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + )?; + } } response.send(proto::Ack {})?; @@ -2268,7 +2285,10 @@ async fn invite_channel_member( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let invitee_id = UserId::from_proto(request.user_id); - let channel = db + let InviteMemberResult { + channel, + notifications, + } = db .invite_channel_member( channel_id, invitee_id, @@ -2282,14 +2302,13 @@ async fn invite_channel_member( ..Default::default() }; - for connection_id in session - .connection_pool() - .await - .user_connection_ids(invitee_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(invitee_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2303,13 +2322,33 @@ async fn remove_channel_member( let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - let membership_updated = db + let RemoveChannelMemberResult { + membership_update, + notification_id, + } = db .remove_channel_member(channel_id, member_id, session.user_id) .await?; - dbg!(&membership_updated); - - notify_membership_updated(membership_updated, member_id, &session).await?; + let connection_pool = &session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ); + for connection_id in connection_pool.user_connection_ids(member_id) { + if let Some(notification_id) = notification_id { + session + .peer + .send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + ) + .trace_err(); + } + } response.send(proto::Ack {})?; Ok(()) @@ -2374,7 +2413,13 @@ async fn set_channel_member_role( match result { db::SetMemberRoleResult::MembershipUpdated(membership_update) => { - notify_membership_updated(membership_update, member_id, &session).await?; + let connection_pool = session.connection_pool().await; + notify_membership_updated( + &connection_pool, + membership_update, + member_id, + &session.peer, + ) } db::SetMemberRoleResult::InviteUpdated(channel) => { let update = proto::UpdateChannels { @@ -2535,24 +2580,34 @@ async fn respond_to_channel_invite( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - let result = db + let RespondToChannelInvite { + membership_update, + notifications, + } = db .respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; - if let Some(accept_invite_result) = result { - notify_membership_updated(accept_invite_result, session.user_id, &session).await?; + let connection_pool = session.connection_pool().await; + if let Some(membership_update) = membership_update { + notify_membership_updated( + &connection_pool, + membership_update, + session.user_id, + &session.peer, + ); } else { let update = proto::UpdateChannels { remove_channel_invitations: vec![channel_id.to_proto()], ..Default::default() }; - let connection_pool = session.connection_pool().await; for connection_id in connection_pool.user_connection_ids(session.user_id) { session.peer.send(connection_id, update.clone())?; } }; + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) @@ -2635,8 +2690,14 @@ async fn join_channel_internal( live_kit_connection_info, })?; + let connection_pool = session.connection_pool().await; if let Some(accept_invite_result) = accept_invite_result { - notify_membership_updated(accept_invite_result, session.user_id, &session).await?; + notify_membership_updated( + &connection_pool, + accept_invite_result, + session.user_id, + &session.peer, + ); } room_updated(&joined_room.room, &session.peer); @@ -2805,6 +2866,29 @@ fn channel_buffer_updated( }); } +fn send_notifications( + connection_pool: &ConnectionPool, + peer: &Peer, + notifications: db::NotificationBatch, +) { + for (user_id, notification) in notifications { + for connection_id in connection_pool.user_connection_ids(user_id) { + if let Err(error) = peer.send( + connection_id, + proto::AddNotification { + notification: Some(notification.clone()), + }, + ) { + tracing::error!( + "failed to send notification to {:?} {}", + connection_id, + error + ); + } + } + } +} + async fn send_channel_message( request: proto::SendChannelMessage, response: Response, @@ -2819,19 +2903,27 @@ async fn send_channel_message( return Err(anyhow!("message can't be blank"))?; } + // TODO: adjust mentions if body is trimmed + let timestamp = OffsetDateTime::now_utc(); let nonce = request .nonce .ok_or_else(|| anyhow!("nonce can't be blank"))?; let channel_id = ChannelId::from_proto(request.channel_id); - let (message_id, connection_ids, non_participants) = session + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session .db() .await .create_channel_message( channel_id, session.user_id, &body, + &request.mentions, timestamp, nonce.clone().into(), ) @@ -2840,18 +2932,23 @@ async fn send_channel_message( sender_id: session.user_id.to_proto(), id: message_id.to_proto(), body, + mentions: request.mentions, timestamp: timestamp.unix_timestamp() as u64, nonce: Some(nonce), }; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); response.send(proto::SendChannelMessageResponse { message: Some(message), })?; @@ -2859,7 +2956,7 @@ async fn send_channel_message( let pool = &*session.connection_pool().await; broadcast( None, - non_participants + channel_members .iter() .flat_map(|user_id| pool.user_connection_ids(*user_id)), |peer_id| { @@ -2875,6 +2972,7 @@ async fn send_channel_message( ) }, ); + send_notifications(pool, &session.peer, notifications); Ok(()) } @@ -2904,11 +3002,16 @@ async fn acknowledge_channel_message( ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); - session + let notifications = session .db() .await .observe_channel_message(channel_id, session.user_id, message_id) .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); Ok(()) } @@ -2983,6 +3086,72 @@ async fn get_channel_messages( Ok(()) } +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_notifications( + request: proto::GetNotifications, + response: Response, + session: Session, +) -> Result<()> { + let notifications = session + .db() + .await + .get_notifications( + session.user_id, + NOTIFICATION_COUNT_PER_PAGE, + request + .before_id + .map(|id| db::NotificationId::from_proto(id)), + ) + .await?; + response.send(proto::GetNotificationsResponse { + done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE, + notifications, + })?; + Ok(()) +} + +async fn mark_notification_as_read( + request: proto::MarkNotificationRead, + response: Response, + session: Session, +) -> Result<()> { + let database = &session.db().await; + let notifications = database + .mark_notification_as_read_by_id( + session.user_id, + NotificationId::from_proto(request.notification_id), + ) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + response.send(proto::Ack {})?; + Ok(()) +} + async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = session @@ -3052,11 +3221,12 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } -async fn notify_membership_updated( +fn notify_membership_updated( + connection_pool: &ConnectionPool, result: MembershipUpdated, user_id: UserId, - session: &Session, -) -> Result<()> { + peer: &Peer, +) { let mut update = build_channels_update(result.new_channels, vec![]); update.delete_channels = result .removed_channels @@ -3065,11 +3235,9 @@ async fn notify_membership_updated( .collect(); update.remove_channel_invitations = vec![result.channel_id.to_proto()]; - let connection_pool = session.connection_pool().await; for connection_id in connection_pool.user_connection_ids(user_id) { - session.peer.send(connection_id, update.clone())?; + peer.send(connection_id, update.clone()).trace_err(); } - Ok(()) } fn build_channels_update( @@ -3120,42 +3288,28 @@ fn build_initial_contacts_update( for contact in contacts { match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(contact_for_user(user_id, should_notify, busy, &pool)); + db::Contact::Accepted { user_id, busy } => { + update.contacts.push(contact_for_user(user_id, busy, &pool)); } db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), + db::Contact::Incoming { user_id } => { + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + }) + } } } update } -fn contact_for_user( - user_id: UserId, - should_notify: bool, - busy: bool, - pool: &ConnectionPool, -) -> proto::Contact { +fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact { proto::Contact { user_id: user_id.to_proto(), online: pool.is_user_online(user_id), busy, - should_notify, } } @@ -3216,7 +3370,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> let busy = db.is_user_busy(user_id).await?; let pool = session.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index e78bbe3466..139910e1f6 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -6,6 +6,7 @@ mod channel_message_tests; mod channel_tests; mod following_tests; mod integration_tests; +mod notification_tests; mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index 0fc3b085ed..918eb053d3 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -1,27 +1,30 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; -use channel::{ChannelChat, ChannelMessageId}; +use channel::{ChannelChat, ChannelMessageId, MessageParams}; use collab_ui::chat_panel::ChatPanel; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; +use rpc::Notification; use std::sync::Arc; use workspace::dock::Panel; #[gpui::test] async fn test_basic_channel_messages( deterministic: Arc, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, + mut cx_a: &mut TestAppContext, + mut cx_b: &mut TestAppContext, + mut cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; let channel_id = server .make_channel( "the-channel", None, (&client_a, cx_a), - &mut [(&client_b, cx_b)], + &mut [(&client_b, cx_b), (&client_c, cx_c)], ) .await; @@ -36,8 +39,17 @@ async fn test_basic_channel_messages( .await .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + let message_id = channel_chat_a + .update(cx_a, |c, cx| { + c.send_message( + MessageParams { + text: "hi @user_c!".into(), + mentions: vec![(3..10, client_c.id())], + }, + cx, + ) + .unwrap() + }) .await .unwrap(); channel_chat_a @@ -52,15 +64,55 @@ async fn test_basic_channel_messages( .unwrap(); deterministic.run_until_parked(); - channel_chat_a.update(cx_a, |c, _| { + + let channel_chat_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + for (chat, cx) in [ + (&channel_chat_a, &mut cx_a), + (&channel_chat_b, &mut cx_b), + (&channel_chat_c, &mut cx_c), + ] { + chat.update(*cx, |c, _| { + assert_eq!( + c.messages() + .iter() + .map(|m| (m.body.as_str(), m.mentions.as_slice())) + .collect::>(), + vec![ + ("hi @user_c!", [(3..10, client_c.id())].as_slice()), + ("two", &[]), + ("three", &[]) + ], + "results for user {}", + c.client().id(), + ); + }); + } + + client_c.notification_store().update(cx_c, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); assert_eq!( - c.messages() - .iter() - .map(|m| m.body.as_str()) - .collect::>(), - vec!["one", "two", "three"] + store.notification_at(0).unwrap().notification, + Notification::ChannelMessageMention { + message_id, + sender_id: client_a.id(), + channel_id, + } ); - }) + assert_eq!( + store.notification_at(1).unwrap().notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + }); } #[gpui::test] @@ -280,7 +332,7 @@ async fn test_channel_message_changes( chat_panel_b .update(cx_b, |chat_panel, cx| { chat_panel.set_active(true, cx); - chat_panel.select_channel(channel_id, cx) + chat_panel.select_channel(channel_id, None, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 3df9597777..63cc78b651 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -126,8 +126,8 @@ async fn test_core_channels( // Client B accepts the invitation. client_b .channel_store() - .update(cx_b, |channels, _| { - channels.respond_to_channel_invite(channel_a_id, true) + .update(cx_b, |channels, cx| { + channels.respond_to_channel_invite(channel_a_id, true, cx) }) .await .unwrap(); @@ -153,7 +153,6 @@ async fn test_core_channels( }, ], ); - dbg!("-------"); let channel_c_id = client_a .channel_store() @@ -289,11 +288,17 @@ async fn test_core_channels( // Client B no longer has access to the channel assert_channels(client_b.channel_store(), cx_b, &[]); - // When disconnected, client A sees no channels. server.forbid_connections(); server.disconnect_client(client_a.peer_id().unwrap()); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - assert_channels(client_a.channel_store(), cx_a, &[]); + + client_b + .channel_store() + .update(cx_b, |channel_store, cx| { + channel_store.rename(channel_a_id, "channel-a-renamed", cx) + }) + .await + .unwrap(); server.allow_connections(); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); @@ -302,7 +307,7 @@ async fn test_core_channels( cx_a, &[ExpectedChannel { id: channel_a_id, - name: "channel-a".to_string(), + name: "channel-a-renamed".to_string(), depth: 0, role: ChannelRole::Admin, }], @@ -886,8 +891,8 @@ async fn test_lost_channel_creation( // Client B accepts the invite client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(channel_id, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -951,16 +956,16 @@ async fn test_channel_link_notifications( client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(zed_channel, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) }) .await .unwrap(); client_c .channel_store() - .update(cx_c, |channel_store, _| { - channel_store.respond_to_channel_invite(zed_channel, true) + .update(cx_c, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) }) .await .unwrap(); @@ -1162,16 +1167,16 @@ async fn test_channel_membership_notifications( client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(zed_channel, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(zed_channel, true, cx) }) .await .unwrap(); client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(vim_channel, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(vim_channel, true, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index f3857e3db3..a28f2ae87f 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -1,6 +1,6 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; -use collab_ui::project_shared_notification::ProjectSharedNotification; +use collab_ui::notifications::project_shared_notification::ProjectSharedNotification; use editor::{Editor, ExcerptRange, MultiBuffer}; use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle}; use live_kit_client::MacOSDisplay; diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index d6d449fd47..8396e8947f 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -15,8 +15,8 @@ use gpui::{executor::Deterministic, test::EmptyView, AppContext, ModelHandle, Te use indoc::indoc; use language::{ language_settings::{AllLanguageSettings, Formatter, InlayHintSettings}, - tree_sitter_rust, Anchor, BundledFormatter, Diagnostic, DiagnosticEntry, FakeLspAdapter, - Language, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, + tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, + LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, }; use live_kit_client::MacOSDisplay; use lsp::LanguageServerId; @@ -4530,6 +4530,7 @@ async fn test_prettier_formatting_buffer( LanguageConfig { name: "Rust".into(), path_suffixes: vec!["rs".to_string()], + prettier_parser_name: Some("test_parser".to_string()), ..Default::default() }, Some(tree_sitter_rust::language()), @@ -4537,10 +4538,7 @@ async fn test_prettier_formatting_buffer( let test_plugin = "test_plugin"; let mut fake_language_servers = language .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { - enabled_formatters: vec![BundledFormatter::Prettier { - parser_name: Some("test_parser"), - plugin_names: vec![test_plugin], - }], + prettier_plugins: vec![test_plugin], ..Default::default() })) .await; diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs new file mode 100644 index 0000000000..1114470449 --- /dev/null +++ b/crates/collab/src/tests/notification_tests.rs @@ -0,0 +1,159 @@ +use crate::tests::TestServer; +use gpui::{executor::Deterministic, TestAppContext}; +use notifications::NotificationEvent; +use parking_lot::Mutex; +use rpc::{proto, Notification}; +use std::sync::Arc; + +#[gpui::test] +async fn test_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let notification_events_a = Arc::new(Mutex::new(Vec::new())); + let notification_events_b = Arc::new(Mutex::new(Vec::new())); + client_a.notification_store().update(cx_a, |_, cx| { + let events = notification_events_a.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + client_b.notification_store().update(cx_b, |_, cx| { + let events = notification_events_b.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + + // Client A sends a contact request to client B. + client_a + .user_store() + .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx)) + .await + .unwrap(); + + // Client B receives a contact request notification and responds to the + // request, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequest { + sender_id: client_a.id() + } + ); + assert!(!entry.is_read); + assert_eq!( + ¬ification_events_b.lock()[0..], + &[ + NotificationEvent::NewNotification { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..0, + new_count: 1 + } + ] + ); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + assert_eq!( + ¬ification_events_b.lock()[2..], + &[ + NotificationEvent::NotificationRead { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..1, + new_count: 1 + } + ] + ); + }); + + // Client A receives a notification that client B accepted their request. + client_a.notification_store().read_with(cx_a, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequestAccepted { + responder_id: client_b.id() + } + ); + assert!(!entry.is_read); + }); + + // Client A creates a channel and invites client B to be a member. + let channel_id = client_a + .channel_store() + .update(cx_a, |store, cx| { + store.create_channel("the-channel", None, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |store, cx| { + store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx) + }) + .await + .unwrap(); + + // Client B receives a channel invitation notification and responds to the + // invitation, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + assert!(!entry.is_read); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + }); +} diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index 39598bdaf9..1cec945282 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -208,8 +208,7 @@ impl TestPlan { false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, + github_user_id: ix as i32, }, ) .await diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 81d86943fc..d6ebe1e84e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -16,6 +16,7 @@ use futures::{channel::oneshot, StreamExt as _}; use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; use language::LanguageRegistry; use node_runtime::FakeNodeRuntime; +use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; @@ -46,6 +47,7 @@ pub struct TestClient { pub username: String, pub app_state: Arc, channel_store: ModelHandle, + notification_store: ModelHandle, state: RefCell, } @@ -138,7 +140,6 @@ impl TestServer { NewUserParams { github_login: name.into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -231,7 +232,8 @@ impl TestServer { workspace::init(app_state.clone(), cx); audio::init((), cx); call::init(client.clone(), user_store.clone(), cx); - channel::init(&client, user_store, cx); + channel::init(&client, user_store.clone(), cx); + notifications::init(client.clone(), user_store, cx); }); client @@ -243,6 +245,7 @@ impl TestServer { app_state, username: name.to_string(), channel_store: cx.read(ChannelStore::global).clone(), + notification_store: cx.read(NotificationStore::global).clone(), state: Default::default(), }; client.wait_for_current_user(cx).await; @@ -338,8 +341,8 @@ impl TestServer { member_cx .read(ChannelStore::global) - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) + .update(*member_cx, |channels, cx| { + channels.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -448,6 +451,10 @@ impl TestClient { &self.channel_store } + pub fn notification_store(&self) -> &ModelHandle { + &self.notification_store + } + pub fn user_store(&self) -> &ModelHandle { &self.app_state.user_store } diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index aac4c924f8..791c6b2fa7 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +notifications = { path = "../notifications" } rich_text = { path = "../rich_text" } picker = { path = "../picker" } project = { path = "../project" } -recent_projects = {path = "../recent_projects"} +recent_projects = { path = "../recent_projects" } +rpc = { path = "../rpc" } settings = { path = "../settings" } feature_flags = {path = "../feature_flags"} theme = { path = "../theme" } @@ -52,6 +54,7 @@ zed-actions = {path = "../zed-actions"} anyhow.workspace = true futures.workspace = true +lazy_static.workspace = true log.workspace = true schemars.workspace = true postage.workspace = true @@ -66,7 +69,12 @@ client = { path = "../client", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } + +pretty_assertions.workspace = true +tree-sitter-markdown.workspace = true diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 8f492e2e36..5a4dafb6d4 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1,4 +1,6 @@ -use crate::{channel_view::ChannelView, ChatPanelSettings}; +use crate::{ + channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings, +}; use anyhow::Result; use call::ActiveCall; use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore}; @@ -6,18 +8,18 @@ use client::Client; use collections::HashMap; use db::kvp::KEY_VALUE_STORE; use editor::Editor; -use feature_flags::{ChannelsAlpha, FeatureFlagAppExt}; use gpui::{ actions, elements::*, platform::{CursorStyle, MouseButton}, serde_json, views::{ItemType, Select, SelectStyle}, - AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task, - View, ViewContext, ViewHandle, WeakViewHandle, + AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View, + ViewContext, ViewHandle, WeakViewHandle, }; -use language::{language_settings::SoftWrap, LanguageRegistry}; +use language::LanguageRegistry; use menu::Confirm; +use message_editor::MessageEditor; use project::Fs; use rich_text::RichText; use serde::{Deserialize, Serialize}; @@ -31,6 +33,8 @@ use workspace::{ Workspace, }; +mod message_editor; + const MESSAGE_LOADING_THRESHOLD: usize = 50; const CHAT_PANEL_KEY: &'static str = "ChatPanel"; @@ -40,7 +44,7 @@ pub struct ChatPanel { languages: Arc, active_chat: Option<(ModelHandle, Subscription)>, message_list: ListState, - input_editor: ViewHandle, + input_editor: ViewHandle, channel_select: ViewHandle