diff --git a/Cargo.lock b/Cargo.lock index c99e88b9b4..121e9a28dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,14 +102,20 @@ dependencies = [ "anyhow", "chrono", "collections", + "ctor", "editor", + "env_logger 0.9.3", "fs", "futures 0.3.28", "gpui", + "indoc", "isahc", "language", + "log", "menu", + "ordered-float", "project", + "rand 0.8.5", "regex", "schemars", "search", @@ -1447,9 +1453,10 @@ dependencies = [ [[package]] name = "collab" -version = "0.18.0" +version = "0.20.0" dependencies = [ "anyhow", + "async-trait", "async-tungstenite", "audio", "axum", @@ -2762,6 +2769,7 @@ dependencies = [ "smol", "sum_tree", "tempfile", + "text", "time 0.3.27", "util", ] @@ -4170,8 +4178,7 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.94.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" +source = "git+https://github.com/zed-industries/lsp-types?branch=updated-completion-list-item-defaults#90a040a1d195687bd19e1df47463320a44e93d7a" dependencies = [ "bitflags 1.3.2", "serde", @@ -4576,6 +4583,7 @@ dependencies = [ "anyhow", "async-compression", "async-tar", + "async-trait", "futures 0.3.28", "gpui", "log", @@ -5707,6 +5715,7 @@ dependencies = [ name = "quick_action_bar" version = "0.1.0" dependencies = [ + "ai", "editor", "gpui", "search", @@ -7689,7 +7698,6 @@ dependencies = [ "ctor", "digest 0.9.0", "env_logger 0.9.3", - "fs", "gpui", "lazy_static", "log", @@ -9755,7 +9763,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.102.0" +version = "0.104.0" dependencies = [ "activity_indicator", "ai", diff --git a/assets/icons/check_circle.svg b/assets/icons/check_circle.svg index 85ba2e1f37..b48fe34631 100644 --- a/assets/icons/check_circle.svg +++ b/assets/icons/check_circle.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/error.svg b/assets/icons/error.svg index 82b9401d08..593629beee 100644 --- a/assets/icons/error.svg +++ b/assets/icons/error.svg @@ -1,4 +1,4 @@ - - + + diff --git a/assets/icons/warning.svg b/assets/icons/warning.svg index 6b3d0fd41e..e581def0d0 100644 --- a/assets/icons/warning.svg +++ b/assets/icons/warning.svg @@ -1,5 +1,6 @@ - - - + + + + diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index a1c771da02..fa62a74f3f 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -515,6 +515,17 @@ "enter": "editor::ConfirmCodeAction" } }, + { + "context": "Editor && (showing_code_actions || showing_completions)", + "bindings": { + "up": "editor::ContextMenuPrev", + "ctrl-p": "editor::ContextMenuPrev", + "down": "editor::ContextMenuNext", + "ctrl-n": "editor::ContextMenuNext", + "pageup": "editor::ContextMenuFirst", + "pagedown": "editor::ContextMenuLast" + } + }, // Custom bindings { "bindings": { @@ -522,7 +533,7 @@ // TODO: Move this to a dock open action "cmd-shift-c": "collab_panel::ToggleFocus", "cmd-alt-i": "zed::DebugElements", - "ctrl-shift-:": "editor::ToggleInlayHints", + "ctrl-:": "editor::ToggleInlayHints", } }, { @@ -530,7 +541,8 @@ "bindings": { "alt-enter": "editor::OpenExcerpts", "cmd-f8": "editor::GoToHunk", - "cmd-shift-f8": "editor::GoToPrevHunk" + "cmd-shift-f8": "editor::GoToPrevHunk", + "ctrl-enter": "assistant::InlineAssist" } }, { diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index c7e6199f44..da094ea7e4 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -371,6 +371,7 @@ "Replace" ], "s": "vim::Substitute", + "shift-s": "vim::SubstituteLine", "> >": "editor::Indent", "< <": "editor::Outdent", "ctrl-pagedown": "pane::ActivateNextItem", @@ -446,6 +447,7 @@ } ], "s": "vim::Substitute", + "shift-s": "vim::SubstituteLine", "c": "vim::Substitute", "~": "vim::ChangeCase", "shift-i": [ diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 013565e14f..4438f88108 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -24,7 +24,9 @@ workspace = { path = "../workspace" } anyhow.workspace = true chrono = { version = "0.4", features = ["serde"] } futures.workspace = true +indoc.workspace = true isahc.workspace = true +ordered-float.workspace = true regex.workspace = true schemars.workspace = true serde.workspace = true @@ -35,3 +37,8 @@ tiktoken-rs = "0.4" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } + +ctor.workspace = true +env_logger.workspace = true +log.workspace = true +rand.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index d2be651bd5..2c2d7e774e 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,28 +1,33 @@ pub mod assistant; mod assistant_settings; +mod streaming_diff; -use anyhow::Result; +use anyhow::{anyhow, Result}; pub use assistant::AssistantPanel; use assistant_settings::OpenAIModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; -use futures::StreamExt; -use gpui::AppContext; +use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use gpui::{executor::Background, AppContext}; +use isahc::{http::StatusCode, Request, RequestExt}; use regex::Regex; use serde::{Deserialize, Serialize}; use std::{ cmp::Reverse, ffi::OsStr, fmt::{self, Display}, + io, path::PathBuf, sync::Arc, }; use util::paths::CONVERSATIONS_DIR; +const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; + // Data types for chat completion requests #[derive(Debug, Serialize)] -struct OpenAIRequest { +pub struct OpenAIRequest { model: String, messages: Vec, stream: bool, @@ -116,7 +121,7 @@ struct RequestMessage { } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -struct ResponseMessage { +pub struct ResponseMessage { role: Option, content: Option, } @@ -150,7 +155,7 @@ impl Display for Role { } #[derive(Deserialize, Debug)] -struct OpenAIResponseStreamEvent { +pub struct OpenAIResponseStreamEvent { pub id: Option, pub object: String, pub created: u32, @@ -160,14 +165,14 @@ struct OpenAIResponseStreamEvent { } #[derive(Deserialize, Debug)] -struct Usage { +pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } #[derive(Deserialize, Debug)] -struct ChatChoiceDelta { +pub struct ChatChoiceDelta { pub index: u32, pub delta: ResponseMessage, pub finish_reason: Option, @@ -191,3 +196,97 @@ struct OpenAIChoice { pub fn init(cx: &mut AppContext) { assistant::init(cx); } + +pub async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 3c561b0e03..9b384252fc 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,53 +1,63 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent, - RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, + stream_completion, + streaming_diff::{Hunk, StreamingDiff}, + MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role, + SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; -use collections::{HashMap, HashSet}; +use collections::{hash_map, HashMap, HashSet, VecDeque}; use editor::{ - display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint}, + display_map::{ + BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, + }, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, ToOffset, + Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint, }; use fs::Fs; -use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{ actions, - elements::*, - executor::Background, + elements::{ + ChildView, Component, Empty, Flex, Label, MouseEventHandler, ParentElement, SafeStylable, + Stack, Svg, Text, UniformList, UniformListState, + }, + fonts::HighlightStyle, geometry::vector::{vec2f, Vector2F}, platform::{CursorStyle, MouseButton}, - Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle, - Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, + Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext, + ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, + WindowContext, +}; +use language::{ + language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _, + TransactionId, }; -use isahc::{http::StatusCode, Request, RequestExt}; -use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use search::BufferSearchBar; -use serde::Deserialize; use settings::SettingsStore; use std::{ - cell::RefCell, + cell::{Cell, RefCell}, cmp, env, fmt::Write, - io, iter, + future, iter, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, time::Duration, }; -use theme::AssistantStyle; +use theme::{ + components::{action_button::Button, ComponentExt}, + AssistantStyle, +}; use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt}; use workspace::{ dock::{DockPosition, Panel}, searchable::Direction, - Save, ToggleZoom, Toolbar, Workspace, + Save, Toast, ToggleZoom, Toolbar, Workspace, }; -const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; - actions!( assistant, [ @@ -58,6 +68,8 @@ actions!( QuoteSelection, ToggleFocus, ResetKey, + InlineAssist, + ToggleIncludeConversation, ] ); @@ -89,6 +101,13 @@ pub fn init(cx: &mut AppContext) { workspace.toggle_panel_focus::(cx); }, ); + cx.add_action(AssistantPanel::inline_assist); + cx.add_action(AssistantPanel::cancel_last_inline_assist); + cx.add_action(InlineAssistant::confirm); + cx.add_action(InlineAssistant::cancel); + cx.add_action(InlineAssistant::toggle_include_conversation); + cx.add_action(InlineAssistant::move_up); + cx.add_action(InlineAssistant::move_down); } #[derive(Debug)] @@ -118,10 +137,17 @@ pub struct AssistantPanel { languages: Arc, fs: Arc, subscriptions: Vec, + next_inline_assist_id: usize, + pending_inline_assists: HashMap, + pending_inline_assist_ids_by_editor: HashMap, Vec>, + include_conversation_in_next_inline_assist: bool, + inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, } impl AssistantPanel { + const INLINE_PROMPT_HISTORY_MAX_LEN: usize = 20; + pub fn load( workspace: WeakViewHandle, cx: AsyncAppContext, @@ -181,6 +207,11 @@ impl AssistantPanel { width: None, height: None, subscriptions: Default::default(), + next_inline_assist_id: 0, + pending_inline_assists: Default::default(), + pending_inline_assist_ids_by_editor: Default::default(), + include_conversation_in_next_inline_assist: false, + inline_prompt_history: Default::default(), _watch_saved_conversations, }; @@ -201,6 +232,717 @@ impl AssistantPanel { }) } + pub fn inline_assist( + workspace: &mut Workspace, + _: &InlineAssist, + cx: &mut ViewContext, + ) { + let this = if let Some(this) = workspace.panel::(cx) { + if this + .update(cx, |assistant, cx| assistant.load_api_key(cx)) + .is_some() + { + this + } else { + workspace.focus_panel::(cx); + return; + } + } else { + return; + }; + + let active_editor = if let Some(active_editor) = workspace + .active_item(cx) + .and_then(|item| item.act_as::(cx)) + { + active_editor + } else { + return; + }; + + this.update(cx, |assistant, cx| { + assistant.new_inline_assist(&active_editor, cx) + }); + } + + fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) { + let inline_assist_id = post_inc(&mut self.next_inline_assist_id); + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + let selection = editor.read(cx).selections.newest_anchor().clone(); + let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot); + let assist_kind = if editor.read(cx).selections.newest::(cx).is_empty() { + InlineAssistKind::Generate + } else { + InlineAssistKind::Transform + }; + let measurements = Rc::new(Cell::new(BlockMeasurements::default())); + let inline_assistant = cx.add_view(|cx| { + let assistant = InlineAssistant::new( + inline_assist_id, + assist_kind, + measurements.clone(), + self.include_conversation_in_next_inline_assist, + self.inline_prompt_history.clone(), + cx, + ); + cx.focus_self(); + assistant + }); + let block_id = editor.update(cx, |editor, cx| { + editor.change_selections(None, cx, |selections| { + selections.select_anchor_ranges([selection.head()..selection.head()]) + }); + editor.insert_blocks( + [BlockProperties { + style: BlockStyle::Flex, + position: selection.head().bias_left(&snapshot), + height: 2, + render: Arc::new({ + let inline_assistant = inline_assistant.clone(); + move |cx: &mut BlockContext| { + measurements.set(BlockMeasurements { + anchor_x: cx.anchor_x, + gutter_width: cx.gutter_width, + }); + ChildView::new(&inline_assistant, cx).into_any() + } + }), + disposition: if selection.reversed { + BlockDisposition::Above + } else { + BlockDisposition::Below + }, + }], + Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)), + cx, + )[0] + }); + + self.pending_inline_assists.insert( + inline_assist_id, + PendingInlineAssist { + kind: assist_kind, + editor: editor.downgrade(), + range, + highlighted_ranges: Default::default(), + inline_assistant: Some((block_id, inline_assistant.clone())), + code_generation: Task::ready(None), + transaction_id: None, + _subscriptions: vec![ + cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), + cx.subscribe(editor, { + let inline_assistant = inline_assistant.downgrade(); + move |this, editor, event, cx| { + if let Some(inline_assistant) = inline_assistant.upgrade(cx) { + match event { + editor::Event::SelectionsChanged { local } => { + if *local && inline_assistant.read(cx).has_focus { + cx.focus(&editor); + } + } + editor::Event::TransactionUndone { + transaction_id: tx_id, + } => { + if let Some(pending_assist) = + this.pending_inline_assists.get(&inline_assist_id) + { + if pending_assist.transaction_id == Some(*tx_id) { + // Notice we are supplying `undo: false` here. This + // is because there's no need to undo the transaction + // because the user just did so. + this.close_inline_assist( + inline_assist_id, + false, + cx, + ); + } + } + } + _ => {} + } + } + } + }), + ], + }, + ); + self.pending_inline_assist_ids_by_editor + .entry(editor.downgrade()) + .or_default() + .push(inline_assist_id); + self.update_highlights_for_editor(&editor, cx); + } + + fn handle_inline_assistant_event( + &mut self, + inline_assistant: ViewHandle, + event: &InlineAssistantEvent, + cx: &mut ViewContext, + ) { + let assist_id = inline_assistant.read(cx).id; + match event { + InlineAssistantEvent::Confirmed { + prompt, + include_conversation, + } => { + self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); + } + InlineAssistantEvent::Canceled => { + self.close_inline_assist(assist_id, true, cx); + } + InlineAssistantEvent::Dismissed => { + self.hide_inline_assist(assist_id, cx); + } + InlineAssistantEvent::IncludeConversationToggled { + include_conversation, + } => { + self.include_conversation_in_next_inline_assist = *include_conversation; + } + } + } + + fn cancel_last_inline_assist( + workspace: &mut Workspace, + _: &editor::Cancel, + cx: &mut ViewContext, + ) { + if let Some(panel) = workspace.panel::(cx) { + if let Some(editor) = workspace + .active_item(cx) + .and_then(|item| item.downcast::()) + { + let handled = panel.update(cx, |panel, cx| { + if let Some(assist_id) = panel + .pending_inline_assist_ids_by_editor + .get(&editor.downgrade()) + .and_then(|assist_ids| assist_ids.last().copied()) + { + panel.close_inline_assist(assist_id, true, cx); + true + } else { + false + } + }); + if handled { + return; + } + } + } + + cx.propagate_action(); + } + + fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { + self.hide_inline_assist(assist_id, cx); + + if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) { + if let hash_map::Entry::Occupied(mut entry) = self + .pending_inline_assist_ids_by_editor + .entry(pending_assist.editor) + { + entry.get_mut().retain(|id| *id != assist_id); + if entry.get().is_empty() { + entry.remove(); + } + } + + if let Some(editor) = pending_assist.editor.upgrade(cx) { + self.update_highlights_for_editor(&editor, cx); + + if undo { + if let Some(transaction_id) = pending_assist.transaction_id { + editor.update(cx, |editor, cx| { + editor.buffer().update(cx, |buffer, cx| { + buffer.undo_transaction(transaction_id, cx) + }); + }); + } + } + } + } + } + + fn hide_inline_assist(&mut self, assist_id: usize, cx: &mut ViewContext) { + if let Some(pending_assist) = self.pending_inline_assists.get_mut(&assist_id) { + if let Some(editor) = pending_assist.editor.upgrade(cx) { + if let Some((block_id, _)) = pending_assist.inline_assistant.take() { + editor.update(cx, |editor, cx| { + editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + }); + } + } + } + } + + fn confirm_inline_assist( + &mut self, + inline_assist_id: usize, + user_prompt: &str, + include_conversation: bool, + cx: &mut ViewContext, + ) { + let api_key = if let Some(api_key) = self.api_key.borrow().clone() { + api_key + } else { + return; + }; + + let conversation = if include_conversation { + self.active_editor() + .map(|editor| editor.read(cx).conversation.clone()) + } else { + None + }; + + let pending_assist = + if let Some(pending_assist) = self.pending_inline_assists.get_mut(&inline_assist_id) { + pending_assist + } else { + return; + }; + + let editor = if let Some(editor) = pending_assist.editor.upgrade(cx) { + editor + } else { + return; + }; + + self.inline_prompt_history + .retain(|prompt| prompt != user_prompt); + self.inline_prompt_history.push_back(user_prompt.into()); + if self.inline_prompt_history.len() > Self::INLINE_PROMPT_HISTORY_MAX_LEN { + self.inline_prompt_history.pop_front(); + } + + let range = pending_assist.range.clone(); + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + let selected_text = snapshot + .text_for_range(range.start..range.end) + .collect::(); + + let selection_start = range.start.to_point(&snapshot); + let selection_end = range.end.to_point(&snapshot); + + let mut base_indent: Option = None; + let mut start_row = selection_start.row; + if snapshot.is_line_blank(start_row) { + if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { + start_row = prev_non_blank_row; + } + } + for row in start_row..=selection_end.row { + if snapshot.is_line_blank(row) { + continue; + } + + let line_indent = snapshot.indent_size_for_line(row); + if let Some(base_indent) = base_indent.as_mut() { + if line_indent.len < base_indent.len { + *base_indent = line_indent; + } + } else { + base_indent = Some(line_indent); + } + } + + let mut normalized_selected_text = selected_text.clone(); + if let Some(base_indent) = base_indent { + for row in selection_start.row..=selection_end.row { + let selection_row = row - selection_start.row; + let line_start = + normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); + let indent_len = if row == selection_start.row { + base_indent.len.saturating_sub(selection_start.column) + } else { + let line_len = normalized_selected_text.line_len(selection_row); + cmp::min(line_len, base_indent.len) + }; + let indent_end = cmp::min( + line_start + indent_len as usize, + normalized_selected_text.len(), + ); + normalized_selected_text.replace(line_start..indent_end, ""); + } + } + + let language = snapshot.language_at(range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None + } else { + Some(language.name()) + } + } else { + None + }; + let language_name = language_name.as_deref(); + + let mut prompt = String::new(); + if let Some(language_name) = language_name { + writeln!(prompt, "You're an expert {language_name} engineer.").unwrap(); + } + match pending_assist.kind { + InlineAssistKind::Transform => { + writeln!( + prompt, + "You're currently working inside an editor on this file:" + ) + .unwrap(); + if let Some(language_name) = language_name { + writeln!(prompt, "```{language_name}").unwrap(); + } else { + writeln!(prompt, "```").unwrap(); + } + for chunk in snapshot.text_for_range(Anchor::min()..Anchor::max()) { + write!(prompt, "{chunk}").unwrap(); + } + writeln!(prompt, "```").unwrap(); + + writeln!( + prompt, + "In particular, the user has selected the following text:" + ) + .unwrap(); + if let Some(language_name) = language_name { + writeln!(prompt, "```{language_name}").unwrap(); + } else { + writeln!(prompt, "```").unwrap(); + } + writeln!(prompt, "{normalized_selected_text}").unwrap(); + writeln!(prompt, "```").unwrap(); + writeln!(prompt).unwrap(); + writeln!( + prompt, + "Modify the selected text given the user prompt: {user_prompt}" + ) + .unwrap(); + writeln!( + prompt, + "You MUST reply only with the edited selected text, not the entire file." + ) + .unwrap(); + } + InlineAssistKind::Generate => { + writeln!( + prompt, + "You're currently working inside an editor on this file:" + ) + .unwrap(); + if let Some(language_name) = language_name { + writeln!(prompt, "```{language_name}").unwrap(); + } else { + writeln!(prompt, "```").unwrap(); + } + for chunk in snapshot.text_for_range(Anchor::min()..range.start) { + write!(prompt, "{chunk}").unwrap(); + } + write!(prompt, "<|>").unwrap(); + for chunk in snapshot.text_for_range(range.start..Anchor::max()) { + write!(prompt, "{chunk}").unwrap(); + } + writeln!(prompt).unwrap(); + writeln!(prompt, "```").unwrap(); + writeln!( + prompt, + "Assume the cursor is located where the `<|>` marker is." + ) + .unwrap(); + writeln!( + prompt, + "Text can't be replaced, so assume your answer will be inserted at the cursor." + ) + .unwrap(); + writeln!( + prompt, + "Complete the text given the user prompt: {user_prompt}" + ) + .unwrap(); + } + } + if let Some(language_name) = language_name { + writeln!(prompt, "Your answer MUST always be valid {language_name}.").unwrap(); + } + writeln!(prompt, "Always wrap your response in a Markdown codeblock.").unwrap(); + writeln!(prompt, "Never make remarks about the output.").unwrap(); + + let mut messages = Vec::new(); + let mut model = settings::get::(cx) + .default_open_ai_model + .clone(); + if let Some(conversation) = conversation { + let conversation = conversation.read(cx); + let buffer = conversation.buffer.read(cx); + messages.extend( + conversation + .messages(cx) + .map(|message| message.to_open_ai_message(buffer)), + ); + model = conversation.model.clone(); + } + + messages.push(RequestMessage { + role: Role::User, + content: prompt, + }); + let request = OpenAIRequest { + model: model.full_name().into(), + messages, + stream: true, + }; + let response = stream_completion(api_key, cx.background().clone(), request); + let editor = editor.downgrade(); + + pending_assist.code_generation = cx.spawn(|this, mut cx| { + async move { + let mut edit_start = range.start.to_offset(&snapshot); + + let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); + let diff = cx.background().spawn(async move { + let chunks = strip_markdown_codeblock(response.await?.filter_map( + |message| async move { + match message { + Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }, + )); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); + + let mut indent_len; + let indent_text; + if let Some(base_indent) = base_indent { + indent_len = base_indent.len; + indent_text = match base_indent.kind { + language::IndentKind::Space => " ", + language::IndentKind::Tab => "\t", + }; + } else { + indent_len = 0; + indent_text = ""; + }; + + let mut first_line_len = 0; + let mut first_line_non_whitespace_char_ix = None; + let mut first_line = true; + let mut new_text = String::new(); + + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + + let mut lines = chunk.split('\n'); + if let Some(mut line) = lines.next() { + if first_line { + if first_line_non_whitespace_char_ix.is_none() { + if let Some(mut char_ix) = + line.find(|ch: char| !ch.is_whitespace()) + { + line = &line[char_ix..]; + char_ix += first_line_len; + first_line_non_whitespace_char_ix = Some(char_ix); + let first_line_indent = char_ix + .saturating_sub(selection_start.column as usize) + as usize; + new_text.push_str(&indent_text.repeat(first_line_indent)); + indent_len = indent_len.saturating_sub(char_ix as u32); + } + } + first_line_len += line.len(); + } + + if first_line_non_whitespace_char_ix.is_some() { + new_text.push_str(line); + } + } + + for line in lines { + first_line = false; + new_text.push('\n'); + if !line.is_empty() { + new_text.push_str(&indent_text.repeat(indent_len as usize)); + } + new_text.push_str(line); + } + + let hunks = diff.push_new(&new_text); + hunks_tx.send(hunks).await?; + new_text.clear(); + } + hunks_tx.send(diff.finish()).await?; + + anyhow::Ok(()) + }); + + while let Some(hunks) = hunks_rx.next().await { + let editor = if let Some(editor) = editor.upgrade(&cx) { + editor + } else { + break; + }; + + let this = if let Some(this) = this.upgrade(&cx) { + this + } else { + break; + }; + + this.update(&mut cx, |this, cx| { + let pending_assist = if let Some(pending_assist) = + this.pending_inline_assists.get_mut(&inline_assist_id) + { + pending_assist + } else { + return; + }; + + pending_assist.highlighted_ranges.clear(); + editor.update(cx, |editor, cx| { + let transaction = editor.buffer().update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + + buffer.start_transaction(cx); + buffer.edit( + hunks.into_iter().filter_map(|hunk| match hunk { + Hunk::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + Hunk::Remove { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + Hunk::Keep { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start += len; + pending_assist.highlighted_ranges.push(edit_range); + None + } + }), + None, + cx, + ); + + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = pending_assist.transaction_id { + // Group all assistant edits into the first transaction. + editor.buffer().update(cx, |buffer, cx| { + buffer.merge_transactions( + transaction, + first_transaction, + cx, + ) + }); + } else { + pending_assist.transaction_id = Some(transaction); + editor.buffer().update(cx, |buffer, cx| { + buffer.finalize_last_transaction(cx) + }); + } + } + }); + + this.update_highlights_for_editor(&editor, cx); + }); + } + + if let Err(error) = diff.await { + this.update(&mut cx, |this, cx| { + let pending_assist = if let Some(pending_assist) = + this.pending_inline_assists.get_mut(&inline_assist_id) + { + pending_assist + } else { + return; + }; + + if let Some((_, inline_assistant)) = + pending_assist.inline_assistant.as_ref() + { + inline_assistant.update(cx, |inline_assistant, cx| { + inline_assistant.set_error(error, cx); + }); + } else if let Some(workspace) = this.workspace.upgrade(cx) { + workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + inline_assist_id, + format!("Inline assistant error: {}", error), + ), + cx, + ); + }) + } + })?; + } else { + let _ = this.update(&mut cx, |this, cx| { + this.close_inline_assist(inline_assist_id, false, cx) + }); + } + + anyhow::Ok(()) + } + .log_err() + }); + } + + fn update_highlights_for_editor( + &self, + editor: &ViewHandle, + cx: &mut ViewContext, + ) { + let mut background_ranges = Vec::new(); + let mut foreground_ranges = Vec::new(); + let empty_inline_assist_ids = Vec::new(); + let inline_assist_ids = self + .pending_inline_assist_ids_by_editor + .get(&editor.downgrade()) + .unwrap_or(&empty_inline_assist_ids); + + for inline_assist_id in inline_assist_ids { + if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) { + background_ranges.push(pending_assist.range.clone()); + foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned()); + } + } + + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); + merge_ranges(&mut background_ranges, &snapshot); + merge_ranges(&mut foreground_ranges, &snapshot); + editor.update(cx, |editor, cx| { + if background_ranges.is_empty() { + editor.clear_background_highlights::(cx); + } else { + editor.highlight_background::( + background_ranges, + |theme| theme.assistant.inline.pending_edit_background, + cx, + ); + } + + if foreground_ranges.is_empty() { + editor.clear_text_highlights::(cx); + } else { + editor.highlight_text::( + foreground_ranges, + HighlightStyle { + fade_out: Some(0.6), + ..Default::default() + }, + cx, + ); + } + }); + } + fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( @@ -570,6 +1312,32 @@ impl AssistantPanel { .iter() .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } + + fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { + if self.api_key.borrow().is_none() && !self.has_read_credentials { + self.has_read_credentials = true; + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + if let Some(api_key) = api_key { + *self.api_key.borrow_mut() = Some(api_key); + } else if self.api_key_editor.is_none() { + self.api_key_editor = Some(build_api_key_editor(cx)); + cx.notify(); + } + } + + self.api_key.borrow().clone() + } } fn build_api_key_editor(cx: &mut ViewContext) -> ViewHandle { @@ -753,27 +1521,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - } + self.load_api_key(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1068,15 +1816,20 @@ impl Conversation { cx: &mut ModelContext, ) -> Vec { let mut user_messages = Vec::new(); - let mut tasks = Vec::new(); - let last_message_id = self.message_anchors.iter().rev().find_map(|message| { - message - .start - .is_valid(self.buffer.read(cx)) - .then_some(message.id) - }); + let last_message_id = if let Some(last_message_id) = + self.message_anchors.iter().rev().find_map(|message| { + message + .start + .is_valid(self.buffer.read(cx)) + .then_some(message.id) + }) { + last_message_id + } else { + return Default::default(); + }; + let mut should_assist = false; for selected_message_id in selected_messages { let selected_message_role = if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { @@ -1093,144 +1846,111 @@ impl Conversation { cx, ) { user_messages.push(user_message); - } else { - continue; } } else { - let request = OpenAIRequest { - model: self.model.full_name().to_string(), - messages: self - .messages(cx) - .filter(|message| matches!(message.status, MessageStatus::Done)) - .flat_map(|message| { - let mut system_message = None; - if message.id == selected_message_id { - system_message = Some(RequestMessage { - role: Role::System, - content: concat!( - "Treat the following messages as additional knowledge you have learned about, ", - "but act as if they were not part of this conversation. That is, treat them ", - "as if the user didn't see them and couldn't possibly inquire about them." - ).into() - }); - } - - Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message) - }) - .chain(Some(RequestMessage { - role: Role::System, - content: format!( - "Direct your reply to message with id {}. Do not include a [Message X] header.", - selected_message_id.0 - ), - })) - .collect(), - stream: true, - }; - - let Some(api_key) = self.api_key.borrow().clone() else { - continue; - }; - let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = self - .insert_message_after( - selected_message_id, - Role::Assistant, - MessageStatus::Pending, - cx, - ) - .unwrap(); - - // Queue up the user's next reply - if Some(selected_message_id) == last_message_id { - let user_message = self - .insert_message_after( - assistant_message.id, - Role::User, - MessageStatus::Done, - cx, - ) - .unwrap(); - user_messages.push(user_message); - } - - tasks.push(cx.spawn_weak({ - |this, mut cx| async move { - let assistant_message_id = assistant_message.id; - let stream_completion = async { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = this.message_anchors.iter().position( - |message| message.id == assistant_message_id, - )?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); - - Some(()) - }); - } - smol::future::yield_now().await; - } - - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - this.pending_completions.retain(|completion| { - completion.id != this.completion_count - }); - this.summarize(cx); - }); - - anyhow::Ok(()) - }; - - let result = stream_completion.await; - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - match result { - Ok(_) => { - metadata.status = MessageStatus::Done; - } - Err(error) => { - metadata.status = MessageStatus::Error( - error.to_string().trim().into(), - ); - } - } - cx.notify(); - } - }); - } - } - })); + should_assist = true; } } - if !tasks.is_empty() { + if should_assist { + let Some(api_key) = self.api_key.borrow().clone() else { + return Default::default(); + }; + + let request = OpenAIRequest { + model: self.model.full_name().to_string(), + messages: self + .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .collect(), + stream: true, + }; + + let stream = stream_completion(api_key, cx.background().clone(), request); + let assistant_message = self + .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) + .unwrap(); + + // Queue up the user's next reply. + let user_message = self + .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx) + .unwrap(); + user_messages.push(user_message); + + let task = cx.spawn_weak({ + |this, mut cx| async move { + let assistant_message_id = assistant_message.id; + let stream_completion = async { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let text: Arc = choice.delta.content?.into(); + let message_ix = + this.message_anchors.iter().position(|message| { + message.id == assistant_message_id + })?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message + .start + .to_offset(buffer) + .saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); + }); + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); + } + smol::future::yield_now().await; + } + + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + this.pending_completions + .retain(|completion| completion.id != this.completion_count); + this.summarize(cx); + }); + + anyhow::Ok(()) + }; + + let result = stream_completion.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + match result { + Ok(_) => { + metadata.status = MessageStatus::Done; + } + Err(error) => { + metadata.status = + MessageStatus::Error(error.to_string().trim().into()); + } + } + cx.notify(); + } + }); + } + } + }); + self.pending_completions.push(PendingCompletion { id: post_inc(&mut self.completion_count), - _tasks: tasks, + _task: task, }); } @@ -1597,7 +2317,7 @@ impl Conversation { struct PendingCompletion { id: usize, - _tasks: Vec>, + _task: Task<()>, } enum ConversationEditorEvent { @@ -2145,8 +2865,9 @@ pub struct Message { impl Message { fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { - let mut content = format!("[Message {}]\n", self.id.0).to_string(); - content.extend(buffer.text_for_range(self.offset_range.clone())); + let content = buffer + .text_for_range(self.offset_range.clone()) + .collect::(); RequestMessage { role: self.role, content: content.trim_end().into(), @@ -2154,96 +2875,374 @@ impl Message { } } -async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; +enum InlineAssistantEvent { + Confirmed { + prompt: String, + include_conversation: bool, + }, + Canceled, + Dismissed, + IncludeConversationToggled { + include_conversation: bool, + }, +} - let (tx, rx) = futures::channel::mpsc::unbounded::>(); +#[derive(Copy, Clone)] +enum InlineAssistKind { + Transform, + Generate, +} - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; +struct InlineAssistant { + id: usize, + prompt_editor: ViewHandle, + confirmed: bool, + has_focus: bool, + include_conversation: bool, + measurements: Rc>, + error: Option, + prompt_history: VecDeque, + prompt_history_ix: Option, + pending_prompt: String, + _subscription: Subscription, +} - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); +impl Entity for InlineAssistant { + type Event = InlineAssistantEvent; +} - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) +impl View for InlineAssistant { + fn ui_name() -> &'static str { + "InlineAssistant" + } + + fn render(&mut self, cx: &mut ViewContext) -> AnyElement { + enum ErrorIcon {} + let theme = theme::current(cx); + + Flex::row() + .with_child( + Flex::row() + .with_child( + Button::action(ToggleIncludeConversation) + .with_tooltip("Include Conversation", theme.tooltip.clone()) + .with_id(self.id) + .with_contents(theme::components::svg::Svg::new("icons/ai.svg")) + .toggleable(self.include_conversation) + .with_style(theme.assistant.inline.include_conversation.clone()) + .element() + .aligned(), + ) + .with_children(if let Some(error) = self.error.as_ref() { + Some( + Svg::new("icons/circle_x_mark_12.svg") + .with_color(theme.assistant.error_icon.color) + .constrained() + .with_width(theme.assistant.error_icon.width) + .contained() + .with_style(theme.assistant.error_icon.container) + .with_tooltip::( + self.id, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), + ) } else { - Ok(None) + None + }) + .aligned() + .constrained() + .dynamically({ + let measurements = self.measurements.clone(); + move |constraint, _, _| { + let measurements = measurements.get(); + SizeConstraint { + min: vec2f(measurements.gutter_width, constraint.min.y()), + max: vec2f(measurements.gutter_width, constraint.max.y()), + } + } + }), + ) + .with_child(Empty::new().constrained().dynamically({ + let measurements = self.measurements.clone(); + move |constraint, _, _| { + let measurements = measurements.get(); + SizeConstraint { + min: vec2f( + measurements.anchor_x - measurements.gutter_width, + constraint.min.y(), + ), + max: vec2f( + measurements.anchor_x - measurements.gutter_width, + constraint.max.y(), + ), } } + })) + .with_child( + ChildView::new(&self.prompt_editor, cx) + .aligned() + .left() + .flex(1., true), + ) + .contained() + .with_style(theme.assistant.inline.container) + .into_any() + .into_any() + } - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } + fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext) { + cx.focus(&self.prompt_editor); + self.has_focus = true; + } - if done { - break; - } - } - } + fn focus_out(&mut self, _: gpui::AnyViewHandle, _: &mut ViewContext) { + self.has_focus = false; + } +} - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), +impl InlineAssistant { + fn new( + id: usize, + kind: InlineAssistKind, + measurements: Rc>, + include_conversation: bool, + prompt_history: VecDeque, + cx: &mut ViewContext, + ) -> Self { + let prompt_editor = cx.add_view(|cx| { + let mut editor = Editor::single_line( + Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), + cx, + ); + let placeholder = match kind { + InlineAssistKind::Transform => "Enter transformation prompt…", + InlineAssistKind::Generate => "Enter generation prompt…", + }; + editor.set_placeholder_text(placeholder, cx); + editor + }); + let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events); + Self { + id, + prompt_editor, + confirmed: false, + has_focus: false, + include_conversation, + measurements, + error: None, + prompt_history, + prompt_history_ix: None, + pending_prompt: String::new(), + _subscription: subscription, } } + + fn handle_prompt_editor_events( + &mut self, + _: ViewHandle, + event: &editor::Event, + cx: &mut ViewContext, + ) { + if let editor::Event::Edited = event { + self.pending_prompt = self.prompt_editor.read(cx).text(cx); + cx.notify(); + } + } + + fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { + cx.emit(InlineAssistantEvent::Canceled); + } + + fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + if self.confirmed { + cx.emit(InlineAssistantEvent::Dismissed); + } else { + let prompt = self.prompt_editor.read(cx).text(cx); + self.prompt_editor.update(cx, |editor, cx| { + editor.set_read_only(true); + editor.set_field_editor_style( + Some(Arc::new(|theme| { + theme.assistant.inline.disabled_editor.clone() + })), + cx, + ); + }); + cx.emit(InlineAssistantEvent::Confirmed { + prompt, + include_conversation: self.include_conversation, + }); + self.confirmed = true; + self.error = None; + cx.notify(); + } + } + + fn toggle_include_conversation( + &mut self, + _: &ToggleIncludeConversation, + cx: &mut ViewContext, + ) { + self.include_conversation = !self.include_conversation; + cx.emit(InlineAssistantEvent::IncludeConversationToggled { + include_conversation: self.include_conversation, + }); + cx.notify(); + } + + fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext) { + self.error = Some(error); + self.confirmed = false; + self.prompt_editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.set_field_editor_style( + Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), + cx, + ); + }); + cx.notify(); + } + + fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix > 0 { + self.prompt_history_ix = Some(ix - 1); + let prompt = self.prompt_history[ix - 1].clone(); + self.set_prompt(&prompt, cx); + } + } else if !self.prompt_history.is_empty() { + self.prompt_history_ix = Some(self.prompt_history.len() - 1); + let prompt = self.prompt_history[self.prompt_history.len() - 1].clone(); + self.set_prompt(&prompt, cx); + } + } + + fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { + if let Some(ix) = self.prompt_history_ix { + if ix < self.prompt_history.len() - 1 { + self.prompt_history_ix = Some(ix + 1); + let prompt = self.prompt_history[ix + 1].clone(); + self.set_prompt(&prompt, cx); + } else { + self.prompt_history_ix = None; + let pending_prompt = self.pending_prompt.clone(); + self.set_prompt(&pending_prompt, cx); + } + } + } + + fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext) { + self.prompt_editor.update(cx, |editor, cx| { + editor.buffer().update(cx, |buffer, cx| { + let len = buffer.len(cx); + buffer.edit([(0..len, prompt)], None, cx); + }); + }); + } +} + +// This wouldn't need to exist if we could pass parameters when rendering child views. +#[derive(Copy, Clone, Default)] +struct BlockMeasurements { + anchor_x: f32, + gutter_width: f32, +} + +struct PendingInlineAssist { + kind: InlineAssistKind, + editor: WeakViewHandle, + range: Range, + highlighted_ranges: Vec>, + inline_assistant: Option<(BlockId, ViewHandle)>, + code_generation: Task>, + transaction_id: Option, + _subscriptions: Vec, +} + +fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { + ranges.sort_unstable_by(|a, b| { + a.start + .cmp(&b.start, buffer) + .then_with(|| b.end.cmp(&a.end, buffer)) + }); + + let mut ix = 0; + while ix + 1 < ranges.len() { + let b = ranges[ix + 1].clone(); + let a = &mut ranges[ix]; + if a.end.cmp(&b.start, buffer).is_gt() { + if a.end.cmp(&b.end, buffer).is_lt() { + a.end = b.end; + } + ranges.remove(ix + 1); + } else { + ix += 1; + } + } +} + +fn strip_markdown_codeblock( + stream: impl Stream>, +) -> impl Stream> { + let mut first_line = true; + let mut buffer = String::new(); + let mut starts_with_fenced_code_block = false; + stream.filter_map(move |chunk| { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => return future::ready(Some(Err(err))), + }; + buffer.push_str(&chunk); + + if first_line { + if buffer == "" || buffer == "`" || buffer == "``" { + return future::ready(None); + } else if buffer.starts_with("```") { + starts_with_fenced_code_block = true; + if let Some(newline_ix) = buffer.find('\n') { + buffer.replace_range(..newline_ix + 1, ""); + first_line = false; + } else { + return future::ready(None); + } + } + } + + let text = if starts_with_fenced_code_block { + buffer + .strip_suffix("\n```\n") + .or_else(|| buffer.strip_suffix("\n```")) + .or_else(|| buffer.strip_suffix("\n``")) + .or_else(|| buffer.strip_suffix("\n`")) + .or_else(|| buffer.strip_suffix('\n')) + .unwrap_or(&buffer) + } else { + &buffer + }; + + if text.contains('\n') { + first_line = false; + } + + let remainder = buffer.split_off(text.len()); + let result = if buffer.is_empty() { + None + } else { + Some(Ok(buffer.clone())) + }; + buffer = remainder; + future::ready(result) + }) } #[cfg(test)] mod tests { use super::*; use crate::MessageId; + use futures::stream; use gpui::AppContext; #[gpui::test] @@ -2612,6 +3611,62 @@ mod tests { ); } + #[gpui::test] + async fn test_strip_markdown_codeblock() { + assert_eq!( + strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "```js\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "``\nLorem ipsum dolor\n```" + ); + + fn chunks(text: &str, size: usize) -> impl Stream> { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| Ok(chunk.iter().collect::())) + .collect::>(), + ) + } + } + fn messages( conversation: &ModelHandle, cx: &AppContext, diff --git a/crates/ai/src/streaming_diff.rs b/crates/ai/src/streaming_diff.rs new file mode 100644 index 0000000000..7399a7b4fa --- /dev/null +++ b/crates/ai/src/streaming_diff.rs @@ -0,0 +1,293 @@ +use collections::HashMap; +use ordered_float::OrderedFloat; +use std::{ + cmp, + fmt::{self, Debug}, + ops::Range, +}; + +struct Matrix { + cells: Vec, + rows: usize, + cols: usize, +} + +impl Matrix { + fn new() -> Self { + Self { + cells: Vec::new(), + rows: 0, + cols: 0, + } + } + + fn resize(&mut self, rows: usize, cols: usize) { + self.cells.resize(rows * cols, 0.); + self.rows = rows; + self.cols = cols; + } + + fn get(&self, row: usize, col: usize) -> f64 { + if row >= self.rows { + panic!("row out of bounds") + } + + if col >= self.cols { + panic!("col out of bounds") + } + self.cells[col * self.rows + row] + } + + fn set(&mut self, row: usize, col: usize, value: f64) { + if row >= self.rows { + panic!("row out of bounds") + } + + if col >= self.cols { + panic!("col out of bounds") + } + + self.cells[col * self.rows + row] = value; + } +} + +impl Debug for Matrix { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f)?; + for i in 0..self.rows { + for j in 0..self.cols { + write!(f, "{:5}", self.get(i, j))?; + } + writeln!(f)?; + } + Ok(()) + } +} + +#[derive(Debug)] +pub enum Hunk { + Insert { text: String }, + Remove { len: usize }, + Keep { len: usize }, +} + +pub struct StreamingDiff { + old: Vec, + new: Vec, + scores: Matrix, + old_text_ix: usize, + new_text_ix: usize, + equal_runs: HashMap<(usize, usize), u32>, +} + +impl StreamingDiff { + const INSERTION_SCORE: f64 = -1.; + const DELETION_SCORE: f64 = -20.; + const EQUALITY_BASE: f64 = 1.8; + const MAX_EQUALITY_EXPONENT: i32 = 16; + + pub fn new(old: String) -> Self { + let old = old.chars().collect::>(); + let mut scores = Matrix::new(); + scores.resize(old.len() + 1, 1); + for i in 0..=old.len() { + scores.set(i, 0, i as f64 * Self::DELETION_SCORE); + } + Self { + old, + new: Vec::new(), + scores, + old_text_ix: 0, + new_text_ix: 0, + equal_runs: Default::default(), + } + } + + pub fn push_new(&mut self, text: &str) -> Vec { + self.new.extend(text.chars()); + self.scores.resize(self.old.len() + 1, self.new.len() + 1); + + for j in self.new_text_ix + 1..=self.new.len() { + self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE); + for i in 1..=self.old.len() { + let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE; + let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE; + let equality_score = if self.old[i - 1] == self.new[j - 1] { + let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0); + equal_run += 1; + self.equal_runs.insert((i, j), equal_run); + + let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT); + self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent) + } else { + f64::NEG_INFINITY + }; + + let score = insertion_score.max(deletion_score).max(equality_score); + self.scores.set(i, j, score); + } + } + + let mut max_score = f64::NEG_INFINITY; + let mut next_old_text_ix = self.old_text_ix; + let next_new_text_ix = self.new.len(); + for i in self.old_text_ix..=self.old.len() { + let score = self.scores.get(i, next_new_text_ix); + if score > max_score { + max_score = score; + next_old_text_ix = i; + } + } + + let hunks = self.backtrack(next_old_text_ix, next_new_text_ix); + self.old_text_ix = next_old_text_ix; + self.new_text_ix = next_new_text_ix; + hunks + } + + fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec { + let mut pending_insert: Option> = None; + let mut hunks = Vec::new(); + let mut i = old_text_ix; + let mut j = new_text_ix; + while (i, j) != (self.old_text_ix, self.new_text_ix) { + let insertion_score = if j > self.new_text_ix { + Some((i, j - 1)) + } else { + None + }; + let deletion_score = if i > self.old_text_ix { + Some((i - 1, j)) + } else { + None + }; + let equality_score = if i > self.old_text_ix && j > self.new_text_ix { + if self.old[i - 1] == self.new[j - 1] { + Some((i - 1, j - 1)) + } else { + None + } + } else { + None + }; + + let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score] + .iter() + .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j)))) + .unwrap() + .unwrap(); + + if prev_i == i && prev_j == j - 1 { + if let Some(pending_insert) = pending_insert.as_mut() { + pending_insert.start = prev_j; + } else { + pending_insert = Some(prev_j..j); + } + } else { + if let Some(range) = pending_insert.take() { + hunks.push(Hunk::Insert { + text: self.new[range].iter().collect(), + }); + } + + let char_len = self.old[i - 1].len_utf8(); + if prev_i == i - 1 && prev_j == j { + if let Some(Hunk::Remove { len }) = hunks.last_mut() { + *len += char_len; + } else { + hunks.push(Hunk::Remove { len: char_len }) + } + } else { + if let Some(Hunk::Keep { len }) = hunks.last_mut() { + *len += char_len; + } else { + hunks.push(Hunk::Keep { len: char_len }) + } + } + } + + i = prev_i; + j = prev_j; + } + + if let Some(range) = pending_insert.take() { + hunks.push(Hunk::Insert { + text: self.new[range].iter().collect(), + }); + } + + hunks.reverse(); + hunks + } + + pub fn finish(self) -> Vec { + self.backtrack(self.old.len(), self.new.len()) + } +} + +#[cfg(test)] +mod tests { + use std::env; + + use super::*; + use rand::prelude::*; + + #[gpui::test(iterations = 100)] + fn test_random_diffs(mut rng: StdRng) { + let old_text_len = env::var("OLD_TEXT_LEN") + .map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable")) + .unwrap_or(10); + let new_text_len = env::var("NEW_TEXT_LEN") + .map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable")) + .unwrap_or(10); + + let old = util::RandomCharIter::new(&mut rng) + .take(old_text_len) + .collect::(); + log::info!("old text: {:?}", old); + + let mut diff = StreamingDiff::new(old.clone()); + let mut hunks = Vec::new(); + let mut new_len = 0; + let mut new = String::new(); + while new_len < new_text_len { + let new_chunk_len = rng.gen_range(1..=new_text_len - new_len); + let new_chunk = util::RandomCharIter::new(&mut rng) + .take(new_len) + .collect::(); + log::info!("new chunk: {:?}", new_chunk); + new_len += new_chunk_len; + new.push_str(&new_chunk); + let new_hunks = diff.push_new(&new_chunk); + log::info!("hunks: {:?}", new_hunks); + hunks.extend(new_hunks); + } + let final_hunks = diff.finish(); + log::info!("final hunks: {:?}", final_hunks); + hunks.extend(final_hunks); + + log::info!("new text: {:?}", new); + let mut old_ix = 0; + let mut new_ix = 0; + let mut patched = String::new(); + for hunk in hunks { + match hunk { + Hunk::Keep { len } => { + assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]); + patched.push_str(&old[old_ix..old_ix + len]); + old_ix += len; + new_ix += len; + } + Hunk::Remove { len } => { + old_ix += len; + } + Hunk::Insert { text } => { + assert_eq!(text, &new[new_ix..new_ix + text.len()]); + patched.push_str(&text); + new_ix += text.len(); + } + } + } + assert_eq!(patched, new); + } +} diff --git a/crates/breadcrumbs/src/breadcrumbs.rs b/crates/breadcrumbs/src/breadcrumbs.rs index 615e238648..41985edb75 100644 --- a/crates/breadcrumbs/src/breadcrumbs.rs +++ b/crates/breadcrumbs/src/breadcrumbs.rs @@ -50,7 +50,7 @@ impl View for Breadcrumbs { let not_editor = active_item.downcast::().is_none(); let theme = theme::current(cx).clone(); - let style = &theme.workspace.breadcrumbs; + let style = &theme.workspace.toolbar.breadcrumbs; let breadcrumbs = match active_item.breadcrumbs(&theme, cx) { Some(breadcrumbs) => breadcrumbs, @@ -60,7 +60,7 @@ impl View for Breadcrumbs { .map(|breadcrumb| { Text::new( breadcrumb.text, - theme.workspace.breadcrumbs.default.text.clone(), + theme.workspace.toolbar.breadcrumbs.default.text.clone(), ) .with_highlights(breadcrumb.highlights.unwrap_or_default()) .into_any() @@ -68,10 +68,10 @@ impl View for Breadcrumbs { let crumbs = Flex::row() .with_children(Itertools::intersperse_with(breadcrumbs, || { - Label::new(" 〉 ", style.default.text.clone()).into_any() + Label::new(" › ", style.default.text.clone()).into_any() })) .constrained() - .with_height(theme.workspace.breadcrumb_height) + .with_height(theme.workspace.toolbar.breadcrumb_height) .contained(); if not_editor || !self.pane_focused { diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 5af094df05..4db298fe98 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -273,7 +273,13 @@ impl ActiveCall { .borrow_mut() .take() .ok_or_else(|| anyhow!("no incoming call"))?; - Self::report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx); + Self::report_call_event_for_room( + "decline incoming", + Some(call.room_id), + None, + &self.client, + cx, + ); self.client.send(proto::DeclineCall { room_id: call.room_id, })?; @@ -404,21 +410,19 @@ impl ActiveCall { } fn report_call_event(&self, operation: &'static str, cx: &AppContext) { - if let Some(room) = self.room() { - let room = room.read(cx); - Self::report_call_event_for_room( - operation, - room.id(), - room.channel_id(), - &self.client, - cx, - ) - } + let (room_id, channel_id) = match self.room() { + Some(room) => { + let room = room.read(cx); + (Some(room.id()), room.channel_id()) + } + None => (None, None), + }; + Self::report_call_event_for_room(operation, room_id, channel_id, &self.client, cx) } pub fn report_call_event_for_room( operation: &'static str, - room_id: u64, + room_id: Option, channel_id: Option, client: &Arc, cx: &AppContext, diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 29f4d3493c..e11282cf79 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -10,6 +10,7 @@ pub(crate) fn init(client: &Arc) { client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer); client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator); client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator); + client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator); } pub struct ChannelBuffer { @@ -17,6 +18,7 @@ pub struct ChannelBuffer { connected: bool, collaborators: Vec, buffer: ModelHandle, + buffer_epoch: u64, client: Arc, subscription: Option, } @@ -73,6 +75,7 @@ impl ChannelBuffer { Self { buffer, + buffer_epoch: response.epoch, client, connected: true, collaborators, @@ -82,6 +85,26 @@ impl ChannelBuffer { })) } + pub(crate) fn replace_collaborators( + &mut self, + collaborators: Vec, + cx: &mut ModelContext, + ) { + for old_collaborator in &self.collaborators { + if collaborators + .iter() + .any(|c| c.replica_id == old_collaborator.replica_id) + { + self.buffer.update(cx, |buffer, cx| { + buffer.remove_peer(old_collaborator.replica_id as u16, cx) + }); + } + } + self.collaborators = collaborators; + cx.emit(Event::CollaboratorsChanged); + cx.notify(); + } + async fn handle_update_channel_buffer( this: ModelHandle, update_channel_buffer: TypedEnvelope, @@ -149,6 +172,26 @@ impl ChannelBuffer { Ok(()) } + async fn handle_update_channel_buffer_collaborator( + this: ModelHandle, + message: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + for collaborator in &mut this.collaborators { + if collaborator.peer_id == message.payload.old_peer_id { + collaborator.peer_id = message.payload.new_peer_id; + break; + } + } + cx.emit(Event::CollaboratorsChanged); + cx.notify(); + }); + + Ok(()) + } + fn on_buffer_update( &mut self, _: ModelHandle, @@ -166,6 +209,10 @@ impl ChannelBuffer { } } + pub fn epoch(&self) -> u64 { + self.buffer_epoch + } + pub fn buffer(&self) -> ModelHandle { self.buffer.clone() } @@ -179,6 +226,7 @@ impl ChannelBuffer { } pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { + log::info!("channel buffer {} disconnected", self.channel.id); if self.connected { self.connected = false; self.subscription.take(); diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 861f731331..a4c8da6df4 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,13 +1,15 @@ use crate::channel_buffer::ChannelBuffer; use anyhow::{anyhow, Result}; -use client::{Client, Status, Subscription, User, UserId, UserStore}; +use client::{Client, Subscription, User, UserId, UserStore}; use collections::{hash_map, HashMap, HashSet}; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; -use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; +use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; -use std::sync::Arc; +use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); + pub type ChannelId = u64; pub struct ChannelStore { @@ -22,7 +24,8 @@ pub struct ChannelStore { client: Arc, user_store: ModelHandle, _rpc_subscription: Subscription, - _watch_connection_status: Task<()>, + _watch_connection_status: Task>, + disconnect_channel_buffers_task: Option>, _update_channels: Task<()>, } @@ -67,24 +70,20 @@ impl ChannelStore { let rpc_subscription = client.add_message_handler(cx.handle(), Self::handle_update_channels); - let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let mut connection_status = client.status(); + let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let watch_connection_status = cx.spawn_weak(|this, mut cx| async move { while let Some(status) = connection_status.next().await { - if !status.is_connected() { - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if matches!(status, Status::ConnectionLost | Status::SignedOut) { - this.handle_disconnect(cx); - } else { - this.disconnect_buffers(cx); - } - }); - } else { - break; - } + let this = this.upgrade(&cx)?; + if status.is_connected() { + this.update(&mut cx, |this, cx| this.handle_connect(cx)) + .await + .log_err()?; + } else { + this.update(&mut cx, |this, cx| this.handle_disconnect(cx)); } } + Some(()) }); Self { @@ -100,6 +99,7 @@ impl ChannelStore { user_store, _rpc_subscription: rpc_subscription, _watch_connection_status: watch_connection_status, + disconnect_channel_buffers_task: None, _update_channels: cx.spawn_weak(|this, mut cx| async move { while let Some(update_channels) = update_channels_rx.next().await { if let Some(this) = this.upgrade(&cx) { @@ -152,6 +152,15 @@ impl ChannelStore { self.channels_by_id.get(&channel_id) } + pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool { + if let Some(buffer) = self.opened_buffers.get(&channel_id) { + if let OpenedChannelBuffer::Open(buffer) = buffer { + return buffer.upgrade(cx).is_some(); + } + } + false + } + pub fn open_channel_buffer( &mut self, channel_id: ChannelId, @@ -482,8 +491,106 @@ impl ChannelStore { Ok(()) } - fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) { - self.disconnect_buffers(cx); + fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.disconnect_channel_buffers_task.take(); + + let mut buffer_versions = Vec::new(); + for buffer in self.opened_buffers.values() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + let channel_buffer = buffer.read(cx); + let buffer = channel_buffer.buffer().read(cx); + buffer_versions.push(proto::ChannelBufferVersion { + channel_id: channel_buffer.channel().id, + epoch: channel_buffer.epoch(), + version: language::proto::serialize_version(&buffer.version()), + }); + } + } + } + + if buffer_versions.is_empty() { + return Task::ready(Ok(())); + } + + let response = self.client.request(proto::RejoinChannelBuffers { + buffers: buffer_versions, + }); + + cx.spawn(|this, mut cx| async move { + let mut response = response.await?; + + this.update(&mut cx, |this, cx| { + this.opened_buffers.retain(|_, buffer| match buffer { + OpenedChannelBuffer::Open(channel_buffer) => { + let Some(channel_buffer) = channel_buffer.upgrade(cx) else { + return false; + }; + + channel_buffer.update(cx, |channel_buffer, cx| { + let channel_id = channel_buffer.channel().id; + if let Some(remote_buffer) = response + .buffers + .iter_mut() + .find(|buffer| buffer.channel_id == channel_id) + { + let channel_id = channel_buffer.channel().id; + let remote_version = + language::proto::deserialize_version(&remote_buffer.version); + + channel_buffer.replace_collaborators( + mem::take(&mut remote_buffer.collaborators), + cx, + ); + + let operations = channel_buffer + .buffer() + .update(cx, |buffer, cx| { + let outgoing_operations = + buffer.serialize_ops(Some(remote_version), cx); + let incoming_operations = + mem::take(&mut remote_buffer.operations) + .into_iter() + .map(language::proto::deserialize_operation) + .collect::>>()?; + buffer.apply_ops(incoming_operations, cx)?; + anyhow::Ok(outgoing_operations) + }) + .log_err(); + + if let Some(operations) = operations { + let client = this.client.clone(); + cx.background() + .spawn(async move { + let operations = operations.await; + for chunk in + language::proto::split_operations(operations) + { + client + .send(proto::UpdateChannelBuffer { + channel_id, + operations: chunk, + }) + .ok(); + } + }) + .detach(); + return true; + } + } + + channel_buffer.disconnect(cx); + false + }) + } + OpenedChannelBuffer::Loading(_) => true, + }); + }); + anyhow::Ok(()) + }) + } + + fn handle_disconnect(&mut self, cx: &mut ModelContext) { self.channels_by_id.clear(); self.channel_invitations.clear(); self.channel_participants.clear(); @@ -491,16 +598,23 @@ impl ChannelStore { self.channel_paths.clear(); self.outgoing_invites.clear(); cx.notify(); - } - fn disconnect_buffers(&mut self, cx: &mut ModelContext) { - for (_, buffer) in self.opened_buffers.drain() { - if let OpenedChannelBuffer::Open(buffer) = buffer { - if let Some(buffer) = buffer.upgrade(cx) { - buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + self.disconnect_channel_buffers_task.get_or_insert_with(|| { + cx.spawn_weak(|this, mut cx| async move { + cx.background().timer(RECONNECT_TIMEOUT).await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + for (_, buffer) in this.opened_buffers.drain() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + } + } + } + }); } - } - } + }) + }); } pub(crate) fn update_channels( diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index a32c415f7e..d28c1ab1a9 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1011,9 +1011,9 @@ impl Client { credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - let is_preview = cx.read(|cx| { + let use_preview_server = cx.read(|cx| { if cx.has_global::() { - *cx.global::() == ReleaseChannel::Preview + *cx.global::() != ReleaseChannel::Stable } else { false } @@ -1028,7 +1028,7 @@ impl Client { let http = self.http.clone(); cx.background().spawn(async move { - let mut rpc_url = Self::get_rpc_url(http, is_preview).await?; + let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?; let rpc_host = rpc_url .host_str() .zip(rpc_url.port_or_known_default()) diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 9cc5d13af0..f8642dd7fa 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -73,7 +73,7 @@ pub enum ClickhouseEvent { }, Call { operation: &'static str, - room_id: u64, + room_id: Option, channel_id: Option, }, } diff --git a/crates/clock/src/clock.rs b/crates/clock/src/clock.rs index bc936fcb99..3cbf8d6594 100644 --- a/crates/clock/src/clock.rs +++ b/crates/clock/src/clock.rs @@ -2,70 +2,17 @@ use smallvec::SmallVec; use std::{ cmp::{self, Ordering}, fmt, iter, - ops::{Add, AddAssign}, }; pub type ReplicaId = u16; pub type Seq = u32; -#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Ord, PartialOrd)] -pub struct Local { - pub replica_id: ReplicaId, - pub value: Seq, -} - #[derive(Clone, Copy, Default, Eq, Hash, PartialEq)] pub struct Lamport { pub replica_id: ReplicaId, pub value: Seq, } -impl Local { - pub const MIN: Self = Self { - replica_id: ReplicaId::MIN, - value: Seq::MIN, - }; - pub const MAX: Self = Self { - replica_id: ReplicaId::MAX, - value: Seq::MAX, - }; - - pub fn new(replica_id: ReplicaId) -> Self { - Self { - replica_id, - value: 1, - } - } - - pub fn tick(&mut self) -> Self { - let timestamp = *self; - self.value += 1; - timestamp - } - - pub fn observe(&mut self, timestamp: Self) { - if timestamp.replica_id == self.replica_id { - self.value = cmp::max(self.value, timestamp.value + 1); - } - } -} - -impl<'a> Add<&'a Self> for Local { - type Output = Local; - - fn add(self, other: &'a Self) -> Self::Output { - *cmp::max(&self, other) - } -} - -impl<'a> AddAssign<&'a Local> for Local { - fn add_assign(&mut self, other: &Self) { - if *self < *other { - *self = *other; - } - } -} - /// A vector clock #[derive(Clone, Default, Hash, Eq, PartialEq)] pub struct Global(SmallVec<[u32; 8]>); @@ -79,7 +26,7 @@ impl Global { self.0.get(replica_id as usize).copied().unwrap_or(0) as Seq } - pub fn observe(&mut self, timestamp: Local) { + pub fn observe(&mut self, timestamp: Lamport) { if timestamp.value > 0 { let new_len = timestamp.replica_id as usize + 1; if new_len > self.0.len() { @@ -126,7 +73,7 @@ impl Global { self.0.resize(new_len, 0); } - pub fn observed(&self, timestamp: Local) -> bool { + pub fn observed(&self, timestamp: Lamport) -> bool { self.get(timestamp.replica_id) >= timestamp.value } @@ -178,16 +125,16 @@ impl Global { false } - pub fn iter(&self) -> impl Iterator + '_ { - self.0.iter().enumerate().map(|(replica_id, seq)| Local { + pub fn iter(&self) -> impl Iterator + '_ { + self.0.iter().enumerate().map(|(replica_id, seq)| Lamport { replica_id: replica_id as ReplicaId, value: *seq, }) } } -impl FromIterator for Global { - fn from_iter>(locals: T) -> Self { +impl FromIterator for Global { + fn from_iter>(locals: T) -> Self { let mut result = Self::new(); for local in locals { result.observe(local); @@ -212,6 +159,16 @@ impl PartialOrd for Lamport { } impl Lamport { + pub const MIN: Self = Self { + replica_id: ReplicaId::MIN, + value: Seq::MIN, + }; + + pub const MAX: Self = Self { + replica_id: ReplicaId::MAX, + value: Seq::MAX, + }; + pub fn new(replica_id: ReplicaId) -> Self { Self { value: 1, @@ -230,12 +187,6 @@ impl Lamport { } } -impl fmt::Debug for Local { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Local {{{}: {}}}", self.replica_id, self.value) - } -} - impl fmt::Debug for Lamport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Lamport {{{}: {}}}", self.replica_id, self.value) diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 8adc38615c..4b29a08015 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] default-run = "collab" edition = "2021" name = "collab" -version = "0.18.0" +version = "0.20.0" publish = false [[bin]] @@ -80,6 +80,7 @@ theme = { path = "../theme" } workspace = { path = "../workspace", features = ["test-support"] } collab_ui = { path = "../collab_ui", features = ["test-support"] } +async-trait.workspace = true ctor.workspace = true env_logger.workspace = true indoc.workspace = true diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 888158188f..b5d968ddf3 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -435,6 +435,12 @@ pub struct ChannelsForUser { pub channels_with_admin_privileges: HashSet, } +#[derive(Debug)] +pub struct RejoinedChannelBuffer { + pub buffer: proto::RejoinedChannelBuffer, + pub old_connection_id: ConnectionId, +} + #[derive(Clone)] pub struct JoinRoom { pub room: proto::Room, @@ -498,6 +504,11 @@ pub struct RefreshedRoom { pub canceled_calls_to_user_ids: Vec, } +pub struct RefreshedChannelBuffer { + pub connection_ids: Vec, + pub removed_collaborators: Vec, +} + pub struct Project { pub collaborators: Vec, pub worktrees: BTreeMap, diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 354accc01a..00de201403 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -1,6 +1,6 @@ use super::*; use prost::Message; -use text::{EditOperation, InsertionTimestamp, UndoOperation}; +use text::{EditOperation, UndoOperation}; impl Database { pub async fn join_channel_buffer( @@ -10,8 +10,6 @@ impl Database { connection: ConnectionId, ) -> Result { self.transaction(|tx| async move { - let tx = tx; - self.check_user_is_channel_member(channel_id, user_id, &tx) .await?; @@ -70,7 +68,6 @@ impl Database { .await?; collaborators.push(collaborator); - // Assemble the buffer state let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?; Ok(proto::JoinChannelBufferResponse { @@ -78,6 +75,7 @@ impl Database { replica_id: replica_id.to_proto() as u32, base_text, operations, + epoch: buffer.epoch as u64, collaborators: collaborators .into_iter() .map(|collaborator| proto::Collaborator { @@ -91,6 +89,154 @@ impl Database { .await } + pub async fn rejoin_channel_buffers( + &self, + buffers: &[proto::ChannelBufferVersion], + user_id: UserId, + connection_id: ConnectionId, + ) -> Result> { + self.transaction(|tx| async move { + let mut results = Vec::new(); + for client_buffer in buffers { + let channel_id = ChannelId::from_proto(client_buffer.channel_id); + if self + .check_user_is_channel_member(channel_id, user_id, &*tx) + .await + .is_err() + { + log::info!("user is not a member of channel"); + continue; + } + + let buffer = self.get_channel_buffer(channel_id, &*tx).await?; + let mut collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + + // If the buffer epoch hasn't changed since the client lost + // connection, then the client's buffer can be syncronized with + // the server's buffer. + if buffer.epoch as u64 != client_buffer.epoch { + log::info!("can't rejoin buffer, epoch has changed"); + continue; + } + + // Find the collaborator record for this user's previous lost + // connection. Update it with the new connection id. + let server_id = ServerId(connection_id.owner_id as i32); + let Some(self_collaborator) = collaborators.iter_mut().find(|c| { + c.user_id == user_id + && (c.connection_lost || c.connection_server_id != server_id) + }) else { + log::info!("can't rejoin buffer, no previous collaborator found"); + continue; + }; + let old_connection_id = self_collaborator.connection(); + *self_collaborator = channel_buffer_collaborator::ActiveModel { + id: ActiveValue::Unchanged(self_collaborator.id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), + connection_lost: ActiveValue::Set(false), + ..Default::default() + } + .update(&*tx) + .await?; + + let client_version = version_from_wire(&client_buffer.version); + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; + + let mut rows = buffer_operation::Entity::find() + .filter( + buffer_operation::Column::BufferId + .eq(buffer.id) + .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), + ) + .stream(&*tx) + .await?; + + // Find the server's version vector and any operations + // that the client has not seen. + let mut server_version = clock::Global::new(); + let mut operations = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let timestamp = clock::Lamport { + replica_id: row.replica_id as u16, + value: row.lamport_timestamp as u32, + }; + server_version.observe(timestamp); + if !client_version.observed(timestamp) { + operations.push(proto::Operation { + variant: Some(operation_from_storage(row, serialization_version)?), + }) + } + } + + results.push(RejoinedChannelBuffer { + old_connection_id, + buffer: proto::RejoinedChannelBuffer { + channel_id: client_buffer.channel_id, + version: version_to_wire(&server_version), + operations, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }, + }); + } + + Ok(results) + }) + .await + } + + pub async fn clear_stale_channel_buffer_collaborators( + &self, + channel_id: ChannelId, + server_id: ServerId, + ) -> Result { + self.transaction(|tx| async move { + let collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + + let mut connection_ids = Vec::new(); + let mut removed_collaborators = Vec::new(); + let mut collaborator_ids_to_remove = Vec::new(); + for collaborator in &collaborators { + if !collaborator.connection_lost && collaborator.connection_server_id == server_id { + connection_ids.push(collaborator.connection()); + } else { + removed_collaborators.push(proto::RemoveChannelBufferCollaborator { + channel_id: channel_id.to_proto(), + peer_id: Some(collaborator.connection().into()), + }); + collaborator_ids_to_remove.push(collaborator.id); + } + } + + channel_buffer_collaborator::Entity::delete_many() + .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove)) + .exec(&*tx) + .await?; + + Ok(RefreshedChannelBuffer { + connection_ids, + removed_collaborators, + }) + }) + .await + } + pub async fn leave_channel_buffer( &self, channel_id: ChannelId, @@ -103,6 +249,39 @@ impl Database { .await } + pub async fn leave_channel_buffers( + &self, + connection: ConnectionId, + ) -> Result)>> { + self.transaction(|tx| async move { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + + let channel_ids: Vec = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .filter(Condition::all().add( + channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), + )) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + let mut result = Vec::new(); + for channel_id in channel_ids { + let collaborators = self + .leave_channel_buffer_internal(channel_id, connection, &*tx) + .await?; + result.push((channel_id, collaborators)); + } + + Ok(result) + }) + .await + } + pub async fn leave_channel_buffer_internal( &self, channel_id: ChannelId, @@ -143,46 +322,12 @@ impl Database { drop(rows); if connections.is_empty() { - self.snapshot_buffer(channel_id, &tx).await?; + self.snapshot_channel_buffer(channel_id, &tx).await?; } Ok(connections) } - pub async fn leave_channel_buffers( - &self, - connection: ConnectionId, - ) -> Result)>> { - self.transaction(|tx| async move { - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryChannelIds { - ChannelId, - } - - let channel_ids: Vec = channel_buffer_collaborator::Entity::find() - .select_only() - .column(channel_buffer_collaborator::Column::ChannelId) - .filter(Condition::all().add( - channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), - )) - .into_values::<_, QueryChannelIds>() - .all(&*tx) - .await?; - - let mut result = Vec::new(); - for channel_id in channel_ids { - let collaborators = self - .leave_channel_buffer_internal(channel_id, connection, &*tx) - .await?; - result.push((channel_id, collaborators)); - } - - Ok(result) - }) - .await - } - - #[cfg(debug_assertions)] pub async fn get_channel_buffer_collaborators( &self, channel_id: ChannelId, @@ -225,20 +370,9 @@ impl Database { .await? .ok_or_else(|| anyhow!("no such buffer"))?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryVersion { - OperationSerializationVersion, - } - - let serialization_version: i32 = buffer - .find_related(buffer_snapshot::Entity) - .select_only() - .column(buffer_snapshot::Column::OperationSerializationVersion) - .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch)) - .into_values::<_, QueryVersion>() - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("missing buffer snapshot"))?; + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; let operations = operations .iter() @@ -246,6 +380,16 @@ impl Database { .collect::>(); if !operations.is_empty() { buffer_operation::Entity::insert_many(operations) + .on_conflict( + OnConflict::columns([ + buffer_operation::Column::BufferId, + buffer_operation::Column::Epoch, + buffer_operation::Column::LamportTimestamp, + buffer_operation::Column::ReplicaId, + ]) + .do_nothing() + .to_owned(), + ) .exec(&*tx) .await?; } @@ -271,6 +415,38 @@ impl Database { .await } + async fn get_buffer_operation_serialization_version( + &self, + buffer_id: BufferId, + epoch: i32, + tx: &DatabaseTransaction, + ) -> Result { + Ok(buffer_snapshot::Entity::find() + .filter(buffer_snapshot::Column::BufferId.eq(buffer_id)) + .filter(buffer_snapshot::Column::Epoch.eq(epoch)) + .select_only() + .column(buffer_snapshot::Column::OperationSerializationVersion) + .into_values::<_, QueryOperationSerializationVersion>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("missing buffer snapshot"))?) + } + + async fn get_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Model { + id: channel_id, + ..Default::default() + } + .find_related(buffer::Entity) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such buffer"))?) + } + async fn get_buffer_state( &self, buffer: &buffer::Model, @@ -304,27 +480,20 @@ impl Database { .await?; let mut operations = Vec::new(); while let Some(row) = rows.next().await { - let row = row?; - - let operation = operation_from_storage(row, version)?; operations.push(proto::Operation { - variant: Some(operation), + variant: Some(operation_from_storage(row?, version)?), }) } Ok((base_text, operations)) } - async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> { - let buffer = channel::Model { - id: channel_id, - ..Default::default() - } - .find_related(buffer::Entity) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such buffer"))?; - + async fn snapshot_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result<()> { + let buffer = self.get_channel_buffer(channel_id, tx).await?; let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?; if operations.is_empty() { return Ok(()); @@ -370,7 +539,6 @@ fn operation_to_storage( operation.replica_id, operation.lamport_timestamp, storage::Operation { - local_timestamp: operation.local_timestamp, version: version_to_storage(&operation.version), is_undo: false, edit_ranges: operation @@ -389,7 +557,6 @@ fn operation_to_storage( operation.replica_id, operation.lamport_timestamp, storage::Operation { - local_timestamp: operation.local_timestamp, version: version_to_storage(&operation.version), is_undo: true, edit_ranges: Vec::new(), @@ -399,7 +566,7 @@ fn operation_to_storage( .iter() .map(|entry| storage::UndoCount { replica_id: entry.replica_id, - local_timestamp: entry.local_timestamp, + lamport_timestamp: entry.lamport_timestamp, count: entry.count, }) .collect(), @@ -427,7 +594,6 @@ fn operation_from_storage( Ok(if operation.is_undo { proto::operation::Variant::Undo(proto::operation::Undo { replica_id: row.replica_id as u32, - local_timestamp: operation.local_timestamp as u32, lamport_timestamp: row.lamport_timestamp as u32, version, counts: operation @@ -435,7 +601,7 @@ fn operation_from_storage( .iter() .map(|entry| proto::UndoCount { replica_id: entry.replica_id, - local_timestamp: entry.local_timestamp, + lamport_timestamp: entry.lamport_timestamp, count: entry.count, }) .collect(), @@ -443,7 +609,6 @@ fn operation_from_storage( } else { proto::operation::Variant::Edit(proto::operation::Edit { replica_id: row.replica_id as u32, - local_timestamp: operation.local_timestamp as u32, lamport_timestamp: row.lamport_timestamp as u32, version, ranges: operation @@ -483,10 +648,9 @@ fn version_from_storage(version: &Vec) -> Vec Option { match operation.variant? { proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation { - timestamp: InsertionTimestamp { + timestamp: clock::Lamport { replica_id: edit.replica_id as text::ReplicaId, - local: edit.local_timestamp, - lamport: edit.lamport_timestamp, + value: edit.lamport_timestamp, }, version: version_from_wire(&edit.version), ranges: edit @@ -498,32 +662,26 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option Some(text::Operation::Undo { - lamport_timestamp: clock::Lamport { + proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation { + timestamp: clock::Lamport { replica_id: undo.replica_id as text::ReplicaId, value: undo.lamport_timestamp, }, - undo: UndoOperation { - id: clock::Local { - replica_id: undo.replica_id as text::ReplicaId, - value: undo.local_timestamp, - }, - version: version_from_wire(&undo.version), - counts: undo - .counts - .into_iter() - .map(|c| { - ( - clock::Local { - replica_id: c.replica_id as text::ReplicaId, - value: c.local_timestamp, - }, - c.count, - ) - }) - .collect(), - }, - }), + version: version_from_wire(&undo.version), + counts: undo + .counts + .into_iter() + .map(|c| { + ( + clock::Lamport { + replica_id: c.replica_id as text::ReplicaId, + value: c.lamport_timestamp, + }, + c.count, + ) + }) + .collect(), + })), _ => None, } } @@ -531,7 +689,7 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option clock::Global { let mut version = clock::Global::new(); for entry in message { - version.observe(clock::Local { + version.observe(clock::Lamport { replica_id: entry.replica_id as text::ReplicaId, value: entry.timestamp, }); @@ -539,6 +697,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global { version } +fn version_to_wire(version: &clock::Global) -> Vec { + let mut message = Vec::new(); + for entry in version.iter() { + message.push(proto::VectorClockEntry { + replica_id: entry.replica_id as u32, + timestamp: entry.value, + }); + } + message +} + +#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] +enum QueryOperationSerializationVersion { + OperationSerializationVersion, +} + mod storage { #![allow(non_snake_case)] use prost::Message; @@ -546,8 +720,6 @@ mod storage { #[derive(Message)] pub struct Operation { - #[prost(uint32, tag = "1")] - pub local_timestamp: u32, #[prost(message, repeated, tag = "2")] pub version: Vec, #[prost(bool, tag = "3")] @@ -581,7 +753,7 @@ mod storage { #[prost(uint32, tag = "1")] pub replica_id: u32, #[prost(uint32, tag = "2")] - pub local_timestamp: u32, + pub lamport_timestamp: u32, #[prost(uint32, tag = "3")] pub count: u32, } diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index e3d3643a61..5da4dd1464 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -1,6 +1,20 @@ use super::*; impl Database { + #[cfg(test)] + pub async fn all_channels(&self) -> Result> { + self.transaction(move |tx| async move { + let mut channels = Vec::new(); + let mut rows = channel::Entity::find().stream(&*tx).await?; + while let Some(row) = rows.next().await { + let row = row?; + channels.push((row.id, row.name)); + } + Ok(channels) + }) + .await + } + pub async fn create_root_channel( &self, name: &str, diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 435e729fed..e348b50bee 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -1,7 +1,7 @@ use super::*; impl Database { - pub async fn refresh_room( + pub async fn clear_stale_room_participants( &self, room_id: RoomId, new_server_id: ServerId, diff --git a/crates/collab/src/db/queries/servers.rs b/crates/collab/src/db/queries/servers.rs index 08a2bda16a..e5ceee8887 100644 --- a/crates/collab/src/db/queries/servers.rs +++ b/crates/collab/src/db/queries/servers.rs @@ -14,31 +14,49 @@ impl Database { .await } - pub async fn stale_room_ids( + pub async fn stale_server_resource_ids( &self, environment: &str, new_server_id: ServerId, - ) -> Result> { + ) -> Result<(Vec, Vec)> { self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { + enum QueryRoomIds { RoomId, } + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + let stale_server_epochs = self .stale_server_ids(environment, new_server_id, &tx) .await?; - Ok(room_participant::Entity::find() + let room_ids = room_participant::Entity::find() .select_only() .column(room_participant::Column::RoomId) .distinct() .filter( room_participant::Column::AnsweringConnectionServerId - .is_in(stale_server_epochs), + .is_in(stale_server_epochs.iter().copied()), ) - .into_values::<_, QueryAs>() + .into_values::<_, QueryRoomIds>() .all(&*tx) - .await?) + .await?; + let channel_ids = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .distinct() + .filter( + channel_buffer_collaborator::Column::ConnectionServerId + .is_in(stale_server_epochs.iter().copied()), + ) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + Ok((room_ids, channel_ids)) }) .await } diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index bd7c3e9ffd..5cb1ef6ea3 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -241,7 +241,6 @@ impl Database { result } - #[cfg(debug_assertions)] pub async fn create_user_flag(&self, flag: &str) -> Result { self.transaction(|tx| async move { let flag = feature_flag::Entity::insert(feature_flag::ActiveModel { @@ -257,7 +256,6 @@ impl Database { .await } - #[cfg(debug_assertions)] pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> { self.transaction(|tx| async move { user_feature::Entity::insert(user_feature::ActiveModel { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 6b44711c42..e454fcbb9e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -251,6 +251,7 @@ impl Server { .add_request_handler(join_channel_buffer) .add_request_handler(leave_channel_buffer) .add_message_handler(update_channel_buffer) + .add_request_handler(rejoin_channel_buffers) .add_request_handler(get_channel_members) .add_request_handler(respond_to_channel_invite) .add_request_handler(join_channel) @@ -277,13 +278,33 @@ impl Server { tracing::info!("waiting for cleanup timeout"); timeout.await; tracing::info!("cleanup timeout expired, retrieving stale rooms"); - if let Some(room_ids) = app_state + if let Some((room_ids, channel_ids)) = app_state .db - .stale_room_ids(&app_state.config.zed_environment, server_id) + .stale_server_resource_ids(&app_state.config.zed_environment, server_id) .await .trace_err() { tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms"); + tracing::info!( + stale_channel_buffer_count = channel_ids.len(), + "retrieved stale channel buffers" + ); + + for channel_id in channel_ids { + if let Some(refreshed_channel_buffer) = app_state + .db + .clear_stale_channel_buffer_collaborators(channel_id, server_id) + .await + .trace_err() + { + for connection_id in refreshed_channel_buffer.connection_ids { + for message in &refreshed_channel_buffer.removed_collaborators { + peer.send(connection_id, message.clone()).trace_err(); + } + } + } + } + for room_id in room_ids { let mut contacts_to_update = HashSet::default(); let mut canceled_calls_to_user_ids = Vec::new(); @@ -292,7 +313,7 @@ impl Server { if let Some(mut refreshed_room) = app_state .db - .refresh_room(room_id, server_id) + .clear_stale_room_participants(room_id, server_id) .await .trace_err() { @@ -854,13 +875,13 @@ async fn connection_lost( .await .trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); - futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id); leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); if !session .connection_pool() @@ -2547,6 +2568,41 @@ async fn update_channel_buffer( Ok(()) } +async fn rejoin_channel_buffers( + request: proto::RejoinChannelBuffers, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let buffers = db + .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .await?; + + for buffer in &buffers { + let collaborators_to_notify = buffer + .buffer + .collaborators + .iter() + .filter_map(|c| Some(c.peer_id?.into())); + channel_buffer_updated( + session.connection_id, + collaborators_to_notify, + &proto::UpdateChannelBufferCollaborator { + channel_id: buffer.buffer.channel_id, + old_peer_id: Some(buffer.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + &session.peer, + ); + } + + response.send(proto::RejoinChannelBuffersResponse { + buffers: buffers.into_iter().map(|b| b.buffer).collect(), + })?; + + Ok(()) +} + async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 25f059c0aa..3000f0d8c3 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -1,555 +1,18 @@ -use crate::{ - db::{tests::TestDb, NewUserParams, UserId}, - executor::Executor, - rpc::{Server, CLEANUP_TIMEOUT}, - AppState, -}; -use anyhow::anyhow; -use call::{ActiveCall, Room}; -use channel::ChannelStore; -use client::{ - self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, -}; -use collections::{HashMap, HashSet}; -use fs::FakeFs; -use futures::{channel::oneshot, StreamExt as _}; -use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; -use language::LanguageRegistry; -use parking_lot::Mutex; -use project::{Project, WorktreeId}; -use settings::SettingsStore; -use std::{ - cell::{Ref, RefCell, RefMut}, - env, - ops::{Deref, DerefMut}, - path::Path, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, - Arc, - }, -}; -use util::http::FakeHttpClient; -use workspace::Workspace; +use call::Room; +use gpui::{ModelHandle, TestAppContext}; mod channel_buffer_tests; mod channel_tests; mod integration_tests; -mod randomized_integration_tests; +mod random_channel_buffer_tests; +mod random_project_collaboration_tests; +mod randomized_test_helpers; +mod test_server; -struct TestServer { - app_state: Arc, - server: Arc, - connection_killers: Arc>>>, - forbid_connections: Arc, - _test_db: TestDb, - test_live_kit_server: Arc, -} - -impl TestServer { - async fn start(deterministic: &Arc) -> Self { - static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - - let use_postgres = env::var("USE_POSTGRES").ok(); - let use_postgres = use_postgres.as_deref(); - let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { - TestDb::postgres(deterministic.build_background()) - } else { - TestDb::sqlite(deterministic.build_background()) - }; - let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); - let live_kit_server = live_kit_client::TestServer::create( - format!("http://livekit.{}.test", live_kit_server_id), - format!("devkey-{}", live_kit_server_id), - format!("secret-{}", live_kit_server_id), - deterministic.build_background(), - ) - .unwrap(); - let app_state = Self::build_app_state(&test_db, &live_kit_server).await; - let epoch = app_state - .db - .create_server(&app_state.config.zed_environment) - .await - .unwrap(); - let server = Server::new( - epoch, - app_state.clone(), - Executor::Deterministic(deterministic.build_background()), - ); - server.start().await.unwrap(); - // Advance clock to ensure the server's cleanup task is finished. - deterministic.advance_clock(CLEANUP_TIMEOUT); - Self { - app_state, - server, - connection_killers: Default::default(), - forbid_connections: Default::default(), - _test_db: test_db, - test_live_kit_server: live_kit_server, - } - } - - async fn reset(&self) { - self.app_state.db.reset(); - let epoch = self - .app_state - .db - .create_server(&self.app_state.config.zed_environment) - .await - .unwrap(); - self.server.reset(epoch); - } - - async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { - cx.update(|cx| { - if cx.has_global::() { - panic!("Same cx used to create two test clients") - } - cx.set_global(SettingsStore::test(cx)); - }); - - let http = FakeHttpClient::with_404_response(); - let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await - { - user.id - } else { - self.app_state - .db - .create_user( - &format!("{name}@example.com"), - false, - NewUserParams { - github_login: name.into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .expect("creating user failed") - .user_id - }; - let client_name = name.to_string(); - let mut client = cx.read(|cx| Client::new(http.clone(), cx)); - let server = self.server.clone(); - let db = self.app_state.db.clone(); - let connection_killers = self.connection_killers.clone(); - let forbid_connections = self.forbid_connections.clone(); - - Arc::get_mut(&mut client) - .unwrap() - .set_id(user_id.0 as usize) - .override_authenticate(move |cx| { - cx.spawn(|_| async move { - let access_token = "the-token".to_string(); - Ok(Credentials { - user_id: user_id.0 as u64, - access_token, - }) - }) - }) - .override_establish_connection(move |credentials, cx| { - assert_eq!(credentials.user_id, user_id.0 as u64); - assert_eq!(credentials.access_token, "the-token"); - - let server = server.clone(); - let db = db.clone(); - let connection_killers = connection_killers.clone(); - let forbid_connections = forbid_connections.clone(); - let client_name = client_name.clone(); - cx.spawn(move |cx| async move { - if forbid_connections.load(SeqCst) { - Err(EstablishConnectionError::other(anyhow!( - "server is forbidding connections" - ))) - } else { - let (client_conn, server_conn, killed) = - Connection::in_memory(cx.background()); - let (connection_id_tx, connection_id_rx) = oneshot::channel(); - let user = db - .get_user_by_id(user_id) - .await - .expect("retrieving user failed") - .unwrap(); - cx.background() - .spawn(server.handle_connection( - server_conn, - client_name, - user, - Some(connection_id_tx), - Executor::Deterministic(cx.background()), - )) - .detach(); - let connection_id = connection_id_rx.await.unwrap(); - connection_killers - .lock() - .insert(connection_id.into(), killed); - Ok(client_conn) - } - }) - }); - - let fs = FakeFs::new(cx.background()); - let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); - let channel_store = - cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx)); - let app_state = Arc::new(workspace::AppState { - client: client.clone(), - user_store: user_store.clone(), - channel_store: channel_store.clone(), - languages: Arc::new(LanguageRegistry::test()), - fs: fs.clone(), - build_window_options: |_, _, _| Default::default(), - initialize_workspace: |_, _, _, _| Task::ready(Ok(())), - background_actions: || &[], - }); - - cx.update(|cx| { - theme::init((), cx); - Project::init(&client, cx); - client::init(&client, cx); - language::init(cx); - editor::init_settings(cx); - workspace::init(app_state.clone(), cx); - audio::init((), cx); - call::init(client.clone(), user_store.clone(), cx); - channel::init(&client); - }); - - client - .authenticate_and_connect(false, &cx.to_async()) - .await - .unwrap(); - - let client = TestClient { - app_state, - username: name.to_string(), - state: Default::default(), - }; - client.wait_for_current_user(cx).await; - client - } - - fn disconnect_client(&self, peer_id: PeerId) { - self.connection_killers - .lock() - .remove(&peer_id) - .unwrap() - .store(true, SeqCst); - } - - fn forbid_connections(&self) { - self.forbid_connections.store(true, SeqCst); - } - - fn allow_connections(&self) { - self.forbid_connections.store(false, SeqCst); - } - - async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { - for ix in 1..clients.len() { - let (left, right) = clients.split_at_mut(ix); - let (client_a, cx_a) = left.last_mut().unwrap(); - for (client_b, cx_b) in right { - client_a - .app_state - .user_store - .update(*cx_a, |store, cx| { - store.request_contact(client_b.user_id().unwrap(), cx) - }) - .await - .unwrap(); - cx_a.foreground().run_until_parked(); - client_b - .app_state - .user_store - .update(*cx_b, |store, cx| { - store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) - }) - .await - .unwrap(); - } - } - } - - async fn make_channel( - &self, - channel: &str, - admin: (&TestClient, &mut TestAppContext), - members: &mut [(&TestClient, &mut TestAppContext)], - ) -> u64 { - let (admin_client, admin_cx) = admin; - let channel_id = admin_client - .app_state - .channel_store - .update(admin_cx, |channel_store, cx| { - channel_store.create_channel(channel, None, cx) - }) - .await - .unwrap(); - - for (member_client, member_cx) in members { - admin_client - .app_state - .channel_store - .update(admin_cx, |channel_store, cx| { - channel_store.invite_member( - channel_id, - member_client.user_id().unwrap(), - false, - cx, - ) - }) - .await - .unwrap(); - - admin_cx.foreground().run_until_parked(); - - member_client - .app_state - .channel_store - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) - }) - .await - .unwrap(); - } - - channel_id - } - - async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { - self.make_contacts(clients).await; - - let (left, right) = clients.split_at_mut(1); - let (_client_a, cx_a) = &mut left[0]; - let active_call_a = cx_a.read(ActiveCall::global); - - for (client_b, cx_b) in right { - let user_id_b = client_b.current_user_id(*cx_b).to_proto(); - active_call_a - .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx)) - .await - .unwrap(); - - cx_b.foreground().run_until_parked(); - let active_call_b = cx_b.read(ActiveCall::global); - active_call_b - .update(*cx_b, |call, cx| call.accept_incoming(cx)) - .await - .unwrap(); - } - } - - async fn build_app_state( - test_db: &TestDb, - fake_server: &live_kit_client::TestServer, - ) -> Arc { - Arc::new(AppState { - db: test_db.db().clone(), - live_kit_client: Some(Arc::new(fake_server.create_api_client())), - config: Default::default(), - }) - } -} - -impl Deref for TestServer { - type Target = Server; - - fn deref(&self) -> &Self::Target { - &self.server - } -} - -impl Drop for TestServer { - fn drop(&mut self) { - self.server.teardown(); - self.test_live_kit_server.teardown().unwrap(); - } -} - -struct TestClient { - username: String, - state: RefCell, - app_state: Arc, -} - -#[derive(Default)] -struct TestClientState { - local_projects: Vec>, - remote_projects: Vec>, - buffers: HashMap, HashSet>>, -} - -impl Deref for TestClient { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.app_state.client - } -} - -struct ContactsSummary { - pub current: Vec, - pub outgoing_requests: Vec, - pub incoming_requests: Vec, -} - -impl TestClient { - pub fn fs(&self) -> &FakeFs { - self.app_state.fs.as_fake() - } - - pub fn channel_store(&self) -> &ModelHandle { - &self.app_state.channel_store - } - - pub fn user_store(&self) -> &ModelHandle { - &self.app_state.user_store - } - - pub fn language_registry(&self) -> &Arc { - &self.app_state.languages - } - - pub fn client(&self) -> &Arc { - &self.app_state.client - } - - pub fn current_user_id(&self, cx: &TestAppContext) -> UserId { - UserId::from_proto( - self.app_state - .user_store - .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), - ) - } - - async fn wait_for_current_user(&self, cx: &TestAppContext) { - let mut authed_user = self - .app_state - .user_store - .read_with(cx, |user_store, _| user_store.watch_current_user()); - while authed_user.next().await.unwrap().is_none() {} - } - - async fn clear_contacts(&self, cx: &mut TestAppContext) { - self.app_state - .user_store - .update(cx, |store, _| store.clear_contacts()) - .await; - } - - fn local_projects<'a>(&'a self) -> impl Deref>> + 'a { - Ref::map(self.state.borrow(), |state| &state.local_projects) - } - - fn remote_projects<'a>(&'a self) -> impl Deref>> + 'a { - Ref::map(self.state.borrow(), |state| &state.remote_projects) - } - - fn local_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects) - } - - fn remote_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects) - } - - fn buffers_for_project<'a>( - &'a self, - project: &ModelHandle, - ) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| { - state.buffers.entry(project.clone()).or_default() - }) - } - - fn buffers<'a>( - &'a self, - ) -> impl DerefMut, HashSet>>> + 'a - { - RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) - } - - fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { - self.app_state - .user_store - .read_with(cx, |store, _| ContactsSummary { - current: store - .contacts() - .iter() - .map(|contact| contact.user.github_login.clone()) - .collect(), - outgoing_requests: store - .outgoing_contact_requests() - .iter() - .map(|user| user.github_login.clone()) - .collect(), - incoming_requests: store - .incoming_contact_requests() - .iter() - .map(|user| user.github_login.clone()) - .collect(), - }) - } - - async fn build_local_project( - &self, - root_path: impl AsRef, - cx: &mut TestAppContext, - ) -> (ModelHandle, WorktreeId) { - let project = cx.update(|cx| { - Project::local( - self.client().clone(), - self.app_state.user_store.clone(), - self.app_state.languages.clone(), - self.app_state.fs.clone(), - cx, - ) - }); - let (worktree, _) = project - .update(cx, |p, cx| { - p.find_or_create_local_worktree(root_path, true, cx) - }) - .await - .unwrap(); - worktree - .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) - .await; - (project, worktree.read_with(cx, |tree, _| tree.id())) - } - - async fn build_remote_project( - &self, - host_project_id: u64, - guest_cx: &mut TestAppContext, - ) -> ModelHandle { - let active_call = guest_cx.read(ActiveCall::global); - let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone()); - room.update(guest_cx, |room, cx| { - room.join_project( - host_project_id, - self.app_state.languages.clone(), - self.app_state.fs.clone(), - cx, - ) - }) - .await - .unwrap() - } - - fn build_workspace( - &self, - project: &ModelHandle, - cx: &mut TestAppContext, - ) -> WindowHandle { - cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) - } -} - -impl Drop for TestClient { - fn drop(&mut self) { - self.app_state.client.teardown(); - } -} +pub use randomized_test_helpers::{ + run_randomized_test, save_randomized_test_plan, RandomizedTest, TestError, UserTestPlan, +}; +pub use test_server::{TestClient, TestServer}; #[derive(Debug, Eq, PartialEq)] struct RoomParticipants { diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 8ac4dbbd3f..fe286895b4 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -1,4 +1,7 @@ -use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; +use crate::{ + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::TestServer, +}; use call::ActiveCall; use channel::Channel; use client::UserId; @@ -21,20 +24,19 @@ async fn test_core_channel_buffers( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; - let zed_id = server + let channel_id = server .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; // Client A joins the channel buffer let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); // Client A edits the buffer let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer()); - buffer_a.update(cx_a, |buffer, cx| { buffer.edit([(0..0, "hello world")], None, cx) }); @@ -45,17 +47,15 @@ async fn test_core_channel_buffers( buffer.edit([(0..5, "goodbye")], None, cx) }); buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx)); - deterministic.run_until_parked(); - assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world"); + deterministic.run_until_parked(); // Client B joins the channel buffer let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - channel_buffer_b.read_with(cx_b, |buffer, _| { assert_collaborators( buffer.collaborators(), @@ -91,9 +91,7 @@ async fn test_core_channel_buffers( // Client A rejoins the channel buffer let _channel_buffer_a = client_a .channel_store() - .update(cx_a, |channels, cx| { - channels.open_channel_buffer(zed_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); deterministic.run_until_parked(); @@ -136,7 +134,7 @@ async fn test_channel_buffer_replica_ids( let channel_id = server .make_channel( - "zed", + "the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b), (&client_c, cx_c)], ) @@ -160,23 +158,17 @@ async fn test_channel_buffer_replica_ids( // C first so that the replica IDs in the project and the channel buffer are different let channel_buffer_c = client_c .channel_store() - .update(cx_c, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -286,28 +278,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; - let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await; + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut []) + .await; let channel_buffer_1 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_2 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_3 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); // All concurrent tasks for opening a channel buffer return the same model handle. - let (channel_buffer_1, channel_buffer_2, channel_buffer_3) = + let (channel_buffer, channel_buffer_2, channel_buffer_3) = future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3) .await .unwrap(); - let model_id = channel_buffer_1.id(); - assert_eq!(channel_buffer_1, channel_buffer_2); - assert_eq!(channel_buffer_1, channel_buffer_3); + let channel_buffer_model_id = channel_buffer.id(); + assert_eq!(channel_buffer, channel_buffer_2); + assert_eq!(channel_buffer, channel_buffer_3); - channel_buffer_1.update(cx_a, |buffer, cx| { + channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "hello")], None, cx); }) @@ -315,7 +309,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu deterministic.run_until_parked(); cx_a.update(|_| { - drop(channel_buffer_1); + drop(channel_buffer); drop(channel_buffer_2); drop(channel_buffer_3); }); @@ -324,10 +318,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu // The channel buffer can be reopened after dropping it. let channel_buffer = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - assert_ne!(channel_buffer.id(), model_id); + assert_ne!(channel_buffer.id(), channel_buffer_model_id); channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, _| { assert_eq!(buffer.text(), "hello"); @@ -347,22 +341,17 @@ async fn test_channel_buffer_disconnect( let client_b = server.create_client(cx_b, "user_b").await; let channel_id = server - .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -375,7 +364,7 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); @@ -403,13 +392,180 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); }); } +#[gpui::test] +async fn test_rejoin_channel_buffer( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client A disconnects. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + + // Both clients make an edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + + // Both clients see their own edit. + deterministic.run_until_parked(); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "12"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "01"); + }); + + // Client A reconnects. Both clients see each other's edits, and see + // the same collaborators. + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); +} + +#[gpui::test] +async fn test_channel_buffers_and_server_restarts( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let channel_id = server + .make_channel( + "the-channel", + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let _channel_buffer_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client C can't reconnect. + client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + + // Server stops. + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + + // While the server is down, both clients make an edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + + // Server restarts. + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + + // Clients reconnects. Clients A and B see each other's edits, and see + // that client C has disconnected. + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!( + buffer_a + .collaborators() + .iter() + .map(|c| c.user_id) + .collect::>(), + vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()] + ); + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!( diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f64a82e32e..8121b0ac91 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -9,7 +9,7 @@ use editor::{ test::editor_test_context::EditorTestContext, ConfirmCodeAction, ConfirmCompletion, ConfirmRename, Editor, ExcerptRange, MultiBuffer, Redo, Rename, ToggleCodeActions, Undo, }; -use fs::{repository::GitFileStatus, FakeFs, Fs as _, LineEnding, RemoveOptions}; +use fs::{repository::GitFileStatus, FakeFs, Fs as _, RemoveOptions}; use futures::StreamExt as _; use gpui::{ executor::Deterministic, geometry::vector::vec2f, test::EmptyView, AppContext, ModelHandle, @@ -19,7 +19,7 @@ use indoc::indoc; use language::{ language_settings::{AllLanguageSettings, Formatter, InlayHintSettings}, tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, - LanguageConfig, OffsetRangeExt, Point, Rope, + LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, }; use live_kit_client::MacOSDisplay; use lsp::LanguageServerId; @@ -33,7 +33,7 @@ use std::{ path::{Path, PathBuf}, rc::Rc, sync::{ - atomic::{AtomicBool, AtomicU32, Ordering::SeqCst}, + atomic::{self, AtomicBool, AtomicUsize, Ordering::SeqCst}, Arc, }, }; @@ -7799,7 +7799,7 @@ async fn test_on_input_format_from_guest_to_host( }); } -#[gpui::test] +#[gpui::test(iterations = 10)] async fn test_mutual_editor_inlay_hint_cache_update( deterministic: Arc, cx_a: &mut TestAppContext, @@ -7913,30 +7913,27 @@ async fn test_mutual_editor_inlay_hint_cache_update( .unwrap(); // Set up the language server to return an additional inlay hint on each request. - let next_call_id = Arc::new(AtomicU32::new(0)); + let edits_made = Arc::new(AtomicUsize::new(0)); + let closure_edits_made = Arc::clone(&edits_made); fake_language_server .handle_request::(move |params, _| { - let task_next_call_id = Arc::clone(&next_call_id); + let task_edits_made = Arc::clone(&closure_edits_made); async move { assert_eq!( params.text_document.uri, lsp::Url::from_file_path("/a/main.rs").unwrap(), ); - let call_count = task_next_call_id.fetch_add(1, SeqCst); - Ok(Some( - (0..=call_count) - .map(|ix| lsp::InlayHint { - position: lsp::Position::new(0, ix), - label: lsp::InlayHintLabel::String(ix.to_string()), - kind: None, - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: None, - data: None, - }) - .collect(), - )) + let edits_made = task_edits_made.load(atomic::Ordering::Acquire); + Ok(Some(vec![lsp::InlayHint { + position: lsp::Position::new(0, edits_made as u32), + label: lsp::InlayHintLabel::String(edits_made.to_string()), + kind: None, + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: None, + data: None, + }])) } }) .next() @@ -7945,17 +7942,17 @@ async fn test_mutual_editor_inlay_hint_cache_update( deterministic.run_until_parked(); - let mut edits_made = 1; + let initial_edit = edits_made.load(atomic::Ordering::Acquire); editor_a.update(cx_a, |editor, _| { assert_eq!( - vec!["0".to_string()], + vec![initial_edit.to_string()], extract_hint_labels(editor), "Host should get its first hints when opens an editor" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, + 1, "Host editor update the cache version after every cache/view change", ); }); @@ -7972,144 +7969,104 @@ async fn test_mutual_editor_inlay_hint_cache_update( deterministic.run_until_parked(); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec!["0".to_string(), "1".to_string()], + vec![initial_edit.to_string()], extract_hint_labels(editor), "Client should get its first hints when opens an editor" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, + 1, "Guest editor update the cache version after every cache/view change" ); }); + let after_client_edit = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; editor_b.update(cx_b, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([13..13].clone())); editor.handle_input(":", cx); cx.focus(&editor_b); - edits_made += 1; }); deterministic.run_until_parked(); editor_a.update(cx_a, |editor, _| { assert_eq!( - vec![ - "0".to_string(), - "1".to_string(), - "2".to_string(), - "3".to_string() - ], + vec![after_client_edit.to_string()], extract_hint_labels(editor), - "Guest should get hints the 1st edit and 2nd LSP query" ); let inlay_cache = editor.inlay_hint_cache(); - assert_eq!(inlay_cache.version(), edits_made); + assert_eq!(inlay_cache.version(), 2); }); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec!["0".to_string(), "1".to_string(), "2".to_string(),], + vec![after_client_edit.to_string()], extract_hint_labels(editor), - "Guest should get hints the 1st edit and 2nd LSP query" ); let inlay_cache = editor.inlay_hint_cache(); - assert_eq!(inlay_cache.version(), edits_made); + assert_eq!(inlay_cache.version(), 2); }); + let after_host_edit = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; editor_a.update(cx_a, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([13..13])); editor.handle_input("a change to increment both buffers' versions", cx); cx.focus(&editor_a); - edits_made += 1; }); deterministic.run_until_parked(); editor_a.update(cx_a, |editor, _| { assert_eq!( - vec![ - "0".to_string(), - "1".to_string(), - "2".to_string(), - "3".to_string(), - "4".to_string() - ], + vec![after_host_edit.to_string()], extract_hint_labels(editor), - "Host should get hints from 3rd edit, 5th LSP query: \ -4th query was made by guest (but not applied) due to cache invalidation logic" ); let inlay_cache = editor.inlay_hint_cache(); - assert_eq!(inlay_cache.version(), edits_made); + assert_eq!(inlay_cache.version(), 3); }); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec![ - "0".to_string(), - "1".to_string(), - "2".to_string(), - "3".to_string(), - "4".to_string(), - "5".to_string(), - ], + vec![after_host_edit.to_string()], extract_hint_labels(editor), - "Guest should get hints from 3rd edit, 6th LSP query" ); let inlay_cache = editor.inlay_hint_cache(); - assert_eq!(inlay_cache.version(), edits_made); + assert_eq!(inlay_cache.version(), 3); }); + let after_special_edit_for_refresh = edits_made.fetch_add(1, atomic::Ordering::Release) + 1; fake_language_server .request::(()) .await .expect("inlay refresh request failed"); - edits_made += 1; deterministic.run_until_parked(); editor_a.update(cx_a, |editor, _| { assert_eq!( - vec![ - "0".to_string(), - "1".to_string(), - "2".to_string(), - "3".to_string(), - "4".to_string(), - "5".to_string(), - "6".to_string(), - ], + vec![after_special_edit_for_refresh.to_string()], extract_hint_labels(editor), - "Host should react to /refresh LSP request and get new hints from 7th LSP query" + "Host should react to /refresh LSP request" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, + 4, "Host should accepted all edits and bump its cache version every time" ); }); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec![ - "0".to_string(), - "1".to_string(), - "2".to_string(), - "3".to_string(), - "4".to_string(), - "5".to_string(), - "6".to_string(), - "7".to_string(), - ], + vec![after_special_edit_for_refresh.to_string()], extract_hint_labels(editor), - "Guest should get a /refresh LSP request propagated by host and get new hints from 8th LSP query" + "Guest should get a /refresh LSP request propagated by host" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, + 4, "Guest should accepted all edits and bump its cache version every time" ); }); } -#[gpui::test] +#[gpui::test(iterations = 10)] async fn test_inlay_hint_refresh_is_forwarded( deterministic: Arc, cx_a: &mut TestAppContext, @@ -8223,35 +8180,34 @@ async fn test_inlay_hint_refresh_is_forwarded( .downcast::() .unwrap(); + let other_hints = Arc::new(AtomicBool::new(false)); let fake_language_server = fake_language_servers.next().await.unwrap(); - let next_call_id = Arc::new(AtomicU32::new(0)); + let closure_other_hints = Arc::clone(&other_hints); fake_language_server .handle_request::(move |params, _| { - let task_next_call_id = Arc::clone(&next_call_id); + let task_other_hints = Arc::clone(&closure_other_hints); async move { assert_eq!( params.text_document.uri, lsp::Url::from_file_path("/a/main.rs").unwrap(), ); - let mut current_call_id = Arc::clone(&task_next_call_id).fetch_add(1, SeqCst); - let mut new_hints = Vec::with_capacity(current_call_id as usize); - loop { - new_hints.push(lsp::InlayHint { - position: lsp::Position::new(0, current_call_id), - label: lsp::InlayHintLabel::String(current_call_id.to_string()), - kind: None, - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: None, - data: None, - }); - if current_call_id == 0 { - break; - } - current_call_id -= 1; - } - Ok(Some(new_hints)) + let other_hints = task_other_hints.load(atomic::Ordering::Acquire); + let character = if other_hints { 0 } else { 2 }; + let label = if other_hints { + "other hint" + } else { + "initial hint" + }; + Ok(Some(vec![lsp::InlayHint { + position: lsp::Position::new(0, character), + label: lsp::InlayHintLabel::String(label.to_string()), + kind: None, + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: None, + data: None, + }])) } }) .next() @@ -8270,26 +8226,26 @@ async fn test_inlay_hint_refresh_is_forwarded( assert_eq!( inlay_cache.version(), 0, - "Host should not increment its cache version due to no changes", + "Turned off hints should not generate version updates" ); }); - let mut edits_made = 1; cx_b.foreground().run_until_parked(); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec!["0".to_string()], + vec!["initial hint".to_string()], extract_hint_labels(editor), "Client should get its first hints when opens an editor" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, - "Guest editor update the cache version after every cache/view change" + 1, + "Should update cache verison after first hints" ); }); + other_hints.fetch_or(true, atomic::Ordering::Release); fake_language_server .request::(()) .await @@ -8304,22 +8260,21 @@ async fn test_inlay_hint_refresh_is_forwarded( assert_eq!( inlay_cache.version(), 0, - "Host should not increment its cache version due to no changes", + "Turned off hints should not generate version updates, again" ); }); - edits_made += 1; cx_b.foreground().run_until_parked(); editor_b.update(cx_b, |editor, _| { assert_eq!( - vec!["0".to_string(), "1".to_string(),], + vec!["other hint".to_string()], extract_hint_labels(editor), "Guest should get a /refresh LSP request propagated by host despite host hints are off" ); let inlay_cache = editor.inlay_hint_cache(); assert_eq!( inlay_cache.version(), - edits_made, + 2, "Guest should accepted all edits and bump its cache version every time" ); }); diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs new file mode 100644 index 0000000000..a60d3d7d7d --- /dev/null +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -0,0 +1,288 @@ +use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; +use anyhow::Result; +use async_trait::async_trait; +use gpui::{executor::Deterministic, TestAppContext}; +use rand::prelude::*; +use serde_derive::{Deserialize, Serialize}; +use std::{ops::Range, rc::Rc, sync::Arc}; +use text::Bias; + +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] +async fn test_random_channel_buffers( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + run_randomized_test::(cx, deterministic, rng).await; +} + +struct RandomChannelBufferTest; + +#[derive(Clone, Serialize, Deserialize)] +enum ChannelBufferOperation { + JoinChannelNotes { + channel_name: String, + }, + LeaveChannelNotes { + channel_name: String, + }, + EditChannelNotes { + channel_name: String, + edits: Vec<(Range, Arc)>, + }, + Noop, +} + +const CHANNEL_COUNT: usize = 3; + +#[async_trait(?Send)] +impl RandomizedTest for RandomChannelBufferTest { + type Operation = ChannelBufferOperation; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for ix in 0..CHANNEL_COUNT { + let id = db + .create_channel( + &format!("channel-{ix}"), + None, + &format!("livekit-room-{ix}"), + users[0].user_id, + ) + .await + .unwrap(); + for user in &users[1..] { + db.invite_channel_member(id, user.user_id, users[0].user_id, false) + .await + .unwrap(); + db.respond_to_channel_invite(id, user.user_id, true) + .await + .unwrap(); + } + } + } + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + _: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ChannelBufferOperation { + let channel_store = client.channel_store().clone(); + let channel_buffers = client.channel_buffers(); + + // When signed out, we can't do anything unless a channel buffer is + // already open. + if channel_buffers.is_empty() + && channel_store.read_with(cx, |store, _| store.channel_count() == 0) + { + return ChannelBufferOperation::Noop; + } + + loop { + match rng.gen_range(0..100_u32) { + 0..=29 => { + let channel_name = client.channel_store().read_with(cx, |store, cx| { + store.channels().find_map(|(_, channel)| { + if store.has_open_channel_buffer(channel.id, cx) { + None + } else { + Some(channel.name.clone()) + } + }) + }); + if let Some(channel_name) = channel_name { + break ChannelBufferOperation::JoinChannelNotes { channel_name }; + } + } + + 30..=40 => { + if let Some(buffer) = channel_buffers.iter().choose(rng) { + let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone()); + break ChannelBufferOperation::LeaveChannelNotes { channel_name }; + } + } + + _ => { + if let Some(buffer) = channel_buffers.iter().choose(rng) { + break buffer.read_with(cx, |b, _| { + let channel_name = b.channel().name.clone(); + let edits = b + .buffer() + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } + }); + } + } + } + } + } + + async fn apply_operation( + client: &TestClient, + operation: ChannelBufferOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + match operation { + ChannelBufferOperation::JoinChannelNotes { channel_name } => { + let buffer = client.channel_store().update(cx, |store, cx| { + let channel_id = store + .channels() + .find(|(_, c)| c.name == channel_name) + .unwrap() + .1 + .id; + if store.has_open_channel_buffer(channel_id, cx) { + Err(TestError::Inapplicable) + } else { + Ok(store.open_channel_buffer(channel_id, cx)) + } + })?; + + log::info!( + "{}: opening notes for channel {channel_name}", + client.username + ); + client.channel_buffers().insert(buffer.await?); + } + + ChannelBufferOperation::LeaveChannelNotes { channel_name } => { + let buffer = cx.update(|cx| { + let mut left_buffer = Err(TestError::Inapplicable); + client.channel_buffers().retain(|buffer| { + if buffer.read(cx).channel().name == channel_name { + left_buffer = Ok(buffer.clone()); + false + } else { + true + } + }); + left_buffer + })?; + + log::info!( + "{}: closing notes for channel {channel_name}", + client.username + ); + cx.update(|_| drop(buffer)); + } + + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } => { + let channel_buffer = cx + .read(|cx| { + client + .channel_buffers() + .iter() + .find(|buffer| buffer.read(cx).channel().name == channel_name) + .cloned() + }) + .ok_or_else(|| TestError::Inapplicable)?; + + log::info!( + "{}: editing notes for channel {channel_name} with {:?}", + client.username, + edits + ); + + channel_buffer.update(cx, |buffer, cx| { + let buffer = buffer.buffer(); + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + }); + } + + ChannelBufferOperation::Noop => Err(TestError::Inapplicable)?, + } + Ok(()) + } + + async fn on_client_added(client: &Rc, cx: &mut TestAppContext) { + let channel_store = client.channel_store(); + while channel_store.read_with(cx, |store, _| store.channel_count() == 0) { + channel_store.next_notification(cx).await; + } + } + + async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + let channels = server.app_state.db.all_channels().await.unwrap(); + + for (client, client_cx) in clients.iter_mut() { + client_cx.update(|cx| { + client + .channel_buffers() + .retain(|b| b.read(cx).is_connected()); + }); + } + + for (channel_id, channel_name) in channels { + let mut prev_text: Option<(u64, String)> = None; + + let mut collaborator_user_ids = server + .app_state + .db + .get_channel_buffer_collaborators(channel_id) + .await + .unwrap() + .into_iter() + .map(|id| id.to_proto()) + .collect::>(); + collaborator_user_ids.sort(); + + for (client, client_cx) in clients.iter() { + let user_id = client.user_id().unwrap(); + client_cx.read(|cx| { + if let Some(channel_buffer) = client + .channel_buffers() + .iter() + .find(|b| b.read(cx).channel().id == channel_id.to_proto()) + { + let channel_buffer = channel_buffer.read(cx); + + // Assert that channel buffer's text matches other clients' copies. + let text = channel_buffer.buffer().read(cx).text(); + if let Some((prev_user_id, prev_text)) = &prev_text { + assert_eq!( + &text, + prev_text, + "client {user_id} has different text than client {prev_user_id} for channel {channel_name}", + ); + } else { + prev_text = Some((user_id, text.clone())); + } + + // Assert that all clients and the server agree about who is present in the + // channel buffer. + let collaborators = channel_buffer.collaborators(); + let mut user_ids = + collaborators.iter().map(|c| c.user_id).collect::>(); + user_ids.sort(); + assert_eq!( + user_ids, + collaborator_user_ids, + "client {user_id} has different user ids for channel {channel_name} than the server", + ); + } + }); + } + } + } +} diff --git a/crates/collab/src/tests/random_project_collaboration_tests.rs b/crates/collab/src/tests/random_project_collaboration_tests.rs new file mode 100644 index 0000000000..7570768249 --- /dev/null +++ b/crates/collab/src/tests/random_project_collaboration_tests.rs @@ -0,0 +1,1585 @@ +use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; +use crate::db::UserId; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use call::ActiveCall; +use collections::{BTreeMap, HashMap}; +use editor::Bias; +use fs::{repository::GitFileStatus, FakeFs, Fs as _}; +use futures::StreamExt; +use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; +use language::{range_to_lsp, FakeLspAdapter, Language, LanguageConfig, PointUtf16}; +use lsp::FakeLanguageServer; +use pretty_assertions::assert_eq; +use project::{search::SearchQuery, Project, ProjectPath}; +use rand::{ + distributions::{Alphanumeric, DistString}, + prelude::*, +}; +use serde::{Deserialize, Serialize}; +use std::{ + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use util::ResultExt; + +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] +async fn test_random_project_collaboration( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + run_randomized_test::(cx, deterministic, rng).await; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ClientOperation { + AcceptIncomingCall, + RejectIncomingCall, + LeaveCall, + InviteContactToCall { + user_id: UserId, + }, + OpenLocalProject { + first_root_name: String, + }, + OpenRemoteProject { + host_id: UserId, + first_root_name: String, + }, + AddWorktreeToProject { + project_root_name: String, + new_root_path: PathBuf, + }, + CloseRemoteProject { + project_root_name: String, + }, + OpenBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SearchProject { + project_root_name: String, + is_local: bool, + query: String, + detach: bool, + }, + EditBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + edits: Vec<(Range, Arc)>, + }, + CloseBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SaveBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + detach: bool, + }, + RequestLspDataInBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + offset: usize, + kind: LspRequestKind, + detach: bool, + }, + CreateWorktreeEntry { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + is_dir: bool, + }, + WriteFsEntry { + path: PathBuf, + is_dir: bool, + content: String, + }, + GitOperation { + operation: GitOperation, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum GitOperation { + WriteGitIndex { + repo_path: PathBuf, + contents: Vec<(PathBuf, String)>, + }, + WriteGitBranch { + repo_path: PathBuf, + new_branch: Option, + }, + WriteGitStatuses { + repo_path: PathBuf, + statuses: Vec<(PathBuf, GitFileStatus)>, + git_operation: bool, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum LspRequestKind { + Rename, + Completion, + CodeAction, + Definition, + Highlights, +} + +struct ProjectCollaborationTest; + +#[async_trait(?Send)] +impl RandomizedTest for ProjectCollaborationTest { + type Operation = ClientOperation; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for (ix, user_a) in users.iter().enumerate() { + for user_b in &users[ix + 1..] { + db.send_contact_request(user_a.user_id, user_b.user_id) + .await + .unwrap(); + db.respond_to_contact_request(user_b.user_id, user_a.user_id, true) + .await + .unwrap(); + } + } + } + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ClientOperation { + let call = cx.read(ActiveCall::global); + loop { + match rng.gen_range(0..100_u32) { + // Mutate the call + 0..=29 => { + // Respond to an incoming call + if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { + break if rng.gen_bool(0.7) { + ClientOperation::AcceptIncomingCall + } else { + ClientOperation::RejectIncomingCall + }; + } + + match rng.gen_range(0..100_u32) { + // Invite a contact to the current call + 0..=70 => { + let available_contacts = + client.user_store().read_with(cx, |user_store, _| { + user_store + .contacts() + .iter() + .filter(|contact| contact.online && !contact.busy) + .cloned() + .collect::>() + }); + if !available_contacts.is_empty() { + let contact = available_contacts.choose(rng).unwrap(); + break ClientOperation::InviteContactToCall { + user_id: UserId(contact.user.id as i32), + }; + } + } + + // Leave the current call + 71.. => { + if plan.allow_client_disconnection + && call.read_with(cx, |call, _| call.room().is_some()) + { + break ClientOperation::LeaveCall; + } + } + } + } + + // Mutate projects + 30..=59 => match rng.gen_range(0..100_u32) { + // Open a new project + 0..=70 => { + // Open a remote project + if let Some(room) = call.read_with(cx, |call, _| call.room().cloned()) { + let existing_remote_project_ids = cx.read(|cx| { + client + .remote_projects() + .iter() + .map(|p| p.read(cx).remote_id().unwrap()) + .collect::>() + }); + let new_remote_projects = room.read_with(cx, |room, _| { + room.remote_participants() + .values() + .flat_map(|participant| { + participant.projects.iter().filter_map(|project| { + if existing_remote_project_ids.contains(&project.id) { + None + } else { + Some(( + UserId::from_proto(participant.user.id), + project.worktree_root_names[0].clone(), + )) + } + }) + }) + .collect::>() + }); + if !new_remote_projects.is_empty() { + let (host_id, first_root_name) = + new_remote_projects.choose(rng).unwrap().clone(); + break ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + }; + } + } + // Open a local project + else { + let first_root_name = plan.next_root_dir_name(); + break ClientOperation::OpenLocalProject { first_root_name }; + } + } + + // Close a remote project + 71..=80 => { + if !client.remote_projects().is_empty() { + let project = client.remote_projects().choose(rng).unwrap().clone(); + let first_root_name = root_name_for_project(&project, cx); + break ClientOperation::CloseRemoteProject { + project_root_name: first_root_name, + }; + } + } + + // Mutate project worktrees + 81.. => match rng.gen_range(0..100_u32) { + // Add a worktree to a local project + 0..=50 => { + let Some(project) = client.local_projects().choose(rng).cloned() else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let mut paths = client.fs().paths(false); + paths.remove(0); + let new_root_path = if paths.is_empty() || rng.gen() { + Path::new("/").join(&plan.next_root_dir_name()) + } else { + paths.choose(rng).unwrap().clone() + }; + break ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + }; + } + + // Add an entry to a worktree + _ => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + && worktree.root_entry().map_or(false, |e| e.is_dir()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let is_dir = rng.gen::(); + let mut full_path = + worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); + full_path.push(gen_file_name(rng)); + if !is_dir { + full_path.set_extension("rs"); + } + break ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + }; + } + }, + }, + + // Query and mutate buffers + 60..=90 => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + + match rng.gen_range(0..100_u32) { + // Manipulate an existing buffer + 0..=70 => { + let Some(buffer) = client + .buffers_for_project(&project) + .iter() + .choose(rng) + .cloned() + else { + continue; + }; + + let full_path = buffer + .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + match rng.gen_range(0..100_u32) { + // Close the buffer + 0..=15 => { + break ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + }; + } + // Save the buffer + 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { + let detach = rng.gen_bool(0.3); + break ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + }; + } + // Edit the buffer + 30..=69 => { + let edits = buffer + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + break ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + }; + } + // Make an LSP request + _ => { + let offset = buffer.read_with(cx, |buffer, _| { + buffer.clip_offset( + rng.gen_range(0..=buffer.len()), + language::Bias::Left, + ) + }); + let detach = rng.gen(); + break ClientOperation::RequestLspDataInBuffer { + project_root_name, + full_path, + offset, + is_local, + kind: match rng.gen_range(0..5_u32) { + 0 => LspRequestKind::Rename, + 1 => LspRequestKind::Highlights, + 2 => LspRequestKind::Definition, + 3 => LspRequestKind::CodeAction, + 4.. => LspRequestKind::Completion, + }, + detach, + }; + } + } + } + + 71..=80 => { + let query = rng.gen_range('a'..='z').to_string(); + let detach = rng.gen_bool(0.3); + break ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + }; + } + + // Open a buffer + 81.. => { + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let full_path = worktree.read_with(cx, |worktree, _| { + let entry = worktree + .entries(false) + .filter(|e| e.is_file()) + .choose(rng) + .unwrap(); + if entry.path.as_ref() == Path::new("") { + Path::new(worktree.root_name()).into() + } else { + Path::new(worktree.root_name()).join(&entry.path) + } + }); + break ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + }; + } + } + } + + // Update a git related action + 91..=95 => { + break ClientOperation::GitOperation { + operation: generate_git_operation(rng, client), + }; + } + + // Create or update a file or directory + 96.. => { + let is_dir = rng.gen::(); + let content; + let mut path; + let dir_paths = client.fs().directories(false); + + if is_dir { + content = String::new(); + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + } else { + content = Alphanumeric.sample_string(rng, 16); + + // Create a new file or overwrite an existing file + let file_paths = client.fs().files(); + if file_paths.is_empty() || rng.gen_bool(0.5) { + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + path.set_extension("rs"); + } else { + path = file_paths.choose(rng).unwrap().clone() + }; + } + break ClientOperation::WriteFsEntry { + path, + is_dir, + content, + }; + } + } + } + } + + async fn apply_operation( + client: &TestClient, + operation: ClientOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + match operation { + ClientOperation::AcceptIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: accepting incoming call", client.username); + active_call + .update(cx, |call, cx| call.accept_incoming(cx)) + .await?; + } + + ClientOperation::RejectIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: declining incoming call", client.username); + active_call.update(cx, |call, cx| call.decline_incoming(cx))?; + } + + ClientOperation::LeaveCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: hanging up", client.username); + active_call.update(cx, |call, cx| call.hang_up(cx)).await?; + } + + ClientOperation::InviteContactToCall { user_id } => { + let active_call = cx.read(ActiveCall::global); + + log::info!("{}: inviting {}", client.username, user_id,); + active_call + .update(cx, |call, cx| call.invite(user_id.to_proto(), None, cx)) + .await + .log_err(); + } + + ClientOperation::OpenLocalProject { first_root_name } => { + log::info!( + "{}: opening local project at {:?}", + client.username, + first_root_name + ); + + let root_path = Path::new("/").join(&first_root_name); + client.fs().create_dir(&root_path).await.unwrap(); + client + .fs() + .create_file(&root_path.join("main.rs"), Default::default()) + .await + .unwrap(); + let project = client.build_local_project(root_path, cx).await.0; + ensure_project_shared(&project, client, cx).await; + client.local_projects_mut().push(project.clone()); + } + + ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: finding/creating local worktree at {:?} to project with root path {}", + client.username, + new_root_path, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + if !client.fs().paths(false).contains(&new_root_path) { + client.fs().create_dir(&new_root_path).await.unwrap(); + } + project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(&new_root_path, true, cx) + }) + .await + .unwrap(); + } + + ClientOperation::CloseRemoteProject { project_root_name } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing remote project with root path {}", + client.username, + project_root_name, + ); + + let ix = client + .remote_projects() + .iter() + .position(|p| p == &project) + .unwrap(); + cx.update(|_| { + client.remote_projects_mut().remove(ix); + client.buffers().retain(|p, _| *p != project); + drop(project); + }); + } + + ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + } => { + let active_call = cx.read(ActiveCall::global); + let project = active_call + .update(cx, |call, cx| { + let room = call.room().cloned()?; + let participant = room + .read(cx) + .remote_participants() + .get(&host_id.to_proto())?; + let project_id = participant + .projects + .iter() + .find(|project| project.worktree_root_names[0] == first_root_name)? + .id; + Some(room.update(cx, |room, cx| { + room.join_project( + project_id, + client.language_registry().clone(), + FakeFs::new(cx.background().clone()), + cx, + ) + })) + }) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: joining remote project of user {}, root name {}", + client.username, + host_id, + first_root_name, + ); + + let project = project.await?; + client.remote_projects_mut().push(project.clone()); + } + + ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: creating {} at path {:?} in {} project {}", + client.username, + if is_dir { "dir" } else { "file" }, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + project + .update(cx, |p, cx| p.create_entry(project_path, is_dir, cx)) + .unwrap() + .await?; + } + + ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: opening buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await?; + client.buffers_for_project(&project).insert(buffer); + } + + ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: editing buffer {:?} in {} project {} with {:?}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + edits + ); + + ensure_project_shared(&project, client, cx).await; + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + } + + ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + cx.update(|_| { + client.buffers_for_project(&project).remove(&buffer); + drop(buffer); + }); + } + + ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: saving buffer {:?} in {} project {}, {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + ensure_project_shared(&project, client, cx).await; + let requested_version = buffer.read_with(cx, |buffer, _| buffer.version()); + let save = + project.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + let save = cx.spawn(|cx| async move { + save.await + .map_err(|err| anyhow!("save request failed: {:?}", err))?; + assert!(buffer + .read_with(&cx, |buffer, _| { buffer.saved_version().to_owned() }) + .observed_all(&requested_version)); + anyhow::Ok(()) + }); + if detach { + cx.update(|cx| save.detach_and_log_err(cx)); + } else { + save.await?; + } + } + + ClientOperation::RequestLspDataInBuffer { + project_root_name, + is_local, + full_path, + offset, + kind, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: request LSP {:?} for buffer {:?} in {} project {}, {}", + client.username, + kind, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + use futures::{FutureExt as _, TryFutureExt as _}; + let offset = buffer.read_with(cx, |b, _| b.clip_offset(offset, Bias::Left)); + let request = cx.foreground().spawn(project.update(cx, |project, cx| { + match kind { + LspRequestKind::Rename => project + .prepare_rename(buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Completion => project + .completions(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::CodeAction => project + .code_actions(&buffer, offset..offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Definition => project + .definition(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Highlights => project + .document_highlights(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + } + })); + if detach { + request.detach(); + } else { + request.await?; + } + } + + ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: search {} project {} for {:?}, {}", + client.username, + if is_local { "local" } else { "remote" }, + project_root_name, + query, + if detach { "detaching" } else { "awaiting" } + ); + + let mut search = project.update(cx, |project, cx| { + project.search( + SearchQuery::text(query, false, false, Vec::new(), Vec::new()), + cx, + ) + }); + drop(project); + let search = cx.background().spawn(async move { + let mut results = HashMap::default(); + while let Some((buffer, ranges)) = search.next().await { + results.entry(buffer).or_insert(ranges); + } + results + }); + search.await; + } + + ClientOperation::WriteFsEntry { + path, + is_dir, + content, + } => { + if !client + .fs() + .directories(false) + .contains(&path.parent().unwrap().to_owned()) + { + return Err(TestError::Inapplicable); + } + + if is_dir { + log::info!("{}: creating dir at {:?}", client.username, path); + client.fs().create_dir(&path).await.unwrap(); + } else { + let exists = client.fs().metadata(&path).await?.is_some(); + let verb = if exists { "updating" } else { "creating" }; + log::info!("{}: {} file at {:?}", verb, client.username, path); + + client + .fs() + .save(&path, &content.as_str().into(), text::LineEnding::Unix) + .await + .unwrap(); + } + } + + ClientOperation::GitOperation { operation } => match operation { + GitOperation::WriteGitIndex { + repo_path, + contents, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + for (path, _) in contents.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git index for repo {:?}: {:?}", + client.username, + repo_path, + contents + ); + + let dot_git_dir = repo_path.join(".git"); + let contents = contents + .iter() + .map(|(path, contents)| (path.as_path(), contents.clone())) + .collect::>(); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client.fs().set_index_for_repo(&dot_git_dir, &contents); + } + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + log::info!( + "{}: writing git branch for repo {:?}: {:?}", + client.username, + repo_path, + new_branch + ); + + let dot_git_dir = repo_path.join(".git"); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client + .fs() + .set_branch_name(&dot_git_dir, new_branch.clone()); + } + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + for (path, _) in statuses.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git statuses for repo {:?}: {:?}", + client.username, + repo_path, + statuses + ); + + let dot_git_dir = repo_path.join(".git"); + + let statuses = statuses + .iter() + .map(|(path, val)| (path.as_path(), val.clone())) + .collect::>(); + + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + + if git_operation { + client.fs().set_status_for_repo_via_git_operation( + &dot_git_dir, + statuses.as_slice(), + ); + } else { + client.fs().set_status_for_repo_via_working_copy_change( + &dot_git_dir, + statuses.as_slice(), + ); + } + } + }, + } + Ok(()) + } + + async fn on_client_added(client: &Rc, _: &mut TestAppContext) { + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + None, + ); + language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + name: "the-fake-language-server", + capabilities: lsp::LanguageServer::full_capabilities(), + initializer: Some(Box::new({ + let fs = client.app_state.fs.clone(); + move |fake_server: &mut FakeLanguageServer| { + fake_server.handle_request::( + |_, _| async move { + Ok(Some(lsp::CompletionResponse::Array(vec![ + lsp::CompletionItem { + text_edit: Some(lsp::CompletionTextEdit::Edit( + lsp::TextEdit { + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 0), + ), + new_text: "the-new-text".to_string(), + }, + )), + ..Default::default() + }, + ]))) + }, + ); + + fake_server.handle_request::( + |_, _| async move { + Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( + lsp::CodeAction { + title: "the-code-action".to_string(), + ..Default::default() + }, + )])) + }, + ); + + fake_server.handle_request::( + |params, _| async move { + Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( + params.position, + params.position, + )))) + }, + ); + + fake_server.handle_request::({ + let fs = fs.clone(); + move |_, cx| { + let background = cx.background(); + let mut rng = background.rng(); + let count = rng.gen_range::(1..3); + let files = fs.as_fake().files(); + let files = (0..count) + .map(|_| files.choose(&mut *rng).unwrap().clone()) + .collect::>(); + async move { + log::info!("LSP: Returning definitions in files {:?}", &files); + Ok(Some(lsp::GotoDefinitionResponse::Array( + files + .into_iter() + .map(|file| lsp::Location { + uri: lsp::Url::from_file_path(file).unwrap(), + range: Default::default(), + }) + .collect(), + ))) + } + } + }); + + fake_server.handle_request::( + move |_, cx| { + let mut highlights = Vec::new(); + let background = cx.background(); + let mut rng = background.rng(); + + let highlight_count = rng.gen_range(1..=5); + for _ in 0..highlight_count { + let start_row = rng.gen_range(0..100); + let start_column = rng.gen_range(0..100); + let end_row = rng.gen_range(0..100); + let end_column = rng.gen_range(0..100); + let start = PointUtf16::new(start_row, start_column); + let end = PointUtf16::new(end_row, end_column); + let range = if start > end { end..start } else { start..end }; + highlights.push(lsp::DocumentHighlight { + range: range_to_lsp(range.clone()), + kind: Some(lsp::DocumentHighlightKind::READ), + }); + } + highlights.sort_unstable_by_key(|highlight| { + (highlight.range.start, highlight.range.end) + }); + async move { Ok(Some(highlights)) } + }, + ); + } + })), + ..Default::default() + })) + .await; + client.app_state.languages.add(Arc::new(language)); + } + + async fn on_quiesce(_: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + for (client, client_cx) in clients.iter() { + for guest_project in client.remote_projects().iter() { + guest_project.read_with(client_cx, |guest_project, cx| { + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == guest_project.remote_id() + }) + })? + .clone(); + Some((project, cx)) + }); + + if !guest_project.is_read_only() { + if let Some((host_project, host_cx)) = host_project { + let host_worktree_snapshots = + host_project.read_with(host_cx, |host_project, cx| { + host_project + .worktrees(cx) + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>() + }); + let guest_worktree_snapshots = guest_project + .worktrees(cx) + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>(); + + assert_eq!( + guest_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + host_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + "{} has different worktrees than the host for project {:?}", + client.username, guest_project.remote_id(), + ); + + for (id, host_snapshot) in &host_worktree_snapshots { + let guest_snapshot = &guest_worktree_snapshots[id]; + assert_eq!( + guest_snapshot.root_name(), + host_snapshot.root_name(), + "{} has different root name than the host for worktree {}, project {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.abs_path(), + host_snapshot.abs_path(), + "{} has different abs path than the host for worktree {}, project: {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.entries(false).collect::>(), + host_snapshot.entries(false).collect::>(), + "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", + client.username, + host_snapshot.abs_path(), + id, + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.repositories().collect::>(), host_snapshot.repositories().collect::>(), + "{} has different repositories than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.scan_id(), host_snapshot.scan_id(), + "{} has different scan id than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + } + } + } + + for buffer in guest_project.opened_buffers(cx) { + let buffer = buffer.read(cx); + assert_eq!( + buffer.deferred_ops_len(), + 0, + "{} has deferred operations for buffer {:?} in project {:?}", + client.username, + buffer.file().unwrap().full_path(cx), + guest_project.remote_id(), + ); + } + }); + } + + let buffers = client.buffers().clone(); + for (guest_project, guest_buffers) in &buffers { + let project_id = if guest_project.read_with(client_cx, |project, _| { + project.is_local() || project.is_read_only() + }) { + continue; + } else { + guest_project + .read_with(client_cx, |project, _| project.remote_id()) + .unwrap() + }; + let guest_user_id = client.user_id().unwrap(); + + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == Some(project_id) + }) + })? + .clone(); + Some((client.user_id().unwrap(), project, cx)) + }); + + let (host_user_id, host_project, host_cx) = + if let Some((host_user_id, host_project, host_cx)) = host_project { + (host_user_id, host_project, host_cx) + } else { + continue; + }; + + for guest_buffer in guest_buffers { + let buffer_id = + guest_buffer.read_with(client_cx, |buffer, _| buffer.remote_id()); + let host_buffer = host_project.read_with(host_cx, |project, cx| { + project.buffer_for_id(buffer_id, cx).unwrap_or_else(|| { + panic!( + "host does not have buffer for guest:{}, peer:{:?}, id:{}", + client.username, + client.peer_id(), + buffer_id + ) + }) + }); + let path = host_buffer + .read_with(host_cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.deferred_ops_len()), + 0, + "{}, buffer {}, path {:?} has deferred operations", + client.username, + buffer_id, + path, + ); + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.text()), + host_buffer.read_with(host_cx, |buffer, _| buffer.text()), + "{}, buffer {}, path {:?}, differs from the host's buffer", + client.username, + buffer_id, + path + ); + + let host_file = host_buffer.read_with(host_cx, |b, _| b.file().cloned()); + let guest_file = guest_buffer.read_with(client_cx, |b, _| b.file().cloned()); + match (host_file, guest_file) { + (Some(host_file), Some(guest_file)) => { + assert_eq!(guest_file.path(), host_file.path()); + assert_eq!(guest_file.is_deleted(), host_file.is_deleted()); + assert_eq!( + guest_file.mtime(), + host_file.mtime(), + "guest {} mtime does not match host {} for path {:?} in project {}", + guest_user_id, + host_user_id, + guest_file.path(), + project_id, + ); + } + (None, None) => {} + (None, _) => panic!("host's file is None, guest's isn't"), + (_, None) => panic!("guest's file is None, hosts's isn't"), + } + + let host_diff_base = host_buffer + .read_with(host_cx, |b, _| b.diff_base().map(ToString::to_string)); + let guest_diff_base = guest_buffer + .read_with(client_cx, |b, _| b.diff_base().map(ToString::to_string)); + assert_eq!( + guest_diff_base, host_diff_base, + "guest {} diff base does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version = + host_buffer.read_with(host_cx, |b, _| b.saved_version().clone()); + let guest_saved_version = + guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone()); + assert_eq!( + guest_saved_version, host_saved_version, + "guest {} saved version does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version_fingerprint = + host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint()); + let guest_saved_version_fingerprint = + guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint()); + assert_eq!( + guest_saved_version_fingerprint, host_saved_version_fingerprint, + "guest {} saved fingerprint does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime()); + let guest_saved_mtime = + guest_buffer.read_with(client_cx, |b, _| b.saved_mtime()); + assert_eq!( + guest_saved_mtime, host_saved_mtime, + "guest {} saved mtime does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty()); + let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty()); + assert_eq!(guest_is_dirty, host_is_dirty, + "guest {} dirty status does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict()); + let guest_has_conflict = + guest_buffer.read_with(client_cx, |b, _| b.has_conflict()); + assert_eq!(guest_has_conflict, host_has_conflict, + "guest {} conflict status does not match host's for path {path:?} in project {project_id}", + client.username + ); + } + } + } + } +} + +fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation { + fn generate_file_paths( + repo_path: &Path, + rng: &mut StdRng, + client: &TestClient, + ) -> Vec { + let mut paths = client + .fs() + .files() + .into_iter() + .filter(|path| path.starts_with(repo_path)) + .collect::>(); + + let count = rng.gen_range(0..=paths.len()); + paths.shuffle(rng); + paths.truncate(count); + + paths + .iter() + .map(|path| path.strip_prefix(repo_path).unwrap().to_path_buf()) + .collect::>() + } + + let repo_path = client.fs().directories(false).choose(rng).unwrap().clone(); + + match rng.gen_range(0..100_u32) { + 0..=25 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let contents = file_paths + .into_iter() + .map(|path| (path, Alphanumeric.sample_string(rng, 16))) + .collect(); + + GitOperation::WriteGitIndex { + repo_path, + contents, + } + } + 26..=63 => { + let new_branch = (rng.gen_range(0..10) > 3).then(|| Alphanumeric.sample_string(rng, 8)); + + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } + } + 64..=100 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let statuses = file_paths + .into_iter() + .map(|paths| { + ( + paths, + match rng.gen_range(0..3_u32) { + 0 => GitFileStatus::Added, + 1 => GitFileStatus::Modified, + 2 => GitFileStatus::Conflict, + _ => unreachable!(), + }, + ) + }) + .collect::>(); + + let git_operation = rng.gen::(); + + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } + } + _ => unreachable!(), + } +} + +fn buffer_for_full_path( + client: &TestClient, + project: &ModelHandle, + full_path: &PathBuf, + cx: &TestAppContext, +) -> Option> { + client + .buffers_for_project(project) + .iter() + .find(|buffer| { + buffer.read_with(cx, |buffer, cx| { + buffer.file().unwrap().full_path(cx) == *full_path + }) + }) + .cloned() +} + +fn project_for_root_name( + client: &TestClient, + root_name: &str, + cx: &TestAppContext, +) -> Option> { + if let Some(ix) = project_ix_for_root_name(&*client.local_projects(), root_name, cx) { + return Some(client.local_projects()[ix].clone()); + } + if let Some(ix) = project_ix_for_root_name(&*client.remote_projects(), root_name, cx) { + return Some(client.remote_projects()[ix].clone()); + } + None +} + +fn project_ix_for_root_name( + projects: &[ModelHandle], + root_name: &str, + cx: &TestAppContext, +) -> Option { + projects.iter().position(|project| { + project.read_with(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next().unwrap(); + worktree.read(cx).root_name() == root_name + }) + }) +} + +fn root_name_for_project(project: &ModelHandle, cx: &TestAppContext) -> String { + project.read_with(cx, |project, cx| { + project + .visible_worktrees(cx) + .next() + .unwrap() + .read(cx) + .root_name() + .to_string() + }) +} + +fn project_path_for_full_path( + project: &ModelHandle, + full_path: &Path, + cx: &TestAppContext, +) -> Option { + let mut components = full_path.components(); + let root_name = components.next().unwrap().as_os_str().to_str().unwrap(); + let path = components.as_path().into(); + let worktree_id = project.read_with(cx, |project, cx| { + project.worktrees(cx).find_map(|worktree| { + let worktree = worktree.read(cx); + if worktree.root_name() == root_name { + Some(worktree.id()) + } else { + None + } + }) + })?; + Some(ProjectPath { worktree_id, path }) +} + +async fn ensure_project_shared( + project: &ModelHandle, + client: &TestClient, + cx: &mut TestAppContext, +) { + let first_root_name = root_name_for_project(project, cx); + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_some()) + && project.read_with(cx, |project, _| project.is_local() && !project.is_shared()) + { + match active_call + .update(cx, |call, cx| call.share_project(project.clone(), cx)) + .await + { + Ok(project_id) => { + log::info!( + "{}: shared project {} with id {}", + client.username, + first_root_name, + project_id + ); + } + Err(error) => { + log::error!( + "{}: error sharing project {}: {:?}", + client.username, + first_root_name, + error + ); + } + } + } +} + +fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option> { + client + .local_projects() + .iter() + .chain(client.remote_projects().iter()) + .choose(rng) + .cloned() +} + +fn gen_file_name(rng: &mut StdRng) -> String { + let mut name = String::new(); + for _ in 0..10 { + let letter = rng.gen_range('a'..='z'); + name.push(letter); + } + name +} diff --git a/crates/collab/src/tests/randomized_integration_tests.rs b/crates/collab/src/tests/randomized_integration_tests.rs deleted file mode 100644 index 814f248b6d..0000000000 --- a/crates/collab/src/tests/randomized_integration_tests.rs +++ /dev/null @@ -1,2199 +0,0 @@ -use crate::{ - db::{self, NewUserParams, UserId}, - rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - tests::{TestClient, TestServer}, -}; -use anyhow::{anyhow, Result}; -use call::ActiveCall; -use client::RECEIVE_TIMEOUT; -use collections::{BTreeMap, HashMap}; -use editor::Bias; -use fs::{repository::GitFileStatus, FakeFs, Fs as _}; -use futures::StreamExt as _; -use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext}; -use language::{range_to_lsp, FakeLspAdapter, Language, LanguageConfig, PointUtf16}; -use lsp::FakeLanguageServer; -use parking_lot::Mutex; -use pretty_assertions::assert_eq; -use project::{search::SearchQuery, Project, ProjectPath}; -use rand::{ - distributions::{Alphanumeric, DistString}, - prelude::*, -}; -use serde::{Deserialize, Serialize}; -use settings::SettingsStore; -use std::{ - env, - ops::Range, - path::{Path, PathBuf}, - rc::Rc, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; -use util::ResultExt; - -lazy_static::lazy_static! { - static ref PLAN_LOAD_PATH: Option = path_env_var("LOAD_PLAN"); - static ref PLAN_SAVE_PATH: Option = path_env_var("SAVE_PLAN"); -} -static LOADED_PLAN_JSON: Mutex>> = Mutex::new(None); -static PLAN: Mutex>>> = Mutex::new(None); - -#[gpui::test(iterations = 100, on_failure = "on_failure")] -async fn test_random_collaboration( - cx: &mut TestAppContext, - deterministic: Arc, - rng: StdRng, -) { - deterministic.forbid_parking(); - - let max_peers = env::var("MAX_PEERS") - .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) - .unwrap_or(3); - let max_operations = env::var("OPERATIONS") - .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) - .unwrap_or(10); - - let mut server = TestServer::start(&deterministic).await; - let db = server.app_state.db.clone(); - - let mut users = Vec::new(); - for ix in 0..max_peers { - let username = format!("user-{}", ix + 1); - let user_id = db - .create_user( - &format!("{username}@example.com"), - false, - NewUserParams { - github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - users.push(UserTestPlan { - user_id, - username, - online: false, - next_root_id: 0, - operation_ix: 0, - }); - } - - for (ix, user_a) in users.iter().enumerate() { - for user_b in &users[ix + 1..] { - server - .app_state - .db - .send_contact_request(user_a.user_id, user_b.user_id) - .await - .unwrap(); - server - .app_state - .db - .respond_to_contact_request(user_b.user_id, user_a.user_id, true) - .await - .unwrap(); - } - } - - let plan = Arc::new(Mutex::new(TestPlan::new(rng, users, max_operations))); - - if let Some(path) = &*PLAN_LOAD_PATH { - let json = LOADED_PLAN_JSON - .lock() - .get_or_insert_with(|| { - eprintln!("loaded test plan from path {:?}", path); - std::fs::read(path).unwrap() - }) - .clone(); - plan.lock().deserialize(json); - } - - PLAN.lock().replace(plan.clone()); - - let mut clients = Vec::new(); - let mut client_tasks = Vec::new(); - let mut operation_channels = Vec::new(); - - loop { - let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else { - break; - }; - applied.store(true, SeqCst); - let did_apply = apply_server_operation( - deterministic.clone(), - &mut server, - &mut clients, - &mut client_tasks, - &mut operation_channels, - plan.clone(), - next_operation, - cx, - ) - .await; - if !did_apply { - applied.store(false, SeqCst); - } - } - - drop(operation_channels); - deterministic.start_waiting(); - futures::future::join_all(client_tasks).await; - deterministic.finish_waiting(); - deterministic.run_until_parked(); - - check_consistency_between_clients(&clients); - - for (client, mut cx) in clients { - cx.update(|cx| { - let store = cx.remove_global::(); - cx.clear_globals(); - cx.set_global(store); - drop(client); - }); - } - - deterministic.run_until_parked(); -} - -fn on_failure() { - if let Some(plan) = PLAN.lock().clone() { - if let Some(path) = &*PLAN_SAVE_PATH { - eprintln!("saved test plan to path {:?}", path); - std::fs::write(path, plan.lock().serialize()).unwrap(); - } - } -} - -async fn apply_server_operation( - deterministic: Arc, - server: &mut TestServer, - clients: &mut Vec<(Rc, TestAppContext)>, - client_tasks: &mut Vec>, - operation_channels: &mut Vec>, - plan: Arc>, - operation: Operation, - cx: &mut TestAppContext, -) -> bool { - match operation { - Operation::AddConnection { user_id } => { - let username; - { - let mut plan = plan.lock(); - let user = plan.user(user_id); - if user.online { - return false; - } - user.online = true; - username = user.username.clone(); - }; - log::info!("Adding new connection for {}", username); - let next_entity_id = (user_id.0 * 10_000) as usize; - let mut client_cx = TestAppContext::new( - cx.foreground_platform(), - cx.platform(), - deterministic.build_foreground(user_id.0 as usize), - deterministic.build_background(), - cx.font_cache(), - cx.leak_detector(), - next_entity_id, - cx.function_name.clone(), - ); - - let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded(); - let client = Rc::new(server.create_client(&mut client_cx, &username).await); - operation_channels.push(operation_tx); - clients.push((client.clone(), client_cx.clone())); - client_tasks.push(client_cx.foreground().spawn(simulate_client( - client, - operation_rx, - plan.clone(), - client_cx, - ))); - - log::info!("Added connection for {}", username); - } - - Operation::RemoveConnection { - user_id: removed_user_id, - } => { - log::info!("Simulating full disconnection of user {}", removed_user_id); - let client_ix = clients - .iter() - .position(|(client, cx)| client.current_user_id(cx) == removed_user_id); - let Some(client_ix) = client_ix else { - return false; - }; - let user_connection_ids = server - .connection_pool - .lock() - .user_connection_ids(removed_user_id) - .collect::>(); - assert_eq!(user_connection_ids.len(), 1); - let removed_peer_id = user_connection_ids[0].into(); - let (client, mut client_cx) = clients.remove(client_ix); - let client_task = client_tasks.remove(client_ix); - operation_channels.remove(client_ix); - server.forbid_connections(); - server.disconnect_client(removed_peer_id); - deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - deterministic.start_waiting(); - log::info!("Waiting for user {} to exit...", removed_user_id); - client_task.await; - deterministic.finish_waiting(); - server.allow_connections(); - - for project in client.remote_projects().iter() { - project.read_with(&client_cx, |project, _| { - assert!( - project.is_read_only(), - "project {:?} should be read only", - project.remote_id() - ) - }); - } - - for (client, cx) in clients { - let contacts = server - .app_state - .db - .get_contacts(client.current_user_id(cx)) - .await - .unwrap(); - let pool = server.connection_pool.lock(); - for contact in contacts { - if let db::Contact::Accepted { user_id, busy, .. } = contact { - if user_id == removed_user_id { - assert!(!pool.is_user_online(user_id)); - assert!(!busy); - } - } - } - } - - log::info!("{} removed", client.username); - plan.lock().user(removed_user_id).online = false; - client_cx.update(|cx| { - cx.clear_globals(); - drop(client); - }); - } - - Operation::BounceConnection { user_id } => { - log::info!("Simulating temporary disconnection of user {}", user_id); - let user_connection_ids = server - .connection_pool - .lock() - .user_connection_ids(user_id) - .collect::>(); - if user_connection_ids.is_empty() { - return false; - } - assert_eq!(user_connection_ids.len(), 1); - let peer_id = user_connection_ids[0].into(); - server.disconnect_client(peer_id); - deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - } - - Operation::RestartServer => { - log::info!("Simulating server restart"); - server.reset().await; - deterministic.advance_clock(RECEIVE_TIMEOUT); - server.start().await.unwrap(); - deterministic.advance_clock(CLEANUP_TIMEOUT); - let environment = &server.app_state.config.zed_environment; - let stale_room_ids = server - .app_state - .db - .stale_room_ids(environment, server.id()) - .await - .unwrap(); - assert_eq!(stale_room_ids, vec![]); - } - - Operation::MutateClients { - user_ids, - batch_id, - quiesce, - } => { - let mut applied = false; - for user_id in user_ids { - let client_ix = clients - .iter() - .position(|(client, cx)| client.current_user_id(cx) == user_id); - let Some(client_ix) = client_ix else { continue }; - applied = true; - if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) { - log::error!("error signaling user {user_id}: {err}"); - } - } - - if quiesce && applied { - deterministic.run_until_parked(); - check_consistency_between_clients(&clients); - } - - return applied; - } - } - true -} - -async fn apply_client_operation( - client: &TestClient, - operation: ClientOperation, - cx: &mut TestAppContext, -) -> Result<(), TestError> { - match operation { - ClientOperation::AcceptIncomingCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: accepting incoming call", client.username); - active_call - .update(cx, |call, cx| call.accept_incoming(cx)) - .await?; - } - - ClientOperation::RejectIncomingCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: declining incoming call", client.username); - active_call.update(cx, |call, cx| call.decline_incoming(cx))?; - } - - ClientOperation::LeaveCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.room().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: hanging up", client.username); - active_call.update(cx, |call, cx| call.hang_up(cx)).await?; - } - - ClientOperation::InviteContactToCall { user_id } => { - let active_call = cx.read(ActiveCall::global); - - log::info!("{}: inviting {}", client.username, user_id,); - active_call - .update(cx, |call, cx| call.invite(user_id.to_proto(), None, cx)) - .await - .log_err(); - } - - ClientOperation::OpenLocalProject { first_root_name } => { - log::info!( - "{}: opening local project at {:?}", - client.username, - first_root_name - ); - - let root_path = Path::new("/").join(&first_root_name); - client.fs().create_dir(&root_path).await.unwrap(); - client - .fs() - .create_file(&root_path.join("main.rs"), Default::default()) - .await - .unwrap(); - let project = client.build_local_project(root_path, cx).await.0; - ensure_project_shared(&project, client, cx).await; - client.local_projects_mut().push(project.clone()); - } - - ClientOperation::AddWorktreeToProject { - project_root_name, - new_root_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: finding/creating local worktree at {:?} to project with root path {}", - client.username, - new_root_path, - project_root_name - ); - - ensure_project_shared(&project, client, cx).await; - if !client.fs().paths(false).contains(&new_root_path) { - client.fs().create_dir(&new_root_path).await.unwrap(); - } - project - .update(cx, |project, cx| { - project.find_or_create_local_worktree(&new_root_path, true, cx) - }) - .await - .unwrap(); - } - - ClientOperation::CloseRemoteProject { project_root_name } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: closing remote project with root path {}", - client.username, - project_root_name, - ); - - let ix = client - .remote_projects() - .iter() - .position(|p| p == &project) - .unwrap(); - cx.update(|_| { - client.remote_projects_mut().remove(ix); - client.buffers().retain(|p, _| *p != project); - drop(project); - }); - } - - ClientOperation::OpenRemoteProject { - host_id, - first_root_name, - } => { - let active_call = cx.read(ActiveCall::global); - let project = active_call - .update(cx, |call, cx| { - let room = call.room().cloned()?; - let participant = room - .read(cx) - .remote_participants() - .get(&host_id.to_proto())?; - let project_id = participant - .projects - .iter() - .find(|project| project.worktree_root_names[0] == first_root_name)? - .id; - Some(room.update(cx, |room, cx| { - room.join_project( - project_id, - client.language_registry().clone(), - FakeFs::new(cx.background().clone()), - cx, - ) - })) - }) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: joining remote project of user {}, root name {}", - client.username, - host_id, - first_root_name, - ); - - let project = project.await?; - client.remote_projects_mut().push(project.clone()); - } - - ClientOperation::CreateWorktreeEntry { - project_root_name, - is_local, - full_path, - is_dir, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let project_path = project_path_for_full_path(&project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: creating {} at path {:?} in {} project {}", - client.username, - if is_dir { "dir" } else { "file" }, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - ); - - ensure_project_shared(&project, client, cx).await; - project - .update(cx, |p, cx| p.create_entry(project_path, is_dir, cx)) - .unwrap() - .await?; - } - - ClientOperation::OpenBuffer { - project_root_name, - is_local, - full_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let project_path = project_path_for_full_path(&project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: opening buffer {:?} in {} project {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - ); - - ensure_project_shared(&project, client, cx).await; - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx)) - .await?; - client.buffers_for_project(&project).insert(buffer); - } - - ClientOperation::EditBuffer { - project_root_name, - is_local, - full_path, - edits, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: editing buffer {:?} in {} project {} with {:?}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - edits - ); - - ensure_project_shared(&project, client, cx).await; - buffer.update(cx, |buffer, cx| { - let snapshot = buffer.snapshot(); - buffer.edit( - edits.into_iter().map(|(range, text)| { - let start = snapshot.clip_offset(range.start, Bias::Left); - let end = snapshot.clip_offset(range.end, Bias::Right); - (start..end, text) - }), - None, - cx, - ); - }); - } - - ClientOperation::CloseBuffer { - project_root_name, - is_local, - full_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: closing buffer {:?} in {} project {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name - ); - - ensure_project_shared(&project, client, cx).await; - cx.update(|_| { - client.buffers_for_project(&project).remove(&buffer); - drop(buffer); - }); - } - - ClientOperation::SaveBuffer { - project_root_name, - is_local, - full_path, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: saving buffer {:?} in {} project {}, {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - if detach { "detaching" } else { "awaiting" } - ); - - ensure_project_shared(&project, client, cx).await; - let requested_version = buffer.read_with(cx, |buffer, _| buffer.version()); - let save = project.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); - let save = cx.spawn(|cx| async move { - save.await - .map_err(|err| anyhow!("save request failed: {:?}", err))?; - assert!(buffer - .read_with(&cx, |buffer, _| { buffer.saved_version().to_owned() }) - .observed_all(&requested_version)); - anyhow::Ok(()) - }); - if detach { - cx.update(|cx| save.detach_and_log_err(cx)); - } else { - save.await?; - } - } - - ClientOperation::RequestLspDataInBuffer { - project_root_name, - is_local, - full_path, - offset, - kind, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: request LSP {:?} for buffer {:?} in {} project {}, {}", - client.username, - kind, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - if detach { "detaching" } else { "awaiting" } - ); - - use futures::{FutureExt as _, TryFutureExt as _}; - let offset = buffer.read_with(cx, |b, _| b.clip_offset(offset, Bias::Left)); - let request = cx.foreground().spawn(project.update(cx, |project, cx| { - match kind { - LspRequestKind::Rename => project - .prepare_rename(buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Completion => project - .completions(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::CodeAction => project - .code_actions(&buffer, offset..offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Definition => project - .definition(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Highlights => project - .document_highlights(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - } - })); - if detach { - request.detach(); - } else { - request.await?; - } - } - - ClientOperation::SearchProject { - project_root_name, - is_local, - query, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: search {} project {} for {:?}, {}", - client.username, - if is_local { "local" } else { "remote" }, - project_root_name, - query, - if detach { "detaching" } else { "awaiting" } - ); - - let mut search = project.update(cx, |project, cx| { - project.search( - SearchQuery::text(query, false, false, Vec::new(), Vec::new()), - cx, - ) - }); - drop(project); - let search = cx.background().spawn(async move { - let mut results = HashMap::default(); - while let Some((buffer, ranges)) = search.next().await { - results.entry(buffer).or_insert(ranges); - } - results - }); - search.await; - } - - ClientOperation::WriteFsEntry { - path, - is_dir, - content, - } => { - if !client - .fs() - .directories(false) - .contains(&path.parent().unwrap().to_owned()) - { - return Err(TestError::Inapplicable); - } - - if is_dir { - log::info!("{}: creating dir at {:?}", client.username, path); - client.fs().create_dir(&path).await.unwrap(); - } else { - let exists = client.fs().metadata(&path).await?.is_some(); - let verb = if exists { "updating" } else { "creating" }; - log::info!("{}: {} file at {:?}", verb, client.username, path); - - client - .fs() - .save(&path, &content.as_str().into(), fs::LineEnding::Unix) - .await - .unwrap(); - } - } - - ClientOperation::GitOperation { operation } => match operation { - GitOperation::WriteGitIndex { - repo_path, - contents, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - - for (path, _) in contents.iter() { - if !client.fs().files().contains(&repo_path.join(path)) { - return Err(TestError::Inapplicable); - } - } - - log::info!( - "{}: writing git index for repo {:?}: {:?}", - client.username, - repo_path, - contents - ); - - let dot_git_dir = repo_path.join(".git"); - let contents = contents - .iter() - .map(|(path, contents)| (path.as_path(), contents.clone())) - .collect::>(); - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - client.fs().set_index_for_repo(&dot_git_dir, &contents); - } - GitOperation::WriteGitBranch { - repo_path, - new_branch, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - - log::info!( - "{}: writing git branch for repo {:?}: {:?}", - client.username, - repo_path, - new_branch - ); - - let dot_git_dir = repo_path.join(".git"); - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - client.fs().set_branch_name(&dot_git_dir, new_branch); - } - GitOperation::WriteGitStatuses { - repo_path, - statuses, - git_operation, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - for (path, _) in statuses.iter() { - if !client.fs().files().contains(&repo_path.join(path)) { - return Err(TestError::Inapplicable); - } - } - - log::info!( - "{}: writing git statuses for repo {:?}: {:?}", - client.username, - repo_path, - statuses - ); - - let dot_git_dir = repo_path.join(".git"); - - let statuses = statuses - .iter() - .map(|(path, val)| (path.as_path(), val.clone())) - .collect::>(); - - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - - if git_operation { - client - .fs() - .set_status_for_repo_via_git_operation(&dot_git_dir, statuses.as_slice()); - } else { - client.fs().set_status_for_repo_via_working_copy_change( - &dot_git_dir, - statuses.as_slice(), - ); - } - } - }, - } - Ok(()) -} - -fn check_consistency_between_clients(clients: &[(Rc, TestAppContext)]) { - for (client, client_cx) in clients { - for guest_project in client.remote_projects().iter() { - guest_project.read_with(client_cx, |guest_project, cx| { - let host_project = clients.iter().find_map(|(client, cx)| { - let project = client - .local_projects() - .iter() - .find(|host_project| { - host_project.read_with(cx, |host_project, _| { - host_project.remote_id() == guest_project.remote_id() - }) - })? - .clone(); - Some((project, cx)) - }); - - if !guest_project.is_read_only() { - if let Some((host_project, host_cx)) = host_project { - let host_worktree_snapshots = - host_project.read_with(host_cx, |host_project, cx| { - host_project - .worktrees(cx) - .map(|worktree| { - let worktree = worktree.read(cx); - (worktree.id(), worktree.snapshot()) - }) - .collect::>() - }); - let guest_worktree_snapshots = guest_project - .worktrees(cx) - .map(|worktree| { - let worktree = worktree.read(cx); - (worktree.id(), worktree.snapshot()) - }) - .collect::>(); - - assert_eq!( - guest_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), - host_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), - "{} has different worktrees than the host for project {:?}", - client.username, guest_project.remote_id(), - ); - - for (id, host_snapshot) in &host_worktree_snapshots { - let guest_snapshot = &guest_worktree_snapshots[id]; - assert_eq!( - guest_snapshot.root_name(), - host_snapshot.root_name(), - "{} has different root name than the host for worktree {}, project {:?}", - client.username, - id, - guest_project.remote_id(), - ); - assert_eq!( - guest_snapshot.abs_path(), - host_snapshot.abs_path(), - "{} has different abs path than the host for worktree {}, project: {:?}", - client.username, - id, - guest_project.remote_id(), - ); - assert_eq!( - guest_snapshot.entries(false).collect::>(), - host_snapshot.entries(false).collect::>(), - "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", - client.username, - host_snapshot.abs_path(), - id, - guest_project.remote_id(), - ); - assert_eq!(guest_snapshot.repositories().collect::>(), host_snapshot.repositories().collect::>(), - "{} has different repositories than the host for worktree {:?} and project {:?}", - client.username, - host_snapshot.abs_path(), - guest_project.remote_id(), - ); - assert_eq!(guest_snapshot.scan_id(), host_snapshot.scan_id(), - "{} has different scan id than the host for worktree {:?} and project {:?}", - client.username, - host_snapshot.abs_path(), - guest_project.remote_id(), - ); - } - } - } - - for buffer in guest_project.opened_buffers(cx) { - let buffer = buffer.read(cx); - assert_eq!( - buffer.deferred_ops_len(), - 0, - "{} has deferred operations for buffer {:?} in project {:?}", - client.username, - buffer.file().unwrap().full_path(cx), - guest_project.remote_id(), - ); - } - }); - } - - let buffers = client.buffers().clone(); - for (guest_project, guest_buffers) in &buffers { - let project_id = if guest_project.read_with(client_cx, |project, _| { - project.is_local() || project.is_read_only() - }) { - continue; - } else { - guest_project - .read_with(client_cx, |project, _| project.remote_id()) - .unwrap() - }; - let guest_user_id = client.user_id().unwrap(); - - let host_project = clients.iter().find_map(|(client, cx)| { - let project = client - .local_projects() - .iter() - .find(|host_project| { - host_project.read_with(cx, |host_project, _| { - host_project.remote_id() == Some(project_id) - }) - })? - .clone(); - Some((client.user_id().unwrap(), project, cx)) - }); - - let (host_user_id, host_project, host_cx) = - if let Some((host_user_id, host_project, host_cx)) = host_project { - (host_user_id, host_project, host_cx) - } else { - continue; - }; - - for guest_buffer in guest_buffers { - let buffer_id = guest_buffer.read_with(client_cx, |buffer, _| buffer.remote_id()); - let host_buffer = host_project.read_with(host_cx, |project, cx| { - project.buffer_for_id(buffer_id, cx).unwrap_or_else(|| { - panic!( - "host does not have buffer for guest:{}, peer:{:?}, id:{}", - client.username, - client.peer_id(), - buffer_id - ) - }) - }); - let path = host_buffer - .read_with(host_cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); - - assert_eq!( - guest_buffer.read_with(client_cx, |buffer, _| buffer.deferred_ops_len()), - 0, - "{}, buffer {}, path {:?} has deferred operations", - client.username, - buffer_id, - path, - ); - assert_eq!( - guest_buffer.read_with(client_cx, |buffer, _| buffer.text()), - host_buffer.read_with(host_cx, |buffer, _| buffer.text()), - "{}, buffer {}, path {:?}, differs from the host's buffer", - client.username, - buffer_id, - path - ); - - let host_file = host_buffer.read_with(host_cx, |b, _| b.file().cloned()); - let guest_file = guest_buffer.read_with(client_cx, |b, _| b.file().cloned()); - match (host_file, guest_file) { - (Some(host_file), Some(guest_file)) => { - assert_eq!(guest_file.path(), host_file.path()); - assert_eq!(guest_file.is_deleted(), host_file.is_deleted()); - assert_eq!( - guest_file.mtime(), - host_file.mtime(), - "guest {} mtime does not match host {} for path {:?} in project {}", - guest_user_id, - host_user_id, - guest_file.path(), - project_id, - ); - } - (None, None) => {} - (None, _) => panic!("host's file is None, guest's isn't"), - (_, None) => panic!("guest's file is None, hosts's isn't"), - } - - let host_diff_base = - host_buffer.read_with(host_cx, |b, _| b.diff_base().map(ToString::to_string)); - let guest_diff_base = guest_buffer - .read_with(client_cx, |b, _| b.diff_base().map(ToString::to_string)); - assert_eq!( - guest_diff_base, host_diff_base, - "guest {} diff base does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_version = - host_buffer.read_with(host_cx, |b, _| b.saved_version().clone()); - let guest_saved_version = - guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone()); - assert_eq!( - guest_saved_version, host_saved_version, - "guest {} saved version does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_version_fingerprint = - host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint()); - let guest_saved_version_fingerprint = - guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint()); - assert_eq!( - guest_saved_version_fingerprint, host_saved_version_fingerprint, - "guest {} saved fingerprint does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime()); - let guest_saved_mtime = guest_buffer.read_with(client_cx, |b, _| b.saved_mtime()); - assert_eq!( - guest_saved_mtime, host_saved_mtime, - "guest {} saved mtime does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty()); - let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty()); - assert_eq!(guest_is_dirty, host_is_dirty, - "guest {} dirty status does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict()); - let guest_has_conflict = guest_buffer.read_with(client_cx, |b, _| b.has_conflict()); - assert_eq!(guest_has_conflict, host_has_conflict, - "guest {} conflict status does not match host's for path {path:?} in project {project_id}", - client.username - ); - } - } - } -} - -struct TestPlan { - rng: StdRng, - replay: bool, - stored_operations: Vec<(StoredOperation, Arc)>, - max_operations: usize, - operation_ix: usize, - users: Vec, - next_batch_id: usize, - allow_server_restarts: bool, - allow_client_reconnection: bool, - allow_client_disconnection: bool, -} - -struct UserTestPlan { - user_id: UserId, - username: String, - next_root_id: usize, - operation_ix: usize, - online: bool, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(untagged)] -enum StoredOperation { - Server(Operation), - Client { - user_id: UserId, - batch_id: usize, - operation: ClientOperation, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum Operation { - AddConnection { - user_id: UserId, - }, - RemoveConnection { - user_id: UserId, - }, - BounceConnection { - user_id: UserId, - }, - RestartServer, - MutateClients { - batch_id: usize, - #[serde(skip_serializing)] - #[serde(skip_deserializing)] - user_ids: Vec, - quiesce: bool, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum ClientOperation { - AcceptIncomingCall, - RejectIncomingCall, - LeaveCall, - InviteContactToCall { - user_id: UserId, - }, - OpenLocalProject { - first_root_name: String, - }, - OpenRemoteProject { - host_id: UserId, - first_root_name: String, - }, - AddWorktreeToProject { - project_root_name: String, - new_root_path: PathBuf, - }, - CloseRemoteProject { - project_root_name: String, - }, - OpenBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - }, - SearchProject { - project_root_name: String, - is_local: bool, - query: String, - detach: bool, - }, - EditBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - edits: Vec<(Range, Arc)>, - }, - CloseBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - }, - SaveBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - detach: bool, - }, - RequestLspDataInBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - offset: usize, - kind: LspRequestKind, - detach: bool, - }, - CreateWorktreeEntry { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - is_dir: bool, - }, - WriteFsEntry { - path: PathBuf, - is_dir: bool, - content: String, - }, - GitOperation { - operation: GitOperation, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum GitOperation { - WriteGitIndex { - repo_path: PathBuf, - contents: Vec<(PathBuf, String)>, - }, - WriteGitBranch { - repo_path: PathBuf, - new_branch: Option, - }, - WriteGitStatuses { - repo_path: PathBuf, - statuses: Vec<(PathBuf, GitFileStatus)>, - git_operation: bool, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum LspRequestKind { - Rename, - Completion, - CodeAction, - Definition, - Highlights, -} - -enum TestError { - Inapplicable, - Other(anyhow::Error), -} - -impl From for TestError { - fn from(value: anyhow::Error) -> Self { - Self::Other(value) - } -} - -impl TestPlan { - fn new(mut rng: StdRng, users: Vec, max_operations: usize) -> Self { - Self { - replay: false, - allow_server_restarts: rng.gen_bool(0.7), - allow_client_reconnection: rng.gen_bool(0.7), - allow_client_disconnection: rng.gen_bool(0.1), - stored_operations: Vec::new(), - operation_ix: 0, - next_batch_id: 0, - max_operations, - users, - rng, - } - } - - fn deserialize(&mut self, json: Vec) { - let stored_operations: Vec = serde_json::from_slice(&json).unwrap(); - self.replay = true; - self.stored_operations = stored_operations - .iter() - .cloned() - .enumerate() - .map(|(i, mut operation)| { - if let StoredOperation::Server(Operation::MutateClients { - batch_id: current_batch_id, - user_ids, - .. - }) = &mut operation - { - assert!(user_ids.is_empty()); - user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| { - if let StoredOperation::Client { - user_id, batch_id, .. - } = operation - { - if batch_id == current_batch_id { - return Some(user_id); - } - } - None - })); - user_ids.sort_unstable(); - } - (operation, Arc::new(AtomicBool::new(false))) - }) - .collect() - } - - fn serialize(&mut self) -> Vec { - // Format each operation as one line - let mut json = Vec::new(); - json.push(b'['); - for (operation, applied) in &self.stored_operations { - if !applied.load(SeqCst) { - continue; - } - if json.len() > 1 { - json.push(b','); - } - json.extend_from_slice(b"\n "); - serde_json::to_writer(&mut json, operation).unwrap(); - } - json.extend_from_slice(b"\n]\n"); - json - } - - fn next_server_operation( - &mut self, - clients: &[(Rc, TestAppContext)], - ) -> Option<(Operation, Arc)> { - if self.replay { - while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) { - self.operation_ix += 1; - if let (StoredOperation::Server(operation), applied) = stored_operation { - return Some((operation.clone(), applied.clone())); - } - } - None - } else { - let operation = self.generate_server_operation(clients)?; - let applied = Arc::new(AtomicBool::new(false)); - self.stored_operations - .push((StoredOperation::Server(operation.clone()), applied.clone())); - Some((operation, applied)) - } - } - - fn next_client_operation( - &mut self, - client: &TestClient, - current_batch_id: usize, - cx: &TestAppContext, - ) -> Option<(ClientOperation, Arc)> { - let current_user_id = client.current_user_id(cx); - let user_ix = self - .users - .iter() - .position(|user| user.user_id == current_user_id) - .unwrap(); - let user_plan = &mut self.users[user_ix]; - - if self.replay { - while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) { - user_plan.operation_ix += 1; - if let ( - StoredOperation::Client { - user_id, operation, .. - }, - applied, - ) = stored_operation - { - if user_id == ¤t_user_id { - return Some((operation.clone(), applied.clone())); - } - } - } - None - } else { - let operation = self.generate_client_operation(current_user_id, client, cx)?; - let applied = Arc::new(AtomicBool::new(false)); - self.stored_operations.push(( - StoredOperation::Client { - user_id: current_user_id, - batch_id: current_batch_id, - operation: operation.clone(), - }, - applied.clone(), - )); - Some((operation, applied)) - } - } - - fn generate_server_operation( - &mut self, - clients: &[(Rc, TestAppContext)], - ) -> Option { - if self.operation_ix == self.max_operations { - return None; - } - - Some(loop { - break match self.rng.gen_range(0..100) { - 0..=29 if clients.len() < self.users.len() => { - let user = self - .users - .iter() - .filter(|u| !u.online) - .choose(&mut self.rng) - .unwrap(); - self.operation_ix += 1; - Operation::AddConnection { - user_id: user.user_id, - } - } - 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; - let user_id = client.current_user_id(cx); - self.operation_ix += 1; - Operation::RemoveConnection { user_id } - } - 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; - let user_id = client.current_user_id(cx); - self.operation_ix += 1; - Operation::BounceConnection { user_id } - } - 40..=44 if self.allow_server_restarts && clients.len() > 1 => { - self.operation_ix += 1; - Operation::RestartServer - } - _ if !clients.is_empty() => { - let count = self - .rng - .gen_range(1..10) - .min(self.max_operations - self.operation_ix); - let batch_id = util::post_inc(&mut self.next_batch_id); - let mut user_ids = (0..count) - .map(|_| { - let ix = self.rng.gen_range(0..clients.len()); - let (client, cx) = &clients[ix]; - client.current_user_id(cx) - }) - .collect::>(); - user_ids.sort_unstable(); - Operation::MutateClients { - user_ids, - batch_id, - quiesce: self.rng.gen_bool(0.7), - } - } - _ => continue, - }; - }) - } - - fn generate_client_operation( - &mut self, - user_id: UserId, - client: &TestClient, - cx: &TestAppContext, - ) -> Option { - if self.operation_ix == self.max_operations { - return None; - } - - self.operation_ix += 1; - let call = cx.read(ActiveCall::global); - Some(loop { - match self.rng.gen_range(0..100_u32) { - // Mutate the call - 0..=29 => { - // Respond to an incoming call - if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { - break if self.rng.gen_bool(0.7) { - ClientOperation::AcceptIncomingCall - } else { - ClientOperation::RejectIncomingCall - }; - } - - match self.rng.gen_range(0..100_u32) { - // Invite a contact to the current call - 0..=70 => { - let available_contacts = - client.user_store().read_with(cx, |user_store, _| { - user_store - .contacts() - .iter() - .filter(|contact| contact.online && !contact.busy) - .cloned() - .collect::>() - }); - if !available_contacts.is_empty() { - let contact = available_contacts.choose(&mut self.rng).unwrap(); - break ClientOperation::InviteContactToCall { - user_id: UserId(contact.user.id as i32), - }; - } - } - - // Leave the current call - 71.. => { - if self.allow_client_disconnection - && call.read_with(cx, |call, _| call.room().is_some()) - { - break ClientOperation::LeaveCall; - } - } - } - } - - // Mutate projects - 30..=59 => match self.rng.gen_range(0..100_u32) { - // Open a new project - 0..=70 => { - // Open a remote project - if let Some(room) = call.read_with(cx, |call, _| call.room().cloned()) { - let existing_remote_project_ids = cx.read(|cx| { - client - .remote_projects() - .iter() - .map(|p| p.read(cx).remote_id().unwrap()) - .collect::>() - }); - let new_remote_projects = room.read_with(cx, |room, _| { - room.remote_participants() - .values() - .flat_map(|participant| { - participant.projects.iter().filter_map(|project| { - if existing_remote_project_ids.contains(&project.id) { - None - } else { - Some(( - UserId::from_proto(participant.user.id), - project.worktree_root_names[0].clone(), - )) - } - }) - }) - .collect::>() - }); - if !new_remote_projects.is_empty() { - let (host_id, first_root_name) = - new_remote_projects.choose(&mut self.rng).unwrap().clone(); - break ClientOperation::OpenRemoteProject { - host_id, - first_root_name, - }; - } - } - // Open a local project - else { - let first_root_name = self.next_root_dir_name(user_id); - break ClientOperation::OpenLocalProject { first_root_name }; - } - } - - // Close a remote project - 71..=80 => { - if !client.remote_projects().is_empty() { - let project = client - .remote_projects() - .choose(&mut self.rng) - .unwrap() - .clone(); - let first_root_name = root_name_for_project(&project, cx); - break ClientOperation::CloseRemoteProject { - project_root_name: first_root_name, - }; - } - } - - // Mutate project worktrees - 81.. => match self.rng.gen_range(0..100_u32) { - // Add a worktree to a local project - 0..=50 => { - let Some(project) = - client.local_projects().choose(&mut self.rng).cloned() - else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let mut paths = client.fs().paths(false); - paths.remove(0); - let new_root_path = if paths.is_empty() || self.rng.gen() { - Path::new("/").join(&self.next_root_dir_name(user_id)) - } else { - paths.choose(&mut self.rng).unwrap().clone() - }; - break ClientOperation::AddWorktreeToProject { - project_root_name, - new_root_path, - }; - } - - // Add an entry to a worktree - _ => { - let Some(project) = choose_random_project(client, &mut self.rng) else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let is_local = project.read_with(cx, |project, _| project.is_local()); - let worktree = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .filter(|worktree| { - let worktree = worktree.read(cx); - worktree.is_visible() - && worktree.entries(false).any(|e| e.is_file()) - && worktree.root_entry().map_or(false, |e| e.is_dir()) - }) - .choose(&mut self.rng) - }); - let Some(worktree) = worktree else { continue }; - let is_dir = self.rng.gen::(); - let mut full_path = - worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); - full_path.push(gen_file_name(&mut self.rng)); - if !is_dir { - full_path.set_extension("rs"); - } - break ClientOperation::CreateWorktreeEntry { - project_root_name, - is_local, - full_path, - is_dir, - }; - } - }, - }, - - // Query and mutate buffers - 60..=90 => { - let Some(project) = choose_random_project(client, &mut self.rng) else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let is_local = project.read_with(cx, |project, _| project.is_local()); - - match self.rng.gen_range(0..100_u32) { - // Manipulate an existing buffer - 0..=70 => { - let Some(buffer) = client - .buffers_for_project(&project) - .iter() - .choose(&mut self.rng) - .cloned() - else { - continue; - }; - - let full_path = buffer - .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); - - match self.rng.gen_range(0..100_u32) { - // Close the buffer - 0..=15 => { - break ClientOperation::CloseBuffer { - project_root_name, - is_local, - full_path, - }; - } - // Save the buffer - 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { - let detach = self.rng.gen_bool(0.3); - break ClientOperation::SaveBuffer { - project_root_name, - is_local, - full_path, - detach, - }; - } - // Edit the buffer - 30..=69 => { - let edits = buffer.read_with(cx, |buffer, _| { - buffer.get_random_edits(&mut self.rng, 3) - }); - break ClientOperation::EditBuffer { - project_root_name, - is_local, - full_path, - edits, - }; - } - // Make an LSP request - _ => { - let offset = buffer.read_with(cx, |buffer, _| { - buffer.clip_offset( - self.rng.gen_range(0..=buffer.len()), - language::Bias::Left, - ) - }); - let detach = self.rng.gen(); - break ClientOperation::RequestLspDataInBuffer { - project_root_name, - full_path, - offset, - is_local, - kind: match self.rng.gen_range(0..5_u32) { - 0 => LspRequestKind::Rename, - 1 => LspRequestKind::Highlights, - 2 => LspRequestKind::Definition, - 3 => LspRequestKind::CodeAction, - 4.. => LspRequestKind::Completion, - }, - detach, - }; - } - } - } - - 71..=80 => { - let query = self.rng.gen_range('a'..='z').to_string(); - let detach = self.rng.gen_bool(0.3); - break ClientOperation::SearchProject { - project_root_name, - is_local, - query, - detach, - }; - } - - // Open a buffer - 81.. => { - let worktree = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .filter(|worktree| { - let worktree = worktree.read(cx); - worktree.is_visible() - && worktree.entries(false).any(|e| e.is_file()) - }) - .choose(&mut self.rng) - }); - let Some(worktree) = worktree else { continue }; - let full_path = worktree.read_with(cx, |worktree, _| { - let entry = worktree - .entries(false) - .filter(|e| e.is_file()) - .choose(&mut self.rng) - .unwrap(); - if entry.path.as_ref() == Path::new("") { - Path::new(worktree.root_name()).into() - } else { - Path::new(worktree.root_name()).join(&entry.path) - } - }); - break ClientOperation::OpenBuffer { - project_root_name, - is_local, - full_path, - }; - } - } - } - - // Update a git related action - 91..=95 => { - break ClientOperation::GitOperation { - operation: self.generate_git_operation(client), - }; - } - - // Create or update a file or directory - 96.. => { - let is_dir = self.rng.gen::(); - let content; - let mut path; - let dir_paths = client.fs().directories(false); - - if is_dir { - content = String::new(); - path = dir_paths.choose(&mut self.rng).unwrap().clone(); - path.push(gen_file_name(&mut self.rng)); - } else { - content = Alphanumeric.sample_string(&mut self.rng, 16); - - // Create a new file or overwrite an existing file - let file_paths = client.fs().files(); - if file_paths.is_empty() || self.rng.gen_bool(0.5) { - path = dir_paths.choose(&mut self.rng).unwrap().clone(); - path.push(gen_file_name(&mut self.rng)); - path.set_extension("rs"); - } else { - path = file_paths.choose(&mut self.rng).unwrap().clone() - }; - } - break ClientOperation::WriteFsEntry { - path, - is_dir, - content, - }; - } - } - }) - } - - fn generate_git_operation(&mut self, client: &TestClient) -> GitOperation { - fn generate_file_paths( - repo_path: &Path, - rng: &mut StdRng, - client: &TestClient, - ) -> Vec { - let mut paths = client - .fs() - .files() - .into_iter() - .filter(|path| path.starts_with(repo_path)) - .collect::>(); - - let count = rng.gen_range(0..=paths.len()); - paths.shuffle(rng); - paths.truncate(count); - - paths - .iter() - .map(|path| path.strip_prefix(repo_path).unwrap().to_path_buf()) - .collect::>() - } - - let repo_path = client - .fs() - .directories(false) - .choose(&mut self.rng) - .unwrap() - .clone(); - - match self.rng.gen_range(0..100_u32) { - 0..=25 => { - let file_paths = generate_file_paths(&repo_path, &mut self.rng, client); - - let contents = file_paths - .into_iter() - .map(|path| (path, Alphanumeric.sample_string(&mut self.rng, 16))) - .collect(); - - GitOperation::WriteGitIndex { - repo_path, - contents, - } - } - 26..=63 => { - let new_branch = (self.rng.gen_range(0..10) > 3) - .then(|| Alphanumeric.sample_string(&mut self.rng, 8)); - - GitOperation::WriteGitBranch { - repo_path, - new_branch, - } - } - 64..=100 => { - let file_paths = generate_file_paths(&repo_path, &mut self.rng, client); - - let statuses = file_paths - .into_iter() - .map(|paths| { - ( - paths, - match self.rng.gen_range(0..3_u32) { - 0 => GitFileStatus::Added, - 1 => GitFileStatus::Modified, - 2 => GitFileStatus::Conflict, - _ => unreachable!(), - }, - ) - }) - .collect::>(); - - let git_operation = self.rng.gen::(); - - GitOperation::WriteGitStatuses { - repo_path, - statuses, - git_operation, - } - } - _ => unreachable!(), - } - } - - fn next_root_dir_name(&mut self, user_id: UserId) -> String { - let user_ix = self - .users - .iter() - .position(|user| user.user_id == user_id) - .unwrap(); - let root_id = util::post_inc(&mut self.users[user_ix].next_root_id); - format!("dir-{user_id}-{root_id}") - } - - fn user(&mut self, user_id: UserId) -> &mut UserTestPlan { - let ix = self - .users - .iter() - .position(|user| user.user_id == user_id) - .unwrap(); - &mut self.users[ix] - } -} - -async fn simulate_client( - client: Rc, - mut operation_rx: futures::channel::mpsc::UnboundedReceiver, - plan: Arc>, - mut cx: TestAppContext, -) { - // Setup language server - let mut language = Language::new( - LanguageConfig { - name: "Rust".into(), - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - None, - ); - let _fake_language_servers = language - .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { - name: "the-fake-language-server", - capabilities: lsp::LanguageServer::full_capabilities(), - initializer: Some(Box::new({ - let fs = client.app_state.fs.clone(); - move |fake_server: &mut FakeLanguageServer| { - fake_server.handle_request::( - |_, _| async move { - Ok(Some(lsp::CompletionResponse::Array(vec![ - lsp::CompletionItem { - text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { - range: lsp::Range::new( - lsp::Position::new(0, 0), - lsp::Position::new(0, 0), - ), - new_text: "the-new-text".to_string(), - })), - ..Default::default() - }, - ]))) - }, - ); - - fake_server.handle_request::( - |_, _| async move { - Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( - lsp::CodeAction { - title: "the-code-action".to_string(), - ..Default::default() - }, - )])) - }, - ); - - fake_server.handle_request::( - |params, _| async move { - Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( - params.position, - params.position, - )))) - }, - ); - - fake_server.handle_request::({ - let fs = fs.clone(); - move |_, cx| { - let background = cx.background(); - let mut rng = background.rng(); - let count = rng.gen_range::(1..3); - let files = fs.as_fake().files(); - let files = (0..count) - .map(|_| files.choose(&mut *rng).unwrap().clone()) - .collect::>(); - async move { - log::info!("LSP: Returning definitions in files {:?}", &files); - Ok(Some(lsp::GotoDefinitionResponse::Array( - files - .into_iter() - .map(|file| lsp::Location { - uri: lsp::Url::from_file_path(file).unwrap(), - range: Default::default(), - }) - .collect(), - ))) - } - } - }); - - fake_server.handle_request::( - move |_, cx| { - let mut highlights = Vec::new(); - let background = cx.background(); - let mut rng = background.rng(); - - let highlight_count = rng.gen_range(1..=5); - for _ in 0..highlight_count { - let start_row = rng.gen_range(0..100); - let start_column = rng.gen_range(0..100); - let end_row = rng.gen_range(0..100); - let end_column = rng.gen_range(0..100); - let start = PointUtf16::new(start_row, start_column); - let end = PointUtf16::new(end_row, end_column); - let range = if start > end { end..start } else { start..end }; - highlights.push(lsp::DocumentHighlight { - range: range_to_lsp(range.clone()), - kind: Some(lsp::DocumentHighlightKind::READ), - }); - } - highlights.sort_unstable_by_key(|highlight| { - (highlight.range.start, highlight.range.end) - }); - async move { Ok(Some(highlights)) } - }, - ); - } - })), - ..Default::default() - })) - .await; - client.app_state.languages.add(Arc::new(language)); - - while let Some(batch_id) = operation_rx.next().await { - let Some((operation, applied)) = plan.lock().next_client_operation(&client, batch_id, &cx) - else { - break; - }; - applied.store(true, SeqCst); - match apply_client_operation(&client, operation, &mut cx).await { - Ok(()) => {} - Err(TestError::Inapplicable) => { - applied.store(false, SeqCst); - log::info!("skipped operation"); - } - Err(TestError::Other(error)) => { - log::error!("{} error: {}", client.username, error); - } - } - cx.background().simulate_random_delay().await; - } - log::info!("{}: done", client.username); -} - -fn buffer_for_full_path( - client: &TestClient, - project: &ModelHandle, - full_path: &PathBuf, - cx: &TestAppContext, -) -> Option> { - client - .buffers_for_project(project) - .iter() - .find(|buffer| { - buffer.read_with(cx, |buffer, cx| { - buffer.file().unwrap().full_path(cx) == *full_path - }) - }) - .cloned() -} - -fn project_for_root_name( - client: &TestClient, - root_name: &str, - cx: &TestAppContext, -) -> Option> { - if let Some(ix) = project_ix_for_root_name(&*client.local_projects(), root_name, cx) { - return Some(client.local_projects()[ix].clone()); - } - if let Some(ix) = project_ix_for_root_name(&*client.remote_projects(), root_name, cx) { - return Some(client.remote_projects()[ix].clone()); - } - None -} - -fn project_ix_for_root_name( - projects: &[ModelHandle], - root_name: &str, - cx: &TestAppContext, -) -> Option { - projects.iter().position(|project| { - project.read_with(cx, |project, cx| { - let worktree = project.visible_worktrees(cx).next().unwrap(); - worktree.read(cx).root_name() == root_name - }) - }) -} - -fn root_name_for_project(project: &ModelHandle, cx: &TestAppContext) -> String { - project.read_with(cx, |project, cx| { - project - .visible_worktrees(cx) - .next() - .unwrap() - .read(cx) - .root_name() - .to_string() - }) -} - -fn project_path_for_full_path( - project: &ModelHandle, - full_path: &Path, - cx: &TestAppContext, -) -> Option { - let mut components = full_path.components(); - let root_name = components.next().unwrap().as_os_str().to_str().unwrap(); - let path = components.as_path().into(); - let worktree_id = project.read_with(cx, |project, cx| { - project.worktrees(cx).find_map(|worktree| { - let worktree = worktree.read(cx); - if worktree.root_name() == root_name { - Some(worktree.id()) - } else { - None - } - }) - })?; - Some(ProjectPath { worktree_id, path }) -} - -async fn ensure_project_shared( - project: &ModelHandle, - client: &TestClient, - cx: &mut TestAppContext, -) { - let first_root_name = root_name_for_project(project, cx); - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.room().is_some()) - && project.read_with(cx, |project, _| project.is_local() && !project.is_shared()) - { - match active_call - .update(cx, |call, cx| call.share_project(project.clone(), cx)) - .await - { - Ok(project_id) => { - log::info!( - "{}: shared project {} with id {}", - client.username, - first_root_name, - project_id - ); - } - Err(error) => { - log::error!( - "{}: error sharing project {}: {:?}", - client.username, - first_root_name, - error - ); - } - } - } -} - -fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option> { - client - .local_projects() - .iter() - .chain(client.remote_projects().iter()) - .choose(rng) - .cloned() -} - -fn gen_file_name(rng: &mut StdRng) -> String { - let mut name = String::new(); - for _ in 0..10 { - let letter = rng.gen_range('a'..='z'); - name.push(letter); - } - name -} - -fn path_env_var(name: &str) -> Option { - let value = env::var(name).ok()?; - let mut path = PathBuf::from(value); - if path.is_relative() { - let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - abs_path.pop(); - abs_path.pop(); - abs_path.push(path); - path = abs_path - } - Some(path) -} diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs new file mode 100644 index 0000000000..39598bdaf9 --- /dev/null +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -0,0 +1,689 @@ +use crate::{ + db::{self, NewUserParams, UserId}, + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::{TestClient, TestServer}, +}; +use async_trait::async_trait; +use futures::StreamExt; +use gpui::{executor::Deterministic, Task, TestAppContext}; +use parking_lot::Mutex; +use rand::prelude::*; +use rpc::RECEIVE_TIMEOUT; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use settings::SettingsStore; +use std::{ + env, + path::PathBuf, + rc::Rc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; + +lazy_static::lazy_static! { + static ref PLAN_LOAD_PATH: Option = path_env_var("LOAD_PLAN"); + static ref PLAN_SAVE_PATH: Option = path_env_var("SAVE_PLAN"); + static ref MAX_PEERS: usize = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(3); + static ref MAX_OPERATIONS: usize = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(10); + +} + +static LOADED_PLAN_JSON: Mutex>> = Mutex::new(None); +static LAST_PLAN: Mutex Vec>>> = Mutex::new(None); + +struct TestPlan { + rng: StdRng, + replay: bool, + stored_operations: Vec<(StoredOperation, Arc)>, + max_operations: usize, + operation_ix: usize, + users: Vec, + next_batch_id: usize, + allow_server_restarts: bool, + allow_client_reconnection: bool, + allow_client_disconnection: bool, +} + +pub struct UserTestPlan { + pub user_id: UserId, + pub username: String, + pub allow_client_reconnection: bool, + pub allow_client_disconnection: bool, + next_root_id: usize, + operation_ix: usize, + online: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum StoredOperation { + Server(ServerOperation), + Client { + user_id: UserId, + batch_id: usize, + operation: T, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ServerOperation { + AddConnection { + user_id: UserId, + }, + RemoveConnection { + user_id: UserId, + }, + BounceConnection { + user_id: UserId, + }, + RestartServer, + MutateClients { + batch_id: usize, + #[serde(skip_serializing)] + #[serde(skip_deserializing)] + user_ids: Vec, + quiesce: bool, + }, +} + +pub enum TestError { + Inapplicable, + Other(anyhow::Error), +} + +#[async_trait(?Send)] +pub trait RandomizedTest: 'static + Sized { + type Operation: Send + Clone + Serialize + DeserializeOwned; + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> Self::Operation; + + async fn apply_operation( + client: &TestClient, + operation: Self::Operation, + cx: &mut TestAppContext, + ) -> Result<(), TestError>; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]); + + async fn on_client_added(client: &Rc, cx: &mut TestAppContext); + + async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc, TestAppContext)]); +} + +pub async fn run_randomized_test( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let plan = TestPlan::::new(&mut server, rng).await; + + LAST_PLAN.lock().replace({ + let plan = plan.clone(); + Box::new(move || plan.lock().serialize()) + }); + + let mut clients = Vec::new(); + let mut client_tasks = Vec::new(); + let mut operation_channels = Vec::new(); + loop { + let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else { + break; + }; + applied.store(true, SeqCst); + let did_apply = TestPlan::apply_server_operation( + plan.clone(), + deterministic.clone(), + &mut server, + &mut clients, + &mut client_tasks, + &mut operation_channels, + next_operation, + cx, + ) + .await; + if !did_apply { + applied.store(false, SeqCst); + } + } + + drop(operation_channels); + deterministic.start_waiting(); + futures::future::join_all(client_tasks).await; + deterministic.finish_waiting(); + + deterministic.run_until_parked(); + T::on_quiesce(&mut server, &mut clients).await; + + for (client, mut cx) in clients { + cx.update(|cx| { + let store = cx.remove_global::(); + cx.clear_globals(); + cx.set_global(store); + drop(client); + }); + } + deterministic.run_until_parked(); + + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, plan.lock().serialize()).unwrap(); + } +} + +pub fn save_randomized_test_plan() { + if let Some(serialize_plan) = LAST_PLAN.lock().take() { + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, serialize_plan()).unwrap(); + } + } +} + +impl TestPlan { + pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc> { + let allow_server_restarts = rng.gen_bool(0.7); + let allow_client_reconnection = rng.gen_bool(0.7); + let allow_client_disconnection = rng.gen_bool(0.1); + + let mut users = Vec::new(); + for ix in 0..*MAX_PEERS { + let username = format!("user-{}", ix + 1); + let user_id = server + .app_state + .db + .create_user( + &format!("{username}@example.com"), + false, + NewUserParams { + github_login: username.clone(), + github_user_id: (ix + 1) as i32, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + users.push(UserTestPlan { + user_id, + username, + online: false, + next_root_id: 0, + operation_ix: 0, + allow_client_disconnection, + allow_client_reconnection, + }); + } + + T::initialize(server, &users).await; + + let plan = Arc::new(Mutex::new(Self { + replay: false, + allow_server_restarts, + allow_client_reconnection, + allow_client_disconnection, + stored_operations: Vec::new(), + operation_ix: 0, + next_batch_id: 0, + max_operations: *MAX_OPERATIONS, + users, + rng, + })); + + if let Some(path) = &*PLAN_LOAD_PATH { + let json = LOADED_PLAN_JSON + .lock() + .get_or_insert_with(|| { + eprintln!("loaded test plan from path {:?}", path); + std::fs::read(path).unwrap() + }) + .clone(); + plan.lock().deserialize(json); + } + + plan + } + + fn deserialize(&mut self, json: Vec) { + let stored_operations: Vec> = + serde_json::from_slice(&json).unwrap(); + self.replay = true; + self.stored_operations = stored_operations + .iter() + .cloned() + .enumerate() + .map(|(i, mut operation)| { + let did_apply = Arc::new(AtomicBool::new(false)); + if let StoredOperation::Server(ServerOperation::MutateClients { + batch_id: current_batch_id, + user_ids, + .. + }) = &mut operation + { + assert!(user_ids.is_empty()); + user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| { + if let StoredOperation::Client { + user_id, batch_id, .. + } = operation + { + if batch_id == current_batch_id { + return Some(user_id); + } + } + None + })); + user_ids.sort_unstable(); + } + (operation, did_apply) + }) + .collect() + } + + fn serialize(&mut self) -> Vec { + // Format each operation as one line + let mut json = Vec::new(); + json.push(b'['); + for (operation, applied) in &self.stored_operations { + if !applied.load(SeqCst) { + continue; + } + if json.len() > 1 { + json.push(b','); + } + json.extend_from_slice(b"\n "); + serde_json::to_writer(&mut json, operation).unwrap(); + } + json.extend_from_slice(b"\n]\n"); + json + } + + fn next_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option<(ServerOperation, Arc)> { + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) { + self.operation_ix += 1; + if let (StoredOperation::Server(operation), applied) = stored_operation { + return Some((operation.clone(), applied.clone())); + } + } + None + } else { + let operation = self.generate_server_operation(clients)?; + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations + .push((StoredOperation::Server(operation.clone()), applied.clone())); + Some((operation, applied)) + } + } + + fn next_client_operation( + &mut self, + client: &TestClient, + current_batch_id: usize, + cx: &TestAppContext, + ) -> Option<(T::Operation, Arc)> { + let current_user_id = client.current_user_id(cx); + let user_ix = self + .users + .iter() + .position(|user| user.user_id == current_user_id) + .unwrap(); + let user_plan = &mut self.users[user_ix]; + + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) { + user_plan.operation_ix += 1; + if let ( + StoredOperation::Client { + user_id, operation, .. + }, + applied, + ) = stored_operation + { + if user_id == ¤t_user_id { + return Some((operation.clone(), applied.clone())); + } + } + } + None + } else { + if self.operation_ix == self.max_operations { + return None; + } + self.operation_ix += 1; + let operation = T::generate_operation( + client, + &mut self.rng, + self.users + .iter_mut() + .find(|user| user.user_id == current_user_id) + .unwrap(), + cx, + ); + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations.push(( + StoredOperation::Client { + user_id: current_user_id, + batch_id: current_batch_id, + operation: operation.clone(), + }, + applied.clone(), + )); + Some((operation, applied)) + } + } + + fn generate_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option { + if self.operation_ix == self.max_operations { + return None; + } + + Some(loop { + break match self.rng.gen_range(0..100) { + 0..=29 if clients.len() < self.users.len() => { + let user = self + .users + .iter() + .filter(|u| !u.online) + .choose(&mut self.rng) + .unwrap(); + self.operation_ix += 1; + ServerOperation::AddConnection { + user_id: user.user_id, + } + } + 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::RemoveConnection { user_id } + } + 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::BounceConnection { user_id } + } + 40..=44 if self.allow_server_restarts && clients.len() > 1 => { + self.operation_ix += 1; + ServerOperation::RestartServer + } + _ if !clients.is_empty() => { + let count = self + .rng + .gen_range(1..10) + .min(self.max_operations - self.operation_ix); + let batch_id = util::post_inc(&mut self.next_batch_id); + let mut user_ids = (0..count) + .map(|_| { + let ix = self.rng.gen_range(0..clients.len()); + let (client, cx) = &clients[ix]; + client.current_user_id(cx) + }) + .collect::>(); + user_ids.sort_unstable(); + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce: self.rng.gen_bool(0.7), + } + } + _ => continue, + }; + }) + } + + async fn apply_server_operation( + plan: Arc>, + deterministic: Arc, + server: &mut TestServer, + clients: &mut Vec<(Rc, TestAppContext)>, + client_tasks: &mut Vec>, + operation_channels: &mut Vec>, + operation: ServerOperation, + cx: &mut TestAppContext, + ) -> bool { + match operation { + ServerOperation::AddConnection { user_id } => { + let username; + { + let mut plan = plan.lock(); + let user = plan.user(user_id); + if user.online { + return false; + } + user.online = true; + username = user.username.clone(); + }; + log::info!("adding new connection for {}", username); + let next_entity_id = (user_id.0 * 10_000) as usize; + let mut client_cx = TestAppContext::new( + cx.foreground_platform(), + cx.platform(), + deterministic.build_foreground(user_id.0 as usize), + deterministic.build_background(), + cx.font_cache(), + cx.leak_detector(), + next_entity_id, + cx.function_name.clone(), + ); + + let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded(); + let client = Rc::new(server.create_client(&mut client_cx, &username).await); + operation_channels.push(operation_tx); + clients.push((client.clone(), client_cx.clone())); + client_tasks.push(client_cx.foreground().spawn(Self::simulate_client( + plan.clone(), + client, + operation_rx, + client_cx, + ))); + + log::info!("added connection for {}", username); + } + + ServerOperation::RemoveConnection { + user_id: removed_user_id, + } => { + log::info!("simulating full disconnection of user {}", removed_user_id); + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == removed_user_id); + let Some(client_ix) = client_ix else { + return false; + }; + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(removed_user_id) + .collect::>(); + assert_eq!(user_connection_ids.len(), 1); + let removed_peer_id = user_connection_ids[0].into(); + let (client, mut client_cx) = clients.remove(client_ix); + let client_task = client_tasks.remove(client_ix); + operation_channels.remove(client_ix); + server.forbid_connections(); + server.disconnect_client(removed_peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + deterministic.start_waiting(); + log::info!("waiting for user {} to exit...", removed_user_id); + client_task.await; + deterministic.finish_waiting(); + server.allow_connections(); + + for project in client.remote_projects().iter() { + project.read_with(&client_cx, |project, _| { + assert!( + project.is_read_only(), + "project {:?} should be read only", + project.remote_id() + ) + }); + } + + for (client, cx) in clients { + let contacts = server + .app_state + .db + .get_contacts(client.current_user_id(cx)) + .await + .unwrap(); + let pool = server.connection_pool.lock(); + for contact in contacts { + if let db::Contact::Accepted { user_id, busy, .. } = contact { + if user_id == removed_user_id { + assert!(!pool.is_user_online(user_id)); + assert!(!busy); + } + } + } + } + + log::info!("{} removed", client.username); + plan.lock().user(removed_user_id).online = false; + client_cx.update(|cx| { + cx.clear_globals(); + drop(client); + }); + } + + ServerOperation::BounceConnection { user_id } => { + log::info!("simulating temporary disconnection of user {}", user_id); + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(user_id) + .collect::>(); + if user_connection_ids.is_empty() { + return false; + } + assert_eq!(user_connection_ids.len(), 1); + let peer_id = user_connection_ids[0].into(); + server.disconnect_client(peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + } + + ServerOperation::RestartServer => { + log::info!("simulating server restart"); + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + let environment = &server.app_state.config.zed_environment; + let (stale_room_ids, _) = server + .app_state + .db + .stale_server_resource_ids(environment, server.id()) + .await + .unwrap(); + assert_eq!(stale_room_ids, vec![]); + } + + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce, + } => { + let mut applied = false; + for user_id in user_ids { + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == user_id); + let Some(client_ix) = client_ix else { continue }; + applied = true; + if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) { + log::error!("error signaling user {user_id}: {err}"); + } + } + + if quiesce && applied { + deterministic.run_until_parked(); + T::on_quiesce(server, clients).await; + } + + return applied; + } + } + true + } + + async fn simulate_client( + plan: Arc>, + client: Rc, + mut operation_rx: futures::channel::mpsc::UnboundedReceiver, + mut cx: TestAppContext, + ) { + T::on_client_added(&client, &mut cx).await; + + while let Some(batch_id) = operation_rx.next().await { + let Some((operation, applied)) = + plan.lock().next_client_operation(&client, batch_id, &cx) + else { + break; + }; + applied.store(true, SeqCst); + match T::apply_operation(&client, operation, &mut cx).await { + Ok(()) => {} + Err(TestError::Inapplicable) => { + applied.store(false, SeqCst); + log::info!("skipped operation"); + } + Err(TestError::Other(error)) => { + log::error!("{} error: {}", client.username, error); + } + } + cx.background().simulate_random_delay().await; + } + log::info!("{}: done", client.username); + } + + fn user(&mut self, user_id: UserId) -> &mut UserTestPlan { + self.users + .iter_mut() + .find(|user| user.user_id == user_id) + .unwrap() + } +} + +impl UserTestPlan { + pub fn next_root_dir_name(&mut self) -> String { + let user_id = self.user_id; + let root_id = util::post_inc(&mut self.next_root_id); + format!("dir-{user_id}-{root_id}") + } +} + +impl From for TestError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value) + } +} + +fn path_env_var(name: &str) -> Option { + let value = env::var(name).ok()?; + let mut path = PathBuf::from(value); + if path.is_relative() { + let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + abs_path.pop(); + abs_path.pop(); + abs_path.push(path); + path = abs_path + } + Some(path) +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs new file mode 100644 index 0000000000..eef1dde967 --- /dev/null +++ b/crates/collab/src/tests/test_server.rs @@ -0,0 +1,558 @@ +use crate::{ + db::{tests::TestDb, NewUserParams, UserId}, + executor::Executor, + rpc::{Server, CLEANUP_TIMEOUT}, + AppState, +}; +use anyhow::anyhow; +use call::ActiveCall; +use channel::{channel_buffer::ChannelBuffer, ChannelStore}; +use client::{ + self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, +}; +use collections::{HashMap, HashSet}; +use fs::FakeFs; +use futures::{channel::oneshot, StreamExt as _}; +use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; +use language::LanguageRegistry; +use parking_lot::Mutex; +use project::{Project, WorktreeId}; +use settings::SettingsStore; +use std::{ + cell::{Ref, RefCell, RefMut}, + env, + ops::{Deref, DerefMut}, + path::Path, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + Arc, + }, +}; +use util::http::FakeHttpClient; +use workspace::Workspace; + +pub struct TestServer { + pub app_state: Arc, + pub test_live_kit_server: Arc, + server: Arc, + connection_killers: Arc>>>, + forbid_connections: Arc, + _test_db: TestDb, +} + +pub struct TestClient { + pub username: String, + pub app_state: Arc, + state: RefCell, +} + +#[derive(Default)] +struct TestClientState { + local_projects: Vec>, + remote_projects: Vec>, + buffers: HashMap, HashSet>>, + channel_buffers: HashSet>, +} + +pub struct ContactsSummary { + pub current: Vec, + pub outgoing_requests: Vec, + pub incoming_requests: Vec, +} + +impl TestServer { + pub async fn start(deterministic: &Arc) -> Self { + static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); + + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(deterministic.build_background()) + } else { + TestDb::sqlite(deterministic.build_background()) + }; + let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); + let live_kit_server = live_kit_client::TestServer::create( + format!("http://livekit.{}.test", live_kit_server_id), + format!("devkey-{}", live_kit_server_id), + format!("secret-{}", live_kit_server_id), + deterministic.build_background(), + ) + .unwrap(); + let app_state = Self::build_app_state(&test_db, &live_kit_server).await; + let epoch = app_state + .db + .create_server(&app_state.config.zed_environment) + .await + .unwrap(); + let server = Server::new( + epoch, + app_state.clone(), + Executor::Deterministic(deterministic.build_background()), + ); + server.start().await.unwrap(); + // Advance clock to ensure the server's cleanup task is finished. + deterministic.advance_clock(CLEANUP_TIMEOUT); + Self { + app_state, + server, + connection_killers: Default::default(), + forbid_connections: Default::default(), + _test_db: test_db, + test_live_kit_server: live_kit_server, + } + } + + pub async fn reset(&self) { + self.app_state.db.reset(); + let epoch = self + .app_state + .db + .create_server(&self.app_state.config.zed_environment) + .await + .unwrap(); + self.server.reset(epoch); + } + + pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + cx.update(|cx| { + if cx.has_global::() { + panic!("Same cx used to create two test clients") + } + cx.set_global(SettingsStore::test(cx)); + }); + + let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await + { + user.id + } else { + self.app_state + .db + .create_user( + &format!("{name}@example.com"), + false, + NewUserParams { + github_login: name.into(), + github_user_id: 0, + invite_count: 0, + }, + ) + .await + .expect("creating user failed") + .user_id + }; + let client_name = name.to_string(); + let mut client = cx.read(|cx| Client::new(http.clone(), cx)); + let server = self.server.clone(); + let db = self.app_state.db.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); + + Arc::get_mut(&mut client) + .unwrap() + .set_id(user_id.0 as usize) + .override_authenticate(move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok(Credentials { + user_id: user_id.0 as u64, + access_token, + }) + }) + }) + .override_establish_connection(move |credentials, cx| { + assert_eq!(credentials.user_id, user_id.0 as u64); + assert_eq!(credentials.access_token, "the-token"); + + let server = server.clone(); + let db = db.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(EstablishConnectionError::other(anyhow!( + "server is forbidding connections" + ))) + } else { + let (client_conn, server_conn, killed) = + Connection::in_memory(cx.background()); + let (connection_id_tx, connection_id_rx) = oneshot::channel(); + let user = db + .get_user_by_id(user_id) + .await + .expect("retrieving user failed") + .unwrap(); + cx.background() + .spawn(server.handle_connection( + server_conn, + client_name, + user, + Some(connection_id_tx), + Executor::Deterministic(cx.background()), + )) + .detach(); + let connection_id = connection_id_rx.await.unwrap(); + connection_killers + .lock() + .insert(connection_id.into(), killed); + Ok(client_conn) + } + }) + }); + + let fs = FakeFs::new(cx.background()); + let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); + let channel_store = + cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx)); + let app_state = Arc::new(workspace::AppState { + client: client.clone(), + user_store: user_store.clone(), + channel_store: channel_store.clone(), + languages: Arc::new(LanguageRegistry::test()), + fs: fs.clone(), + build_window_options: |_, _, _| Default::default(), + initialize_workspace: |_, _, _, _| Task::ready(Ok(())), + background_actions: || &[], + }); + + cx.update(|cx| { + theme::init((), cx); + Project::init(&client, cx); + client::init(&client, cx); + language::init(cx); + editor::init_settings(cx); + workspace::init(app_state.clone(), cx); + audio::init((), cx); + call::init(client.clone(), user_store.clone(), cx); + channel::init(&client); + }); + + client + .authenticate_and_connect(false, &cx.to_async()) + .await + .unwrap(); + + let client = TestClient { + app_state, + username: name.to_string(), + state: Default::default(), + }; + client.wait_for_current_user(cx).await; + client + } + + pub fn disconnect_client(&self, peer_id: PeerId) { + self.connection_killers + .lock() + .remove(&peer_id) + .unwrap() + .store(true, SeqCst); + } + + pub fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + pub fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); + } + + pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + for ix in 1..clients.len() { + let (left, right) = clients.split_at_mut(ix); + let (client_a, cx_a) = left.last_mut().unwrap(); + for (client_b, cx_b) in right { + client_a + .app_state + .user_store + .update(*cx_a, |store, cx| { + store.request_contact(client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + cx_a.foreground().run_until_parked(); + client_b + .app_state + .user_store + .update(*cx_b, |store, cx| { + store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) + }) + .await + .unwrap(); + } + } + } + + pub async fn make_channel( + &self, + channel: &str, + admin: (&TestClient, &mut TestAppContext), + members: &mut [(&TestClient, &mut TestAppContext)], + ) -> u64 { + let (admin_client, admin_cx) = admin; + let channel_id = admin_client + .app_state + .channel_store + .update(admin_cx, |channel_store, cx| { + channel_store.create_channel(channel, None, cx) + }) + .await + .unwrap(); + + for (member_client, member_cx) in members { + admin_client + .app_state + .channel_store + .update(admin_cx, |channel_store, cx| { + channel_store.invite_member( + channel_id, + member_client.user_id().unwrap(), + false, + cx, + ) + }) + .await + .unwrap(); + + admin_cx.foreground().run_until_parked(); + + member_client + .app_state + .channel_store + .update(*member_cx, |channels, _| { + channels.respond_to_channel_invite(channel_id, true) + }) + .await + .unwrap(); + } + + channel_id + } + + pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + self.make_contacts(clients).await; + + let (left, right) = clients.split_at_mut(1); + let (_client_a, cx_a) = &mut left[0]; + let active_call_a = cx_a.read(ActiveCall::global); + + for (client_b, cx_b) in right { + let user_id_b = client_b.current_user_id(*cx_b).to_proto(); + active_call_a + .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx)) + .await + .unwrap(); + + cx_b.foreground().run_until_parked(); + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(*cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + } + } + + pub async fn build_app_state( + test_db: &TestDb, + fake_server: &live_kit_client::TestServer, + ) -> Arc { + Arc::new(AppState { + db: test_db.db().clone(), + live_kit_client: Some(Arc::new(fake_server.create_api_client())), + config: Default::default(), + }) + } +} + +impl Deref for TestServer { + type Target = Server; + + fn deref(&self) -> &Self::Target { + &self.server + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.server.teardown(); + self.test_live_kit_server.teardown().unwrap(); + } +} + +impl Deref for TestClient { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.app_state.client + } +} + +impl TestClient { + pub fn fs(&self) -> &FakeFs { + self.app_state.fs.as_fake() + } + + pub fn channel_store(&self) -> &ModelHandle { + &self.app_state.channel_store + } + + pub fn user_store(&self) -> &ModelHandle { + &self.app_state.user_store + } + + pub fn language_registry(&self) -> &Arc { + &self.app_state.languages + } + + pub fn client(&self) -> &Arc { + &self.app_state.client + } + + pub fn current_user_id(&self, cx: &TestAppContext) -> UserId { + UserId::from_proto( + self.app_state + .user_store + .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), + ) + } + + pub async fn wait_for_current_user(&self, cx: &TestAppContext) { + let mut authed_user = self + .app_state + .user_store + .read_with(cx, |user_store, _| user_store.watch_current_user()); + while authed_user.next().await.unwrap().is_none() {} + } + + pub async fn clear_contacts(&self, cx: &mut TestAppContext) { + self.app_state + .user_store + .update(cx, |store, _| store.clear_contacts()) + .await; + } + + pub fn local_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.local_projects) + } + + pub fn remote_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.remote_projects) + } + + pub fn local_projects_mut<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects) + } + + pub fn remote_projects_mut<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects) + } + + pub fn buffers_for_project<'a>( + &'a self, + project: &ModelHandle, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| { + state.buffers.entry(project.clone()).or_default() + }) + } + + pub fn buffers<'a>( + &'a self, + ) -> impl DerefMut, HashSet>>> + 'a + { + RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) + } + + pub fn channel_buffers<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers) + } + + pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { + self.app_state + .user_store + .read_with(cx, |store, _| ContactsSummary { + current: store + .contacts() + .iter() + .map(|contact| contact.user.github_login.clone()) + .collect(), + outgoing_requests: store + .outgoing_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + incoming_requests: store + .incoming_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + }) + } + + pub async fn build_local_project( + &self, + root_path: impl AsRef, + cx: &mut TestAppContext, + ) -> (ModelHandle, WorktreeId) { + let project = cx.update(|cx| { + Project::local( + self.client().clone(), + self.app_state.user_store.clone(), + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }); + let (worktree, _) = project + .update(cx, |p, cx| { + p.find_or_create_local_worktree(root_path, true, cx) + }) + .await + .unwrap(); + worktree + .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + (project, worktree.read_with(cx, |tree, _| tree.id())) + } + + pub async fn build_remote_project( + &self, + host_project_id: u64, + guest_cx: &mut TestAppContext, + ) -> ModelHandle { + let active_call = guest_cx.read(ActiveCall::global); + let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone()); + room.update(guest_cx, |room, cx| { + room.join_project( + host_project_id, + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }) + .await + .unwrap() + } + + pub fn build_workspace( + &self, + project: &ModelHandle, + cx: &mut TestAppContext, + ) -> WindowHandle { + cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) + } +} + +impl Drop for TestClient { + fn drop(&mut self) { + self.app_state.client.teardown(); + } +} diff --git a/crates/collab_ui/src/channel_view.rs b/crates/collab_ui/src/channel_view.rs index a34f10b2db..5086cc8b37 100644 --- a/crates/collab_ui/src/channel_view.rs +++ b/crates/collab_ui/src/channel_view.rs @@ -213,7 +213,7 @@ impl Item for ChannelView { } fn is_singleton(&self, _cx: &AppContext) -> bool { - true + false } fn navigate(&mut self, data: Box, cx: &mut ViewContext) -> bool { diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 0593bfcb1f..fba10c61ba 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -1106,23 +1106,17 @@ impl CollabPanel { ) -> AnyElement { enum OpenSharedScreen {} - let font_cache = cx.font_cache(); - let host_avatar_height = theme + let host_avatar_width = theme .contact_avatar .width .or(theme.contact_avatar.height) .unwrap_or(0.); - let row = &theme.project_row.inactive_state().default; let tree_branch = theme.tree_branch; - let line_height = row.name.text.line_height(font_cache); - let cap_height = row.name.text.cap_height(font_cache); - let baseline_offset = - row.name.text.baseline_offset(font_cache) + (theme.row_height - line_height) / 2.; MouseEventHandler::new::( peer_id.as_u64() as usize, cx, - |mouse_state, _| { + |mouse_state, cx| { let tree_branch = *tree_branch.in_state(is_selected).style_for(mouse_state); let row = theme .project_row @@ -1130,49 +1124,20 @@ impl CollabPanel { .style_for(mouse_state); Flex::row() - .with_child( - Stack::new() - .with_child(Canvas::new(move |scene, bounds, _, _, _| { - let start_x = bounds.min_x() + (bounds.width() / 2.) - - (tree_branch.width / 2.); - let end_x = bounds.max_x(); - let start_y = bounds.min_y(); - let end_y = bounds.min_y() + baseline_offset - (cap_height / 2.); - - scene.push_quad(gpui::Quad { - bounds: RectF::from_points( - vec2f(start_x, start_y), - vec2f( - start_x + tree_branch.width, - if is_last { end_y } else { bounds.max_y() }, - ), - ), - background: Some(tree_branch.color), - border: gpui::Border::default(), - corner_radii: (0.).into(), - }); - scene.push_quad(gpui::Quad { - bounds: RectF::from_points( - vec2f(start_x, end_y), - vec2f(end_x, end_y + tree_branch.width), - ), - background: Some(tree_branch.color), - border: gpui::Border::default(), - corner_radii: (0.).into(), - }); - })) - .constrained() - .with_width(host_avatar_height), - ) + .with_child(render_tree_branch( + tree_branch, + &row.name.text, + is_last, + vec2f(host_avatar_width, theme.row_height), + cx.font_cache(), + )) .with_child( Svg::new("icons/disable_screen_sharing_12.svg") - .with_color(row.icon.color) + .with_color(theme.channel_hash.color) .constrained() - .with_width(row.icon.width) + .with_width(theme.channel_hash.width) .aligned() - .left() - .contained() - .with_style(row.icon.container), + .left(), ) .with_child( Label::new("Screen", row.name.text.clone()) @@ -2275,7 +2240,8 @@ impl CollabPanel { fn open_channel_buffer(&mut self, action: &OpenChannelBuffer, cx: &mut ViewContext) { if let Some(workspace) = self.workspace.upgrade(cx) { let pane = workspace.read(cx).active_pane().clone(); - let channel_view = ChannelView::open(action.channel_id, pane.clone(), workspace, cx); + let channel_id = action.channel_id; + let channel_view = ChannelView::open(channel_id, pane.clone(), workspace, cx); cx.spawn(|_, mut cx| async move { let channel_view = channel_view.await?; pane.update(&mut cx, |pane, cx| { @@ -2284,6 +2250,18 @@ impl CollabPanel { anyhow::Ok(()) }) .detach(); + let room_id = ActiveCall::global(cx) + .read(cx) + .room() + .map(|room| room.read(cx).id()); + + ActiveCall::report_call_event_for_room( + "open channel notes", + room_id, + Some(channel_id), + &self.client, + cx, + ); } } @@ -2553,27 +2531,16 @@ impl View for CollabPanel { .with_child( Flex::column() .with_child( - Flex::row() - .with_child( - ChildView::new(&self.filter_editor, cx) - .contained() - .with_style(theme.user_query_editor.container) - .flex(1.0, true), - ) - .constrained() - .with_width(self.size(cx)), - ) - .with_child( - List::new(self.list_state.clone()) - .constrained() - .with_width(self.size(cx)) - .flex(1., true) - .into_any(), + Flex::row().with_child( + ChildView::new(&self.filter_editor, cx) + .contained() + .with_style(theme.user_query_editor.container) + .flex(1.0, true), + ), ) + .with_child(List::new(self.list_state.clone()).flex(1., true).into_any()) .contained() .with_style(theme.container) - .constrained() - .with_width(self.size(cx)) .into_any(), ) .with_children( diff --git a/crates/collab_ui/src/collab_titlebar_item.rs b/crates/collab_ui/src/collab_titlebar_item.rs index 684ddca08d..95b9868937 100644 --- a/crates/collab_ui/src/collab_titlebar_item.rs +++ b/crates/collab_ui/src/collab_titlebar_item.rs @@ -213,7 +213,6 @@ impl CollabTitlebarItem { .map(|branch| util::truncate_and_trailoff(&branch, MAX_BRANCH_NAME_LENGTH)); let project_style = theme.titlebar.project_menu_button.clone(); let git_style = theme.titlebar.git_menu_button.clone(); - let divider_style = theme.titlebar.project_name_divider.clone(); let item_spacing = theme.titlebar.item_spacing; let mut ret = Flex::row().with_child( @@ -248,49 +247,37 @@ impl CollabTitlebarItem { ); if let Some(git_branch) = branch_prepended { ret = ret.with_child( - Flex::row() - .with_child( - Label::new("/", divider_style.text) - .contained() - .with_style(divider_style.container) - .aligned() - .left(), - ) - .with_child( - Stack::new() - .with_child( - MouseEventHandler::new::( - 0, - cx, - |mouse_state, cx| { - enum BranchPopoverTooltip {} - let style = git_style - .in_state(self.branch_popover.is_some()) - .style_for(mouse_state); - Label::new(git_branch, style.text.clone()) - .contained() - .with_style(style.container.clone()) - .with_margin_right(item_spacing) - .aligned() - .left() - .with_tooltip::( - 0, - "Recent branches", - Some(Box::new(ToggleVcsMenu)), - theme.tooltip.clone(), - cx, - ) - .into_any_named("title-project-branch") - }, - ) - .with_cursor_style(CursorStyle::PointingHand) - .on_down(MouseButton::Left, move |_, this, cx| { - this.toggle_vcs_menu(&Default::default(), cx) - }) - .on_click(MouseButton::Left, move |_, _, _| {}), - ) - .with_children(self.render_branches_popover_host(&theme.titlebar, cx)), - ), + Flex::row().with_child( + Stack::new() + .with_child( + MouseEventHandler::new::(0, cx, |mouse_state, cx| { + enum BranchPopoverTooltip {} + let style = git_style + .in_state(self.branch_popover.is_some()) + .style_for(mouse_state); + Label::new(git_branch, style.text.clone()) + .contained() + .with_style(style.container.clone()) + .with_margin_right(item_spacing) + .aligned() + .left() + .with_tooltip::( + 0, + "Recent branches", + Some(Box::new(ToggleVcsMenu)), + theme.tooltip.clone(), + cx, + ) + .into_any_named("title-project-branch") + }) + .with_cursor_style(CursorStyle::PointingHand) + .on_down(MouseButton::Left, move |_, this, cx| { + this.toggle_vcs_menu(&Default::default(), cx) + }) + .on_click(MouseButton::Left, move |_, _, _| {}), + ) + .with_children(self.render_branches_popover_host(&theme.titlebar, cx)), + ), ) } ret.into_any() diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index 04644b62d9..ee34f600fa 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -49,7 +49,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) { if room.is_screen_sharing() { ActiveCall::report_call_event_for_room( "disable screen share", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -58,7 +58,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) { } else { ActiveCall::report_call_event_for_room( "enable screen share", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -78,7 +78,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) { if room.is_muted(cx) { ActiveCall::report_call_event_for_room( "enable microphone", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -86,7 +86,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) { } else { ActiveCall::report_call_event_for_room( "disable microphone", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 427134894f..28c20d95bb 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -41,7 +41,7 @@ actions!( [Suggest, NextSuggestion, PreviousSuggestion, Reinstall] ); -pub fn init(http: Arc, node_runtime: Arc, cx: &mut AppContext) { +pub fn init(http: Arc, node_runtime: Arc, cx: &mut AppContext) { let copilot = cx.add_model({ let node_runtime = node_runtime.clone(); move |cx| Copilot::start(http, node_runtime, cx) @@ -265,7 +265,7 @@ pub struct Completion { pub struct Copilot { http: Arc, - node_runtime: Arc, + node_runtime: Arc, server: CopilotServer, buffers: HashSet>, } @@ -299,7 +299,7 @@ impl Copilot { fn start( http: Arc, - node_runtime: Arc, + node_runtime: Arc, cx: &mut ModelContext, ) -> Self { let mut this = Self { @@ -335,12 +335,15 @@ impl Copilot { #[cfg(any(test, feature = "test-support"))] pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle, lsp::FakeLanguageServer) { + use node_runtime::FakeNodeRuntime; + let (server, fake_server) = LanguageServer::fake("copilot".into(), Default::default(), cx.to_async()); let http = util::http::FakeHttpClient::create(|_| async { unreachable!() }); + let node_runtime = FakeNodeRuntime::new(); let this = cx.add_model(|_| Self { http: http.clone(), - node_runtime: NodeRuntime::instance(http), + node_runtime, server: CopilotServer::Running(RunningCopilotServer { lsp: Arc::new(server), sign_in_status: SignInStatus::Authorized, @@ -353,7 +356,7 @@ impl Copilot { fn start_language_server( http: Arc, - node_runtime: Arc, + node_runtime: Arc, this: ModelHandle, mut cx: AsyncAppContext, ) -> impl Future { @@ -1188,7 +1191,7 @@ mod tests { _: u64, _: &clock::Global, _: language::RopeFingerprint, - _: ::fs::LineEnding, + _: language::LineEnding, _: std::time::SystemTime, _: &mut AppContext, ) { diff --git a/crates/editor/src/blink_manager.rs b/crates/editor/src/blink_manager.rs index 24ea4774aa..fa5a3af0c6 100644 --- a/crates/editor/src/blink_manager.rs +++ b/crates/editor/src/blink_manager.rs @@ -37,10 +37,7 @@ impl BlinkManager { } pub fn pause_blinking(&mut self, cx: &mut ModelContext) { - if !self.visible { - self.visible = true; - cx.notify(); - } + self.show_cursor(cx); let epoch = self.next_blink_epoch(); let interval = self.blink_interval; @@ -82,7 +79,13 @@ impl BlinkManager { }) .detach(); } - } else if !self.visible { + } else { + self.show_cursor(cx); + } + } + + pub fn show_cursor(&mut self, cx: &mut ModelContext<'_, BlinkManager>) { + if !self.visible { self.visible = true; cx.notify(); } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index fe8e4e338c..bdd29b04fa 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -44,7 +44,7 @@ use gpui::{ elements::*, executor, fonts::{self, HighlightStyle, TextStyle}, - geometry::vector::Vector2F, + geometry::vector::{vec2f, Vector2F}, impl_actions, keymap_matcher::KeymapContext, platform::{CursorStyle, MouseButton}, @@ -312,6 +312,10 @@ actions!( CopyPath, CopyRelativePath, CopyHighlightJson, + ContextMenuFirst, + ContextMenuPrev, + ContextMenuNext, + ContextMenuLast, ] ); @@ -468,6 +472,10 @@ pub fn init(cx: &mut AppContext) { cx.add_action(Editor::next_copilot_suggestion); cx.add_action(Editor::previous_copilot_suggestion); cx.add_action(Editor::copilot_suggest); + cx.add_action(Editor::context_menu_first); + cx.add_action(Editor::context_menu_prev); + cx.add_action(Editor::context_menu_next); + cx.add_action(Editor::context_menu_last); hover_popover::init(cx); scroll::actions::init(cx); @@ -820,6 +828,7 @@ struct CompletionsMenu { id: CompletionId, initial_position: Anchor, buffer: ModelHandle, + project: Option>, completions: Arc<[Completion]>, match_candidates: Vec, matches: Arc<[StringMatch]>, @@ -863,6 +872,48 @@ impl CompletionsMenu { fn render(&self, style: EditorStyle, cx: &mut ViewContext) -> AnyElement { enum CompletionTag {} + let language_servers = self.project.as_ref().map(|project| { + project + .read(cx) + .language_servers_for_buffer(self.buffer.read(cx), cx) + .filter(|(_, server)| server.capabilities().completion_provider.is_some()) + .map(|(adapter, server)| (server.server_id(), adapter.short_name)) + .collect::>() + }); + let needs_server_name = language_servers + .as_ref() + .map_or(false, |servers| servers.len() > 1); + + let get_server_name = + move |lookup_server_id: lsp::LanguageServerId| -> Option<&'static str> { + language_servers + .iter() + .flatten() + .find_map(|(server_id, server_name)| { + if *server_id == lookup_server_id { + Some(*server_name) + } else { + None + } + }) + }; + + let widest_completion_ix = self + .matches + .iter() + .enumerate() + .max_by_key(|(_, mat)| { + let completion = &self.completions[mat.candidate_id]; + let mut len = completion.label.text.chars().count(); + + if let Some(server_name) = get_server_name(completion.server_id) { + len += server_name.chars().count(); + } + + len + }) + .map(|(ix, _)| ix); + let completions = self.completions.clone(); let matches = self.matches.clone(); let selected_item = self.selected_item; @@ -889,19 +940,83 @@ impl CompletionsMenu { style.autocomplete.item }; - Text::new(completion.label.text.clone(), style.text.clone()) - .with_soft_wrap(false) - .with_highlights(combine_syntax_and_fuzzy_match_highlights( - &completion.label.text, - style.text.color.into(), - styled_runs_for_code_label( - &completion.label, - &style.syntax, - ), - &mat.positions, - )) - .contained() - .with_style(item_style) + let completion_label = + Text::new(completion.label.text.clone(), style.text.clone()) + .with_soft_wrap(false) + .with_highlights( + combine_syntax_and_fuzzy_match_highlights( + &completion.label.text, + style.text.color.into(), + styled_runs_for_code_label( + &completion.label, + &style.syntax, + ), + &mat.positions, + ), + ); + + if let Some(server_name) = get_server_name(completion.server_id) { + Flex::row() + .with_child(completion_label) + .with_children((|| { + if !needs_server_name { + return None; + } + + let text_style = TextStyle { + color: style.autocomplete.server_name_color, + font_size: style.text.font_size + * style.autocomplete.server_name_size_percent, + ..style.text.clone() + }; + + let label = Text::new(server_name, text_style) + .aligned() + .constrained() + .dynamically(move |constraint, _, _| { + gpui::SizeConstraint { + min: constraint.min, + max: vec2f( + constraint.max.x(), + constraint.min.y(), + ), + } + }); + + if Some(item_ix) == widest_completion_ix { + Some( + label + .contained() + .with_style( + style + .autocomplete + .server_name_container, + ) + .into_any(), + ) + } else { + Some(label.flex_float().into_any()) + } + })()) + .into_any() + } else { + completion_label.into_any() + } + .contained() + .with_style(item_style) + .constrained() + .dynamically( + move |constraint, _, _| { + if Some(item_ix) == widest_completion_ix { + constraint + } else { + gpui::SizeConstraint { + min: constraint.min, + max: constraint.min, + } + } + }, + ) }, ) .with_cursor_style(CursorStyle::PointingHand) @@ -918,19 +1033,7 @@ impl CompletionsMenu { } }, ) - .with_width_from_item( - self.matches - .iter() - .enumerate() - .max_by_key(|(_, mat)| { - self.completions[mat.candidate_id] - .label - .text - .chars() - .count() - }) - .map(|(ix, _)| ix), - ) + .with_width_from_item(widest_completion_ix) .contained() .with_style(container_style) .into_any() @@ -1454,6 +1557,16 @@ impl Editor { cx.observe(&display_map, Self::on_display_map_changed), cx.observe(&blink_manager, |_, _, cx| cx.notify()), cx.observe_global::(Self::settings_changed), + cx.observe_window_activation(|editor, active, cx| { + editor.blink_manager.update(cx, |blink_manager, cx| { + if active { + blink_manager.enable(cx); + } else { + blink_manager.show_cursor(cx); + blink_manager.disable(cx); + } + }); + }), ], }; @@ -1549,7 +1662,7 @@ impl Editor { .excerpt_containing(self.selections.newest_anchor().head(), cx) } - fn style(&self, cx: &AppContext) -> EditorStyle { + pub fn style(&self, cx: &AppContext) -> EditorStyle { build_style( settings::get::(cx), self.get_field_editor_theme.as_deref(), @@ -1625,6 +1738,15 @@ impl Editor { self.read_only = read_only; } + pub fn set_field_editor_style( + &mut self, + style: Option>, + cx: &mut ViewContext, + ) { + self.get_field_editor_theme = style; + cx.notify(); + } + pub fn replica_id_map(&self) -> Option<&HashMap> { self.replica_id_mapping.as_ref() } @@ -2964,6 +3086,7 @@ impl Editor { }); let id = post_inc(&mut self.next_completion_id); + let project = self.project.clone(); let task = cx.spawn(|this, mut cx| { async move { let menu = if let Some(completions) = completions.await.log_err() { @@ -2982,6 +3105,7 @@ impl Editor { }) .collect(), buffer, + project, completions: completions.into(), matches: Vec::new().into(), selected_item: 0, @@ -4979,6 +5103,9 @@ impl Editor { self.unmark_text(cx); self.refresh_copilot_suggestions(true, cx); cx.emit(Event::Edited); + cx.emit(Event::TransactionUndone { + transaction_id: tx_id, + }); } } @@ -5047,12 +5174,6 @@ impl Editor { return; } - if let Some(context_menu) = self.context_menu.as_mut() { - if context_menu.select_prev(cx) { - return; - } - } - if matches!(self.mode, EditorMode::SingleLine) { cx.propagate_action(); return; @@ -5075,15 +5196,6 @@ impl Editor { return; } - if self - .context_menu - .as_mut() - .map(|menu| menu.select_first(cx)) - .unwrap_or(false) - { - return; - } - if matches!(self.mode, EditorMode::SingleLine) { cx.propagate_action(); return; @@ -5123,12 +5235,6 @@ impl Editor { pub fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { self.take_rename(true, cx); - if let Some(context_menu) = self.context_menu.as_mut() { - if context_menu.select_next(cx) { - return; - } - } - if self.mode == EditorMode::SingleLine { cx.propagate_action(); return; @@ -5196,6 +5302,30 @@ impl Editor { }); } + pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_first(cx); + } + } + + pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_prev(cx); + } + } + + pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_next(cx); + } + } + + pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_last(cx); + } + } + pub fn move_to_previous_word_start( &mut self, _: &MoveToPreviousWordStart, @@ -8418,6 +8548,9 @@ pub enum Event { local: bool, autoscroll: bool, }, + TransactionUndone { + transaction_id: TransactionId, + }, Closed, } @@ -8458,7 +8591,7 @@ impl View for Editor { "Editor" } - fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { + fn focus_in(&mut self, focused: AnyViewHandle, cx: &mut ViewContext) { if cx.is_self_focused() { let focused_event = EditorFocused(cx.handle()); cx.emit(Event::Focused); @@ -8466,7 +8599,7 @@ impl View for Editor { } if let Some(rename) = self.pending_rename.as_ref() { cx.focus(&rename.editor); - } else { + } else if cx.is_self_focused() || !focused.is::() { if !self.focused { self.blink_manager.update(cx, BlinkManager::enable); } @@ -8544,17 +8677,20 @@ impl View for Editor { if self.pending_rename.is_some() { keymap.add_identifier("renaming"); } - match self.context_menu.as_ref() { - Some(ContextMenu::Completions(_)) => { - keymap.add_identifier("menu"); - keymap.add_identifier("showing_completions") + if self.context_menu_visible() { + match self.context_menu.as_ref() { + Some(ContextMenu::Completions(_)) => { + keymap.add_identifier("menu"); + keymap.add_identifier("showing_completions") + } + Some(ContextMenu::CodeActions(_)) => { + keymap.add_identifier("menu"); + keymap.add_identifier("showing_code_actions") + } + None => {} } - Some(ContextMenu::CodeActions(_)) => { - keymap.add_identifier("menu"); - keymap.add_identifier("showing_code_actions") - } - None => {} } + for layer in self.keymap_context_layers.values() { keymap.extend(layer); } @@ -9161,6 +9297,7 @@ pub fn split_words<'a>(text: &'a str) -> impl std::iter::Iterator(move |_, _| async move { + Ok(Some(lsp::CompletionResponse::Array(vec![ + lsp::CompletionItem { + label: "bg-blue".into(), + ..Default::default() + }, + lsp::CompletionItem { + label: "bg-red".into(), + ..Default::default() + }, + lsp::CompletionItem { + label: "bg-yellow".into(), + ..Default::default() + }, + ]))) + }); + + cx.set_state(r#"

"#); + + // Trigger completion when typing a dash, because the dash is an extra + // word character in the 'element' scope, which contains the cursor. + cx.simulate_keystroke("-"); + cx.foreground().run_until_parked(); + cx.update_editor(|editor, _| { + if let Some(ContextMenu::Completions(menu)) = &editor.context_menu { + assert_eq!( + menu.matches.iter().map(|m| &m.string).collect::>(), + &["bg-red", "bg-blue", "bg-yellow"] + ); + } else { + panic!("expected completion menu to be open"); + } + }); + + cx.simulate_keystroke("l"); + cx.foreground().run_until_parked(); + cx.update_editor(|editor, _| { + if let Some(ContextMenu::Completions(menu)) = &editor.context_menu { + assert_eq!( + menu.matches.iter().map(|m| &m.string).collect::>(), + &["bg-blue", "bg-yellow"] + ); + } else { + panic!("expected completion menu to be open"); + } + }); + + // When filtering completions, consider the character after the '-' to + // be the start of a subword. + cx.set_state(r#"

"#); + cx.simulate_keystroke("l"); + cx.foreground().run_until_parked(); + cx.update_editor(|editor, _| { + if let Some(ContextMenu::Completions(menu)) = &editor.context_menu { + assert_eq!( + menu.matches.iter().map(|m| &m.string).collect::>(), + &["bg-yellow"] + ); + } else { + panic!("expected completion menu to be open"); + } + }); +} + fn empty_range(row: usize, column: usize) -> Range { let point = DisplayPoint::new(row as u32, column as u32); point..point diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 62f4c8c806..90fe6ccc52 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -2251,7 +2251,7 @@ impl Element for EditorElement { let replica_id = if let Some(mapping) = &editor.replica_id_mapping { mapping.get(&replica_id).copied() } else { - None + Some(replica_id) }; // The local selections match the leader's selections. diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index def6340e38..974af4bc24 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -1,8 +1,14 @@ use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint}; -use crate::{char_kind, CharKind, ToPoint}; +use crate::{char_kind, CharKind, ToOffset, ToPoint}; use language::Point; use std::ops::Range; +#[derive(Debug, PartialEq)] +pub enum FindRange { + SingleLine, + MultiLine, +} + pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint { if point.column() > 0 { *point.column_mut() -= 1; @@ -177,20 +183,21 @@ pub fn line_end( pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); - let language = map.buffer_snapshot.language_at(raw_point); + let scope = map.buffer_snapshot.language_scope_at(raw_point); - find_preceding_boundary(map, point, |left, right| { - (char_kind(language, left) != char_kind(language, right) && !right.is_whitespace()) + find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { + (char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace()) || left == '\n' }) } pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); - let language = map.buffer_snapshot.language_at(raw_point); - find_preceding_boundary(map, point, |left, right| { + let scope = map.buffer_snapshot.language_scope_at(raw_point); + + find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { let is_word_start = - char_kind(language, left) != char_kind(language, right) && !right.is_whitespace(); + char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace(); let is_subword_start = left == '_' && right != '_' || left.is_lowercase() && right.is_uppercase(); is_word_start || is_subword_start || left == '\n' @@ -199,19 +206,21 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); - let language = map.buffer_snapshot.language_at(raw_point); - find_boundary(map, point, |left, right| { - (char_kind(language, left) != char_kind(language, right) && !left.is_whitespace()) + let scope = map.buffer_snapshot.language_scope_at(raw_point); + + find_boundary(map, point, FindRange::MultiLine, |left, right| { + (char_kind(&scope, left) != char_kind(&scope, right) && !left.is_whitespace()) || right == '\n' }) } pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); - let language = map.buffer_snapshot.language_at(raw_point); - find_boundary(map, point, |left, right| { + let scope = map.buffer_snapshot.language_scope_at(raw_point); + + find_boundary(map, point, FindRange::MultiLine, |left, right| { let is_word_end = - (char_kind(language, left) != char_kind(language, right)) && !left.is_whitespace(); + (char_kind(&scope, left) != char_kind(&scope, right)) && !left.is_whitespace(); let is_subword_end = left != '_' && right == '_' || left.is_lowercase() && right.is_uppercase(); is_word_end || is_subword_end || right == '\n' @@ -272,79 +281,34 @@ pub fn end_of_paragraph( map.max_point() } -/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. +/// Scans for a boundary preceding the given start point `from` until a boundary is found, +/// indicated by the given predicate returning true. +/// The predicate is called with the character to the left and right of the candidate boundary location. +/// If FindRange::SingleLine is specified and no boundary is found before the start of the current line, the start of the current line will be returned. pub fn find_preceding_boundary( map: &DisplaySnapshot, from: DisplayPoint, + find_range: FindRange, mut is_boundary: impl FnMut(char, char) -> bool, ) -> DisplayPoint { - let mut start_column = 0; - let mut soft_wrap_row = from.row() + 1; + let mut prev_ch = None; + let mut offset = from.to_point(map).to_offset(&map.buffer_snapshot); - let mut prev = None; - for (ch, point) in map.reverse_chars_at(from) { - // Recompute soft_wrap_indent if the row has changed - if point.row() != soft_wrap_row { - soft_wrap_row = point.row(); - - if point.row() == 0 { - start_column = 0; - } else if let Some(indent) = map.soft_wrap_indent(point.row() - 1) { - start_column = indent; - } - } - - // If the current point is in the soft_wrap, skip comparing it - if point.column() < start_column { - continue; - } - - if let Some((prev_ch, prev_point)) = prev { - if is_boundary(ch, prev_ch) { - return map.clip_point(prev_point, Bias::Left); - } - } - - prev = Some((ch, point)); - } - map.clip_point(DisplayPoint::zero(), Bias::Left) -} - -/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. If no boundary is found, the start of the line is returned. -pub fn find_preceding_boundary_in_line( - map: &DisplaySnapshot, - from: DisplayPoint, - mut is_boundary: impl FnMut(char, char) -> bool, -) -> DisplayPoint { - let mut start_column = 0; - if from.row() > 0 { - if let Some(indent) = map.soft_wrap_indent(from.row() - 1) { - start_column = indent; - } - } - - let mut prev = None; - for (ch, point) in map.reverse_chars_at(from) { - if let Some((prev_ch, prev_point)) = prev { - if is_boundary(ch, prev_ch) { - return map.clip_point(prev_point, Bias::Left); - } - } - - if ch == '\n' || point.column() < start_column { + for ch in map.buffer_snapshot.reversed_chars_at(offset) { + if find_range == FindRange::SingleLine && ch == '\n' { break; } + if let Some(prev_ch) = prev_ch { + if is_boundary(ch, prev_ch) { + break; + } + } - prev = Some((ch, point)); + offset -= ch.len_utf8(); + prev_ch = Some(ch); } - map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Left) + map.clip_point(offset.to_display_point(map), Bias::Left) } /// Scans for a boundary following the given start point until a boundary is found, indicated by the @@ -354,59 +318,38 @@ pub fn find_preceding_boundary_in_line( pub fn find_boundary( map: &DisplaySnapshot, from: DisplayPoint, + find_range: FindRange, mut is_boundary: impl FnMut(char, char) -> bool, ) -> DisplayPoint { + let mut offset = from.to_offset(&map, Bias::Right); let mut prev_ch = None; - for (ch, point) in map.chars_at(from) { - if let Some(prev_ch) = prev_ch { - if is_boundary(prev_ch, ch) { - return map.clip_point(point, Bias::Right); - } - } - prev_ch = Some(ch); - } - map.clip_point(map.max_point(), Bias::Right) -} - -/// Scans for a boundary following the given start point until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. If no boundary is found, the end of the line is returned -pub fn find_boundary_in_line( - map: &DisplaySnapshot, - from: DisplayPoint, - mut is_boundary: impl FnMut(char, char) -> bool, -) -> DisplayPoint { - let mut prev = None; - for (ch, point) in map.chars_at(from) { - if let Some((prev_ch, _)) = prev { - if is_boundary(prev_ch, ch) { - return map.clip_point(point, Bias::Right); - } - } - - prev = Some((ch, point)); - - if ch == '\n' { + for ch in map.buffer_snapshot.chars_at(offset) { + if find_range == FindRange::SingleLine && ch == '\n' { break; } - } + if let Some(prev_ch) = prev_ch { + if is_boundary(prev_ch, ch) { + break; + } + } - // Return the last position checked so that we give a point right before the newline or eof. - map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Right) + offset += ch.len_utf8(); + prev_ch = Some(ch); + } + map.clip_point(offset.to_display_point(map), Bias::Right) } pub fn is_inside_word(map: &DisplaySnapshot, point: DisplayPoint) -> bool { let raw_point = point.to_point(map); - let language = map.buffer_snapshot.language_at(raw_point); + let scope = map.buffer_snapshot.language_scope_at(raw_point); let ix = map.clip_point(point, Bias::Left).to_offset(map, Bias::Left); let text = &map.buffer_snapshot; - let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(language, c)); + let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(&scope, c)); let prev_char_kind = text .reversed_chars_at(ix) .next() - .map(|c| char_kind(language, c)); + .map(|c| char_kind(&scope, c)); prev_char_kind.zip(next_char_kind) == Some((CharKind::Word, CharKind::Word)) } @@ -533,7 +476,12 @@ mod tests { ) { let (snapshot, display_points) = marked_display_snapshot(marked_text, cx); assert_eq!( - find_preceding_boundary(&snapshot, display_points[1], is_boundary), + find_preceding_boundary( + &snapshot, + display_points[1], + FindRange::MultiLine, + is_boundary + ), display_points[0] ); } @@ -612,21 +560,15 @@ mod tests { find_preceding_boundary( &snapshot, buffer_snapshot.len().to_display_point(&snapshot), - |left, _| left == 'a', + FindRange::MultiLine, + |left, _| left == 'e', ), - 0.to_display_point(&snapshot), + snapshot + .buffer_snapshot + .offset_to_point(5) + .to_display_point(&snapshot), "Should not stop at inlays when looking for boundaries" ); - - assert_eq!( - find_preceding_boundary_in_line( - &snapshot, - buffer_snapshot.len().to_display_point(&snapshot), - |left, _| left == 'a', - ), - 0.to_display_point(&snapshot), - "Should not stop at inlays when looking for boundaries in line" - ); } #[gpui::test] @@ -699,7 +641,12 @@ mod tests { ) { let (snapshot, display_points) = marked_display_snapshot(marked_text, cx); assert_eq!( - find_boundary(&snapshot, display_points[0], is_boundary), + find_boundary( + &snapshot, + display_points[0], + FindRange::MultiLine, + is_boundary + ), display_points[1] ); } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index 3ace5adbc7..74283fd778 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -617,6 +617,42 @@ impl MultiBuffer { } } + pub fn merge_transactions( + &mut self, + transaction: TransactionId, + destination: TransactionId, + cx: &mut ModelContext, + ) { + if let Some(buffer) = self.as_singleton() { + buffer.update(cx, |buffer, _| { + buffer.merge_transactions(transaction, destination) + }); + } else { + if let Some(transaction) = self.history.forget(transaction) { + if let Some(destination) = self.history.transaction_mut(destination) { + for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions { + if let Some(destination_buffer_transaction_id) = + destination.buffer_transactions.get(&buffer_id) + { + if let Some(state) = self.buffers.borrow().get(&buffer_id) { + state.buffer.update(cx, |buffer, _| { + buffer.merge_transactions( + buffer_transaction_id, + *destination_buffer_transaction_id, + ) + }); + } + } else { + destination + .buffer_transactions + .insert(buffer_id, buffer_transaction_id); + } + } + } + } + } + } + pub fn finalize_last_transaction(&mut self, cx: &mut ModelContext) { self.history.finalize_last_transaction(); for BufferState { buffer, .. } in self.buffers.borrow().values() { @@ -788,6 +824,20 @@ impl MultiBuffer { None } + pub fn undo_transaction(&mut self, transaction_id: TransactionId, cx: &mut ModelContext) { + if let Some(buffer) = self.as_singleton() { + buffer.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); + } else if let Some(transaction) = self.history.remove_from_undo(transaction_id) { + for (buffer_id, transaction_id) in &transaction.buffer_transactions { + if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) { + buffer.update(cx, |buffer, cx| { + buffer.undo_transaction(*transaction_id, cx) + }); + } + } + } + } + pub fn stream_excerpts_with_context_lines( &mut self, buffer: ModelHandle, @@ -1367,13 +1417,13 @@ impl MultiBuffer { return false; } - let language = self.language_at(position.clone(), cx); - - if char_kind(language.as_ref(), char) == CharKind::Word { + let snapshot = self.snapshot(cx); + let position = position.to_offset(&snapshot); + let scope = snapshot.language_scope_at(position); + if char_kind(&scope, char) == CharKind::Word { return true; } - let snapshot = self.snapshot(cx); let anchor = snapshot.anchor_before(position); anchor .buffer_id @@ -1875,8 +1925,8 @@ impl MultiBufferSnapshot { let mut next_chars = self.chars_at(start).peekable(); let mut prev_chars = self.reversed_chars_at(start).peekable(); - let language = self.language_at(start); - let kind = |c| char_kind(language, c); + let scope = self.language_scope_at(start); + let kind = |c| char_kind(&scope, c); let word_kind = cmp::max( prev_chars.peek().copied().map(kind), next_chars.peek().copied().map(kind), @@ -2316,6 +2366,16 @@ impl MultiBufferSnapshot { } } + pub fn prev_non_blank_row(&self, mut row: u32) -> Option { + while row > 0 { + row -= 1; + if !self.is_line_blank(row) { + return Some(row); + } + } + None + } + pub fn line_len(&self, row: u32) -> u32 { if let Some((_, range)) = self.buffer_line_for_row(row) { range.end.column - range.start.column @@ -3347,6 +3407,35 @@ impl History { } } + fn forget(&mut self, transaction_id: TransactionId) -> Option { + if let Some(ix) = self + .undo_stack + .iter() + .rposition(|transaction| transaction.id == transaction_id) + { + Some(self.undo_stack.remove(ix)) + } else if let Some(ix) = self + .redo_stack + .iter() + .rposition(|transaction| transaction.id == transaction_id) + { + Some(self.redo_stack.remove(ix)) + } else { + None + } + } + + fn transaction_mut(&mut self, transaction_id: TransactionId) -> Option<&mut Transaction> { + self.undo_stack + .iter_mut() + .find(|transaction| transaction.id == transaction_id) + .or_else(|| { + self.redo_stack + .iter_mut() + .find(|transaction| transaction.id == transaction_id) + }) + } + fn pop_undo(&mut self) -> Option<&mut Transaction> { assert_eq!(self.transaction_depth, 0); if let Some(transaction) = self.undo_stack.pop() { @@ -3367,6 +3456,16 @@ impl History { } } + fn remove_from_undo(&mut self, transaction_id: TransactionId) -> Option<&Transaction> { + let ix = self + .undo_stack + .iter() + .rposition(|transaction| transaction.id == transaction_id)?; + let transaction = self.undo_stack.remove(ix); + self.redo_stack.push(transaction); + self.redo_stack.last() + } + fn group(&mut self) -> Option { let mut count = 0; let mut transactions = self.undo_stack.iter(); diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index d87bc0ae4f..8233f92a1a 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -378,10 +378,6 @@ impl Editor { return; } - if amount.move_context_menu_selection(self, cx) { - return; - } - let cur_position = self.scroll_position(cx); let new_pos = cur_position + vec2f(0., amount.lines(self)); self.set_scroll_position(new_pos, cx); diff --git a/crates/editor/src/scroll/scroll_amount.rs b/crates/editor/src/scroll/scroll_amount.rs index f9d09adcf5..0edab2bdfc 100644 --- a/crates/editor/src/scroll/scroll_amount.rs +++ b/crates/editor/src/scroll/scroll_amount.rs @@ -1,8 +1,5 @@ -use gpui::ViewContext; -use serde::Deserialize; -use util::iife; - use crate::Editor; +use serde::Deserialize; #[derive(Clone, PartialEq, Deserialize)] pub enum ScrollAmount { @@ -13,25 +10,6 @@ pub enum ScrollAmount { } impl ScrollAmount { - pub fn move_context_menu_selection( - &self, - editor: &mut Editor, - cx: &mut ViewContext, - ) -> bool { - iife!({ - let context_menu = editor.context_menu.as_mut()?; - - match self { - Self::Line(c) if *c > 0. => context_menu.select_next(cx), - Self::Line(_) => context_menu.select_prev(cx), - Self::Page(c) if *c > 0. => context_menu.select_last(cx), - Self::Page(_) => context_menu.select_first(cx), - } - .then_some(()) - }) - .is_some() - } - pub fn lines(&self, editor: &mut Editor) -> f32 { match self { Self::Line(count) => *count, @@ -39,7 +17,7 @@ impl ScrollAmount { .visible_line_count() // subtract one to leave an anchor line // round towards zero (so page-up and page-down are symmetric) - .map(|l| ((l - 1.) * count).trunc()) + .map(|l| (l * count).trunc() - count.signum()) .unwrap_or(0.), } } diff --git a/crates/editor/src/test/editor_lsp_test_context.rs b/crates/editor/src/test/editor_lsp_test_context.rs index 668d6abf21..bbc4911f35 100644 --- a/crates/editor/src/test/editor_lsp_test_context.rs +++ b/crates/editor/src/test/editor_lsp_test_context.rs @@ -51,7 +51,7 @@ impl<'a> EditorLspTestContext<'a> { language .path_suffixes() .first() - .unwrap_or(&"txt".to_string()) + .expect("language must have a path suffix for EditorLspTestContext") ); let mut fake_servers = language diff --git a/crates/fs/Cargo.toml b/crates/fs/Cargo.toml index b3ebd224b0..7584dec21a 100644 --- a/crates/fs/Cargo.toml +++ b/crates/fs/Cargo.toml @@ -12,6 +12,7 @@ collections = { path = "../collections" } gpui = { path = "../gpui" } lsp = { path = "../lsp" } rope = { path = "../rope" } +text = { path = "../text" } util = { path = "../util" } sum_tree = { path = "../sum_tree" } rpc = { path = "../rpc" } diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index ec8a249ff4..ecaee4534e 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -4,14 +4,10 @@ use anyhow::{anyhow, Result}; use fsevent::EventStream; use futures::{future::BoxFuture, Stream, StreamExt}; use git2::Repository as LibGitRepository; -use lazy_static::lazy_static; use parking_lot::Mutex; -use regex::Regex; use repository::GitRepository; use rope::Rope; use smol::io::{AsyncReadExt, AsyncWriteExt}; -use std::borrow::Cow; -use std::cmp; use std::io::Write; use std::sync::Arc; use std::{ @@ -22,6 +18,7 @@ use std::{ time::{Duration, SystemTime}, }; use tempfile::NamedTempFile; +use text::LineEnding; use util::ResultExt; #[cfg(any(test, feature = "test-support"))] @@ -33,66 +30,6 @@ use std::ffi::OsStr; #[cfg(any(test, feature = "test-support"))] use std::sync::Weak; -lazy_static! { - static ref LINE_SEPARATORS_REGEX: Regex = Regex::new("\r\n|\r|\u{2028}|\u{2029}").unwrap(); -} - -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum LineEnding { - Unix, - Windows, -} - -impl Default for LineEnding { - fn default() -> Self { - #[cfg(unix)] - return Self::Unix; - - #[cfg(not(unix))] - return Self::CRLF; - } -} - -impl LineEnding { - pub fn as_str(&self) -> &'static str { - match self { - LineEnding::Unix => "\n", - LineEnding::Windows => "\r\n", - } - } - - pub fn detect(text: &str) -> Self { - let mut max_ix = cmp::min(text.len(), 1000); - while !text.is_char_boundary(max_ix) { - max_ix -= 1; - } - - if let Some(ix) = text[..max_ix].find(&['\n']) { - if ix > 0 && text.as_bytes()[ix - 1] == b'\r' { - Self::Windows - } else { - Self::Unix - } - } else { - Self::default() - } - } - - pub fn normalize(text: &mut String) { - if let Cow::Owned(replaced) = LINE_SEPARATORS_REGEX.replace_all(text, "\n") { - *text = replaced; - } - } - - pub fn normalize_arc(text: Arc) -> Arc { - if let Cow::Owned(replaced) = LINE_SEPARATORS_REGEX.replace_all(&text, "\n") { - replaced.into() - } else { - text - } - } -} - #[async_trait::async_trait] pub trait Fs: Send + Sync { async fn create_dir(&self, path: &Path) -> Result<()>; @@ -520,7 +457,7 @@ impl FakeFsState { } #[cfg(any(test, feature = "test-support"))] -lazy_static! { +lazy_static::lazy_static! { pub static ref FS_DOT_GIT: &'static OsStr = OsStr::new(".git"); } diff --git a/crates/gpui/src/elements/flex.rs b/crates/gpui/src/elements/flex.rs index d43a152a64..80dfb0625c 100644 --- a/crates/gpui/src/elements/flex.rs +++ b/crates/gpui/src/elements/flex.rs @@ -88,8 +88,7 @@ impl Flex { cx: &mut LayoutContext, ) { let cross_axis = self.axis.invert(); - let last = self.children.len() - 1; - for (ix, child) in &mut self.children.iter_mut().enumerate() { + for child in self.children.iter_mut() { if let Some(metadata) = child.metadata::() { if let Some((flex, expanded)) = metadata.flex { if expanded != layout_expanded { @@ -101,10 +100,6 @@ impl Flex { } else { let space_per_flex = *remaining_space / *remaining_flex; space_per_flex * flex - } - if ix == 0 || ix == last { - self.spacing / 2. - } else { - self.spacing }; let child_min = if expanded { child_max } else { 0. }; let child_constraint = match self.axis { @@ -144,13 +139,12 @@ impl Element for Flex { cx: &mut LayoutContext, ) -> (Vector2F, Self::LayoutState) { let mut total_flex = None; - let mut fixed_space = 0.0; + let mut fixed_space = self.children.len().saturating_sub(1) as f32 * self.spacing; let mut contains_float = false; let cross_axis = self.axis.invert(); let mut cross_axis_max: f32 = 0.0; - let last = self.children.len().saturating_sub(1); - for (ix, child) in &mut self.children.iter_mut().enumerate() { + for child in self.children.iter_mut() { let metadata = child.metadata::(); contains_float |= metadata.map_or(false, |metadata| metadata.float); @@ -168,12 +162,7 @@ impl Element for Flex { ), }; let size = child.layout(child_constraint, view, cx); - fixed_space += size.along(self.axis) - + if ix == 0 || ix == last { - self.spacing / 2. - } else { - self.spacing - }; + fixed_space += size.along(self.axis); cross_axis_max = cross_axis_max.max(size.along(cross_axis)); } } @@ -333,8 +322,7 @@ impl Element for Flex { } } - let last = self.children.len().saturating_sub(1); - for (ix, child) in &mut self.children.iter_mut().enumerate() { + for child in self.children.iter_mut() { if remaining_space > 0. { if let Some(metadata) = child.metadata::() { if metadata.float { @@ -372,11 +360,9 @@ impl Element for Flex { child.paint(scene, aligned_child_origin, visible_bounds, view, cx); - let spacing = if ix == last { 0. } else { self.spacing }; - match self.axis { - Axis::Horizontal => child_origin += vec2f(child.size().x() + spacing, 0.0), - Axis::Vertical => child_origin += vec2f(0.0, child.size().y() + spacing), + Axis::Horizontal => child_origin += vec2f(child.size().x() + self.spacing, 0.0), + Axis::Vertical => child_origin += vec2f(0.0, child.size().y() + self.spacing), } } diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 712c854488..474ea8364f 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -106,6 +106,7 @@ pub struct Deterministic { parker: parking_lot::Mutex, } +#[must_use] pub enum Timer { Production(smol::Timer), #[cfg(any(test, feature = "test-support"))] diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index 7f70bc6a91..16f293afbe 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -37,8 +37,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { Some("seed") => starting_seed = parse_int(&meta.lit)?, Some("on_failure") => { if let Lit::Str(name) = meta.lit { - let ident = Ident::new(&name.value(), name.span()); - on_failure_fn_name = quote!(Some(#ident)); + let mut path = syn::Path { + leading_colon: None, + segments: Default::default(), + }; + for part in name.value().split("::") { + path.segments.push(Ident::new(part, name.span()).into()); + } + on_failure_fn_name = quote!(Some(#path)); } else { return Err(TokenStream::from( syn::Error::new( diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 902ed26b57..38b2842c12 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -15,7 +15,6 @@ use crate::{ }; use anyhow::{anyhow, Result}; pub use clock::ReplicaId; -use fs::LineEnding; use futures::FutureExt as _; use gpui::{fonts::HighlightStyle, AppContext, Entity, ModelContext, Task}; use lsp::LanguageServerId; @@ -149,6 +148,7 @@ pub struct Completion { pub old_range: Range, pub new_text: String, pub label: CodeLabel, + pub server_id: LanguageServerId, pub lsp_completion: lsp::CompletionItem, } @@ -439,7 +439,7 @@ impl Buffer { operations.extend( text_operations .iter() - .filter(|(_, op)| !since.observed(op.local_timestamp())) + .filter(|(_, op)| !since.observed(op.timestamp())) .map(|(_, op)| proto::serialize_operation(&Operation::Buffer(op.clone()))), ); operations.sort_unstable_by_key(proto::lamport_timestamp_for_operation); @@ -1298,9 +1298,13 @@ impl Buffer { self.text.forget_transaction(transaction_id); } + pub fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) { + self.text.merge_transactions(transaction, destination); + } + pub fn wait_for_edits( &mut self, - edit_ids: impl IntoIterator, + edit_ids: impl IntoIterator, ) -> impl Future> { self.text.wait_for_edits(edit_ids) } @@ -1358,7 +1362,7 @@ impl Buffer { } } - pub fn set_text(&mut self, text: T, cx: &mut ModelContext) -> Option + pub fn set_text(&mut self, text: T, cx: &mut ModelContext) -> Option where T: Into>, { @@ -1371,7 +1375,7 @@ impl Buffer { edits_iter: I, autoindent_mode: Option, cx: &mut ModelContext, - ) -> Option + ) -> Option where I: IntoIterator, T)>, S: ToOffset, @@ -1408,7 +1412,7 @@ impl Buffer { .and_then(|mode| self.language.as_ref().map(|_| (self.snapshot(), mode))); let edit_operation = self.text.edit(edits.iter().cloned()); - let edit_id = edit_operation.local_timestamp(); + let edit_id = edit_operation.timestamp(); if let Some((before_edit, mode)) = autoindent_request { let mut delta = 0isize; @@ -1664,6 +1668,22 @@ impl Buffer { } } + pub fn undo_transaction( + &mut self, + transaction_id: TransactionId, + cx: &mut ModelContext, + ) -> bool { + let was_dirty = self.is_dirty(); + let old_version = self.version.clone(); + if let Some(operation) = self.text.undo_transaction(transaction_id) { + self.send_operation(Operation::Buffer(operation), cx); + self.did_edit(&old_version, was_dirty, cx); + true + } else { + false + } + } + pub fn undo_to_transaction( &mut self, transaction_id: TransactionId, @@ -2197,8 +2217,8 @@ impl BufferSnapshot { let mut next_chars = self.chars_at(start).peekable(); let mut prev_chars = self.reversed_chars_at(start).peekable(); - let language = self.language_at(start); - let kind = |c| char_kind(language, c); + let scope = self.language_scope_at(start); + let kind = |c| char_kind(&scope, c); let word_kind = cmp::max( prev_chars.peek().copied().map(kind), next_chars.peek().copied().map(kind), @@ -3012,17 +3032,21 @@ pub fn contiguous_ranges( }) } -pub fn char_kind(language: Option<&Arc>, c: char) -> CharKind { +pub fn char_kind(scope: &Option, c: char) -> CharKind { if c.is_whitespace() { return CharKind::Whitespace; } else if c.is_alphanumeric() || c == '_' { return CharKind::Word; } - if let Some(language) = language { - if language.config.word_characters.contains(&c) { - return CharKind::Word; + + if let Some(scope) = scope { + if let Some(characters) = scope.word_characters() { + if characters.contains(&c) { + return CharKind::Word; + } } } + CharKind::Punctuation } diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index db3749aa25..3bedf5b7a8 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -5,7 +5,6 @@ use crate::language_settings::{ use super::*; use clock::ReplicaId; use collections::BTreeMap; -use fs::LineEnding; use gpui::{AppContext, ModelHandle}; use indoc::indoc; use proto::deserialize_operation; @@ -20,6 +19,7 @@ use std::{ time::{Duration, Instant}, }; use text::network::Network; +use text::LineEnding; use unindent::Unindent as _; use util::{assert_set_eq, post_inc, test::marked_text_ranges, RandomCharIter}; diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 7a9e6b83ce..2193b5c07e 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -46,7 +46,7 @@ use theme::{SyntaxTheme, Theme}; use tree_sitter::{self, Query}; use unicase::UniCase; use util::{http::HttpClient, paths::PathExt}; -use util::{merge_json_value_into, post_inc, ResultExt, TryFutureExt as _, UnwrapFuture}; +use util::{post_inc, ResultExt, TryFutureExt as _, UnwrapFuture}; #[cfg(any(test, feature = "test-support"))] use futures::channel::mpsc; @@ -57,6 +57,7 @@ pub use diagnostic_set::DiagnosticEntry; pub use lsp::LanguageServerId; pub use outline::{Outline, OutlineItem}; pub use syntax_map::{OwnedSyntaxLayerInfo, SyntaxLayerInfo}; +pub use text::LineEnding; pub use tree_sitter::{Parser, Tree}; pub fn init(cx: &mut AppContext) { @@ -90,6 +91,7 @@ pub struct LanguageServerName(pub Arc); /// once at startup, and caches the results. pub struct CachedLspAdapter { pub name: LanguageServerName, + pub short_name: &'static str, pub initialization_options: Option, pub disk_based_diagnostic_sources: Vec, pub disk_based_diagnostics_progress_token: Option, @@ -100,6 +102,7 @@ pub struct CachedLspAdapter { impl CachedLspAdapter { pub async fn new(adapter: Arc) -> Arc { let name = adapter.name().await; + let short_name = adapter.short_name(); let initialization_options = adapter.initialization_options().await; let disk_based_diagnostic_sources = adapter.disk_based_diagnostic_sources().await; let disk_based_diagnostics_progress_token = @@ -108,6 +111,7 @@ impl CachedLspAdapter { Arc::new(CachedLspAdapter { name, + short_name, initialization_options, disk_based_diagnostic_sources, disk_based_diagnostics_progress_token, @@ -175,10 +179,7 @@ impl CachedLspAdapter { self.adapter.code_action_kinds() } - pub fn workspace_configuration( - &self, - cx: &mut AppContext, - ) -> Option> { + pub fn workspace_configuration(&self, cx: &mut AppContext) -> BoxFuture<'static, Value> { self.adapter.workspace_configuration(cx) } @@ -219,6 +220,8 @@ pub trait LspAdapterDelegate: Send + Sync { pub trait LspAdapter: 'static + Send + Sync { async fn name(&self) -> LanguageServerName; + fn short_name(&self) -> &'static str; + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, @@ -287,8 +290,8 @@ pub trait LspAdapter: 'static + Send + Sync { None } - fn workspace_configuration(&self, _: &mut AppContext) -> Option> { - None + fn workspace_configuration(&self, _: &mut AppContext) -> BoxFuture<'static, Value> { + futures::future::ready(serde_json::json!({})).boxed() } fn code_action_kinds(&self) -> Option> { @@ -343,6 +346,8 @@ pub struct LanguageConfig { #[serde(default)] pub block_comment: Option<(Arc, Arc)>, #[serde(default)] + pub scope_opt_in_language_servers: Vec, + #[serde(default)] pub overrides: HashMap, #[serde(default)] pub word_characters: HashSet, @@ -373,6 +378,10 @@ pub struct LanguageConfigOverride { pub block_comment: Override<(Arc, Arc)>, #[serde(skip_deserializing)] pub disabled_bracket_ixs: Vec, + #[serde(default)] + pub word_characters: Override>, + #[serde(default)] + pub opt_into_language_servers: Vec, } #[derive(Clone, Deserialize, Debug)] @@ -411,6 +420,7 @@ impl Default for LanguageConfig { autoclose_before: Default::default(), line_comment: Default::default(), block_comment: Default::default(), + scope_opt_in_language_servers: Default::default(), overrides: Default::default(), collapsed_placeholder: Default::default(), word_characters: Default::default(), @@ -685,41 +695,6 @@ impl LanguageRegistry { result } - pub fn workspace_configuration(&self, cx: &mut AppContext) -> Task { - let lsp_adapters = { - let state = self.state.read(); - state - .available_languages - .iter() - .filter(|l| !l.loaded) - .flat_map(|l| l.lsp_adapters.clone()) - .chain( - state - .languages - .iter() - .flat_map(|language| &language.adapters) - .map(|adapter| adapter.adapter.clone()), - ) - .collect::>() - }; - - let mut language_configs = Vec::new(); - for adapter in &lsp_adapters { - if let Some(language_config) = adapter.workspace_configuration(cx) { - language_configs.push(language_config); - } - } - - cx.background().spawn(async move { - let mut config = serde_json::json!({}); - let language_configs = futures::future::join_all(language_configs).await; - for language_config in language_configs { - merge_json_value_into(language_config, &mut config); - } - config - }) - } - pub fn add(&self, language: Arc) { self.state.write().add(language); } @@ -1383,13 +1358,23 @@ impl Language { Ok(self) } - pub fn with_override_query(mut self, source: &str) -> Result { + pub fn with_override_query(mut self, source: &str) -> anyhow::Result { let query = Query::new(self.grammar_mut().ts_language, source)?; let mut override_configs_by_id = HashMap::default(); for (ix, name) in query.capture_names().iter().enumerate() { if !name.starts_with('_') { let value = self.config.overrides.remove(name).unwrap_or_default(); + for server_name in &value.opt_into_language_servers { + if !self + .config + .scope_opt_in_language_servers + .contains(server_name) + { + util::debug_panic!("Server {server_name:?} has been opted-in by scope {name:?} but has not been marked as an opt-in server"); + } + } + override_configs_by_id.insert(ix as u32, (name.clone(), value)); } } @@ -1595,6 +1580,13 @@ impl LanguageScope { .map(|e| (&e.0, &e.1)) } + pub fn word_characters(&self) -> Option<&HashSet> { + Override::as_option( + self.config_override().map(|o| &o.word_characters), + Some(&self.language.config.word_characters), + ) + } + pub fn brackets(&self) -> impl Iterator { let mut disabled_ids = self .config_override() @@ -1621,6 +1613,20 @@ impl LanguageScope { c.is_whitespace() || self.language.config.autoclose_before.contains(c) } + pub fn language_allowed(&self, name: &LanguageServerName) -> bool { + let config = &self.language.config; + let opt_in_servers = &config.scope_opt_in_language_servers; + if opt_in_servers.iter().any(|o| *o == *name.0) { + if let Some(over) = self.config_override() { + over.opt_into_language_servers.iter().any(|o| *o == *name.0) + } else { + false + } + } else { + true + } + } + fn config_override(&self) -> Option<&LanguageConfigOverride> { let id = self.override_id?; let grammar = self.language.grammar.as_ref()?; @@ -1725,6 +1731,10 @@ impl LspAdapter for Arc { LanguageServerName(self.name.into()) } + fn short_name(&self) -> &'static str { + "FakeLspAdapter" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 09c5ec7fc3..c4abe39d47 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -20,17 +20,17 @@ pub fn deserialize_fingerprint(fingerprint: &str) -> Result { .map_err(|error| anyhow!("invalid fingerprint: {}", error)) } -pub fn deserialize_line_ending(message: proto::LineEnding) -> fs::LineEnding { +pub fn deserialize_line_ending(message: proto::LineEnding) -> text::LineEnding { match message { - proto::LineEnding::Unix => fs::LineEnding::Unix, - proto::LineEnding::Windows => fs::LineEnding::Windows, + proto::LineEnding::Unix => text::LineEnding::Unix, + proto::LineEnding::Windows => text::LineEnding::Windows, } } -pub fn serialize_line_ending(message: fs::LineEnding) -> proto::LineEnding { +pub fn serialize_line_ending(message: text::LineEnding) -> proto::LineEnding { match message { - fs::LineEnding::Unix => proto::LineEnding::Unix, - fs::LineEnding::Windows => proto::LineEnding::Windows, + text::LineEnding::Unix => proto::LineEnding::Unix, + text::LineEnding::Windows => proto::LineEnding::Windows, } } @@ -41,24 +41,22 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation { proto::operation::Variant::Edit(serialize_edit_operation(edit)) } - crate::Operation::Buffer(text::Operation::Undo { - undo, - lamport_timestamp, - }) => proto::operation::Variant::Undo(proto::operation::Undo { - replica_id: undo.id.replica_id as u32, - local_timestamp: undo.id.value, - lamport_timestamp: lamport_timestamp.value, - version: serialize_version(&undo.version), - counts: undo - .counts - .iter() - .map(|(edit_id, count)| proto::UndoCount { - replica_id: edit_id.replica_id as u32, - local_timestamp: edit_id.value, - count: *count, - }) - .collect(), - }), + crate::Operation::Buffer(text::Operation::Undo(undo)) => { + proto::operation::Variant::Undo(proto::operation::Undo { + replica_id: undo.timestamp.replica_id as u32, + lamport_timestamp: undo.timestamp.value, + version: serialize_version(&undo.version), + counts: undo + .counts + .iter() + .map(|(edit_id, count)| proto::UndoCount { + replica_id: edit_id.replica_id as u32, + lamport_timestamp: edit_id.value, + count: *count, + }) + .collect(), + }) + } crate::Operation::UpdateSelections { selections, @@ -101,8 +99,7 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation { pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::Edit { proto::operation::Edit { replica_id: operation.timestamp.replica_id as u32, - local_timestamp: operation.timestamp.local, - lamport_timestamp: operation.timestamp.lamport, + lamport_timestamp: operation.timestamp.value, version: serialize_version(&operation.version), ranges: operation.ranges.iter().map(serialize_range).collect(), new_text: operation @@ -114,7 +111,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation:: } pub fn serialize_undo_map_entry( - (edit_id, counts): (&clock::Local, &[(clock::Local, u32)]), + (edit_id, counts): (&clock::Lamport, &[(clock::Lamport, u32)]), ) -> proto::UndoMapEntry { proto::UndoMapEntry { replica_id: edit_id.replica_id as u32, @@ -123,13 +120,38 @@ pub fn serialize_undo_map_entry( .iter() .map(|(undo_id, count)| proto::UndoCount { replica_id: undo_id.replica_id as u32, - local_timestamp: undo_id.value, + lamport_timestamp: undo_id.value, count: *count, }) .collect(), } } +pub fn split_operations( + mut operations: Vec, +) -> impl Iterator> { + #[cfg(any(test, feature = "test-support"))] + const CHUNK_SIZE: usize = 5; + + #[cfg(not(any(test, feature = "test-support")))] + const CHUNK_SIZE: usize = 100; + + let mut done = false; + std::iter::from_fn(move || { + if done { + return None; + } + + let operations = operations + .drain(..std::cmp::min(CHUNK_SIZE, operations.len())) + .collect::>(); + if operations.is_empty() { + done = true; + } + Some(operations) + }) +} + pub fn serialize_selections(selections: &Arc<[Selection]>) -> Vec { selections.iter().map(serialize_selection).collect() } @@ -197,7 +219,7 @@ pub fn serialize_diagnostics<'a>( pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor { proto::Anchor { replica_id: anchor.timestamp.replica_id as u32, - local_timestamp: anchor.timestamp.value, + timestamp: anchor.timestamp.value, offset: anchor.offset as u64, bias: match anchor.bias { Bias::Left => proto::Bias::Left as i32, @@ -218,32 +240,26 @@ pub fn deserialize_operation(message: proto::Operation) -> Result { - crate::Operation::Buffer(text::Operation::Undo { - lamport_timestamp: clock::Lamport { + crate::Operation::Buffer(text::Operation::Undo(UndoOperation { + timestamp: clock::Lamport { replica_id: undo.replica_id as ReplicaId, value: undo.lamport_timestamp, }, - undo: UndoOperation { - id: clock::Local { - replica_id: undo.replica_id as ReplicaId, - value: undo.local_timestamp, - }, - version: deserialize_version(&undo.version), - counts: undo - .counts - .into_iter() - .map(|c| { - ( - clock::Local { - replica_id: c.replica_id as ReplicaId, - value: c.local_timestamp, - }, - c.count, - ) - }) - .collect(), - }, - }) + version: deserialize_version(&undo.version), + counts: undo + .counts + .into_iter() + .map(|c| { + ( + clock::Lamport { + replica_id: c.replica_id as ReplicaId, + value: c.lamport_timestamp, + }, + c.count, + ) + }) + .collect(), + })) } proto::operation::Variant::UpdateSelections(message) => { let selections = message @@ -298,10 +314,9 @@ pub fn deserialize_operation(message: proto::Operation) -> Result EditOperation { EditOperation { - timestamp: InsertionTimestamp { + timestamp: clock::Lamport { replica_id: edit.replica_id as ReplicaId, - local: edit.local_timestamp, - lamport: edit.lamport_timestamp, + value: edit.lamport_timestamp, }, version: deserialize_version(&edit.version), ranges: edit.ranges.into_iter().map(deserialize_range).collect(), @@ -311,9 +326,9 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation pub fn deserialize_undo_map_entry( entry: proto::UndoMapEntry, -) -> (clock::Local, Vec<(clock::Local, u32)>) { +) -> (clock::Lamport, Vec<(clock::Lamport, u32)>) { ( - clock::Local { + clock::Lamport { replica_id: entry.replica_id as u16, value: entry.local_timestamp, }, @@ -322,9 +337,9 @@ pub fn deserialize_undo_map_entry( .into_iter() .map(|undo_count| { ( - clock::Local { + clock::Lamport { replica_id: undo_count.replica_id as u16, - value: undo_count.local_timestamp, + value: undo_count.lamport_timestamp, }, undo_count.count, ) @@ -384,9 +399,9 @@ pub fn deserialize_diagnostics( pub fn deserialize_anchor(anchor: proto::Anchor) -> Option { Some(Anchor { - timestamp: clock::Local { + timestamp: clock::Lamport { replica_id: anchor.replica_id as ReplicaId, - value: anchor.local_timestamp, + value: anchor.timestamp, }, offset: anchor.offset as usize, bias: match proto::Bias::from_i32(anchor.bias)? { @@ -434,6 +449,7 @@ pub fn serialize_completion(completion: &Completion) -> proto::Completion { old_start: Some(serialize_anchor(&completion.old_range.start)), old_end: Some(serialize_anchor(&completion.old_range.end)), new_text: completion.new_text.clone(), + server_id: completion.server_id.0 as u64, lsp_completion: serde_json::to_vec(&completion.lsp_completion).unwrap(), } } @@ -466,6 +482,7 @@ pub async fn deserialize_completion( lsp_completion.filter_text.as_deref(), ) }), + server_id: LanguageServerId(completion.server_id as usize), lsp_completion, }) } @@ -498,12 +515,12 @@ pub fn deserialize_code_action(action: proto::CodeAction) -> Result pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction { proto::Transaction { - id: Some(serialize_local_timestamp(transaction.id)), + id: Some(serialize_timestamp(transaction.id)), edit_ids: transaction .edit_ids .iter() .copied() - .map(serialize_local_timestamp) + .map(serialize_timestamp) .collect(), start: serialize_version(&transaction.start), } @@ -511,7 +528,7 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction { pub fn deserialize_transaction(transaction: proto::Transaction) -> Result { Ok(Transaction { - id: deserialize_local_timestamp( + id: deserialize_timestamp( transaction .id .ok_or_else(|| anyhow!("missing transaction id"))?, @@ -519,21 +536,21 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result proto::LocalTimestamp { - proto::LocalTimestamp { +pub fn serialize_timestamp(timestamp: clock::Lamport) -> proto::LamportTimestamp { + proto::LamportTimestamp { replica_id: timestamp.replica_id as u32, value: timestamp.value, } } -pub fn deserialize_local_timestamp(timestamp: proto::LocalTimestamp) -> clock::Local { - clock::Local { +pub fn deserialize_timestamp(timestamp: proto::LamportTimestamp) -> clock::Lamport { + clock::Lamport { replica_id: timestamp.replica_id as ReplicaId, value: timestamp.value, } @@ -553,7 +570,7 @@ pub fn deserialize_range(range: proto::Range) -> Range { pub fn deserialize_version(message: &[proto::VectorClockEntry]) -> clock::Global { let mut version = clock::Global::new(); for entry in message { - version.observe(clock::Local { + version.observe(clock::Lamport { replica_id: entry.replica_id as ReplicaId, value: entry.timestamp, }); diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 51bdb4c5ce..a918e3d151 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -12,6 +12,7 @@ use gpui::{ ViewHandle, WeakModelHandle, }; use language::{Buffer, LanguageServerId, LanguageServerName}; +use lsp::IoKind; use project::{Project, Worktree}; use std::{borrow::Cow, sync::Arc}; use theme::{ui, Theme}; @@ -26,7 +27,7 @@ const RECEIVE_LINE: &str = "// Receive:\n"; pub struct LogStore { projects: HashMap, ProjectState>, - io_tx: mpsc::UnboundedSender<(WeakModelHandle, LanguageServerId, bool, String)>, + io_tx: mpsc::UnboundedSender<(WeakModelHandle, LanguageServerId, IoKind, String)>, } struct ProjectState { @@ -37,12 +38,12 @@ struct ProjectState { struct LanguageServerState { log_buffer: ModelHandle, rpc_state: Option, + _subscription: Option, } struct LanguageServerRpcState { buffer: ModelHandle, last_message_kind: Option, - _subscription: lsp::Subscription, } pub struct LspLogView { @@ -118,11 +119,11 @@ impl LogStore { io_tx, }; cx.spawn_weak(|this, mut cx| async move { - while let Some((project, server_id, is_output, mut message)) = io_rx.next().await { + while let Some((project, server_id, io_kind, mut message)) = io_rx.next().await { if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { message.push('\n'); - this.on_io(project, server_id, is_output, &message, cx); + this.on_io(project, server_id, io_kind, &message, cx); }); } } @@ -168,22 +169,29 @@ impl LogStore { cx: &mut ModelContext, ) -> Option> { let project_state = self.projects.get_mut(&project.downgrade())?; - Some( - project_state - .servers - .entry(id) - .or_insert_with(|| { - cx.notify(); - LanguageServerState { - rpc_state: None, - log_buffer: cx - .add_model(|cx| Buffer::new(0, cx.model_id() as u64, "")) - .clone(), - } - }) - .log_buffer - .clone(), - ) + let server_state = project_state.servers.entry(id).or_insert_with(|| { + cx.notify(); + LanguageServerState { + rpc_state: None, + log_buffer: cx + .add_model(|cx| Buffer::new(0, cx.model_id() as u64, "")) + .clone(), + _subscription: None, + } + }); + + let server = project.read(cx).language_server_for_id(id); + let weak_project = project.downgrade(); + let io_tx = self.io_tx.clone(); + server_state._subscription = server.map(|server| { + server.on_io(move |io_kind, message| { + io_tx + .unbounded_send((weak_project, id, io_kind, message.to_string())) + .ok(); + }) + }); + + Some(server_state.log_buffer.clone()) } fn add_language_server_log( @@ -230,7 +238,7 @@ impl LogStore { Some(server_state.log_buffer.clone()) } - pub fn enable_rpc_trace_for_language_server( + fn enable_rpc_trace_for_language_server( &mut self, project: &ModelHandle, server_id: LanguageServerId, @@ -239,9 +247,7 @@ impl LogStore { let weak_project = project.downgrade(); let project_state = self.projects.get_mut(&weak_project)?; let server_state = project_state.servers.get_mut(&server_id)?; - let server = project.read(cx).language_server_for_id(server_id)?; let rpc_state = server_state.rpc_state.get_or_insert_with(|| { - let io_tx = self.io_tx.clone(); let language = project.read(cx).languages().language_for_name("JSON"); let buffer = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "")); cx.spawn_weak({ @@ -258,11 +264,6 @@ impl LogStore { LanguageServerRpcState { buffer, last_message_kind: None, - _subscription: server.on_io(move |is_received, json| { - io_tx - .unbounded_send((weak_project, server_id, is_received, json.to_string())) - .ok(); - }), } }); Some(rpc_state.buffer.clone()) @@ -285,10 +286,25 @@ impl LogStore { &mut self, project: WeakModelHandle, language_server_id: LanguageServerId, - is_received: bool, + io_kind: IoKind, message: &str, cx: &mut AppContext, ) -> Option<()> { + let is_received = match io_kind { + IoKind::StdOut => true, + IoKind::StdIn => false, + IoKind::StdErr => { + let project = project.upgrade(cx)?; + project.update(cx, |_, cx| { + cx.emit(project::Event::LanguageServerLog( + language_server_id, + format!("stderr: {}\n", message.trim()), + )) + }); + return Some(()); + } + }; + let state = self .projects .get_mut(&project)? @@ -554,10 +570,12 @@ impl View for LspLogToolbarItemView { let Some(log_view) = self.log_view.as_ref() else { return Empty::new().into_any(); }; - let log_view = log_view.read(cx); - let menu_rows = log_view.menu_items(cx).unwrap_or_default(); + let (menu_rows, current_server_id) = log_view.update(cx, |log_view, cx| { + let menu_rows = log_view.menu_items(cx).unwrap_or_default(); + let current_server_id = log_view.current_server_id; + (menu_rows, current_server_id) + }); - let current_server_id = log_view.current_server_id; let current_server = current_server_id.and_then(|current_server_id| { if let Ok(ix) = menu_rows.binary_search_by_key(¤t_server_id, |e| e.server_id) { Some(menu_rows[ix].clone()) @@ -565,10 +583,10 @@ impl View for LspLogToolbarItemView { None } }); + let server_selected = current_server.is_some(); enum Menu {} - - Stack::new() + let lsp_menu = Stack::new() .with_child(Self::render_language_server_menu_header( current_server, &theme, @@ -615,8 +633,47 @@ impl View for LspLogToolbarItemView { }) .aligned() .left() - .clipped() - .into_any() + .clipped(); + + enum LspCleanupButton {} + let log_cleanup_button = + MouseEventHandler::new::(1, cx, |state, cx| { + let theme = theme::current(cx).clone(); + let style = theme + .workspace + .toolbar + .toggleable_text_tool + .in_state(server_selected) + .style_for(state); + Label::new("Clear", style.text.clone()) + .aligned() + .contained() + .with_style(style.container) + .constrained() + .with_height(theme.toolbar_dropdown_menu.row_height / 6.0 * 5.0) + }) + .on_click(MouseButton::Left, move |_, this, cx| { + if let Some(log_view) = this.log_view.as_ref() { + log_view.update(cx, |log_view, cx| { + log_view.editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.clear(cx); + editor.set_read_only(true); + }); + }) + } + }) + .with_cursor_style(CursorStyle::PointingHand) + .aligned() + .right(); + + Flex::row() + .with_child(lsp_menu) + .with_child(log_cleanup_button) + .contained() + .aligned() + .left() + .into_any_named("lsp log controls") } } diff --git a/crates/lsp/Cargo.toml b/crates/lsp/Cargo.toml index 47e0995c85..653c25b7bb 100644 --- a/crates/lsp/Cargo.toml +++ b/crates/lsp/Cargo.toml @@ -20,7 +20,7 @@ anyhow.workspace = true async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553", optional = true } futures.workspace = true log.workspace = true -lsp-types = "0.94" +lsp-types = { git = "https://github.com/zed-industries/lsp-types", branch = "updated-completion-list-item-defaults" } parking_lot.workspace = true postage.workspace = true serde.workspace = true diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index d49dafff2f..dcfce4f1fb 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -4,7 +4,7 @@ pub use lsp_types::*; use anyhow::{anyhow, Context, Result}; use collections::HashMap; -use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite}; +use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite, FutureExt}; use gpui::{executor, AsyncAppContext, Task}; use parking_lot::Mutex; use postage::{barrier, prelude::Stream}; @@ -26,16 +26,25 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, Arc, Weak, }, + time::{Duration, Instant}, }; use std::{path::Path, process::Stdio}; use util::{ResultExt, TryFutureExt}; const JSON_RPC_VERSION: &str = "2.0"; const CONTENT_LEN_HEADER: &str = "Content-Length: "; +const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); type NotificationHandler = Box, &str, AsyncAppContext)>; type ResponseHandler = Box)>; -type IoHandler = Box; +type IoHandler = Box; + +#[derive(Debug, Clone, Copy)] +pub enum IoKind { + StdOut, + StdIn, + StdErr, +} #[derive(Debug, Clone, Deserialize)] pub struct LanguageServerBinary { @@ -144,16 +153,18 @@ impl LanguageServer { .args(binary.arguments) .stdin(Stdio::piped()) .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) + .stderr(Stdio::piped()) .kill_on_drop(true) .spawn()?; let stdin = server.stdin.take().unwrap(); - let stout = server.stdout.take().unwrap(); + let stdout = server.stdout.take().unwrap(); + let stderr = server.stderr.take().unwrap(); let mut server = Self::new_internal( server_id.clone(), stdin, - stout, + stdout, + Some(stderr), Some(server), root_path, code_action_kinds, @@ -181,10 +192,11 @@ impl LanguageServer { Ok(server) } - fn new_internal( + fn new_internal( server_id: LanguageServerId, stdin: Stdin, stdout: Stdout, + stderr: Option, server: Option, root_path: &Path, code_action_kinds: Option>, @@ -194,7 +206,8 @@ impl LanguageServer { where Stdin: AsyncWrite + Unpin + Send + 'static, Stdout: AsyncRead + Unpin + Send + 'static, - F: FnMut(AnyNotification) + 'static + Send, + Stderr: AsyncRead + Unpin + Send + 'static, + F: FnMut(AnyNotification) + 'static + Send + Clone, { let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -203,17 +216,27 @@ impl LanguageServer { let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); let io_handlers = Arc::new(Mutex::new(HashMap::default())); - let input_task = cx.spawn(|cx| { - Self::handle_input( - stdout, - on_unhandled_notification, - notification_handlers.clone(), - response_handlers.clone(), - io_handlers.clone(), - cx, - ) + + let stdout_input_task = cx.spawn(|cx| { + { + Self::handle_input( + stdout, + on_unhandled_notification.clone(), + notification_handlers.clone(), + response_handlers.clone(), + io_handlers.clone(), + cx, + ) + } .log_err() }); + let stderr_input_task = stderr + .map(|stderr| cx.spawn(|_| Self::handle_stderr(stderr, io_handlers.clone()).log_err())) + .unwrap_or_else(|| Task::Ready(Some(None))); + let input_task = cx.spawn(|_| async move { + let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); + stdout.or(stderr) + }); let output_task = cx.background().spawn({ Self::handle_output( stdin, @@ -282,9 +305,9 @@ impl LanguageServer { stdout.read_exact(&mut buffer).await?; if let Ok(message) = str::from_utf8(&buffer) { - log::trace!("incoming message:{}", message); + log::trace!("incoming message: {}", message); for handler in io_handlers.lock().values_mut() { - handler(true, message); + handler(IoKind::StdOut, message); } } @@ -327,6 +350,30 @@ impl LanguageServer { } } + async fn handle_stderr( + stderr: Stderr, + io_handlers: Arc>>, + ) -> anyhow::Result<()> + where + Stderr: AsyncRead + Unpin + Send + 'static, + { + let mut stderr = BufReader::new(stderr); + let mut buffer = Vec::new(); + loop { + buffer.clear(); + stderr.read_until(b'\n', &mut buffer).await?; + if let Ok(message) = str::from_utf8(&buffer) { + log::trace!("incoming stderr message:{message}"); + for handler in io_handlers.lock().values_mut() { + handler(IoKind::StdErr, message); + } + } + + // Don't starve the main thread when receiving lots of messages at once. + smol::future::yield_now().await; + } + } + async fn handle_output( stdin: Stdin, outbound_rx: channel::Receiver, @@ -348,7 +395,7 @@ impl LanguageServer { while let Ok(message) = outbound_rx.recv().await { log::trace!("outgoing message:{}", message); for handler in io_handlers.lock().values_mut() { - handler(false, &message); + handler(IoKind::StdIn, &message); } content_len_buffer.clear(); @@ -423,6 +470,14 @@ impl LanguageServer { }), ..Default::default() }), + completion_list: Some(CompletionListCapability { + item_defaults: Some(vec![ + "commitCharacters".to_owned(), + "editRange".to_owned(), + "insertTextMode".to_owned(), + "data".to_owned(), + ]), + }), ..Default::default() }), rename: Some(RenameClientCapabilities { @@ -532,7 +587,7 @@ impl LanguageServer { #[must_use] pub fn on_io(&self, f: F) -> Subscription where - F: 'static + Send + FnMut(bool, &str), + F: 'static + Send + FnMut(IoKind, &str), { let id = self.next_id.fetch_add(1, SeqCst); self.io_handlers.lock().insert(id, Box::new(f)); @@ -695,7 +750,7 @@ impl LanguageServer { outbound_tx: &channel::Sender, executor: &Arc, params: T::Params, - ) -> impl 'static + Future> + ) -> impl 'static + Future> where T::Result: 'static + Send, { @@ -736,10 +791,25 @@ impl LanguageServer { .try_send(message) .context("failed to write to language server's stdin"); + let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse(); + let started = Instant::now(); async move { handle_response?; send?; - rx.await? + + let method = T::METHOD; + futures::select! { + response = rx.fuse() => { + let elapsed = started.elapsed(); + log::trace!("Took {elapsed:?} to recieve response to {method:?} id {id}"); + response? + } + + _ = timeout => { + log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}"); + anyhow::bail!("LSP request timeout"); + } + } } } @@ -851,6 +921,7 @@ impl LanguageServer { LanguageServerId(0), stdin_writer, stdout_reader, + None::, None, Path::new("/"), None, @@ -862,6 +933,7 @@ impl LanguageServer { LanguageServerId(0), stdout_writer, stdin_reader, + None::, None, Path::new("/"), None, diff --git a/crates/node_runtime/Cargo.toml b/crates/node_runtime/Cargo.toml index 53635f2725..2b9503468a 100644 --- a/crates/node_runtime/Cargo.toml +++ b/crates/node_runtime/Cargo.toml @@ -14,6 +14,7 @@ util = { path = "../util" } async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] } async-tar = "0.4.2" futures.workspace = true +async-trait.workspace = true anyhow.workspace = true parking_lot.workspace = true serde.workspace = true diff --git a/crates/node_runtime/src/node_runtime.rs b/crates/node_runtime/src/node_runtime.rs index d43c14ec7b..820a8b6f81 100644 --- a/crates/node_runtime/src/node_runtime.rs +++ b/crates/node_runtime/src/node_runtime.rs @@ -7,14 +7,12 @@ use std::process::{Output, Stdio}; use std::{ env::consts, path::{Path, PathBuf}, - sync::{Arc, OnceLock}, + sync::Arc, }; use util::http::HttpClient; const VERSION: &str = "v18.15.0"; -static RUNTIME_INSTANCE: OnceLock> = OnceLock::new(); - #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct NpmInfo { @@ -28,23 +26,88 @@ pub struct NpmInfoDistTags { latest: Option, } -pub struct NodeRuntime { +#[async_trait::async_trait] +pub trait NodeRuntime: Send + Sync { + async fn binary_path(&self) -> Result; + + async fn run_npm_subcommand( + &self, + directory: Option<&Path>, + subcommand: &str, + args: &[&str], + ) -> Result; + + async fn npm_package_latest_version(&self, name: &str) -> Result; + + async fn npm_install_packages(&self, directory: &Path, packages: &[(&str, &str)]) + -> Result<()>; +} + +pub struct RealNodeRuntime { http: Arc, } -impl NodeRuntime { - pub fn instance(http: Arc) -> Arc { - RUNTIME_INSTANCE - .get_or_init(|| Arc::new(NodeRuntime { http })) - .clone() +impl RealNodeRuntime { + pub fn new(http: Arc) -> Arc { + Arc::new(RealNodeRuntime { http }) } - pub async fn binary_path(&self) -> Result { + async fn install_if_needed(&self) -> Result { + log::info!("Node runtime install_if_needed"); + + let arch = match consts::ARCH { + "x86_64" => "x64", + "aarch64" => "arm64", + other => bail!("Running on unsupported platform: {other}"), + }; + + let folder_name = format!("node-{VERSION}-darwin-{arch}"); + let node_containing_dir = util::paths::SUPPORT_DIR.join("node"); + let node_dir = node_containing_dir.join(folder_name); + let node_binary = node_dir.join("bin/node"); + let npm_file = node_dir.join("bin/npm"); + + let result = Command::new(&node_binary) + .arg(npm_file) + .arg("--version") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .await; + let valid = matches!(result, Ok(status) if status.success()); + + if !valid { + _ = fs::remove_dir_all(&node_containing_dir).await; + fs::create_dir(&node_containing_dir) + .await + .context("error creating node containing dir")?; + + let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz"); + let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}"); + let mut response = self + .http + .get(&url, Default::default(), true) + .await + .context("error downloading Node binary tarball")?; + + let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); + let archive = Archive::new(decompressed_bytes); + archive.unpack(&node_containing_dir).await?; + } + + anyhow::Ok(node_dir) + } +} + +#[async_trait::async_trait] +impl NodeRuntime for RealNodeRuntime { + async fn binary_path(&self) -> Result { let installation_path = self.install_if_needed().await?; Ok(installation_path.join("bin/node")) } - pub async fn run_npm_subcommand( + async fn run_npm_subcommand( &self, directory: Option<&Path>, subcommand: &str, @@ -106,7 +169,7 @@ impl NodeRuntime { output.map_err(|e| anyhow!("{e}")) } - pub async fn npm_package_latest_version(&self, name: &str) -> Result { + async fn npm_package_latest_version(&self, name: &str) -> Result { let output = self .run_npm_subcommand( None, @@ -131,10 +194,10 @@ impl NodeRuntime { .ok_or_else(|| anyhow!("no version found for npm package {}", name)) } - pub async fn npm_install_packages( + async fn npm_install_packages( &self, directory: &Path, - packages: impl IntoIterator, + packages: &[(&str, &str)], ) -> Result<()> { let packages: Vec<_> = packages .into_iter() @@ -155,51 +218,31 @@ impl NodeRuntime { .await?; Ok(()) } +} - async fn install_if_needed(&self) -> Result { - log::info!("Node runtime install_if_needed"); +pub struct FakeNodeRuntime; - let arch = match consts::ARCH { - "x86_64" => "x64", - "aarch64" => "arm64", - other => bail!("Running on unsupported platform: {other}"), - }; - - let folder_name = format!("node-{VERSION}-darwin-{arch}"); - let node_containing_dir = util::paths::SUPPORT_DIR.join("node"); - let node_dir = node_containing_dir.join(folder_name); - let node_binary = node_dir.join("bin/node"); - let npm_file = node_dir.join("bin/npm"); - - let result = Command::new(&node_binary) - .arg(npm_file) - .arg("--version") - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .await; - let valid = matches!(result, Ok(status) if status.success()); - - if !valid { - _ = fs::remove_dir_all(&node_containing_dir).await; - fs::create_dir(&node_containing_dir) - .await - .context("error creating node containing dir")?; - - let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz"); - let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}"); - let mut response = self - .http - .get(&url, Default::default(), true) - .await - .context("error downloading Node binary tarball")?; - - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = Archive::new(decompressed_bytes); - archive.unpack(&node_containing_dir).await?; - } - - anyhow::Ok(node_dir) +impl FakeNodeRuntime { + pub fn new() -> Arc { + Arc::new(FakeNodeRuntime) + } +} + +#[async_trait::async_trait] +impl NodeRuntime for FakeNodeRuntime { + async fn binary_path(&self) -> Result { + unreachable!() + } + + async fn run_npm_subcommand(&self, _: Option<&Path>, _: &str, _: &[&str]) -> Result { + unreachable!() + } + + async fn npm_package_latest_version(&self, _: &str) -> Result { + unreachable!() + } + + async fn npm_install_packages(&self, _: &Path, _: &[(&str, &str)]) -> Result<()> { + unreachable!() } } diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index 8239cf8690..8beaea5031 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -6,7 +6,6 @@ use crate::{ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use client::proto::{self, PeerId}; -use fs::LineEnding; use futures::future; use gpui::{AppContext, AsyncAppContext, ModelHandle}; use language::{ @@ -17,8 +16,12 @@ use language::{ CodeAction, Completion, OffsetRangeExt, PointUtf16, ToOffset, ToPointUtf16, Transaction, Unclipped, }; -use lsp::{DocumentHighlightKind, LanguageServer, LanguageServerId, OneOf, ServerCapabilities}; +use lsp::{ + CompletionListItemDefaultsEditRange, DocumentHighlightKind, LanguageServer, LanguageServerId, + OneOf, ServerCapabilities, +}; use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc}; +use text::LineEnding; pub fn lsp_formatting_options(tab_size: u32) -> lsp::FormattingOptions { lsp::FormattingOptions { @@ -1340,13 +1343,19 @@ impl LspCommand for GetCompletions { completions: Option, _: ModelHandle, buffer: ModelHandle, - _: LanguageServerId, + server_id: LanguageServerId, cx: AsyncAppContext, ) -> Result> { + let mut response_list = None; let completions = if let Some(completions) = completions { match completions { lsp::CompletionResponse::Array(completions) => completions, - lsp::CompletionResponse::List(list) => list.items, + + lsp::CompletionResponse::List(mut list) => { + let items = std::mem::take(&mut list.items); + response_list = Some(list); + items + } } } else { Default::default() @@ -1356,6 +1365,7 @@ impl LspCommand for GetCompletions { let language = buffer.language().cloned(); let snapshot = buffer.snapshot(); let clipped_position = buffer.clip_point_utf16(Unclipped(self.position), Bias::Left); + let mut range_for_token = None; completions .into_iter() @@ -1376,6 +1386,7 @@ impl LspCommand for GetCompletions { edit.new_text.clone(), ) } + // If the language server does not provide a range, then infer // the range based on the syntax tree. None => { @@ -1383,27 +1394,51 @@ impl LspCommand for GetCompletions { log::info!("completion out of expected range"); return None; } - let Range { start, end } = range_for_token - .get_or_insert_with(|| { - let offset = self.position.to_offset(&snapshot); - let (range, kind) = snapshot.surrounding_word(offset); - if kind == Some(CharKind::Word) { - range - } else { - offset..offset - } - }) - .clone(); + + let default_edit_range = response_list + .as_ref() + .and_then(|list| list.item_defaults.as_ref()) + .and_then(|defaults| defaults.edit_range.as_ref()) + .and_then(|range| match range { + CompletionListItemDefaultsEditRange::Range(r) => Some(r), + _ => None, + }); + + let range = if let Some(range) = default_edit_range { + let range = range_from_lsp(range.clone()); + let start = snapshot.clip_point_utf16(range.start, Bias::Left); + let end = snapshot.clip_point_utf16(range.end, Bias::Left); + if start != range.start.0 || end != range.end.0 { + log::info!("completion out of expected range"); + return None; + } + + snapshot.anchor_before(start)..snapshot.anchor_after(end) + } else { + range_for_token + .get_or_insert_with(|| { + let offset = self.position.to_offset(&snapshot); + let (range, kind) = snapshot.surrounding_word(offset); + let range = if kind == Some(CharKind::Word) { + range + } else { + offset..offset + }; + + snapshot.anchor_before(range.start) + ..snapshot.anchor_after(range.end) + }) + .clone() + }; + let text = lsp_completion .insert_text .as_ref() .unwrap_or(&lsp_completion.label) .clone(); - ( - snapshot.anchor_before(start)..snapshot.anchor_after(end), - text, - ) + (range, text) } + Some(lsp::CompletionTextEdit::InsertAndReplace(_)) => { log::info!("unsupported insert/replace completion"); return None; @@ -1427,6 +1462,7 @@ impl LspCommand for GetCompletions { lsp_completion.filter_text.as_deref(), ) }), + server_id, lsp_completion, } }) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index f839c8d5c5..0690cc9188 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -35,7 +35,7 @@ use language::{ point_to_lsp, proto::{ deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version, - serialize_anchor, serialize_version, + serialize_anchor, serialize_version, split_operations, }, range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent, @@ -156,6 +156,11 @@ struct DelayedDebounced { cancel_channel: Option>, } +enum LanguageServerToQuery { + Primary, + Other(LanguageServerId), +} + impl DelayedDebounced { fn new() -> DelayedDebounced { DelayedDebounced { @@ -634,7 +639,7 @@ impl Project { cx.observe_global::(Self::on_settings_changed) ], _maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx), - _maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx), + _maintain_workspace_config: Self::maintain_workspace_config(cx), active_entry: None, languages, client, @@ -704,7 +709,7 @@ impl Project { collaborators: Default::default(), join_project_response_message_id: response.message_id, _maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx), - _maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx), + _maintain_workspace_config: Self::maintain_workspace_config(cx), languages, user_store: user_store.clone(), fs, @@ -2472,35 +2477,42 @@ impl Project { }) } - fn maintain_workspace_config( - languages: Arc, - cx: &mut ModelContext, - ) -> Task<()> { + fn maintain_workspace_config(cx: &mut ModelContext) -> Task<()> { let (mut settings_changed_tx, mut settings_changed_rx) = watch::channel(); let _ = postage::stream::Stream::try_recv(&mut settings_changed_rx); let settings_observation = cx.observe_global::(move |_, _| { *settings_changed_tx.borrow_mut() = (); }); + cx.spawn_weak(|this, mut cx| async move { while let Some(_) = settings_changed_rx.next().await { - let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await; - if let Some(this) = this.upgrade(&cx) { - this.read_with(&cx, |this, _| { - for server_state in this.language_servers.values() { - if let LanguageServerState::Running { server, .. } = server_state { - server - .notify::( - lsp::DidChangeConfigurationParams { - settings: workspace_config.clone(), - }, - ) - .ok(); - } - } - }) - } else { + let Some(this) = this.upgrade(&cx) else { break; + }; + + let servers: Vec<_> = this.read_with(&cx, |this, _| { + this.language_servers + .values() + .filter_map(|state| match state { + LanguageServerState::Starting(_) => None, + LanguageServerState::Running { + adapter, server, .. + } => Some((adapter.clone(), server.clone())), + }) + .collect() + }); + + for (adapter, server) in servers { + let workspace_config = + cx.update(|cx| adapter.workspace_configuration(cx)).await; + server + .notify::( + lsp::DidChangeConfigurationParams { + settings: workspace_config.clone(), + }, + ) + .ok(); } } @@ -2615,7 +2627,6 @@ impl Project { let state = LanguageServerState::Starting({ let adapter = adapter.clone(); let server_name = adapter.name.0.clone(); - let languages = self.languages.clone(); let language = language.clone(); let key = key.clone(); @@ -2625,7 +2636,6 @@ impl Project { initialization_options, pending_server, adapter.clone(), - languages, language.clone(), server_id, key, @@ -2729,7 +2739,6 @@ impl Project { initialization_options: Option, pending_server: PendingLanguageServer, adapter: Arc, - languages: Arc, language: Arc, server_id: LanguageServerId, key: (WorktreeId, LanguageServerName), @@ -2740,7 +2749,6 @@ impl Project { initialization_options, pending_server, adapter.clone(), - languages, server_id, cx, ); @@ -2773,16 +2781,13 @@ impl Project { initialization_options: Option, pending_server: PendingLanguageServer, adapter: Arc, - languages: Arc, server_id: LanguageServerId, cx: &mut AsyncAppContext, ) -> Result>> { - let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await; + let workspace_config = cx.update(|cx| adapter.workspace_configuration(cx)).await; let language_server = match pending_server.task.await? { - Some(server) => server.initialize(initialization_options).await?, - None => { - return Ok(None); - } + Some(server) => server, + None => return Ok(None), }; language_server @@ -2821,12 +2826,12 @@ impl Project { language_server .on_request::({ - let languages = languages.clone(); + let adapter = adapter.clone(); move |params, mut cx| { - let languages = languages.clone(); + let adapter = adapter.clone(); async move { let workspace_config = - cx.update(|cx| languages.workspace_configuration(cx)).await; + cx.update(|cx| adapter.workspace_configuration(cx)).await; Ok(params .items .into_iter() @@ -2932,6 +2937,8 @@ impl Project { }) .detach(); + let language_server = language_server.initialize(initialization_options).await?; + language_server .notify::( lsp::DidChangeConfigurationParams { @@ -3892,7 +3899,7 @@ impl Project { let file = File::from_dyn(buffer.file())?; let buffer_abs_path = file.as_local().map(|f| f.abs_path(cx)); let server = self - .primary_language_servers_for_buffer(buffer, cx) + .primary_language_server_for_buffer(buffer, cx) .map(|s| s.1.clone()); Some((buffer_handle, buffer_abs_path, server)) }) @@ -4197,7 +4204,12 @@ impl Project { cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetDefinition { position }, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::Primary, + GetDefinition { position }, + cx, + ) } pub fn type_definition( @@ -4207,7 +4219,12 @@ impl Project { cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetTypeDefinition { position }, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::Primary, + GetTypeDefinition { position }, + cx, + ) } pub fn references( @@ -4217,7 +4234,12 @@ impl Project { cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetReferences { position }, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::Primary, + GetReferences { position }, + cx, + ) } pub fn document_highlights( @@ -4227,7 +4249,12 @@ impl Project { cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetDocumentHighlights { position }, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::Primary, + GetDocumentHighlights { position }, + cx, + ) } pub fn symbols(&self, query: &str, cx: &mut ModelContext) -> Task>> { @@ -4455,17 +4482,66 @@ impl Project { cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetHover { position }, cx) + self.request_lsp( + buffer.clone(), + LanguageServerToQuery::Primary, + GetHover { position }, + cx, + ) } - pub fn completions( + pub fn completions( &self, buffer: &ModelHandle, position: T, cx: &mut ModelContext, ) -> Task>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer.clone(), GetCompletions { position }, cx) + if self.is_local() { + let snapshot = buffer.read(cx).snapshot(); + let offset = position.to_offset(&snapshot); + let scope = snapshot.language_scope_at(offset); + + let server_ids: Vec<_> = self + .language_servers_for_buffer(buffer.read(cx), cx) + .filter(|(_, server)| server.capabilities().completion_provider.is_some()) + .filter(|(adapter, _)| { + scope + .as_ref() + .map(|scope| scope.language_allowed(&adapter.name)) + .unwrap_or(true) + }) + .map(|(_, server)| server.server_id()) + .collect(); + + let buffer = buffer.clone(); + cx.spawn(|this, mut cx| async move { + let mut tasks = Vec::with_capacity(server_ids.len()); + this.update(&mut cx, |this, cx| { + for server_id in server_ids { + tasks.push(this.request_lsp( + buffer.clone(), + LanguageServerToQuery::Other(server_id), + GetCompletions { position }, + cx, + )); + } + }); + + let mut completions = Vec::new(); + for task in tasks { + if let Ok(new_completions) = task.await { + completions.extend_from_slice(&new_completions); + } + } + + Ok(completions) + }) + } else if let Some(project_id) = self.remote_id() { + self.send_lsp_proto_request(buffer.clone(), project_id, GetCompletions { position }, cx) + } else { + Task::ready(Ok(Default::default())) + } } pub fn apply_additional_edits_for_completion( @@ -4479,7 +4555,8 @@ impl Project { let buffer_id = buffer.remote_id(); if self.is_local() { - let lang_server = match self.primary_language_servers_for_buffer(buffer, cx) { + let server_id = completion.server_id; + let lang_server = match self.language_server_for_buffer(buffer, server_id, cx) { Some((_, server)) => server.clone(), _ => return Task::ready(Ok(Default::default())), }; @@ -4586,7 +4663,12 @@ impl Project { ) -> Task>> { let buffer = buffer_handle.read(cx); let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end); - self.request_lsp(buffer_handle.clone(), GetCodeActions { range }, cx) + self.request_lsp( + buffer_handle.clone(), + LanguageServerToQuery::Primary, + GetCodeActions { range }, + cx, + ) } pub fn apply_code_action( @@ -4942,7 +5024,12 @@ impl Project { cx: &mut ModelContext, ) -> Task>>> { let position = position.to_point_utf16(buffer.read(cx)); - self.request_lsp(buffer, PrepareRename { position }, cx) + self.request_lsp( + buffer, + LanguageServerToQuery::Primary, + PrepareRename { position }, + cx, + ) } pub fn perform_rename( @@ -4956,6 +5043,7 @@ impl Project { let position = position.to_point_utf16(buffer.read(cx)); self.request_lsp( buffer, + LanguageServerToQuery::Primary, PerformRename { position, new_name, @@ -4983,6 +5071,7 @@ impl Project { }); self.request_lsp( buffer.clone(), + LanguageServerToQuery::Primary, OnTypeFormatting { position, trigger, @@ -5008,7 +5097,12 @@ impl Project { let lsp_request = InlayHints { range }; if self.is_local() { - let lsp_request_task = self.request_lsp(buffer_handle.clone(), lsp_request, cx); + let lsp_request_task = self.request_lsp( + buffer_handle.clone(), + LanguageServerToQuery::Primary, + lsp_request, + cx, + ); cx.spawn(|_, mut cx| async move { buffer_handle .update(&mut cx, |buffer, _| { @@ -5441,10 +5535,10 @@ impl Project { .await; } - // TODO: Wire this up to allow selecting a server? fn request_lsp( &self, buffer_handle: ModelHandle, + server: LanguageServerToQuery, request: R, cx: &mut ModelContext, ) -> Task> @@ -5453,11 +5547,19 @@ impl Project { { let buffer = buffer_handle.read(cx); if self.is_local() { + let language_server = match server { + LanguageServerToQuery::Primary => { + match self.primary_language_server_for_buffer(buffer, cx) { + Some((_, server)) => Some(Arc::clone(server)), + None => return Task::ready(Ok(Default::default())), + } + } + LanguageServerToQuery::Other(id) => self + .language_server_for_buffer(buffer, id, cx) + .map(|(_, server)| Arc::clone(server)), + }; let file = File::from_dyn(buffer.file()).and_then(File::as_local); - if let Some((file, language_server)) = file.zip( - self.primary_language_servers_for_buffer(buffer, cx) - .map(|(_, server)| server.clone()), - ) { + if let (Some(file), Some(language_server)) = (file, language_server) { let lsp_params = request.to_lsp(&file.abs_path(cx), buffer, &language_server, cx); return cx.spawn(|this, cx| async move { if !request.check_capabilities(language_server.capabilities()) { @@ -5490,31 +5592,40 @@ impl Project { }); } } else if let Some(project_id) = self.remote_id() { - let rpc = self.client.clone(); - let message = request.to_proto(project_id, buffer); - return cx.spawn_weak(|this, cx| async move { - // Ensure the project is still alive by the time the task - // is scheduled. - this.upgrade(&cx) - .ok_or_else(|| anyhow!("project dropped"))?; - - let response = rpc.request(message).await?; - - let this = this - .upgrade(&cx) - .ok_or_else(|| anyhow!("project dropped"))?; - if this.read_with(&cx, |this, _| this.is_read_only()) { - Err(anyhow!("disconnected before completing request")) - } else { - request - .response_from_proto(response, this, buffer_handle, cx) - .await - } - }); + return self.send_lsp_proto_request(buffer_handle, project_id, request, cx); } + Task::ready(Ok(Default::default())) } + fn send_lsp_proto_request( + &self, + buffer: ModelHandle, + project_id: u64, + request: R, + cx: &mut ModelContext<'_, Project>, + ) -> Task::Response>> { + let rpc = self.client.clone(); + let message = request.to_proto(project_id, buffer.read(cx)); + cx.spawn_weak(|this, cx| async move { + // Ensure the project is still alive by the time the task + // is scheduled. + this.upgrade(&cx) + .ok_or_else(|| anyhow!("project dropped"))?; + let response = rpc.request(message).await?; + let this = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("project dropped"))?; + if this.read_with(&cx, |this, _| this.is_read_only()) { + Err(anyhow!("disconnected before completing request")) + } else { + request + .response_from_proto(response, this, buffer, cx) + .await + } + }) + } + fn sort_candidates_and_open_buffers( mut matching_paths_rx: Receiver, cx: &mut ModelContext, @@ -7150,7 +7261,7 @@ impl Project { let buffer_version = buffer_handle.read_with(&cx, |buffer, _| buffer.version()); let response = this .update(&mut cx, |this, cx| { - this.request_lsp(buffer_handle, request, cx) + this.request_lsp(buffer_handle, LanguageServerToQuery::Primary, request, cx) }) .await?; this.update(&mut cx, |this, cx| { @@ -7867,7 +7978,7 @@ impl Project { }) } - fn primary_language_servers_for_buffer( + fn primary_language_server_for_buffer( &self, buffer: &Buffer, cx: &AppContext, @@ -8089,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate { } } -fn split_operations( - mut operations: Vec, -) -> impl Iterator> { - #[cfg(any(test, feature = "test-support"))] - const CHUNK_SIZE: usize = 5; - - #[cfg(not(any(test, feature = "test-support")))] - const CHUNK_SIZE: usize = 100; - - let mut done = false; - std::iter::from_fn(move || { - if done { - return None; - } - - let operations = operations - .drain(..cmp::min(CHUNK_SIZE, operations.len())) - .collect::>(); - if operations.is_empty() { - done = true; - } - Some(operations) - }) -} - fn serialize_symbol(symbol: &Symbol) -> proto::Symbol { proto::Symbol { language_server_name: symbol.language_server_name.0.to_string(), diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 7c5983a0a9..b6adb371e1 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -1,11 +1,11 @@ use crate::{search::PathMatcher, worktree::WorktreeModelHandle, Event, *}; -use fs::{FakeFs, LineEnding, RealFs}; +use fs::{FakeFs, RealFs}; use futures::{future, StreamExt}; use gpui::{executor::Deterministic, test::subscribe, AppContext}; use language::{ language_settings::{AllLanguageSettings, LanguageSettingsContent}, tree_sitter_rust, tree_sitter_typescript, Diagnostic, FakeLspAdapter, LanguageConfig, - OffsetRangeExt, Point, ToPoint, + LineEnding, OffsetRangeExt, Point, ToPoint, }; use lsp::Url; use parking_lot::Mutex; @@ -2272,7 +2272,18 @@ async fn test_completions_without_edit_ranges(cx: &mut gpui::TestAppContext) { }, Some(tree_sitter_typescript::language_typescript()), ); - let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + let mut fake_language_servers = language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![":".to_string()]), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + })) + .await; let fs = FakeFs::new(cx.background()); fs.insert_tree( @@ -2358,7 +2369,18 @@ async fn test_completions_with_carriage_returns(cx: &mut gpui::TestAppContext) { }, Some(tree_sitter_typescript::language_typescript()), ); - let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await; + let mut fake_language_servers = language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![":".to_string()]), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + })) + .await; let fs = FakeFs::new(cx.background()); fs.insert_tree( diff --git a/crates/project/src/search.rs b/crates/project/src/search.rs index a3c6583052..6c53d2e934 100644 --- a/crates/project/src/search.rs +++ b/crates/project/src/search.rs @@ -225,15 +225,14 @@ impl SearchQuery { if self.as_str().is_empty() { return Default::default(); } - let language = buffer.language_at(0); + + let range_offset = subrange.as_ref().map(|r| r.start).unwrap_or(0); let rope = if let Some(range) = subrange { buffer.as_rope().slice(range) } else { buffer.as_rope().clone() }; - let kind = |c| char_kind(language, c); - let mut matches = Vec::new(); match self { Self::Text { @@ -249,6 +248,9 @@ impl SearchQuery { let mat = mat.unwrap(); if *whole_word { + let scope = buffer.language_scope_at(range_offset + mat.start()); + let kind = |c| char_kind(&scope, c); + let prev_kind = rope.reversed_chars_at(mat.start()).next().map(kind); let start_kind = kind(rope.chars_at(mat.start()).next().unwrap()); let end_kind = kind(rope.reversed_chars_at(mat.end()).next().unwrap()); diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index e6e0f37cc7..2de3671033 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -8,7 +8,7 @@ use clock::ReplicaId; use collections::{HashMap, HashSet, VecDeque}; use fs::{ repository::{GitFileStatus, GitRepository, RepoPath}, - Fs, LineEnding, + Fs, }; use futures::{ channel::{ @@ -27,7 +27,7 @@ use language::{ deserialize_fingerprint, deserialize_version, serialize_fingerprint, serialize_line_ending, serialize_version, }, - Buffer, DiagnosticEntry, File as _, PointUtf16, Rope, RopeFingerprint, Unclipped, + Buffer, DiagnosticEntry, File as _, LineEnding, PointUtf16, Rope, RopeFingerprint, Unclipped, }; use lsp::LanguageServerId; use parking_lot::Mutex; diff --git a/crates/quick_action_bar/Cargo.toml b/crates/quick_action_bar/Cargo.toml index 6953ac0e02..1f8ec4e92b 100644 --- a/crates/quick_action_bar/Cargo.toml +++ b/crates/quick_action_bar/Cargo.toml @@ -9,6 +9,7 @@ path = "src/quick_action_bar.rs" doctest = false [dependencies] +ai = { path = "../ai" } editor = { path = "../editor" } gpui = { path = "../gpui" } search = { path = "../search" } diff --git a/crates/quick_action_bar/src/quick_action_bar.rs b/crates/quick_action_bar/src/quick_action_bar.rs index 1804c2b1fc..b3d9784f1f 100644 --- a/crates/quick_action_bar/src/quick_action_bar.rs +++ b/crates/quick_action_bar/src/quick_action_bar.rs @@ -1,25 +1,29 @@ +use ai::{assistant::InlineAssist, AssistantPanel}; use editor::Editor; use gpui::{ elements::{Empty, Flex, MouseEventHandler, ParentElement, Svg}, platform::{CursorStyle, MouseButton}, Action, AnyElement, Element, Entity, EventContext, Subscription, View, ViewContext, ViewHandle, + WeakViewHandle, }; use search::{buffer_search, BufferSearchBar}; -use workspace::{item::ItemHandle, ToolbarItemLocation, ToolbarItemView}; +use workspace::{item::ItemHandle, ToolbarItemLocation, ToolbarItemView, Workspace}; pub struct QuickActionBar { buffer_search_bar: ViewHandle, active_item: Option>, _inlay_hints_enabled_subscription: Option, + workspace: WeakViewHandle, } impl QuickActionBar { - pub fn new(buffer_search_bar: ViewHandle) -> Self { + pub fn new(buffer_search_bar: ViewHandle, workspace: &Workspace) -> Self { Self { buffer_search_bar, active_item: None, _inlay_hints_enabled_subscription: None, + workspace: workspace.weak_handle(), } } @@ -88,6 +92,21 @@ impl View for QuickActionBar { )); } + bar.add_child(render_quick_action_bar_button( + 2, + "icons/radix/magic-wand.svg", + false, + ("Inline Assist".into(), Some(Box::new(InlineAssist))), + cx, + move |this, cx| { + if let Some(workspace) = this.workspace.upgrade(cx) { + workspace.update(cx, |workspace, cx| { + AssistantPanel::inline_assist(workspace, &Default::default(), cx); + }); + } + }, + )); + bar.into_any() } } @@ -152,9 +171,10 @@ impl ToolbarItemView for QuickActionBar { cx.notify(); } })); + ToolbarItemLocation::PrimaryRight { flex: None } + } else { + ToolbarItemLocation::Hidden } - - ToolbarItemLocation::PrimaryRight { flex: None } } None => { self.active_item = None; diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 2bfb090bb2..9c764c468e 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -384,6 +384,16 @@ impl<'a> From<&'a str> for Rope { } } +impl<'a> FromIterator<&'a str> for Rope { + fn from_iter>(iter: T) -> Self { + let mut rope = Rope::new(); + for chunk in iter { + rope.push(chunk); + } + rope + } +} + impl From for Rope { fn from(text: String) -> Self { Rope::from(text.as_str()) diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 94d6075ecf..2e96d79f5e 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1,6 +1,8 @@ syntax = "proto3"; package zed.messages; +// Looking for a number? Search "// Current max" + message PeerId { uint32 owner_id = 1; uint32 id = 2; @@ -151,6 +153,9 @@ message Envelope { LeaveChannelBuffer leave_channel_buffer = 134; AddChannelBufferCollaborator add_channel_buffer_collaborator = 135; RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136; + UpdateChannelBufferCollaborator update_channel_buffer_collaborator = 139; + RejoinChannelBuffers rejoin_channel_buffers = 140; + RejoinChannelBuffersResponse rejoin_channel_buffers_response = 141; // Current max } } @@ -430,6 +435,12 @@ message RemoveChannelBufferCollaborator { PeerId peer_id = 2; } +message UpdateChannelBufferCollaborator { + uint64 channel_id = 1; + PeerId old_peer_id = 2; + PeerId new_peer_id = 3; +} + message GetDefinition { uint64 project_id = 1; uint64 buffer_id = 2; @@ -616,6 +627,12 @@ message BufferVersion { repeated VectorClockEntry version = 2; } +message ChannelBufferVersion { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + uint64 epoch = 3; +} + enum FormatTrigger { Save = 0; Manual = 1; @@ -657,7 +674,8 @@ message Completion { Anchor old_start = 1; Anchor old_end = 2; string new_text = 3; - bytes lsp_completion = 4; + uint64 server_id = 4; + bytes lsp_completion = 5; } message GetCodeActions { @@ -860,12 +878,12 @@ message ProjectTransaction { } message Transaction { - LocalTimestamp id = 1; - repeated LocalTimestamp edit_ids = 2; + LamportTimestamp id = 1; + repeated LamportTimestamp edit_ids = 2; repeated VectorClockEntry start = 3; } -message LocalTimestamp { +message LamportTimestamp { uint32 replica_id = 1; uint32 value = 2; } @@ -1007,12 +1025,28 @@ message JoinChannelBuffer { uint64 channel_id = 1; } +message RejoinChannelBuffers { + repeated ChannelBufferVersion buffers = 1; +} + +message RejoinChannelBuffersResponse { + repeated RejoinedChannelBuffer buffers = 1; +} + message JoinChannelBufferResponse { uint64 buffer_id = 1; uint32 replica_id = 2; string base_text = 3; repeated Operation operations = 4; repeated Collaborator collaborators = 5; + uint64 epoch = 6; +} + +message RejoinedChannelBuffer { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + repeated Operation operations = 3; + repeated Collaborator collaborators = 4; } message LeaveChannelBuffer { @@ -1279,7 +1313,7 @@ message Excerpt { message Anchor { uint32 replica_id = 1; - uint32 local_timestamp = 2; + uint32 timestamp = 2; uint64 offset = 3; Bias bias = 4; optional uint64 buffer_id = 5; @@ -1323,19 +1357,17 @@ message Operation { message Edit { uint32 replica_id = 1; - uint32 local_timestamp = 2; - uint32 lamport_timestamp = 3; - repeated VectorClockEntry version = 4; - repeated Range ranges = 5; - repeated string new_text = 6; + uint32 lamport_timestamp = 2; + repeated VectorClockEntry version = 3; + repeated Range ranges = 4; + repeated string new_text = 5; } message Undo { uint32 replica_id = 1; - uint32 local_timestamp = 2; - uint32 lamport_timestamp = 3; - repeated VectorClockEntry version = 4; - repeated UndoCount counts = 5; + uint32 lamport_timestamp = 2; + repeated VectorClockEntry version = 3; + repeated UndoCount counts = 4; } message UpdateSelections { @@ -1361,7 +1393,7 @@ message UndoMapEntry { message UndoCount { uint32 replica_id = 1; - uint32 local_timestamp = 2; + uint32 lamport_timestamp = 2; uint32 count = 3; } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 2e4dce01e1..f643a8c168 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -229,6 +229,8 @@ messages!( (StartLanguageServer, Foreground), (SynchronizeBuffers, Foreground), (SynchronizeBuffersResponse, Foreground), + (RejoinChannelBuffers, Foreground), + (RejoinChannelBuffersResponse, Foreground), (Test, Foreground), (Unfollow, Foreground), (UnshareProject, Foreground), @@ -257,6 +259,7 @@ messages!( (UpdateChannelBuffer, Foreground), (RemoveChannelBufferCollaborator, Foreground), (AddChannelBufferCollaborator, Foreground), + (UpdateChannelBufferCollaborator, Foreground), ); request_messages!( @@ -319,6 +322,7 @@ request_messages!( (SearchProject, SearchProjectResponse), (ShareProject, ShareProjectResponse), (SynchronizeBuffers, SynchronizeBuffersResponse), + (RejoinChannelBuffers, RejoinChannelBuffersResponse), (Test, Test), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), @@ -386,7 +390,8 @@ entity_messages!( channel_id, UpdateChannelBuffer, RemoveChannelBufferCollaborator, - AddChannelBufferCollaborator + AddChannelBufferCollaborator, + UpdateChannelBufferCollaborator ); const KIB: usize = 1024; diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index bc9dd6f80b..d64cbae929 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -6,4 +6,4 @@ pub use conn::Connection; pub use peer::*; mod macros; -pub const PROTOCOL_VERSION: u32 = 61; +pub const PROTOCOL_VERSION: u32 = 62; diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index e708781eca..78729df936 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -1,6 +1,6 @@ use crate::{ history::SearchHistory, - mode::{next_mode, SearchMode}, + mode::{next_mode, SearchMode, Side}, search_bar::{render_nav_button, render_search_mode_button}, CycleMode, NextHistoryQuery, PreviousHistoryQuery, SearchOptions, SelectAllMatches, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleWholeWord, @@ -156,11 +156,12 @@ impl View for BufferSearchBar { self.query_editor.update(cx, |editor, cx| { editor.set_placeholder_text(new_placeholder_text, cx); }); - let search_button_for_mode = |mode, cx: &mut ViewContext| { + let search_button_for_mode = |mode, side, cx: &mut ViewContext| { let is_active = self.current_mode == mode; render_search_mode_button( mode, + side, is_active, move |_, this, cx| { this.activate_search_mode(mode, cx); @@ -212,20 +213,11 @@ impl View for BufferSearchBar { ) }; - let icon_style = theme.search.editor_icon.clone(); - let nav_column = Flex::row() - .with_child(self.render_action_button("Select All", cx)) - .with_child(nav_button_for_direction("<", Direction::Prev, cx)) - .with_child(nav_button_for_direction(">", Direction::Next, cx)) - .with_child(Flex::row().with_children(match_count)) - .constrained() - .with_height(theme.search.search_bar_row_height); - - let query = Flex::row() + let query_column = Flex::row() .with_child( - Svg::for_style(icon_style.icon) + Svg::for_style(theme.search.editor_icon.clone().icon) .contained() - .with_style(icon_style.container), + .with_style(theme.search.editor_icon.clone().container), ) .with_child(ChildView::new(&self.query_editor, cx).flex(1., true)) .with_child( @@ -244,49 +236,45 @@ impl View for BufferSearchBar { .contained(), ) .align_children_center() - .flex(1., true); - let editor_column = Flex::row() - .with_child( - query - .contained() - .with_style(query_container_style) - .constrained() - .with_min_width(theme.search.editor.min_width) - .with_max_width(theme.search.editor.max_width) - .with_height(theme.search.search_bar_row_height) - .flex(1., false), - ) .contained() + .with_style(query_container_style) .constrained() + .with_min_width(theme.search.editor.min_width) + .with_max_width(theme.search.editor.max_width) .with_height(theme.search.search_bar_row_height) .flex(1., false); + let mode_column = Flex::row() - .with_child( - Flex::row() - .with_child(search_button_for_mode(SearchMode::Text, cx)) - .with_child(search_button_for_mode(SearchMode::Regex, cx)) - .contained() - .with_style(theme.search.modes_container), - ) - .with_child(super::search_bar::render_close_button( - "Dismiss Buffer Search", - &theme.search, + .with_child(search_button_for_mode( + SearchMode::Text, + Some(Side::Left), cx, - |_, this, cx| this.dismiss(&Default::default(), cx), - Some(Box::new(Dismiss)), )) + .with_child(search_button_for_mode( + SearchMode::Regex, + Some(Side::Right), + cx, + )) + .contained() + .with_style(theme.search.modes_container) + .constrained() + .with_height(theme.search.search_bar_row_height); + + let nav_column = Flex::row() + .with_child(self.render_action_button("all", cx)) + .with_child(Flex::row().with_children(match_count)) + .with_child(nav_button_for_direction("<", Direction::Prev, cx)) + .with_child(nav_button_for_direction(">", Direction::Next, cx)) .constrained() .with_height(theme.search.search_bar_row_height) - .aligned() - .right() .flex_float(); + Flex::row() - .with_child(editor_column) - .with_child(nav_column) + .with_child(query_column) .with_child(mode_column) + .with_child(nav_column) .contained() .with_style(theme.search.container) - .aligned() .into_any_named("search bar") } } @@ -340,8 +328,9 @@ impl ToolbarItemView for BufferSearchBar { ToolbarItemLocation::Hidden } } + fn row_count(&self, _: &ViewContext) -> usize { - 2 + 1 } } diff --git a/crates/search/src/mode.rs b/crates/search/src/mode.rs index 2c180be761..8afc2bd3f4 100644 --- a/crates/search/src/mode.rs +++ b/crates/search/src/mode.rs @@ -48,41 +48,18 @@ impl SearchMode { SearchMode::Regex => Box::new(ActivateRegexMode), } } - - pub(crate) fn border_right(&self) -> bool { - match self { - SearchMode::Regex => true, - SearchMode::Text => true, - SearchMode::Semantic => true, - } - } - - pub(crate) fn border_left(&self) -> bool { - match self { - SearchMode::Text => true, - _ => false, - } - } - - pub(crate) fn button_side(&self) -> Option { - match self { - SearchMode::Text => Some(Side::Left), - SearchMode::Semantic => None, - SearchMode::Regex => Some(Side::Right), - } - } } pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode { - let next_text_state = if semantic_enabled { - SearchMode::Semantic - } else { - SearchMode::Regex - }; - match mode { - SearchMode::Text => next_text_state, - SearchMode::Semantic => SearchMode::Regex, - SearchMode::Regex => SearchMode::Text, + SearchMode::Text => SearchMode::Regex, + SearchMode::Regex => { + if semantic_enabled { + SearchMode::Semantic + } else { + SearchMode::Text + } + } + SearchMode::Semantic => SearchMode::Text, } } diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 9bebd448a7..7088f394bc 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -1,6 +1,6 @@ use crate::{ history::SearchHistory, - mode::SearchMode, + mode::{SearchMode, Side}, search_bar::{render_nav_button, render_option_button_icon, render_search_mode_button}, ActivateRegexMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleWholeWord, @@ -1418,8 +1418,13 @@ impl View for ProjectSearchBar { }, cx, ); + let search = _search.read(cx); + let is_semantic_available = SemanticIndex::enabled(cx); let is_semantic_disabled = search.semantic_state.is_none(); + let icon_style = theme.search.editor_icon.clone(); + let is_active = search.active_match_index.is_some(); + let render_option_button_icon = |path, option, cx: &mut ViewContext| { crate::search_bar::render_option_button_icon( self.is_option_enabled(option, cx), @@ -1445,28 +1450,23 @@ impl View for ProjectSearchBar { render_option_button_icon("icons/word_search_12.svg", SearchOptions::WHOLE_WORD, cx) }); - let search = _search.read(cx); - let icon_style = theme.search.editor_icon.clone(); - - // Editor Functionality - let query = Flex::row() - .with_child( - Svg::for_style(icon_style.icon) - .contained() - .with_style(icon_style.container), + let search_button_for_mode = |mode, side, cx: &mut ViewContext| { + let is_active = if let Some(search) = self.active_project_search.as_ref() { + let search = search.read(cx); + search.current_mode == mode + } else { + false + }; + render_search_mode_button( + mode, + side, + is_active, + move |_, this, cx| { + this.activate_search_mode(mode, cx); + }, + cx, ) - .with_child(ChildView::new(&search.query_editor, cx).flex(1., true)) - .with_child( - Flex::row() - .with_child(filter_button) - .with_children(case_sensitive) - .with_children(whole_word) - .flex(1., false) - .constrained() - .contained(), - ) - .align_children_center() - .flex(1., true); + }; let search = _search.read(cx); @@ -1484,50 +1484,6 @@ impl View for ProjectSearchBar { theme.search.include_exclude_editor.input.container }; - let included_files_view = ChildView::new(&search.included_files_editor, cx) - .contained() - .flex(1., true); - let excluded_files_view = ChildView::new(&search.excluded_files_editor, cx) - .contained() - .flex(1., true); - let filters = search.filters_enabled.then(|| { - Flex::row() - .with_child( - included_files_view - .contained() - .with_style(include_container_style) - .constrained() - .with_height(theme.search.search_bar_row_height) - .with_min_width(theme.search.include_exclude_editor.min_width) - .with_max_width(theme.search.include_exclude_editor.max_width), - ) - .with_child( - excluded_files_view - .contained() - .with_style(exclude_container_style) - .constrained() - .with_height(theme.search.search_bar_row_height) - .with_min_width(theme.search.include_exclude_editor.min_width) - .with_max_width(theme.search.include_exclude_editor.max_width), - ) - .contained() - .with_padding_top(theme.workspace.toolbar.container.padding.bottom) - }); - - let editor_column = Flex::column() - .with_child( - query - .contained() - .with_style(query_container_style) - .constrained() - .with_min_width(theme.search.editor.min_width) - .with_max_width(theme.search.editor.max_width) - .with_height(theme.search.search_bar_row_height) - .flex(1., false), - ) - .with_children(filters) - .flex(1., false); - let matches = search.active_match_index.map(|match_ix| { Label::new( format!( @@ -1542,25 +1498,81 @@ impl View for ProjectSearchBar { .aligned() }); - let search_button_for_mode = |mode, cx: &mut ViewContext| { - let is_active = if let Some(search) = self.active_project_search.as_ref() { - let search = search.read(cx); - search.current_mode == mode - } else { - false - }; - render_search_mode_button( - mode, - is_active, - move |_, this, cx| { - this.activate_search_mode(mode, cx); - }, - cx, + let query_column = Flex::column() + .with_spacing(theme.search.search_row_spacing) + .with_child( + Flex::row() + .with_child( + Svg::for_style(icon_style.icon) + .contained() + .with_style(icon_style.container), + ) + .with_child(ChildView::new(&search.query_editor, cx).flex(1., true)) + .with_child( + Flex::row() + .with_child(filter_button) + .with_children(case_sensitive) + .with_children(whole_word) + .flex(1., false) + .constrained() + .contained(), + ) + .align_children_center() + .contained() + .with_style(query_container_style) + .constrained() + .with_min_width(theme.search.editor.min_width) + .with_max_width(theme.search.editor.max_width) + .with_height(theme.search.search_bar_row_height) + .flex(1., false), ) - }; - let is_active = search.active_match_index.is_some(); - let semantic_index = SemanticIndex::enabled(cx) - .then(|| search_button_for_mode(SearchMode::Semantic, cx)); + .with_children(search.filters_enabled.then(|| { + Flex::row() + .with_child( + ChildView::new(&search.included_files_editor, cx) + .contained() + .with_style(include_container_style) + .constrained() + .with_height(theme.search.search_bar_row_height) + .flex(1., true), + ) + .with_child( + ChildView::new(&search.excluded_files_editor, cx) + .contained() + .with_style(exclude_container_style) + .constrained() + .with_height(theme.search.search_bar_row_height) + .flex(1., true), + ) + .constrained() + .with_min_width(theme.search.editor.min_width) + .with_max_width(theme.search.editor.max_width) + .flex(1., false) + })) + .flex(1., false); + + let mode_column = + Flex::row() + .with_child(search_button_for_mode( + SearchMode::Text, + Some(Side::Left), + cx, + )) + .with_child(search_button_for_mode( + SearchMode::Regex, + if is_semantic_available { + None + } else { + Some(Side::Right) + }, + cx, + )) + .with_children(is_semantic_available.then(|| { + search_button_for_mode(SearchMode::Semantic, Some(Side::Right), cx) + })) + .contained() + .with_style(theme.search.modes_container); + let nav_button_for_direction = |label, direction, cx: &mut ViewContext| { render_nav_button( label, @@ -1576,43 +1588,17 @@ impl View for ProjectSearchBar { }; let nav_column = Flex::row() + .with_child(Flex::row().with_children(matches)) .with_child(nav_button_for_direction("<", Direction::Prev, cx)) .with_child(nav_button_for_direction(">", Direction::Next, cx)) - .with_child(Flex::row().with_children(matches)) - .constrained() - .with_height(theme.search.search_bar_row_height); - - let mode_column = Flex::row() - .with_child( - Flex::row() - .with_child(search_button_for_mode(SearchMode::Text, cx)) - .with_children(semantic_index) - .with_child(search_button_for_mode(SearchMode::Regex, cx)) - .contained() - .with_style(theme.search.modes_container), - ) - .with_child(super::search_bar::render_close_button( - "Dismiss Project Search", - &theme.search, - cx, - |_, this, cx| { - if let Some(search) = this.active_project_search.as_mut() { - search.update(cx, |_, cx| cx.emit(ViewEvent::Dismiss)) - } - }, - None, - )) .constrained() .with_height(theme.search.search_bar_row_height) - .aligned() - .right() - .top() .flex_float(); Flex::row() - .with_child(editor_column) - .with_child(nav_column) + .with_child(query_column) .with_child(mode_column) + .with_child(nav_column) .contained() .with_style(theme.search.container) .into_any_named("project search") @@ -1641,7 +1627,7 @@ impl ToolbarItemView for ProjectSearchBar { self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify())); self.active_project_search = Some(search); ToolbarItemLocation::PrimaryLeft { - flex: Some((1., false)), + flex: Some((1., true)), } } else { ToolbarItemLocation::Hidden @@ -1649,13 +1635,12 @@ impl ToolbarItemView for ProjectSearchBar { } fn row_count(&self, cx: &ViewContext) -> usize { - self.active_project_search - .as_ref() - .map(|search| { - let offset = search.read(cx).filters_enabled as usize; - 2 + offset - }) - .unwrap_or_else(|| 2) + if let Some(search) = self.active_project_search.as_ref() { + if search.read(cx).filters_enabled { + return 2; + } + } + 1 } } diff --git a/crates/search/src/search_bar.rs b/crates/search/src/search_bar.rs index 7d3c5261ea..d1a5a0380a 100644 --- a/crates/search/src/search_bar.rs +++ b/crates/search/src/search_bar.rs @@ -13,34 +13,6 @@ use crate::{ SelectNextMatch, SelectPrevMatch, }; -pub(super) fn render_close_button( - tooltip: &'static str, - theme: &theme::Search, - cx: &mut ViewContext, - on_click: impl Fn(MouseClick, &mut V, &mut EventContext) + 'static, - dismiss_action: Option>, -) -> AnyElement { - let tooltip_style = theme::current(cx).tooltip.clone(); - - enum CloseButton {} - MouseEventHandler::new::(0, cx, |state, _| { - let style = theme.dismiss_button.style_for(state); - Svg::new("icons/x_mark_8.svg") - .with_color(style.color) - .constrained() - .with_width(style.icon_width) - .aligned() - .contained() - .with_style(style.container) - .constrained() - .with_height(theme.search_bar_row_height) - }) - .on_click(MouseButton::Left, on_click) - .with_cursor_style(CursorStyle::PointingHand) - .with_tooltip::(0, tooltip.to_string(), dismiss_action, tooltip_style, cx) - .into_any() -} - pub(super) fn render_nav_button( icon: &'static str, direction: Direction, @@ -111,6 +83,7 @@ pub(super) fn render_nav_button( pub(crate) fn render_search_mode_button( mode: SearchMode, + side: Option, is_active: bool, on_click: impl Fn(MouseClick, &mut V, &mut EventContext) + 'static, cx: &mut ViewContext, @@ -119,41 +92,41 @@ pub(crate) fn render_search_mode_button( enum SearchModeButton {} MouseEventHandler::new::(mode.region_id(), cx, |state, cx| { let theme = theme::current(cx); - let mut style = theme + let style = theme .search .mode_button .in_state(is_active) .style_for(state) .clone(); - style.container.border.left = mode.border_left(); - style.container.border.right = mode.border_right(); - let label = Label::new(mode.label(), style.text.clone()) - .aligned() - .contained(); - let mut container_style = style.container.clone(); - if let Some(button_side) = mode.button_side() { + let mut container_style = style.container; + if let Some(button_side) = side { if button_side == Side::Left { + container_style.border.left = true; container_style.corner_radii = CornerRadii { bottom_right: 0., top_right: 0., ..container_style.corner_radii }; - label.with_style(container_style) } else { + container_style.border.left = false; container_style.corner_radii = CornerRadii { bottom_left: 0., top_left: 0., ..container_style.corner_radii }; - label.with_style(container_style) } } else { + container_style.border.left = false; container_style.corner_radii = CornerRadii::default(); - label.with_style(container_style) } - .constrained() - .with_height(theme.search.search_bar_row_height) + + Label::new(mode.label(), style.text) + .aligned() + .contained() + .with_style(container_style) + .constrained() + .with_height(theme.search.search_bar_row_height) }) .on_click(MouseButton::Left, on_click) .with_cursor_style(CursorStyle::PointingHand) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 97c25ca170..7228738525 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -181,18 +181,17 @@ impl EmbeddingProvider for OpenAIEmbeddings { fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { + let output = if tokens.len() > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER - .decode(tokens) + .decode(tokens.clone()) .ok() .unwrap_or_else(|| span.to_string()) } else { span.to_string() }; - (output, token_count) + (output, tokens.len()) } async fn embed_batch(&self, spans: Vec) -> Result> { @@ -204,7 +203,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; - let mut request_timeout: u64 = 10; + let mut request_timeout: u64 = 15; let mut response: Response; while request_number < MAX_RETRIES { response = self diff --git a/crates/text/Cargo.toml b/crates/text/Cargo.toml index adec2d4fd0..d1bc6cc8f8 100644 --- a/crates/text/Cargo.toml +++ b/crates/text/Cargo.toml @@ -14,7 +14,6 @@ test-support = ["rand"] [dependencies] clock = { path = "../clock" } collections = { path = "../collections" } -fs = { path = "../fs" } rope = { path = "../rope" } sum_tree = { path = "../sum_tree" } util = { path = "../util" } @@ -32,6 +31,7 @@ regex.workspace = true [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } +util = { path = "../util", features = ["test-support"] } ctor.workspace = true env_logger.workspace = true rand.workspace = true diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index b5f4fb24ec..084be0e336 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -8,7 +8,7 @@ use sum_tree::Bias; #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)] pub struct Anchor { - pub timestamp: clock::Local, + pub timestamp: clock::Lamport, pub offset: usize, pub bias: Bias, pub buffer_id: Option, @@ -16,14 +16,14 @@ pub struct Anchor { impl Anchor { pub const MIN: Self = Self { - timestamp: clock::Local::MIN, + timestamp: clock::Lamport::MIN, offset: usize::MIN, bias: Bias::Left, buffer_id: None, }; pub const MAX: Self = Self { - timestamp: clock::Local::MAX, + timestamp: clock::Lamport::MAX, offset: usize::MAX, bias: Bias::Right, buffer_id: None, diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index 4a97faf015..c05ea1109c 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -14,16 +14,17 @@ pub use anchor::*; use anyhow::{anyhow, Result}; pub use clock::ReplicaId; use collections::{HashMap, HashSet}; -use fs::LineEnding; use locator::Locator; use operation_queue::OperationQueue; pub use patch::Patch; use postage::{oneshot, prelude::*}; +use lazy_static::lazy_static; +use regex::Regex; pub use rope::*; pub use selection::*; - use std::{ + borrow::Cow, cmp::{self, Ordering, Reverse}, future::Future, iter::Iterator, @@ -36,22 +37,25 @@ pub use subscription::*; pub use sum_tree::Bias; use sum_tree::{FilterCursor, SumTree, TreeMap}; use undo_map::UndoMap; +use util::ResultExt; #[cfg(any(test, feature = "test-support"))] use util::RandomCharIter; -pub type TransactionId = clock::Local; +lazy_static! { + static ref LINE_SEPARATORS_REGEX: Regex = Regex::new("\r\n|\r|\u{2028}|\u{2029}").unwrap(); +} + +pub type TransactionId = clock::Lamport; pub struct Buffer { snapshot: BufferSnapshot, history: History, deferred_ops: OperationQueue, deferred_replicas: HashSet, - replica_id: ReplicaId, - local_clock: clock::Local, pub lamport_clock: clock::Lamport, subscriptions: Topic, - edit_id_resolvers: HashMap>>, + edit_id_resolvers: HashMap>>, wait_for_version_txs: Vec<(clock::Global, oneshot::Sender<()>)>, } @@ -79,7 +83,7 @@ pub struct HistoryEntry { #[derive(Clone, Debug)] pub struct Transaction { pub id: TransactionId, - pub edit_ids: Vec, + pub edit_ids: Vec, pub start: clock::Global, } @@ -91,8 +95,8 @@ impl HistoryEntry { struct History { base_text: Rope, - operations: TreeMap, - insertion_slices: HashMap>, + operations: TreeMap, + insertion_slices: HashMap>, undo_stack: Vec, redo_stack: Vec, transaction_depth: usize, @@ -101,7 +105,7 @@ struct History { #[derive(Clone, Debug)] struct InsertionSlice { - insertion_id: clock::Local, + insertion_id: clock::Lamport, range: Range, } @@ -123,18 +127,18 @@ impl History { } fn push(&mut self, op: Operation) { - self.operations.insert(op.local_timestamp(), op); + self.operations.insert(op.timestamp(), op); } fn start_transaction( &mut self, start: clock::Global, now: Instant, - local_clock: &mut clock::Local, + clock: &mut clock::Lamport, ) -> Option { self.transaction_depth += 1; if self.transaction_depth == 1 { - let id = local_clock.tick(); + let id = clock.tick(); self.undo_stack.push(HistoryEntry { transaction: Transaction { id, @@ -245,7 +249,7 @@ impl History { self.redo_stack.clear(); } - fn push_undo(&mut self, op_id: clock::Local) { + fn push_undo(&mut self, op_id: clock::Lamport) { assert_ne!(self.transaction_depth, 0); if let Some(Operation::Edit(_)) = self.operations.get(&op_id) { let last_transaction = self.undo_stack.last_mut().unwrap(); @@ -263,7 +267,19 @@ impl History { } } - fn remove_from_undo(&mut self, transaction_id: TransactionId) -> &[HistoryEntry] { + fn remove_from_undo(&mut self, transaction_id: TransactionId) -> Option<&HistoryEntry> { + assert_eq!(self.transaction_depth, 0); + + let entry_ix = self + .undo_stack + .iter() + .rposition(|entry| entry.transaction.id == transaction_id)?; + let entry = self.undo_stack.remove(entry_ix); + self.redo_stack.push(entry); + self.redo_stack.last() + } + + fn remove_from_undo_until(&mut self, transaction_id: TransactionId) -> &[HistoryEntry] { assert_eq!(self.transaction_depth, 0); let redo_stack_start_len = self.redo_stack.len(); @@ -278,20 +294,43 @@ impl History { &self.redo_stack[redo_stack_start_len..] } - fn forget(&mut self, transaction_id: TransactionId) { + fn forget(&mut self, transaction_id: TransactionId) -> Option { assert_eq!(self.transaction_depth, 0); if let Some(entry_ix) = self .undo_stack .iter() .rposition(|entry| entry.transaction.id == transaction_id) { - self.undo_stack.remove(entry_ix); + Some(self.undo_stack.remove(entry_ix).transaction) } else if let Some(entry_ix) = self .redo_stack .iter() .rposition(|entry| entry.transaction.id == transaction_id) { - self.undo_stack.remove(entry_ix); + Some(self.redo_stack.remove(entry_ix).transaction) + } else { + None + } + } + + fn transaction_mut(&mut self, transaction_id: TransactionId) -> Option<&mut Transaction> { + let entry = self + .undo_stack + .iter_mut() + .rfind(|entry| entry.transaction.id == transaction_id) + .or_else(|| { + self.redo_stack + .iter_mut() + .rfind(|entry| entry.transaction.id == transaction_id) + })?; + Some(&mut entry.transaction) + } + + fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) { + if let Some(transaction) = self.forget(transaction) { + if let Some(destination) = self.transaction_mut(destination) { + destination.edit_ids.extend(transaction.edit_ids); + } } } @@ -371,37 +410,14 @@ impl Edit<(D1, D2)> { } } -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] -pub struct InsertionTimestamp { - pub replica_id: ReplicaId, - pub local: clock::Seq, - pub lamport: clock::Seq, -} - -impl InsertionTimestamp { - pub fn local(&self) -> clock::Local { - clock::Local { - replica_id: self.replica_id, - value: self.local, - } - } - - pub fn lamport(&self) -> clock::Lamport { - clock::Lamport { - replica_id: self.replica_id, - value: self.lamport, - } - } -} - #[derive(Eq, PartialEq, Clone, Debug)] pub struct Fragment { pub id: Locator, - pub insertion_timestamp: InsertionTimestamp, + pub timestamp: clock::Lamport, pub insertion_offset: usize, pub len: usize, pub visible: bool, - pub deletions: HashSet, + pub deletions: HashSet, pub max_undos: clock::Global, } @@ -429,29 +445,26 @@ impl<'a> sum_tree::Dimension<'a, FragmentSummary> for FragmentTextSummary { #[derive(Eq, PartialEq, Clone, Debug)] struct InsertionFragment { - timestamp: clock::Local, + timestamp: clock::Lamport, split_offset: usize, fragment_id: Locator, } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] struct InsertionFragmentKey { - timestamp: clock::Local, + timestamp: clock::Lamport, split_offset: usize, } #[derive(Clone, Debug, Eq, PartialEq)] pub enum Operation { Edit(EditOperation), - Undo { - undo: UndoOperation, - lamport_timestamp: clock::Lamport, - }, + Undo(UndoOperation), } #[derive(Clone, Debug, Eq, PartialEq)] pub struct EditOperation { - pub timestamp: InsertionTimestamp, + pub timestamp: clock::Lamport, pub version: clock::Global, pub ranges: Vec>, pub new_text: Vec>, @@ -459,9 +472,9 @@ pub struct EditOperation { #[derive(Clone, Debug, Eq, PartialEq)] pub struct UndoOperation { - pub id: clock::Local, - pub counts: HashMap, + pub timestamp: clock::Lamport, pub version: clock::Global, + pub counts: HashMap, } impl Buffer { @@ -473,24 +486,21 @@ impl Buffer { let mut fragments = SumTree::new(); let mut insertions = SumTree::new(); - let mut local_clock = clock::Local::new(replica_id); let mut lamport_clock = clock::Lamport::new(replica_id); let mut version = clock::Global::new(); let visible_text = history.base_text.clone(); if !visible_text.is_empty() { - let insertion_timestamp = InsertionTimestamp { + let insertion_timestamp = clock::Lamport { replica_id: 0, - local: 1, - lamport: 1, + value: 1, }; - local_clock.observe(insertion_timestamp.local()); - lamport_clock.observe(insertion_timestamp.lamport()); - version.observe(insertion_timestamp.local()); + lamport_clock.observe(insertion_timestamp); + version.observe(insertion_timestamp); let fragment_id = Locator::between(&Locator::min(), &Locator::max()); let fragment = Fragment { id: fragment_id, - insertion_timestamp, + timestamp: insertion_timestamp, insertion_offset: 0, len: visible_text.len(), visible: true, @@ -516,8 +526,6 @@ impl Buffer { history, deferred_ops: OperationQueue::new(), deferred_replicas: HashSet::default(), - replica_id, - local_clock, lamport_clock, subscriptions: Default::default(), edit_id_resolvers: Default::default(), @@ -534,7 +542,7 @@ impl Buffer { } pub fn replica_id(&self) -> ReplicaId { - self.local_clock.replica_id + self.lamport_clock.replica_id } pub fn remote_id(&self) -> u64 { @@ -561,16 +569,12 @@ impl Buffer { .map(|(range, new_text)| (range, new_text.into())); self.start_transaction(); - let timestamp = InsertionTimestamp { - replica_id: self.replica_id, - local: self.local_clock.tick().value, - lamport: self.lamport_clock.tick().value, - }; + let timestamp = self.lamport_clock.tick(); let operation = Operation::Edit(self.apply_local_edit(edits, timestamp)); self.history.push(operation.clone()); - self.history.push_undo(operation.local_timestamp()); - self.snapshot.version.observe(operation.local_timestamp()); + self.history.push_undo(operation.timestamp()); + self.snapshot.version.observe(operation.timestamp()); self.end_transaction(); operation } @@ -578,7 +582,7 @@ impl Buffer { fn apply_local_edit>>( &mut self, edits: impl ExactSizeIterator, T)>, - timestamp: InsertionTimestamp, + timestamp: clock::Lamport, ) -> EditOperation { let mut edits_patch = Patch::default(); let mut edit_op = EditOperation { @@ -655,7 +659,7 @@ impl Buffer { .item() .map_or(&Locator::max(), |old_fragment| &old_fragment.id), ), - insertion_timestamp: timestamp, + timestamp, insertion_offset, len: new_text.len(), deletions: Default::default(), @@ -685,7 +689,7 @@ impl Buffer { intersection.insertion_offset += fragment_start - old_fragments.start().visible; intersection.id = Locator::between(&new_fragments.summary().max_id, &intersection.id); - intersection.deletions.insert(timestamp.local()); + intersection.deletions.insert(timestamp); intersection.visible = false; } if intersection.len > 0 { @@ -740,7 +744,7 @@ impl Buffer { self.subscriptions.publish_mut(&edits_patch); self.history .insertion_slices - .insert(timestamp.local(), insertion_slices); + .insert(timestamp, insertion_slices); edit_op } @@ -767,28 +771,23 @@ impl Buffer { fn apply_op(&mut self, op: Operation) -> Result<()> { match op { Operation::Edit(edit) => { - if !self.version.observed(edit.timestamp.local()) { + if !self.version.observed(edit.timestamp) { self.apply_remote_edit( &edit.version, &edit.ranges, &edit.new_text, edit.timestamp, ); - self.snapshot.version.observe(edit.timestamp.local()); - self.local_clock.observe(edit.timestamp.local()); - self.lamport_clock.observe(edit.timestamp.lamport()); - self.resolve_edit(edit.timestamp.local()); + self.snapshot.version.observe(edit.timestamp); + self.lamport_clock.observe(edit.timestamp); + self.resolve_edit(edit.timestamp); } } - Operation::Undo { - undo, - lamport_timestamp, - } => { - if !self.version.observed(undo.id) { + Operation::Undo(undo) => { + if !self.version.observed(undo.timestamp) { self.apply_undo(&undo)?; - self.snapshot.version.observe(undo.id); - self.local_clock.observe(undo.id); - self.lamport_clock.observe(lamport_timestamp); + self.snapshot.version.observe(undo.timestamp); + self.lamport_clock.observe(undo.timestamp); } } } @@ -808,7 +807,7 @@ impl Buffer { version: &clock::Global, ranges: &[Range], new_text: &[Arc], - timestamp: InsertionTimestamp, + timestamp: clock::Lamport, ) { if ranges.is_empty() { return; @@ -875,9 +874,7 @@ impl Buffer { // Skip over insertions that are concurrent to this edit, but have a lower lamport // timestamp. while let Some(fragment) = old_fragments.item() { - if fragment_start == range.start - && fragment.insertion_timestamp.lamport() > timestamp.lamport() - { + if fragment_start == range.start && fragment.timestamp > timestamp { new_ropes.push_fragment(fragment, fragment.visible); new_fragments.push(fragment.clone(), &None); old_fragments.next(&cx); @@ -914,7 +911,7 @@ impl Buffer { .item() .map_or(&Locator::max(), |old_fragment| &old_fragment.id), ), - insertion_timestamp: timestamp, + timestamp, insertion_offset, len: new_text.len(), deletions: Default::default(), @@ -945,7 +942,7 @@ impl Buffer { fragment_start - old_fragments.start().0.full_offset(); intersection.id = Locator::between(&new_fragments.summary().max_id, &intersection.id); - intersection.deletions.insert(timestamp.local()); + intersection.deletions.insert(timestamp); intersection.visible = false; insertion_slices.push(intersection.insertion_slice()); } @@ -997,13 +994,13 @@ impl Buffer { self.snapshot.insertions.edit(new_insertions, &()); self.history .insertion_slices - .insert(timestamp.local(), insertion_slices); + .insert(timestamp, insertion_slices); self.subscriptions.publish_mut(&edits_patch) } fn fragment_ids_for_edits<'a>( &'a self, - edit_ids: impl Iterator, + edit_ids: impl Iterator, ) -> Vec<&'a Locator> { // Get all of the insertion slices changed by the given edits. let mut insertion_slices = Vec::new(); @@ -1064,7 +1061,7 @@ impl Buffer { let fragment_was_visible = fragment.visible; fragment.visible = fragment.is_visible(&self.undo_map); - fragment.max_undos.observe(undo.id); + fragment.max_undos.observe(undo.timestamp); let old_start = old_fragments.start().1; let new_start = new_fragments.summary().text.visible; @@ -1118,10 +1115,10 @@ impl Buffer { if self.deferred_replicas.contains(&op.replica_id()) { false } else { - match op { - Operation::Edit(edit) => self.version.observed_all(&edit.version), - Operation::Undo { undo, .. } => self.version.observed_all(&undo.version), - } + self.version.observed_all(match op { + Operation::Edit(edit) => &edit.version, + Operation::Undo(undo) => &undo.version, + }) } } @@ -1139,7 +1136,7 @@ impl Buffer { pub fn start_transaction_at(&mut self, now: Instant) -> Option { self.history - .start_transaction(self.version.clone(), now, &mut self.local_clock) + .start_transaction(self.version.clone(), now, &mut self.lamport_clock) } pub fn end_transaction(&mut self) -> Option<(TransactionId, clock::Global)> { @@ -1168,7 +1165,7 @@ impl Buffer { &self.history.base_text } - pub fn operations(&self) -> &TreeMap { + pub fn operations(&self) -> &TreeMap { &self.history.operations } @@ -1183,11 +1180,20 @@ impl Buffer { } } + pub fn undo_transaction(&mut self, transaction_id: TransactionId) -> Option { + let transaction = self + .history + .remove_from_undo(transaction_id)? + .transaction + .clone(); + self.undo_or_redo(transaction).log_err() + } + #[allow(clippy::needless_collect)] pub fn undo_to_transaction(&mut self, transaction_id: TransactionId) -> Vec { let transactions = self .history - .remove_from_undo(transaction_id) + .remove_from_undo_until(transaction_id) .iter() .map(|entry| entry.transaction.clone()) .collect::>(); @@ -1202,6 +1208,10 @@ impl Buffer { self.history.forget(transaction_id); } + pub fn merge_transactions(&mut self, transaction: TransactionId, destination: TransactionId) { + self.history.merge_transactions(transaction, destination); + } + pub fn redo(&mut self) -> Option<(TransactionId, Operation)> { if let Some(entry) = self.history.pop_redo() { let transaction = entry.transaction.clone(); @@ -1235,16 +1245,13 @@ impl Buffer { } let undo = UndoOperation { - id: self.local_clock.tick(), + timestamp: self.lamport_clock.tick(), version: self.version(), counts, }; self.apply_undo(&undo)?; - let operation = Operation::Undo { - undo, - lamport_timestamp: self.lamport_clock.tick(), - }; - self.snapshot.version.observe(operation.local_timestamp()); + self.snapshot.version.observe(undo.timestamp); + let operation = Operation::Undo(undo); self.history.push(operation.clone()); Ok(operation) } @@ -1309,7 +1316,7 @@ impl Buffer { pub fn wait_for_edits( &mut self, - edit_ids: impl IntoIterator, + edit_ids: impl IntoIterator, ) -> impl 'static + Future> { let mut futures = Vec::new(); for edit_id in edit_ids { @@ -1381,7 +1388,7 @@ impl Buffer { self.wait_for_version_txs.clear(); } - fn resolve_edit(&mut self, edit_id: clock::Local) { + fn resolve_edit(&mut self, edit_id: clock::Lamport) { for mut tx in self .edit_id_resolvers .remove(&edit_id) @@ -1459,7 +1466,7 @@ impl Buffer { .insertions .get( &InsertionFragmentKey { - timestamp: fragment.insertion_timestamp.local(), + timestamp: fragment.timestamp, split_offset: fragment.insertion_offset, }, &(), @@ -1942,7 +1949,7 @@ impl BufferSnapshot { let fragment = fragment_cursor.item().unwrap(); let overshoot = offset - *fragment_cursor.start(); Anchor { - timestamp: fragment.insertion_timestamp.local(), + timestamp: fragment.timestamp, offset: fragment.insertion_offset + overshoot, bias, buffer_id: Some(self.remote_id), @@ -2134,15 +2141,14 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo break; } - let timestamp = fragment.insertion_timestamp.local(); let start_anchor = Anchor { - timestamp, + timestamp: fragment.timestamp, offset: fragment.insertion_offset, bias: Bias::Right, buffer_id: Some(self.buffer_id), }; let end_anchor = Anchor { - timestamp, + timestamp: fragment.timestamp, offset: fragment.insertion_offset + fragment.len, bias: Bias::Left, buffer_id: Some(self.buffer_id), @@ -2215,19 +2221,17 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo impl Fragment { fn insertion_slice(&self) -> InsertionSlice { InsertionSlice { - insertion_id: self.insertion_timestamp.local(), + insertion_id: self.timestamp, range: self.insertion_offset..self.insertion_offset + self.len, } } fn is_visible(&self, undos: &UndoMap) -> bool { - !undos.is_undone(self.insertion_timestamp.local()) - && self.deletions.iter().all(|d| undos.is_undone(*d)) + !undos.is_undone(self.timestamp) && self.deletions.iter().all(|d| undos.is_undone(*d)) } fn was_visible(&self, version: &clock::Global, undos: &UndoMap) -> bool { - (version.observed(self.insertion_timestamp.local()) - && !undos.was_undone(self.insertion_timestamp.local(), version)) + (version.observed(self.timestamp) && !undos.was_undone(self.timestamp, version)) && self .deletions .iter() @@ -2240,14 +2244,14 @@ impl sum_tree::Item for Fragment { fn summary(&self) -> Self::Summary { let mut max_version = clock::Global::new(); - max_version.observe(self.insertion_timestamp.local()); + max_version.observe(self.timestamp); for deletion in &self.deletions { max_version.observe(*deletion); } max_version.join(&self.max_undos); let mut min_insertion_version = clock::Global::new(); - min_insertion_version.observe(self.insertion_timestamp.local()); + min_insertion_version.observe(self.timestamp); let max_insertion_version = min_insertion_version.clone(); if self.visible { FragmentSummary { @@ -2324,7 +2328,7 @@ impl sum_tree::KeyedItem for InsertionFragment { impl InsertionFragment { fn new(fragment: &Fragment) -> Self { Self { - timestamp: fragment.insertion_timestamp.local(), + timestamp: fragment.timestamp, split_offset: fragment.insertion_offset, fragment_id: fragment.id.clone(), } @@ -2447,10 +2451,10 @@ impl Operation { operation_queue::Operation::lamport_timestamp(self).replica_id } - pub fn local_timestamp(&self) -> clock::Local { + pub fn timestamp(&self) -> clock::Lamport { match self { - Operation::Edit(edit) => edit.timestamp.local(), - Operation::Undo { undo, .. } => undo.id, + Operation::Edit(edit) => edit.timestamp, + Operation::Undo(undo) => undo.timestamp, } } @@ -2469,10 +2473,8 @@ impl Operation { impl operation_queue::Operation for Operation { fn lamport_timestamp(&self) -> clock::Lamport { match self { - Operation::Edit(edit) => edit.timestamp.lamport(), - Operation::Undo { - lamport_timestamp, .. - } => *lamport_timestamp, + Operation::Edit(edit) => edit.timestamp, + Operation::Undo(undo) => undo.timestamp, } } } @@ -2622,3 +2624,59 @@ impl FromAnchor for usize { snapshot.summary_for_anchor(anchor) } } + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum LineEnding { + Unix, + Windows, +} + +impl Default for LineEnding { + fn default() -> Self { + #[cfg(unix)] + return Self::Unix; + + #[cfg(not(unix))] + return Self::CRLF; + } +} + +impl LineEnding { + pub fn as_str(&self) -> &'static str { + match self { + LineEnding::Unix => "\n", + LineEnding::Windows => "\r\n", + } + } + + pub fn detect(text: &str) -> Self { + let mut max_ix = cmp::min(text.len(), 1000); + while !text.is_char_boundary(max_ix) { + max_ix -= 1; + } + + if let Some(ix) = text[..max_ix].find(&['\n']) { + if ix > 0 && text.as_bytes()[ix - 1] == b'\r' { + Self::Windows + } else { + Self::Unix + } + } else { + Self::default() + } + } + + pub fn normalize(text: &mut String) { + if let Cow::Owned(replaced) = LINE_SEPARATORS_REGEX.replace_all(text, "\n") { + *text = replaced; + } + } + + pub fn normalize_arc(text: Arc) -> Arc { + if let Cow::Owned(replaced) = LINE_SEPARATORS_REGEX.replace_all(&text, "\n") { + replaced.into() + } else { + text + } + } +} diff --git a/crates/text/src/undo_map.rs b/crates/text/src/undo_map.rs index ff1b241e73..f95809c02e 100644 --- a/crates/text/src/undo_map.rs +++ b/crates/text/src/undo_map.rs @@ -26,8 +26,8 @@ impl sum_tree::KeyedItem for UndoMapEntry { #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] struct UndoMapKey { - edit_id: clock::Local, - undo_id: clock::Local, + edit_id: clock::Lamport, + undo_id: clock::Lamport, } impl sum_tree::Summary for UndoMapKey { @@ -50,7 +50,7 @@ impl UndoMap { sum_tree::Edit::Insert(UndoMapEntry { key: UndoMapKey { edit_id: *edit_id, - undo_id: undo.id, + undo_id: undo.timestamp, }, undo_count: *count, }) @@ -59,11 +59,11 @@ impl UndoMap { self.0.edit(edits, &()); } - pub fn is_undone(&self, edit_id: clock::Local) -> bool { + pub fn is_undone(&self, edit_id: clock::Lamport) -> bool { self.undo_count(edit_id) % 2 == 1 } - pub fn was_undone(&self, edit_id: clock::Local, version: &clock::Global) -> bool { + pub fn was_undone(&self, edit_id: clock::Lamport, version: &clock::Global) -> bool { let mut cursor = self.0.cursor::(); cursor.seek( &UndoMapKey { @@ -88,7 +88,7 @@ impl UndoMap { undo_count % 2 == 1 } - pub fn undo_count(&self, edit_id: clock::Local) -> u32 { + pub fn undo_count(&self, edit_id: clock::Lamport) -> u32 { let mut cursor = self.0.cursor::(); cursor.seek( &UndoMapKey { diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index 9005fc9757..a542249788 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -88,8 +88,6 @@ pub struct Workspace { pub dock: Dock, pub status_bar: StatusBar, pub toolbar: Toolbar, - pub breadcrumb_height: f32, - pub breadcrumbs: Interactive, pub disconnected_overlay: ContainedText, pub modal: ContainerStyle, pub zoomed_panel_foreground: ContainerStyle, @@ -120,7 +118,6 @@ pub struct Titlebar { pub height: f32, pub menu: TitlebarMenu, pub project_menu_button: Toggleable>, - pub project_name_divider: ContainedText, pub git_menu_button: Toggleable>, pub item_spacing: f32, pub face_pile_spacing: f32, @@ -411,6 +408,9 @@ pub struct Toolbar { pub height: f32, pub item_spacing: f32, pub toggleable_tool: Toggleable>, + pub toggleable_text_tool: Toggleable>, + pub breadcrumb_height: f32, + pub breadcrumbs: Interactive, } #[derive(Clone, Deserialize, Default, JsonSchema)] @@ -437,11 +437,11 @@ pub struct Search { pub match_index: ContainedText, pub major_results_status: TextStyle, pub minor_results_status: TextStyle, - pub dismiss_button: Interactive, pub editor_icon: IconStyle, pub mode_button: Toggleable>, pub nav_button: Toggleable>, pub search_bar_row_height: f32, + pub search_row_spacing: f32, pub option_button_height: f32, pub modes_container: ContainerStyle, } @@ -835,6 +835,9 @@ pub struct AutocompleteStyle { pub selected_item: ContainerStyle, pub hovered_item: ContainerStyle, pub match_highlight: HighlightStyle, + pub server_name_container: ContainerStyle, + pub server_name_color: Color, + pub server_name_size_percent: f32, } #[derive(Clone, Copy, Default, Deserialize, JsonSchema)] @@ -1150,6 +1153,17 @@ pub struct AssistantStyle { pub api_key_editor: FieldEditor, pub api_key_prompt: ContainedText, pub saved_conversation: SavedConversation, + pub inline: InlineAssistantStyle, +} + +#[derive(Clone, Deserialize, Default, JsonSchema)] +pub struct InlineAssistantStyle { + #[serde(flatten)] + pub container: ContainerStyle, + pub editor: FieldEditor, + pub disabled_editor: FieldEditor, + pub pending_edit_background: Color, + pub include_conversation: ToggleIconButtonStyle, } #[derive(Clone, Deserialize, Default, JsonSchema)] diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index 0d3fb700ef..16bccb6963 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -3,7 +3,8 @@ use std::{cmp, sync::Arc}; use editor::{ char_kind, display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint}, - movement, Bias, CharKind, DisplayPoint, ToOffset, + movement::{self, FindRange}, + Bias, CharKind, DisplayPoint, ToOffset, }; use gpui::{actions, impl_actions, AppContext, WindowContext}; use language::{Point, Selection, SelectionGoal}; @@ -589,12 +590,12 @@ pub(crate) fn next_word_start( ignore_punctuation: bool, times: usize, ) -> DisplayPoint { - let language = map.buffer_snapshot.language_at(point.to_point(map)); + let scope = map.buffer_snapshot.language_scope_at(point.to_point(map)); for _ in 0..times { let mut crossed_newline = false; - point = movement::find_boundary(map, point, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation); + let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation); let at_newline = right == '\n'; let found = (left_kind != right_kind && right_kind != CharKind::Whitespace) @@ -614,12 +615,17 @@ fn next_word_end( ignore_punctuation: bool, times: usize, ) -> DisplayPoint { - let language = map.buffer_snapshot.language_at(point.to_point(map)); + let scope = map.buffer_snapshot.language_scope_at(point.to_point(map)); for _ in 0..times { - *point.column_mut() += 1; - point = movement::find_boundary(map, point, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + if point.column() < map.line_len(point.row()) { + *point.column_mut() += 1; + } else if point.row() < map.max_buffer_row() { + *point.row_mut() += 1; + *point.column_mut() = 0; + } + point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation); + let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation); left_kind != right_kind && left_kind != CharKind::Whitespace }); @@ -645,16 +651,17 @@ fn previous_word_start( ignore_punctuation: bool, times: usize, ) -> DisplayPoint { - let language = map.buffer_snapshot.language_at(point.to_point(map)); + let scope = map.buffer_snapshot.language_scope_at(point.to_point(map)); for _ in 0..times { // This works even though find_preceding_boundary is called for every character in the line containing // cursor because the newline is checked only once. - point = movement::find_preceding_boundary(map, point, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + point = + movement::find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation); + let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation); - (left_kind != right_kind && !right.is_whitespace()) || left == '\n' - }); + (left_kind != right_kind && !right.is_whitespace()) || left == '\n' + }); } point } @@ -665,7 +672,7 @@ fn first_non_whitespace( from: DisplayPoint, ) -> DisplayPoint { let mut last_point = start_of_line(map, display_lines, from); - let language = map.buffer_snapshot.language_at(from.to_point(map)); + let scope = map.buffer_snapshot.language_scope_at(from.to_point(map)); for (ch, point) in map.chars_at(last_point) { if ch == '\n' { return from; @@ -673,7 +680,7 @@ fn first_non_whitespace( last_point = point; - if char_kind(language, ch) != CharKind::Whitespace { + if char_kind(&scope, ch) != CharKind::Whitespace { break; } } diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index a73c518809..ce1e4c3d6d 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -27,7 +27,6 @@ use self::{ case::change_case, change::{change_motion, change_object}, delete::{delete_motion, delete_object}, - substitute::substitute, yank::{yank_motion, yank_object}, }; @@ -44,7 +43,6 @@ actions!( ChangeToEndOfLine, DeleteToEndOfLine, Yank, - Substitute, ChangeCase, ] ); @@ -56,13 +54,8 @@ pub fn init(cx: &mut AppContext) { cx.add_action(insert_line_above); cx.add_action(insert_line_below); cx.add_action(change_case); + substitute::init(cx); search::init(cx); - cx.add_action(|_: &mut Workspace, _: &Substitute, cx| { - Vim::update(cx, |vim, cx| { - let times = vim.pop_number_operator(cx); - substitute(vim, times, cx); - }) - }); cx.add_action(|_: &mut Workspace, _: &DeleteLeft, cx| { Vim::update(cx, |vim, cx| { let times = vim.pop_number_operator(cx); @@ -445,7 +438,7 @@ mod test { } #[gpui::test] - async fn test_e(cx: &mut gpui::TestAppContext) { + async fn test_end_of_word(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await.binding(["e"]); cx.assert_all(indoc! {" Thˇe quicˇkˇ-browˇn diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 5591de89c6..836ce1492b 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -1,7 +1,10 @@ use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_content, Vim}; use editor::{ - char_kind, display_map::DisplaySnapshot, movement, scroll::autoscroll::Autoscroll, CharKind, - DisplayPoint, + char_kind, + display_map::DisplaySnapshot, + movement::{self, FindRange}, + scroll::autoscroll::Autoscroll, + CharKind, DisplayPoint, }; use gpui::WindowContext; use language::Selection; @@ -86,22 +89,24 @@ fn expand_changed_word_selection( ignore_punctuation: bool, ) -> bool { if times.is_none() || times.unwrap() == 1 { - let language = map + let scope = map .buffer_snapshot - .language_at(selection.start.to_point(map)); + .language_scope_at(selection.start.to_point(map)); let in_word = map .chars_at(selection.head()) .next() - .map(|(c, _)| char_kind(language, c) != CharKind::Whitespace) + .map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace) .unwrap_or_default(); if in_word { - selection.end = movement::find_boundary(map, selection.end, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + selection.end = + movement::find_boundary(map, selection.end, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation); + let right_kind = + char_kind(&scope, right).coerce_punctuation(ignore_punctuation); - left_kind != right_kind && left_kind != CharKind::Whitespace - }); + left_kind != right_kind && left_kind != CharKind::Whitespace + }); true } else { Motion::NextWordStart { ignore_punctuation } diff --git a/crates/vim/src/normal/scroll.rs b/crates/vim/src/normal/scroll.rs index a2bbab0478..1b3dcee6ad 100644 --- a/crates/vim/src/normal/scroll.rs +++ b/crates/vim/src/normal/scroll.rs @@ -67,7 +67,8 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex let top_anchor = editor.scroll_manager.anchor().anchor; editor.change_selections(None, cx, |s| { - s.move_heads_with(|map, head, goal| { + s.move_with(|map, selection| { + let head = selection.head(); let top = top_anchor.to_display_point(map); let min_row = top.row() + VERTICAL_SCROLL_MARGIN as u32; let max_row = top.row() + visible_rows - VERTICAL_SCROLL_MARGIN as u32 - 1; @@ -79,7 +80,11 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex } else { head }; - (new_head, goal) + if selection.is_empty() { + selection.collapse_to(new_head, selection.goal) + } else { + selection.set_head(new_head, selection.goal) + }; }) }); } @@ -90,12 +95,35 @@ mod test { use crate::{state::Mode, test::VimTestContext}; use gpui::geometry::vector::vec2f; use indoc::indoc; + use language::Point; #[gpui::test] async fn test_scroll(cx: &mut gpui::TestAppContext) { let mut cx = VimTestContext::new(cx, true).await; - cx.set_state(indoc! {"ˇa\nb\nc\nd\ne\n"}, Mode::Normal); + let window = cx.window; + let line_height = + cx.editor(|editor, cx| editor.style(cx).text.line_height(cx.font_cache())); + window.simulate_resize(vec2f(1000., 8.0 * line_height - 1.0), &mut cx); + + cx.set_state( + indoc!( + "ˇone + two + three + four + five + six + seven + eight + nine + ten + eleven + twelve + " + ), + Mode::Normal, + ); cx.update_editor(|editor, cx| { assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) @@ -112,5 +140,33 @@ mod test { cx.update_editor(|editor, cx| { assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.)) }); + + // does not select in normal mode + cx.simulate_keystrokes(["g", "g"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) + }); + cx.simulate_keystrokes(["ctrl-d"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0)); + assert_eq!( + editor.selections.newest(cx).range(), + Point::new(5, 0)..Point::new(5, 0) + ) + }); + + // does select in visual mode + cx.simulate_keystrokes(["g", "g"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) + }); + cx.simulate_keystrokes(["v", "ctrl-d"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0)); + assert_eq!( + editor.selections.newest(cx).range(), + Point::new(0, 0)..Point::new(5, 1) + ) + }); } } diff --git a/crates/vim/src/normal/substitute.rs b/crates/vim/src/normal/substitute.rs index b04596240a..23b545abd8 100644 --- a/crates/vim/src/normal/substitute.rs +++ b/crates/vim/src/normal/substitute.rs @@ -1,10 +1,32 @@ -use gpui::WindowContext; +use editor::movement; +use gpui::{actions, AppContext, WindowContext}; use language::Point; +use workspace::Workspace; use crate::{motion::Motion, utils::copy_selections_content, Mode, Vim}; -pub fn substitute(vim: &mut Vim, count: Option, cx: &mut WindowContext) { - let line_mode = vim.state().mode == Mode::VisualLine; +actions!(vim, [Substitute, SubstituteLine]); + +pub(crate) fn init(cx: &mut AppContext) { + cx.add_action(|_: &mut Workspace, _: &Substitute, cx| { + Vim::update(cx, |vim, cx| { + let count = vim.pop_number_operator(cx); + substitute(vim, count, vim.state().mode == Mode::VisualLine, cx); + }) + }); + + cx.add_action(|_: &mut Workspace, _: &SubstituteLine, cx| { + Vim::update(cx, |vim, cx| { + if matches!(vim.state().mode, Mode::VisualBlock | Mode::Visual) { + vim.switch_mode(Mode::VisualLine, false, cx) + } + let count = vim.pop_number_operator(cx); + substitute(vim, count, true, cx) + }) + }); +} + +pub fn substitute(vim: &mut Vim, count: Option, line_mode: bool, cx: &mut WindowContext) { vim.update_active_editor(cx, |editor, cx| { editor.set_clip_at_line_ends(false, cx); editor.transact(cx, |editor, cx| { @@ -14,6 +36,11 @@ pub fn substitute(vim: &mut Vim, count: Option, cx: &mut WindowContext) { Motion::Right.expand_selection(map, selection, count, true); } if line_mode { + // in Visual mode when the selection contains the newline at the end + // of the line, we should exclude it. + if !selection.is_empty() && selection.end.column() == 0 { + selection.end = movement::left(map, selection.end); + } Motion::CurrentLine.expand_selection(map, selection, None, false); if let Some((point, _)) = (Motion::FirstNonWhitespace { display_lines: false, @@ -166,4 +193,68 @@ mod test { the laˇzy dog"}) .await; } + + #[gpui::test] + async fn test_substitute_line(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + let initial_state = indoc! {" + The quick brown + fox juˇmps over + the lazy dog + "}; + + // normal mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["shift-s", "o"]).await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + + // visual mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["v", "k", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + oˇ + the lazy dog + "}) + .await; + + // visual block mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["ctrl-v", "j", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + "}) + .await; + + // visual mode including newline + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["v", "$", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + + // indentation + cx.set_neovim_option("shiftwidth=4").await; + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes([">", ">", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + } } diff --git a/crates/vim/src/object.rs b/crates/vim/src/object.rs index dd922e7af6..653d4ca7b6 100644 --- a/crates/vim/src/object.rs +++ b/crates/vim/src/object.rs @@ -1,6 +1,11 @@ use std::ops::Range; -use editor::{char_kind, display_map::DisplaySnapshot, movement, Bias, CharKind, DisplayPoint}; +use editor::{ + char_kind, + display_map::DisplaySnapshot, + movement::{self, FindRange}, + Bias, CharKind, DisplayPoint, +}; use gpui::{actions, impl_actions, AppContext, WindowContext}; use language::Selection; use serde::Deserialize; @@ -177,18 +182,22 @@ fn in_word( ignore_punctuation: bool, ) -> Option> { // Use motion::right so that we consider the character under the cursor when looking for the start - let language = map.buffer_snapshot.language_at(relative_to.to_point(map)); - let start = movement::find_preceding_boundary_in_line( + let scope = map + .buffer_snapshot + .language_scope_at(relative_to.to_point(map)); + let start = movement::find_preceding_boundary( map, right(map, relative_to, 1), + movement::FindRange::SingleLine, |left, right| { - char_kind(language, left).coerce_punctuation(ignore_punctuation) - != char_kind(language, right).coerce_punctuation(ignore_punctuation) + char_kind(&scope, left).coerce_punctuation(ignore_punctuation) + != char_kind(&scope, right).coerce_punctuation(ignore_punctuation) }, ); - let end = movement::find_boundary_in_line(map, relative_to, |left, right| { - char_kind(language, left).coerce_punctuation(ignore_punctuation) - != char_kind(language, right).coerce_punctuation(ignore_punctuation) + + let end = movement::find_boundary(map, relative_to, FindRange::SingleLine, |left, right| { + char_kind(&scope, left).coerce_punctuation(ignore_punctuation) + != char_kind(&scope, right).coerce_punctuation(ignore_punctuation) }); Some(start..end) @@ -211,11 +220,13 @@ fn around_word( relative_to: DisplayPoint, ignore_punctuation: bool, ) -> Option> { - let language = map.buffer_snapshot.language_at(relative_to.to_point(map)); + let scope = map + .buffer_snapshot + .language_scope_at(relative_to.to_point(map)); let in_word = map .chars_at(relative_to) .next() - .map(|(c, _)| char_kind(language, c) != CharKind::Whitespace) + .map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace) .unwrap_or(false); if in_word { @@ -239,21 +250,24 @@ fn around_next_word( relative_to: DisplayPoint, ignore_punctuation: bool, ) -> Option> { - let language = map.buffer_snapshot.language_at(relative_to.to_point(map)); + let scope = map + .buffer_snapshot + .language_scope_at(relative_to.to_point(map)); // Get the start of the word - let start = movement::find_preceding_boundary_in_line( + let start = movement::find_preceding_boundary( map, right(map, relative_to, 1), + FindRange::SingleLine, |left, right| { - char_kind(language, left).coerce_punctuation(ignore_punctuation) - != char_kind(language, right).coerce_punctuation(ignore_punctuation) + char_kind(&scope, left).coerce_punctuation(ignore_punctuation) + != char_kind(&scope, right).coerce_punctuation(ignore_punctuation) }, ); let mut word_found = false; - let end = movement::find_boundary(map, relative_to, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + let end = movement::find_boundary(map, relative_to, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation); + let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation); let found = (word_found && left_kind != right_kind) || right == '\n' && left == '\n'; @@ -566,11 +580,18 @@ mod test { async fn test_visual_word_object(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await; - cx.set_shared_state("The quick ˇbrown\nfox").await; + /* + cx.set_shared_state("The quick ˇbrown\nfox").await; + cx.simulate_shared_keystrokes(["v"]).await; + cx.assert_shared_state("The quick «bˇ»rown\nfox").await; + cx.simulate_shared_keystrokes(["i", "w"]).await; + cx.assert_shared_state("The quick «brownˇ»\nfox").await; + */ + cx.set_shared_state("The quick brown\nˇ\nfox").await; cx.simulate_shared_keystrokes(["v"]).await; - cx.assert_shared_state("The quick «bˇ»rown\nfox").await; + cx.assert_shared_state("The quick brown\n«\nˇ»fox").await; cx.simulate_shared_keystrokes(["i", "w"]).await; - cx.assert_shared_state("The quick «brownˇ»\nfox").await; + cx.assert_shared_state("The quick brown\n«\nˇ»fox").await; cx.assert_binding_matches_all(["v", "i", "w"], WORD_LOCATIONS) .await; diff --git a/crates/vim/src/test.rs b/crates/vim/src/test.rs index 88fa375851..c6a212d77f 100644 --- a/crates/vim/src/test.rs +++ b/crates/vim/src/test.rs @@ -431,6 +431,24 @@ async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) { twelve char "}) .await; + + // line wraps as: + // fourteen ch + // ar + // fourteen ch + // ar + cx.set_shared_state(indoc! { " + fourteen chaˇr + fourteen char + "}) + .await; + + cx.simulate_shared_keystrokes(["d", "i", "w"]).await; + cx.assert_shared_state(indoc! {" + fourteenˇ• + fourteen char + "}) + .await; } #[gpui::test] diff --git a/crates/vim/src/test/neovim_backed_test_context.rs b/crates/vim/src/test/neovim_backed_test_context.rs index d04b1b7768..b433a6bfc0 100644 --- a/crates/vim/src/test/neovim_backed_test_context.rs +++ b/crates/vim/src/test/neovim_backed_test_context.rs @@ -153,6 +153,7 @@ impl<'a> NeovimBackedTestContext<'a> { } pub async fn assert_shared_state(&mut self, marked_text: &str) { + let marked_text = marked_text.replace("•", " "); let neovim = self.neovim_state().await; let editor = self.editor_state(); if neovim == marked_text && neovim == editor { @@ -184,9 +185,9 @@ impl<'a> NeovimBackedTestContext<'a> { message, initial_state, self.recent_keystrokes.join(" "), - marked_text, - neovim, - editor + marked_text.replace(" \n", "•\n"), + neovim.replace(" \n", "•\n"), + editor.replace(" \n", "•\n") ) } diff --git a/crates/vim/src/test/neovim_connection.rs b/crates/vim/src/test/neovim_connection.rs index 3e59080b13..e44e8d0e4c 100644 --- a/crates/vim/src/test/neovim_connection.rs +++ b/crates/vim/src/test/neovim_connection.rs @@ -237,6 +237,9 @@ impl NeovimConnection { #[cfg(not(feature = "neovim"))] pub async fn set_option(&mut self, value: &str) { + if let Some(NeovimData::Get { .. }) = self.data.front() { + self.data.pop_front(); + }; assert_eq!( self.data.pop_front(), Some(NeovimData::SetOption { diff --git a/crates/vim/test_data/test_end_of_word.json b/crates/vim/test_data/test_end_of_word.json new file mode 100644 index 0000000000..06f80dc245 --- /dev/null +++ b/crates/vim/test_data/test_end_of_word.json @@ -0,0 +1,32 @@ +{"Put":{"state":"Thˇe quick-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"e"} +{"Get":{"state":"The quicˇk-brown\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quickˇ-brown\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumpˇs over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps oveˇr\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Put":{"state":"Thˇe quick-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Put":{"state":"The quicˇk-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Put":{"state":"The quickˇ-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumpˇs over\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps oveˇr\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} diff --git a/crates/vim/test_data/test_substitute_line.json b/crates/vim/test_data/test_substitute_line.json new file mode 100644 index 0000000000..eb0a9825f8 --- /dev/null +++ b/crates/vim/test_data/test_substitute_line.json @@ -0,0 +1,29 @@ +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\nthe lazy dog\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"v"} +{"Key":"k"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"oˇ\nthe lazy dog\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"ctrl-v"} +{"Key":"j"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"v"} +{"Key":"$"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\nthe lazy dog\n","mode":"Insert"}} +{"SetOption":{"value":"shiftwidth=4"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":">"} +{"Key":">"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\n oˇ\nthe lazy dog\n","mode":"Insert"}} diff --git a/crates/vim/test_data/test_visual_word_object.json b/crates/vim/test_data/test_visual_word_object.json index 0041baf969..5e1a9839e9 100644 --- a/crates/vim/test_data/test_visual_word_object.json +++ b/crates/vim/test_data/test_visual_word_object.json @@ -1,9 +1,9 @@ -{"Put":{"state":"The quick ˇbrown\nfox"}} +{"Put":{"state":"The quick brown\nˇ\nfox"}} {"Key":"v"} -{"Get":{"state":"The quick «bˇ»rown\nfox","mode":"Visual"}} +{"Get":{"state":"The quick brown\n«\nˇ»fox","mode":"Visual"}} {"Key":"i"} {"Key":"w"} -{"Get":{"state":"The quick «brownˇ»\nfox","mode":"Visual"}} +{"Get":{"state":"The quick brown\n«\nˇ»fox","mode":"Visual"}} {"Put":{"state":"The quick ˇbrown \nfox jumps over\nthe lazy dog \n\n\n\nThe-quick brown \n \n \n fox-jumps over\nthe lazy dog \n\n"}} {"Key":"v"} {"Key":"i"} diff --git a/crates/vim/test_data/test_wrapped_lines.json b/crates/vim/test_data/test_wrapped_lines.json index 1ebbd4f205..1fbfc935d9 100644 --- a/crates/vim/test_data/test_wrapped_lines.json +++ b/crates/vim/test_data/test_wrapped_lines.json @@ -48,3 +48,8 @@ {"Key":"o"} {"Key":"escape"} {"Get":{"state":"twelve char\nˇo\ntwelve char twelve char\ntwelve char\n","mode":"Normal"}} +{"Put":{"state":"fourteen chaˇr\nfourteen char\n"}} +{"Key":"d"} +{"Key":"i"} +{"Key":"w"} +{"Get":{"state":"fourteenˇ \nfourteen char\n","mode":"Normal"}} diff --git a/crates/workspace/src/toolbar.rs b/crates/workspace/src/toolbar.rs index 72c879d6d4..c3f4bb9723 100644 --- a/crates/workspace/src/toolbar.rs +++ b/crates/workspace/src/toolbar.rs @@ -81,10 +81,7 @@ impl View for Toolbar { ToolbarItemLocation::PrimaryLeft { flex } => { primary_items_row_count = primary_items_row_count.max(item.row_count(cx)); - let left_item = ChildView::new(item.as_any(), cx) - .aligned() - .contained() - .with_margin_right(spacing); + let left_item = ChildView::new(item.as_any(), cx).aligned(); if let Some((flex, expanded)) = flex { primary_left_items.push(left_item.flex(flex, expanded).into_any()); } else { @@ -94,11 +91,7 @@ impl View for Toolbar { ToolbarItemLocation::PrimaryRight { flex } => { primary_items_row_count = primary_items_row_count.max(item.row_count(cx)); - let right_item = ChildView::new(item.as_any(), cx) - .aligned() - .contained() - .with_margin_left(spacing) - .flex_float(); + let right_item = ChildView::new(item.as_any(), cx).aligned().flex_float(); if let Some((flex, expanded)) = flex { primary_right_items.push(right_item.flex(flex, expanded).into_any()); } else { @@ -120,7 +113,7 @@ impl View for Toolbar { let container_style = theme.container; let height = theme.height * primary_items_row_count as f32; - let mut primary_items = Flex::row(); + let mut primary_items = Flex::row().with_spacing(spacing); primary_items.extend(primary_left_items); primary_items.extend(primary_right_items); diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 2a97764647..e102a66519 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] description = "The fast, collaborative code editor." edition = "2021" name = "zed" -version = "0.102.0" +version = "0.104.0" publish = false [lib] diff --git a/crates/zed/src/languages.rs b/crates/zed/src/languages.rs index eb31c08dd2..3fbb5aa14f 100644 --- a/crates/zed/src/languages.rs +++ b/crates/zed/src/languages.rs @@ -6,6 +6,7 @@ use std::{borrow::Cow, str, sync::Arc}; use util::asset_str; mod c; +mod css; mod elixir; mod go; mod html; @@ -18,6 +19,7 @@ mod python; mod ruby; mod rust; mod svelte; +mod tailwind; mod typescript; mod yaml; @@ -35,7 +37,7 @@ mod yaml; #[exclude = "*.rs"] struct LanguageDir; -pub fn init(languages: Arc, node_runtime: Arc) { +pub fn init(languages: Arc, node_runtime: Arc) { let language = |name, grammar, adapters| { languages.register(name, load_config(name), grammar, adapters, load_queries) }; @@ -51,7 +53,14 @@ pub fn init(languages: Arc, node_runtime: Arc) { tree_sitter_cpp::language(), vec![Arc::new(c::CLspAdapter)], ); - language("css", tree_sitter_css::language(), vec![]); + language( + "css", + tree_sitter_css::language(), + vec![ + Arc::new(css::CssLspAdapter::new(node_runtime.clone())), + Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())), + ], + ); language( "elixir", tree_sitter_elixir::language(), @@ -95,6 +104,7 @@ pub fn init(languages: Arc, node_runtime: Arc) { vec![ Arc::new(typescript::TypeScriptLspAdapter::new(node_runtime.clone())), Arc::new(typescript::EsLintLspAdapter::new(node_runtime.clone())), + Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())), ], ); language( @@ -111,12 +121,16 @@ pub fn init(languages: Arc, node_runtime: Arc) { vec![ Arc::new(typescript::TypeScriptLspAdapter::new(node_runtime.clone())), Arc::new(typescript::EsLintLspAdapter::new(node_runtime.clone())), + Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())), ], ); language( "html", tree_sitter_html::language(), - vec![Arc::new(html::HtmlLspAdapter::new(node_runtime.clone()))], + vec![ + Arc::new(html::HtmlLspAdapter::new(node_runtime.clone())), + Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())), + ], ); language( "ruby", diff --git a/crates/zed/src/languages/c.rs b/crates/zed/src/languages/c.rs index c5041136c9..27a65570b6 100644 --- a/crates/zed/src/languages/c.rs +++ b/crates/zed/src/languages/c.rs @@ -19,6 +19,10 @@ impl super::LspAdapter for CLspAdapter { LanguageServerName("clangd".into()) } + fn short_name(&self) -> &'static str { + "clangd" + } + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, diff --git a/crates/zed/src/languages/css.rs b/crates/zed/src/languages/css.rs new file mode 100644 index 0000000000..fdbc179209 --- /dev/null +++ b/crates/zed/src/languages/css.rs @@ -0,0 +1,130 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::StreamExt; +use language::{LanguageServerName, LspAdapter, LspAdapterDelegate}; +use lsp::LanguageServerBinary; +use node_runtime::NodeRuntime; +use serde_json::json; +use smol::fs; +use std::{ + any::Any, + ffi::OsString, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt; + +const SERVER_PATH: &'static str = + "node_modules/vscode-langservers-extracted/bin/vscode-css-language-server"; + +fn server_binary_arguments(server_path: &Path) -> Vec { + vec![server_path.into(), "--stdio".into()] +} + +pub struct CssLspAdapter { + node: Arc, +} + +impl CssLspAdapter { + pub fn new(node: Arc) -> Self { + CssLspAdapter { node } + } +} + +#[async_trait] +impl LspAdapter for CssLspAdapter { + async fn name(&self) -> LanguageServerName { + LanguageServerName("vscode-css-language-server".into()) + } + + fn short_name(&self) -> &'static str { + "css" + } + + async fn fetch_latest_server_version( + &self, + _: &dyn LspAdapterDelegate, + ) -> Result> { + Ok(Box::new( + self.node + .npm_package_latest_version("vscode-langservers-extracted") + .await?, + ) as Box<_>) + } + + async fn fetch_server_binary( + &self, + version: Box, + container_dir: PathBuf, + _: &dyn LspAdapterDelegate, + ) -> Result { + let version = version.downcast::().unwrap(); + let server_path = container_dir.join(SERVER_PATH); + + if fs::metadata(&server_path).await.is_err() { + self.node + .npm_install_packages( + &container_dir, + &[("vscode-langservers-extracted", version.as_str())], + ) + .await?; + } + + Ok(LanguageServerBinary { + path: self.node.binary_path().await?, + arguments: server_binary_arguments(&server_path), + }) + } + + async fn cached_server_binary( + &self, + container_dir: PathBuf, + _: &dyn LspAdapterDelegate, + ) -> Option { + get_cached_server_binary(container_dir, &*self.node).await + } + + async fn installation_test_binary( + &self, + container_dir: PathBuf, + ) -> Option { + get_cached_server_binary(container_dir, &*self.node).await + } + + async fn initialization_options(&self) -> Option { + Some(json!({ + "provideFormatter": true + })) + } +} + +async fn get_cached_server_binary( + container_dir: PathBuf, + node: &dyn NodeRuntime, +) -> Option { + (|| async move { + let mut last_version_dir = None; + let mut entries = fs::read_dir(&container_dir).await?; + while let Some(entry) = entries.next().await { + let entry = entry?; + if entry.file_type().await?.is_dir() { + last_version_dir = Some(entry.path()); + } + } + let last_version_dir = last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?; + let server_path = last_version_dir.join(SERVER_PATH); + if server_path.exists() { + Ok(LanguageServerBinary { + path: node.binary_path().await?, + arguments: server_binary_arguments(&server_path), + }) + } else { + Err(anyhow!( + "missing executable in directory {:?}", + last_version_dir + )) + } + })() + .await + .log_err() +} diff --git a/crates/zed/src/languages/css/config.toml b/crates/zed/src/languages/css/config.toml index ba9660c4ed..05de4be8a3 100644 --- a/crates/zed/src/languages/css/config.toml +++ b/crates/zed/src/languages/css/config.toml @@ -8,3 +8,4 @@ brackets = [ { start = "\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, { start = "'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, ] +word_characters = ["-"] diff --git a/crates/zed/src/languages/elixir.rs b/crates/zed/src/languages/elixir.rs index c32927e15c..b166feda76 100644 --- a/crates/zed/src/languages/elixir.rs +++ b/crates/zed/src/languages/elixir.rs @@ -27,6 +27,10 @@ impl LspAdapter for ElixirLspAdapter { LanguageServerName("elixir-ls".into()) } + fn short_name(&self) -> &'static str { + "elixir-ls" + } + fn will_start_server( &self, delegate: &Arc, diff --git a/crates/zed/src/languages/go.rs b/crates/zed/src/languages/go.rs index d7982f7bdb..19b7013709 100644 --- a/crates/zed/src/languages/go.rs +++ b/crates/zed/src/languages/go.rs @@ -37,6 +37,10 @@ impl super::LspAdapter for GoLspAdapter { LanguageServerName("gopls".into()) } + fn short_name(&self) -> &'static str { + "gopls" + } + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, diff --git a/crates/zed/src/languages/html.rs b/crates/zed/src/languages/html.rs index ecc839fca6..b8f1c70cce 100644 --- a/crates/zed/src/languages/html.rs +++ b/crates/zed/src/languages/html.rs @@ -22,11 +22,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct HtmlLspAdapter { - node: Arc, + node: Arc, } impl HtmlLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { HtmlLspAdapter { node } } } @@ -37,6 +37,10 @@ impl LspAdapter for HtmlLspAdapter { LanguageServerName("vscode-html-language-server".into()) } + fn short_name(&self) -> &'static str { + "html" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -61,7 +65,7 @@ impl LspAdapter for HtmlLspAdapter { self.node .npm_install_packages( &container_dir, - [("vscode-langservers-extracted", version.as_str())], + &[("vscode-langservers-extracted", version.as_str())], ) .await?; } @@ -77,14 +81,14 @@ impl LspAdapter for HtmlLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -96,7 +100,7 @@ impl LspAdapter for HtmlLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/html/config.toml b/crates/zed/src/languages/html/config.toml index 077a421ce1..164e095cee 100644 --- a/crates/zed/src/languages/html/config.toml +++ b/crates/zed/src/languages/html/config.toml @@ -10,3 +10,4 @@ brackets = [ { start = "<", end = ">", close = true, newline = true, not_in = ["comment", "string"] }, { start = "!--", end = " --", close = true, newline = false, not_in = ["comment", "string"] }, ] +word_characters = ["-"] diff --git a/crates/zed/src/languages/javascript/config.toml b/crates/zed/src/languages/javascript/config.toml index 0435f96c92..2394c57539 100644 --- a/crates/zed/src/languages/javascript/config.toml +++ b/crates/zed/src/languages/javascript/config.toml @@ -14,7 +14,12 @@ brackets = [ { start = "/*", end = " */", close = true, newline = false, not_in = ["comment", "string"] }, ] word_characters = ["$", "#"] +scope_opt_in_language_servers = ["tailwindcss-language-server"] [overrides.element] line_comment = { remove = true } block_comment = ["{/* ", " */}"] + +[overrides.string] +word_characters = ["-"] +opt_into_language_servers = ["tailwindcss-language-server"] diff --git a/crates/zed/src/languages/json.rs b/crates/zed/src/languages/json.rs index 61d19ce5b6..63f909ae2a 100644 --- a/crates/zed/src/languages/json.rs +++ b/crates/zed/src/languages/json.rs @@ -27,12 +27,12 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct JsonLspAdapter { - node: Arc, + node: Arc, languages: Arc, } impl JsonLspAdapter { - pub fn new(node: Arc, languages: Arc) -> Self { + pub fn new(node: Arc, languages: Arc) -> Self { JsonLspAdapter { node, languages } } } @@ -43,6 +43,10 @@ impl LspAdapter for JsonLspAdapter { LanguageServerName("json-language-server".into()) } + fn short_name(&self) -> &'static str { + "json" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -67,7 +71,7 @@ impl LspAdapter for JsonLspAdapter { self.node .npm_install_packages( &container_dir, - [("vscode-json-languageserver", version.as_str())], + &[("vscode-json-languageserver", version.as_str())], ) .await?; } @@ -83,14 +87,14 @@ impl LspAdapter for JsonLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -102,7 +106,7 @@ impl LspAdapter for JsonLspAdapter { fn workspace_configuration( &self, cx: &mut AppContext, - ) -> Option> { + ) -> BoxFuture<'static, serde_json::Value> { let action_names = cx.all_action_names().collect::>(); let staff_mode = cx.is_staff(); let language_names = &self.languages.language_names(); @@ -113,29 +117,28 @@ impl LspAdapter for JsonLspAdapter { }, cx, ); - Some( - future::ready(serde_json::json!({ - "json": { - "format": { - "enable": true, + + future::ready(serde_json::json!({ + "json": { + "format": { + "enable": true, + }, + "schemas": [ + { + "fileMatch": [ + schema_file_match(&paths::SETTINGS), + &*paths::LOCAL_SETTINGS_RELATIVE_PATH, + ], + "schema": settings_schema, }, - "schemas": [ - { - "fileMatch": [ - schema_file_match(&paths::SETTINGS), - &*paths::LOCAL_SETTINGS_RELATIVE_PATH, - ], - "schema": settings_schema, - }, - { - "fileMatch": [schema_file_match(&paths::KEYMAP)], - "schema": KeymapFile::generate_json_schema(&action_names), - } - ] - } - })) - .boxed(), - ) + { + "fileMatch": [schema_file_match(&paths::KEYMAP)], + "schema": KeymapFile::generate_json_schema(&action_names), + } + ] + } + })) + .boxed() } async fn language_ids(&self) -> HashMap { @@ -145,7 +148,7 @@ impl LspAdapter for JsonLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/language_plugin.rs b/crates/zed/src/languages/language_plugin.rs index b071936392..b2405d8bb8 100644 --- a/crates/zed/src/languages/language_plugin.rs +++ b/crates/zed/src/languages/language_plugin.rs @@ -70,6 +70,10 @@ impl LspAdapter for PluginLspAdapter { LanguageServerName(name.into()) } + fn short_name(&self) -> &'static str { + "PluginLspAdapter" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, diff --git a/crates/zed/src/languages/lua.rs b/crates/zed/src/languages/lua.rs index 7c5c7179d0..8187847c9a 100644 --- a/crates/zed/src/languages/lua.rs +++ b/crates/zed/src/languages/lua.rs @@ -6,7 +6,7 @@ use futures::{io::BufReader, StreamExt}; use language::{LanguageServerName, LspAdapterDelegate}; use lsp::LanguageServerBinary; use smol::fs; -use std::{any::Any, env::consts, ffi::OsString, path::PathBuf}; +use std::{any::Any, env::consts, path::PathBuf}; use util::{ async_iife, github::{latest_github_release, GitHubLspBinaryVersion}, @@ -16,19 +16,16 @@ use util::{ #[derive(Copy, Clone)] pub struct LuaLspAdapter; -fn server_binary_arguments() -> Vec { - vec![ - "--logpath=~/lua-language-server.log".into(), - "--loglevel=trace".into(), - ] -} - #[async_trait] impl super::LspAdapter for LuaLspAdapter { async fn name(&self) -> LanguageServerName { LanguageServerName("lua-language-server".into()) } + fn short_name(&self) -> &'static str { + "lua" + } + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, @@ -83,7 +80,7 @@ impl super::LspAdapter for LuaLspAdapter { .await?; Ok(LanguageServerBinary { path: binary_path, - arguments: server_binary_arguments(), + arguments: Vec::new(), }) } @@ -127,7 +124,7 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option Vec { pub struct IntelephenseVersion(String); pub struct IntelephenseLspAdapter { - node: Arc, + node: Arc, } impl IntelephenseLspAdapter { const SERVER_PATH: &'static str = "node_modules/intelephense/lib/intelephense.js"; #[allow(unused)] - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { Self { node } } } @@ -41,6 +41,10 @@ impl LspAdapter for IntelephenseLspAdapter { LanguageServerName("intelephense".into()) } + fn short_name(&self) -> &'static str { + "php" + } + async fn fetch_latest_server_version( &self, _delegate: &dyn LspAdapterDelegate, @@ -61,7 +65,7 @@ impl LspAdapter for IntelephenseLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("intelephense", version.0.as_str())]) + .npm_install_packages(&container_dir, &[("intelephense", version.0.as_str())]) .await?; } Ok(LanguageServerBinary { @@ -75,14 +79,14 @@ impl LspAdapter for IntelephenseLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn label_for_completion( @@ -103,7 +107,7 @@ impl LspAdapter for IntelephenseLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/python.rs b/crates/zed/src/languages/python.rs index d89a4171e9..c1539e9590 100644 --- a/crates/zed/src/languages/python.rs +++ b/crates/zed/src/languages/python.rs @@ -20,11 +20,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct PythonLspAdapter { - node: Arc, + node: Arc, } impl PythonLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { PythonLspAdapter { node } } } @@ -35,6 +35,10 @@ impl LspAdapter for PythonLspAdapter { LanguageServerName("pyright".into()) } + fn short_name(&self) -> &'static str { + "pyright" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -53,7 +57,7 @@ impl LspAdapter for PythonLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("pyright", version.as_str())]) + .npm_install_packages(&container_dir, &[("pyright", version.as_str())]) .await?; } @@ -68,14 +72,14 @@ impl LspAdapter for PythonLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn process_completion(&self, item: &mut lsp::CompletionItem) { @@ -158,7 +162,7 @@ impl LspAdapter for PythonLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/ruby.rs b/crates/zed/src/languages/ruby.rs index 358441352a..3890b90dbd 100644 --- a/crates/zed/src/languages/ruby.rs +++ b/crates/zed/src/languages/ruby.rs @@ -12,6 +12,10 @@ impl LspAdapter for RubyLanguageServer { LanguageServerName("solargraph".into()) } + fn short_name(&self) -> &'static str { + "solargraph" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, diff --git a/crates/zed/src/languages/rust.rs b/crates/zed/src/languages/rust.rs index d550d126bb..854eeb7e08 100644 --- a/crates/zed/src/languages/rust.rs +++ b/crates/zed/src/languages/rust.rs @@ -22,6 +22,10 @@ impl LspAdapter for RustLspAdapter { LanguageServerName("rust-analyzer".into()) } + fn short_name(&self) -> &'static str { + "rust" + } + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, diff --git a/crates/zed/src/languages/svelte.rs b/crates/zed/src/languages/svelte.rs index 8416859f5a..5e42d80e77 100644 --- a/crates/zed/src/languages/svelte.rs +++ b/crates/zed/src/languages/svelte.rs @@ -21,11 +21,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct SvelteLspAdapter { - node: Arc, + node: Arc, } impl SvelteLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { SvelteLspAdapter { node } } } @@ -36,6 +36,10 @@ impl LspAdapter for SvelteLspAdapter { LanguageServerName("svelte-language-server".into()) } + fn short_name(&self) -> &'static str { + "svelte" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -60,7 +64,7 @@ impl LspAdapter for SvelteLspAdapter { self.node .npm_install_packages( &container_dir, - [("svelte-language-server", version.as_str())], + &[("svelte-language-server", version.as_str())], ) .await?; } @@ -76,14 +80,14 @@ impl LspAdapter for SvelteLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -95,7 +99,7 @@ impl LspAdapter for SvelteLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/tailwind.rs b/crates/zed/src/languages/tailwind.rs new file mode 100644 index 0000000000..cf07fa71c9 --- /dev/null +++ b/crates/zed/src/languages/tailwind.rs @@ -0,0 +1,161 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use collections::HashMap; +use futures::{ + future::{self, BoxFuture}, + FutureExt, StreamExt, +}; +use gpui::AppContext; +use language::{LanguageServerName, LspAdapter, LspAdapterDelegate}; +use lsp::LanguageServerBinary; +use node_runtime::NodeRuntime; +use serde_json::{json, Value}; +use smol::fs; +use std::{ + any::Any, + ffi::OsString, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt; + +const SERVER_PATH: &'static str = "node_modules/.bin/tailwindcss-language-server"; + +fn server_binary_arguments(server_path: &Path) -> Vec { + vec![server_path.into(), "--stdio".into()] +} + +pub struct TailwindLspAdapter { + node: Arc, +} + +impl TailwindLspAdapter { + pub fn new(node: Arc) -> Self { + TailwindLspAdapter { node } + } +} + +#[async_trait] +impl LspAdapter for TailwindLspAdapter { + async fn name(&self) -> LanguageServerName { + LanguageServerName("tailwindcss-language-server".into()) + } + + fn short_name(&self) -> &'static str { + "tailwind" + } + + async fn fetch_latest_server_version( + &self, + _: &dyn LspAdapterDelegate, + ) -> Result> { + Ok(Box::new( + self.node + .npm_package_latest_version("@tailwindcss/language-server") + .await?, + ) as Box<_>) + } + + async fn fetch_server_binary( + &self, + version: Box, + container_dir: PathBuf, + _: &dyn LspAdapterDelegate, + ) -> Result { + let version = version.downcast::().unwrap(); + let server_path = container_dir.join(SERVER_PATH); + + if fs::metadata(&server_path).await.is_err() { + self.node + .npm_install_packages( + &container_dir, + &[("@tailwindcss/language-server", version.as_str())], + ) + .await?; + } + + Ok(LanguageServerBinary { + path: self.node.binary_path().await?, + arguments: server_binary_arguments(&server_path), + }) + } + + async fn cached_server_binary( + &self, + container_dir: PathBuf, + _: &dyn LspAdapterDelegate, + ) -> Option { + get_cached_server_binary(container_dir, &*self.node).await + } + + async fn installation_test_binary( + &self, + container_dir: PathBuf, + ) -> Option { + get_cached_server_binary(container_dir, &*self.node).await + } + + async fn initialization_options(&self) -> Option { + Some(json!({ + "provideFormatter": true, + "userLanguages": { + "html": "html", + "css": "css", + "javascript": "javascript", + "typescriptreact": "typescriptreact", + }, + })) + } + + fn workspace_configuration(&self, _: &mut AppContext) -> BoxFuture<'static, Value> { + future::ready(json!({ + "tailwindCSS": { + "emmetCompletions": true, + } + })) + .boxed() + } + + async fn language_ids(&self) -> HashMap { + HashMap::from_iter( + [ + ("HTML".to_string(), "html".to_string()), + ("CSS".to_string(), "css".to_string()), + ("JavaScript".to_string(), "javascript".to_string()), + ("TSX".to_string(), "typescriptreact".to_string()), + ] + .into_iter(), + ) + } +} + +async fn get_cached_server_binary( + container_dir: PathBuf, + node: &dyn NodeRuntime, +) -> Option { + (|| async move { + let mut last_version_dir = None; + let mut entries = fs::read_dir(&container_dir).await?; + while let Some(entry) = entries.next().await { + let entry = entry?; + if entry.file_type().await?.is_dir() { + last_version_dir = Some(entry.path()); + } + } + let last_version_dir = last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?; + let server_path = last_version_dir.join(SERVER_PATH); + if server_path.exists() { + Ok(LanguageServerBinary { + path: node.binary_path().await?, + arguments: server_binary_arguments(&server_path), + }) + } else { + Err(anyhow!( + "missing executable in directory {:?}", + last_version_dir + )) + } + })() + .await + .log_err() +} diff --git a/crates/zed/src/languages/tsx/config.toml b/crates/zed/src/languages/tsx/config.toml index 63d1f85e64..a7f99bef5e 100644 --- a/crates/zed/src/languages/tsx/config.toml +++ b/crates/zed/src/languages/tsx/config.toml @@ -13,7 +13,12 @@ brackets = [ { start = "/*", end = " */", close = true, newline = false, not_in = ["string", "comment"] }, ] word_characters = ["#", "$"] +scope_opt_in_language_servers = ["tailwindcss-language-server"] [overrides.element] line_comment = { remove = true } block_comment = ["{/* ", " */}"] + +[overrides.string] +word_characters = ["-"] +opt_into_language_servers = ["tailwindcss-language-server"] diff --git a/crates/zed/src/languages/typescript.rs b/crates/zed/src/languages/typescript.rs index 34a512f300..676d0fd4c0 100644 --- a/crates/zed/src/languages/typescript.rs +++ b/crates/zed/src/languages/typescript.rs @@ -33,14 +33,14 @@ fn eslint_server_binary_arguments(server_path: &Path) -> Vec { } pub struct TypeScriptLspAdapter { - node: Arc, + node: Arc, } impl TypeScriptLspAdapter { const OLD_SERVER_PATH: &'static str = "node_modules/typescript-language-server/lib/cli.js"; const NEW_SERVER_PATH: &'static str = "node_modules/typescript-language-server/lib/cli.mjs"; - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { TypeScriptLspAdapter { node } } } @@ -56,6 +56,10 @@ impl LspAdapter for TypeScriptLspAdapter { LanguageServerName("typescript-language-server".into()) } + fn short_name(&self) -> &'static str { + "tsserver" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -82,7 +86,7 @@ impl LspAdapter for TypeScriptLspAdapter { self.node .npm_install_packages( &container_dir, - [ + &[ ("typescript", version.typescript_version.as_str()), ( "typescript-language-server", @@ -104,14 +108,14 @@ impl LspAdapter for TypeScriptLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_ts_server_binary(container_dir, &self.node).await + get_cached_ts_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_ts_server_binary(container_dir, &self.node).await + get_cached_ts_server_binary(container_dir, &*self.node).await } fn code_action_kinds(&self) -> Option> { @@ -161,7 +165,7 @@ impl LspAdapter for TypeScriptLspAdapter { async fn get_cached_ts_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let old_server_path = container_dir.join(TypeScriptLspAdapter::OLD_SERVER_PATH); @@ -188,38 +192,40 @@ async fn get_cached_ts_server_binary( } pub struct EsLintLspAdapter { - node: Arc, + node: Arc, } impl EsLintLspAdapter { const SERVER_PATH: &'static str = "vscode-eslint/server/out/eslintServer.js"; #[allow(unused)] - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { EsLintLspAdapter { node } } } #[async_trait] impl LspAdapter for EsLintLspAdapter { - fn workspace_configuration(&self, _: &mut AppContext) -> Option> { - Some( - future::ready(json!({ - "": { - "validate": "on", - "rulesCustomizations": [], - "run": "onType", - "nodePath": null, - } - })) - .boxed(), - ) + fn workspace_configuration(&self, _: &mut AppContext) -> BoxFuture<'static, Value> { + future::ready(json!({ + "": { + "validate": "on", + "rulesCustomizations": [], + "run": "onType", + "nodePath": null, + } + })) + .boxed() } async fn name(&self) -> LanguageServerName { LanguageServerName("eslint".into()) } + fn short_name(&self) -> &'static str { + "eslint" + } + async fn fetch_latest_server_version( &self, delegate: &dyn LspAdapterDelegate, @@ -282,14 +288,14 @@ impl LspAdapter for EsLintLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_eslint_server_binary(container_dir, &self.node).await + get_cached_eslint_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_eslint_server_binary(container_dir, &self.node).await + get_cached_eslint_server_binary(container_dir, &*self.node).await } async fn label_for_completion( @@ -307,7 +313,7 @@ impl LspAdapter for EsLintLspAdapter { async fn get_cached_eslint_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { // This is unfortunate but we don't know what the version is to build a path directly diff --git a/crates/zed/src/languages/yaml.rs b/crates/zed/src/languages/yaml.rs index b57c6f5699..8b438d0949 100644 --- a/crates/zed/src/languages/yaml.rs +++ b/crates/zed/src/languages/yaml.rs @@ -25,11 +25,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct YamlLspAdapter { - node: Arc, + node: Arc, } impl YamlLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { YamlLspAdapter { node } } } @@ -40,6 +40,10 @@ impl LspAdapter for YamlLspAdapter { LanguageServerName("yaml-language-server".into()) } + fn short_name(&self) -> &'static str { + "yaml" + } + async fn fetch_latest_server_version( &self, _: &dyn LspAdapterDelegate, @@ -62,7 +66,10 @@ impl LspAdapter for YamlLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("yaml-language-server", version.as_str())]) + .npm_install_packages( + &container_dir, + &[("yaml-language-server", version.as_str())], + ) .await?; } @@ -77,36 +84,35 @@ impl LspAdapter for YamlLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } - fn workspace_configuration(&self, cx: &mut AppContext) -> Option> { + fn workspace_configuration(&self, cx: &mut AppContext) -> BoxFuture<'static, Value> { let tab_size = all_language_settings(None, cx) .language(Some("YAML")) .tab_size; - Some( - future::ready(serde_json::json!({ - "yaml": { - "keyOrdering": false - }, - "[yaml]": { - "editor.tabSize": tab_size, - } - })) - .boxed(), - ) + + future::ready(serde_json::json!({ + "yaml": { + "keyOrdering": false + }, + "[yaml]": { + "editor.tabSize": tab_size, + } + })) + .boxed() } } async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 3e0a8a7a07..f78a4f6419 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -19,7 +19,7 @@ use gpui::{Action, App, AppContext, AssetSource, AsyncAppContext, Task}; use isahc::{config::Configurable, Request}; use language::{LanguageRegistry, Point}; use log::LevelFilter; -use node_runtime::NodeRuntime; +use node_runtime::RealNodeRuntime; use parking_lot::Mutex; use project::Fs; use serde::{Deserialize, Serialize}; @@ -138,7 +138,7 @@ fn main() { languages.set_executor(cx.background().clone()); languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone()); let languages = Arc::new(languages); - let node_runtime = NodeRuntime::instance(http.clone()); + let node_runtime = RealNodeRuntime::new(http.clone()); languages::init(languages.clone(), node_runtime.clone()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx)); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 9ea406fc3e..424bce60f2 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -264,8 +264,9 @@ pub fn initialize_workspace( toolbar.add_item(breadcrumbs, cx); let buffer_search_bar = cx.add_view(BufferSearchBar::new); toolbar.add_item(buffer_search_bar.clone(), cx); - let quick_action_bar = - cx.add_view(|_| QuickActionBar::new(buffer_search_bar)); + let quick_action_bar = cx.add_view(|_| { + QuickActionBar::new(buffer_search_bar, workspace) + }); toolbar.add_item(quick_action_bar, cx); let project_search_bar = cx.add_view(|_| ProjectSearchBar::new()); toolbar.add_item(project_search_bar, cx); @@ -722,7 +723,6 @@ mod tests { AppContext, AssetSource, Element, Entity, TestAppContext, View, ViewHandle, }; use language::LanguageRegistry; - use node_runtime::NodeRuntime; use project::{Project, ProjectPath}; use serde_json::json; use settings::{handle_settings_file_changes, watch_config_file, SettingsStore}; @@ -731,7 +731,6 @@ mod tests { path::{Path, PathBuf}, }; use theme::{ThemeRegistry, ThemeSettings}; - use util::http::FakeHttpClient; use workspace::{ item::{Item, ItemHandle}, open_new, open_paths, pane, NewFile, SplitDirection, WorkspaceHandle, @@ -2363,8 +2362,7 @@ mod tests { let mut languages = LanguageRegistry::test(); languages.set_executor(cx.background().clone()); let languages = Arc::new(languages); - let http = FakeHttpClient::with_404_response(); - let node_runtime = NodeRuntime::instance(http); + let node_runtime = node_runtime::FakeNodeRuntime::new(); languages::init(languages.clone(), node_runtime); for name in languages.language_names() { languages.language_for_name(&name); diff --git a/script/randomized-test-minimize b/script/randomized-test-minimize index ce0b7203b4..df003cbf3e 100755 --- a/script/randomized-test-minimize +++ b/script/randomized-test-minimize @@ -9,7 +9,6 @@ const CARGO_TEST_ARGS = [ '--release', '--lib', '--package', 'collab', - 'random_collaboration', ] if (require.main === module) { @@ -99,7 +98,7 @@ function buildTests() { } function runTests(env) { - const {status, stdout} = spawnSync('cargo', ['test', ...CARGO_TEST_ARGS], { + const {status, stdout} = spawnSync('cargo', ['test', ...CARGO_TEST_ARGS, 'random_project_collaboration'], { stdio: 'pipe', encoding: 'utf8', env: { diff --git a/styles/src/build_themes.ts b/styles/src/build_themes.ts index 17575663a1..4d262f8146 100644 --- a/styles/src/build_themes.ts +++ b/styles/src/build_themes.ts @@ -21,9 +21,7 @@ function clear_themes(theme_directory: string) { } } -const all_themes: Theme[] = themes.map((theme) => - create_theme(theme) -) +const all_themes: Theme[] = themes.map((theme) => create_theme(theme)) function write_themes(themes: Theme[], output_directory: string) { clear_themes(output_directory) @@ -34,10 +32,7 @@ function write_themes(themes: Theme[], output_directory: string) { const style_tree = app() const style_tree_json = JSON.stringify(style_tree, null, 2) const temp_path = path.join(temp_directory, `${theme.name}.json`) - const out_path = path.join( - output_directory, - `${theme.name}.json` - ) + const out_path = path.join(output_directory, `${theme.name}.json`) fs.writeFileSync(temp_path, style_tree_json) fs.renameSync(temp_path, out_path) console.log(`- ${out_path} created`) diff --git a/styles/src/build_tokens.ts b/styles/src/build_tokens.ts index fd6aa18ced..3c52b6d989 100644 --- a/styles/src/build_tokens.ts +++ b/styles/src/build_tokens.ts @@ -83,8 +83,6 @@ function write_tokens(themes: Theme[], tokens_directory: string) { console.log(`- ${METADATA_FILE} created`) } -const all_themes: Theme[] = themes.map((theme) => - create_theme(theme) -) +const all_themes: Theme[] = themes.map((theme) => create_theme(theme)) write_tokens(all_themes, TOKENS_DIRECTORY) diff --git a/styles/src/component/button.ts b/styles/src/component/button.ts index 3b554ae37a..e0e831b082 100644 --- a/styles/src/component/button.ts +++ b/styles/src/component/button.ts @@ -5,7 +5,7 @@ import { TextStyle, background } from "../style_tree/components" // eslint-disable-next-line @typescript-eslint/no-namespace export namespace Button { export type Options = { - layer: Layer, + layer: Layer background: keyof Theme["lowest"] color: keyof Theme["lowest"] variant: Button.Variant @@ -16,13 +16,13 @@ export namespace Button { bottom?: number left?: number right?: number - }, + } states: { - enabled?: boolean, - hovered?: boolean, - pressed?: boolean, - focused?: boolean, - disabled?: boolean, + enabled?: boolean + hovered?: boolean + pressed?: boolean + focused?: boolean + disabled?: boolean } } @@ -38,26 +38,26 @@ export namespace Button { export const CORNER_RADIUS = 6 export const variant = { - Default: 'filled', - Outline: 'outline', - Ghost: 'ghost' + Default: "filled", + Outline: "outline", + Ghost: "ghost", } as const - export type Variant = typeof variant[keyof typeof variant] + export type Variant = (typeof variant)[keyof typeof variant] export const shape = { - Rectangle: 'rectangle', - Square: 'square' + Rectangle: "rectangle", + Square: "square", } as const - export type Shape = typeof shape[keyof typeof shape] + export type Shape = (typeof shape)[keyof typeof shape] export const size = { Small: "sm", - Medium: "md" + Medium: "md", } as const - export type Size = typeof size[keyof typeof size] + export type Size = (typeof size)[keyof typeof size] export type BaseStyle = { corder_radius: number @@ -67,8 +67,8 @@ export namespace Button { bottom: number left: number right: number - }, - margin: Button.Options['margin'] + } + margin: Button.Options["margin"] button_height: number } @@ -81,15 +81,18 @@ export namespace Button { shape: Button.shape.Rectangle, states: { hovered: true, - pressed: true - } + pressed: true, + }, } ): BaseStyle => { const theme = useTheme() const layer = options.layer ?? theme.middle const color = options.color ?? "base" - const background_color = options.variant === Button.variant.Ghost ? null : background(layer, options.background ?? color) + const background_color = + options.variant === Button.variant.Ghost + ? null + : background(layer, options.background ?? color) const m = { top: options.margin?.top ?? 0, @@ -106,8 +109,14 @@ export namespace Button { padding: { top: padding, bottom: padding, - left: options.shape === Button.shape.Rectangle ? padding + Button.RECTANGLE_PADDING : padding, - right: options.shape === Button.shape.Rectangle ? padding + Button.RECTANGLE_PADDING : padding + left: + options.shape === Button.shape.Rectangle + ? padding + Button.RECTANGLE_PADDING + : padding, + right: + options.shape === Button.shape.Rectangle + ? padding + Button.RECTANGLE_PADDING + : padding, }, margin: m, button_height: 16, diff --git a/styles/src/component/icon_button.ts b/styles/src/component/icon_button.ts index 935909afdb..5b7c61b17c 100644 --- a/styles/src/component/icon_button.ts +++ b/styles/src/component/icon_button.ts @@ -11,11 +11,9 @@ export type Margin = { } interface IconButtonOptions { - layer?: - | Theme["lowest"] - | Theme["middle"] - | Theme["highest"] + layer?: Theme["lowest"] | Theme["middle"] | Theme["highest"] color?: keyof Theme["lowest"] + background_color?: keyof Theme["lowest"] margin?: Partial variant?: Button.Variant size?: Button.Size @@ -23,18 +21,25 @@ interface IconButtonOptions { type ToggleableIconButtonOptions = IconButtonOptions & { active_color?: keyof Theme["lowest"] + active_background_color?: keyof Theme["lowest"] active_layer?: Layer + active_variant?: Button.Variant } -export function icon_button({ color, margin, layer, variant, size }: IconButtonOptions = { - variant: Button.variant.Default, - size: Button.size.Medium, -}) { +export function icon_button( + { color, background_color, margin, layer, variant, size }: IconButtonOptions = { + variant: Button.variant.Default, + size: Button.size.Medium, + } +) { const theme = useTheme() if (!color) color = "base" - const background_color = variant === Button.variant.Ghost ? null : background(layer ?? theme.lowest, color) + const default_background = + variant === Button.variant.Ghost + ? null + : background(layer ?? theme.lowest, background_color ?? color) const m = { top: margin?.top ?? 0, @@ -55,42 +60,51 @@ export function icon_button({ color, margin, layer, variant, size }: IconButtonO corner_radius: 6, padding: padding, margin: m, - icon_width: 12, + icon_width: 14, icon_height: 14, button_width: size === Button.size.Small ? 16 : 20, button_height: 14, }, state: { default: { - background: background_color, + background: default_background, color: foreground(layer ?? theme.lowest, color), }, hovered: { - background: background(layer ?? theme.lowest, color, "hovered"), + background: background(layer ?? theme.lowest, background_color ?? color, "hovered"), color: foreground(layer ?? theme.lowest, color, "hovered"), }, clicked: { - background: background(layer ?? theme.lowest, color, "pressed"), + background: background(layer ?? theme.lowest, background_color ?? color, "pressed"), color: foreground(layer ?? theme.lowest, color, "pressed"), }, }, }) } -export function toggleable_icon_button( - theme: Theme, - { color, active_color, margin, variant, size, active_layer }: ToggleableIconButtonOptions -) { +export function toggleable_icon_button({ + color, + background_color, + active_color, + active_background_color, + active_variant, + margin, + variant, + size, + active_layer, +}: ToggleableIconButtonOptions) { if (!color) color = "base" return toggleable({ state: { - inactive: icon_button({ color, margin, variant, size }), + inactive: icon_button({ color, background_color, margin, variant, size }), active: icon_button({ color: active_color ? active_color : color, + background_color: active_background_color ? active_background_color : background_color, margin, layer: active_layer, - size + variant: active_variant || variant, + size, }), }, }) diff --git a/styles/src/component/index.ts b/styles/src/component/index.ts new file mode 100644 index 0000000000..f2cbc7b26a --- /dev/null +++ b/styles/src/component/index.ts @@ -0,0 +1,6 @@ +export * from "./icon_button" +export * from "./indicator" +export * from "./input" +export * from "./tab" +export * from "./tab_bar_button" +export * from "./text_button" diff --git a/styles/src/component/indicator.ts b/styles/src/component/indicator.ts index 81a3b40da7..b3d2105f6a 100644 --- a/styles/src/component/indicator.ts +++ b/styles/src/component/indicator.ts @@ -1,7 +1,13 @@ import { foreground } from "../style_tree/components" import { Layer, StyleSets } from "../theme" -export const indicator = ({ layer, color }: { layer: Layer, color: StyleSets }) => ({ +export const indicator = ({ + layer, + color, +}: { + layer: Layer + color: StyleSets +}) => ({ corner_radius: 4, padding: 4, margin: { top: 12, left: 12 }, diff --git a/styles/src/component/input.ts b/styles/src/component/input.ts index cadfcc8d4a..5921210f88 100644 --- a/styles/src/component/input.ts +++ b/styles/src/component/input.ts @@ -18,6 +18,6 @@ export const input = () => { bottom: 3, left: 12, right: 8, - } + }, } } diff --git a/styles/src/component/label_button.ts b/styles/src/component/label_button.ts deleted file mode 100644 index 3f1c54a7f6..0000000000 --- a/styles/src/component/label_button.ts +++ /dev/null @@ -1,78 +0,0 @@ -import { Interactive, interactive, toggleable, Toggleable } from "../element" -import { TextStyle, background, text } from "../style_tree/components" -import { useTheme } from "../theme" -import { Button } from "./button" - -type LabelButtonStyle = { - corder_radius: number - background: string | null - padding: { - top: number - bottom: number - left: number - right: number - }, - margin: Button.Options['margin'] - button_height: number -} & TextStyle - -/** Styles an Interactive<ContainedText> */ -export function label_button_style( - options: Partial = { - variant: Button.variant.Default, - shape: Button.shape.Rectangle, - states: { - hovered: true, - pressed: true - } - } -): Interactive { - const theme = useTheme() - - const base = Button.button_base(options) - const layer = options.layer ?? theme.middle - const color = options.color ?? "base" - - const default_state = { - ...base, - ...text(layer ?? theme.lowest, "sans", color), - font_size: Button.FONT_SIZE, - } - - return interactive({ - base: default_state, - state: { - hovered: { - background: background(layer, options.background ?? color, "hovered") - }, - clicked: { - background: background(layer, options.background ?? color, "pressed") - } - } - }) -} - -/** Styles an Toggleable<Interactive<ContainedText>> */ -export function toggle_label_button_style( - options: Partial = { - variant: Button.variant.Default, - shape: Button.shape.Rectangle, - states: { - hovered: true, - pressed: true - } - } -): Toggleable> { - const activeOptions = { - ...options, - color: options.active_color || options.color, - background: options.active_background || options.background - } - - return toggleable({ - state: { - inactive: label_button_style(options), - active: label_button_style(activeOptions), - }, - }) -} diff --git a/styles/src/component/tab.ts b/styles/src/component/tab.ts index 9938fb9311..6f73b6f3fb 100644 --- a/styles/src/component/tab.ts +++ b/styles/src/component/tab.ts @@ -9,7 +9,7 @@ type TabProps = { export const tab = ({ layer }: TabProps) => { const active_color = text(layer, "sans", "base").color const inactive_border: Border = { - color: '#FFFFFF00', + color: "#FFFFFF00", width: 1, bottom: true, left: false, @@ -27,7 +27,7 @@ export const tab = ({ layer }: TabProps) => { top: 8, left: 8, right: 8, - bottom: 6 + bottom: 6, }, border: inactive_border, } @@ -35,17 +35,17 @@ export const tab = ({ layer }: TabProps) => { const i = interactive({ state: { default: { - ...base + ...base, }, hovered: { ...base, - ...text(layer, "sans", "base", "hovered") + ...text(layer, "sans", "base", "hovered"), }, clicked: { ...base, - ...text(layer, "sans", "base", "pressed") + ...text(layer, "sans", "base", "pressed"), }, - } + }, }) return toggleable({ @@ -60,14 +60,14 @@ export const tab = ({ layer }: TabProps) => { hovered: { ...i, ...text(layer, "sans", "base", "hovered"), - border: active_border + border: active_border, }, clicked: { ...i, ...text(layer, "sans", "base", "pressed"), - border: active_border + border: active_border, }, - } - } + }, + }, }) } diff --git a/styles/src/component/tab_bar_button.ts b/styles/src/component/tab_bar_button.ts index 0c43e7010e..9e7f9acfc3 100644 --- a/styles/src/component/tab_bar_button.ts +++ b/styles/src/component/tab_bar_button.ts @@ -12,44 +12,47 @@ type TabBarButtonProps = TabBarButtonOptions & { state?: Partial>> } -export function tab_bar_button(theme: Theme, { icon, color = "base" }: TabBarButtonProps) { +export function tab_bar_button( + theme: Theme, + { icon, color = "base" }: TabBarButtonProps +) { const button_spacing = 8 - return ( - interactive({ - base: { - icon: { - color: foreground(theme.middle, color), - asset: icon, - dimensions: { - width: 15, - height: 15, - }, + return interactive({ + base: { + icon: { + color: foreground(theme.middle, color), + asset: icon, + dimensions: { + width: 15, + height: 15, }, + }, + container: { + corner_radius: 4, + padding: { + top: 4, + bottom: 4, + left: 4, + right: 4, + }, + margin: { + left: button_spacing / 2, + right: button_spacing / 2, + }, + }, + }, + state: { + hovered: { container: { - corner_radius: 4, - padding: { - top: 4, bottom: 4, left: 4, right: 4 - }, - margin: { - left: button_spacing / 2, - right: button_spacing / 2, - }, + background: background(theme.middle, color, "hovered"), }, }, - state: { - hovered: { - container: { - background: background(theme.middle, color, "hovered"), - - } - }, - clicked: { - container: { - background: background(theme.middle, color, "pressed"), - } + clicked: { + container: { + background: background(theme.middle, color, "pressed"), }, }, - }) - ) + }, + }) } diff --git a/styles/src/component/text_button.ts b/styles/src/component/text_button.ts index b911cd5b77..0e293e403a 100644 --- a/styles/src/component/text_button.ts +++ b/styles/src/component/text_button.ts @@ -1,5 +1,6 @@ import { interactive, toggleable } from "../element" import { + Border, TextProperties, background, foreground, @@ -10,14 +11,13 @@ import { Button } from "./button" import { Margin } from "./icon_button" interface TextButtonOptions { - layer?: - | Theme["lowest"] - | Theme["middle"] - | Theme["highest"] + layer?: Theme["lowest"] | Theme["middle"] | Theme["highest"] variant?: Button.Variant color?: keyof Theme["lowest"] margin?: Partial + disabled?: boolean text_properties?: TextProperties + border?: Border } type ToggleableTextButtonOptions = TextButtonOptions & { @@ -29,12 +29,17 @@ export function text_button({ color, layer, margin, + disabled, text_properties, + border, }: TextButtonOptions = {}) { const theme = useTheme() if (!color) color = "base" - const background_color = variant === Button.variant.Ghost ? null : background(layer ?? theme.lowest, color) + const background_color = + variant === Button.variant.Ghost + ? null + : background(layer ?? theme.lowest, color) const text_options: TextProperties = { size: "xs", @@ -64,17 +69,42 @@ export function text_button({ }, state: { default: { + border, background: background_color, - color: foreground(layer ?? theme.lowest, color), - }, - hovered: { - background: background(layer ?? theme.lowest, color, "hovered"), - color: foreground(layer ?? theme.lowest, color, "hovered"), - }, - clicked: { - background: background(layer ?? theme.lowest, color, "pressed"), - color: foreground(layer ?? theme.lowest, color, "pressed"), + color: disabled + ? foreground(layer ?? theme.lowest, "disabled") + : foreground(layer ?? theme.lowest, color), }, + hovered: disabled + ? {} + : { + border, + background: background( + layer ?? theme.lowest, + color, + "hovered" + ), + color: foreground( + layer ?? theme.lowest, + color, + "hovered" + ), + }, + clicked: disabled + ? {} + : { + border, + background: background( + layer ?? theme.lowest, + color, + "pressed" + ), + color: foreground( + layer ?? theme.lowest, + color, + "pressed" + ), + }, }, }) } diff --git a/styles/src/element/index.ts b/styles/src/element/index.ts index d41b4e2cc3..0586399fb1 100644 --- a/styles/src/element/index.ts +++ b/styles/src/element/index.ts @@ -1,4 +1,6 @@ import { interactive, Interactive } from "./interactive" import { toggleable, Toggleable } from "./toggle" +export * from "./padding" +export * from "./margin" export { interactive, Interactive, toggleable, Toggleable } diff --git a/styles/src/element/margin.ts b/styles/src/element/margin.ts new file mode 100644 index 0000000000..5bbdd646a8 --- /dev/null +++ b/styles/src/element/margin.ts @@ -0,0 +1,41 @@ +type MarginOptions = { + all?: number + left?: number + right?: number + top?: number + bottom?: number +} + +export type MarginStyle = { + top: number + bottom: number + left: number + right: number +} + +export const margin_style = (options: MarginOptions): MarginStyle => { + const { all, top, bottom, left, right } = options + + if (all !== undefined) + return { + top: all, + bottom: all, + left: all, + right: all, + } + + if ( + top === undefined && + bottom === undefined && + left === undefined && + right === undefined + ) + throw new Error("Margin must have at least one value") + + return { + top: top || 0, + bottom: bottom || 0, + left: left || 0, + right: right || 0, + } +} diff --git a/styles/src/element/padding.ts b/styles/src/element/padding.ts new file mode 100644 index 0000000000..b94e263922 --- /dev/null +++ b/styles/src/element/padding.ts @@ -0,0 +1,41 @@ +type PaddingOptions = { + all?: number + left?: number + right?: number + top?: number + bottom?: number +} + +export type PaddingStyle = { + top: number + bottom: number + left: number + right: number +} + +export const padding_style = (options: PaddingOptions): PaddingStyle => { + const { all, top, bottom, left, right } = options + + if (all !== undefined) + return { + top: all, + bottom: all, + left: all, + right: all, + } + + if ( + top === undefined && + bottom === undefined && + left === undefined && + right === undefined + ) + throw new Error("Padding must have at least one value") + + return { + top: top || 0, + bottom: bottom || 0, + left: left || 0, + right: right || 0, + } +} diff --git a/styles/src/style_tree/assistant.ts b/styles/src/style_tree/assistant.ts index cfc1f8d813..7a41f45e53 100644 --- a/styles/src/style_tree/assistant.ts +++ b/styles/src/style_tree/assistant.ts @@ -1,5 +1,5 @@ import { text, border, background, foreground, TextStyle } from "./components" -import { Interactive, interactive } from "../element" +import { Interactive, interactive, toggleable } from "../element" import { tab_bar_button } from "../component/tab_bar_button" import { StyleSets, useTheme } from "../theme" @@ -8,50 +8,48 @@ type RoleCycleButton = TextStyle & { } // TODO: Replace these with zed types type RemainingTokens = TextStyle & { - background: string, - margin: { top: number, right: number }, + background: string + margin: { top: number; right: number } padding: { - right: number, - left: number, - top: number, - bottom: number, - }, - corner_radius: number, + right: number + left: number + top: number + bottom: number + } + corner_radius: number } export default function assistant(): any { const theme = useTheme() - const interactive_role = (color: StyleSets): Interactive => { - return ( - interactive({ - base: { + const interactive_role = ( + color: StyleSets + ): Interactive => { + return interactive({ + base: { + ...text(theme.highest, "sans", color, { size: "sm" }), + }, + state: { + hovered: { ...text(theme.highest, "sans", color, { size: "sm" }), + background: background(theme.highest, color, "hovered"), }, - state: { - hovered: { - ...text(theme.highest, "sans", color, { size: "sm" }), - background: background(theme.highest, color, "hovered"), - }, - clicked: { - ...text(theme.highest, "sans", color, { size: "sm" }), - background: background(theme.highest, color, "pressed"), - } + clicked: { + ...text(theme.highest, "sans", color, { size: "sm" }), + background: background(theme.highest, color, "pressed"), }, - }) - ) + }, + }) } const tokens_remaining = (color: StyleSets): RemainingTokens => { - return ( - { - ...text(theme.highest, "mono", color, { size: "xs" }), - background: background(theme.highest, "on", "default"), - margin: { top: 12, right: 20 }, - padding: { right: 4, left: 4, top: 1, bottom: 1 }, - corner_radius: 6, - } - ) + return { + ...text(theme.highest, "mono", color, { size: "xs" }), + background: background(theme.highest, "on", "default"), + margin: { top: 12, right: 20 }, + padding: { right: 4, left: 4, top: 1, bottom: 1 }, + corner_radius: 6, + } } return { @@ -59,6 +57,85 @@ export default function assistant(): any { background: background(theme.highest), padding: { left: 12 }, }, + inline: { + background: background(theme.highest), + margin: { top: 3, bottom: 3 }, + border: border(theme.lowest, "on", { + top: true, + bottom: true, + overlay: true, + }), + editor: { + text: text(theme.highest, "mono", "default", { size: "sm" }), + placeholder_text: text(theme.highest, "sans", "on", "disabled"), + selection: theme.players[0], + }, + disabled_editor: { + text: text(theme.highest, "mono", "disabled", { size: "sm" }), + placeholder_text: text(theme.highest, "sans", "on", "disabled"), + selection: { + cursor: text(theme.highest, "mono", "disabled").color, + selection: theme.players[0].selection, + }, + }, + pending_edit_background: background(theme.highest, "positive"), + include_conversation: toggleable({ + base: interactive({ + base: { + icon_size: 12, + color: foreground(theme.highest, "variant"), + + button_width: 12, + background: background(theme.highest, "on"), + corner_radius: 2, + border: { + width: 1., color: background(theme.highest, "on") + }, + padding: { + left: 4, + right: 4, + top: 4, + bottom: 4, + }, + }, + state: { + hovered: { + ...text(theme.highest, "mono", "variant", "hovered"), + background: background(theme.highest, "on", "hovered"), + border: { + width: 1., color: background(theme.highest, "on", "hovered") + }, + }, + clicked: { + ...text(theme.highest, "mono", "variant", "pressed"), + background: background(theme.highest, "on", "pressed"), + border: { + width: 1., color: background(theme.highest, "on", "pressed") + }, + }, + }, + }), + state: { + active: { + default: { + icon_size: 12, + button_width: 12, + color: foreground(theme.highest, "variant"), + background: background(theme.highest, "accent"), + border: border(theme.highest, "accent"), + }, + hovered: { + background: background(theme.highest, "accent", "hovered"), + border: border(theme.highest, "accent", "hovered"), + }, + clicked: { + background: background(theme.highest, "accent", "pressed"), + border: border(theme.highest, "accent", "pressed"), + }, + }, + }, + }), + }, message_header: { margin: { bottom: 4, top: 4 }, background: background(theme.highest), @@ -93,7 +170,10 @@ export default function assistant(): any { base: { background: background(theme.middle), padding: { top: 4, bottom: 4 }, - border: border(theme.middle, "default", { top: true, overlay: true }), + border: border(theme.middle, "default", { + top: true, + overlay: true, + }), }, state: { hovered: { @@ -101,7 +181,7 @@ export default function assistant(): any { }, clicked: { background: background(theme.middle, "pressed"), - } + }, }, }), saved_at: { diff --git a/styles/src/style_tree/collab_modals.ts b/styles/src/style_tree/collab_modals.ts index 0f50e01a39..f9b22b6867 100644 --- a/styles/src/style_tree/collab_modals.ts +++ b/styles/src/style_tree/collab_modals.ts @@ -39,7 +39,12 @@ export default function channel_modal(): any { row_height: ITEM_HEIGHT, header: { background: background(theme.lowest), - border: border(theme.middle, { "bottom": true, "top": false, left: false, right: false }), + border: border(theme.middle, { + bottom: true, + top: false, + left: false, + right: false, + }), padding: { top: SPACING, left: SPACING - BUTTON_OFFSET, @@ -48,7 +53,7 @@ export default function channel_modal(): any { corner_radii: { top_right: 12, top_left: 12, - } + }, }, body: { background: background(theme.middle), @@ -57,12 +62,11 @@ export default function channel_modal(): any { left: SPACING, right: SPACING, bottom: SPACING, - }, corner_radii: { bottom_right: 12, bottom_left: 12, - } + }, }, modal: { background: background(theme.middle), @@ -74,7 +78,6 @@ export default function channel_modal(): any { right: 0, top: 0, }, - }, // FIXME: due to a bug in the picker's size calculation, this must be 600 max_height: 600, @@ -83,7 +86,7 @@ export default function channel_modal(): any { ...text(theme.middle, "sans", "on", { size: "lg" }), padding: { left: BUTTON_OFFSET, - } + }, }, picker: { empty_container: {}, @@ -108,8 +111,8 @@ export default function channel_modal(): any { background: background(theme.middle), padding: { left: 7, - right: 7 - } + right: 7, + }, }, cancel_invite_button: { ...text(theme.middle, "sans", { size: "xs" }), @@ -125,7 +128,7 @@ export default function channel_modal(): any { padding: { left: 4, right: 4, - } + }, }, contact_avatar: { corner_radius: 10, @@ -147,6 +150,6 @@ export default function channel_modal(): any { background: background(theme.middle, "disabled"), color: foreground(theme.middle, "disabled"), }, - } + }, } } diff --git a/styles/src/style_tree/collab_panel.ts b/styles/src/style_tree/collab_panel.ts index 07f367c8af..4d605d118c 100644 --- a/styles/src/style_tree/collab_panel.ts +++ b/styles/src/style_tree/collab_panel.ts @@ -27,7 +27,7 @@ export default function contacts_panel(): any { color: foreground(layer, "on"), icon_width: 14, button_width: 16, - corner_radius: 8 + corner_radius: 8, } const project_row = { @@ -61,7 +61,7 @@ export default function contacts_panel(): any { width: 14, } - const header_icon_button = toggleable_icon_button(theme, { + const header_icon_button = toggleable_icon_button({ variant: "ghost", size: "sm", active_layer: theme.lowest, @@ -275,7 +275,7 @@ export default function contacts_panel(): any { list_empty_label_container: { margin: { left: NAME_MARGIN, - } + }, }, list_empty_icon: { color: foreground(layer, "variant"), @@ -289,7 +289,7 @@ export default function contacts_panel(): any { top: SPACING / 2, bottom: SPACING / 2, left: SPACING, - right: SPACING + right: SPACING, }, }, state: { @@ -330,7 +330,7 @@ export default function contacts_panel(): any { right: 4, }, background: background(layer, "hovered"), - ...text(layer, "sans", "hovered", { size: "xs" }) + ...text(layer, "sans", "hovered", { size: "xs" }), }, contact_status_free: indicator({ layer, color: "positive" }), contact_status_busy: indicator({ layer, color: "negative" }), @@ -404,7 +404,7 @@ export default function contacts_panel(): any { channel_editor: { padding: { left: NAME_MARGIN, - } - } + }, + }, } } diff --git a/styles/src/style_tree/component_test.ts b/styles/src/style_tree/component_test.ts index e2bb0915c1..71057c67ea 100644 --- a/styles/src/style_tree/component_test.ts +++ b/styles/src/style_tree/component_test.ts @@ -1,4 +1,3 @@ - import { useTheme } from "../common" import { text_button } from "../component/text_button" import { icon_button } from "../component/icon_button" @@ -14,14 +13,14 @@ export default function contacts_panel(): any { base: text_button({}), state: { active: { - ...text_button({ color: "accent" }) - } - } + ...text_button({ color: "accent" }), + }, + }, }), disclosure: { ...text(theme.lowest, "sans", "base"), button: icon_button({ variant: "ghost" }), spacing: 4, - } + }, } } diff --git a/styles/src/style_tree/contacts_popover.ts b/styles/src/style_tree/contacts_popover.ts index 0e76bbb38a..dcd84c3252 100644 --- a/styles/src/style_tree/contacts_popover.ts +++ b/styles/src/style_tree/contacts_popover.ts @@ -3,5 +3,4 @@ import { background, border } from "./components" export default function contacts_popover(): any { const theme = useTheme() - } diff --git a/styles/src/style_tree/editor.ts b/styles/src/style_tree/editor.ts index 9277a2e7a1..8fd512c5d4 100644 --- a/styles/src/style_tree/editor.ts +++ b/styles/src/style_tree/editor.ts @@ -206,6 +206,9 @@ export default function editor(): any { match_highlight: foreground(theme.middle, "accent", "active"), background: background(theme.middle, "active"), }, + server_name_container: { padding: { left: 40 } }, + server_name_color: text(theme.middle, "sans", "disabled", {}).color, + server_name_size_percent: 0.75, }, diagnostic_header: { background: background(theme.middle), @@ -307,7 +310,7 @@ export default function editor(): any { ? with_opacity(theme.ramps.green(0.5).hex(), 0.8) : with_opacity(theme.ramps.green(0.4).hex(), 0.8), }, - selections: foreground(layer, "accent") + selections: foreground(layer, "accent"), }, composition_mark: { underline: { diff --git a/styles/src/style_tree/feedback.ts b/styles/src/style_tree/feedback.ts index b1bd96e165..0349359533 100644 --- a/styles/src/style_tree/feedback.ts +++ b/styles/src/style_tree/feedback.ts @@ -37,7 +37,7 @@ export default function feedback(): any { ...text(theme.highest, "mono", "on", "disabled"), background: background(theme.highest, "on", "disabled"), border: border(theme.highest, "on", "disabled"), - } + }, }, }), button_margin: 8, diff --git a/styles/src/style_tree/picker.ts b/styles/src/style_tree/picker.ts index 28ae854787..317f600b1e 100644 --- a/styles/src/style_tree/picker.ts +++ b/styles/src/style_tree/picker.ts @@ -152,7 +152,7 @@ export default function picker(): any { 0.5 ), }, - } + }, }), } } diff --git a/styles/src/style_tree/project_panel.ts b/styles/src/style_tree/project_panel.ts index e239f9a840..51958af145 100644 --- a/styles/src/style_tree/project_panel.ts +++ b/styles/src/style_tree/project_panel.ts @@ -64,17 +64,17 @@ export default function project_panel(): any { const unselected_default_style = merge( base_properties, unselected?.default ?? {}, - {}, + {} ) const unselected_hovered_style = merge( base_properties, { background: background(theme.middle, "hovered") }, - unselected?.hovered ?? {}, + unselected?.hovered ?? {} ) const unselected_clicked_style = merge( base_properties, { background: background(theme.middle, "pressed") }, - unselected?.clicked ?? {}, + unselected?.clicked ?? {} ) const selected_default_style = merge( base_properties, @@ -82,7 +82,7 @@ export default function project_panel(): any { background: background(theme.lowest), text: text(theme.lowest, "sans", { size: "sm" }), }, - selected_style?.default ?? {}, + selected_style?.default ?? {} ) const selected_hovered_style = merge( base_properties, @@ -90,7 +90,7 @@ export default function project_panel(): any { background: background(theme.lowest, "hovered"), text: text(theme.lowest, "sans", { size: "sm" }), }, - selected_style?.hovered ?? {}, + selected_style?.hovered ?? {} ) const selected_clicked_style = merge( base_properties, @@ -98,7 +98,7 @@ export default function project_panel(): any { background: background(theme.lowest, "pressed"), text: text(theme.lowest, "sans", { size: "sm" }), }, - selected_style?.clicked ?? {}, + selected_style?.clicked ?? {} ) return toggleable({ @@ -175,7 +175,7 @@ export default function project_panel(): any { default: { icon_color: foreground(theme.middle, "variant"), }, - }, + } ), cut_entry: entry( { @@ -190,7 +190,7 @@ export default function project_panel(): any { size: "sm", }), }, - }, + } ), filename_editor: { background: background(theme.middle, "on"), diff --git a/styles/src/style_tree/search.ts b/styles/src/style_tree/search.ts index 4493634a8e..8174690fde 100644 --- a/styles/src/style_tree/search.ts +++ b/styles/src/style_tree/search.ts @@ -2,9 +2,23 @@ import { with_opacity } from "../theme/color" import { background, border, foreground, text } from "./components" import { interactive, toggleable } from "../element" import { useTheme } from "../theme" +import { text_button } from "../component/text_button" + +const search_results = () => { + const theme = useTheme() + + return { + // TODO: Add an activeMatchBackground on the rust side to differentiate between active and inactive + match_background: with_opacity( + foreground(theme.highest, "accent"), + 0.4 + ), + } +} export default function search(): any { const theme = useTheme() + const SEARCH_ROW_SPACING = 12 // Search input const editor = { @@ -17,7 +31,7 @@ export default function search(): any { text: text(theme.highest, "mono", "default"), border: border(theme.highest), margin: { - right: 9, + right: SEARCH_ROW_SPACING, }, padding: { top: 4, @@ -34,12 +48,8 @@ export default function search(): any { } return { - padding: { top: 16, bottom: 16, left: 16, right: 16 }, - // TODO: Add an activeMatchBackground on the rust side to differentiate between active and inactive - match_background: with_opacity( - foreground(theme.highest, "accent"), - 0.4 - ), + padding: { top: 4, bottom: 4 }, + option_button: toggleable({ base: interactive({ base: { @@ -50,7 +60,8 @@ export default function search(): any { corner_radius: 2, margin: { right: 2 }, border: { - width: 1., color: background(theme.highest, "on") + width: 1, + color: background(theme.highest, "on"), }, padding: { left: 4, @@ -64,14 +75,16 @@ export default function search(): any { ...text(theme.highest, "mono", "variant", "hovered"), background: background(theme.highest, "on", "hovered"), border: { - width: 1., color: background(theme.highest, "on", "hovered") + width: 1, + color: background(theme.highest, "on", "hovered"), }, }, clicked: { ...text(theme.highest, "mono", "variant", "pressed"), background: background(theme.highest, "on", "pressed"), border: { - width: 1., color: background(theme.highest, "on", "pressed") + width: 1, + color: background(theme.highest, "on", "pressed"), }, }, }, @@ -86,11 +99,19 @@ export default function search(): any { border: border(theme.highest, "accent"), }, hovered: { - background: background(theme.highest, "accent", "hovered"), + background: background( + theme.highest, + "accent", + "hovered" + ), border: border(theme.highest, "accent", "hovered"), }, clicked: { - background: background(theme.highest, "accent", "pressed"), + background: background( + theme.highest, + "accent", + "pressed" + ), border: border(theme.highest, "accent", "pressed"), }, }, @@ -107,7 +128,8 @@ export default function search(): any { corner_radius: 2, margin: { right: 2 }, border: { - width: 1., color: background(theme.highest, "on") + width: 1, + color: background(theme.highest, "on"), }, padding: { left: 4, @@ -121,14 +143,16 @@ export default function search(): any { ...text(theme.highest, "mono", "variant", "hovered"), background: background(theme.highest, "on", "hovered"), border: { - width: 1., color: background(theme.highest, "on", "hovered") + width: 1, + color: background(theme.highest, "on", "hovered"), }, }, clicked: { ...text(theme.highest, "mono", "variant", "pressed"), background: background(theme.highest, "on", "pressed"), border: { - width: 1., color: background(theme.highest, "on", "pressed") + width: 1, + color: background(theme.highest, "on", "pressed"), }, }, }, @@ -143,58 +167,43 @@ export default function search(): any { border: border(theme.highest, "accent"), }, hovered: { - background: background(theme.highest, "accent", "hovered"), + background: background( + theme.highest, + "accent", + "hovered" + ), border: border(theme.highest, "accent", "hovered"), }, clicked: { - background: background(theme.highest, "accent", "pressed"), + background: background( + theme.highest, + "accent", + "pressed" + ), border: border(theme.highest, "accent", "pressed"), }, }, }, }), + // Search tool buttons + // HACK: This is not how disabled elements should be created + // Disabled elements should use a disabled state of an interactive element, not a toggleable element with the inactive state being disabled action_button: toggleable({ - base: interactive({ - base: { - ...text(theme.highest, "mono", "disabled"), - background: background(theme.highest, "disabled"), - corner_radius: 6, - border: border(theme.highest, "disabled"), - padding: { - // bottom: 2, - left: 10, - right: 10, - // top: 2, - }, - margin: { - right: 9, - } - }, - state: { - hovered: {} - }, - }), state: { - active: interactive({ - base: { - ...text(theme.highest, "mono", "on"), - background: background(theme.highest, "on"), - border: border(theme.highest, "on"), - }, - state: { - hovered: { - ...text(theme.highest, "mono", "on", "hovered"), - background: background(theme.highest, "on", "hovered"), - border: border(theme.highest, "on", "hovered"), - }, - clicked: { - ...text(theme.highest, "mono", "on", "pressed"), - background: background(theme.highest, "on", "pressed"), - border: border(theme.highest, "on", "pressed"), - }, - }, - }) - } + inactive: text_button({ + variant: "ghost", + layer: theme.highest, + disabled: true, + margin: { right: SEARCH_ROW_SPACING }, + text_properties: { size: "sm" }, + }), + active: text_button({ + variant: "ghost", + layer: theme.highest, + margin: { right: SEARCH_ROW_SPACING }, + text_properties: { size: "sm" }, + }), + }, }), editor, invalid_editor: { @@ -207,15 +216,15 @@ export default function search(): any { border: border(theme.highest, "negative"), }, match_index: { - ...text(theme.highest, "mono", "variant"), + ...text(theme.highest, "mono", { size: "sm" }), padding: { - left: 9, + right: SEARCH_ROW_SPACING, }, }, option_button_group: { padding: { - left: 12, - right: 12, + left: SEARCH_ROW_SPACING, + right: SEARCH_ROW_SPACING, }, }, include_exclude_inputs: { @@ -232,60 +241,37 @@ export default function search(): any { ...text(theme.highest, "mono", "variant"), size: 13, }, - dismiss_button: interactive({ - base: { - color: foreground(theme.highest, "variant"), - icon_width: 14, - button_width: 32, - corner_radius: 6, - padding: { - // // top: 10, - // bottom: 10, - left: 10, - right: 10, - }, - - background: background(theme.highest, "variant"), - - border: border(theme.highest, "on"), - }, - state: { - hovered: { - color: foreground(theme.highest, "hovered"), - background: background(theme.highest, "variant", "hovered") - }, - clicked: { - color: foreground(theme.highest, "pressed"), - background: background(theme.highest, "variant", "pressed") - }, - }, - }), + // Input Icon editor_icon: { icon: { - color: foreground(theme.highest, "variant"), - asset: "icons/magnifying_glass_12.svg", + color: foreground(theme.highest, "disabled"), + asset: "icons/magnifying_glass.svg", dimensions: { - width: 12, - height: 12, - } + width: 14, + height: 14, + }, }, container: { - margin: { right: 6 }, - padding: { left: 2, right: 2 }, - } + margin: { right: 4 }, + padding: { left: 1, right: 1 }, + }, }, + // Toggle group buttons - Text | Regex | Semantic mode_button: toggleable({ base: interactive({ base: { - ...text(theme.highest, "mono", "variant"), + ...text(theme.highest, "mono", "variant", { size: "sm" }), background: background(theme.highest, "variant"), border: { ...border(theme.highest, "on"), left: false, - right: false + right: false, + }, + margin: { + top: 1, + bottom: 1, }, - padding: { left: 10, right: 10, @@ -294,13 +280,25 @@ export default function search(): any { }, state: { hovered: { - ...text(theme.highest, "mono", "variant", "hovered"), - background: background(theme.highest, "variant", "hovered"), + ...text(theme.highest, "mono", "variant", "hovered", { + size: "sm", + }), + background: background( + theme.highest, + "variant", + "hovered" + ), border: border(theme.highest, "on", "hovered"), }, clicked: { - ...text(theme.highest, "mono", "variant", "pressed"), - background: background(theme.highest, "variant", "pressed"), + ...text(theme.highest, "mono", "variant", "pressed", { + size: "sm", + }), + background: background( + theme.highest, + "variant", + "pressed" + ), border: border(theme.highest, "on", "pressed"), }, }, @@ -308,20 +306,27 @@ export default function search(): any { state: { active: { default: { - ...text(theme.highest, "mono", "on"), - background: background(theme.highest, "on") + ...text(theme.highest, "mono", "on", { size: "sm" }), + background: background(theme.highest, "on"), }, hovered: { - ...text(theme.highest, "mono", "on", "hovered"), - background: background(theme.highest, "on", "hovered") + ...text(theme.highest, "mono", "on", "hovered", { + size: "sm", + }), + background: background(theme.highest, "on", "hovered"), }, clicked: { - ...text(theme.highest, "mono", "on", "pressed"), - background: background(theme.highest, "on", "pressed") + ...text(theme.highest, "mono", "on", "pressed", { + size: "sm", + }), + background: background(theme.highest, "on", "pressed"), }, }, }, }), + // Next/Previous Match buttons + // HACK: This is not how disabled elements should be created + // Disabled elements should use a disabled state of an interactive element, not a toggleable element with the inactive state being disabled nav_button: toggleable({ state: { inactive: interactive({ @@ -334,15 +339,18 @@ export default function search(): any { left: false, right: false, }, - + margin: { + top: 1, + bottom: 1, + }, padding: { left: 10, right: 10, }, }, state: { - hovered: {} - } + hovered: {}, + }, }), active: interactive({ base: { @@ -354,7 +362,10 @@ export default function search(): any { left: false, right: false, }, - + margin: { + top: 1, + bottom: 1, + }, padding: { left: 10, right: 10, @@ -363,25 +374,30 @@ export default function search(): any { state: { hovered: { ...text(theme.highest, "mono", "on", "hovered"), - background: background(theme.highest, "on", "hovered"), + background: background( + theme.highest, + "on", + "hovered" + ), border: border(theme.highest, "on", "hovered"), }, clicked: { ...text(theme.highest, "mono", "on", "pressed"), - background: background(theme.highest, "on", "pressed"), + background: background( + theme.highest, + "on", + "pressed" + ), border: border(theme.highest, "on", "pressed"), }, }, - }) - } + }), + }, }), - search_bar_row_height: 32, + search_bar_row_height: 34, + search_row_spacing: 8, option_button_height: 22, - modes_container: { - margin: { - right: 9 - } - } - + modes_container: {}, + ...search_results(), } } diff --git a/styles/src/style_tree/status_bar.ts b/styles/src/style_tree/status_bar.ts index 2d3b81f7c2..b279bbac14 100644 --- a/styles/src/style_tree/status_bar.ts +++ b/styles/src/style_tree/status_bar.ts @@ -34,9 +34,11 @@ export default function status_bar(): any { ...text(layer, "mono", "base", { size: "xs" }), }, active_language: text_button({ - color: "base" + color: "base", + }), + auto_update_progress_message: text(layer, "sans", "base", { + size: "xs", }), - auto_update_progress_message: text(layer, "sans", "base", { size: "xs" }), auto_update_done_message: text(layer, "sans", "base", { size: "xs" }), lsp_status: interactive({ base: { @@ -73,34 +75,36 @@ export default function status_bar(): any { icon_color_error: foreground(layer, "negative"), container_ok: { corner_radius: 6, - padding: { top: 3, bottom: 3, left: 7, right: 7 }, - }, - container_warning: { - ...diagnostic_status_container, - background: background(layer, "warning"), - border: border(layer, "warning"), - }, - container_error: { - ...diagnostic_status_container, - background: background(layer, "negative"), - border: border(layer, "negative"), + padding: { top: 2, bottom: 2, left: 6, right: 6 }, }, + container_warning: diagnostic_status_container, + container_error: diagnostic_status_container }, state: { hovered: { icon_color_ok: foreground(layer, "on"), container_ok: { - background: background(layer, "on", "hovered"), + background: background(layer, "hovered") }, container_warning: { - background: background(layer, "warning", "hovered"), - border: border(layer, "warning", "hovered"), + background: background(layer, "hovered") }, container_error: { - background: background(layer, "negative", "hovered"), - border: border(layer, "negative", "hovered"), + background: background(layer, "hovered") }, }, + clicked: { + icon_color_ok: foreground(layer, "on"), + container_ok: { + background: background(layer, "pressed") + }, + container_warning: { + background: background(layer, "pressed") + }, + container_error: { + background: background(layer, "pressed") + } + } }, }), panel_buttons: { @@ -125,7 +129,7 @@ export default function status_bar(): any { }, clicked: { background: background(layer, "pressed"), - } + }, }, }), state: { diff --git a/styles/src/style_tree/tab_bar.ts b/styles/src/style_tree/tab_bar.ts index 129bd17869..23ff03a6a3 100644 --- a/styles/src/style_tree/tab_bar.ts +++ b/styles/src/style_tree/tab_bar.ts @@ -93,7 +93,7 @@ export default function tab_bar(): any { border: border(theme.lowest, "on", { bottom: true, overlay: true, - }) + }), }, state: { hovered: { @@ -101,7 +101,7 @@ export default function tab_bar(): any { background: background(theme.highest, "on", "hovered"), }, disabled: { - color: foreground(theme.highest, "on", "disabled") + color: foreground(theme.highest, "on", "disabled"), }, }, }) @@ -162,6 +162,6 @@ export default function tab_bar(): any { right: false, }, }, - nav_button: nav_button + nav_button: nav_button, } } diff --git a/styles/src/style_tree/titlebar.ts b/styles/src/style_tree/titlebar.ts index 0a0b69e596..672907b22c 100644 --- a/styles/src/style_tree/titlebar.ts +++ b/styles/src/style_tree/titlebar.ts @@ -1,8 +1,6 @@ -import { icon_button, toggleable_icon_button } from "../component/icon_button" -import { toggleable_text_button } from "../component/text_button" +import { icon_button, toggleable_icon_button, toggleable_text_button } from "../component" import { interactive, toggleable } from "../element" -import { useTheme } from "../theme" -import { with_opacity } from "../theme/color" +import { useTheme, with_opacity } from "../theme" import { background, border, foreground, text } from "./components" const ITEM_SPACING = 8 @@ -34,16 +32,17 @@ function call_controls() { } return { - toggle_microphone_button: toggleable_icon_button(theme, { + toggle_microphone_button: toggleable_icon_button({ margin: { ...margin_y, left: space.group, right: space.half_item, }, active_color: "negative", + active_background_color: "negative", }), - toggle_speakers_button: toggleable_icon_button(theme, { + toggle_speakers_button: toggleable_icon_button({ margin: { ...margin_y, left: space.half_item, @@ -51,13 +50,14 @@ function call_controls() { }, }), - screen_share_button: toggleable_icon_button(theme, { + screen_share_button: toggleable_icon_button({ margin: { ...margin_y, left: space.half_item, right: space.group, }, active_color: "accent", + active_background_color: "accent", }), muted: foreground(theme.lowest, "negative"), @@ -183,14 +183,12 @@ export function titlebar(): any { height: 400, }, - // Project - project_name_divider: text(theme.lowest, "sans", "variant"), - project_menu_button: toggleable_text_button(theme, { - color: 'base', + color: "base" }), + git_menu_button: toggleable_text_button(theme, { - color: 'variant', + color: "variant", }), // Collaborators @@ -263,7 +261,7 @@ export function titlebar(): any { ...call_controls(), - toggle_contacts_button: toggleable_icon_button(theme, { + toggle_contacts_button: toggleable_icon_button({ margin: { left: ITEM_SPACING, }, diff --git a/styles/src/style_tree/toolbar.ts b/styles/src/style_tree/toolbar.ts new file mode 100644 index 0000000000..adf8fb866f --- /dev/null +++ b/styles/src/style_tree/toolbar.ts @@ -0,0 +1,58 @@ +import { useTheme } from "../common" +import { toggleable_icon_button } from "../component/icon_button" +import { interactive, toggleable } from "../element" +import { background, border, foreground, text } from "./components" +import { text_button } from "../component"; + +export const toolbar = () => { + const theme = useTheme() + + return { + height: 42, + padding: { left: 4, right: 4 }, + background: background(theme.highest), + border: border(theme.highest, { bottom: true }), + item_spacing: 4, + toggleable_tool: toggleable_icon_button({ + margin: { left: 4 }, + variant: "ghost", + active_color: "accent", + }), + breadcrumb_height: 24, + breadcrumbs: interactive({ + base: { + ...text(theme.highest, "sans", "variant"), + corner_radius: 6, + padding: { + left: 6, + right: 6, + }, + }, + state: { + hovered: { + color: foreground(theme.highest, "on", "hovered"), + background: background(theme.highest, "on", "hovered"), + }, + }, + }), + toggleable_text_tool: toggleable({ + state: { + inactive: text_button({ + disabled: true, + variant: "ghost", + layer: theme.highest, + margin: { left: 4 }, + text_properties: { size: "sm" }, + border: border(theme.middle), + }), + active: text_button({ + variant: "ghost", + layer: theme.highest, + margin: { left: 4 }, + text_properties: { size: "sm" }, + border: border(theme.middle), + }), + } + }), + } +} diff --git a/styles/src/style_tree/workspace.ts b/styles/src/style_tree/workspace.ts index ecfb572f7e..ba89c7b05f 100644 --- a/styles/src/style_tree/workspace.ts +++ b/styles/src/style_tree/workspace.ts @@ -12,7 +12,7 @@ import tabBar from "./tab_bar" import { interactive } from "../element" import { titlebar } from "./titlebar" import { useTheme } from "../theme" -import { toggleable_icon_button } from "../component/icon_button" +import { toolbar } from "./toolbar" export default function workspace(): any { const theme = useTheme() @@ -128,35 +128,7 @@ export default function workspace(): any { }, status_bar: statusBar(), titlebar: titlebar(), - toolbar: { - height: 34, - background: background(theme.highest), - border: border(theme.highest, { bottom: true }), - item_spacing: 8, - toggleable_tool: toggleable_icon_button(theme, { - margin: { left: 8 }, - variant: "ghost", - active_color: "accent", - }), - padding: { left: 8, right: 8, top: 4, bottom: 4 }, - }, - breadcrumb_height: 24, - breadcrumbs: interactive({ - base: { - ...text(theme.highest, "sans", "variant"), - corner_radius: 6, - padding: { - left: 6, - right: 6, - }, - }, - state: { - hovered: { - color: foreground(theme.highest, "on", "hovered"), - background: background(theme.highest, "on", "hovered"), - }, - }, - }), + toolbar: toolbar(), disconnected_overlay: { ...text(theme.lowest, "sans"), background: with_opacity(background(theme.lowest), 0.8), diff --git a/styles/src/theme/create_theme.ts b/styles/src/theme/create_theme.ts index ab3c96f280..61471616fb 100644 --- a/styles/src/theme/create_theme.ts +++ b/styles/src/theme/create_theme.ts @@ -13,16 +13,16 @@ export interface Theme { is_light: boolean /** - * App background, other elements that should sit directly on top of the background. - */ + * App background, other elements that should sit directly on top of the background. + */ lowest: Layer /** - * Panels, tabs, other UI surfaces that sit on top of the background. - */ + * Panels, tabs, other UI surfaces that sit on top of the background. + */ middle: Layer /** - * Editors like code buffers, conversation editors, etc. - */ + * Editors like code buffers, conversation editors, etc. + */ highest: Layer ramps: RampSet @@ -206,7 +206,10 @@ function build_color_family(ramps: RampSet): ColorFamily { for (const ramp in ramps) { const ramp_value = ramps[ramp as keyof RampSet] - const lightnessValues = [ramp_value(0).get('hsl.l') * 100, ramp_value(1).get('hsl.l') * 100] + const lightnessValues = [ + ramp_value(0).get("hsl.l") * 100, + ramp_value(1).get("hsl.l") * 100, + ] const low = Math.min(...lightnessValues) const high = Math.max(...lightnessValues) const range = high - low diff --git a/styles/src/theme/index.ts b/styles/src/theme/index.ts index ca8aaa461f..47110940f5 100644 --- a/styles/src/theme/index.ts +++ b/styles/src/theme/index.ts @@ -23,3 +23,4 @@ export * from "./create_theme" export * from "./ramps" export * from "./syntax" export * from "./theme_config" +export * from "./color" diff --git a/styles/src/theme/tokens/theme.ts b/styles/src/theme/tokens/theme.ts index f759bc8139..e2c3bb33d3 100644 --- a/styles/src/theme/tokens/theme.ts +++ b/styles/src/theme/tokens/theme.ts @@ -4,11 +4,7 @@ import { SingleOtherToken, TokenTypes, } from "@tokens-studio/types" -import { - Shadow, - SyntaxHighlightStyle, - ThemeSyntax, -} from "../create_theme" +import { Shadow, SyntaxHighlightStyle, ThemeSyntax } from "../create_theme" import { LayerToken, layer_token } from "./layer" import { PlayersToken, players_token } from "./players" import { color_token } from "./token" diff --git a/styles/tsconfig.json b/styles/tsconfig.json index 281bd74b21..940442e1b7 100644 --- a/styles/tsconfig.json +++ b/styles/tsconfig.json @@ -21,10 +21,7 @@ "experimentalDecorators": true, "strictPropertyInitialization": false, "skipLibCheck": true, - "useUnknownInCatchVariables": false, - "baseUrl": "." + "useUnknownInCatchVariables": false }, - "exclude": [ - "node_modules" - ] + "exclude": ["node_modules"] }