From 2d411303bb3b034e06bed2d4bba4b1ce275736da Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 29 Aug 2023 10:07:22 -0600 Subject: [PATCH 01/60] Use preview server when not on stable --- crates/client/src/client.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index a32c415f7e..d28c1ab1a9 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1011,9 +1011,9 @@ impl Client { credentials: &Credentials, cx: &AsyncAppContext, ) -> Task> { - let is_preview = cx.read(|cx| { + let use_preview_server = cx.read(|cx| { if cx.has_global::() { - *cx.global::() == ReleaseChannel::Preview + *cx.global::() != ReleaseChannel::Stable } else { false } @@ -1028,7 +1028,7 @@ impl Client { let http = self.http.clone(); cx.background().spawn(async move { - let mut rpc_url = Self::get_rpc_url(http, is_preview).await?; + let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?; let rpc_host = rpc_url .host_str() .zip(rpc_url.port_or_known_default()) From 4f8b95cf0d99955555b6b086bed7c3153cd5bc92 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 29 Aug 2023 15:44:51 -0400 Subject: [PATCH 02/60] add proper handling for open ai rate limit delays --- Cargo.lock | 65 ++++++++++++++++- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding.rs | 96 ++++++++++++++++---------- 3 files changed, 124 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 347976691d..e0eb1947e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3532,7 +3532,7 @@ dependencies = [ "gif", "jpeg-decoder", "num-iter", - "num-rational", + "num-rational 0.3.2", "num-traits", "png", "scoped_threadpool", @@ -4625,6 +4625,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +dependencies = [ + "num-bigint 0.2.6", + "num-complex", + "num-integer", + "num-iter", + "num-rational 0.2.4", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -4653,6 +4678,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4685,6 +4720,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +dependencies = [ + "autocfg", + "num-bigint 0.2.6", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.3.2" @@ -5001,6 +5048,17 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parse_duration" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d" +dependencies = [ + "lazy_static", + "num", + "regex", +] + [[package]] name = "password-hash" version = "0.2.3" @@ -6667,6 +6725,7 @@ dependencies = [ "log", "matrixmultiply", "parking_lot 0.11.2", + "parse_duration", "picker", "postage", "pretty_assertions", @@ -6998,7 +7057,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80" dependencies = [ "chrono", - "num-bigint", + "num-bigint 0.4.4", "num-traits", "thiserror", ] @@ -7230,7 +7289,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint", + "num-bigint 0.4.4", "once_cell", "paste", "percent-encoding", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 4e817fcbe2..d46346e0ab 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +parse_duration = "2.1.1" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index f2269a786a..a9cb0245c4 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -7,6 +7,7 @@ use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; +use parse_duration::parse; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -84,10 +85,15 @@ impl OpenAIEmbeddings { span } - async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result> { + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(4)) + .timeout(Duration::from_secs(request_timeout)) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( @@ -114,45 +120,23 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; + let mut request_timeout: u64 = 10; let mut truncated = false; let mut response: Response; let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self - .send_request(api_key, spans.iter().map(|x| &**x).collect()) + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) .await?; request_number += 1; - if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK { - return Err(anyhow!( - "openai max retries, error: {:?}", - &response.status() - )); - } - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - log::trace!( - "open ai rate limiting, delaying request by {:?} seconds", - delay.as_secs() - ); - self.executor.timer(delay).await; - } - StatusCode::BAD_REQUEST => { - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - log::trace!("open ai bad request: {:?} {:?}", &response.status(), body); - break; - } + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; } StatusCode::OK => { let mut body = String::new(); @@ -163,18 +147,60 @@ impl EmbeddingProvider for OpenAIEmbeddings { "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); + return Ok(response .data .into_iter() .map(|embedding| embedding.embedding) .collect()); } + StatusCode::TOO_MANY_REQUESTS => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } _ => { - return Err(anyhow!("openai embedding failed {}", response.status())); + // TODO: Move this to parsing step + // Only truncate if it hasnt been truncated before + if !truncated { + for span in spans.iter_mut() { + *span = Self::truncate(span.clone()); + } + truncated = true; + } else { + // If failing once already truncated, log the error and break the loop + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } } } } - - Err(anyhow!("openai embedding failed")) + Err(anyhow!("openai max retries")) } } From a7e6a65debbe032edfd180e88f6be545edf89281 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Tue, 29 Aug 2023 17:14:44 -0400 Subject: [PATCH 03/60] reindex files in the background after they have not been edited for 10 minutes Co-authored-by: Max --- crates/semantic_index/src/semantic_index.rs | 416 +++++++++----------- 1 file changed, 188 insertions(+), 228 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 736f2c98a8..2da0d84baf 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -16,16 +16,18 @@ use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; -use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId}; +use project::{ + search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, +}; use smol::channel; use std::{ cmp::Ordering, - collections::HashMap, + collections::{BTreeMap, HashMap}, mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, - time::{Instant, SystemTime}, + time::{Duration, Instant, SystemTime}, }; use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, @@ -37,6 +39,7 @@ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 7; const EMBEDDINGS_BATCH_SIZE: usize = 80; +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( fs: Arc, @@ -77,6 +80,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, + // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, executor: cx.background(), @@ -113,9 +117,14 @@ struct ProjectState { worktree_db_ids: Vec<(WorktreeId, i64)>, _subscription: gpui::Subscription, outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, - job_queue_tx: channel::Sender, - _queue_update_task: Task<()>, + outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, +} + +struct ChangedPathInfo { + changed_at: Instant, + mtime: SystemTime, + is_deleted: bool, } #[derive(Clone)] @@ -133,31 +142,21 @@ impl JobHandle { } } } + impl ProjectState { fn new( - cx: &mut AppContext, subscription: gpui::Subscription, worktree_db_ids: Vec<(WorktreeId, i64)>, - outstanding_job_count_rx: watch::Receiver, - _outstanding_job_count_tx: Arc>>, + changed_paths: BTreeMap, ) -> Self { - let (job_queue_tx, job_queue_rx) = channel::unbounded(); - let _queue_update_task = cx.background().spawn({ - let mut worktree_queue = HashMap::new(); - async move { - while let Ok(operation) = job_queue_rx.recv().await { - Self::update_queue(&mut worktree_queue, operation); - } - } - }); - + let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0); + let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx)); Self { worktree_db_ids, outstanding_job_count_rx, - _outstanding_job_count_tx, + outstanding_job_count_tx, + changed_paths, _subscription: subscription, - _queue_update_task, - job_queue_tx, } } @@ -165,41 +164,6 @@ impl ProjectState { self.outstanding_job_count_rx.borrow().clone() } - fn update_queue(queue: &mut HashMap, operation: IndexOperation) { - match operation { - IndexOperation::FlushQueue => { - let queue = std::mem::take(queue); - for (_, op) in queue { - match op { - IndexOperation::IndexFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - IndexOperation::DeleteFile { - absolute_path: _, - payload, - tx, - } => { - let _ = tx.try_send(payload); - } - _ => {} - } - } - } - IndexOperation::IndexFile { - ref absolute_path, .. - } - | IndexOperation::DeleteFile { - ref absolute_path, .. - } => { - queue.insert(absolute_path.clone(), operation); - } - } - } - fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { self.worktree_db_ids .iter() @@ -230,23 +194,10 @@ pub struct PendingFile { worktree_db_id: i64, relative_path: PathBuf, absolute_path: PathBuf, - language: Arc, + language: Option>, modified_time: SystemTime, job_handle: JobHandle, } -enum IndexOperation { - IndexFile { - absolute_path: PathBuf, - payload: PendingFile, - tx: channel::Sender, - }, - DeleteFile { - absolute_path: PathBuf, - payload: DbOperation, - tx: channel::Sender, - }, - FlushQueue, -} pub struct SearchResult { pub buffer: ModelHandle, @@ -582,13 +533,13 @@ impl SemanticIndex { parsing_files_rx: &channel::Receiver, db_update_tx: &channel::Sender, ) { + let Some(language) = pending_file.language else { + return; + }; + if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { if let Some(documents) = retriever - .parse_file_with_template( - &pending_file.relative_path, - &content, - pending_file.language, - ) + .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { log::trace!( @@ -679,103 +630,50 @@ impl SemanticIndex { } fn project_entries_changed( - &self, + &mut self, project: ModelHandle, changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, cx: &mut ModelContext<'_, SemanticIndex>, worktree_id: &WorktreeId, - ) -> Result<()> { - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); - let (job_queue_tx, outstanding_job_tx, worktree_db_id) = { - let state = self - .projects - .get(&project.downgrade()) - .ok_or(anyhow!("Project not yet initialized"))?; - let worktree_db_id = state - .db_id_for_worktree_id(*worktree_id) - .ok_or(anyhow!("Worktree ID in Database Not Available"))?; - ( - state.job_queue_tx.clone(), - state._outstanding_job_count_tx.clone(), - worktree_db_id, - ) + ) { + let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else { + return; + }; + let project = project.downgrade(); + let Some(project_state) = self.projects.get_mut(&project) else { + return; }; - let language_registry = self.language_registry.clone(); - let parsing_files_tx = parsing_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); + let worktree = worktree.read(cx); + let change_time = Instant::now(); + for (path, entry_id, change) in changes.iter() { + let Some(entry) = worktree.entry_for_id(*entry_id) else { + continue; + }; + if entry.is_ignored || entry.is_symlink || entry.is_external { + continue; + } + let project_path = ProjectPath { + worktree_id: *worktree_id, + path: path.clone(), + }; + project_state.changed_paths.insert( + project_path, + ChangedPathInfo { + changed_at: change_time, + mtime: entry.mtime, + is_deleted: *change == PathChange::Removed, + }, + ); + } - let worktree = project - .read(cx) - .worktree_for_id(worktree_id.clone(), cx) - .ok_or(anyhow!("Worktree not available"))? - .read(cx) - .snapshot(); - cx.spawn(|_, _| async move { - let worktree = worktree.clone(); - for (path, entry_id, path_change) in changes.iter() { - let relative_path = path.to_path_buf(); - let absolute_path = worktree.absolutize(path); - - let Some(entry) = worktree.entry_for_id(*entry_id) else { - continue; - }; - if entry.is_ignored || entry.is_symlink || entry.is_external { - continue; - } - - log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path); - match path_change { - PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => { - if let Ok(language) = language_registry - .language_for_file(&relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - - let job_handle = JobHandle::new(&outstanding_job_tx); - let new_operation = IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path, - absolute_path, - language, - modified_time: entry.mtime, - job_handle, - }, - tx: parsing_files_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - } - PathChange::Removed => { - let new_operation = IndexOperation::DeleteFile { - absolute_path, - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path: relative_path, - }, - tx: db_update_tx.clone(), - }; - let _ = job_queue_tx.try_send(new_operation); - } - _ => {} - } + cx.spawn_weak(|this, mut cx| async move { + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; + if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { + Self::reindex_changed_paths(this, project, Some(change_time), &mut cx).await; } }) .detach(); - - Ok(()) } pub fn initialize_project( @@ -805,14 +703,11 @@ impl SemanticIndex { let _subscription = cx.subscribe(&project, |this, project, event, cx| { if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - let _ = - this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); + this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); }; }); let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - let db_update_tx = self.db_update_tx.clone(); cx.spawn(|this, mut cx| async move { futures::future::join_all(worktree_scans_complete).await; @@ -843,17 +738,13 @@ impl SemanticIndex { .map(|(a, b)| (*a, *b)) .collect(); - let (job_count_tx, job_count_rx) = watch::channel_with(0); - let job_count_tx = Arc::new(Mutex::new(job_count_tx)); - let job_count_tx_longlived = job_count_tx.clone(); - - let worktree_files = cx + let changed_paths = cx .background() .spawn(async move { - let mut worktree_files = Vec::new(); + let mut changed_paths = BTreeMap::new(); + let now = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap(); - let worktree_db_id = db_ids_by_worktree_id[&worktree.id()]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -876,59 +767,51 @@ impl SemanticIndex { continue; } - let path_buf = file.path.to_path_buf(); let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); let already_stored = stored_mtime .map_or(false, |existing_mtime| existing_mtime == file.mtime); if !already_stored { - let job_handle = JobHandle::new(&job_count_tx); - worktree_files.push(IndexOperation::IndexFile { - absolute_path: absolute_path.clone(), - payload: PendingFile { - worktree_db_id, - relative_path: path_buf, - absolute_path, - language, - job_handle, - modified_time: file.mtime, + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: file.path.clone(), }, - tx: parsing_files_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime: file.mtime, + is_deleted: false, + }, + ); } } } + // Clean up entries from database that are no longer in the worktree. - for (path, _) in file_mtimes { - worktree_files.push(IndexOperation::DeleteFile { - absolute_path: worktree.absolutize(path.as_path()), - payload: DbOperation::Delete { - worktree_id: worktree_db_id, - path, + for (path, mtime) in file_mtimes { + changed_paths.insert( + ProjectPath { + worktree_id: worktree.id(), + path: path.into(), }, - tx: db_update_tx.clone(), - }); + ChangedPathInfo { + changed_at: now, + mtime, + is_deleted: true, + }, + ); } } - anyhow::Ok(worktree_files) + anyhow::Ok(changed_paths) }) .await?; - this.update(&mut cx, |this, cx| { - let project_state = ProjectState::new( - cx, - _subscription, - worktree_db_ids, - job_count_rx, - job_count_tx_longlived, + this.update(&mut cx, |this, _| { + this.projects.insert( + project.downgrade(), + ProjectState::new(_subscription, worktree_db_ids, changed_paths), ); - - for op in worktree_files { - let _ = project_state.job_queue_tx.try_send(op); - } - - this.projects.insert(project.downgrade(), project_state); }); Result::<(), _>::Ok(()) }) @@ -939,27 +822,17 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task)>> { - let state = self.projects.get_mut(&project.downgrade()); - let state = if state.is_none() { - return Task::Ready(Some(Err(anyhow!("Project not yet initialized")))); - } else { - state.unwrap() - }; - - // let parsing_files_tx = self.parsing_files_tx.clone(); - // let db_update_tx = self.db_update_tx.clone(); - let job_count_rx = state.outstanding_job_count_rx.clone(); - let count = state.get_outstanding_count(); - cx.spawn(|this, mut cx| async move { - this.update(&mut cx, |this, _| { - let Some(state) = this.projects.get_mut(&project.downgrade()) else { - return; - }; - let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue); - }); + Self::reindex_changed_paths(this.clone(), project.clone(), None, &mut cx).await; - Ok((count, job_count_rx)) + this.update(&mut cx, |this, _cx| { + let Some(state) = this.projects.get(&project.downgrade()) else { + return Err(anyhow!("Project not yet initialized")); + }; + let job_count_rx = state.outstanding_job_count_rx.clone(); + let count = state.get_outstanding_count(); + Ok((count, job_count_rx)) + }) }) } @@ -1110,6 +983,93 @@ impl SemanticIndex { .collect::>()) }) } + + async fn reindex_changed_paths( + this: ModelHandle, + project: ModelHandle, + last_changed_before: Option, + cx: &mut AsyncAppContext, + ) { + let mut pending_files = Vec::new(); + let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| { + if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { + let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + let db_ids = &project_state.worktree_db_ids; + let mut worktree: Option> = None; + + project_state.changed_paths.retain(|path, info| { + if let Some(last_changed_before) = last_changed_before { + if info.changed_at > last_changed_before { + return true; + } + } + + if worktree + .as_ref() + .map_or(true, |tree| tree.read(cx).id() != path.worktree_id) + { + worktree = project.read(cx).worktree_for_id(path.worktree_id, cx); + } + let Some(worktree) = &worktree else { + return false; + }; + + let Some(worktree_db_id) = db_ids + .iter() + .find_map(|entry| (entry.0 == path.worktree_id).then_some(entry.1)) + else { + return false; + }; + + if info.is_deleted { + this.db_update_tx + .try_send(DbOperation::Delete { + worktree_id: worktree_db_id, + path: path.path.to_path_buf(), + }) + .ok(); + } else { + let absolute_path = worktree.read(cx).absolutize(&path.path); + let job_handle = JobHandle::new(&outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.path.to_path_buf(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id, + }); + } + + false + }); + } + + ( + this.language_registry.clone(), + this.parsing_files_tx.clone(), + ) + }); + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx.try_send(pending_file).ok(); + } + } } impl Entity for SemanticIndex { From d3650594c386e2b96958a0fb552e5ad322a6df30 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 28 Aug 2023 11:47:37 -0600 Subject: [PATCH 04/60] Fix find_{,preceding}boundary to work on buffer text Before this change the bounday could mistakenly have happened on a soft line wrap. Also fixes interaction with inlays better. --- crates/editor/src/movement.rs | 174 ++++++------------ crates/vim/src/motion.rs | 26 ++- crates/vim/src/normal.rs | 2 +- crates/vim/src/normal/change.rs | 20 +- crates/vim/src/object.rs | 30 ++- crates/vim/src/test.rs | 18 ++ .../src/test/neovim_backed_test_context.rs | 7 +- crates/vim/test_data/test_end_of_word.json | 32 ++++ .../test_data/test_visual_word_object.json | 6 +- crates/vim/test_data/test_wrapped_lines.json | 5 + 10 files changed, 174 insertions(+), 146 deletions(-) create mode 100644 crates/vim/test_data/test_end_of_word.json diff --git a/crates/editor/src/movement.rs b/crates/editor/src/movement.rs index def6340e38..915da7b23f 100644 --- a/crates/editor/src/movement.rs +++ b/crates/editor/src/movement.rs @@ -1,8 +1,14 @@ use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint}; -use crate::{char_kind, CharKind, ToPoint}; +use crate::{char_kind, CharKind, ToOffset, ToPoint}; use language::Point; use std::ops::Range; +#[derive(Debug, PartialEq)] +pub enum FindRange { + SingleLine, + MultiLine, +} + pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint { if point.column() > 0 { *point.column_mut() -= 1; @@ -179,7 +185,7 @@ pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> Displa let raw_point = point.to_point(map); let language = map.buffer_snapshot.language_at(raw_point); - find_preceding_boundary(map, point, |left, right| { + find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { (char_kind(language, left) != char_kind(language, right) && !right.is_whitespace()) || left == '\n' }) @@ -188,7 +194,7 @@ pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> Displa pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); let language = map.buffer_snapshot.language_at(raw_point); - find_preceding_boundary(map, point, |left, right| { + find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { let is_word_start = char_kind(language, left) != char_kind(language, right) && !right.is_whitespace(); let is_subword_start = @@ -200,7 +206,7 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); let language = map.buffer_snapshot.language_at(raw_point); - find_boundary(map, point, |left, right| { + find_boundary(map, point, FindRange::MultiLine, |left, right| { (char_kind(language, left) != char_kind(language, right) && !left.is_whitespace()) || right == '\n' }) @@ -209,7 +215,7 @@ pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint { let raw_point = point.to_point(map); let language = map.buffer_snapshot.language_at(raw_point); - find_boundary(map, point, |left, right| { + find_boundary(map, point, FindRange::MultiLine, |left, right| { let is_word_end = (char_kind(language, left) != char_kind(language, right)) && !left.is_whitespace(); let is_subword_end = @@ -272,79 +278,34 @@ pub fn end_of_paragraph( map.max_point() } -/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. +/// Scans for a boundary preceding the given start point `from` until a boundary is found, +/// indicated by the given predicate returning true. +/// The predicate is called with the character to the left and right of the candidate boundary location. +/// If FindRange::SingleLine is specified and no boundary is found before the start of the current line, the start of the current line will be returned. pub fn find_preceding_boundary( map: &DisplaySnapshot, from: DisplayPoint, + find_range: FindRange, mut is_boundary: impl FnMut(char, char) -> bool, ) -> DisplayPoint { - let mut start_column = 0; - let mut soft_wrap_row = from.row() + 1; + let mut prev_ch = None; + let mut offset = from.to_point(map).to_offset(&map.buffer_snapshot); - let mut prev = None; - for (ch, point) in map.reverse_chars_at(from) { - // Recompute soft_wrap_indent if the row has changed - if point.row() != soft_wrap_row { - soft_wrap_row = point.row(); - - if point.row() == 0 { - start_column = 0; - } else if let Some(indent) = map.soft_wrap_indent(point.row() - 1) { - start_column = indent; - } - } - - // If the current point is in the soft_wrap, skip comparing it - if point.column() < start_column { - continue; - } - - if let Some((prev_ch, prev_point)) = prev { - if is_boundary(ch, prev_ch) { - return map.clip_point(prev_point, Bias::Left); - } - } - - prev = Some((ch, point)); - } - map.clip_point(DisplayPoint::zero(), Bias::Left) -} - -/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. If no boundary is found, the start of the line is returned. -pub fn find_preceding_boundary_in_line( - map: &DisplaySnapshot, - from: DisplayPoint, - mut is_boundary: impl FnMut(char, char) -> bool, -) -> DisplayPoint { - let mut start_column = 0; - if from.row() > 0 { - if let Some(indent) = map.soft_wrap_indent(from.row() - 1) { - start_column = indent; - } - } - - let mut prev = None; - for (ch, point) in map.reverse_chars_at(from) { - if let Some((prev_ch, prev_point)) = prev { - if is_boundary(ch, prev_ch) { - return map.clip_point(prev_point, Bias::Left); - } - } - - if ch == '\n' || point.column() < start_column { + for ch in map.buffer_snapshot.reversed_chars_at(offset) { + if find_range == FindRange::SingleLine && ch == '\n' { break; } + if let Some(prev_ch) = prev_ch { + if is_boundary(ch, prev_ch) { + break; + } + } - prev = Some((ch, point)); + offset -= ch.len_utf8(); + prev_ch = Some(ch); } - map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Left) + map.clip_point(offset.to_display_point(map), Bias::Left) } /// Scans for a boundary following the given start point until a boundary is found, indicated by the @@ -354,47 +315,26 @@ pub fn find_preceding_boundary_in_line( pub fn find_boundary( map: &DisplaySnapshot, from: DisplayPoint, + find_range: FindRange, mut is_boundary: impl FnMut(char, char) -> bool, ) -> DisplayPoint { + let mut offset = from.to_offset(&map, Bias::Right); let mut prev_ch = None; - for (ch, point) in map.chars_at(from) { - if let Some(prev_ch) = prev_ch { - if is_boundary(prev_ch, ch) { - return map.clip_point(point, Bias::Right); - } - } - prev_ch = Some(ch); - } - map.clip_point(map.max_point(), Bias::Right) -} - -/// Scans for a boundary following the given start point until a boundary is found, indicated by the -/// given predicate returning true. The predicate is called with the character to the left and right -/// of the candidate boundary location, and will be called with `\n` characters indicating the start -/// or end of a line. If no boundary is found, the end of the line is returned -pub fn find_boundary_in_line( - map: &DisplaySnapshot, - from: DisplayPoint, - mut is_boundary: impl FnMut(char, char) -> bool, -) -> DisplayPoint { - let mut prev = None; - for (ch, point) in map.chars_at(from) { - if let Some((prev_ch, _)) = prev { - if is_boundary(prev_ch, ch) { - return map.clip_point(point, Bias::Right); - } - } - - prev = Some((ch, point)); - - if ch == '\n' { + for ch in map.buffer_snapshot.chars_at(offset) { + if find_range == FindRange::SingleLine && ch == '\n' { break; } - } + if let Some(prev_ch) = prev_ch { + if is_boundary(prev_ch, ch) { + break; + } + } - // Return the last position checked so that we give a point right before the newline or eof. - map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Right) + offset += ch.len_utf8(); + prev_ch = Some(ch); + } + map.clip_point(offset.to_display_point(map), Bias::Right) } pub fn is_inside_word(map: &DisplaySnapshot, point: DisplayPoint) -> bool { @@ -533,7 +473,12 @@ mod tests { ) { let (snapshot, display_points) = marked_display_snapshot(marked_text, cx); assert_eq!( - find_preceding_boundary(&snapshot, display_points[1], is_boundary), + find_preceding_boundary( + &snapshot, + display_points[1], + FindRange::MultiLine, + is_boundary + ), display_points[0] ); } @@ -612,21 +557,15 @@ mod tests { find_preceding_boundary( &snapshot, buffer_snapshot.len().to_display_point(&snapshot), - |left, _| left == 'a', + FindRange::MultiLine, + |left, _| left == 'e', ), - 0.to_display_point(&snapshot), + snapshot + .buffer_snapshot + .offset_to_point(5) + .to_display_point(&snapshot), "Should not stop at inlays when looking for boundaries" ); - - assert_eq!( - find_preceding_boundary_in_line( - &snapshot, - buffer_snapshot.len().to_display_point(&snapshot), - |left, _| left == 'a', - ), - 0.to_display_point(&snapshot), - "Should not stop at inlays when looking for boundaries in line" - ); } #[gpui::test] @@ -699,7 +638,12 @@ mod tests { ) { let (snapshot, display_points) = marked_display_snapshot(marked_text, cx); assert_eq!( - find_boundary(&snapshot, display_points[0], is_boundary), + find_boundary( + &snapshot, + display_points[0], + FindRange::MultiLine, + is_boundary + ), display_points[1] ); } diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index 0d3fb700ef..6f28430796 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -3,7 +3,8 @@ use std::{cmp, sync::Arc}; use editor::{ char_kind, display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint}, - movement, Bias, CharKind, DisplayPoint, ToOffset, + movement::{self, FindRange}, + Bias, CharKind, DisplayPoint, ToOffset, }; use gpui::{actions, impl_actions, AppContext, WindowContext}; use language::{Point, Selection, SelectionGoal}; @@ -592,7 +593,7 @@ pub(crate) fn next_word_start( let language = map.buffer_snapshot.language_at(point.to_point(map)); for _ in 0..times { let mut crossed_newline = false; - point = movement::find_boundary(map, point, |left, right| { + point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| { let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); let at_newline = right == '\n'; @@ -616,8 +617,14 @@ fn next_word_end( ) -> DisplayPoint { let language = map.buffer_snapshot.language_at(point.to_point(map)); for _ in 0..times { - *point.column_mut() += 1; - point = movement::find_boundary(map, point, |left, right| { + if point.column() < map.line_len(point.row()) { + *point.column_mut() += 1; + } else if point.row() < map.max_buffer_row() { + *point.row_mut() += 1; + *point.column_mut() = 0; + } + // *point.column_mut() += 1; + point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| { let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); @@ -649,12 +656,13 @@ fn previous_word_start( for _ in 0..times { // This works even though find_preceding_boundary is called for every character in the line containing // cursor because the newline is checked only once. - point = movement::find_preceding_boundary(map, point, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + point = + movement::find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| { + let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); + let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); - (left_kind != right_kind && !right.is_whitespace()) || left == '\n' - }); + (left_kind != right_kind && !right.is_whitespace()) || left == '\n' + }); } point } diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index a73c518809..c8e623e4c1 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -445,7 +445,7 @@ mod test { } #[gpui::test] - async fn test_e(cx: &mut gpui::TestAppContext) { + async fn test_end_of_word(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await.binding(["e"]); cx.assert_all(indoc! {" Thˇe quicˇkˇ-browˇn diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 5591de89c6..6e64b050d1 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -1,7 +1,10 @@ use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_content, Vim}; use editor::{ - char_kind, display_map::DisplaySnapshot, movement, scroll::autoscroll::Autoscroll, CharKind, - DisplayPoint, + char_kind, + display_map::DisplaySnapshot, + movement::{self, FindRange}, + scroll::autoscroll::Autoscroll, + CharKind, DisplayPoint, }; use gpui::WindowContext; use language::Selection; @@ -96,12 +99,15 @@ fn expand_changed_word_selection( .unwrap_or_default(); if in_word { - selection.end = movement::find_boundary(map, selection.end, |left, right| { - let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); - let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); + selection.end = + movement::find_boundary(map, selection.end, FindRange::MultiLine, |left, right| { + let left_kind = + char_kind(language, left).coerce_punctuation(ignore_punctuation); + let right_kind = + char_kind(language, right).coerce_punctuation(ignore_punctuation); - left_kind != right_kind && left_kind != CharKind::Whitespace - }); + left_kind != right_kind && left_kind != CharKind::Whitespace + }); true } else { Motion::NextWordStart { ignore_punctuation } diff --git a/crates/vim/src/object.rs b/crates/vim/src/object.rs index dd922e7af6..94906a1e80 100644 --- a/crates/vim/src/object.rs +++ b/crates/vim/src/object.rs @@ -1,6 +1,11 @@ use std::ops::Range; -use editor::{char_kind, display_map::DisplaySnapshot, movement, Bias, CharKind, DisplayPoint}; +use editor::{ + char_kind, + display_map::DisplaySnapshot, + movement::{self, FindRange}, + Bias, CharKind, DisplayPoint, +}; use gpui::{actions, impl_actions, AppContext, WindowContext}; use language::Selection; use serde::Deserialize; @@ -178,15 +183,16 @@ fn in_word( ) -> Option> { // Use motion::right so that we consider the character under the cursor when looking for the start let language = map.buffer_snapshot.language_at(relative_to.to_point(map)); - let start = movement::find_preceding_boundary_in_line( + let start = movement::find_preceding_boundary( map, right(map, relative_to, 1), + movement::FindRange::SingleLine, |left, right| { char_kind(language, left).coerce_punctuation(ignore_punctuation) != char_kind(language, right).coerce_punctuation(ignore_punctuation) }, ); - let end = movement::find_boundary_in_line(map, relative_to, |left, right| { + let end = movement::find_boundary(map, relative_to, FindRange::SingleLine, |left, right| { char_kind(language, left).coerce_punctuation(ignore_punctuation) != char_kind(language, right).coerce_punctuation(ignore_punctuation) }); @@ -241,9 +247,10 @@ fn around_next_word( ) -> Option> { let language = map.buffer_snapshot.language_at(relative_to.to_point(map)); // Get the start of the word - let start = movement::find_preceding_boundary_in_line( + let start = movement::find_preceding_boundary( map, right(map, relative_to, 1), + FindRange::SingleLine, |left, right| { char_kind(language, left).coerce_punctuation(ignore_punctuation) != char_kind(language, right).coerce_punctuation(ignore_punctuation) @@ -251,7 +258,7 @@ fn around_next_word( ); let mut word_found = false; - let end = movement::find_boundary(map, relative_to, |left, right| { + let end = movement::find_boundary(map, relative_to, FindRange::MultiLine, |left, right| { let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation); let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation); @@ -566,11 +573,18 @@ mod test { async fn test_visual_word_object(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await; - cx.set_shared_state("The quick ˇbrown\nfox").await; + /* + cx.set_shared_state("The quick ˇbrown\nfox").await; + cx.simulate_shared_keystrokes(["v"]).await; + cx.assert_shared_state("The quick «bˇ»rown\nfox").await; + cx.simulate_shared_keystrokes(["i", "w"]).await; + cx.assert_shared_state("The quick «brownˇ»\nfox").await; + */ + cx.set_shared_state("The quick brown\nˇ\nfox").await; cx.simulate_shared_keystrokes(["v"]).await; - cx.assert_shared_state("The quick «bˇ»rown\nfox").await; + cx.assert_shared_state("The quick brown\n«\nˇ»fox").await; cx.simulate_shared_keystrokes(["i", "w"]).await; - cx.assert_shared_state("The quick «brownˇ»\nfox").await; + cx.assert_shared_state("The quick brown\n«\nˇ»fox").await; cx.assert_binding_matches_all(["v", "i", "w"], WORD_LOCATIONS) .await; diff --git a/crates/vim/src/test.rs b/crates/vim/src/test.rs index 88fa375851..c6a212d77f 100644 --- a/crates/vim/src/test.rs +++ b/crates/vim/src/test.rs @@ -431,6 +431,24 @@ async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) { twelve char "}) .await; + + // line wraps as: + // fourteen ch + // ar + // fourteen ch + // ar + cx.set_shared_state(indoc! { " + fourteen chaˇr + fourteen char + "}) + .await; + + cx.simulate_shared_keystrokes(["d", "i", "w"]).await; + cx.assert_shared_state(indoc! {" + fourteenˇ• + fourteen char + "}) + .await; } #[gpui::test] diff --git a/crates/vim/src/test/neovim_backed_test_context.rs b/crates/vim/src/test/neovim_backed_test_context.rs index d04b1b7768..b433a6bfc0 100644 --- a/crates/vim/src/test/neovim_backed_test_context.rs +++ b/crates/vim/src/test/neovim_backed_test_context.rs @@ -153,6 +153,7 @@ impl<'a> NeovimBackedTestContext<'a> { } pub async fn assert_shared_state(&mut self, marked_text: &str) { + let marked_text = marked_text.replace("•", " "); let neovim = self.neovim_state().await; let editor = self.editor_state(); if neovim == marked_text && neovim == editor { @@ -184,9 +185,9 @@ impl<'a> NeovimBackedTestContext<'a> { message, initial_state, self.recent_keystrokes.join(" "), - marked_text, - neovim, - editor + marked_text.replace(" \n", "•\n"), + neovim.replace(" \n", "•\n"), + editor.replace(" \n", "•\n") ) } diff --git a/crates/vim/test_data/test_end_of_word.json b/crates/vim/test_data/test_end_of_word.json new file mode 100644 index 0000000000..06f80dc245 --- /dev/null +++ b/crates/vim/test_data/test_end_of_word.json @@ -0,0 +1,32 @@ +{"Put":{"state":"Thˇe quick-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"e"} +{"Get":{"state":"The quicˇk-brown\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quickˇ-brown\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumpˇs over\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps oveˇr\nthe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Key":"e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Put":{"state":"Thˇe quick-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Put":{"state":"The quicˇk-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Put":{"state":"The quickˇ-brown\n\n\nfox_jumps over\nthe"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-browˇn\n\n\nfox_jumps over\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumpˇs over\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps oveˇr\nthe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} +{"Key":"shift-e"} +{"Get":{"state":"The quick-brown\n\n\nfox_jumps over\nthˇe","mode":"Normal"}} diff --git a/crates/vim/test_data/test_visual_word_object.json b/crates/vim/test_data/test_visual_word_object.json index 0041baf969..5e1a9839e9 100644 --- a/crates/vim/test_data/test_visual_word_object.json +++ b/crates/vim/test_data/test_visual_word_object.json @@ -1,9 +1,9 @@ -{"Put":{"state":"The quick ˇbrown\nfox"}} +{"Put":{"state":"The quick brown\nˇ\nfox"}} {"Key":"v"} -{"Get":{"state":"The quick «bˇ»rown\nfox","mode":"Visual"}} +{"Get":{"state":"The quick brown\n«\nˇ»fox","mode":"Visual"}} {"Key":"i"} {"Key":"w"} -{"Get":{"state":"The quick «brownˇ»\nfox","mode":"Visual"}} +{"Get":{"state":"The quick brown\n«\nˇ»fox","mode":"Visual"}} {"Put":{"state":"The quick ˇbrown \nfox jumps over\nthe lazy dog \n\n\n\nThe-quick brown \n \n \n fox-jumps over\nthe lazy dog \n\n"}} {"Key":"v"} {"Key":"i"} diff --git a/crates/vim/test_data/test_wrapped_lines.json b/crates/vim/test_data/test_wrapped_lines.json index 1ebbd4f205..1fbfc935d9 100644 --- a/crates/vim/test_data/test_wrapped_lines.json +++ b/crates/vim/test_data/test_wrapped_lines.json @@ -48,3 +48,8 @@ {"Key":"o"} {"Key":"escape"} {"Get":{"state":"twelve char\nˇo\ntwelve char twelve char\ntwelve char\n","mode":"Normal"}} +{"Put":{"state":"fourteen chaˇr\nfourteen char\n"}} +{"Key":"d"} +{"Key":"i"} +{"Key":"w"} +{"Get":{"state":"fourteenˇ \nfourteen char\n","mode":"Normal"}} From e377ada1a9aef3335f08543f6036b69c6adc0ddf Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 11:05:46 -0400 Subject: [PATCH 05/60] added token count to documents during parsing --- crates/semantic_index/src/embedding.rs | 14 +++++++++ crates/semantic_index/src/parsing.rs | 19 ++++++++++-- crates/semantic_index/src/semantic_index.rs | 3 +- .../src/semantic_index_tests.rs | 30 +++++++++++++------ 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index a9cb0245c4..72621d3138 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -54,6 +54,8 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; + fn count_tokens(&self, span: &str) -> usize; + // fn truncate(&self, span: &str) -> Result<&str>; } pub struct DummyEmbeddings {} @@ -66,6 +68,12 @@ impl EmbeddingProvider for DummyEmbeddings { let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } + + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -111,6 +119,12 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { + fn count_tokens(&self, span: &str) -> usize { + // For Dummy Providers, we are going to use OpenAI tokenization for ease + let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + tokens.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 4aefb0b00d..b106e5055b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,3 +1,4 @@ +use crate::embedding::EmbeddingProvider; use anyhow::{anyhow, Ok, Result}; use language::{Grammar, Language}; use sha1::{Digest, Sha1}; @@ -17,6 +18,7 @@ pub struct Document { pub content: String, pub embedding: Vec, pub sha1: [u8; 20], + pub token_count: usize, } const CODE_CONTEXT_TEMPLATE: &str = @@ -30,6 +32,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = pub struct CodeContextRetriever { pub parser: Parser, pub cursor: QueryCursor, + pub embedding_provider: Arc, } // Every match has an item, this represents the fundamental treesitter symbol and anchors the search @@ -47,10 +50,11 @@ pub struct CodeContextMatch { } impl CodeContextRetriever { - pub fn new() -> Self { + pub fn new(embedding_provider: Arc) -> Self { Self { parser: Parser::new(), cursor: QueryCursor::new(), + embedding_provider, } } @@ -68,12 +72,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: language_name.to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -85,12 +92,15 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); + let token_count = self.embedding_provider.count_tokens(&document_span); + Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Vec::new(), name: "Markdown".to_string(), sha1: sha1.finalize().into(), + token_count, }]) } @@ -166,10 +176,14 @@ impl CodeContextRetriever { let mut documents = self.parse_file(content, language)?; for document in &mut documents { - document.content = CODE_CONTEXT_TEMPLATE + let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("item", &document.content); + + let token_count = self.embedding_provider.count_tokens(&document_content); + document.content = document_content; + document.token_count = token_count; } Ok(documents) } @@ -272,6 +286,7 @@ impl CodeContextRetriever { range: item_range.clone(), embedding: vec![], sha1: sha1.finalize().into(), + token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 2da0d84baf..ab05ca7581 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -332,8 +332,9 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let batch_files_tx = batch_files_tx.clone(); let db_update_tx = db_update_tx.clone(); + let embedding_provider = embedding_provider.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { - let mut retriever = CodeContextRetriever::new(); + let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { Self::parse_file( &fs, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 32d8bb0fb8..cb318a9fd6 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,6 +1,6 @@ use crate::{ db::dot, - embedding::EmbeddingProvider, + embedding::{DummyEmbeddings, EmbeddingProvider}, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, SearchResult, SemanticIndex, @@ -227,7 +227,8 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /// A doc comment @@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" { @@ -397,7 +399,8 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /* globals importScripts, backend */ @@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" -- Creates a new class @@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" defmodule File.Stream do @@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " /** @@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" # This concern is inspired by "sudo mode" on GitHub. It @@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let mut retriever = CodeContextRetriever::new(); + let embedding_provider = Arc::new(DummyEmbeddings {}); + let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" usize { + span.len() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 76caea80f7543cf86eaf0f4e899f06ea478f3d8a Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 11:58:45 -0400 Subject: [PATCH 06/60] add should_truncate to embedding providers --- crates/semantic_index/src/embedding.rs | 19 +++++++++++++++++++ .../src/semantic_index_tests.rs | 4 ++++ 2 files changed, 23 insertions(+) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 72621d3138..3dd979f01b 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -55,6 +55,7 @@ struct OpenAIEmbeddingUsage { pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; + fn should_truncate(&self, span: &str) -> bool; // fn truncate(&self, span: &str) -> Result<&str>; } @@ -74,6 +75,20 @@ impl EmbeddingProvider for DummyEmbeddings { let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); tokens.len() } + + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + + // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let Ok(output) = { + // if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER.decode(tokens) + // } else { + // Ok(span) + // } + // }; + } } const OPENAI_INPUT_LIMIT: usize = 8190; @@ -125,6 +140,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { tokens.len() } + fn should_truncate(&self, span: &str) -> bool { + self.count_tokens(span) > OPENAI_INPUT_LIMIT + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index cb318a9fd6..48cefd93b1 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1228,6 +1228,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { span.len() } + fn should_truncate(&self, span: &str) -> bool { + false + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 97810471569618955c241e4137629b578c46285b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 12:13:26 -0400 Subject: [PATCH 07/60] move truncation to parsing step leveraging the EmbeddingProvider trait --- crates/semantic_index/src/embedding.rs | 78 +++++++++---------- crates/semantic_index/src/parsing.rs | 4 + .../src/semantic_index_tests.rs | 4 + 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 3dd979f01b..cba34439c8 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -56,7 +56,7 @@ pub trait EmbeddingProvider: Sync + Send { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; fn count_tokens(&self, span: &str) -> usize; fn should_truncate(&self, span: &str) -> bool; - // fn truncate(&self, span: &str) -> Result<&str>; + fn truncate(&self, span: &str) -> String; } pub struct DummyEmbeddings {} @@ -78,36 +78,27 @@ impl EmbeddingProvider for DummyEmbeddings { fn should_truncate(&self, span: &str) -> bool { self.count_tokens(span) > OPENAI_INPUT_LIMIT + } - // let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let Ok(output) = { - // if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER.decode(tokens) - // } else { - // Ok(span) - // } - // }; + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output } } const OPENAI_INPUT_LIMIT: usize = 8190; impl OpenAIEmbeddings { - fn truncate(span: String) -> String { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref()); - if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - if result.is_ok() { - let transformed = result.unwrap(); - return transformed; - } - } - - span - } - async fn send_request( &self, api_key: &str, @@ -144,6 +135,21 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.count_tokens(span) > OPENAI_INPUT_LIMIT } + fn truncate(&self, span: &str) -> String { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + output + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -214,23 +220,13 @@ impl EmbeddingProvider for OpenAIEmbeddings { self.executor.timer(delay_duration).await; } _ => { - // TODO: Move this to parsing step - // Only truncate if it hasnt been truncated before - if !truncated { - for span in spans.iter_mut() { - *span = Self::truncate(span.clone()); - } - truncated = true; - } else { - // If failing once already truncated, log the error and break the loop - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); } } } diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index b106e5055b..00849580bb 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -73,6 +73,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,6 +94,7 @@ impl CodeContextRetriever { sha1.update(&document_span); let token_count = self.embedding_provider.count_tokens(&document_span); + let document_span = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -182,6 +184,8 @@ impl CodeContextRetriever { .replace("item", &document.content); let token_count = self.embedding_provider.count_tokens(&document_content); + let document_content = self.embedding_provider.truncate(&document_content); + document.content = document_content; document.token_count = token_count; } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 48cefd93b1..7093cf9fcf 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1232,6 +1232,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { false } + fn truncate(&self, span: &str) -> String { + span.to_string() + } + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); From 76ce52df4ee0f4b4b977093f096c76e15b852ae3 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:01:28 -0400 Subject: [PATCH 08/60] move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch" Co-authored-by: Max --- crates/semantic_index/src/embedding.rs | 47 ++---- crates/semantic_index/src/embedding_queue.rs | 140 ++++++++++++++++ crates/semantic_index/src/parsing.rs | 10 +- .../src/semantic_index_tests.rs | 154 +++++++++++++----- crates/util/src/util.rs | 35 ++-- 5 files changed, 295 insertions(+), 91 deletions(-) create mode 100644 crates/semantic_index/src/embedding_queue.rs diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index cba34439c8..7db22c3716 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; - fn count_tokens(&self, span: &str) -> usize; - fn should_truncate(&self, span: &str) -> bool; - fn truncate(&self, span: &str) -> String; + async fn embed_batch(&self, spans: Vec) -> Result>>; + fn max_tokens_per_batch(&self) -> usize; + fn truncate(&self, span: &str) -> (String, usize); } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings { span.to_string() }; - output + (output, token_count) } } @@ -125,19 +119,14 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { span.to_string() }; - output + (output, token_count) } - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { let mut request_number = 0; let mut request_timeout: u64 = 10; - let mut truncated = false; let mut response: Response; - let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self .send_request( diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs new file mode 100644 index 0000000000..6609c39e78 --- /dev/null +++ b/crates/semantic_index/src/embedding_queue.rs @@ -0,0 +1,140 @@ +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; + +use gpui::AppContext; +use parking_lot::Mutex; +use smol::channel; + +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; + +#[derive(Clone)] +pub struct FileToEmbed { + pub worktree_id: i64, + pub path: PathBuf, + pub mtime: SystemTime, + pub documents: Vec, + pub job_handle: JobHandle, +} + +impl std::fmt::Debug for FileToEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileToEmbed") + .field("worktree_id", &self.worktree_id) + .field("path", &self.path) + .field("mtime", &self.mtime) + .field("document", &self.documents) + .finish_non_exhaustive() + } +} + +impl PartialEq for FileToEmbed { + fn eq(&self, other: &Self) -> bool { + self.worktree_id == other.worktree_id + && self.path == other.path + && self.mtime == other.mtime + && self.documents == other.documents + } +} + +pub struct EmbeddingQueue { + embedding_provider: Arc, + pending_batch: Vec, + pending_batch_token_count: usize, + finished_files_tx: channel::Sender, + finished_files_rx: channel::Receiver, +} + +pub struct FileToEmbedFragment { + file: Arc>, + document_range: Range, +} + +impl EmbeddingQueue { + pub fn new(embedding_provider: Arc) -> Self { + let (finished_files_tx, finished_files_rx) = channel::unbounded(); + Self { + embedding_provider, + pending_batch: Vec::new(), + pending_batch_token_count: 0, + finished_files_tx, + finished_files_rx, + } + } + + pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + let file = Arc::new(Mutex::new(file)); + + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: 0..0, + }); + + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + for (ix, document) in file.lock().documents.iter().enumerate() { + let next_token_count = self.pending_batch_token_count + document.token_count; + if next_token_count > self.embedding_provider.max_tokens_per_batch() { + let range_end = fragment_range.end; + self.flush(cx); + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: range_end..range_end, + }); + fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + } + + fragment_range.end = ix + 1; + self.pending_batch_token_count += document.token_count; + } + } + + pub fn flush(&mut self, cx: &mut AppContext) { + let batch = mem::take(&mut self.pending_batch); + self.pending_batch_token_count = 0; + if batch.is_empty() { + return; + } + + let finished_files_tx = self.finished_files_tx.clone(); + let embedding_provider = self.embedding_provider.clone(); + cx.background().spawn(async move { + let mut spans = Vec::new(); + for fragment in &batch { + let file = fragment.file.lock(); + spans.extend( + file.documents[fragment.document_range.clone()] + .iter() + .map(|d| d.content.clone()), + ); + } + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for document in + &mut fragment.file.lock().documents[fragment.document_range.clone()] + { + if let Some(embedding) = embeddings.next() { + document.embedding = embedding; + } else { + // + log::error!("number of embeddings returned different from number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } + } + }) + .detach(); + } + + pub fn finished_files(&self) -> channel::Receiver { + self.finished_files_rx.clone() + } +} diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 00849580bb..51f1bd7ca9 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -72,8 +72,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,8 +92,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -183,8 +181,8 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &document.content); - let token_count = self.embedding_provider.count_tokens(&document_content); - let document_content = self.embedding_provider.truncate(&document_content); + let (document_content, token_count) = + self.embedding_provider.truncate(&document_content); document.content = document_content; document.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7093cf9fcf..7178987165 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,14 +1,16 @@ use crate::{ db::dot, embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding_queue::EmbeddingQueue, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, - SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; +use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; @@ -20,8 +22,10 @@ use std::{ atomic::{self, AtomicUsize}, Arc, }, + time::SystemTime, }; use unindent::Unindent; +use util::RandomCharIter; #[ctor::ctor] fn init_logger() { @@ -32,11 +36,7 @@ fn init_logger() { #[gpui::test] async fn test_semantic_index(cx: &mut TestAppContext) { - cx.update(|cx| { - cx.set_global(SettingsStore::test(cx)); - settings::register::(cx); - settings::register::(cx); - }); + init_test(cx); let fs = FakeFs::new(cx.background()); fs.insert_tree( @@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let db_path = db_dir.path().join("db.sqlite"); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let store = SemanticIndex::new( + let semantic_index = SemanticIndex::new( fs.clone(), db_path, embedding_provider.clone(), @@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let _ = store + let _ = semantic_index .update(cx, |store, cx| { store.initialize_project(project.clone(), cx) }) .await; - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); assert_eq!(*outstanding_file_count.borrow(), 0); - let search_results = store + let search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { // Test Include Files Functonality let include_files = vec![PathMatcher::new("*.rs").unwrap()]; let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; - let rust_only_search_results = store + let rust_only_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx, ); - let no_rust_search_results = store + let no_rust_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); let prev_embedding_count = embedding_provider.embedding_count(); - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) { ); } +#[gpui::test(iterations = 10)] +async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { + let (outstanding_job_count, _) = postage::watch::channel_with(0); + let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); + + let files = (1..=3) + .map(|file_ix| FileToEmbed { + worktree_id: 5, + path: format!("path-{file_ix}").into(), + mtime: SystemTime::now(), + documents: (0..rng.gen_range(4..22)) + .map(|document_ix| { + let content_len = rng.gen_range(10..100); + Document { + range: 0..10, + embedding: Vec::new(), + name: format!("document {document_ix}"), + content: RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect(), + sha1: rng.gen(), + token_count: rng.gen_range(10..30), + } + }) + .collect(), + job_handle: JobHandle::new(&outstanding_job_count), + }) + .collect::>(); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut queue = EmbeddingQueue::new(embedding_provider.clone()); + + let finished_files = cx.update(|cx| { + for file in &files { + queue.push(file.clone(), cx); + } + queue.flush(cx); + queue.finished_files() + }); + + cx.foreground().run_until_parked(); + let mut embedded_files: Vec<_> = files + .iter() + .map(|_| finished_files.try_recv().expect("no finished file")) + .collect(); + + let expected_files: Vec<_> = files + .iter() + .map(|file| { + let mut file = file.clone(); + for doc in &mut file.documents { + doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + } + file + }) + .collect(); + + embedded_files.sort_by_key(|f| f.path.clone()); + + assert_eq!(embedded_files, expected_files); +} + #[track_caller] fn assert_search_results( actual: &[SearchResult], @@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider { fn embedding_count(&self) -> usize { self.embedding_count.load(atomic::Ordering::SeqCst) } + + fn embed_sync(&self, span: &str) -> Vec { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result + } } #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { - fn count_tokens(&self, span: &str) -> usize { - span.len() + fn truncate(&self, span: &str) -> (String, usize) { + (span.to_string(), 1) } - fn should_truncate(&self, span: &str) -> bool { - false + fn max_tokens_per_batch(&self) -> usize { + 200 } - fn truncate(&self, span: &str) -> String { - span.to_string() - } - - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans - .iter() - .map(|span| { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result - }) - .collect()) + Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } @@ -1704,3 +1762,11 @@ fn test_subtract_ranges() { assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); } + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + cx.set_global(SettingsStore::test(cx)); + settings::register::(cx); + settings::register::(cx); + }); +} diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index c8beb86aef..785426ed4c 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -260,11 +260,22 @@ pub fn defer(f: F) -> impl Drop { Defer(Some(f)) } -pub struct RandomCharIter(T); +pub struct RandomCharIter { + rng: T, + simple_text: bool, +} impl RandomCharIter { pub fn new(rng: T) -> Self { - Self(rng) + Self { + rng, + simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()), + } + } + + pub fn with_simple_text(mut self) -> Self { + self.simple_text = true; + self } } @@ -272,25 +283,27 @@ impl Iterator for RandomCharIter { type Item = char; fn next(&mut self) -> Option { - if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) { - return if self.0.gen_range(0..100) < 5 { + if self.simple_text { + return if self.rng.gen_range(0..100) < 5 { Some('\n') } else { - Some(self.0.gen_range(b'a'..b'z' + 1).into()) + Some(self.rng.gen_range(b'a'..b'z' + 1).into()) }; } - match self.0.gen_range(0..100) { + match self.rng.gen_range(0..100) { // whitespace - 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(), + 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(), // two-byte greek letters - 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))), + 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))), // // three-byte characters - 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(), + 33..=45 => ['✋', '✅', '❌', '❎', '⭐'] + .choose(&mut self.rng) + .copied(), // // four-byte characters - 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(), + 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(), // ascii letters - _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()), + _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()), } } } From 5abad58b0d81941726f81fd8e6e8ca876811163e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:58:45 -0400 Subject: [PATCH 09/60] moved semantic index to use embeddings queue to batch and managed for atomic database writes Co-authored-by: Max --- crates/semantic_index/src/embedding_queue.rs | 25 +- crates/semantic_index/src/semantic_index.rs | 238 +++--------------- .../src/semantic_index_tests.rs | 14 +- 3 files changed, 55 insertions(+), 222 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6609c39e78..2b48b7a7d6 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,10 +1,8 @@ -use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; - -use gpui::AppContext; +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; - -use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; #[derive(Clone)] pub struct FileToEmbed { @@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed { pub struct EmbeddingQueue { embedding_provider: Arc, pending_batch: Vec, + executor: Arc, pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, @@ -49,10 +48,11 @@ pub struct FileToEmbedFragment { } impl EmbeddingQueue { - pub fn new(embedding_provider: Arc) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, + executor, pending_batch: Vec::new(), pending_batch_token_count: 0, finished_files_tx, @@ -60,7 +60,12 @@ impl EmbeddingQueue { } } - pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + pub fn push(&mut self, file: FileToEmbed) { + if file.documents.is_empty() { + self.finished_files_tx.try_send(file).unwrap(); + return; + } + let file = Arc::new(Mutex::new(file)); self.pending_batch.push(FileToEmbedFragment { @@ -73,7 +78,7 @@ impl EmbeddingQueue { let next_token_count = self.pending_batch_token_count + document.token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; - self.flush(cx); + self.flush(); self.pending_batch.push(FileToEmbedFragment { file: file.clone(), document_range: range_end..range_end, @@ -86,7 +91,7 @@ impl EmbeddingQueue { } } - pub fn flush(&mut self, cx: &mut AppContext) { + pub fn flush(&mut self) { let batch = mem::take(&mut self.pending_batch); self.pending_batch_token_count = 0; if batch.is_empty() { @@ -95,7 +100,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - cx.background().spawn(async move { + self.executor.spawn(async move { let mut spans = Vec::new(); for fragment in &batch { let file = fragment.file.lock(); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ab05ca7581..cde53182dc 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod embedding_queue; mod parsing; pub mod semantic_index_settings; @@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding_queue::{EmbeddingQueue, FileToEmbed}; use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; @@ -23,7 +25,6 @@ use smol::channel; use std::{ cmp::Ordering, collections::{BTreeMap, HashMap}, - mem, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -38,7 +39,6 @@ use util::{ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 7; -const EMBEDDINGS_BATCH_SIZE: usize = 80; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( @@ -106,9 +106,8 @@ pub struct SemanticIndex { language_registry: Arc, db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, + _embedding_task: Task<()>, _db_update_task: Task<()>, - _embed_batch_tasks: Vec>, - _batch_files_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -128,7 +127,7 @@ struct ChangedPathInfo { } #[derive(Clone)] -struct JobHandle { +pub struct JobHandle { /// The outer Arc is here to count the clones of a JobHandle instance; /// when the last handle to a given job is dropped, we decrement a counter (just once). tx: Arc>>>, @@ -230,17 +229,6 @@ enum DbOperation { }, } -enum EmbeddingJob { - Enqueue { - worktree_id: i64, - path: PathBuf, - mtime: SystemTime, - documents: Vec, - job_handle: JobHandle, - }, - Flush, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -287,52 +275,35 @@ impl SemanticIndex { } }); - // Group documents into batches and send them to the embedding provider. - let (embed_batch_tx, embed_batch_rx) = - channel::unbounded::, PathBuf, SystemTime, JobHandle)>>(); - let mut _embed_batch_tasks = Vec::new(); - for _ in 0..cx.background().num_cpus() { - let embed_batch_rx = embed_batch_rx.clone(); - _embed_batch_tasks.push(cx.background().spawn({ - let db_update_tx = db_update_tx.clone(); - let embedding_provider = embedding_provider.clone(); - async move { - while let Ok(embeddings_queue) = embed_batch_rx.recv().await { - Self::compute_embeddings_for_batch( - embeddings_queue, - &embedding_provider, - &db_update_tx, - ) - .await; - } + let embedding_queue = + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); + let _embedding_task = cx.background().spawn({ + let embedded_files = embedding_queue.finished_files(); + let db_update_tx = db_update_tx.clone(); + async move { + while let Ok(file) = embedded_files.recv().await { + db_update_tx + .try_send(DbOperation::InsertFile { + worktree_id: file.worktree_id, + documents: file.documents, + path: file.path, + mtime: file.mtime, + job_handle: file.job_handle, + }) + .ok(); } - })); - } - - // Group documents into batches and send them to the embedding provider. - let (batch_files_tx, batch_files_rx) = channel::unbounded::(); - let _batch_files_task = cx.background().spawn(async move { - let mut queue_len = 0; - let mut embeddings_queue = vec![]; - while let Ok(job) = batch_files_rx.recv().await { - Self::enqueue_documents_to_embed( - job, - &mut queue_len, - &mut embeddings_queue, - &embed_batch_tx, - ); } }); // Parse files into embeddable documents. let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); let parsing_files_rx = parsing_files_rx.clone(); - let batch_files_tx = batch_files_tx.clone(); - let db_update_tx = db_update_tx.clone(); let embedding_provider = embedding_provider.clone(); + let embedding_queue = embedding_queue.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -340,9 +311,8 @@ impl SemanticIndex { &fs, pending_file, &mut retriever, - &batch_files_tx, + &embedding_queue, &parsing_files_rx, - &db_update_tx, ) .await; } @@ -361,8 +331,7 @@ impl SemanticIndex { db_update_tx, parsing_files_tx, _db_update_task, - _embed_batch_tasks, - _batch_files_task, + _embedding_task, _parsing_files_tasks, projects: HashMap::new(), } @@ -403,136 +372,12 @@ impl SemanticIndex { } } - async fn compute_embeddings_for_batch( - mut embeddings_queue: Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embedding_provider: &Arc, - db_update_tx: &channel::Sender, - ) { - let mut batch_documents = vec![]; - for (_, documents, _, _, _) in embeddings_queue.iter() { - batch_documents.extend(documents.iter().map(|document| document.content.as_str())); - } - - if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await { - log::trace!( - "created {} embeddings for {} files", - embeddings.len(), - embeddings_queue.len(), - ); - - let mut i = 0; - let mut j = 0; - - for embedding in embeddings.iter() { - while embeddings_queue[i].1.len() == j { - i += 1; - j = 0; - } - - embeddings_queue[i].1[j].embedding = embedding.to_owned(); - j += 1; - } - - for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } else { - // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed). - for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id, - documents: vec![], - path, - mtime, - job_handle, - }) - .await - .unwrap(); - } - } - } - - fn enqueue_documents_to_embed( - job: EmbeddingJob, - queue_len: &mut usize, - embeddings_queue: &mut Vec<(i64, Vec, PathBuf, SystemTime, JobHandle)>, - embed_batch_tx: &channel::Sender, PathBuf, SystemTime, JobHandle)>>, - ) { - // Handle edge case where individual file has more documents than max batch size - let should_flush = match job { - EmbeddingJob::Enqueue { - documents, - worktree_id, - path, - mtime, - job_handle, - } => { - // If documents is greater than embeddings batch size, recursively batch existing rows. - if &documents.len() > &EMBEDDINGS_BATCH_SIZE { - let first_job = EmbeddingJob::Enqueue { - documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - first_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - - let second_job = EmbeddingJob::Enqueue { - documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(), - worktree_id, - path: path.clone(), - mtime, - job_handle: job_handle.clone(), - }; - - Self::enqueue_documents_to_embed( - second_job, - queue_len, - embeddings_queue, - embed_batch_tx, - ); - return; - } else { - *queue_len += &documents.len(); - embeddings_queue.push((worktree_id, documents, path, mtime, job_handle)); - *queue_len >= EMBEDDINGS_BATCH_SIZE - } - } - EmbeddingJob::Flush => true, - }; - - if should_flush { - embed_batch_tx - .try_send(mem::take(embeddings_queue)) - .unwrap(); - *queue_len = 0; - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, retriever: &mut CodeContextRetriever, - batch_files_tx: &channel::Sender, + embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, - db_update_tx: &channel::Sender, ) { let Some(language) = pending_file.language else { return; @@ -549,33 +394,18 @@ impl SemanticIndex { documents.len() ); - if documents.len() == 0 { - db_update_tx - .send(DbOperation::InsertFile { - worktree_id: pending_file.worktree_db_id, - documents, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - }) - .await - .unwrap(); - } else { - batch_files_tx - .try_send(EmbeddingJob::Enqueue { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - documents, - }) - .unwrap(); - } + embedding_queue.lock().push(FileToEmbed { + worktree_id: pending_file.worktree_db_id, + path: pending_file.relative_path, + mtime: pending_file.modified_time, + job_handle: pending_file.job_handle, + documents, + }); } } if parsing_files_rx.len() == 0 { - batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); + embedding_queue.lock().flush(); } } @@ -881,7 +711,7 @@ impl SemanticIndex { let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) + .embed_batch(vec![phrase]) .await? .into_iter() .next() diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7178987165..dc41c09f7a 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .collect::>(); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone()); - let finished_files = cx.update(|cx| { - for file in &files { - queue.push(file.clone(), cx); - } - queue.flush(cx); - queue.finished_files() - }); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); + for file in &files { + queue.push(file.clone()); + } + queue.flush(); cx.foreground().run_until_parked(); + let finished_files = queue.finished_files(); let mut embedded_files: Vec<_> = files .iter() .map(|_| finished_files.try_recv().expect("no finished file")) From 7d4d6c871ba88eafc8a084539a4619c8ba686872 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 17:42:16 -0400 Subject: [PATCH 10/60] fix bug for truncation ensuring no valid inputs are sent to openai --- crates/semantic_index/src/embedding.rs | 10 ++++------ crates/semantic_index/src/embedding_queue.rs | 8 +++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 7db22c3716..60e13a9e01 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -78,15 +78,13 @@ impl EmbeddingProvider for DummyEmbeddings { let token_count = tokens.len(); let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens) - .ok() - .unwrap_or_else(|| span.to_string()) + let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); + new_input.ok().unwrap_or_else(|| span.to_string()) } else { span.to_string() }; - (output, token_count) + (output, tokens.len()) } } @@ -120,7 +118,7 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT + 50000 } fn truncate(&self, span: &str) -> (String, usize) { diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 2b48b7a7d6..c3a5de1373 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -105,9 +105,11 @@ impl EmbeddingQueue { for fragment in &batch { let file = fragment.file.lock(); spans.extend( - file.documents[fragment.document_range.clone()] - .iter() - .map(|d| d.content.clone()), + { + file.documents[fragment.document_range.clone()] + .iter() + .map(|d| d.content.clone()) + } ); } From 35440be98e13df2d87f1e87e2ef750adf2ff59cc Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 16:54:11 +0200 Subject: [PATCH 11/60] Abstract away how database transactions are executed Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 630 +++++++++++--------- crates/semantic_index/src/semantic_index.rs | 199 ++----- 2 files changed, 397 insertions(+), 432 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 60ecf3b45f..652c2819ce 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,5 +1,7 @@ use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; +use futures::channel::oneshot; +use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::{ @@ -9,12 +11,14 @@ use rusqlite::{ use std::{ cmp::Ordering, collections::HashMap, + future::Future, ops::Range, path::{Path, PathBuf}, rc::Rc, sync::Arc, time::SystemTime, }; +use util::TryFutureExt; #[derive(Debug)] pub struct FileRecord { @@ -51,117 +55,161 @@ impl FromSql for Sha1 { } } +#[derive(Clone)] pub struct VectorDatabase { - db: rusqlite::Connection, + path: Arc, + transactions: smol::channel::Sender>, } impl VectorDatabase { - pub async fn new(fs: Arc, path: Arc) -> Result { + pub async fn new( + fs: Arc, + path: Arc, + executor: Arc, + ) -> Result { if let Some(db_directory) = path.parent() { fs.create_dir(db_directory).await?; } + let (transactions_tx, transactions_rx) = + smol::channel::unbounded::>(); + executor + .spawn({ + let path = path.clone(); + async move { + let connection = rusqlite::Connection::open(&path)?; + while let Ok(transaction) = transactions_rx.recv().await { + transaction(&connection); + } + + anyhow::Ok(()) + } + .log_err() + }) + .detach(); let this = Self { - db: rusqlite::Connection::open(path.as_path())?, + transactions: transactions_tx, + path, }; - this.initialize_database()?; + this.initialize_database().await?; Ok(this) } - fn get_existing_version(&self) -> Result { - let mut version_query = self - .db - .prepare("SELECT version from semantic_index_config")?; - version_query - .query_row([], |row| Ok(row.get::<_, i64>(0)?)) - .map_err(|err| anyhow!("version query failed: {err}")) + pub fn path(&self) -> &Arc { + &self.path } - fn initialize_database(&self) -> Result<()> { - rusqlite::vtab::array::load_module(&self.db)?; - - // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped - if self - .get_existing_version() - .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) - { - log::trace!("vector database schema up to date"); - return Ok(()); + fn transact(&self, transaction: F) -> impl Future> + where + F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result, + T: 'static + Send, + { + let (tx, rx) = oneshot::channel(); + let transactions = self.transactions.clone(); + async move { + if transactions + .send(Box::new(|connection| { + let result = transaction(connection); + let _ = tx.send(result); + })) + .await + .is_err() + { + return Err(anyhow!("connection was dropped"))?; + } + rx.await? } - - log::trace!("vector database schema out of date. updating..."); - self.db - .execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; - self.db - .execute("DROP TABLE IF EXISTS files", []) - .context("failed to drop 'files' table")?; - self.db - .execute("DROP TABLE IF EXISTS worktrees", []) - .context("failed to drop 'worktrees' table")?; - self.db - .execute("DROP TABLE IF EXISTS semantic_index_config", []) - .context("failed to drop 'semantic_index_config' table")?; - - // Initialize Vector Databasing Tables - self.db.execute( - "CREATE TABLE semantic_index_config ( - version INTEGER NOT NULL - )", - [], - )?; - - self.db.execute( - "INSERT INTO semantic_index_config (version) VALUES (?1)", - params![SEMANTIC_INDEX_VERSION], - )?; - - self.db.execute( - "CREATE TABLE worktrees ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - absolute_path VARCHAR NOT NULL - ); - CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); - ", - [], - )?; - - self.db.execute( - "CREATE TABLE files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - worktree_id INTEGER NOT NULL, - relative_path VARCHAR NOT NULL, - mtime_seconds INTEGER NOT NULL, - mtime_nanos INTEGER NOT NULL, - FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE - )", - [], - )?; - - self.db.execute( - "CREATE TABLE documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - start_byte INTEGER NOT NULL, - end_byte INTEGER NOT NULL, - name VARCHAR NOT NULL, - embedding BLOB NOT NULL, - sha1 BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE - )", - [], - )?; - - log::trace!("vector database initialized with updated schema."); - Ok(()) } - pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> { - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", - params![worktree_id, delete_path.to_str()], - )?; - Ok(()) + fn initialize_database(&self) -> impl Future> { + self.transact(|db| { + rusqlite::vtab::array::load_module(&db)?; + + // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped + let version_query = db.prepare("SELECT version from semantic_index_config"); + let version = version_query + .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?))); + if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) { + log::trace!("vector database schema up to date"); + return Ok(()); + } + + log::trace!("vector database schema out of date. updating..."); + db.execute("DROP TABLE IF EXISTS documents", []) + .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS files", []) + .context("failed to drop 'files' table")?; + db.execute("DROP TABLE IF EXISTS worktrees", []) + .context("failed to drop 'worktrees' table")?; + db.execute("DROP TABLE IF EXISTS semantic_index_config", []) + .context("failed to drop 'semantic_index_config' table")?; + + // Initialize Vector Databasing Tables + db.execute( + "CREATE TABLE semantic_index_config ( + version INTEGER NOT NULL + )", + [], + )?; + + db.execute( + "INSERT INTO semantic_index_config (version) VALUES (?1)", + params![SEMANTIC_INDEX_VERSION], + )?; + + db.execute( + "CREATE TABLE worktrees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + absolute_path VARCHAR NOT NULL + ); + CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); + ", + [], + )?; + + db.execute( + "CREATE TABLE files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + worktree_id INTEGER NOT NULL, + relative_path VARCHAR NOT NULL, + mtime_seconds INTEGER NOT NULL, + mtime_nanos INTEGER NOT NULL, + FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE + )", + [], + )?; + + db.execute( + "CREATE TABLE documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + start_byte INTEGER NOT NULL, + end_byte INTEGER NOT NULL, + name VARCHAR NOT NULL, + embedding BLOB NOT NULL, + sha1 BLOB NOT NULL, + FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE + )", + [], + )?; + + log::trace!("vector database initialized with updated schema."); + Ok(()) + }) + } + + pub fn delete_file( + &self, + worktree_id: i64, + delete_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", + params![worktree_id, delete_path.to_str()], + )?; + Ok(()) + }) } pub fn insert_file( @@ -170,117 +218,126 @@ impl VectorDatabase { path: PathBuf, mtime: SystemTime, documents: Vec, - ) -> Result<()> { - // Return the existing ID, if both the file and mtime match - let mtime = Timestamp::from(mtime); - let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; - let existing_id = existing_id_query - .query_row( - params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - |row| Ok(row.get::<_, i64>(0)?), - ) - .map_err(|err| anyhow!(err)); - let file_id = if existing_id.is_ok() { - // If already exists, just return the existing id - existing_id.unwrap() - } else { - // Delete Existing Row - self.db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", - params![worktree_id, path.to_str()], + ) -> impl Future> { + self.transact(move |db| { + // Return the existing ID, if both the file and mtime match + let mtime = Timestamp::from(mtime); + + let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; + let existing_id = existing_id_query + .query_row( + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], + |row| Ok(row.get::<_, i64>(0)?), + ); + + let file_id = if existing_id.is_ok() { + // If already exists, just return the existing id + existing_id? + } else { + // Delete Existing Row + db.execute( + "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", + params![worktree_id, path.to_str()], + )?; + db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; + db.last_insert_rowid() + }; + + // Currently inserting at approximately 3400 documents a second + // I imagine we can speed this up with a bulk insert of some kind. + for document in documents { + let embedding_blob = bincode::serialize(&document.embedding)?; + let sha_blob = bincode::serialize(&document.sha1)?; + + db.execute( + "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + file_id, + document.range.start.to_string(), + document.range.end.to_string(), + document.name, + embedding_blob, + sha_blob + ], + )?; + } + + Ok(()) + }) + } + + pub fn worktree_previously_indexed( + &self, + worktree_root_path: &Path, + ) -> impl Future> { + let worktree_root_path = worktree_root_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?)); + + if worktree_id.is_ok() { + return Ok(true); + } else { + return Ok(false); + } + }) + } + + pub fn find_or_create_worktree( + &self, + worktree_root_path: PathBuf, + ) -> impl Future> { + self.transact(move |db| { + let mut worktree_query = + db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }); + + if worktree_id.is_ok() { + return Ok(worktree_id?); + } + + // If worktree_id is Err, insert new worktree + db.execute( + "INSERT into worktrees (absolute_path) VALUES (?1)", + params![worktree_root_path.to_string_lossy()], )?; - self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; - self.db.last_insert_rowid() - }; + Ok(db.last_insert_rowid()) + }) + } - // Currently inserting at approximately 3400 documents a second - // I imagine we can speed this up with a bulk insert of some kind. - for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; - - self.db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![ - file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - embedding_blob, - sha_blob - ], + pub fn get_file_mtimes( + &self, + worktree_id: i64, + ) -> impl Future>> { + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT relative_path, mtime_seconds, mtime_nanos + FROM files + WHERE worktree_id = ?1 + ORDER BY relative_path", )?; - } - - Ok(()) - } - - pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result { - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); - - if worktree_id.is_ok() { - return Ok(true); - } else { - return Ok(false); - } - } - - pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { - // Check that the absolute path doesnt exist - let mut worktree_query = self - .db - .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - Ok(row.get::<_, i64>(0)?) - }) - .map_err(|err| anyhow!(err)); - - if worktree_id.is_ok() { - return worktree_id; - } - - // If worktree_id is Err, insert new worktree - self.db.execute( - " - INSERT into worktrees (absolute_path) VALUES (?1) - ", - params![worktree_root_path.to_string_lossy()], - )?; - Ok(self.db.last_insert_rowid()) - } - - pub fn get_file_mtimes(&self, worktree_id: i64) -> Result> { - let mut statement = self.db.prepare( - " - SELECT relative_path, mtime_seconds, mtime_nanos - FROM files - WHERE worktree_id = ?1 - ORDER BY relative_path", - )?; - let mut result: HashMap = HashMap::new(); - for row in statement.query_map(params![worktree_id], |row| { - Ok(( - row.get::<_, String>(0)?.into(), - Timestamp { - seconds: row.get(1)?, - nanos: row.get(2)?, - } - .into(), - )) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) + let mut result: HashMap = HashMap::new(); + for row in statement.query_map(params![worktree_id], |row| { + Ok(( + row.get::<_, String>(0)?.into(), + Timestamp { + seconds: row.get(1)?, + nanos: row.get(2)?, + } + .into(), + )) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) } pub fn top_k_search( @@ -288,21 +345,25 @@ impl VectorDatabase { query_embedding: &Vec, limit: usize, file_ids: &[i64], - ) -> Result> { - let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document(file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results - .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + ) -> impl Future>> { + let query_embedding = query_embedding.clone(); + let file_ids = file_ids.to_vec(); + self.transact(move |db| { + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); + Self::for_each_document(db, &file_ids, |id, embedding| { + let similarity = dot(&embedding, &query_embedding); + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - Ok(results) + anyhow::Ok(results) + }) } pub fn retrieve_included_file_ids( @@ -310,37 +371,46 @@ impl VectorDatabase { worktree_ids: &[i64], includes: &[PathMatcher], excludes: &[PathMatcher], - ) -> Result> { - let mut file_query = self.db.prepare( - " - SELECT - id, relative_path - FROM - files - WHERE - worktree_id IN rarray(?) - ", - )?; + ) -> impl Future>> { + let worktree_ids = worktree_ids.to_vec(); + let includes = includes.to_vec(); + let excludes = excludes.to_vec(); + self.transact(move |db| { + let mut file_query = db.prepare( + " + SELECT + id, relative_path + FROM + files + WHERE + worktree_id IN rarray(?) + ", + )?; - let mut file_ids = Vec::::new(); - let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; + let mut file_ids = Vec::::new(); + let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?; - while let Some(row) = rows.next()? { - let file_id = row.get(0)?; - let relative_path = row.get_ref(1)?.as_str()?; - let included = - includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); - let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); - if included && !excluded { - file_ids.push(file_id); + while let Some(row) = rows.next()? { + let file_id = row.get(0)?; + let relative_path = row.get_ref(1)?.as_str()?; + let included = + includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); + let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); + if included && !excluded { + file_ids.push(file_id); + } } - } - Ok(file_ids) + anyhow::Ok(file_ids) + }) } - fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec)) -> Result<()> { - let mut query_statement = self.db.prepare( + fn for_each_document( + db: &rusqlite::Connection, + file_ids: &[i64], + mut f: impl FnMut(i64, Vec), + ) -> Result<()> { + let mut query_statement = db.prepare( " SELECT id, embedding @@ -360,47 +430,53 @@ impl VectorDatabase { Ok(()) } - pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { - let mut statement = self.db.prepare( - " - SELECT - documents.id, - files.worktree_id, - files.relative_path, - documents.start_byte, - documents.end_byte - FROM - documents, files - WHERE - documents.file_id = files.id AND - documents.id in rarray(?) - ", - )?; + pub fn get_documents_by_ids( + &self, + ids: &[i64], + ) -> impl Future)>>> { + let ids = ids.to_vec(); + self.transact(move |db| { + let mut statement = db.prepare( + " + SELECT + documents.id, + files.worktree_id, + files.relative_path, + documents.start_byte, + documents.end_byte + FROM + documents, files + WHERE + documents.file_id = files.id AND + documents.id in rarray(?) + ", + )?; - let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, i64>(1)?, - row.get::<_, String>(2)?.into(), - row.get(3)?..row.get(4)?, - )) - })?; + let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, i64>(1)?, + row.get::<_, String>(2)?.into(), + row.get(3)?..row.get(4)?, + )) + })?; - let mut values_by_id = HashMap::)>::default(); - for row in result_iter { - let (id, worktree_id, path, range) = row?; - values_by_id.insert(id, (worktree_id, path, range)); - } + let mut values_by_id = HashMap::)>::default(); + for row in result_iter { + let (id, worktree_id, path, range) = row?; + values_by_id.insert(id, (worktree_id, path, range)); + } - let mut results = Vec::with_capacity(ids.len()); - for id in ids { - let value = values_by_id - .remove(id) - .ok_or(anyhow!("missing document id {}", id))?; - results.push(value); - } + let mut results = Vec::with_capacity(ids.len()); + for id in &ids { + let value = values_by_id + .remove(id) + .ok_or(anyhow!("missing document id {}", id))?; + results.push(value); + } - Ok(results) + Ok(results) + }) } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index cde53182dc..7a0985b273 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,11 +12,10 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; -use futures::{channel::oneshot, Future}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{ search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, @@ -101,13 +100,11 @@ pub fn init( pub struct SemanticIndex { fs: Arc, - database_url: Arc, + db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - db_update_tx: channel::Sender, parsing_files_tx: channel::Sender, _embedding_task: Task<()>, - _db_update_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -203,32 +200,6 @@ pub struct SearchResult { pub range: Range, } -enum DbOperation { - InsertFile { - worktree_id: i64, - documents: Vec, - path: PathBuf, - mtime: SystemTime, - job_handle: JobHandle, - }, - Delete { - worktree_id: i64, - path: PathBuf, - }, - FindOrCreateWorktree { - path: PathBuf, - sender: oneshot::Sender>, - }, - FileMTimes { - worktree_id: i64, - sender: oneshot::Sender>>, - }, - WorktreePreviouslyIndexed { - path: Arc, - sender: oneshot::Sender>, - }, -} - impl SemanticIndex { pub fn global(cx: &AppContext) -> Option> { if cx.has_global::>() { @@ -245,18 +216,14 @@ impl SemanticIndex { async fn new( fs: Arc, - database_url: PathBuf, + database_path: PathBuf, embedding_provider: Arc, language_registry: Arc, mut cx: AsyncAppContext, ) -> Result> { let t0 = Instant::now(); - let database_url = Arc::new(database_url); - - let db = cx - .background() - .spawn(VectorDatabase::new(fs.clone(), database_url.clone())) - .await?; + let database_path = Arc::from(database_path); + let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?; log::trace!( "db initialization took {:?} milliseconds", @@ -265,32 +232,16 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); - // Perform database operations - let (db_update_tx, db_update_rx) = channel::unbounded(); - let _db_update_task = cx.background().spawn({ - async move { - while let Ok(job) = db_update_rx.recv().await { - Self::run_db_operation(&db, job) - } - } - }); - let embedding_queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); - let db_update_tx = db_update_tx.clone(); + let db = db.clone(); async move { while let Ok(file) = embedded_files.recv().await { - db_update_tx - .try_send(DbOperation::InsertFile { - worktree_id: file.worktree_id, - documents: file.documents, - path: file.path, - mtime: file.mtime, - job_handle: file.job_handle, - }) - .ok(); + db.insert_file(file.worktree_id, file.path, file.mtime, file.documents) + .await + .log_err(); } } }); @@ -325,12 +276,10 @@ impl SemanticIndex { ); Self { fs, - database_url, + db, embedding_provider, language_registry, - db_update_tx, parsing_files_tx, - _db_update_task, _embedding_task, _parsing_files_tasks, projects: HashMap::new(), @@ -338,40 +287,6 @@ impl SemanticIndex { })) } - fn run_db_operation(db: &VectorDatabase, job: DbOperation) { - match job { - DbOperation::InsertFile { - worktree_id, - documents, - path, - mtime, - job_handle, - } => { - db.insert_file(worktree_id, path, mtime, documents) - .log_err(); - drop(job_handle) - } - DbOperation::Delete { worktree_id, path } => { - db.delete_file(worktree_id, path).log_err(); - } - DbOperation::FindOrCreateWorktree { path, sender } => { - let id = db.find_or_create_worktree(&path); - sender.send(id).ok(); - } - DbOperation::FileMTimes { - worktree_id: worktree_db_id, - sender, - } => { - let file_mtimes = db.get_file_mtimes(worktree_db_id); - sender.send(file_mtimes).ok(); - } - DbOperation::WorktreePreviouslyIndexed { path, sender } => { - let worktree_indexed = db.worktree_previously_indexed(path.as_ref()); - sender.send(worktree_indexed).ok(); - } - } - } - async fn parse_file( fs: &Arc, pending_file: PendingFile, @@ -409,36 +324,6 @@ impl SemanticIndex { } } - fn find_or_create_worktree(&self, path: PathBuf) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx }) - .unwrap(); - async move { rx.await? } - } - - fn get_file_mtimes( - &self, - worktree_id: i64, - ) -> impl Future>> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::FileMTimes { - worktree_id, - sender: tx, - }) - .unwrap(); - async move { rx.await? } - } - - fn worktree_previously_indexed(&self, path: Arc) -> impl Future> { - let (tx, rx) = oneshot::channel(); - self.db_update_tx - .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx }) - .unwrap(); - async move { rx.await? } - } - pub fn project_previously_indexed( &mut self, project: ModelHandle, @@ -447,7 +332,10 @@ impl SemanticIndex { let worktrees_indexed_previously = project .read(cx) .worktrees(cx) - .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path())) + .map(|worktree| { + self.db + .worktree_previously_indexed(&worktree.read(cx).abs_path()) + }) .collect::>(); cx.spawn(|_, _cx| async move { let worktree_indexed_previously = @@ -528,7 +416,8 @@ impl SemanticIndex { .read(cx) .worktrees(cx) .map(|worktree| { - self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) + self.db + .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) }) .collect::>(); @@ -559,7 +448,7 @@ impl SemanticIndex { db_ids_by_worktree_id.insert(worktree.id(), db_id); worktree_file_mtimes.insert( worktree.id(), - this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) + this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id)) .await?, ); } @@ -704,11 +593,12 @@ impl SemanticIndex { .collect::>(); let embedding_provider = self.embedding_provider.clone(); - let database_url = self.database_url.clone(); + let db_path = self.db.path().clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { let t0 = Instant::now(); - let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; + let database = + VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; let phrase_embedding = embedding_provider .embed_batch(vec![phrase]) @@ -722,8 +612,9 @@ impl SemanticIndex { t0.elapsed().as_millis() ); - let file_ids = - database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?; + let file_ids = database + .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) + .await?; let batch_n = cx.background().num_cpus(); let ids_len = file_ids.clone().len(); @@ -733,27 +624,24 @@ impl SemanticIndex { ids_len / batch_n }; - let mut result_tasks = Vec::new(); + let mut batch_results = Vec::new(); for batch in file_ids.chunks(batch_size) { let batch = batch.into_iter().map(|v| *v).collect::>(); let limit = limit.clone(); let fs = fs.clone(); - let database_url = database_url.clone(); + let db_path = db_path.clone(); let phrase_embedding = phrase_embedding.clone(); - let task = cx.background().spawn(async move { - let database = VectorDatabase::new(fs, database_url).await.log_err(); - if database.is_none() { - return Err(anyhow!("failed to acquire database connection")); - } else { - database - .unwrap() - .top_k_search(&phrase_embedding, limit, batch.as_slice()) - } - }); - result_tasks.push(task); + if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) + .await + .log_err() + { + batch_results.push(async move { + db.top_k_search(&phrase_embedding, limit, batch.as_slice()) + .await + }); + } } - - let batch_results = futures::future::join_all(result_tasks).await; + let batch_results = futures::future::join_all(batch_results).await; let mut results = Vec::new(); for batch_result in batch_results { @@ -772,7 +660,7 @@ impl SemanticIndex { } let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(ids.as_slice())?; + let documents = database.get_documents_by_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); let mut ranges = Vec::new(); @@ -822,7 +710,8 @@ impl SemanticIndex { cx: &mut AsyncAppContext, ) { let mut pending_files = Vec::new(); - let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| { + let mut files_to_delete = Vec::new(); + let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| { if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; let db_ids = &project_state.worktree_db_ids; @@ -853,12 +742,7 @@ impl SemanticIndex { }; if info.is_deleted { - this.db_update_tx - .try_send(DbOperation::Delete { - worktree_id: worktree_db_id, - path: path.path.to_path_buf(), - }) - .ok(); + files_to_delete.push((worktree_db_id, path.path.to_path_buf())); } else { let absolute_path = worktree.read(cx).absolutize(&path.path); let job_handle = JobHandle::new(&outstanding_job_count_tx); @@ -877,11 +761,16 @@ impl SemanticIndex { } ( + this.db.clone(), this.language_registry.clone(), this.parsing_files_tx.clone(), ) }); + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); + } + for mut pending_file in pending_files { if let Ok(language) = language_registry .language_for_file(&pending_file.relative_path, None) From c763e728d12b413d27ae9f1477026dc82c0cf002 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 16:59:54 +0200 Subject: [PATCH 12/60] Write to and read from the database in a transactional way Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 652c2819ce..313df40674 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -58,7 +58,8 @@ impl FromSql for Sha1 { #[derive(Clone)] pub struct VectorDatabase { path: Arc, - transactions: smol::channel::Sender>, + transactions: + smol::channel::Sender>, } impl VectorDatabase { @@ -71,15 +72,16 @@ impl VectorDatabase { fs.create_dir(db_directory).await?; } - let (transactions_tx, transactions_rx) = - smol::channel::unbounded::>(); + let (transactions_tx, transactions_rx) = smol::channel::unbounded::< + Box, + >(); executor .spawn({ let path = path.clone(); async move { - let connection = rusqlite::Connection::open(&path)?; + let mut connection = rusqlite::Connection::open(&path)?; while let Ok(transaction) = transactions_rx.recv().await { - transaction(&connection); + transaction(&mut connection); } anyhow::Ok(()) @@ -99,9 +101,9 @@ impl VectorDatabase { &self.path } - fn transact(&self, transaction: F) -> impl Future> + fn transact(&self, f: F) -> impl Future> where - F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result, + F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result, T: 'static + Send, { let (tx, rx) = oneshot::channel(); @@ -109,7 +111,14 @@ impl VectorDatabase { async move { if transactions .send(Box::new(|connection| { - let result = transaction(connection); + let result = connection + .transaction() + .map_err(|err| anyhow!(err)) + .and_then(|transaction| { + let result = f(&transaction)?; + transaction.commit()?; + Ok(result) + }); let _ = tx.send(result); })) .await From 7b5974e8e9f0ef912cafce81ce87a59aa7351137 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Wed, 30 Aug 2023 13:04:25 +0300 Subject: [PATCH 13/60] Add LSP logs clear button --- crates/language_tools/src/lsp_log.rs | 52 ++++++++++++++++++++++++---- crates/theme/src/theme.rs | 1 + styles/src/style_tree/toolbar.ts | 9 ++++- styles/src/style_tree/workspace.ts | 2 ++ 4 files changed, 56 insertions(+), 8 deletions(-) diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 60c4e41666..3275b3ee01 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -570,10 +570,12 @@ impl View for LspLogToolbarItemView { let Some(log_view) = self.log_view.as_ref() else { return Empty::new().into_any(); }; - let log_view = log_view.read(cx); - let menu_rows = log_view.menu_items(cx).unwrap_or_default(); + let (menu_rows, current_server_id) = log_view.update(cx, |log_view, cx| { + let menu_rows = log_view.menu_items(cx).unwrap_or_default(); + let current_server_id = log_view.current_server_id; + (menu_rows, current_server_id) + }); - let current_server_id = log_view.current_server_id; let current_server = current_server_id.and_then(|current_server_id| { if let Ok(ix) = menu_rows.binary_search_by_key(¤t_server_id, |e| e.server_id) { Some(menu_rows[ix].clone()) @@ -583,8 +585,7 @@ impl View for LspLogToolbarItemView { }); enum Menu {} - - Stack::new() + let lsp_menu = Stack::new() .with_child(Self::render_language_server_menu_header( current_server, &theme, @@ -631,8 +632,45 @@ impl View for LspLogToolbarItemView { }) .aligned() .left() - .clipped() - .into_any() + .clipped(); + + enum LspCleanupButton {} + let log_cleanup_button = + MouseEventHandler::new::(1, cx, |state, cx| { + let theme = theme::current(cx).clone(); + let style = theme + .workspace + .toolbar + .toggleable_text_tool + .active_state() + .style_for(state); + Label::new("Clear", style.text.clone()) + .aligned() + .contained() + .with_style(style.container) + }) + .on_click(MouseButton::Left, move |_, this, cx| { + if let Some(log_view) = this.log_view.as_ref() { + log_view.update(cx, |log_view, cx| { + log_view.editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.clear(cx); + editor.set_read_only(true); + }); + }) + } + }) + .with_cursor_style(CursorStyle::PointingHand) + .aligned() + .right(); + + Flex::row() + .with_child(lsp_menu) + .with_child(log_cleanup_button) + .contained() + .aligned() + .left() + .into_any_named("lsp log controls") } } diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index a51f18c4db..a542249788 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -408,6 +408,7 @@ pub struct Toolbar { pub height: f32, pub item_spacing: f32, pub toggleable_tool: Toggleable>, + pub toggleable_text_tool: Toggleable>, pub breadcrumb_height: f32, pub breadcrumbs: Interactive, } diff --git a/styles/src/style_tree/toolbar.ts b/styles/src/style_tree/toolbar.ts index 7292a220a8..0145ee2785 100644 --- a/styles/src/style_tree/toolbar.ts +++ b/styles/src/style_tree/toolbar.ts @@ -1,7 +1,8 @@ import { useTheme } from "../common" import { toggleable_icon_button } from "../component/icon_button" -import { interactive } from "../element" +import { interactive, toggleable } from "../element" import { background, border, foreground, text } from "./components" +import { text_button } from "../component"; export const toolbar = () => { const theme = useTheme() @@ -34,5 +35,11 @@ export const toolbar = () => { }, }, }), + toggleable_text_tool: toggleable({ + state: { + inactive: text_button({ variant: "ghost", layer: theme.highest, disabled: true, margin: { right: 4 }, text_properties: { size: "sm" } }), + active: text_button({ variant: "ghost", layer: theme.highest, margin: { right: 4 }, text_properties: { size: "sm" } }) + } + }), } } diff --git a/styles/src/style_tree/workspace.ts b/styles/src/style_tree/workspace.ts index ba89c7b05f..8fda5e0117 100644 --- a/styles/src/style_tree/workspace.ts +++ b/styles/src/style_tree/workspace.ts @@ -19,6 +19,8 @@ export default function workspace(): any { const { is_light } = theme + const TOOLBAR_ITEM_SPACING = 8; + return { background: background(theme.lowest), blank_pane: { From fe2300fdaa59f12df58736f0f4cf61db2b8ee8c3 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Wed, 30 Aug 2023 14:49:33 +0300 Subject: [PATCH 14/60] Style the clear button better, add border to button constructor options --- crates/language_tools/src/lsp_log.rs | 5 ++++- styles/src/component/text_button.ts | 6 ++++++ styles/src/style_tree/toolbar.ts | 17 +++++++++++++++-- styles/src/style_tree/workspace.ts | 2 -- 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 3275b3ee01..a918e3d151 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -583,6 +583,7 @@ impl View for LspLogToolbarItemView { None } }); + let server_selected = current_server.is_some(); enum Menu {} let lsp_menu = Stack::new() @@ -642,12 +643,14 @@ impl View for LspLogToolbarItemView { .workspace .toolbar .toggleable_text_tool - .active_state() + .in_state(server_selected) .style_for(state); Label::new("Clear", style.text.clone()) .aligned() .contained() .with_style(style.container) + .constrained() + .with_height(theme.toolbar_dropdown_menu.row_height / 6.0 * 5.0) }) .on_click(MouseButton::Left, move |_, this, cx| { if let Some(log_view) = this.log_view.as_ref() { diff --git a/styles/src/component/text_button.ts b/styles/src/component/text_button.ts index 8333d9e81a..0e293e403a 100644 --- a/styles/src/component/text_button.ts +++ b/styles/src/component/text_button.ts @@ -1,5 +1,6 @@ import { interactive, toggleable } from "../element" import { + Border, TextProperties, background, foreground, @@ -16,6 +17,7 @@ interface TextButtonOptions { margin?: Partial disabled?: boolean text_properties?: TextProperties + border?: Border } type ToggleableTextButtonOptions = TextButtonOptions & { @@ -29,6 +31,7 @@ export function text_button({ margin, disabled, text_properties, + border, }: TextButtonOptions = {}) { const theme = useTheme() if (!color) color = "base" @@ -66,6 +69,7 @@ export function text_button({ }, state: { default: { + border, background: background_color, color: disabled ? foreground(layer ?? theme.lowest, "disabled") @@ -74,6 +78,7 @@ export function text_button({ hovered: disabled ? {} : { + border, background: background( layer ?? theme.lowest, color, @@ -88,6 +93,7 @@ export function text_button({ clicked: disabled ? {} : { + border, background: background( layer ?? theme.lowest, color, diff --git a/styles/src/style_tree/toolbar.ts b/styles/src/style_tree/toolbar.ts index 0145ee2785..01a09a0616 100644 --- a/styles/src/style_tree/toolbar.ts +++ b/styles/src/style_tree/toolbar.ts @@ -37,8 +37,21 @@ export const toolbar = () => { }), toggleable_text_tool: toggleable({ state: { - inactive: text_button({ variant: "ghost", layer: theme.highest, disabled: true, margin: { right: 4 }, text_properties: { size: "sm" } }), - active: text_button({ variant: "ghost", layer: theme.highest, margin: { right: 4 }, text_properties: { size: "sm" } }) + inactive: text_button({ + disabled: true, + variant: "ghost", + layer: theme.highest, + margin: { left: 4 }, + text_properties: { size: "sm" }, + border: border(theme.middle), + }), + active: text_button({ + variant: "ghost", + layer: theme.highest, + margin: { left: 4 }, + text_properties: { size: "sm" }, + border: border(theme.middle), + }), } }), } diff --git a/styles/src/style_tree/workspace.ts b/styles/src/style_tree/workspace.ts index 8fda5e0117..ba89c7b05f 100644 --- a/styles/src/style_tree/workspace.ts +++ b/styles/src/style_tree/workspace.ts @@ -19,8 +19,6 @@ export default function workspace(): any { const { is_light } = theme - const TOOLBAR_ITEM_SPACING = 8; - return { background: background(theme.lowest), blank_pane: { From 3001a46f6995cd900cae7bf633605dc0fb1334e4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 17:55:43 +0200 Subject: [PATCH 15/60] Reify `Embedding`/`Sha1` structs that can be (de)serialized from SQL Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 76 ++---------- crates/semantic_index/src/embedding.rs | 114 +++++++++++++++++- crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/parsing.rs | 69 +++++++---- .../src/semantic_index_tests.rs | 57 +++------ 5 files changed, 180 insertions(+), 138 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 313df40674..81b05720d2 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,13 +1,10 @@ -use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION}; use anyhow::{anyhow, Context, Result}; use futures::channel::oneshot; use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; -use rusqlite::{ - params, - types::{FromSql, FromSqlResult, ValueRef}, -}; +use rusqlite::params; use std::{ cmp::Ordering, collections::HashMap, @@ -27,34 +24,6 @@ pub struct FileRecord { pub mtime: Timestamp, } -#[derive(Debug)] -struct Embedding(pub Vec); - -#[derive(Debug)] -struct Sha1(pub Vec); - -impl FromSql for Embedding { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let embedding: Result, Box> = bincode::deserialize(bytes); - if embedding.is_err() { - return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); - } - return Ok(Embedding(embedding.unwrap())); - } -} - -impl FromSql for Sha1 { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let sha1: Result, Box> = bincode::deserialize(bytes); - if sha1.is_err() { - return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err())); - } - return Ok(Sha1(sha1.unwrap())); - } -} - #[derive(Clone)] pub struct VectorDatabase { path: Arc, @@ -255,9 +224,6 @@ impl VectorDatabase { // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in documents { - let embedding_blob = bincode::serialize(&document.embedding)?; - let sha_blob = bincode::serialize(&document.sha1)?; - db.execute( "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ @@ -265,8 +231,8 @@ impl VectorDatabase { document.range.start.to_string(), document.range.end.to_string(), document.name, - embedding_blob, - sha_blob + document.embedding, + document.sha1 ], )?; } @@ -351,7 +317,7 @@ impl VectorDatabase { pub fn top_k_search( &self, - query_embedding: &Vec, + query_embedding: &Embedding, limit: usize, file_ids: &[i64], ) -> impl Future>> { @@ -360,7 +326,7 @@ impl VectorDatabase { self.transact(move |db| { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); Self::for_each_document(db, &file_ids, |id, embedding| { - let similarity = dot(&embedding, &query_embedding); + let similarity = embedding.similarity(&query_embedding); let ix = match results.binary_search_by(|(_, s)| { similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) }) { @@ -417,7 +383,7 @@ impl VectorDatabase { fn for_each_document( db: &rusqlite::Connection, file_ids: &[i64], - mut f: impl FnMut(i64, Vec), + mut f: impl FnMut(i64, Embedding), ) -> Result<()> { let mut query_statement = db.prepare( " @@ -435,7 +401,7 @@ impl VectorDatabase { Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) })? .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding.0)); + .for_each(|(id, embedding)| f(id, embedding)); Ok(()) } @@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc> { .collect::>(), ) } - -pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 { - let len = vec_a.len(); - assert_eq!(len, vec_b.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - vec_a.as_ptr(), - len as isize, - 1, - vec_b.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - result -} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 60e13a9e01..97c25ca170 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -8,6 +8,8 @@ use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; use parse_duration::parse; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; @@ -20,6 +22,62 @@ lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(Vec); + +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> f32 { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + result + } +} + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} + #[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, @@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec) -> Result>>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); } @@ -62,10 +120,10 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1536]; + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); return Ok(vec![dummy_vec; spans.len()]); } @@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { (output, token_count) } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { return Ok(response .data .into_iter() - .map(|embedding| embedding.embedding) + .map(|embedding| Embedding::from(embedding.embedding)) .collect()); } StatusCode::TOO_MANY_REQUESTS => { @@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings { Err(anyhow!("openai max retries")) } } + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() + } + } +} diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index c3a5de1373..4c82ced918 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -121,7 +121,7 @@ impl EmbeddingQueue { &mut fragment.file.lock().documents[fragment.document_range.clone()] { if let Some(embedding) = embeddings.next() { - document.embedding = embedding; + document.embedding = Some(embedding); } else { // log::error!("number of embeddings returned different from number of documents"); diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 51f1bd7ca9..2b67f41714 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,7 +1,11 @@ -use crate::embedding::EmbeddingProvider; -use anyhow::{anyhow, Ok, Result}; +use crate::embedding::{EmbeddingProvider, Embedding}; +use anyhow::{anyhow, Result}; use language::{Grammar, Language}; -use sha1::{Digest, Sha1}; +use rusqlite::{ + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, +}; +use sha1::Digest; use std::{ cmp::{self, Reverse}, collections::HashSet, @@ -11,13 +15,43 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; +#[derive(Debug, PartialEq, Clone)] +pub struct Sha1([u8; 20]); + +impl FromSql for Sha1 { + fn column_result(value: ValueRef) -> FromSqlResult { + let blob = value.as_blob()?; + let bytes = + blob.try_into() + .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { + expected_size: 20, + blob_size: blob.len(), + })?; + return Ok(Sha1(bytes)); + } +} + +impl ToSql for Sha1 { + fn to_sql(&self) -> rusqlite::Result { + self.0.to_sql() + } +} + +impl From<&'_ str> for Sha1 { + fn from(value: &'_ str) -> Self { + let mut sha1 = sha1::Sha1::new(); + sha1.update(value); + Self(sha1.finalize().into()) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct Document { pub name: String, pub range: Range, pub content: String, - pub embedding: Vec, - pub sha1: [u8; 20], + pub embedding: Option, + pub sha1: Sha1, pub token_count: usize, } @@ -69,17 +103,16 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); - let mut sha1 = Sha1::new(); - sha1.update(&document_span); + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: Default::default(), name: language_name.to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -88,18 +121,14 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - - let mut sha1 = Sha1::new(); - sha1.update(&document_span); - + let sha1 = Sha1::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { range: 0..content.len(), content: document_span, - embedding: Vec::new(), + embedding: None, name: "Markdown".to_string(), - sha1: sha1.finalize().into(), + sha1, token_count, }]) } @@ -279,15 +308,13 @@ impl CodeContextRetriever { ); } - let mut sha1 = Sha1::new(); - sha1.update(&document_content); - + let sha1 = Sha1::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), - embedding: vec![], - sha1: sha1.finalize().into(), + embedding: None, + sha1, token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index dc41c09f7a..75232eb4d2 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,8 +1,7 @@ use crate::{ - db::dot, - embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document}, + parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; @@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { documents: (0..rng.gen_range(4..22)) .map(|document_ix| { let content_len = rng.gen_range(10..100); + let content = RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect::(); + let sha1 = Sha1::from(content.as_str()); Document { range: 0..10, - embedding: Vec::new(), + embedding: None, name: format!("document {document_ix}"), - content: RandomCharIter::new(&mut rng) - .with_simple_text() - .take(content_len) - .collect(), - sha1: rng.gen(), + content, + sha1, token_count: rng.gen_range(10..30), } }) @@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .map(|file| { let mut file = file.clone(); for doc in &mut file.documents { - doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); } file }) @@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() { ); } -#[gpui::test] -fn test_dot_product(mut rng: StdRng) { - assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.); - assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.); - - for _ in 0..100 { - let size = 1536; - let mut a = vec![0.; size]; - let mut b = vec![0.; size]; - for (a, b) in a.iter_mut().zip(b.iter_mut()) { - *a = rng.gen(); - *b = rng.gen(); - } - - assert_eq!( - round_to_decimals(dot(&a, &b), 1), - round_to_decimals(reference_dot(&a, &b), 1) - ); - } - - fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { - let factor = (10.0 as f32).powi(decimal_places); - (n * factor).round() / factor - } - - fn reference_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() - } -} - #[derive(Default)] struct FakeEmbeddingProvider { embedding_count: AtomicUsize, @@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider { self.embedding_count.load(atomic::Ordering::SeqCst) } - fn embed_sync(&self, span: &str) -> Vec { + fn embed_sync(&self, span: &str) -> Embedding { let mut result = vec![1.0; 26]; for letter in span.chars() { let letter = letter.to_ascii_lowercase(); @@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider { *x /= norm; } - result + result.into() } } @@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { 200 } - async fn embed_batch(&self, spans: Vec) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) From 2503d54d1957f9b34c64af54b2a6d2e0e712ac13 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 18:00:36 +0200 Subject: [PATCH 16/60] Rename `Sha1` to `DocumentDigest` Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 12 ++++--- crates/semantic_index/src/parsing.rs | 35 +++++++++---------- crates/semantic_index/src/semantic_index.rs | 2 +- .../src/semantic_index_tests.rs | 6 ++-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 81b05720d2..375934e7fe 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,4 +1,8 @@ -use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION}; +use crate::{ + embedding::Embedding, + parsing::{Document, DocumentDigest}, + SEMANTIC_INDEX_VERSION, +}; use anyhow::{anyhow, Context, Result}; use futures::channel::oneshot; use gpui::executor; @@ -165,7 +169,7 @@ impl VectorDatabase { end_byte INTEGER NOT NULL, name VARCHAR NOT NULL, embedding BLOB NOT NULL, - sha1 BLOB NOT NULL, + digest BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", [], @@ -225,14 +229,14 @@ impl VectorDatabase { // I imagine we can speed this up with a bulk insert of some kind. for document in documents { db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![ file_id, document.range.start.to_string(), document.range.end.to_string(), document.name, document.embedding, - document.sha1 + document.digest ], )?; } diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 2b67f41714..c0a94c6b73 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,11 +1,11 @@ -use crate::embedding::{EmbeddingProvider, Embedding}; +use crate::embedding::{Embedding, EmbeddingProvider}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, ToSql, }; -use sha1::Digest; +use sha1::{Digest, Sha1}; use std::{ cmp::{self, Reverse}, collections::HashSet, @@ -15,10 +15,10 @@ use std::{ }; use tree_sitter::{Parser, QueryCursor}; -#[derive(Debug, PartialEq, Clone)] -pub struct Sha1([u8; 20]); +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct DocumentDigest([u8; 20]); -impl FromSql for Sha1 { +impl FromSql for DocumentDigest { fn column_result(value: ValueRef) -> FromSqlResult { let blob = value.as_blob()?; let bytes = @@ -27,19 +27,19 @@ impl FromSql for Sha1 { expected_size: 20, blob_size: blob.len(), })?; - return Ok(Sha1(bytes)); + return Ok(DocumentDigest(bytes)); } } -impl ToSql for Sha1 { +impl ToSql for DocumentDigest { fn to_sql(&self) -> rusqlite::Result { self.0.to_sql() } } -impl From<&'_ str> for Sha1 { +impl From<&'_ str> for DocumentDigest { fn from(value: &'_ str) -> Self { - let mut sha1 = sha1::Sha1::new(); + let mut sha1 = Sha1::new(); sha1.update(value); Self(sha1.finalize().into()) } @@ -51,7 +51,7 @@ pub struct Document { pub range: Range, pub content: String, pub embedding: Option, - pub sha1: Sha1, + pub digest: DocumentDigest, pub token_count: usize, } @@ -102,17 +102,14 @@ impl CodeContextRetriever { .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("", &content); - - let sha1 = Sha1::from(document_span.as_str()); - + let digest = DocumentDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: Default::default(), name: language_name.to_string(), - sha1, + digest, token_count, }]) } @@ -121,14 +118,14 @@ impl CodeContextRetriever { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - let sha1 = Sha1::from(document_span.as_str()); + let digest = DocumentDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), content: document_span, embedding: None, name: "Markdown".to_string(), - sha1, + digest, token_count, }]) } @@ -308,13 +305,13 @@ impl CodeContextRetriever { ); } - let sha1 = Sha1::from(document_content.as_str()); + let sha1 = DocumentDigest::from(document_content.as_str()); documents.push(Document { name, content: document_content, range: item_range.clone(), embedding: None, - sha1, + digest: sha1, token_count: 0, }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7a0985b273..0a9a808a64 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -37,7 +37,7 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 7; +const SEMANTIC_INDEX_VERSION: usize = 8; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); pub fn init( diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 75232eb4d2..e65bc04412 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,7 +1,7 @@ use crate::{ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1}, + parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; @@ -220,13 +220,13 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .with_simple_text() .take(content_len) .collect::(); - let sha1 = Sha1::from(content.as_str()); + let digest = DocumentDigest::from(content.as_str()); Document { range: 0..10, embedding: None, name: format!("document {document_ix}"), content, - sha1, + digest, token_count: rng.gen_range(10..30), } }) From 220533ff1abf46066853eae31c11eb17b219554d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 31 Aug 2023 18:00:57 +0200 Subject: [PATCH 17/60] WIP --- crates/semantic_index/src/db.rs | 19 +++++++++++++++++++ crates/semantic_index/src/semantic_index.rs | 17 +++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 375934e7fe..134a70972f 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -264,6 +264,25 @@ impl VectorDatabase { }) } + pub fn embeddings_for_file( + &self, + worktree_id: i64, + relative_path: PathBuf, + ) -> impl Future>> { + let relative_path = relative_path.to_string_lossy().into_owned(); + self.transact(move |db| { + let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE files.worktree_id = ?1 AND files.relative_path = ?2")?; + let mut result: HashMap = HashMap::new(); + for row in query.query_map(params![worktree_id, relative_path], |row| { + Ok((row.get::<_, DocumentDigest>(0)?.into(), row.get::<_, Embedding>(1)?.into())) + })? { + let row = row?; + result.insert(row.0, row.1); + } + Ok(result) + }) + } + pub fn find_or_create_worktree( &self, worktree_root_path: PathBuf, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 0a9a808a64..58166c1a22 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -309,6 +309,23 @@ impl SemanticIndex { documents.len() ); + todo!(); + // if let Some(embeddings) = db + // .embeddings_for_documents( + // pending_file.worktree_db_id, + // pending_file.relative_path, + // &documents, + // ) + // .await + // .log_err() + // { + // for (document, embedding) in documents.iter_mut().zip(embeddings) { + // if let Some(embedding) = embedding { + // document.embedding = embedding; + // } + // } + // } + embedding_queue.lock().push(FileToEmbed { worktree_id: pending_file.worktree_db_id, path: pending_file.relative_path, From 50cfb067e7c536636ed5bf7e119968d50843b287 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 31 Aug 2023 13:19:17 -0400 Subject: [PATCH 18/60] fill embeddings with database values and skip during embeddings queue --- crates/semantic_index/src/embedding_queue.rs | 34 ++++++++++++++++--- crates/semantic_index/src/semantic_index.rs | 35 ++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 4c82ced918..96493fc4d3 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -42,6 +42,7 @@ pub struct EmbeddingQueue { finished_files_rx: channel::Receiver, } +#[derive(Clone)] pub struct FileToEmbedFragment { file: Arc>, document_range: Range, @@ -74,8 +75,16 @@ impl EmbeddingQueue { }); let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut saved_tokens = 0; for (ix, document) in file.lock().documents.iter().enumerate() { - let next_token_count = self.pending_batch_token_count + document.token_count; + let document_token_count = if document.embedding.is_none() { + document.token_count + } else { + saved_tokens += document.token_count; + 0 + }; + + let next_token_count = self.pending_batch_token_count + document_token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; self.flush(); @@ -87,8 +96,9 @@ impl EmbeddingQueue { } fragment_range.end = ix + 1; - self.pending_batch_token_count += document.token_count; + self.pending_batch_token_count += document_token_count; } + log::trace!("Saved Tokens: {:?}", saved_tokens); } pub fn flush(&mut self) { @@ -100,25 +110,41 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); + self.executor.spawn(async move { let mut spans = Vec::new(); + let mut document_count = 0; for fragment in &batch { let file = fragment.file.lock(); + document_count += file.documents[fragment.document_range.clone()].len(); spans.extend( { file.documents[fragment.document_range.clone()] - .iter() + .iter().filter(|d| d.embedding.is_none()) .map(|d| d.content.clone()) } ); } + log::trace!("Documents Length: {:?}", document_count); + log::trace!("Span Length: {:?}", spans.clone().len()); + + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.len() == 0 { + for fragment in batch.clone() { + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + return; + }; + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { for document in - &mut fragment.file.lock().documents[fragment.document_range.clone()] + &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) { if let Some(embedding) = embeddings.next() { document.embedding = Some(embedding); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 58166c1a22..726b04583a 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -255,6 +255,7 @@ impl SemanticIndex { let parsing_files_rx = parsing_files_rx.clone(); let embedding_provider = embedding_provider.clone(); let embedding_queue = embedding_queue.clone(); + let db = db.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); while let Ok(pending_file) = parsing_files_rx.recv().await { @@ -264,6 +265,7 @@ impl SemanticIndex { &mut retriever, &embedding_queue, &parsing_files_rx, + &db, ) .await; } @@ -293,13 +295,14 @@ impl SemanticIndex { retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, parsing_files_rx: &channel::Receiver, + db: &VectorDatabase, ) { let Some(language) = pending_file.language else { return; }; if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(documents) = retriever + if let Some(mut documents) = retriever .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { @@ -309,22 +312,20 @@ impl SemanticIndex { documents.len() ); - todo!(); - // if let Some(embeddings) = db - // .embeddings_for_documents( - // pending_file.worktree_db_id, - // pending_file.relative_path, - // &documents, - // ) - // .await - // .log_err() - // { - // for (document, embedding) in documents.iter_mut().zip(embeddings) { - // if let Some(embedding) = embedding { - // document.embedding = embedding; - // } - // } - // } + if let Some(sha_to_embeddings) = db + .embeddings_for_file( + pending_file.worktree_db_id, + pending_file.relative_path.clone(), + ) + .await + .log_err() + { + for document in documents.iter_mut() { + if let Some(embedding) = sha_to_embeddings.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); + } + } + } embedding_queue.lock().push(FileToEmbed { worktree_id: pending_file.worktree_db_id, From afa59abbcd8a6208a844227e122e0e439e50bfda Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 31 Aug 2023 16:42:39 -0400 Subject: [PATCH 19/60] WIP: work towards wiring up a embeddings_for_digest hashmap that is stored for all indexed files --- crates/semantic_index/src/db.rs | 36 ++++++++ crates/semantic_index/src/semantic_index.rs | 91 +++++++++++++++------ 2 files changed, 104 insertions(+), 23 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 134a70972f..4a953a2866 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -9,6 +9,7 @@ use gpui::executor; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::params; +use rusqlite::types::Value; use std::{ cmp::Ordering, collections::HashMap, @@ -283,6 +284,41 @@ impl VectorDatabase { }) } + pub fn embeddings_for_files( + &self, + worktree_id_file_paths: Vec<(i64, PathBuf)>, + ) -> impl Future>> { + todo!(); + // The remainder of the code is wired up. + // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query + async { Ok(HashMap::new()) } + // let mut embeddings_by_digest = HashMap::new(); + // self.transact(move |db| { + + // let worktree_ids: Rc> = Rc::new( + // worktree_id_file_paths + // .iter() + // .map(|(id, _)| Value::from(*id)) + // .collect(), + // ); + // let file_paths: Rc> = Rc::new(worktree_id_file_paths + // .iter() + // .map(|(_, path)| Value::from(path.to_string_lossy().to_string())) + // .collect()); + + // let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?; + + // for row in query.query_map(params![worktree_ids, file_paths], |row| { + // Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?)) + // })? { + // if let Ok(row) = row { + // embeddings_by_digest.insert(row.0, row.1); + // } + // } + // Ok(embeddings_by_digest) + // }) + } + pub fn find_or_create_worktree( &self, worktree_root_path: PathBuf, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 726b04583a..908ac1f4be 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -10,12 +10,12 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; use db::VectorDatabase; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{ search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, @@ -103,7 +103,7 @@ pub struct SemanticIndex { db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - parsing_files_tx: channel::Sender, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -247,7 +247,8 @@ impl SemanticIndex { }); // Parse files into embeddable documents. - let (parsing_files_tx, parsing_files_rx) = channel::unbounded::(); + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(Arc>, PendingFile)>(); let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { @@ -258,14 +259,16 @@ impl SemanticIndex { let db = db.clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - while let Ok(pending_file) = parsing_files_rx.recv().await { + while let Ok((embeddings_for_digest, pending_file)) = + parsing_files_rx.recv().await + { Self::parse_file( &fs, pending_file, &mut retriever, &embedding_queue, &parsing_files_rx, - &db, + &embeddings_for_digest, ) .await; } @@ -294,8 +297,11 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - parsing_files_rx: &channel::Receiver, - db: &VectorDatabase, + parsing_files_rx: &channel::Receiver<( + Arc>, + PendingFile, + )>, + embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { return; @@ -312,18 +318,9 @@ impl SemanticIndex { documents.len() ); - if let Some(sha_to_embeddings) = db - .embeddings_for_file( - pending_file.worktree_db_id, - pending_file.relative_path.clone(), - ) - .await - .log_err() - { - for document in documents.iter_mut() { - if let Some(embedding) = sha_to_embeddings.get(&document.digest) { - document.embedding = Some(embedding.to_owned()); - } + for document in documents.iter_mut() { + if let Some(embedding) = embeddings_for_digest.get(&document.digest) { + document.embedding = Some(embedding.to_owned()); } } @@ -381,6 +378,17 @@ impl SemanticIndex { return; }; + let embeddings_for_digest = { + let mut worktree_id_file_paths = Vec::new(); + for (path, _) in &project_state.changed_paths { + if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + } + } + self.db.embeddings_for_files(worktree_id_file_paths) + }; + let worktree = worktree.read(cx); let change_time = Instant::now(); for (path, entry_id, change) in changes.iter() { @@ -405,9 +413,18 @@ impl SemanticIndex { } cx.spawn_weak(|this, mut cx| async move { + let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default(); + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - Self::reindex_changed_paths(this, project, Some(change_time), &mut cx).await; + Self::reindex_changed_paths( + this, + project, + Some(change_time), + &mut cx, + Arc::new(embeddings_for_digest), + ) + .await; } }) .detach(); @@ -561,7 +578,32 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task)>> { cx.spawn(|this, mut cx| async move { - Self::reindex_changed_paths(this.clone(), project.clone(), None, &mut cx).await; + let embeddings_for_digest = this.read_with(&cx, |this, cx| { + if let Some(state) = this.projects.get(&project.downgrade()) { + let mut worktree_id_file_paths = Vec::new(); + for (path, _) in &state.changed_paths { + if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) + { + worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + } + } + + Ok(this.db.embeddings_for_files(worktree_id_file_paths)) + } else { + Err(anyhow!("Project not yet initialized")) + } + })?; + + let embeddings_for_digest = Arc::new(embeddings_for_digest.await?); + + Self::reindex_changed_paths( + this.clone(), + project.clone(), + None, + &mut cx, + embeddings_for_digest, + ) + .await; this.update(&mut cx, |this, _cx| { let Some(state) = this.projects.get(&project.downgrade()) else { @@ -726,6 +768,7 @@ impl SemanticIndex { project: ModelHandle, last_changed_before: Option, cx: &mut AsyncAppContext, + embeddings_for_digest: Arc>, ) { let mut pending_files = Vec::new(); let mut files_to_delete = Vec::new(); @@ -805,7 +848,9 @@ impl SemanticIndex { } pending_file.language = Some(language); } - parsing_files_tx.try_send(pending_file).ok(); + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); } } } From c4db914f0a4397878ffeb7ffb74c8f6a3522e272 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 08:59:18 -0400 Subject: [PATCH 20/60] move embeddings queue to use single hashmap for all changed paths Co-authored-by: Antonio --- crates/semantic_index/src/db.rs | 79 ++++++++----------- crates/semantic_index/src/semantic_index.rs | 14 +++- .../src/semantic_index_tests.rs | 5 +- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 4a953a2866..abb47cddf0 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -265,58 +265,43 @@ impl VectorDatabase { }) } - pub fn embeddings_for_file( - &self, - worktree_id: i64, - relative_path: PathBuf, - ) -> impl Future>> { - let relative_path = relative_path.to_string_lossy().into_owned(); - self.transact(move |db| { - let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE files.worktree_id = ?1 AND files.relative_path = ?2")?; - let mut result: HashMap = HashMap::new(); - for row in query.query_map(params![worktree_id, relative_path], |row| { - Ok((row.get::<_, DocumentDigest>(0)?.into(), row.get::<_, Embedding>(1)?.into())) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) - }) - } - pub fn embeddings_for_files( &self, - worktree_id_file_paths: Vec<(i64, PathBuf)>, + worktree_id_file_paths: HashMap>>, ) -> impl Future>> { - todo!(); - // The remainder of the code is wired up. - // I'm having a bit of trouble figuring out the rusqlite syntax for a WHERE (files.worktree_id, files.relative_path) IN (VALUES (?, ?), (?, ?)) query - async { Ok(HashMap::new()) } - // let mut embeddings_by_digest = HashMap::new(); - // self.transact(move |db| { + self.transact(move |db| { + let mut query = db.prepare( + " + SELECT digest, embedding + FROM documents + LEFT JOIN files ON files.id = documents.file_id + WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) + ", + )?; + let mut embeddings_by_digest = HashMap::new(); + for (worktree_id, file_paths) in worktree_id_file_paths { + let file_paths = Rc::new( + file_paths + .into_iter() + .map(|p| Value::Text(p.to_string_lossy().into_owned())) + .collect::>(), + ); + let rows = query.query_map(params![worktree_id, file_paths], |row| { + Ok(( + row.get::<_, DocumentDigest>(0)?, + row.get::<_, Embedding>(1)?, + )) + })?; - // let worktree_ids: Rc> = Rc::new( - // worktree_id_file_paths - // .iter() - // .map(|(id, _)| Value::from(*id)) - // .collect(), - // ); - // let file_paths: Rc> = Rc::new(worktree_id_file_paths - // .iter() - // .map(|(_, path)| Value::from(path.to_string_lossy().to_string())) - // .collect()); + for row in rows { + if let Ok(row) = row { + embeddings_by_digest.insert(row.0, row.1); + } + } + } - // let mut query = db.prepare("SELECT digest, embedding FROM documents LEFT JOIN files ON files.id = documents.file_id WHERE (files.worktree_id, files.relative_path) IN (VALUES (rarray = (?1), rarray = (?2))")?; - - // for row in query.query_map(params![worktree_ids, file_paths], |row| { - // Ok((row.get::<_, DocumentDigest>(0)?, row.get::<_, Embedding>(1)?)) - // })? { - // if let Ok(row) = row { - // embeddings_by_digest.insert(row.0, row.1); - // } - // } - // Ok(embeddings_by_digest) - // }) + Ok(embeddings_by_digest) + }) } pub fn find_or_create_worktree( diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 908ac1f4be..6d140931d6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -379,11 +379,14 @@ impl SemanticIndex { }; let embeddings_for_digest = { - let mut worktree_id_file_paths = Vec::new(); + let mut worktree_id_file_paths = HashMap::new(); for (path, _) in &project_state.changed_paths { if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) { - worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); } } self.db.embeddings_for_files(worktree_id_file_paths) @@ -580,11 +583,14 @@ impl SemanticIndex { cx.spawn(|this, mut cx| async move { let embeddings_for_digest = this.read_with(&cx, |this, cx| { if let Some(state) = this.projects.get(&project.downgrade()) { - let mut worktree_id_file_paths = Vec::new(); + let mut worktree_id_file_paths = HashMap::default(); for (path, _) in &state.changed_paths { if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) { - worktree_id_file_paths.push((worktree_db_id, path.path.to_path_buf())); + worktree_id_file_paths + .entry(worktree_db_id) + .or_insert(Vec::new()) + .push(path.path.clone()); } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index e65bc04412..01f34a2b1d 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -55,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { fn bbb() { println!(\"bbbbbbbbbbbbb!\"); } + struct pqpqpqp {} ".unindent(), "file3.toml": " ZZZZZZZZZZZZZZZZZZ = 5 @@ -121,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file2.rs").into(), 0), (Path::new("src/file3.toml").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); @@ -148,6 +150,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { (Path::new("src/file1.rs").into(), 0), (Path::new("src/file2.rs").into(), 0), (Path::new("src/file1.rs").into(), 45), + (Path::new("src/file2.rs").into(), 45), ], cx, ); @@ -199,7 +202,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, - 2 + 1 ); } From 524533cfb227dffba93adfec461fee722c73ba4d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 11:24:08 -0400 Subject: [PATCH 21/60] flush embeddings queue when no files are parsed for 250 milliseconds Co-authored-by: Antonio --- crates/semantic_index/src/semantic_index.rs | 50 ++++++++++--------- .../src/semantic_index_tests.rs | 12 ++--- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6d140931d6..a8518ce695 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,6 +12,7 @@ use anyhow::{anyhow, Result}; use db::VectorDatabase; use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; +use futures::{FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; @@ -39,6 +40,7 @@ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 8; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); +const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); pub fn init( fs: Arc, @@ -253,24 +255,34 @@ impl SemanticIndex { let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { let fs = fs.clone(); - let parsing_files_rx = parsing_files_rx.clone(); + let mut parsing_files_rx = parsing_files_rx.clone(); let embedding_provider = embedding_provider.clone(); let embedding_queue = embedding_queue.clone(); - let db = db.clone(); + let background = cx.background().clone(); _parsing_files_tasks.push(cx.background().spawn(async move { let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - while let Ok((embeddings_for_digest, pending_file)) = - parsing_files_rx.recv().await - { - Self::parse_file( - &fs, - pending_file, - &mut retriever, - &embedding_queue, - &parsing_files_rx, - &embeddings_for_digest, - ) - .await; + loop { + let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse(); + let mut next_file_to_parse = parsing_files_rx.next().fuse(); + futures::select_biased! { + next_file_to_parse = next_file_to_parse => { + if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse { + Self::parse_file( + &fs, + pending_file, + &mut retriever, + &embedding_queue, + &embeddings_for_digest, + ) + .await + } else { + break; + } + }, + _ = timer => { + embedding_queue.lock().flush(); + } + } } })); } @@ -297,10 +309,6 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - parsing_files_rx: &channel::Receiver<( - Arc>, - PendingFile, - )>, embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { @@ -333,10 +341,6 @@ impl SemanticIndex { }); } } - - if parsing_files_rx.len() == 0 { - embedding_queue.lock().flush(); - } } pub fn project_previously_indexed( @@ -581,7 +585,7 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task)>> { cx.spawn(|this, mut cx| async move { - let embeddings_for_digest = this.read_with(&cx, |this, cx| { + let embeddings_for_digest = this.read_with(&cx, |this, _| { if let Some(state) = this.projects.get(&project.downgrade()) { let mut worktree_id_file_paths = HashMap::default(); for (path, _) in &state.changed_paths { diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 01f34a2b1d..f549e68e04 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -3,11 +3,11 @@ use crate::{ embedding_queue::EmbeddingQueue, parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, semantic_index_settings::SemanticIndexSettings, - FileToEmbed, JobHandle, SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; use anyhow::Result; use async_trait::async_trait; -use gpui::{Task, TestAppContext}; +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -34,7 +34,7 @@ fn init_logger() { } #[gpui::test] -async fn test_semantic_index(cx: &mut TestAppContext) { +async fn test_semantic_index(deterministic: Arc, cx: &mut TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.background()); @@ -98,7 +98,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); assert_eq!(file_count, 3); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); let search_results = semantic_index @@ -188,7 +188,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .await .unwrap(); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); let prev_embedding_count = embedding_provider.embedding_count(); let (file_count, outstanding_file_count) = semantic_index @@ -197,7 +197,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { .unwrap(); assert_eq!(file_count, 1); - cx.foreground().run_until_parked(); + deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*outstanding_file_count.borrow(), 0); assert_eq!( From e86964eb5d4f7e2a387d8faec32f18df8da91362 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 13:01:37 -0400 Subject: [PATCH 22/60] optimize insert file in vector database Co-authored-by: Max --- crates/semantic_index/src/db.rs | 65 ++++++++++----------- crates/semantic_index/src/semantic_index.rs | 2 +- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index abb47cddf0..6cfd01456d 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -162,6 +162,11 @@ impl VectorDatabase { [], )?; + db.execute( + "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)", + [], + )?; + db.execute( "CREATE TABLE documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -206,43 +211,37 @@ impl VectorDatabase { // Return the existing ID, if both the file and mtime match let mtime = Timestamp::from(mtime); - let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?; - let existing_id = existing_id_query - .query_row( - params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - |row| Ok(row.get::<_, i64>(0)?), - ); + db.execute( + " + REPLACE INTO files + (worktree_id, relative_path, mtime_seconds, mtime_nanos) + VALUES (?1, ?2, ?3, ?4) + ", + params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], + )?; - let file_id = if existing_id.is_ok() { - // If already exists, just return the existing id - existing_id? - } else { - // Delete Existing Row - db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", - params![worktree_id, path.to_str()], - )?; - db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; - db.last_insert_rowid() - }; + let file_id = db.last_insert_rowid(); + + let mut query = db.prepare( + " + INSERT INTO documents + (file_id, start_byte, end_byte, name, embedding, digest) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + )?; - // Currently inserting at approximately 3400 documents a second - // I imagine we can speed this up with a bulk insert of some kind. for document in documents { - db.execute( - "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", - params![ - file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - document.embedding, - document.digest - ], - )?; - } + query.execute(params![ + file_id, + document.range.start.to_string(), + document.range.end.to_string(), + document.name, + document.embedding, + document.digest + ])?; + } - Ok(()) + Ok(()) }) } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a8518ce695..e155fe3c74 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -38,7 +38,7 @@ use util::{ }; use workspace::WorkspaceCreated; -const SEMANTIC_INDEX_VERSION: usize = 8; +const SEMANTIC_INDEX_VERSION: usize = 9; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); From 54235f4fb179049b7f8b27eaf9de1cd5e7e54d33 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 13:04:09 -0400 Subject: [PATCH 23/60] updated embeddings background delay to 5 minutes Co-authored-by: Max --- crates/semantic_index/src/semantic_index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index e155fe3c74..4e48b9cd71 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -39,7 +39,7 @@ use util::{ use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 9; -const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); +const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); pub fn init( From 6d7949654bdcfed2cb60c6f3faa8c0850edc527c Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 1 Sep 2023 11:14:27 -0600 Subject: [PATCH 24/60] Fix accidental visual selection on scroll As part of this fix partial page distance calculations to more closely match vim. --- crates/editor/src/editor.rs | 2 +- crates/editor/src/scroll/scroll_amount.rs | 2 +- crates/vim/src/normal/scroll.rs | 62 +++++++++++++++++++++-- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 2ea2ec7453..d331b0a268 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1654,7 +1654,7 @@ impl Editor { .excerpt_containing(self.selections.newest_anchor().head(), cx) } - fn style(&self, cx: &AppContext) -> EditorStyle { + pub fn style(&self, cx: &AppContext) -> EditorStyle { build_style( settings::get::(cx), self.get_field_editor_theme.as_deref(), diff --git a/crates/editor/src/scroll/scroll_amount.rs b/crates/editor/src/scroll/scroll_amount.rs index f9d09adcf5..cadf37b31d 100644 --- a/crates/editor/src/scroll/scroll_amount.rs +++ b/crates/editor/src/scroll/scroll_amount.rs @@ -39,7 +39,7 @@ impl ScrollAmount { .visible_line_count() // subtract one to leave an anchor line // round towards zero (so page-up and page-down are symmetric) - .map(|l| ((l - 1.) * count).trunc()) + .map(|l| (l * count).trunc() - count.signum()) .unwrap_or(0.), } } diff --git a/crates/vim/src/normal/scroll.rs b/crates/vim/src/normal/scroll.rs index a2bbab0478..1b3dcee6ad 100644 --- a/crates/vim/src/normal/scroll.rs +++ b/crates/vim/src/normal/scroll.rs @@ -67,7 +67,8 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex let top_anchor = editor.scroll_manager.anchor().anchor; editor.change_selections(None, cx, |s| { - s.move_heads_with(|map, head, goal| { + s.move_with(|map, selection| { + let head = selection.head(); let top = top_anchor.to_display_point(map); let min_row = top.row() + VERTICAL_SCROLL_MARGIN as u32; let max_row = top.row() + visible_rows - VERTICAL_SCROLL_MARGIN as u32 - 1; @@ -79,7 +80,11 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex } else { head }; - (new_head, goal) + if selection.is_empty() { + selection.collapse_to(new_head, selection.goal) + } else { + selection.set_head(new_head, selection.goal) + }; }) }); } @@ -90,12 +95,35 @@ mod test { use crate::{state::Mode, test::VimTestContext}; use gpui::geometry::vector::vec2f; use indoc::indoc; + use language::Point; #[gpui::test] async fn test_scroll(cx: &mut gpui::TestAppContext) { let mut cx = VimTestContext::new(cx, true).await; - cx.set_state(indoc! {"ˇa\nb\nc\nd\ne\n"}, Mode::Normal); + let window = cx.window; + let line_height = + cx.editor(|editor, cx| editor.style(cx).text.line_height(cx.font_cache())); + window.simulate_resize(vec2f(1000., 8.0 * line_height - 1.0), &mut cx); + + cx.set_state( + indoc!( + "ˇone + two + three + four + five + six + seven + eight + nine + ten + eleven + twelve + " + ), + Mode::Normal, + ); cx.update_editor(|editor, cx| { assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) @@ -112,5 +140,33 @@ mod test { cx.update_editor(|editor, cx| { assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.)) }); + + // does not select in normal mode + cx.simulate_keystrokes(["g", "g"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) + }); + cx.simulate_keystrokes(["ctrl-d"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0)); + assert_eq!( + editor.selections.newest(cx).range(), + Point::new(5, 0)..Point::new(5, 0) + ) + }); + + // does select in visual mode + cx.simulate_keystrokes(["g", "g"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.)) + }); + cx.simulate_keystrokes(["v", "ctrl-d"]); + cx.update_editor(|editor, cx| { + assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0)); + assert_eq!( + editor.selections.newest(cx).range(), + Point::new(0, 0)..Point::new(5, 1) + ) + }); } } From af12977d1777fa1af1be6b1d0bbd7127b0752401 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 1 Sep 2023 12:23:45 -0600 Subject: [PATCH 25/60] vim: Add `S` to substitute line For zed-industries/community#1897 --- assets/keymaps/vim.json | 2 + .../LiveKitBridge/Package.resolved | 4 +- crates/vim/src/normal.rs | 9 +- crates/vim/src/normal/substitute.rs | 97 ++++++++++++++++++- .../vim/test_data/test_substitute_line.json | 29 ++++++ 5 files changed, 128 insertions(+), 13 deletions(-) create mode 100644 crates/vim/test_data/test_substitute_line.json diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index c7e6199f44..da094ea7e4 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -371,6 +371,7 @@ "Replace" ], "s": "vim::Substitute", + "shift-s": "vim::SubstituteLine", "> >": "editor::Indent", "< <": "editor::Outdent", "ctrl-pagedown": "pane::ActivateNextItem", @@ -446,6 +447,7 @@ } ], "s": "vim::Substitute", + "shift-s": "vim::SubstituteLine", "c": "vim::Substitute", "~": "vim::ChangeCase", "shift-i": [ diff --git a/crates/live_kit_client/LiveKitBridge/Package.resolved b/crates/live_kit_client/LiveKitBridge/Package.resolved index b925bc8f0d..85ae088565 100644 --- a/crates/live_kit_client/LiveKitBridge/Package.resolved +++ b/crates/live_kit_client/LiveKitBridge/Package.resolved @@ -42,8 +42,8 @@ "repositoryURL": "https://github.com/apple/swift-protobuf.git", "state": { "branch": null, - "revision": "ce20dc083ee485524b802669890291c0d8090170", - "version": "1.22.1" + "revision": "0af9125c4eae12a4973fb66574c53a54962a9e1e", + "version": "1.21.0" } } ] diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index a73c518809..1f8276c327 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -27,7 +27,6 @@ use self::{ case::change_case, change::{change_motion, change_object}, delete::{delete_motion, delete_object}, - substitute::substitute, yank::{yank_motion, yank_object}, }; @@ -44,7 +43,6 @@ actions!( ChangeToEndOfLine, DeleteToEndOfLine, Yank, - Substitute, ChangeCase, ] ); @@ -56,13 +54,8 @@ pub fn init(cx: &mut AppContext) { cx.add_action(insert_line_above); cx.add_action(insert_line_below); cx.add_action(change_case); + substitute::init(cx); search::init(cx); - cx.add_action(|_: &mut Workspace, _: &Substitute, cx| { - Vim::update(cx, |vim, cx| { - let times = vim.pop_number_operator(cx); - substitute(vim, times, cx); - }) - }); cx.add_action(|_: &mut Workspace, _: &DeleteLeft, cx| { Vim::update(cx, |vim, cx| { let times = vim.pop_number_operator(cx); diff --git a/crates/vim/src/normal/substitute.rs b/crates/vim/src/normal/substitute.rs index b04596240a..efdd43d0a4 100644 --- a/crates/vim/src/normal/substitute.rs +++ b/crates/vim/src/normal/substitute.rs @@ -1,10 +1,32 @@ -use gpui::WindowContext; +use editor::movement; +use gpui::{actions, AppContext, WindowContext}; use language::Point; +use workspace::Workspace; use crate::{motion::Motion, utils::copy_selections_content, Mode, Vim}; -pub fn substitute(vim: &mut Vim, count: Option, cx: &mut WindowContext) { - let line_mode = vim.state().mode == Mode::VisualLine; +actions!(vim, [Substitute, SubstituteLine]); + +pub(crate) fn init(cx: &mut AppContext) { + cx.add_action(|_: &mut Workspace, _: &Substitute, cx| { + Vim::update(cx, |vim, cx| { + let times = vim.pop_number_operator(cx); + substitute(vim, times, vim.state().mode == Mode::VisualLine, cx); + }) + }); + + cx.add_action(|_: &mut Workspace, _: &SubstituteLine, cx| { + Vim::update(cx, |vim, cx| { + if matches!(vim.state().mode, Mode::VisualBlock | Mode::Visual) { + vim.switch_mode(Mode::VisualLine, false, cx) + } + let count = vim.pop_number_operator(cx); + substitute(vim, count, true, cx) + }) + }); +} + +pub fn substitute(vim: &mut Vim, count: Option, line_mode: bool, cx: &mut WindowContext) { vim.update_active_editor(cx, |editor, cx| { editor.set_clip_at_line_ends(false, cx); editor.transact(cx, |editor, cx| { @@ -14,6 +36,11 @@ pub fn substitute(vim: &mut Vim, count: Option, cx: &mut WindowContext) { Motion::Right.expand_selection(map, selection, count, true); } if line_mode { + // in Visual mode when the selection contains the newline at the end + // of the line, we should exclude it. + if !selection.is_empty() && selection.end.column() == 0 { + selection.end = movement::left(map, selection.end); + } Motion::CurrentLine.expand_selection(map, selection, None, false); if let Some((point, _)) = (Motion::FirstNonWhitespace { display_lines: false, @@ -166,4 +193,68 @@ mod test { the laˇzy dog"}) .await; } + + #[gpui::test] + async fn test_substitute_line(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + let initial_state = indoc! {" + The quick brown + fox juˇmps over + the lazy dog + "}; + + // normal mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["shift-s", "o"]).await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + + // visual mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["v", "k", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + oˇ + the lazy dog + "}) + .await; + + // visual block mode + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["ctrl-v", "j", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + "}) + .await; + + // visual mode including newline + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes(["v", "$", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + + // indentation + cx.set_neovim_option("shiftwidth=4").await; + cx.set_shared_state(initial_state).await; + cx.simulate_shared_keystrokes([">", ">", "shift-s", "o"]) + .await; + cx.assert_shared_state(indoc! {" + The quick brown + oˇ + the lazy dog + "}) + .await; + } } diff --git a/crates/vim/test_data/test_substitute_line.json b/crates/vim/test_data/test_substitute_line.json new file mode 100644 index 0000000000..eb0a9825f8 --- /dev/null +++ b/crates/vim/test_data/test_substitute_line.json @@ -0,0 +1,29 @@ +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\nthe lazy dog\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"v"} +{"Key":"k"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"oˇ\nthe lazy dog\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"ctrl-v"} +{"Key":"j"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\n","mode":"Insert"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":"v"} +{"Key":"$"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\noˇ\nthe lazy dog\n","mode":"Insert"}} +{"SetOption":{"value":"shiftwidth=4"}} +{"Put":{"state":"The quick brown\nfox juˇmps over\nthe lazy dog\n"}} +{"Key":">"} +{"Key":">"} +{"Key":"shift-s"} +{"Key":"o"} +{"Get":{"state":"The quick brown\n oˇ\nthe lazy dog\n","mode":"Insert"}} From 8dbc0fe0333958d6057a99bb734700deb270bf8b Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 1 Sep 2023 17:07:20 -0400 Subject: [PATCH 26/60] update pragma settings for improved database performance --- crates/semantic_index/src/db.rs | 13 ++++++++++++- crates/semantic_index/src/semantic_index.rs | 1 - 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 6cfd01456d..2ececc1eb6 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -18,7 +18,7 @@ use std::{ path::{Path, PathBuf}, rc::Rc, sync::Arc, - time::SystemTime, + time::{Instant, SystemTime}, }; use util::TryFutureExt; @@ -54,6 +54,12 @@ impl VectorDatabase { let path = path.clone(); async move { let mut connection = rusqlite::Connection::open(&path)?; + + connection.pragma_update(None, "journal_mode", "wal")?; + connection.pragma_update(None, "synchronous", "normal")?; + connection.pragma_update(None, "cache_size", 1000000)?; + connection.pragma_update(None, "temp_store", "MEMORY")?; + while let Ok(transaction) = transactions_rx.recv().await { transaction(&mut connection); } @@ -222,6 +228,7 @@ impl VectorDatabase { let file_id = db.last_insert_rowid(); + let t0 = Instant::now(); let mut query = db.prepare( " INSERT INTO documents @@ -229,6 +236,10 @@ impl VectorDatabase { VALUES (?1, ?2, ?3, ?4, ?5, ?6) ", )?; + log::trace!( + "Preparing Query Took: {:?} milliseconds", + t0.elapsed().as_millis() + ); for document in documents { query.execute(params![ diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 4e48b9cd71..a917eabfc8 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -81,7 +81,6 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - // Arc::new(embedding::DummyEmbeddings {}), Arc::new(OpenAIEmbeddings { client: http_client, executor: cx.background(), From d370c72fbfbdf9d3fa9448b49bfefb408cc3ecd9 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 1 Sep 2023 15:31:52 -0700 Subject: [PATCH 27/60] Start work on rejoining channel buffers --- crates/channel/src/channel_buffer.rs | 27 ++ crates/channel/src/channel_store.rs | 153 ++++++++-- crates/collab/src/db/queries/buffers.rs | 265 +++++++++++++----- crates/collab/src/rpc.rs | 25 +- .../collab/src/tests/channel_buffer_tests.rs | 138 ++++++--- crates/language/src/proto.rs | 25 ++ crates/project/src/project.rs | 27 +- crates/rpc/proto/zed.proto | 26 ++ crates/rpc/src/proto.rs | 3 + 9 files changed, 526 insertions(+), 163 deletions(-) diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 29f4d3493c..98ecbc5dcf 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -17,6 +17,7 @@ pub struct ChannelBuffer { connected: bool, collaborators: Vec, buffer: ModelHandle, + buffer_epoch: u64, client: Arc, subscription: Option, } @@ -73,6 +74,7 @@ impl ChannelBuffer { Self { buffer, + buffer_epoch: response.epoch, client, connected: true, collaborators, @@ -82,6 +84,26 @@ impl ChannelBuffer { })) } + pub(crate) fn replace_collaborators( + &mut self, + collaborators: Vec, + cx: &mut ModelContext, + ) { + for old_collaborator in &self.collaborators { + if collaborators + .iter() + .any(|c| c.replica_id == old_collaborator.replica_id) + { + self.buffer.update(cx, |buffer, cx| { + buffer.remove_peer(old_collaborator.replica_id as u16, cx) + }); + } + } + self.collaborators = collaborators; + cx.emit(Event::CollaboratorsChanged); + cx.notify(); + } + async fn handle_update_channel_buffer( this: ModelHandle, update_channel_buffer: TypedEnvelope, @@ -166,6 +188,10 @@ impl ChannelBuffer { } } + pub fn epoch(&self) -> u64 { + self.buffer_epoch + } + pub fn buffer(&self) -> ModelHandle { self.buffer.clone() } @@ -179,6 +205,7 @@ impl ChannelBuffer { } pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { + log::info!("channel buffer {} disconnected", self.channel.id); if self.connected { self.connected = false; self.subscription.take(); diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 861f731331..ec1652581d 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,13 +1,15 @@ use crate::channel_buffer::ChannelBuffer; use anyhow::{anyhow, Result}; -use client::{Client, Status, Subscription, User, UserId, UserStore}; +use client::{Client, Subscription, User, UserId, UserStore}; use collections::{hash_map, HashMap, HashSet}; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; -use std::sync::Arc; +use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); + pub type ChannelId = u64; pub struct ChannelStore { @@ -22,7 +24,8 @@ pub struct ChannelStore { client: Arc, user_store: ModelHandle, _rpc_subscription: Subscription, - _watch_connection_status: Task<()>, + _watch_connection_status: Task>, + disconnect_channel_buffers_task: Option>, _update_channels: Task<()>, } @@ -67,24 +70,20 @@ impl ChannelStore { let rpc_subscription = client.add_message_handler(cx.handle(), Self::handle_update_channels); - let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let mut connection_status = client.status(); + let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let watch_connection_status = cx.spawn_weak(|this, mut cx| async move { while let Some(status) = connection_status.next().await { - if !status.is_connected() { - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if matches!(status, Status::ConnectionLost | Status::SignedOut) { - this.handle_disconnect(cx); - } else { - this.disconnect_buffers(cx); - } - }); - } else { - break; - } + let this = this.upgrade(&cx)?; + if status.is_connected() { + this.update(&mut cx, |this, cx| this.handle_connect(cx)) + .await + .log_err()?; + } else { + this.update(&mut cx, |this, cx| this.handle_disconnect(cx)); } } + Some(()) }); Self { @@ -100,6 +99,7 @@ impl ChannelStore { user_store, _rpc_subscription: rpc_subscription, _watch_connection_status: watch_connection_status, + disconnect_channel_buffers_task: None, _update_channels: cx.spawn_weak(|this, mut cx| async move { while let Some(update_channels) = update_channels_rx.next().await { if let Some(this) = this.upgrade(&cx) { @@ -482,8 +482,102 @@ impl ChannelStore { Ok(()) } - fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) { - self.disconnect_buffers(cx); + fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.disconnect_channel_buffers_task.take(); + + let mut buffer_versions = Vec::new(); + for buffer in self.opened_buffers.values() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + let channel_buffer = buffer.read(cx); + let buffer = channel_buffer.buffer().read(cx); + buffer_versions.push(proto::ChannelBufferVersion { + channel_id: channel_buffer.channel().id, + epoch: channel_buffer.epoch(), + version: language::proto::serialize_version(&buffer.version()), + }); + } + } + } + + let response = self.client.request(proto::RejoinChannelBuffers { + buffers: buffer_versions, + }); + + cx.spawn(|this, mut cx| async move { + let mut response = response.await?; + + this.update(&mut cx, |this, cx| { + this.opened_buffers.retain(|_, buffer| match buffer { + OpenedChannelBuffer::Open(channel_buffer) => { + let Some(channel_buffer) = channel_buffer.upgrade(cx) else { + return false; + }; + + channel_buffer.update(cx, |channel_buffer, cx| { + let channel_id = channel_buffer.channel().id; + if let Some(remote_buffer) = response + .buffers + .iter_mut() + .find(|buffer| buffer.channel_id == channel_id) + { + let channel_id = channel_buffer.channel().id; + let remote_version = + language::proto::deserialize_version(&remote_buffer.version); + + channel_buffer.replace_collaborators( + mem::take(&mut remote_buffer.collaborators), + cx, + ); + + let operations = channel_buffer + .buffer() + .update(cx, |buffer, cx| { + let outgoing_operations = + buffer.serialize_ops(Some(remote_version), cx); + let incoming_operations = + mem::take(&mut remote_buffer.operations) + .into_iter() + .map(language::proto::deserialize_operation) + .collect::>>()?; + buffer.apply_ops(incoming_operations, cx)?; + anyhow::Ok(outgoing_operations) + }) + .log_err(); + + if let Some(operations) = operations { + let client = this.client.clone(); + cx.background() + .spawn(async move { + let operations = operations.await; + for chunk in + language::proto::split_operations(operations) + { + client + .send(proto::UpdateChannelBuffer { + channel_id, + operations: chunk, + }) + .ok(); + } + }) + .detach(); + return true; + } + } + + channel_buffer.disconnect(cx); + false + }) + } + OpenedChannelBuffer::Loading(_) => true, + }); + }); + anyhow::Ok(()) + }) + } + + fn handle_disconnect(&mut self, cx: &mut ModelContext) { self.channels_by_id.clear(); self.channel_invitations.clear(); self.channel_participants.clear(); @@ -491,16 +585,23 @@ impl ChannelStore { self.channel_paths.clear(); self.outgoing_invites.clear(); cx.notify(); - } - fn disconnect_buffers(&mut self, cx: &mut ModelContext) { - for (_, buffer) in self.opened_buffers.drain() { - if let OpenedChannelBuffer::Open(buffer) = buffer { - if let Some(buffer) = buffer.upgrade(cx) { - buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + self.disconnect_channel_buffers_task.get_or_insert_with(|| { + cx.spawn_weak(|this, mut cx| async move { + cx.background().timer(RECONNECT_TIMEOUT).await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + for (_, buffer) in this.opened_buffers.drain() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + } + } + } + }); } - } - } + }) + }); } pub(crate) fn update_channels( diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index f120aea1c5..587ed058ff 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -10,8 +10,6 @@ impl Database { connection: ConnectionId, ) -> Result { self.transaction(|tx| async move { - let tx = tx; - self.check_user_is_channel_member(channel_id, user_id, &tx) .await?; @@ -70,7 +68,6 @@ impl Database { .await?; collaborators.push(collaborator); - // Assemble the buffer state let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?; Ok(proto::JoinChannelBufferResponse { @@ -78,6 +75,7 @@ impl Database { replica_id: replica_id.to_proto() as u32, base_text, operations, + epoch: buffer.epoch as u64, collaborators: collaborators .into_iter() .map(|collaborator| proto::Collaborator { @@ -91,6 +89,113 @@ impl Database { .await } + pub async fn rejoin_channel_buffers( + &self, + buffers: &[proto::ChannelBufferVersion], + user_id: UserId, + connection_id: ConnectionId, + ) -> Result { + self.transaction(|tx| async move { + let mut response = proto::RejoinChannelBuffersResponse::default(); + for client_buffer in buffers { + let channel_id = ChannelId::from_proto(client_buffer.channel_id); + if self + .check_user_is_channel_member(channel_id, user_id, &*tx) + .await + .is_err() + { + log::info!("user is not a member of channel"); + continue; + } + + let buffer = self.get_channel_buffer(channel_id, &*tx).await?; + let mut collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + + // If the buffer epoch hasn't changed since the client lost + // connection, then the client's buffer can be syncronized with + // the server's buffer. + if buffer.epoch as u64 != client_buffer.epoch { + continue; + } + + // If there is still a disconnected collaborator for the user, + // update the connection associated with that collaborator, and reuse + // that replica id. + if let Some(ix) = collaborators + .iter() + .position(|c| c.user_id == user_id && c.connection_lost) + { + let self_collaborator = &mut collaborators[ix]; + *self_collaborator = channel_buffer_collaborator::ActiveModel { + id: ActiveValue::Unchanged(self_collaborator.id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId( + connection_id.owner_id as i32, + )), + connection_lost: ActiveValue::Set(false), + ..Default::default() + } + .update(&*tx) + .await?; + } else { + continue; + } + + let client_version = version_from_wire(&client_buffer.version); + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; + + let mut rows = buffer_operation::Entity::find() + .filter( + buffer_operation::Column::BufferId + .eq(buffer.id) + .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), + ) + .stream(&*tx) + .await?; + + // Find the server's version vector and any operations + // that the client has not seen. + let mut server_version = clock::Global::new(); + let mut operations = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let timestamp = clock::Lamport { + replica_id: row.replica_id as u16, + value: row.lamport_timestamp as u32, + }; + server_version.observe(timestamp); + if !client_version.observed(timestamp) { + operations.push(proto::Operation { + variant: Some(operation_from_storage(row, serialization_version)?), + }) + } + } + + response.buffers.push(proto::RejoinedChannelBuffer { + channel_id: client_buffer.channel_id, + version: version_to_wire(&server_version), + operations, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }); + } + + Ok(response) + }) + .await + } + pub async fn leave_channel_buffer( &self, channel_id: ChannelId, @@ -103,6 +208,39 @@ impl Database { .await } + pub async fn leave_channel_buffers( + &self, + connection: ConnectionId, + ) -> Result)>> { + self.transaction(|tx| async move { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + + let channel_ids: Vec = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .filter(Condition::all().add( + channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), + )) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + let mut result = Vec::new(); + for channel_id in channel_ids { + let collaborators = self + .leave_channel_buffer_internal(channel_id, connection, &*tx) + .await?; + result.push((channel_id, collaborators)); + } + + Ok(result) + }) + .await + } + pub async fn leave_channel_buffer_internal( &self, channel_id: ChannelId, @@ -143,45 +281,12 @@ impl Database { drop(rows); if connections.is_empty() { - self.snapshot_buffer(channel_id, &tx).await?; + self.snapshot_channel_buffer(channel_id, &tx).await?; } Ok(connections) } - pub async fn leave_channel_buffers( - &self, - connection: ConnectionId, - ) -> Result)>> { - self.transaction(|tx| async move { - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryChannelIds { - ChannelId, - } - - let channel_ids: Vec = channel_buffer_collaborator::Entity::find() - .select_only() - .column(channel_buffer_collaborator::Column::ChannelId) - .filter(Condition::all().add( - channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), - )) - .into_values::<_, QueryChannelIds>() - .all(&*tx) - .await?; - - let mut result = Vec::new(); - for channel_id in channel_ids { - let collaborators = self - .leave_channel_buffer_internal(channel_id, connection, &*tx) - .await?; - result.push((channel_id, collaborators)); - } - - Ok(result) - }) - .await - } - pub async fn get_channel_buffer_collaborators( &self, channel_id: ChannelId, @@ -224,20 +329,9 @@ impl Database { .await? .ok_or_else(|| anyhow!("no such buffer"))?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryVersion { - OperationSerializationVersion, - } - - let serialization_version: i32 = buffer - .find_related(buffer_snapshot::Entity) - .select_only() - .column(buffer_snapshot::Column::OperationSerializationVersion) - .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch)) - .into_values::<_, QueryVersion>() - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("missing buffer snapshot"))?; + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; let operations = operations .iter() @@ -270,6 +364,38 @@ impl Database { .await } + async fn get_buffer_operation_serialization_version( + &self, + buffer_id: BufferId, + epoch: i32, + tx: &DatabaseTransaction, + ) -> Result { + Ok(buffer_snapshot::Entity::find() + .filter(buffer_snapshot::Column::BufferId.eq(buffer_id)) + .filter(buffer_snapshot::Column::Epoch.eq(epoch)) + .select_only() + .column(buffer_snapshot::Column::OperationSerializationVersion) + .into_values::<_, QueryOperationSerializationVersion>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("missing buffer snapshot"))?) + } + + async fn get_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Model { + id: channel_id, + ..Default::default() + } + .find_related(buffer::Entity) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such buffer"))?) + } + async fn get_buffer_state( &self, buffer: &buffer::Model, @@ -303,27 +429,20 @@ impl Database { .await?; let mut operations = Vec::new(); while let Some(row) = rows.next().await { - let row = row?; - - let operation = operation_from_storage(row, version)?; operations.push(proto::Operation { - variant: Some(operation), + variant: Some(operation_from_storage(row?, version)?), }) } Ok((base_text, operations)) } - async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> { - let buffer = channel::Model { - id: channel_id, - ..Default::default() - } - .find_related(buffer::Entity) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such buffer"))?; - + async fn snapshot_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result<()> { + let buffer = self.get_channel_buffer(channel_id, tx).await?; let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?; if operations.is_empty() { return Ok(()); @@ -527,6 +646,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global { version } +fn version_to_wire(version: &clock::Global) -> Vec { + let mut message = Vec::new(); + for entry in version.iter() { + message.push(proto::VectorClockEntry { + replica_id: entry.replica_id as u32, + timestamp: entry.value, + }); + } + message +} + +#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] +enum QueryOperationSerializationVersion { + OperationSerializationVersion, +} + mod storage { #![allow(non_snake_case)] use prost::Message; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 6b44711c42..06aa00c9b8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -251,6 +251,7 @@ impl Server { .add_request_handler(join_channel_buffer) .add_request_handler(leave_channel_buffer) .add_message_handler(update_channel_buffer) + .add_request_handler(rejoin_channel_buffers) .add_request_handler(get_channel_members) .add_request_handler(respond_to_channel_invite) .add_request_handler(join_channel) @@ -854,13 +855,12 @@ async fn connection_lost( .await .trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); - futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); if !session .connection_pool() @@ -2547,6 +2547,23 @@ async fn update_channel_buffer( Ok(()) } +async fn rejoin_channel_buffers( + request: proto::RejoinChannelBuffers, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let rejoin_response = db + .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .await?; + + // TODO: inform channel buffer collaborators that this user has rejoined. + + response.send(rejoin_response)?; + + Ok(()) +} + async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 8ac4dbbd3f..5ba5b50429 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -21,20 +21,19 @@ async fn test_core_channel_buffers( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; - let zed_id = server + let channel_id = server .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; // Client A joins the channel buffer let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); // Client A edits the buffer let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer()); - buffer_a.update(cx_a, |buffer, cx| { buffer.edit([(0..0, "hello world")], None, cx) }); @@ -45,17 +44,15 @@ async fn test_core_channel_buffers( buffer.edit([(0..5, "goodbye")], None, cx) }); buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx)); - deterministic.run_until_parked(); - assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world"); + deterministic.run_until_parked(); // Client B joins the channel buffer let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - channel_buffer_b.read_with(cx_b, |buffer, _| { assert_collaborators( buffer.collaborators(), @@ -91,9 +88,7 @@ async fn test_core_channel_buffers( // Client A rejoins the channel buffer let _channel_buffer_a = client_a .channel_store() - .update(cx_a, |channels, cx| { - channels.open_channel_buffer(zed_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); deterministic.run_until_parked(); @@ -136,7 +131,7 @@ async fn test_channel_buffer_replica_ids( let channel_id = server .make_channel( - "zed", + "the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b), (&client_c, cx_c)], ) @@ -160,23 +155,17 @@ async fn test_channel_buffer_replica_ids( // C first so that the replica IDs in the project and the channel buffer are different let channel_buffer_c = client_c .channel_store() - .update(cx_c, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -286,28 +275,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; - let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await; + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut []) + .await; let channel_buffer_1 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_2 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_3 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); // All concurrent tasks for opening a channel buffer return the same model handle. - let (channel_buffer_1, channel_buffer_2, channel_buffer_3) = + let (channel_buffer, channel_buffer_2, channel_buffer_3) = future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3) .await .unwrap(); - let model_id = channel_buffer_1.id(); - assert_eq!(channel_buffer_1, channel_buffer_2); - assert_eq!(channel_buffer_1, channel_buffer_3); + let channel_buffer_model_id = channel_buffer.id(); + assert_eq!(channel_buffer, channel_buffer_2); + assert_eq!(channel_buffer, channel_buffer_3); - channel_buffer_1.update(cx_a, |buffer, cx| { + channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "hello")], None, cx); }) @@ -315,7 +306,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu deterministic.run_until_parked(); cx_a.update(|_| { - drop(channel_buffer_1); + drop(channel_buffer); drop(channel_buffer_2); drop(channel_buffer_3); }); @@ -324,10 +315,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu // The channel buffer can be reopened after dropping it. let channel_buffer = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - assert_ne!(channel_buffer.id(), model_id); + assert_ne!(channel_buffer.id(), channel_buffer_model_id); channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, _| { assert_eq!(buffer.text(), "hello"); @@ -347,22 +338,17 @@ async fn test_channel_buffer_disconnect( let client_b = server.create_client(cx_b, "user_b").await; let channel_id = server - .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -375,7 +361,7 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); @@ -403,13 +389,81 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); }); } +#[gpui::test] +async fn test_rejoin_channel_buffer( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client A disconnects. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + // deterministic.advance_clock(RECEIVE_TIMEOUT); + + // Both clients make an edit. Both clients see their own edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + deterministic.run_until_parked(); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "12"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "01"); + }); + + // Client A reconnects. + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!( diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 80eb972f42..c4abe39d47 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -127,6 +127,31 @@ pub fn serialize_undo_map_entry( } } +pub fn split_operations( + mut operations: Vec, +) -> impl Iterator> { + #[cfg(any(test, feature = "test-support"))] + const CHUNK_SIZE: usize = 5; + + #[cfg(not(any(test, feature = "test-support")))] + const CHUNK_SIZE: usize = 100; + + let mut done = false; + std::iter::from_fn(move || { + if done { + return None; + } + + let operations = operations + .drain(..std::cmp::min(CHUNK_SIZE, operations.len())) + .collect::>(); + if operations.is_empty() { + done = true; + } + Some(operations) + }) +} + pub fn serialize_selections(selections: &Arc<[Selection]>) -> Vec { selections.iter().map(serialize_selection).collect() } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 5cd13b8be8..0690cc9188 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -35,7 +35,7 @@ use language::{ point_to_lsp, proto::{ deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version, - serialize_anchor, serialize_version, + serialize_anchor, serialize_version, split_operations, }, range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent, @@ -8200,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate { } } -fn split_operations( - mut operations: Vec, -) -> impl Iterator> { - #[cfg(any(test, feature = "test-support"))] - const CHUNK_SIZE: usize = 5; - - #[cfg(not(any(test, feature = "test-support")))] - const CHUNK_SIZE: usize = 100; - - let mut done = false; - std::iter::from_fn(move || { - if done { - return None; - } - - let operations = operations - .drain(..cmp::min(CHUNK_SIZE, operations.len())) - .collect::>(); - if operations.is_empty() { - done = true; - } - Some(operations) - }) -} - fn serialize_symbol(symbol: &Symbol) -> proto::Symbol { proto::Symbol { language_server_name: symbol.language_server_name.0.to_string(), diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 92c85677f6..fe9093245e 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1,6 +1,8 @@ syntax = "proto3"; package zed.messages; +// Looking for a number? Search "// Current max" + message PeerId { uint32 owner_id = 1; uint32 id = 2; @@ -151,6 +153,8 @@ message Envelope { LeaveChannelBuffer leave_channel_buffer = 134; AddChannelBufferCollaborator add_channel_buffer_collaborator = 135; RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136; + RejoinChannelBuffers rejoin_channel_buffers = 139; + RejoinChannelBuffersResponse rejoin_channel_buffers_response = 140; // Current max } } @@ -616,6 +620,12 @@ message BufferVersion { repeated VectorClockEntry version = 2; } +message ChannelBufferVersion { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + uint64 epoch = 3; +} + enum FormatTrigger { Save = 0; Manual = 1; @@ -1008,12 +1018,28 @@ message JoinChannelBuffer { uint64 channel_id = 1; } +message RejoinChannelBuffers { + repeated ChannelBufferVersion buffers = 1; +} + +message RejoinChannelBuffersResponse { + repeated RejoinedChannelBuffer buffers = 1; +} + message JoinChannelBufferResponse { uint64 buffer_id = 1; uint32 replica_id = 2; string base_text = 3; repeated Operation operations = 4; repeated Collaborator collaborators = 5; + uint64 epoch = 6; +} + +message RejoinedChannelBuffer { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + repeated Operation operations = 3; + repeated Collaborator collaborators = 4; } message LeaveChannelBuffer { diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 2e4dce01e1..a600bc4970 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -229,6 +229,8 @@ messages!( (StartLanguageServer, Foreground), (SynchronizeBuffers, Foreground), (SynchronizeBuffersResponse, Foreground), + (RejoinChannelBuffers, Foreground), + (RejoinChannelBuffersResponse, Foreground), (Test, Foreground), (Unfollow, Foreground), (UnshareProject, Foreground), @@ -319,6 +321,7 @@ request_messages!( (SearchProject, SearchProjectResponse), (ShareProject, ShareProjectResponse), (SynchronizeBuffers, SynchronizeBuffersResponse), + (RejoinChannelBuffers, RejoinChannelBuffersResponse), (Test, Test), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), From d7e4cb4ab10da4f5a6ac936e26db67f5956349de Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 1 Sep 2023 16:52:41 -0700 Subject: [PATCH 28/60] executor: timers must be used --- crates/gpui/src/executor.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 712c854488..474ea8364f 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -106,6 +106,7 @@ pub struct Deterministic { parker: parking_lot::Mutex, } +#[must_use] pub enum Timer { Production(smol::Timer), #[cfg(any(test, feature = "test-support"))] From e6babce556d07d21faa60f013e9290f50516b157 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 1 Sep 2023 17:23:55 -0700 Subject: [PATCH 29/60] Broadcast new peer ids for rejoined channel collaborators --- crates/channel/src/channel_buffer.rs | 21 ++++++ crates/collab/src/db.rs | 5 ++ crates/collab/src/db/queries/buffers.rs | 69 +++++++++---------- crates/collab/src/rpc.rs | 24 ++++++- .../collab/src/tests/channel_buffer_tests.rs | 14 +++- crates/rpc/proto/zed.proto | 11 ++- crates/rpc/src/proto.rs | 4 +- 7 files changed, 104 insertions(+), 44 deletions(-) diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 98ecbc5dcf..e11282cf79 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -10,6 +10,7 @@ pub(crate) fn init(client: &Arc) { client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer); client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator); client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator); + client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator); } pub struct ChannelBuffer { @@ -171,6 +172,26 @@ impl ChannelBuffer { Ok(()) } + async fn handle_update_channel_buffer_collaborator( + this: ModelHandle, + message: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + for collaborator in &mut this.collaborators { + if collaborator.peer_id == message.payload.old_peer_id { + collaborator.peer_id = message.payload.new_peer_id; + break; + } + } + cx.emit(Event::CollaboratorsChanged); + cx.notify(); + }); + + Ok(()) + } + fn on_buffer_update( &mut self, _: ModelHandle, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 888158188f..4a9983e600 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -435,6 +435,11 @@ pub struct ChannelsForUser { pub channels_with_admin_privileges: HashSet, } +pub struct RejoinedChannelBuffer { + pub buffer: proto::RejoinedChannelBuffer, + pub old_connection_id: ConnectionId, +} + #[derive(Clone)] pub struct JoinRoom { pub room: proto::Room, diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 587ed058ff..79e20a2622 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -94,9 +94,9 @@ impl Database { buffers: &[proto::ChannelBufferVersion], user_id: UserId, connection_id: ConnectionId, - ) -> Result { + ) -> Result> { self.transaction(|tx| async move { - let mut response = proto::RejoinChannelBuffersResponse::default(); + let mut results = Vec::new(); for client_buffer in buffers { let channel_id = ChannelId::from_proto(client_buffer.channel_id); if self @@ -121,28 +121,24 @@ impl Database { continue; } - // If there is still a disconnected collaborator for the user, - // update the connection associated with that collaborator, and reuse - // that replica id. - if let Some(ix) = collaborators - .iter() - .position(|c| c.user_id == user_id && c.connection_lost) - { - let self_collaborator = &mut collaborators[ix]; - *self_collaborator = channel_buffer_collaborator::ActiveModel { - id: ActiveValue::Unchanged(self_collaborator.id), - connection_id: ActiveValue::Set(connection_id.id as i32), - connection_server_id: ActiveValue::Set(ServerId( - connection_id.owner_id as i32, - )), - connection_lost: ActiveValue::Set(false), - ..Default::default() - } - .update(&*tx) - .await?; - } else { + // Find the collaborator record for this user's previous lost + // connection. Update it with the new connection id. + let Some(self_collaborator) = collaborators + .iter_mut() + .find(|c| c.user_id == user_id && c.connection_lost) + else { continue; + }; + let old_connection_id = self_collaborator.connection(); + *self_collaborator = channel_buffer_collaborator::ActiveModel { + id: ActiveValue::Unchanged(self_collaborator.id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), + connection_lost: ActiveValue::Set(false), + ..Default::default() } + .update(&*tx) + .await?; let client_version = version_from_wire(&client_buffer.version); let serialization_version = self @@ -176,22 +172,25 @@ impl Database { } } - response.buffers.push(proto::RejoinedChannelBuffer { - channel_id: client_buffer.channel_id, - version: version_to_wire(&server_version), - operations, - collaborators: collaborators - .into_iter() - .map(|collaborator| proto::Collaborator { - peer_id: Some(collaborator.connection().into()), - user_id: collaborator.user_id.to_proto(), - replica_id: collaborator.replica_id.0 as u32, - }) - .collect(), + results.push(RejoinedChannelBuffer { + old_connection_id, + buffer: proto::RejoinedChannelBuffer { + channel_id: client_buffer.channel_id, + version: version_to_wire(&server_version), + operations, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }, }); } - Ok(response) + Ok(results) }) .await } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 06aa00c9b8..d221d1c99e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2553,13 +2553,31 @@ async fn rejoin_channel_buffers( session: Session, ) -> Result<()> { let db = session.db().await; - let rejoin_response = db + let buffers = db .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) .await?; - // TODO: inform channel buffer collaborators that this user has rejoined. + for buffer in &buffers { + let collaborators_to_notify = buffer + .buffer + .collaborators + .iter() + .filter_map(|c| Some(c.peer_id?.into())); + channel_buffer_updated( + session.connection_id, + collaborators_to_notify, + &proto::UpdateChannelBufferCollaborator { + channel_id: buffer.buffer.channel_id, + old_peer_id: Some(buffer.old_connection_id.into()), + new_peer_id: Some(session.connection_id.into()), + }, + &session.peer, + ); + } - response.send(rejoin_response)?; + response.send(proto::RejoinChannelBuffersResponse { + buffers: buffers.into_iter().map(|b| b.buffer).collect(), + })?; Ok(()) } diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 5ba5b50429..236771c2a5 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -432,9 +432,8 @@ async fn test_rejoin_channel_buffer( // Client A disconnects. server.forbid_connections(); server.disconnect_client(client_a.peer_id().unwrap()); - // deterministic.advance_clock(RECEIVE_TIMEOUT); - // Both clients make an edit. Both clients see their own edit. + // Both clients make an edit. channel_buffer_a.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, cx| { buffer.edit([(1..1, "2")], None, cx); @@ -445,6 +444,8 @@ async fn test_rejoin_channel_buffer( buffer.edit([(0..0, "0")], None, cx); }) }); + + // Both clients see their own edit. deterministic.run_until_parked(); channel_buffer_a.read_with(cx_a, |buffer, cx| { assert_eq!(buffer.buffer().read(cx).text(), "12"); @@ -453,7 +454,8 @@ async fn test_rejoin_channel_buffer( assert_eq!(buffer.buffer().read(cx).text(), "01"); }); - // Client A reconnects. + // Client A reconnects. Both clients see each other's edits, and see + // the same collaborators. server.allow_connections(); deterministic.advance_clock(RECEIVE_TIMEOUT); channel_buffer_a.read_with(cx_a, |buffer, cx| { @@ -462,6 +464,12 @@ async fn test_rejoin_channel_buffer( channel_buffer_b.read_with(cx_b, |buffer, cx| { assert_eq!(buffer.buffer().read(cx).text(), "012"); }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); } #[track_caller] diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index fe9093245e..2e96d79f5e 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -153,8 +153,9 @@ message Envelope { LeaveChannelBuffer leave_channel_buffer = 134; AddChannelBufferCollaborator add_channel_buffer_collaborator = 135; RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136; - RejoinChannelBuffers rejoin_channel_buffers = 139; - RejoinChannelBuffersResponse rejoin_channel_buffers_response = 140; // Current max + UpdateChannelBufferCollaborator update_channel_buffer_collaborator = 139; + RejoinChannelBuffers rejoin_channel_buffers = 140; + RejoinChannelBuffersResponse rejoin_channel_buffers_response = 141; // Current max } } @@ -434,6 +435,12 @@ message RemoveChannelBufferCollaborator { PeerId peer_id = 2; } +message UpdateChannelBufferCollaborator { + uint64 channel_id = 1; + PeerId old_peer_id = 2; + PeerId new_peer_id = 3; +} + message GetDefinition { uint64 project_id = 1; uint64 buffer_id = 2; diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index a600bc4970..f643a8c168 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -259,6 +259,7 @@ messages!( (UpdateChannelBuffer, Foreground), (RemoveChannelBufferCollaborator, Foreground), (AddChannelBufferCollaborator, Foreground), + (UpdateChannelBufferCollaborator, Foreground), ); request_messages!( @@ -389,7 +390,8 @@ entity_messages!( channel_id, UpdateChannelBuffer, RemoveChannelBufferCollaborator, - AddChannelBufferCollaborator + AddChannelBufferCollaborator, + UpdateChannelBufferCollaborator ); const KIB: usize = 1024; From 6827ddf97d93bfcde40a5e1fdfd36024a5e85cba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 1 Sep 2023 17:51:00 -0700 Subject: [PATCH 30/60] Start work on refreshing channel buffer collaborators on server restart --- crates/collab/src/db.rs | 5 +++ crates/collab/src/db/queries/buffers.rs | 28 ++++++++++++++--- crates/collab/src/db/queries/servers.rs | 31 ++++++++++++++----- crates/collab/src/rpc.rs | 20 ++++++++++-- .../src/tests/randomized_integration_tests.rs | 4 +-- 5 files changed, 73 insertions(+), 15 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 4a9983e600..823990eaf8 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -503,6 +503,11 @@ pub struct RefreshedRoom { pub canceled_calls_to_user_ids: Vec, } +pub struct RefreshedChannelBuffer { + pub connection_ids: Vec, + pub removed_collaborators: Vec, +} + pub struct Project { pub collaborators: Vec, pub worktrees: BTreeMap, diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 79e20a2622..813255b80e 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -123,10 +123,11 @@ impl Database { // Find the collaborator record for this user's previous lost // connection. Update it with the new connection id. - let Some(self_collaborator) = collaborators - .iter_mut() - .find(|c| c.user_id == user_id && c.connection_lost) - else { + let server_id = ServerId(connection_id.owner_id as i32); + let Some(self_collaborator) = collaborators.iter_mut().find(|c| { + c.user_id == user_id + && (c.connection_lost || c.connection_server_id != server_id) + }) else { continue; }; let old_connection_id = self_collaborator.connection(); @@ -195,6 +196,25 @@ impl Database { .await } + pub async fn refresh_channel_buffer( + &self, + channel_id: ChannelId, + server_id: ServerId, + ) -> Result { + self.transaction(|tx| async move { + let mut connection_ids = Vec::new(); + let mut removed_collaborators = Vec::new(); + + // TODO + + Ok(RefreshedChannelBuffer { + connection_ids, + removed_collaborators, + }) + }) + .await + } + pub async fn leave_channel_buffer( &self, channel_id: ChannelId, diff --git a/crates/collab/src/db/queries/servers.rs b/crates/collab/src/db/queries/servers.rs index 08a2bda16a..2b1d0d2c0c 100644 --- a/crates/collab/src/db/queries/servers.rs +++ b/crates/collab/src/db/queries/servers.rs @@ -14,31 +14,48 @@ impl Database { .await } - pub async fn stale_room_ids( + pub async fn stale_server_resource_ids( &self, environment: &str, new_server_id: ServerId, - ) -> Result> { + ) -> Result<(Vec, Vec)> { self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { + enum QueryRoomIds { RoomId, } + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + let stale_server_epochs = self .stale_server_ids(environment, new_server_id, &tx) .await?; - Ok(room_participant::Entity::find() + let room_ids = room_participant::Entity::find() .select_only() .column(room_participant::Column::RoomId) .distinct() .filter( room_participant::Column::AnsweringConnectionServerId - .is_in(stale_server_epochs), + .is_in(stale_server_epochs.iter().copied()), ) - .into_values::<_, QueryAs>() + .into_values::<_, QueryRoomIds>() .all(&*tx) - .await?) + .await?; + let channel_ids = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .distinct() + .filter( + channel_buffer_collaborator::Column::ConnectionServerId + .is_in(stale_server_epochs.iter().copied()), + ) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + Ok((room_ids, channel_ids)) }) .await } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d221d1c99e..95307ba725 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -278,13 +278,29 @@ impl Server { tracing::info!("waiting for cleanup timeout"); timeout.await; tracing::info!("cleanup timeout expired, retrieving stale rooms"); - if let Some(room_ids) = app_state + if let Some((room_ids, channel_ids)) = app_state .db - .stale_room_ids(&app_state.config.zed_environment, server_id) + .stale_server_resource_ids(&app_state.config.zed_environment, server_id) .await .trace_err() { tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms"); + + for channel_id in channel_ids { + if let Some(refreshed_channel_buffer) = app_state + .db + .refresh_channel_buffer(channel_id, server_id) + .await + .trace_err() + { + for connection_id in refreshed_channel_buffer.connection_ids { + for message in &refreshed_channel_buffer.removed_collaborators { + peer.send(connection_id, message.clone()).trace_err(); + } + } + } + } + for room_id in room_ids { let mut contacts_to_update = HashSet::default(); let mut canceled_calls_to_user_ids = Vec::new(); diff --git a/crates/collab/src/tests/randomized_integration_tests.rs b/crates/collab/src/tests/randomized_integration_tests.rs index e48753ed41..309fcf7e44 100644 --- a/crates/collab/src/tests/randomized_integration_tests.rs +++ b/crates/collab/src/tests/randomized_integration_tests.rs @@ -307,10 +307,10 @@ async fn apply_server_operation( server.start().await.unwrap(); deterministic.advance_clock(CLEANUP_TIMEOUT); let environment = &server.app_state.config.zed_environment; - let stale_room_ids = server + let (stale_room_ids, _) = server .app_state .db - .stale_room_ids(environment, server.id()) + .stale_server_resource_ids(environment, server.id()) .await .unwrap(); assert_eq!(stale_room_ids, vec![]); From 3a7b551e337826c5514eadeba81f1228a05557a9 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Sat, 2 Sep 2023 19:43:05 -0600 Subject: [PATCH 31/60] Fix tests with no neovim --- crates/vim/src/test/neovim_connection.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/vim/src/test/neovim_connection.rs b/crates/vim/src/test/neovim_connection.rs index 3e59080b13..e44e8d0e4c 100644 --- a/crates/vim/src/test/neovim_connection.rs +++ b/crates/vim/src/test/neovim_connection.rs @@ -237,6 +237,9 @@ impl NeovimConnection { #[cfg(not(feature = "neovim"))] pub async fn set_option(&mut self, value: &str) { + if let Some(NeovimData::Get { .. }) = self.data.front() { + self.data.pop_front(); + }; assert_eq!( self.data.pop_front(), Some(NeovimData::SetOption { From 55dd0b176c47f782d6c1a23c471077ff38823866 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Sat, 2 Sep 2023 19:52:18 -0600 Subject: [PATCH 32/60] Use consistent naming --- crates/vim/src/normal/substitute.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/vim/src/normal/substitute.rs b/crates/vim/src/normal/substitute.rs index efdd43d0a4..23b545abd8 100644 --- a/crates/vim/src/normal/substitute.rs +++ b/crates/vim/src/normal/substitute.rs @@ -10,8 +10,8 @@ actions!(vim, [Substitute, SubstituteLine]); pub(crate) fn init(cx: &mut AppContext) { cx.add_action(|_: &mut Workspace, _: &Substitute, cx| { Vim::update(cx, |vim, cx| { - let times = vim.pop_number_operator(cx); - substitute(vim, times, vim.state().mode == Mode::VisualLine, cx); + let count = vim.pop_number_operator(cx); + substitute(vim, count, vim.state().mode == Mode::VisualLine, cx); }) }); From 56db21d54bbd0dc83aa1a756152a18092c1ef8be Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 25 Aug 2023 14:38:42 -0600 Subject: [PATCH 33/60] Split ContextMenu actions This should have no user-visible impact. For vim `.` to repeat it's important that actions are replayable. Currently editor::MoveDown *sometimes* moves the cursor down, and *sometimes* selects the next completion. For replay we need to be able to separate the two. --- assets/keymaps/default.json | 11 ++++ crates/editor/src/editor.rs | 74 ++++++++++++++--------- crates/editor/src/editor_tests.rs | 2 +- crates/editor/src/scroll.rs | 4 -- crates/editor/src/scroll/scroll_amount.rs | 24 +------- 5 files changed, 57 insertions(+), 58 deletions(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 1b2d8ce419..fa62a74f3f 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -515,6 +515,17 @@ "enter": "editor::ConfirmCodeAction" } }, + { + "context": "Editor && (showing_code_actions || showing_completions)", + "bindings": { + "up": "editor::ContextMenuPrev", + "ctrl-p": "editor::ContextMenuPrev", + "down": "editor::ContextMenuNext", + "ctrl-n": "editor::ContextMenuNext", + "pageup": "editor::ContextMenuFirst", + "pagedown": "editor::ContextMenuLast" + } + }, // Custom bindings { "bindings": { diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index d331b0a268..bdd29b04fa 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -312,6 +312,10 @@ actions!( CopyPath, CopyRelativePath, CopyHighlightJson, + ContextMenuFirst, + ContextMenuPrev, + ContextMenuNext, + ContextMenuLast, ] ); @@ -468,6 +472,10 @@ pub fn init(cx: &mut AppContext) { cx.add_action(Editor::next_copilot_suggestion); cx.add_action(Editor::previous_copilot_suggestion); cx.add_action(Editor::copilot_suggest); + cx.add_action(Editor::context_menu_first); + cx.add_action(Editor::context_menu_prev); + cx.add_action(Editor::context_menu_next); + cx.add_action(Editor::context_menu_last); hover_popover::init(cx); scroll::actions::init(cx); @@ -5166,12 +5174,6 @@ impl Editor { return; } - if let Some(context_menu) = self.context_menu.as_mut() { - if context_menu.select_prev(cx) { - return; - } - } - if matches!(self.mode, EditorMode::SingleLine) { cx.propagate_action(); return; @@ -5194,15 +5196,6 @@ impl Editor { return; } - if self - .context_menu - .as_mut() - .map(|menu| menu.select_first(cx)) - .unwrap_or(false) - { - return; - } - if matches!(self.mode, EditorMode::SingleLine) { cx.propagate_action(); return; @@ -5242,12 +5235,6 @@ impl Editor { pub fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext) { self.take_rename(true, cx); - if let Some(context_menu) = self.context_menu.as_mut() { - if context_menu.select_next(cx) { - return; - } - } - if self.mode == EditorMode::SingleLine { cx.propagate_action(); return; @@ -5315,6 +5302,30 @@ impl Editor { }); } + pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_first(cx); + } + } + + pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_prev(cx); + } + } + + pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_next(cx); + } + } + + pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext) { + if let Some(context_menu) = self.context_menu.as_mut() { + context_menu.select_last(cx); + } + } + pub fn move_to_previous_word_start( &mut self, _: &MoveToPreviousWordStart, @@ -8666,17 +8677,20 @@ impl View for Editor { if self.pending_rename.is_some() { keymap.add_identifier("renaming"); } - match self.context_menu.as_ref() { - Some(ContextMenu::Completions(_)) => { - keymap.add_identifier("menu"); - keymap.add_identifier("showing_completions") + if self.context_menu_visible() { + match self.context_menu.as_ref() { + Some(ContextMenu::Completions(_)) => { + keymap.add_identifier("menu"); + keymap.add_identifier("showing_completions") + } + Some(ContextMenu::CodeActions(_)) => { + keymap.add_identifier("menu"); + keymap.add_identifier("showing_code_actions") + } + None => {} } - Some(ContextMenu::CodeActions(_)) => { - keymap.add_identifier("menu"); - keymap.add_identifier("showing_code_actions") - } - None => {} } + for layer in self.keymap_context_layers.values() { keymap.extend(layer); } diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index ad97639d0b..74bd67e03a 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -5340,7 +5340,7 @@ async fn test_completion(cx: &mut gpui::TestAppContext) { cx.condition(|editor, _| editor.context_menu_visible()) .await; let apply_additional_edits = cx.update_editor(|editor, cx| { - editor.move_down(&MoveDown, cx); + editor.context_menu_next(&Default::default(), cx); editor .confirm_completion(&ConfirmCompletion::default(), cx) .unwrap() diff --git a/crates/editor/src/scroll.rs b/crates/editor/src/scroll.rs index d87bc0ae4f..8233f92a1a 100644 --- a/crates/editor/src/scroll.rs +++ b/crates/editor/src/scroll.rs @@ -378,10 +378,6 @@ impl Editor { return; } - if amount.move_context_menu_selection(self, cx) { - return; - } - let cur_position = self.scroll_position(cx); let new_pos = cur_position + vec2f(0., amount.lines(self)); self.set_scroll_position(new_pos, cx); diff --git a/crates/editor/src/scroll/scroll_amount.rs b/crates/editor/src/scroll/scroll_amount.rs index cadf37b31d..0edab2bdfc 100644 --- a/crates/editor/src/scroll/scroll_amount.rs +++ b/crates/editor/src/scroll/scroll_amount.rs @@ -1,8 +1,5 @@ -use gpui::ViewContext; -use serde::Deserialize; -use util::iife; - use crate::Editor; +use serde::Deserialize; #[derive(Clone, PartialEq, Deserialize)] pub enum ScrollAmount { @@ -13,25 +10,6 @@ pub enum ScrollAmount { } impl ScrollAmount { - pub fn move_context_menu_selection( - &self, - editor: &mut Editor, - cx: &mut ViewContext, - ) -> bool { - iife!({ - let context_menu = editor.context_menu.as_mut()?; - - match self { - Self::Line(c) if *c > 0. => context_menu.select_next(cx), - Self::Line(_) => context_menu.select_prev(cx), - Self::Page(c) if *c > 0. => context_menu.select_last(cx), - Self::Page(_) => context_menu.select_first(cx), - } - .then_some(()) - }) - .is_some() - } - pub fn lines(&self, editor: &mut Editor) -> f32 { match self { Self::Line(count) => *count, From d4cff684751e464e1fc733bd5cbe72772e95f891 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Sep 2023 12:26:48 +0200 Subject: [PATCH 34/60] :art: --- Cargo.lock | 1 + crates/semantic_index/Cargo.toml | 2 + crates/semantic_index/src/db.rs | 12 +- crates/semantic_index/src/embedding_queue.rs | 4 +- crates/semantic_index/src/semantic_index.rs | 539 ++++++++++-------- .../src/semantic_index_tests.rs | 9 +- 6 files changed, 308 insertions(+), 259 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e0eb1947e2..c99e88b9b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6713,6 +6713,7 @@ dependencies = [ "anyhow", "async-trait", "bincode", + "collections", "ctor", "editor", "env_logger 0.9.3", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index d46346e0ab..72a36efd50 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -9,6 +9,7 @@ path = "src/semantic_index.rs" doctest = false [dependencies] +collections = { path = "../collections" } gpui = { path = "../gpui" } language = { path = "../language" } project = { path = "../project" } @@ -42,6 +43,7 @@ sha1 = "0.10.5" parse_duration = "2.1.1" [dev-dependencies] +collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 2ececc1eb6..5664210388 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -4,6 +4,7 @@ use crate::{ SEMANTIC_INDEX_VERSION, }; use anyhow::{anyhow, Context, Result}; +use collections::HashMap; use futures::channel::oneshot; use gpui::executor; use project::{search::PathMatcher, Fs}; @@ -12,7 +13,6 @@ use rusqlite::params; use rusqlite::types::Value; use std::{ cmp::Ordering, - collections::HashMap, future::Future, ops::Range, path::{Path, PathBuf}, @@ -195,7 +195,7 @@ impl VectorDatabase { pub fn delete_file( &self, worktree_id: i64, - delete_path: PathBuf, + delete_path: Arc, ) -> impl Future> { self.transact(move |db| { db.execute( @@ -209,7 +209,7 @@ impl VectorDatabase { pub fn insert_file( &self, worktree_id: i64, - path: PathBuf, + path: Arc, mtime: SystemTime, documents: Vec, ) -> impl Future> { @@ -288,7 +288,7 @@ impl VectorDatabase { WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) ", )?; - let mut embeddings_by_digest = HashMap::new(); + let mut embeddings_by_digest = HashMap::default(); for (worktree_id, file_paths) in worktree_id_file_paths { let file_paths = Rc::new( file_paths @@ -316,7 +316,7 @@ impl VectorDatabase { pub fn find_or_create_worktree( &self, - worktree_root_path: PathBuf, + worktree_root_path: Arc, ) -> impl Future> { self.transact(move |db| { let mut worktree_query = @@ -351,7 +351,7 @@ impl VectorDatabase { WHERE worktree_id = ?1 ORDER BY relative_path", )?; - let mut result: HashMap = HashMap::new(); + let mut result: HashMap = HashMap::default(); for row in statement.query_map(params![worktree_id], |row| { Ok(( row.get::<_, String>(0)?.into(), diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 96493fc4d3..f1abbde3a4 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -2,12 +2,12 @@ use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; -use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime}; #[derive(Clone)] pub struct FileToEmbed { pub worktree_id: i64, - pub path: PathBuf, + pub path: Arc, pub mtime: SystemTime, pub documents: Vec, pub job_handle: JobHandle, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a917eabfc8..6441d8d5c0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -9,6 +9,7 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use anyhow::{anyhow, Result}; +use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; @@ -18,13 +19,10 @@ use language::{Anchor, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; -use project::{ - search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, -}; +use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use smol::channel; use std::{ cmp::Ordering, - collections::{BTreeMap, HashMap}, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -34,7 +32,7 @@ use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, http::HttpClient, paths::EMBEDDINGS_DIR, - ResultExt, + ResultExt, TryFutureExt, }; use workspace::WorkspaceCreated; @@ -68,9 +66,7 @@ pub fn init( if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); if project.read(cx).is_local() { - semantic_index.update(cx, |index, cx| { - index.initialize_project(project, cx).detach_and_log_err(cx) - }); + semantic_index.update(cx, |index, cx| index.register_project(project, cx)); } } } @@ -111,11 +107,56 @@ pub struct SemanticIndex { } struct ProjectState { - worktree_db_ids: Vec<(WorktreeId, i64)>, - _subscription: gpui::Subscription, + worktrees: HashMap, outstanding_job_count_rx: watch::Receiver, outstanding_job_count_tx: Arc>>, - changed_paths: BTreeMap, + _subscription: gpui::Subscription, +} + +enum WorktreeState { + Registering(RegisteringWorktreeState), + Registered(RegisteredWorktreeState), +} + +impl WorktreeState { + fn paths_changed( + &mut self, + changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, + changed_at: Instant, + worktree: &Worktree, + ) { + let changed_paths = match self { + Self::Registering(state) => &mut state.changed_paths, + Self::Registered(state) => &mut state.changed_paths, + }; + + for (path, entry_id, change) in changes.iter() { + let Some(entry) = worktree.entry_for_id(*entry_id) else { + continue; + }; + if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() { + continue; + } + changed_paths.insert( + path.clone(), + ChangedPathInfo { + changed_at, + mtime: entry.mtime, + is_deleted: *change == PathChange::Removed, + }, + ); + } + } +} + +struct RegisteringWorktreeState { + changed_paths: BTreeMap, ChangedPathInfo>, + _registration: Task>, +} + +struct RegisteredWorktreeState { + db_id: i64, + changed_paths: BTreeMap, ChangedPathInfo>, } struct ChangedPathInfo { @@ -141,55 +182,42 @@ impl JobHandle { } impl ProjectState { - fn new( - subscription: gpui::Subscription, - worktree_db_ids: Vec<(WorktreeId, i64)>, - changed_paths: BTreeMap, - ) -> Self { + fn new(subscription: gpui::Subscription) -> Self { let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0); let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx)); Self { - worktree_db_ids, + worktrees: Default::default(), outstanding_job_count_rx, outstanding_job_count_tx, - changed_paths, _subscription: subscription, } } - pub fn get_outstanding_count(&self) -> usize { - self.outstanding_job_count_rx.borrow().clone() - } - fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { - self.worktree_db_ids - .iter() - .find_map(|(worktree_id, db_id)| { - if *worktree_id == id { - Some(*db_id) - } else { - None - } - }) + match self.worktrees.get(&id)? { + WorktreeState::Registering(_) => None, + WorktreeState::Registered(state) => Some(state.db_id), + } } fn worktree_id_for_db_id(&self, id: i64) -> Option { - self.worktree_db_ids + self.worktrees .iter() - .find_map(|(worktree_id, db_id)| { - if *db_id == id { - Some(*worktree_id) - } else { - None - } + .find_map(|(worktree_id, worktree_state)| match worktree_state { + WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id), + _ => None, }) } + + fn worktree(&mut self, id: WorktreeId) -> Option<&mut WorktreeState> { + self.worktrees.get_mut(&id) + } } #[derive(Clone)] pub struct PendingFile { worktree_db_id: i64, - relative_path: PathBuf, + relative_path: Arc, absolute_path: PathBuf, language: Option>, modified_time: SystemTime, @@ -298,7 +326,7 @@ impl SemanticIndex { parsing_files_tx, _embedding_task, _parsing_files_tasks, - projects: HashMap::new(), + projects: Default::default(), } })) } @@ -369,9 +397,9 @@ impl SemanticIndex { fn project_entries_changed( &mut self, project: ModelHandle, + worktree_id: WorktreeId, changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, - cx: &mut ModelContext<'_, SemanticIndex>, - worktree_id: &WorktreeId, + cx: &mut ModelContext, ) { let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else { return; @@ -381,131 +409,103 @@ impl SemanticIndex { return; }; - let embeddings_for_digest = { - let mut worktree_id_file_paths = HashMap::new(); - for (path, _) in &project_state.changed_paths { - if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id) - { - worktree_id_file_paths - .entry(worktree_db_id) - .or_insert(Vec::new()) - .push(path.path.clone()); - } - } - self.db.embeddings_for_files(worktree_id_file_paths) - }; - - let worktree = worktree.read(cx); let change_time = Instant::now(); - for (path, entry_id, change) in changes.iter() { - let Some(entry) = worktree.entry_for_id(*entry_id) else { - continue; + let worktree = worktree.read(cx); + let worktree_state = if let Some(worktree_state) = project_state.worktree(worktree_id) { + worktree_state + } else { + return; + }; + worktree_state.paths_changed(changes, Instant::now(), worktree); + if let WorktreeState::Registered(worktree_state) = worktree_state { + let embeddings_for_digest = { + let worktree_paths = worktree_state + .changed_paths + .iter() + .map(|(path, _)| path.clone()) + .collect::>(); + let mut worktree_id_file_paths = HashMap::default(); + worktree_id_file_paths.insert(worktree_state.db_id, worktree_paths); + self.db.embeddings_for_files(worktree_id_file_paths) }; - if entry.is_ignored || entry.is_symlink || entry.is_external { - continue; - } - let project_path = ProjectPath { - worktree_id: *worktree_id, - path: path.clone(), - }; - project_state.changed_paths.insert( - project_path, - ChangedPathInfo { - changed_at: change_time, - mtime: entry.mtime, - is_deleted: *change == PathChange::Removed, - }, - ); + + cx.spawn_weak(|this, mut cx| async move { + let embeddings_for_digest = + embeddings_for_digest.await.log_err().unwrap_or_default(); + + cx.background().timer(BACKGROUND_INDEXING_DELAY).await; + if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { + Self::reindex_changed_paths( + this, + project, + Some(change_time), + &mut cx, + Arc::new(embeddings_for_digest), + ) + .await; + } + }) + .detach(); } - - cx.spawn_weak(|this, mut cx| async move { - let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default(); - - cx.background().timer(BACKGROUND_INDEXING_DELAY).await; - if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - Self::reindex_changed_paths( - this, - project, - Some(change_time), - &mut cx, - Arc::new(embeddings_for_digest), - ) - .await; - } - }) - .detach(); } - pub fn initialize_project( + pub fn register_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + log::trace!("Registering Project for Semantic Index"); + + let subscription = cx.subscribe(&project, |this, project, event, cx| match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + this.project_worktrees_changed(project.clone(), cx); + } + project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { + this.project_entries_changed(project, *worktree_id, changes.clone(), cx); + } + _ => {} + }); + self.projects + .insert(project.downgrade(), ProjectState::new(subscription)); + self.project_worktrees_changed(project, cx); + } + + fn register_worktree( &mut self, project: ModelHandle, + worktree: ModelHandle, cx: &mut ModelContext, - ) -> Task> { - log::trace!("Initializing Project for Semantic Index"); - let worktree_scans_complete = project - .read(cx) - .worktrees(cx) - .map(|worktree| { - let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete(); - async move { - scan_complete.await; - } - }) - .collect::>(); - - let worktree_db_ids = project - .read(cx) - .worktrees(cx) - .map(|worktree| { - self.db - .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) - }) - .collect::>(); - - let _subscription = cx.subscribe(&project, |this, project, event, cx| { - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event { - this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id); - }; - }); - + ) { + let project = project.downgrade(); + let project_state = if let Some(project_state) = self.projects.get_mut(&project) { + project_state + } else { + return; + }; + let worktree = if let Some(worktree) = worktree.read(cx).as_local() { + worktree + } else { + return; + }; + let worktree_abs_path = worktree.abs_path().clone(); + let scan_complete = worktree.scan_complete(); + let worktree_id = worktree.id(); + let db = self.db.clone(); let language_registry = self.language_registry.clone(); - - cx.spawn(|this, mut cx| async move { - futures::future::join_all(worktree_scans_complete).await; - - let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; - let worktrees = project.read_with(&cx, |project, cx| { - project - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>() - }); - - let mut worktree_file_mtimes = HashMap::new(); - let mut db_ids_by_worktree_id = HashMap::new(); - - for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) { - let db_id = db_id?; - db_ids_by_worktree_id.insert(worktree.id(), db_id); - worktree_file_mtimes.insert( - worktree.id(), - this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id)) - .await?, - ); - } - - let worktree_db_ids = db_ids_by_worktree_id - .iter() - .map(|(a, b)| (*a, *b)) - .collect(); - - let changed_paths = cx - .background() - .spawn(async move { - let mut changed_paths = BTreeMap::new(); - let now = Instant::now(); - for worktree in worktrees.into_iter() { - let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap(); + let registration = cx.spawn(|this, mut cx| { + async move { + scan_complete.await; + let db_id = db.find_or_create_worktree(worktree_abs_path).await?; + let mut file_mtimes = db.get_file_mtimes(db_id).await?; + let worktree = if let Some(project) = project.upgrade(&cx) { + project + .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx)) + .ok_or_else(|| anyhow!("worktree not found"))? + } else { + return anyhow::Ok(()); + }; + let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot()); + let mut changed_paths = cx + .background() + .spawn(async move { + let mut changed_paths = BTreeMap::new(); + let now = Instant::now(); for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -534,10 +534,7 @@ impl SemanticIndex { if !already_stored { changed_paths.insert( - ProjectPath { - worktree_id: worktree.id(), - path: file.path.clone(), - }, + file.path.clone(), ChangedPathInfo { changed_at: now, mtime: file.mtime, @@ -551,10 +548,7 @@ impl SemanticIndex { // Clean up entries from database that are no longer in the worktree. for (path, mtime) in file_mtimes { changed_paths.insert( - ProjectPath { - worktree_id: worktree.id(), - path: path.into(), - }, + path.into(), ChangedPathInfo { changed_at: now, mtime, @@ -562,20 +556,80 @@ impl SemanticIndex { }, ); } + + anyhow::Ok(changed_paths) + }) + .await?; + this.update(&mut cx, |this, _| { + let project_state = this + .projects + .get_mut(&project) + .ok_or_else(|| anyhow!("project not registered"))?; + + if let Some(WorktreeState::Registering(state)) = + project_state.worktrees.remove(&worktree_id) + { + changed_paths.extend(state.changed_paths); } + project_state.worktrees.insert( + worktree_id, + WorktreeState::Registered(RegisteredWorktreeState { + db_id, + changed_paths, + }), + ); - anyhow::Ok(changed_paths) - }) - .await?; + anyhow::Ok(()) + })?; - this.update(&mut cx, |this, _| { - this.projects.insert( - project.downgrade(), - ProjectState::new(_subscription, worktree_db_ids, changed_paths), - ); - }); - Result::<(), _>::Ok(()) - }) + anyhow::Ok(()) + } + .log_err() + }); + project_state.worktrees.insert( + worktree_id, + WorktreeState::Registering(RegisteringWorktreeState { + changed_paths: Default::default(), + _registration: registration, + }), + ); + } + + fn project_worktrees_changed( + &mut self, + project: ModelHandle, + cx: &mut ModelContext, + ) { + let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) + { + project_state + } else { + return; + }; + + let mut worktrees = project + .read(cx) + .worktrees(cx) + .filter(|worktree| worktree.read(cx).is_local()) + .collect::>(); + let worktree_ids = worktrees + .iter() + .map(|worktree| worktree.read(cx).id()) + .collect::>(); + + // Remove worktrees that are no longer present + project_state + .worktrees + .retain(|worktree_id, _| worktree_ids.contains(worktree_id)); + + // Register new worktrees + worktrees.retain(|worktree| { + let worktree_id = worktree.read(cx).id(); + project_state.worktree(worktree_id).is_none() + }); + for worktree in worktrees { + self.register_worktree(project.clone(), worktree, cx); + } } pub fn index_project( @@ -583,28 +637,31 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task)>> { + let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) + { + project_state + } else { + return Task::ready(Err(anyhow!("project was not registered"))); + }; + let outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone(); + + let mut worktree_id_file_paths = HashMap::default(); + for worktree in project_state.worktrees.values() { + if let WorktreeState::Registered(worktree_state) = worktree { + for (path, _) in &worktree_state.changed_paths { + worktree_id_file_paths + .entry(worktree_state.db_id) + .or_insert(Vec::new()) + .push(path.clone()); + } + } + } + cx.spawn(|this, mut cx| async move { let embeddings_for_digest = this.read_with(&cx, |this, _| { - if let Some(state) = this.projects.get(&project.downgrade()) { - let mut worktree_id_file_paths = HashMap::default(); - for (path, _) in &state.changed_paths { - if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id) - { - worktree_id_file_paths - .entry(worktree_db_id) - .or_insert(Vec::new()) - .push(path.path.clone()); - } - } - - Ok(this.db.embeddings_for_files(worktree_id_file_paths)) - } else { - Err(anyhow!("Project not yet initialized")) - } - })?; - + this.db.embeddings_for_files(worktree_id_file_paths) + }); let embeddings_for_digest = Arc::new(embeddings_for_digest.await?); - Self::reindex_changed_paths( this.clone(), project.clone(), @@ -613,15 +670,8 @@ impl SemanticIndex { embeddings_for_digest, ) .await; - - this.update(&mut cx, |this, _cx| { - let Some(state) = this.projects.get(&project.downgrade()) else { - return Err(anyhow!("Project not yet initialized")); - }; - let job_count_rx = state.outstanding_job_count_rx.clone(); - let count = state.get_outstanding_count(); - Ok((count, job_count_rx)) - }) + let count = *outstanding_job_count_rx.borrow(); + Ok((count, outstanding_job_count_rx)) }) } @@ -784,50 +834,49 @@ impl SemanticIndex { let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| { if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; - let db_ids = &project_state.worktree_db_ids; - let mut worktree: Option> = None; + project_state + .worktrees + .retain(|worktree_id, worktree_state| { + let worktree = if let Some(worktree) = + project.read(cx).worktree_for_id(*worktree_id, cx) + { + worktree + } else { + return false; + }; + let worktree_state = + if let WorktreeState::Registered(worktree_state) = worktree_state { + worktree_state + } else { + return true; + }; - project_state.changed_paths.retain(|path, info| { - if let Some(last_changed_before) = last_changed_before { - if info.changed_at > last_changed_before { - return true; - } - } + worktree_state.changed_paths.retain(|path, info| { + if let Some(last_changed_before) = last_changed_before { + if info.changed_at > last_changed_before { + return true; + } + } - if worktree - .as_ref() - .map_or(true, |tree| tree.read(cx).id() != path.worktree_id) - { - worktree = project.read(cx).worktree_for_id(path.worktree_id, cx); - } - let Some(worktree) = &worktree else { - return false; - }; + if info.is_deleted { + files_to_delete.push((worktree_state.db_id, path.clone())); + } else { + let absolute_path = worktree.read(cx).absolutize(path); + let job_handle = JobHandle::new(&outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.clone(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id: worktree_state.db_id, + }); + } - let Some(worktree_db_id) = db_ids - .iter() - .find_map(|entry| (entry.0 == path.worktree_id).then_some(entry.1)) - else { - return false; - }; - - if info.is_deleted { - files_to_delete.push((worktree_db_id, path.path.to_path_buf())); - } else { - let absolute_path = worktree.read(cx).absolutize(&path.path); - let job_handle = JobHandle::new(&outstanding_job_count_tx); - pending_files.push(PendingFile { - absolute_path, - relative_path: path.path.to_path_buf(), - language: None, - job_handle, - modified_time: info.mtime, - worktree_db_id, + false }); - } - - false - }); + true + }); } ( diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index f549e68e04..2f28184f20 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -87,11 +87,8 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let _ = semantic_index - .update(cx, |store, cx| { - store.initialize_project(project.clone(), cx) - }) - .await; + semantic_index.update(cx, |store, cx| store.register_project(project.clone(), cx)); + deterministic.run_until_parked(); let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) @@ -214,7 +211,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let files = (1..=3) .map(|file_ix| FileToEmbed { worktree_id: 5, - path: format!("path-{file_ix}").into(), + path: Path::new(&format!("path-{file_ix}")).into(), mtime: SystemTime::now(), documents: (0..rng.gen_range(4..22)) .map(|document_ix| { From 7b5a41dda22a4d99ae226aaeda658194d4b1689b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Sep 2023 16:09:24 +0200 Subject: [PATCH 35/60] Move retrieval of embeddings from the db into `reindex_changed_files` Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/semantic_index.rs | 222 +++++++++----------- 1 file changed, 101 insertions(+), 121 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6441d8d5c0..0b2a77378e 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -418,30 +418,12 @@ impl SemanticIndex { }; worktree_state.paths_changed(changes, Instant::now(), worktree); if let WorktreeState::Registered(worktree_state) = worktree_state { - let embeddings_for_digest = { - let worktree_paths = worktree_state - .changed_paths - .iter() - .map(|(path, _)| path.clone()) - .collect::>(); - let mut worktree_id_file_paths = HashMap::default(); - worktree_id_file_paths.insert(worktree_state.db_id, worktree_paths); - self.db.embeddings_for_files(worktree_id_file_paths) - }; - cx.spawn_weak(|this, mut cx| async move { - let embeddings_for_digest = - embeddings_for_digest.await.log_err().unwrap_or_default(); - cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - Self::reindex_changed_paths( - this, - project, - Some(change_time), - &mut cx, - Arc::new(embeddings_for_digest), - ) + this.update(&mut cx, |this, cx| { + this.reindex_changed_paths(project, Some(change_time), cx) + }) .await; } }) @@ -644,31 +626,10 @@ impl SemanticIndex { return Task::ready(Err(anyhow!("project was not registered"))); }; let outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone(); - - let mut worktree_id_file_paths = HashMap::default(); - for worktree in project_state.worktrees.values() { - if let WorktreeState::Registered(worktree_state) = worktree { - for (path, _) in &worktree_state.changed_paths { - worktree_id_file_paths - .entry(worktree_state.db_id) - .or_insert(Vec::new()) - .push(path.clone()); - } - } - } - cx.spawn(|this, mut cx| async move { - let embeddings_for_digest = this.read_with(&cx, |this, _| { - this.db.embeddings_for_files(worktree_id_file_paths) - }); - let embeddings_for_digest = Arc::new(embeddings_for_digest.await?); - Self::reindex_changed_paths( - this.clone(), - project.clone(), - None, - &mut cx, - embeddings_for_digest, - ) + this.update(&mut cx, |this, cx| { + this.reindex_changed_paths(project.clone(), None, cx) + }) .await; let count = *outstanding_job_count_rx.borrow(); Ok((count, outstanding_job_count_rx)) @@ -822,94 +783,113 @@ impl SemanticIndex { }) } - async fn reindex_changed_paths( - this: ModelHandle, + fn reindex_changed_paths( + &mut self, project: ModelHandle, last_changed_before: Option, - cx: &mut AsyncAppContext, - embeddings_for_digest: Arc>, - ) { + cx: &mut ModelContext, + ) -> Task<()> { + let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) + { + project_state + } else { + return Task::ready(()); + }; + let mut pending_files = Vec::new(); let mut files_to_delete = Vec::new(); - let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| { - if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { - let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; - project_state - .worktrees - .retain(|worktree_id, worktree_state| { - let worktree = if let Some(worktree) = - project.read(cx).worktree_for_id(*worktree_id, cx) - { - worktree - } else { - return false; - }; - let worktree_state = - if let WorktreeState::Registered(worktree_state) = worktree_state { - worktree_state - } else { - return true; - }; + let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + project_state + .worktrees + .retain(|worktree_id, worktree_state| { + let worktree = + if let Some(worktree) = project.read(cx).worktree_for_id(*worktree_id, cx) { + worktree + } else { + return false; + }; + let worktree_state = + if let WorktreeState::Registered(worktree_state) = worktree_state { + worktree_state + } else { + return true; + }; - worktree_state.changed_paths.retain(|path, info| { - if let Some(last_changed_before) = last_changed_before { - if info.changed_at > last_changed_before { - return true; - } - } + worktree_state.changed_paths.retain(|path, info| { + if let Some(last_changed_before) = last_changed_before { + if info.changed_at > last_changed_before { + return true; + } + } - if info.is_deleted { - files_to_delete.push((worktree_state.db_id, path.clone())); - } else { - let absolute_path = worktree.read(cx).absolutize(path); - let job_handle = JobHandle::new(&outstanding_job_count_tx); - pending_files.push(PendingFile { - absolute_path, - relative_path: path.clone(), - language: None, - job_handle, - modified_time: info.mtime, - worktree_db_id: worktree_state.db_id, - }); - } - - false + if info.is_deleted { + files_to_delete.push((worktree_state.db_id, path.clone())); + } else { + let absolute_path = worktree.read(cx).absolutize(path); + let job_handle = JobHandle::new(&outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.clone(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id: worktree_state.db_id, }); - true - }); - } + } - ( - this.db.clone(), - this.language_registry.clone(), - this.parsing_files_tx.clone(), - ) - }); + false + }); + true + }); - for (worktree_db_id, path) in files_to_delete { - db.delete_file(worktree_db_id, path).await.log_err(); - } - - for mut pending_file in pending_files { - if let Ok(language) = language_registry - .language_for_file(&pending_file.relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; + let mut worktree_id_file_paths = HashMap::default(); + for worktree in project_state.worktrees.values() { + if let WorktreeState::Registered(worktree_state) = worktree { + for (path, _) in &worktree_state.changed_paths { + worktree_id_file_paths + .entry(worktree_state.db_id) + .or_insert(Vec::new()) + .push(path.clone()); } - pending_file.language = Some(language); } - parsing_files_tx - .try_send((embeddings_for_digest.clone(), pending_file)) - .ok(); } + + let db = self.db.clone(); + let language_registry = self.language_registry.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + cx.background().spawn(async move { + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); + } + + let embeddings_for_digest = Arc::new( + db.embeddings_for_files(worktree_id_file_paths) + .await + .log_err() + .unwrap_or_default(), + ); + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); + } + }) } } From 6b1dc63fc057df75acee794c4d8670c6fcb120f1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Sep 2023 16:16:12 +0200 Subject: [PATCH 36/60] Retrieve embeddings based on pending files Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/semantic_index.rs | 35 ++++++++++----------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 0b2a77378e..8122cffdc9 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -417,7 +417,7 @@ impl SemanticIndex { return; }; worktree_state.paths_changed(changes, Instant::now(), worktree); - if let WorktreeState::Registered(worktree_state) = worktree_state { + if let WorktreeState::Registered(_) = worktree_state { cx.spawn_weak(|this, mut cx| async move { cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { @@ -842,18 +842,6 @@ impl SemanticIndex { true }); - let mut worktree_id_file_paths = HashMap::default(); - for worktree in project_state.worktrees.values() { - if let WorktreeState::Registered(worktree_state) = worktree { - for (path, _) in &worktree_state.changed_paths { - worktree_id_file_paths - .entry(worktree_state.db_id) - .or_insert(Vec::new()) - .push(path.clone()); - } - } - } - let db = self.db.clone(); let language_registry = self.language_registry.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); @@ -862,12 +850,21 @@ impl SemanticIndex { db.delete_file(worktree_db_id, path).await.log_err(); } - let embeddings_for_digest = Arc::new( - db.embeddings_for_files(worktree_id_file_paths) - .await - .log_err() - .unwrap_or_default(), - ); + let embeddings_for_digest = { + let mut files = HashMap::default(); + for pending_file in &pending_files { + files + .entry(pending_file.worktree_db_id) + .or_insert(Vec::new()) + .push(pending_file.relative_path.clone()); + } + Arc::new( + db.embeddings_for_files(files) + .await + .log_err() + .unwrap_or_default(), + ) + }; for mut pending_file in pending_files { if let Ok(language) = language_registry From e2479a7172a617bb61804973c3dbb2f3b924d87a Mon Sep 17 00:00:00 2001 From: Nate Butler Date: Tue, 5 Sep 2023 10:24:49 -0400 Subject: [PATCH 37/60] Fix cropped search filters --- styles/src/style_tree/search.ts | 4 ++-- styles/src/style_tree/toolbar.ts | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/styles/src/style_tree/search.ts b/styles/src/style_tree/search.ts index a93aab4ea8..8174690fde 100644 --- a/styles/src/style_tree/search.ts +++ b/styles/src/style_tree/search.ts @@ -48,7 +48,7 @@ export default function search(): any { } return { - padding: { top: 0, bottom: 0 }, + padding: { top: 4, bottom: 4 }, option_button: toggleable({ base: interactive({ @@ -394,7 +394,7 @@ export default function search(): any { }), }, }), - search_bar_row_height: 32, + search_bar_row_height: 34, search_row_spacing: 8, option_button_height: 22, modes_container: {}, diff --git a/styles/src/style_tree/toolbar.ts b/styles/src/style_tree/toolbar.ts index 01a09a0616..adf8fb866f 100644 --- a/styles/src/style_tree/toolbar.ts +++ b/styles/src/style_tree/toolbar.ts @@ -8,8 +8,8 @@ export const toolbar = () => { const theme = useTheme() return { - height: 32, - padding: { left: 4, right: 4, top: 4, bottom: 4 }, + height: 42, + padding: { left: 4, right: 4 }, background: background(theme.highest), border: border(theme.highest, { bottom: true }), item_spacing: 4, From 3c70b127bd2785a943ae81bc3643f73eeb8babae Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Sep 2023 16:52:58 +0200 Subject: [PATCH 38/60] Simplify `SemanticIndex::index_project` Co-Authored-By: Kyle Caverly --- crates/search/src/project_search.rs | 40 ++-- crates/semantic_index/src/semantic_index.rs | 176 +++++++----------- .../src/semantic_index_tests.rs | 27 ++- 3 files changed, 99 insertions(+), 144 deletions(-) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 6364183877..f1a0ff71d3 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -12,15 +12,13 @@ use editor::{ SelectAll, MAX_TAB_TITLE_LEN, }; use futures::StreamExt; - -use gpui::platform::PromptLevel; - use gpui::{ - actions, elements::*, platform::MouseButton, Action, AnyElement, AnyViewHandle, AppContext, - Entity, ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, - WeakModelHandle, WeakViewHandle, + actions, + elements::*, + platform::{MouseButton, PromptLevel}, + Action, AnyElement, AnyViewHandle, AppContext, Entity, ModelContext, ModelHandle, Subscription, + Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle, }; - use menu::Confirm; use postage::stream::Stream; use project::{ @@ -132,8 +130,7 @@ pub struct ProjectSearchView { } struct SemanticSearchState { - file_count: usize, - outstanding_file_count: usize, + pending_file_count: usize, _progress_task: Task<()>, } @@ -319,12 +316,8 @@ impl View for ProjectSearchView { }; let semantic_status = if let Some(semantic) = &self.semantic_state { - if semantic.outstanding_file_count > 0 { - format!( - "Indexing: {} of {}...", - semantic.file_count - semantic.outstanding_file_count, - semantic.file_count - ) + if semantic.pending_file_count > 0 { + format!("Remaining files to index: {}", semantic.pending_file_count) } else { "Indexing complete".to_string() } @@ -641,26 +634,25 @@ impl ProjectSearchView { let project = self.model.read(cx).project.clone(); - let index_task = semantic_index.update(cx, |semantic_index, cx| { - semantic_index.index_project(project, cx) + let mut pending_file_count_rx = semantic_index.update(cx, |semantic_index, cx| { + semantic_index.index_project(project.clone(), cx); + semantic_index.pending_file_count(&project).unwrap() }); cx.spawn(|search_view, mut cx| async move { - let (files_to_index, mut files_remaining_rx) = index_task.await?; - search_view.update(&mut cx, |search_view, cx| { cx.notify(); + let pending_file_count = *pending_file_count_rx.borrow(); search_view.semantic_state = Some(SemanticSearchState { - file_count: files_to_index, - outstanding_file_count: files_to_index, + pending_file_count, _progress_task: cx.spawn(|search_view, mut cx| async move { - while let Some(count) = files_remaining_rx.recv().await { + while let Some(count) = pending_file_count_rx.recv().await { search_view .update(&mut cx, |search_view, cx| { if let Some(semantic_search_state) = &mut search_view.semantic_state { - semantic_search_state.outstanding_file_count = count; + semantic_search_state.pending_file_count = count; cx.notify(); if count == 0 { return; @@ -959,7 +951,7 @@ impl ProjectSearchView { match mode { SearchMode::Semantic => { if let Some(semantic) = &mut self.semantic_state { - if semantic.outstanding_file_count > 0 { + if semantic.pending_file_count > 0 { return; } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 8122cffdc9..2de78ab7e3 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -66,7 +66,9 @@ pub fn init( if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); if project.read(cx).is_local() { - semantic_index.update(cx, |index, cx| index.register_project(project, cx)); + semantic_index.update(cx, |index, cx| { + index.register_project(project, cx); + }); } } } @@ -122,7 +124,6 @@ impl WorktreeState { fn paths_changed( &mut self, changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, - changed_at: Instant, worktree: &Worktree, ) { let changed_paths = match self { @@ -140,7 +141,6 @@ impl WorktreeState { changed_paths.insert( path.clone(), ChangedPathInfo { - changed_at, mtime: entry.mtime, is_deleted: *change == PathChange::Removed, }, @@ -160,7 +160,6 @@ struct RegisteredWorktreeState { } struct ChangedPathInfo { - changed_at: Instant, mtime: SystemTime, is_deleted: bool, } @@ -409,43 +408,47 @@ impl SemanticIndex { return; }; - let change_time = Instant::now(); let worktree = worktree.read(cx); let worktree_state = if let Some(worktree_state) = project_state.worktree(worktree_id) { worktree_state } else { return; }; - worktree_state.paths_changed(changes, Instant::now(), worktree); + worktree_state.paths_changed(changes, worktree); if let WorktreeState::Registered(_) = worktree_state { cx.spawn_weak(|this, mut cx| async move { cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - this.update(&mut cx, |this, cx| { - this.reindex_changed_paths(project, Some(change_time), cx) - }) - .await; + this.update(&mut cx, |this, cx| this.index_project(project, cx)); } }) .detach(); } } - pub fn register_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { - log::trace!("Registering Project for Semantic Index"); + fn register_project( + &mut self, + project: ModelHandle, + cx: &mut ModelContext, + ) -> &mut ProjectState { + if !self.projects.contains_key(&project.downgrade()) { + log::trace!("Registering Project for Semantic Index"); - let subscription = cx.subscribe(&project, |this, project, event, cx| match event { - project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { - this.project_worktrees_changed(project.clone(), cx); - } - project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { - this.project_entries_changed(project, *worktree_id, changes.clone(), cx); - } - _ => {} - }); - self.projects - .insert(project.downgrade(), ProjectState::new(subscription)); - self.project_worktrees_changed(project, cx); + let subscription = cx.subscribe(&project, |this, project, event, cx| match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + this.project_worktrees_changed(project.clone(), cx); + } + project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { + this.project_entries_changed(project, *worktree_id, changes.clone(), cx); + } + _ => {} + }); + self.projects + .insert(project.downgrade(), ProjectState::new(subscription)); + self.project_worktrees_changed(project.clone(), cx); + } + + self.projects.get_mut(&project.downgrade()).unwrap() } fn register_worktree( @@ -487,7 +490,6 @@ impl SemanticIndex { .background() .spawn(async move { let mut changed_paths = BTreeMap::new(); - let now = Instant::now(); for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); @@ -518,7 +520,6 @@ impl SemanticIndex { changed_paths.insert( file.path.clone(), ChangedPathInfo { - changed_at: now, mtime: file.mtime, is_deleted: false, }, @@ -532,7 +533,6 @@ impl SemanticIndex { changed_paths.insert( path.into(), ChangedPathInfo { - changed_at: now, mtime, is_deleted: true, }, @@ -614,29 +614,7 @@ impl SemanticIndex { } } - pub fn index_project( - &mut self, - project: ModelHandle, - cx: &mut ModelContext, - ) -> Task)>> { - let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) - { - project_state - } else { - return Task::ready(Err(anyhow!("project was not registered"))); - }; - let outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone(); - cx.spawn(|this, mut cx| async move { - this.update(&mut cx, |this, cx| { - this.reindex_changed_paths(project.clone(), None, cx) - }) - .await; - let count = *outstanding_job_count_rx.borrow(); - Ok((count, outstanding_job_count_rx)) - }) - } - - pub fn outstanding_job_count_rx( + pub fn pending_file_count( &self, project: &ModelHandle, ) -> Option> { @@ -783,18 +761,8 @@ impl SemanticIndex { }) } - fn reindex_changed_paths( - &mut self, - project: ModelHandle, - last_changed_before: Option, - cx: &mut ModelContext, - ) -> Task<()> { - let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) - { - project_state - } else { - return Task::ready(()); - }; + pub fn index_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + let project_state = self.register_project(project.clone(), cx); let mut pending_files = Vec::new(); let mut files_to_delete = Vec::new(); @@ -816,12 +784,6 @@ impl SemanticIndex { }; worktree_state.changed_paths.retain(|path, info| { - if let Some(last_changed_before) = last_changed_before { - if info.changed_at > last_changed_before { - return true; - } - } - if info.is_deleted { files_to_delete.push((worktree_state.db_id, path.clone())); } else { @@ -845,48 +807,50 @@ impl SemanticIndex { let db = self.db.clone(); let language_registry = self.language_registry.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); - cx.background().spawn(async move { - for (worktree_db_id, path) in files_to_delete { - db.delete_file(worktree_db_id, path).await.log_err(); - } - - let embeddings_for_digest = { - let mut files = HashMap::default(); - for pending_file in &pending_files { - files - .entry(pending_file.worktree_db_id) - .or_insert(Vec::new()) - .push(pending_file.relative_path.clone()); + cx.background() + .spawn(async move { + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); } - Arc::new( - db.embeddings_for_files(files) - .await - .log_err() - .unwrap_or_default(), - ) - }; - for mut pending_file in pending_files { - if let Ok(language) = language_registry - .language_for_file(&pending_file.relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; + let embeddings_for_digest = { + let mut files = HashMap::default(); + for pending_file in &pending_files { + files + .entry(pending_file.worktree_db_id) + .or_insert(Vec::new()) + .push(pending_file.relative_path.clone()); } - pending_file.language = Some(language); + Arc::new( + db.embeddings_for_files(files) + .await + .log_err() + .unwrap_or_default(), + ) + }; + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); } - parsing_files_tx - .try_send((embeddings_for_digest.clone(), pending_file)) - .ok(); - } - }) + }) + .detach() } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 2f28184f20..008a9e0434 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -87,16 +87,18 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - semantic_index.update(cx, |store, cx| store.register_project(project.clone(), cx)); + semantic_index.update(cx, |store, cx| { + store.register_project(project.clone(), cx); + }); deterministic.run_until_parked(); - let (file_count, outstanding_file_count) = semantic_index - .update(cx, |store, cx| store.index_project(project.clone(), cx)) - .await - .unwrap(); - assert_eq!(file_count, 3); + let pending_file_count = + semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap()); + semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); + deterministic.run_until_parked(); + assert_eq!(*pending_file_count.borrow(), 3); deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); - assert_eq!(*outstanding_file_count.borrow(), 0); + assert_eq!(*pending_file_count.borrow(), 0); let search_results = semantic_index .update(cx, |store, cx| { @@ -188,14 +190,11 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); let prev_embedding_count = embedding_provider.embedding_count(); - let (file_count, outstanding_file_count) = semantic_index - .update(cx, |store, cx| store.index_project(project.clone(), cx)) - .await - .unwrap(); - assert_eq!(file_count, 1); - + semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); + deterministic.run_until_parked(); + assert_eq!(*pending_file_count.borrow(), 1); deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); - assert_eq!(*outstanding_file_count.borrow(), 0); + assert_eq!(*pending_file_count.borrow(), 0); assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, From 95b72a73ade5f6579be0e5c03becc2ab2f3c592e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 5 Sep 2023 17:17:58 +0200 Subject: [PATCH 39/60] Re-index project when a worktree is registered Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/semantic_index.rs | 69 +++++++------------ .../src/semantic_index_tests.rs | 7 +- 2 files changed, 24 insertions(+), 52 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 2de78ab7e3..6b94874c43 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -34,7 +34,6 @@ use util::{ paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt, }; -use workspace::WorkspaceCreated; const SEMANTIC_INDEX_VERSION: usize = 9; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); @@ -57,24 +56,6 @@ pub fn init( return; } - cx.subscribe_global::({ - move |event, cx| { - let Some(semantic_index) = SemanticIndex::global(cx) else { - return; - }; - let workspace = &event.0; - if let Some(workspace) = workspace.upgrade(cx) { - let project = workspace.read(cx).project().clone(); - if project.read(cx).is_local() { - semantic_index.update(cx, |index, cx| { - index.register_project(project, cx); - }); - } - } - } - }) - .detach(); - cx.spawn(move |mut cx| async move { let semantic_index = SemanticIndex::new( fs, @@ -426,31 +407,6 @@ impl SemanticIndex { } } - fn register_project( - &mut self, - project: ModelHandle, - cx: &mut ModelContext, - ) -> &mut ProjectState { - if !self.projects.contains_key(&project.downgrade()) { - log::trace!("Registering Project for Semantic Index"); - - let subscription = cx.subscribe(&project, |this, project, event, cx| match event { - project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { - this.project_worktrees_changed(project.clone(), cx); - } - project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { - this.project_entries_changed(project, *worktree_id, changes.clone(), cx); - } - _ => {} - }); - self.projects - .insert(project.downgrade(), ProjectState::new(subscription)); - self.project_worktrees_changed(project.clone(), cx); - } - - self.projects.get_mut(&project.downgrade()).unwrap() - } - fn register_worktree( &mut self, project: ModelHandle, @@ -542,11 +498,14 @@ impl SemanticIndex { anyhow::Ok(changed_paths) }) .await?; - this.update(&mut cx, |this, _| { + this.update(&mut cx, |this, cx| { let project_state = this .projects .get_mut(&project) .ok_or_else(|| anyhow!("project not registered"))?; + let project = project + .upgrade(cx) + .ok_or_else(|| anyhow!("project was dropped"))?; if let Some(WorktreeState::Registering(state)) = project_state.worktrees.remove(&worktree_id) @@ -560,6 +519,7 @@ impl SemanticIndex { changed_paths, }), ); + this.index_project(project, cx); anyhow::Ok(()) })?; @@ -762,7 +722,24 @@ impl SemanticIndex { } pub fn index_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { - let project_state = self.register_project(project.clone(), cx); + if !self.projects.contains_key(&project.downgrade()) { + log::trace!("Registering Project for Semantic Index"); + + let subscription = cx.subscribe(&project, |this, project, event, cx| match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + this.project_worktrees_changed(project.clone(), cx); + } + project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { + this.project_entries_changed(project, *worktree_id, changes.clone(), cx); + } + _ => {} + }); + self.projects + .insert(project.downgrade(), ProjectState::new(subscription)); + self.project_worktrees_changed(project.clone(), cx); + } + + let project_state = self.projects.get_mut(&project.downgrade()).unwrap(); let mut pending_files = Vec::new(); let mut files_to_delete = Vec::new(); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 008a9e0434..ca5f7df533 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -87,14 +87,9 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - semantic_index.update(cx, |store, cx| { - store.register_project(project.clone(), cx); - }); - deterministic.run_until_parked(); - + semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); let pending_file_count = semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap()); - semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); deterministic.run_until_parked(); assert_eq!(*pending_file_count.borrow(), 3); deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); From ec5ff20b4ca31561c49590f5f61cc65ee5551588 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 5 Sep 2023 11:34:24 -0700 Subject: [PATCH 40/60] Implement clearing stale channel buffer participants on server restart Co-authored-by: Mikayla --- crates/channel/src/channel_store.rs | 4 + crates/collab/src/db.rs | 1 + crates/collab/src/db/queries/buffers.rs | 26 ++++- crates/collab/src/db/queries/rooms.rs | 2 +- crates/collab/src/db/queries/servers.rs | 1 + crates/collab/src/rpc.rs | 9 +- .../collab/src/tests/channel_buffer_tests.rs | 96 ++++++++++++++++++- 7 files changed, 133 insertions(+), 6 deletions(-) diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index ec1652581d..3d2f61d61f 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -500,6 +500,10 @@ impl ChannelStore { } } + if buffer_versions.is_empty() { + return Task::ready(Ok(())); + } + let response = self.client.request(proto::RejoinChannelBuffers { buffers: buffer_versions, }); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 823990eaf8..b5d968ddf3 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -435,6 +435,7 @@ pub struct ChannelsForUser { pub channels_with_admin_privileges: HashSet, } +#[derive(Debug)] pub struct RejoinedChannelBuffer { pub buffer: proto::RejoinedChannelBuffer, pub old_connection_id: ConnectionId, diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 813255b80e..8236eb9c3b 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -118,6 +118,7 @@ impl Database { // connection, then the client's buffer can be syncronized with // the server's buffer. if buffer.epoch as u64 != client_buffer.epoch { + log::info!("can't rejoin buffer, epoch has changed"); continue; } @@ -128,6 +129,7 @@ impl Database { c.user_id == user_id && (c.connection_lost || c.connection_server_id != server_id) }) else { + log::info!("can't rejoin buffer, no previous collaborator found"); continue; }; let old_connection_id = self_collaborator.connection(); @@ -196,16 +198,36 @@ impl Database { .await } - pub async fn refresh_channel_buffer( + pub async fn clear_stale_channel_buffer_collaborators( &self, channel_id: ChannelId, server_id: ServerId, ) -> Result { self.transaction(|tx| async move { + let collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + let mut connection_ids = Vec::new(); let mut removed_collaborators = Vec::new(); + let mut collaborator_ids_to_remove = Vec::new(); + for collaborator in &collaborators { + if !collaborator.connection_lost && collaborator.connection_server_id == server_id { + connection_ids.push(collaborator.connection()); + } else { + removed_collaborators.push(proto::RemoveChannelBufferCollaborator { + channel_id: channel_id.to_proto(), + peer_id: Some(collaborator.connection().into()), + }); + collaborator_ids_to_remove.push(collaborator.id); + } + } - // TODO + channel_buffer_collaborator::Entity::delete_many() + .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove)) + .exec(&*tx) + .await?; Ok(RefreshedChannelBuffer { connection_ids, diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 435e729fed..e348b50bee 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -1,7 +1,7 @@ use super::*; impl Database { - pub async fn refresh_room( + pub async fn clear_stale_room_participants( &self, room_id: RoomId, new_server_id: ServerId, diff --git a/crates/collab/src/db/queries/servers.rs b/crates/collab/src/db/queries/servers.rs index 2b1d0d2c0c..e5ceee8887 100644 --- a/crates/collab/src/db/queries/servers.rs +++ b/crates/collab/src/db/queries/servers.rs @@ -55,6 +55,7 @@ impl Database { .into_values::<_, QueryChannelIds>() .all(&*tx) .await?; + Ok((room_ids, channel_ids)) }) .await diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 95307ba725..e454fcbb9e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -285,11 +285,15 @@ impl Server { .trace_err() { tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms"); + tracing::info!( + stale_channel_buffer_count = channel_ids.len(), + "retrieved stale channel buffers" + ); for channel_id in channel_ids { if let Some(refreshed_channel_buffer) = app_state .db - .refresh_channel_buffer(channel_id, server_id) + .clear_stale_channel_buffer_collaborators(channel_id, server_id) .await .trace_err() { @@ -309,7 +313,7 @@ impl Server { if let Some(mut refreshed_room) = app_state .db - .refresh_room(room_id, server_id) + .clear_stale_room_participants(room_id, server_id) .await .trace_err() { @@ -873,6 +877,7 @@ async fn connection_lost( futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id); leave_room_for_session(&session).await.trace_err(); leave_channel_buffers_for_session(&session) .await diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 236771c2a5..fe286895b4 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -1,4 +1,7 @@ -use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; +use crate::{ + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::TestServer, +}; use call::ActiveCall; use channel::Channel; use client::UserId; @@ -472,6 +475,97 @@ async fn test_rejoin_channel_buffer( }); } +#[gpui::test] +async fn test_channel_buffers_and_server_restarts( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, + cx_c: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; + + let channel_id = server + .make_channel( + "the-channel", + (&client_a, cx_a), + &mut [(&client_b, cx_b), (&client_c, cx_c)], + ) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let _channel_buffer_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client C can't reconnect. + client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending())); + + // Server stops. + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + + // While the server is down, both clients make an edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + + // Server restarts. + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + + // Clients reconnects. Clients A and B see each other's edits, and see + // that client C has disconnected. + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + + channel_buffer_a.read_with(cx_a, |buffer_a, _| { + channel_buffer_b.read_with(cx_b, |buffer_b, _| { + assert_eq!( + buffer_a + .collaborators() + .iter() + .map(|c| c.user_id) + .collect::>(), + vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()] + ); + assert_eq!(buffer_a.collaborators(), buffer_b.collaborators()); + }); + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!( From 653d4976cd4a03e1043b3c4a50453a1ed5e27aeb Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Tue, 5 Sep 2023 17:13:09 -0400 Subject: [PATCH 41/60] Add operation for opening channel notes in channel based calls --- crates/call/src/call.rs | 2 +- crates/collab_ui/src/collab_panel.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 5af094df05..5886462ccf 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -403,7 +403,7 @@ impl ActiveCall { &self.pending_invites } - fn report_call_event(&self, operation: &'static str, cx: &AppContext) { + pub fn report_call_event(&self, operation: &'static str, cx: &AppContext) { if let Some(room) = self.room() { let room = room.read(cx); Self::report_call_event_for_room( diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index daaa483975..d27cdc8acf 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2249,6 +2249,9 @@ impl CollabPanel { anyhow::Ok(()) }) .detach(); + ActiveCall::global(cx).update(cx, |call, cx| { + call.report_call_event("open channel notes", cx) + }); } } From c802680084d8cadf7aa2a8ac096e5dbd452abd31 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 6 Sep 2023 09:41:51 +0200 Subject: [PATCH 42/60] Clip ranges returned by `SemanticIndex::search` The files may have changed since the last time they were parsed, so the ranges returned by `SemanticIndex::search` may be out of bounds. --- crates/semantic_index/src/semantic_index.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6b94874c43..7ce4c9c2e4 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -15,7 +15,7 @@ use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; use futures::{FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; -use language::{Anchor, Buffer, Language, LanguageRegistry}; +use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; @@ -713,7 +713,9 @@ impl SemanticIndex { .filter_map(|(buffer, range)| { let buffer = buffer.log_err()?; let range = buffer.read_with(&cx, |buffer, _| { - buffer.anchor_before(range.start)..buffer.anchor_after(range.end) + let start = buffer.clip_offset(range.start, Bias::Left); + let end = buffer.clip_offset(range.end, Bias::Right); + buffer.anchor_before(start)..buffer.anchor_after(end) }); Some(SearchResult { buffer, range }) }) From de0f53b39f5e41ef22cd90635de96dcb4e086259 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 6 Sep 2023 11:40:59 +0200 Subject: [PATCH 43/60] Ensure `SemanticIndex::search` waits for indexing to complete --- crates/search/src/project_search.rs | 4 +- crates/semantic_index/src/semantic_index.rs | 492 +++++++++++------- .../src/semantic_index_tests.rs | 29 +- 3 files changed, 307 insertions(+), 218 deletions(-) diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index f1a0ff71d3..9bebd448a7 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -635,7 +635,9 @@ impl ProjectSearchView { let project = self.model.read(cx).project.clone(); let mut pending_file_count_rx = semantic_index.update(cx, |semantic_index, cx| { - semantic_index.index_project(project.clone(), cx); + semantic_index + .index_project(project.clone(), cx) + .detach_and_log_err(cx); semantic_index.pending_file_count(&project).unwrap() }); diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7ce4c9c2e4..a098152784 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -13,7 +13,7 @@ use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use embedding_queue::{EmbeddingQueue, FileToEmbed}; -use futures::{FutureExt, StreamExt}; +use futures::{future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; @@ -23,6 +23,7 @@ use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Work use smol::channel; use std::{ cmp::Ordering, + future::Future, ops::Range, path::{Path, PathBuf}, sync::{Arc, Weak}, @@ -32,7 +33,7 @@ use util::{ channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}, http::HttpClient, paths::EMBEDDINGS_DIR, - ResultExt, TryFutureExt, + ResultExt, }; const SEMANTIC_INDEX_VERSION: usize = 9; @@ -132,7 +133,21 @@ impl WorktreeState { struct RegisteringWorktreeState { changed_paths: BTreeMap, ChangedPathInfo>, - _registration: Task>, + done_rx: watch::Receiver>, + _registration: Task<()>, +} + +impl RegisteringWorktreeState { + fn done(&self) -> impl Future { + let mut done_rx = self.done_rx.clone(); + async move { + while let Some(result) = done_rx.next().await { + if result.is_some() { + break; + } + } + } + } } struct RegisteredWorktreeState { @@ -173,13 +188,6 @@ impl ProjectState { } } - fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option { - match self.worktrees.get(&id)? { - WorktreeState::Registering(_) => None, - WorktreeState::Registered(state) => Some(state.db_id), - } - } - fn worktree_id_for_db_id(&self, id: i64) -> Option { self.worktrees .iter() @@ -188,10 +196,6 @@ impl ProjectState { _ => None, }) } - - fn worktree(&mut self, id: WorktreeId) -> Option<&mut WorktreeState> { - self.worktrees.get_mut(&id) - } } #[derive(Clone)] @@ -390,17 +394,20 @@ impl SemanticIndex { }; let worktree = worktree.read(cx); - let worktree_state = if let Some(worktree_state) = project_state.worktree(worktree_id) { - worktree_state - } else { - return; - }; + let worktree_state = + if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) { + worktree_state + } else { + return; + }; worktree_state.paths_changed(changes, worktree); if let WorktreeState::Registered(_) = worktree_state { cx.spawn_weak(|this, mut cx| async move { cx.background().timer(BACKGROUND_INDEXING_DELAY).await; if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) { - this.update(&mut cx, |this, cx| this.index_project(project, cx)); + this.update(&mut cx, |this, cx| { + this.index_project(project, cx).detach_and_log_err(cx) + }); } }) .detach(); @@ -429,109 +436,126 @@ impl SemanticIndex { let worktree_id = worktree.id(); let db = self.db.clone(); let language_registry = self.language_registry.clone(); + let (mut done_tx, done_rx) = watch::channel(); let registration = cx.spawn(|this, mut cx| { async move { - scan_complete.await; - let db_id = db.find_or_create_worktree(worktree_abs_path).await?; - let mut file_mtimes = db.get_file_mtimes(db_id).await?; - let worktree = if let Some(project) = project.upgrade(&cx) { - project - .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx)) - .ok_or_else(|| anyhow!("worktree not found"))? - } else { - return anyhow::Ok(()); - }; - let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot()); - let mut changed_paths = cx - .background() - .spawn(async move { - let mut changed_paths = BTreeMap::new(); - for file in worktree.files(false, 0) { - let absolute_path = worktree.absolutize(&file.path); + let register = async { + scan_complete.await; + let db_id = db.find_or_create_worktree(worktree_abs_path).await?; + let mut file_mtimes = db.get_file_mtimes(db_id).await?; + let worktree = if let Some(project) = project.upgrade(&cx) { + project + .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx)) + .ok_or_else(|| anyhow!("worktree not found"))? + } else { + return anyhow::Ok(()); + }; + let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot()); + let mut changed_paths = cx + .background() + .spawn(async move { + let mut changed_paths = BTreeMap::new(); + for file in worktree.files(false, 0) { + let absolute_path = worktree.absolutize(&file.path); - if file.is_external || file.is_ignored || file.is_symlink { - continue; - } - - if let Ok(language) = language_registry - .language_for_file(&absolute_path, None) - .await - { - // Test if file is valid parseable file - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { + if file.is_external || file.is_ignored || file.is_symlink { continue; } - let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); - let already_stored = stored_mtime - .map_or(false, |existing_mtime| existing_mtime == file.mtime); + if let Ok(language) = language_registry + .language_for_file(&absolute_path, None) + .await + { + // Test if file is valid parseable file + if !PARSEABLE_ENTIRE_FILE_TYPES + .contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } - if !already_stored { - changed_paths.insert( - file.path.clone(), - ChangedPathInfo { - mtime: file.mtime, - is_deleted: false, - }, - ); + let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); + let already_stored = stored_mtime + .map_or(false, |existing_mtime| { + existing_mtime == file.mtime + }); + + if !already_stored { + changed_paths.insert( + file.path.clone(), + ChangedPathInfo { + mtime: file.mtime, + is_deleted: false, + }, + ); + } } } + + // Clean up entries from database that are no longer in the worktree. + for (path, mtime) in file_mtimes { + changed_paths.insert( + path.into(), + ChangedPathInfo { + mtime, + is_deleted: true, + }, + ); + } + + anyhow::Ok(changed_paths) + }) + .await?; + this.update(&mut cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project) + .ok_or_else(|| anyhow!("project not registered"))?; + let project = project + .upgrade(cx) + .ok_or_else(|| anyhow!("project was dropped"))?; + + if let Some(WorktreeState::Registering(state)) = + project_state.worktrees.remove(&worktree_id) + { + changed_paths.extend(state.changed_paths); } + project_state.worktrees.insert( + worktree_id, + WorktreeState::Registered(RegisteredWorktreeState { + db_id, + changed_paths, + }), + ); + this.index_project(project, cx).detach_and_log_err(cx); - // Clean up entries from database that are no longer in the worktree. - for (path, mtime) in file_mtimes { - changed_paths.insert( - path.into(), - ChangedPathInfo { - mtime, - is_deleted: true, - }, - ); - } - - anyhow::Ok(changed_paths) - }) - .await?; - this.update(&mut cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project) - .ok_or_else(|| anyhow!("project not registered"))?; - let project = project - .upgrade(cx) - .ok_or_else(|| anyhow!("project was dropped"))?; - - if let Some(WorktreeState::Registering(state)) = - project_state.worktrees.remove(&worktree_id) - { - changed_paths.extend(state.changed_paths); - } - project_state.worktrees.insert( - worktree_id, - WorktreeState::Registered(RegisteredWorktreeState { - db_id, - changed_paths, - }), - ); - this.index_project(project, cx); + anyhow::Ok(()) + })?; anyhow::Ok(()) - })?; + }; - anyhow::Ok(()) + if register.await.log_err().is_none() { + // Stop tracking this worktree if the registration failed. + this.update(&mut cx, |this, _| { + this.projects.get_mut(&project).map(|project_state| { + project_state.worktrees.remove(&worktree_id); + }); + }) + } + + *done_tx.borrow_mut() = Some(()); } - .log_err() }); project_state.worktrees.insert( worktree_id, WorktreeState::Registering(RegisteringWorktreeState { changed_paths: Default::default(), + done_rx, _registration: registration, }), ); @@ -567,7 +591,7 @@ impl SemanticIndex { // Register new worktrees worktrees.retain(|worktree| { let worktree_id = worktree.read(cx).id(); - project_state.worktree(worktree_id).is_none() + !project_state.worktrees.contains_key(&worktree_id) }); for worktree in worktrees { self.register_worktree(project.clone(), worktree, cx); @@ -595,25 +619,13 @@ impl SemanticIndex { excludes: Vec, cx: &mut ModelContext, ) -> Task>> { - let project_state = if let Some(state) = self.projects.get(&project.downgrade()) { - state - } else { - return Task::ready(Err(anyhow!("project not added"))); - }; - - let worktree_db_ids = project - .read(cx) - .worktrees(cx) - .filter_map(|worktree| { - let worktree_id = worktree.read(cx).id(); - project_state.db_id_for_worktree_id(worktree_id) - }) - .collect::>(); - + let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); let db_path = self.db.path().clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { + index.await?; + let t0 = Instant::now(); let database = VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; @@ -630,6 +642,24 @@ impl SemanticIndex { t0.elapsed().as_millis() ); + let worktree_db_ids = this.read_with(&cx, |this, _| { + let project_state = this + .projects + .get(&project.downgrade()) + .ok_or_else(|| anyhow!("project was not indexed"))?; + let worktree_db_ids = project_state + .worktrees + .values() + .filter_map(|worktree| { + if let WorktreeState::Registered(worktree) = worktree { + Some(worktree.db_id) + } else { + None + } + }) + .collect::>(); + anyhow::Ok(worktree_db_ids) + })?; let file_ids = database .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) .await?; @@ -723,7 +753,11 @@ impl SemanticIndex { }) } - pub fn index_project(&mut self, project: ModelHandle, cx: &mut ModelContext) { + pub fn index_project( + &mut self, + project: ModelHandle, + cx: &mut ModelContext, + ) -> Task> { if !self.projects.contains_key(&project.downgrade()) { log::trace!("Registering Project for Semantic Index"); @@ -740,96 +774,152 @@ impl SemanticIndex { .insert(project.downgrade(), ProjectState::new(subscription)); self.project_worktrees_changed(project.clone(), cx); } - - let project_state = self.projects.get_mut(&project.downgrade()).unwrap(); - - let mut pending_files = Vec::new(); - let mut files_to_delete = Vec::new(); - let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; - project_state - .worktrees - .retain(|worktree_id, worktree_state| { - let worktree = - if let Some(worktree) = project.read(cx).worktree_for_id(*worktree_id, cx) { - worktree - } else { - return false; - }; - let worktree_state = - if let WorktreeState::Registered(worktree_state) = worktree_state { - worktree_state - } else { - return true; - }; - - worktree_state.changed_paths.retain(|path, info| { - if info.is_deleted { - files_to_delete.push((worktree_state.db_id, path.clone())); - } else { - let absolute_path = worktree.read(cx).absolutize(path); - let job_handle = JobHandle::new(&outstanding_job_count_tx); - pending_files.push(PendingFile { - absolute_path, - relative_path: path.clone(), - language: None, - job_handle, - modified_time: info.mtime, - worktree_db_id: worktree_state.db_id, - }); - } - - false - }); - true - }); + let project_state = self.projects.get(&project.downgrade()).unwrap(); + let mut outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone(); let db = self.db.clone(); let language_registry = self.language_registry.clone(); let parsing_files_tx = self.parsing_files_tx.clone(); - cx.background() - .spawn(async move { - for (worktree_db_id, path) in files_to_delete { - db.delete_file(worktree_db_id, path).await.log_err(); - } + let worktree_registration = self.wait_for_worktree_registration(&project, cx); - let embeddings_for_digest = { - let mut files = HashMap::default(); - for pending_file in &pending_files { - files - .entry(pending_file.worktree_db_id) - .or_insert(Vec::new()) - .push(pending_file.relative_path.clone()); - } - Arc::new( - db.embeddings_for_files(files) - .await - .log_err() - .unwrap_or_default(), - ) - }; + cx.spawn(|this, mut cx| async move { + worktree_registration.await?; - for mut pending_file in pending_files { - if let Ok(language) = language_registry - .language_for_file(&pending_file.relative_path, None) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() + let mut pending_files = Vec::new(); + let mut files_to_delete = Vec::new(); + this.update(&mut cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.downgrade()) + .ok_or_else(|| anyhow!("project was dropped"))?; + let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + + project_state + .worktrees + .retain(|worktree_id, worktree_state| { + let worktree = if let Some(worktree) = + project.read(cx).worktree_for_id(*worktree_id, cx) { - continue; - } - pending_file.language = Some(language); + worktree + } else { + return false; + }; + let worktree_state = + if let WorktreeState::Registered(worktree_state) = worktree_state { + worktree_state + } else { + return true; + }; + + worktree_state.changed_paths.retain(|path, info| { + if info.is_deleted { + files_to_delete.push((worktree_state.db_id, path.clone())); + } else { + let absolute_path = worktree.read(cx).absolutize(path); + let job_handle = JobHandle::new(outstanding_job_count_tx); + pending_files.push(PendingFile { + absolute_path, + relative_path: path.clone(), + language: None, + job_handle, + modified_time: info.mtime, + worktree_db_id: worktree_state.db_id, + }); + } + + false + }); + true + }); + + anyhow::Ok(()) + })?; + + cx.background() + .spawn(async move { + for (worktree_db_id, path) in files_to_delete { + db.delete_file(worktree_db_id, path).await.log_err(); } - parsing_files_tx - .try_send((embeddings_for_digest.clone(), pending_file)) - .ok(); + + let embeddings_for_digest = { + let mut files = HashMap::default(); + for pending_file in &pending_files { + files + .entry(pending_file.worktree_db_id) + .or_insert(Vec::new()) + .push(pending_file.relative_path.clone()); + } + Arc::new( + db.embeddings_for_files(files) + .await + .log_err() + .unwrap_or_default(), + ) + }; + + for mut pending_file in pending_files { + if let Ok(language) = language_registry + .language_for_file(&pending_file.relative_path, None) + .await + { + if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) + && &language.name().as_ref() != &"Markdown" + && language + .grammar() + .and_then(|grammar| grammar.embedding_config.as_ref()) + .is_none() + { + continue; + } + pending_file.language = Some(language); + } + parsing_files_tx + .try_send((embeddings_for_digest.clone(), pending_file)) + .ok(); + } + + // Wait until we're done indexing. + while let Some(count) = outstanding_job_count_rx.next().await { + if count == 0 { + break; + } + } + }) + .await; + + Ok(()) + }) + } + + fn wait_for_worktree_registration( + &self, + project: &ModelHandle, + cx: &mut ModelContext, + ) -> Task> { + let project = project.downgrade(); + cx.spawn_weak(|this, cx| async move { + loop { + let mut pending_worktrees = Vec::new(); + this.upgrade(&cx) + .ok_or_else(|| anyhow!("semantic index dropped"))? + .read_with(&cx, |this, _| { + if let Some(project) = this.projects.get(&project) { + for worktree in project.worktrees.values() { + if let WorktreeState::Registering(worktree) = worktree { + pending_worktrees.push(worktree.done()); + } + } + } + }); + + if pending_worktrees.is_empty() { + break; + } else { + future::join_all(pending_worktrees).await; } - }) - .detach() + } + Ok(()) + }) } } diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index ca5f7df533..fe1b6b9cf9 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -87,7 +87,16 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); + let search_results = semantic_index.update(cx, |store, cx| { + store.search_project( + project.clone(), + "aaaaaabbbbzz".to_string(), + 5, + vec![], + vec![], + cx, + ) + }); let pending_file_count = semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap()); deterministic.run_until_parked(); @@ -95,20 +104,7 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*pending_file_count.borrow(), 0); - let search_results = semantic_index - .update(cx, |store, cx| { - store.search_project( - project.clone(), - "aaaaaabbbbzz".to_string(), - 5, - vec![], - vec![], - cx, - ) - }) - .await - .unwrap(); - + let search_results = search_results.await.unwrap(); assert_search_results( &search_results, &[ @@ -185,11 +181,12 @@ async fn test_semantic_index(deterministic: Arc, cx: &mut TestApp deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); let prev_embedding_count = embedding_provider.embedding_count(); - semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); + let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); deterministic.run_until_parked(); assert_eq!(*pending_file_count.borrow(), 1); deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); assert_eq!(*pending_file_count.borrow(), 0); + index.await.unwrap(); assert_eq!( embedding_provider.embedding_count() - prev_embedding_count, From ce62173534cff576776a92e154149153b183936e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 6 Sep 2023 16:48:53 +0200 Subject: [PATCH 44/60] Rename `Document` to `Span` --- crates/semantic_index/src/db.rs | 57 ++++++++-------- crates/semantic_index/src/embedding_queue.rs | 54 +++++++-------- crates/semantic_index/src/parsing.rs | 66 +++++++++---------- crates/semantic_index/src/semantic_index.rs | 32 ++++----- .../src/semantic_index_tests.rs | 12 ++-- 5 files changed, 109 insertions(+), 112 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 5664210388..28bbd56156 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -1,6 +1,6 @@ use crate::{ embedding::Embedding, - parsing::{Document, DocumentDigest}, + parsing::{Span, SpanDigest}, SEMANTIC_INDEX_VERSION, }; use anyhow::{anyhow, Context, Result}; @@ -124,8 +124,8 @@ impl VectorDatabase { } log::trace!("vector database schema out of date. updating..."); - db.execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; + db.execute("DROP TABLE IF EXISTS spans", []) + .context("failed to drop 'spans' table")?; db.execute("DROP TABLE IF EXISTS files", []) .context("failed to drop 'files' table")?; db.execute("DROP TABLE IF EXISTS worktrees", []) @@ -174,7 +174,7 @@ impl VectorDatabase { )?; db.execute( - "CREATE TABLE documents ( + "CREATE TABLE spans ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, start_byte INTEGER NOT NULL, @@ -211,7 +211,7 @@ impl VectorDatabase { worktree_id: i64, path: Arc, mtime: SystemTime, - documents: Vec, + spans: Vec, ) -> impl Future> { self.transact(move |db| { // Return the existing ID, if both the file and mtime match @@ -231,7 +231,7 @@ impl VectorDatabase { let t0 = Instant::now(); let mut query = db.prepare( " - INSERT INTO documents + INSERT INTO spans (file_id, start_byte, end_byte, name, embedding, digest) VALUES (?1, ?2, ?3, ?4, ?5, ?6) ", @@ -241,14 +241,14 @@ impl VectorDatabase { t0.elapsed().as_millis() ); - for document in documents { + for span in spans { query.execute(params![ file_id, - document.range.start.to_string(), - document.range.end.to_string(), - document.name, - document.embedding, - document.digest + span.range.start.to_string(), + span.range.end.to_string(), + span.name, + span.embedding, + span.digest ])?; } @@ -278,13 +278,13 @@ impl VectorDatabase { pub fn embeddings_for_files( &self, worktree_id_file_paths: HashMap>>, - ) -> impl Future>> { + ) -> impl Future>> { self.transact(move |db| { let mut query = db.prepare( " SELECT digest, embedding - FROM documents - LEFT JOIN files ON files.id = documents.file_id + FROM spans + LEFT JOIN files ON files.id = spans.file_id WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) ", )?; @@ -297,10 +297,7 @@ impl VectorDatabase { .collect::>(), ); let rows = query.query_map(params![worktree_id, file_paths], |row| { - Ok(( - row.get::<_, DocumentDigest>(0)?, - row.get::<_, Embedding>(1)?, - )) + Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) })?; for row in rows { @@ -379,7 +376,7 @@ impl VectorDatabase { let file_ids = file_ids.to_vec(); self.transact(move |db| { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - Self::for_each_document(db, &file_ids, |id, embedding| { + Self::for_each_span(db, &file_ids, |id, embedding| { let similarity = embedding.similarity(&query_embedding); let ix = match results.binary_search_by(|(_, s)| { similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) @@ -434,7 +431,7 @@ impl VectorDatabase { }) } - fn for_each_document( + fn for_each_span( db: &rusqlite::Connection, file_ids: &[i64], mut f: impl FnMut(i64, Embedding), @@ -444,7 +441,7 @@ impl VectorDatabase { SELECT id, embedding FROM - documents + spans WHERE file_id IN rarray(?) ", @@ -459,7 +456,7 @@ impl VectorDatabase { Ok(()) } - pub fn get_documents_by_ids( + pub fn spans_for_ids( &self, ids: &[i64], ) -> impl Future)>>> { @@ -468,16 +465,16 @@ impl VectorDatabase { let mut statement = db.prepare( " SELECT - documents.id, + spans.id, files.worktree_id, files.relative_path, - documents.start_byte, - documents.end_byte + spans.start_byte, + spans.end_byte FROM - documents, files + spans, files WHERE - documents.file_id = files.id AND - documents.id in rarray(?) + spans.file_id = files.id AND + spans.id in rarray(?) ", )?; @@ -500,7 +497,7 @@ impl VectorDatabase { for id in &ids { let value = values_by_id .remove(id) - .ok_or(anyhow!("missing document id {}", id))?; + .ok_or(anyhow!("missing span id {}", id))?; results.push(value); } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index f1abbde3a4..024881f0b8 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,4 +1,4 @@ -use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; +use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -9,7 +9,7 @@ pub struct FileToEmbed { pub worktree_id: i64, pub path: Arc, pub mtime: SystemTime, - pub documents: Vec, + pub spans: Vec, pub job_handle: JobHandle, } @@ -19,7 +19,7 @@ impl std::fmt::Debug for FileToEmbed { .field("worktree_id", &self.worktree_id) .field("path", &self.path) .field("mtime", &self.mtime) - .field("document", &self.documents) + .field("spans", &self.spans) .finish_non_exhaustive() } } @@ -29,13 +29,13 @@ impl PartialEq for FileToEmbed { self.worktree_id == other.worktree_id && self.path == other.path && self.mtime == other.mtime - && self.documents == other.documents + && self.spans == other.spans } } pub struct EmbeddingQueue { embedding_provider: Arc, - pending_batch: Vec, + pending_batch: Vec, executor: Arc, pending_batch_token_count: usize, finished_files_tx: channel::Sender, @@ -43,9 +43,9 @@ pub struct EmbeddingQueue { } #[derive(Clone)] -pub struct FileToEmbedFragment { +pub struct FileFragmentToEmbed { file: Arc>, - document_range: Range, + span_range: Range, } impl EmbeddingQueue { @@ -62,41 +62,41 @@ impl EmbeddingQueue { } pub fn push(&mut self, file: FileToEmbed) { - if file.documents.is_empty() { + if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); return; } let file = Arc::new(Mutex::new(file)); - self.pending_batch.push(FileToEmbedFragment { + self.pending_batch.push(FileFragmentToEmbed { file: file.clone(), - document_range: 0..0, + span_range: 0..0, }); - let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; let mut saved_tokens = 0; - for (ix, document) in file.lock().documents.iter().enumerate() { - let document_token_count = if document.embedding.is_none() { - document.token_count + for (ix, span) in file.lock().spans.iter().enumerate() { + let span_token_count = if span.embedding.is_none() { + span.token_count } else { - saved_tokens += document.token_count; + saved_tokens += span.token_count; 0 }; - let next_token_count = self.pending_batch_token_count + document_token_count; + let next_token_count = self.pending_batch_token_count + span_token_count; if next_token_count > self.embedding_provider.max_tokens_per_batch() { let range_end = fragment_range.end; self.flush(); - self.pending_batch.push(FileToEmbedFragment { + self.pending_batch.push(FileFragmentToEmbed { file: file.clone(), - document_range: range_end..range_end, + span_range: range_end..range_end, }); - fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; } fragment_range.end = ix + 1; - self.pending_batch_token_count += document_token_count; + self.pending_batch_token_count += span_token_count; } log::trace!("Saved Tokens: {:?}", saved_tokens); } @@ -113,20 +113,20 @@ impl EmbeddingQueue { self.executor.spawn(async move { let mut spans = Vec::new(); - let mut document_count = 0; + let mut span_count = 0; for fragment in &batch { let file = fragment.file.lock(); - document_count += file.documents[fragment.document_range.clone()].len(); + span_count += file.spans[fragment.span_range.clone()].len(); spans.extend( { - file.documents[fragment.document_range.clone()] + file.spans[fragment.span_range.clone()] .iter().filter(|d| d.embedding.is_none()) .map(|d| d.content.clone()) } ); } - log::trace!("Documents Length: {:?}", document_count); + log::trace!("Documents Length: {:?}", span_count); log::trace!("Span Length: {:?}", spans.clone().len()); // If spans is 0, just send the fragment to the finished files if its the last one. @@ -143,11 +143,11 @@ impl EmbeddingQueue { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { - for document in - &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) + for span in + &mut fragment.file.lock().spans[fragment.span_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) { if let Some(embedding) = embeddings.next() { - document.embedding = Some(embedding); + span.embedding = Some(embedding); } else { // log::error!("number of embeddings returned different from number of documents"); diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index c0a94c6b73..b6fc000e1d 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -16,9 +16,9 @@ use std::{ use tree_sitter::{Parser, QueryCursor}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct DocumentDigest([u8; 20]); +pub struct SpanDigest([u8; 20]); -impl FromSql for DocumentDigest { +impl FromSql for SpanDigest { fn column_result(value: ValueRef) -> FromSqlResult { let blob = value.as_blob()?; let bytes = @@ -27,17 +27,17 @@ impl FromSql for DocumentDigest { expected_size: 20, blob_size: blob.len(), })?; - return Ok(DocumentDigest(bytes)); + return Ok(SpanDigest(bytes)); } } -impl ToSql for DocumentDigest { +impl ToSql for SpanDigest { fn to_sql(&self) -> rusqlite::Result { self.0.to_sql() } } -impl From<&'_ str> for DocumentDigest { +impl From<&'_ str> for SpanDigest { fn from(value: &'_ str) -> Self { let mut sha1 = Sha1::new(); sha1.update(value); @@ -46,12 +46,12 @@ impl From<&'_ str> for DocumentDigest { } #[derive(Debug, PartialEq, Clone)] -pub struct Document { +pub struct Span { pub name: String, pub range: Range, pub content: String, pub embedding: Option, - pub digest: DocumentDigest, + pub digest: SpanDigest, pub token_count: usize, } @@ -97,14 +97,14 @@ impl CodeContextRetriever { relative_path: &Path, language_name: Arc, content: &str, - ) -> Result> { + ) -> Result> { let document_span = ENTIRE_FILE_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) .replace("", &content); - let digest = DocumentDigest::from(document_span.as_str()); + let digest = SpanDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { + Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: Default::default(), @@ -114,13 +114,13 @@ impl CodeContextRetriever { }]) } - fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result> { + fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result> { let document_span = MARKDOWN_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", &content); - let digest = DocumentDigest::from(document_span.as_str()); + let digest = SpanDigest::from(document_span.as_str()); let (document_span, token_count) = self.embedding_provider.truncate(&document_span); - Ok(vec![Document { + Ok(vec![Span { range: 0..content.len(), content: document_span, embedding: None, @@ -191,32 +191,32 @@ impl CodeContextRetriever { relative_path: &Path, content: &str, language: Arc, - ) -> Result> { + ) -> Result> { let language_name = language.name(); if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { return self.parse_entire_file(relative_path, language_name, &content); - } else if &language_name.to_string() == &"Markdown".to_string() { + } else if language_name.as_ref() == "Markdown" { return self.parse_markdown_file(relative_path, &content); } - let mut documents = self.parse_file(content, language)?; - for document in &mut documents { + let mut spans = self.parse_file(content, language)?; + for span in &mut spans { let document_content = CODE_CONTEXT_TEMPLATE .replace("", relative_path.to_string_lossy().as_ref()) .replace("", language_name.as_ref()) - .replace("item", &document.content); + .replace("item", &span.content); let (document_content, token_count) = self.embedding_provider.truncate(&document_content); - document.content = document_content; - document.token_count = token_count; + span.content = document_content; + span.token_count = token_count; } - Ok(documents) + Ok(spans) } - pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { + pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { let grammar = language .grammar() .ok_or_else(|| anyhow!("no grammar for language"))?; @@ -227,7 +227,7 @@ impl CodeContextRetriever { let language_scope = language.default_scope(); let placeholder = language_scope.collapsed_placeholder(); - let mut documents = Vec::new(); + let mut spans = Vec::new(); let mut collapsed_ranges_within = Vec::new(); let mut parsed_name_ranges = HashSet::new(); for (i, context_match) in matches.iter().enumerate() { @@ -267,22 +267,22 @@ impl CodeContextRetriever { collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); - let mut document_content = String::new(); + let mut span_content = String::new(); for context_range in &context_match.context_ranges { add_content_from_range( - &mut document_content, + &mut span_content, content, context_range.clone(), context_match.start_col, ); - document_content.push_str("\n"); + span_content.push_str("\n"); } let mut offset = item_range.start; for collapsed_range in &collapsed_ranges_within { if collapsed_range.start > offset { add_content_from_range( - &mut document_content, + &mut span_content, content, offset..collapsed_range.start, context_match.start_col, @@ -291,24 +291,24 @@ impl CodeContextRetriever { } if collapsed_range.end > offset { - document_content.push_str(placeholder); + span_content.push_str(placeholder); offset = collapsed_range.end; } } if offset < item_range.end { add_content_from_range( - &mut document_content, + &mut span_content, content, offset..item_range.end, context_match.start_col, ); } - let sha1 = DocumentDigest::from(document_content.as_str()); - documents.push(Document { + let sha1 = SpanDigest::from(span_content.as_str()); + spans.push(Span { name, - content: document_content, + content: span_content, range: item_range.clone(), embedding: None, digest: sha1, @@ -316,7 +316,7 @@ impl CodeContextRetriever { }) } - return Ok(documents); + return Ok(spans); } } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a098152784..1c1c40fa27 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -17,7 +17,7 @@ use futures::{future, FutureExt, StreamExt}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use parking_lot::Mutex; -use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES}; +use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use postage::watch; use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use smol::channel; @@ -36,7 +36,7 @@ use util::{ ResultExt, }; -const SEMANTIC_INDEX_VERSION: usize = 9; +const SEMANTIC_INDEX_VERSION: usize = 10; const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); @@ -84,7 +84,7 @@ pub struct SemanticIndex { db: VectorDatabase, embedding_provider: Arc, language_registry: Arc, - parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, + parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, @@ -252,16 +252,16 @@ impl SemanticIndex { let db = db.clone(); async move { while let Ok(file) = embedded_files.recv().await { - db.insert_file(file.worktree_id, file.path, file.mtime, file.documents) + db.insert_file(file.worktree_id, file.path, file.mtime, file.spans) .await .log_err(); } } }); - // Parse files into embeddable documents. + // Parse files into embeddable spans. let (parsing_files_tx, parsing_files_rx) = - channel::unbounded::<(Arc>, PendingFile)>(); + channel::unbounded::<(Arc>, PendingFile)>(); let embedding_queue = Arc::new(Mutex::new(embedding_queue)); let mut _parsing_files_tasks = Vec::new(); for _ in 0..cx.background().num_cpus() { @@ -320,26 +320,26 @@ impl SemanticIndex { pending_file: PendingFile, retriever: &mut CodeContextRetriever, embedding_queue: &Arc>, - embeddings_for_digest: &HashMap, + embeddings_for_digest: &HashMap, ) { let Some(language) = pending_file.language else { return; }; if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(mut documents) = retriever + if let Some(mut spans) = retriever .parse_file_with_template(&pending_file.relative_path, &content, language) .log_err() { log::trace!( - "parsed path {:?}: {} documents", + "parsed path {:?}: {} spans", pending_file.relative_path, - documents.len() + spans.len() ); - for document in documents.iter_mut() { - if let Some(embedding) = embeddings_for_digest.get(&document.digest) { - document.embedding = Some(embedding.to_owned()); + for span in &mut spans { + if let Some(embedding) = embeddings_for_digest.get(&span.digest) { + span.embedding = Some(embedding.to_owned()); } } @@ -348,7 +348,7 @@ impl SemanticIndex { path: pending_file.relative_path, mtime: pending_file.modified_time, job_handle: pending_file.job_handle, - documents, + spans: spans, }); } } @@ -708,13 +708,13 @@ impl SemanticIndex { } let ids = results.into_iter().map(|(id, _)| id).collect::>(); - let documents = database.get_documents_by_ids(ids.as_slice()).await?; + let spans = database.spans_for_ids(ids.as_slice()).await?; let mut tasks = Vec::new(); let mut ranges = Vec::new(); let weak_project = project.downgrade(); project.update(&mut cx, |project, cx| { - for (worktree_db_id, file_path, byte_range) in documents { + for (worktree_db_id, file_path, byte_range) in spans { let project_state = if let Some(state) = this.read(cx).projects.get(&weak_project) { state diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index fe1b6b9cf9..ffd8db8781 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,7 +1,7 @@ use crate::{ embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}, embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest}, + parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest}, semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; @@ -204,15 +204,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { worktree_id: 5, path: Path::new(&format!("path-{file_ix}")).into(), mtime: SystemTime::now(), - documents: (0..rng.gen_range(4..22)) + spans: (0..rng.gen_range(4..22)) .map(|document_ix| { let content_len = rng.gen_range(10..100); let content = RandomCharIter::new(&mut rng) .with_simple_text() .take(content_len) .collect::(); - let digest = DocumentDigest::from(content.as_str()); - Document { + let digest = SpanDigest::from(content.as_str()); + Span { range: 0..10, embedding: None, name: format!("document {document_ix}"), @@ -245,7 +245,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { .iter() .map(|file| { let mut file = file.clone(); - for doc in &mut file.documents { + for doc in &mut file.spans { doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); } file @@ -437,7 +437,7 @@ async fn test_code_context_retrieval_json() { } fn assert_documents_eq( - documents: &[Document], + documents: &[Span], expected_contents_and_start_offsets: &[(String, usize)], ) { assert_eq!( From 8d672f5d4cd532217209e6826728afc641e7c6d9 Mon Sep 17 00:00:00 2001 From: Julia Date: Wed, 9 Aug 2023 21:32:41 -0400 Subject: [PATCH 45/60] Remove NodeRuntime static & add fake implementation for tests --- Cargo.lock | 1 + crates/copilot/src/copilot.rs | 13 +- .../LiveKitBridge/Package.resolved | 4 +- crates/node_runtime/Cargo.toml | 1 + crates/node_runtime/src/node_runtime.rs | 159 +++++++++++------- crates/zed/src/languages.rs | 2 +- crates/zed/src/languages/css.rs | 12 +- crates/zed/src/languages/html.rs | 12 +- crates/zed/src/languages/json.rs | 12 +- crates/zed/src/languages/php.rs | 12 +- crates/zed/src/languages/python.rs | 12 +- crates/zed/src/languages/svelte.rs | 12 +- crates/zed/src/languages/tailwind.rs | 12 +- crates/zed/src/languages/typescript.rs | 22 +-- crates/zed/src/languages/yaml.rs | 15 +- crates/zed/src/main.rs | 4 +- crates/zed/src/zed.rs | 5 +- 17 files changed, 179 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8786c8ed6e..05cd0ec21c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4582,6 +4582,7 @@ dependencies = [ "anyhow", "async-compression", "async-tar", + "async-trait", "futures 0.3.28", "gpui", "log", diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 499ae2e808..28c20d95bb 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -41,7 +41,7 @@ actions!( [Suggest, NextSuggestion, PreviousSuggestion, Reinstall] ); -pub fn init(http: Arc, node_runtime: Arc, cx: &mut AppContext) { +pub fn init(http: Arc, node_runtime: Arc, cx: &mut AppContext) { let copilot = cx.add_model({ let node_runtime = node_runtime.clone(); move |cx| Copilot::start(http, node_runtime, cx) @@ -265,7 +265,7 @@ pub struct Completion { pub struct Copilot { http: Arc, - node_runtime: Arc, + node_runtime: Arc, server: CopilotServer, buffers: HashSet>, } @@ -299,7 +299,7 @@ impl Copilot { fn start( http: Arc, - node_runtime: Arc, + node_runtime: Arc, cx: &mut ModelContext, ) -> Self { let mut this = Self { @@ -335,12 +335,15 @@ impl Copilot { #[cfg(any(test, feature = "test-support"))] pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle, lsp::FakeLanguageServer) { + use node_runtime::FakeNodeRuntime; + let (server, fake_server) = LanguageServer::fake("copilot".into(), Default::default(), cx.to_async()); let http = util::http::FakeHttpClient::create(|_| async { unreachable!() }); + let node_runtime = FakeNodeRuntime::new(); let this = cx.add_model(|_| Self { http: http.clone(), - node_runtime: NodeRuntime::instance(http), + node_runtime, server: CopilotServer::Running(RunningCopilotServer { lsp: Arc::new(server), sign_in_status: SignInStatus::Authorized, @@ -353,7 +356,7 @@ impl Copilot { fn start_language_server( http: Arc, - node_runtime: Arc, + node_runtime: Arc, this: ModelHandle, mut cx: AsyncAppContext, ) -> impl Future { diff --git a/crates/live_kit_client/LiveKitBridge/Package.resolved b/crates/live_kit_client/LiveKitBridge/Package.resolved index 85ae088565..b925bc8f0d 100644 --- a/crates/live_kit_client/LiveKitBridge/Package.resolved +++ b/crates/live_kit_client/LiveKitBridge/Package.resolved @@ -42,8 +42,8 @@ "repositoryURL": "https://github.com/apple/swift-protobuf.git", "state": { "branch": null, - "revision": "0af9125c4eae12a4973fb66574c53a54962a9e1e", - "version": "1.21.0" + "revision": "ce20dc083ee485524b802669890291c0d8090170", + "version": "1.22.1" } } ] diff --git a/crates/node_runtime/Cargo.toml b/crates/node_runtime/Cargo.toml index 53635f2725..2b9503468a 100644 --- a/crates/node_runtime/Cargo.toml +++ b/crates/node_runtime/Cargo.toml @@ -14,6 +14,7 @@ util = { path = "../util" } async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] } async-tar = "0.4.2" futures.workspace = true +async-trait.workspace = true anyhow.workspace = true parking_lot.workspace = true serde.workspace = true diff --git a/crates/node_runtime/src/node_runtime.rs b/crates/node_runtime/src/node_runtime.rs index d43c14ec7b..820a8b6f81 100644 --- a/crates/node_runtime/src/node_runtime.rs +++ b/crates/node_runtime/src/node_runtime.rs @@ -7,14 +7,12 @@ use std::process::{Output, Stdio}; use std::{ env::consts, path::{Path, PathBuf}, - sync::{Arc, OnceLock}, + sync::Arc, }; use util::http::HttpClient; const VERSION: &str = "v18.15.0"; -static RUNTIME_INSTANCE: OnceLock> = OnceLock::new(); - #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct NpmInfo { @@ -28,23 +26,88 @@ pub struct NpmInfoDistTags { latest: Option, } -pub struct NodeRuntime { +#[async_trait::async_trait] +pub trait NodeRuntime: Send + Sync { + async fn binary_path(&self) -> Result; + + async fn run_npm_subcommand( + &self, + directory: Option<&Path>, + subcommand: &str, + args: &[&str], + ) -> Result; + + async fn npm_package_latest_version(&self, name: &str) -> Result; + + async fn npm_install_packages(&self, directory: &Path, packages: &[(&str, &str)]) + -> Result<()>; +} + +pub struct RealNodeRuntime { http: Arc, } -impl NodeRuntime { - pub fn instance(http: Arc) -> Arc { - RUNTIME_INSTANCE - .get_or_init(|| Arc::new(NodeRuntime { http })) - .clone() +impl RealNodeRuntime { + pub fn new(http: Arc) -> Arc { + Arc::new(RealNodeRuntime { http }) } - pub async fn binary_path(&self) -> Result { + async fn install_if_needed(&self) -> Result { + log::info!("Node runtime install_if_needed"); + + let arch = match consts::ARCH { + "x86_64" => "x64", + "aarch64" => "arm64", + other => bail!("Running on unsupported platform: {other}"), + }; + + let folder_name = format!("node-{VERSION}-darwin-{arch}"); + let node_containing_dir = util::paths::SUPPORT_DIR.join("node"); + let node_dir = node_containing_dir.join(folder_name); + let node_binary = node_dir.join("bin/node"); + let npm_file = node_dir.join("bin/npm"); + + let result = Command::new(&node_binary) + .arg(npm_file) + .arg("--version") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .await; + let valid = matches!(result, Ok(status) if status.success()); + + if !valid { + _ = fs::remove_dir_all(&node_containing_dir).await; + fs::create_dir(&node_containing_dir) + .await + .context("error creating node containing dir")?; + + let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz"); + let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}"); + let mut response = self + .http + .get(&url, Default::default(), true) + .await + .context("error downloading Node binary tarball")?; + + let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); + let archive = Archive::new(decompressed_bytes); + archive.unpack(&node_containing_dir).await?; + } + + anyhow::Ok(node_dir) + } +} + +#[async_trait::async_trait] +impl NodeRuntime for RealNodeRuntime { + async fn binary_path(&self) -> Result { let installation_path = self.install_if_needed().await?; Ok(installation_path.join("bin/node")) } - pub async fn run_npm_subcommand( + async fn run_npm_subcommand( &self, directory: Option<&Path>, subcommand: &str, @@ -106,7 +169,7 @@ impl NodeRuntime { output.map_err(|e| anyhow!("{e}")) } - pub async fn npm_package_latest_version(&self, name: &str) -> Result { + async fn npm_package_latest_version(&self, name: &str) -> Result { let output = self .run_npm_subcommand( None, @@ -131,10 +194,10 @@ impl NodeRuntime { .ok_or_else(|| anyhow!("no version found for npm package {}", name)) } - pub async fn npm_install_packages( + async fn npm_install_packages( &self, directory: &Path, - packages: impl IntoIterator, + packages: &[(&str, &str)], ) -> Result<()> { let packages: Vec<_> = packages .into_iter() @@ -155,51 +218,31 @@ impl NodeRuntime { .await?; Ok(()) } +} - async fn install_if_needed(&self) -> Result { - log::info!("Node runtime install_if_needed"); +pub struct FakeNodeRuntime; - let arch = match consts::ARCH { - "x86_64" => "x64", - "aarch64" => "arm64", - other => bail!("Running on unsupported platform: {other}"), - }; - - let folder_name = format!("node-{VERSION}-darwin-{arch}"); - let node_containing_dir = util::paths::SUPPORT_DIR.join("node"); - let node_dir = node_containing_dir.join(folder_name); - let node_binary = node_dir.join("bin/node"); - let npm_file = node_dir.join("bin/npm"); - - let result = Command::new(&node_binary) - .arg(npm_file) - .arg("--version") - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .await; - let valid = matches!(result, Ok(status) if status.success()); - - if !valid { - _ = fs::remove_dir_all(&node_containing_dir).await; - fs::create_dir(&node_containing_dir) - .await - .context("error creating node containing dir")?; - - let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz"); - let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}"); - let mut response = self - .http - .get(&url, Default::default(), true) - .await - .context("error downloading Node binary tarball")?; - - let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); - let archive = Archive::new(decompressed_bytes); - archive.unpack(&node_containing_dir).await?; - } - - anyhow::Ok(node_dir) +impl FakeNodeRuntime { + pub fn new() -> Arc { + Arc::new(FakeNodeRuntime) + } +} + +#[async_trait::async_trait] +impl NodeRuntime for FakeNodeRuntime { + async fn binary_path(&self) -> Result { + unreachable!() + } + + async fn run_npm_subcommand(&self, _: Option<&Path>, _: &str, _: &[&str]) -> Result { + unreachable!() + } + + async fn npm_package_latest_version(&self, _: &str) -> Result { + unreachable!() + } + + async fn npm_install_packages(&self, _: &Path, _: &[(&str, &str)]) -> Result<()> { + unreachable!() } } diff --git a/crates/zed/src/languages.rs b/crates/zed/src/languages.rs index f0b8a1444a..3fbb5aa14f 100644 --- a/crates/zed/src/languages.rs +++ b/crates/zed/src/languages.rs @@ -37,7 +37,7 @@ mod yaml; #[exclude = "*.rs"] struct LanguageDir; -pub fn init(languages: Arc, node_runtime: Arc) { +pub fn init(languages: Arc, node_runtime: Arc) { let language = |name, grammar, adapters| { languages.register(name, load_config(name), grammar, adapters, load_queries) }; diff --git a/crates/zed/src/languages/css.rs b/crates/zed/src/languages/css.rs index f2103050f3..fdbc179209 100644 --- a/crates/zed/src/languages/css.rs +++ b/crates/zed/src/languages/css.rs @@ -22,11 +22,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct CssLspAdapter { - node: Arc, + node: Arc, } impl CssLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { CssLspAdapter { node } } } @@ -65,7 +65,7 @@ impl LspAdapter for CssLspAdapter { self.node .npm_install_packages( &container_dir, - [("vscode-langservers-extracted", version.as_str())], + &[("vscode-langservers-extracted", version.as_str())], ) .await?; } @@ -81,14 +81,14 @@ impl LspAdapter for CssLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -100,7 +100,7 @@ impl LspAdapter for CssLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/html.rs b/crates/zed/src/languages/html.rs index cfb6a5dde9..b8f1c70cce 100644 --- a/crates/zed/src/languages/html.rs +++ b/crates/zed/src/languages/html.rs @@ -22,11 +22,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct HtmlLspAdapter { - node: Arc, + node: Arc, } impl HtmlLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { HtmlLspAdapter { node } } } @@ -65,7 +65,7 @@ impl LspAdapter for HtmlLspAdapter { self.node .npm_install_packages( &container_dir, - [("vscode-langservers-extracted", version.as_str())], + &[("vscode-langservers-extracted", version.as_str())], ) .await?; } @@ -81,14 +81,14 @@ impl LspAdapter for HtmlLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -100,7 +100,7 @@ impl LspAdapter for HtmlLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/json.rs b/crates/zed/src/languages/json.rs index 049549ac5d..63f909ae2a 100644 --- a/crates/zed/src/languages/json.rs +++ b/crates/zed/src/languages/json.rs @@ -27,12 +27,12 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct JsonLspAdapter { - node: Arc, + node: Arc, languages: Arc, } impl JsonLspAdapter { - pub fn new(node: Arc, languages: Arc) -> Self { + pub fn new(node: Arc, languages: Arc) -> Self { JsonLspAdapter { node, languages } } } @@ -71,7 +71,7 @@ impl LspAdapter for JsonLspAdapter { self.node .npm_install_packages( &container_dir, - [("vscode-json-languageserver", version.as_str())], + &[("vscode-json-languageserver", version.as_str())], ) .await?; } @@ -87,14 +87,14 @@ impl LspAdapter for JsonLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -148,7 +148,7 @@ impl LspAdapter for JsonLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/php.rs b/crates/zed/src/languages/php.rs index 73bb4b019c..3096fd16e6 100644 --- a/crates/zed/src/languages/php.rs +++ b/crates/zed/src/languages/php.rs @@ -23,14 +23,14 @@ fn intelephense_server_binary_arguments(server_path: &Path) -> Vec { pub struct IntelephenseVersion(String); pub struct IntelephenseLspAdapter { - node: Arc, + node: Arc, } impl IntelephenseLspAdapter { const SERVER_PATH: &'static str = "node_modules/intelephense/lib/intelephense.js"; #[allow(unused)] - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { Self { node } } } @@ -65,7 +65,7 @@ impl LspAdapter for IntelephenseLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("intelephense", version.0.as_str())]) + .npm_install_packages(&container_dir, &[("intelephense", version.0.as_str())]) .await?; } Ok(LanguageServerBinary { @@ -79,14 +79,14 @@ impl LspAdapter for IntelephenseLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn label_for_completion( @@ -107,7 +107,7 @@ impl LspAdapter for IntelephenseLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/python.rs b/crates/zed/src/languages/python.rs index 956cf49551..c1539e9590 100644 --- a/crates/zed/src/languages/python.rs +++ b/crates/zed/src/languages/python.rs @@ -20,11 +20,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct PythonLspAdapter { - node: Arc, + node: Arc, } impl PythonLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { PythonLspAdapter { node } } } @@ -57,7 +57,7 @@ impl LspAdapter for PythonLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("pyright", version.as_str())]) + .npm_install_packages(&container_dir, &[("pyright", version.as_str())]) .await?; } @@ -72,14 +72,14 @@ impl LspAdapter for PythonLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn process_completion(&self, item: &mut lsp::CompletionItem) { @@ -162,7 +162,7 @@ impl LspAdapter for PythonLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/svelte.rs b/crates/zed/src/languages/svelte.rs index 35665e864f..5e42d80e77 100644 --- a/crates/zed/src/languages/svelte.rs +++ b/crates/zed/src/languages/svelte.rs @@ -21,11 +21,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct SvelteLspAdapter { - node: Arc, + node: Arc, } impl SvelteLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { SvelteLspAdapter { node } } } @@ -64,7 +64,7 @@ impl LspAdapter for SvelteLspAdapter { self.node .npm_install_packages( &container_dir, - [("svelte-language-server", version.as_str())], + &[("svelte-language-server", version.as_str())], ) .await?; } @@ -80,14 +80,14 @@ impl LspAdapter for SvelteLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -99,7 +99,7 @@ impl LspAdapter for SvelteLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/tailwind.rs b/crates/zed/src/languages/tailwind.rs index 12a0a4e3b8..cf07fa71c9 100644 --- a/crates/zed/src/languages/tailwind.rs +++ b/crates/zed/src/languages/tailwind.rs @@ -26,11 +26,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct TailwindLspAdapter { - node: Arc, + node: Arc, } impl TailwindLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { TailwindLspAdapter { node } } } @@ -69,7 +69,7 @@ impl LspAdapter for TailwindLspAdapter { self.node .npm_install_packages( &container_dir, - [("@tailwindcss/language-server", version.as_str())], + &[("@tailwindcss/language-server", version.as_str())], ) .await?; } @@ -85,14 +85,14 @@ impl LspAdapter for TailwindLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn initialization_options(&self) -> Option { @@ -131,7 +131,7 @@ impl LspAdapter for TailwindLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/languages/typescript.rs b/crates/zed/src/languages/typescript.rs index 27074e164b..676d0fd4c0 100644 --- a/crates/zed/src/languages/typescript.rs +++ b/crates/zed/src/languages/typescript.rs @@ -33,14 +33,14 @@ fn eslint_server_binary_arguments(server_path: &Path) -> Vec { } pub struct TypeScriptLspAdapter { - node: Arc, + node: Arc, } impl TypeScriptLspAdapter { const OLD_SERVER_PATH: &'static str = "node_modules/typescript-language-server/lib/cli.js"; const NEW_SERVER_PATH: &'static str = "node_modules/typescript-language-server/lib/cli.mjs"; - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { TypeScriptLspAdapter { node } } } @@ -86,7 +86,7 @@ impl LspAdapter for TypeScriptLspAdapter { self.node .npm_install_packages( &container_dir, - [ + &[ ("typescript", version.typescript_version.as_str()), ( "typescript-language-server", @@ -108,14 +108,14 @@ impl LspAdapter for TypeScriptLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_ts_server_binary(container_dir, &self.node).await + get_cached_ts_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_ts_server_binary(container_dir, &self.node).await + get_cached_ts_server_binary(container_dir, &*self.node).await } fn code_action_kinds(&self) -> Option> { @@ -165,7 +165,7 @@ impl LspAdapter for TypeScriptLspAdapter { async fn get_cached_ts_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let old_server_path = container_dir.join(TypeScriptLspAdapter::OLD_SERVER_PATH); @@ -192,14 +192,14 @@ async fn get_cached_ts_server_binary( } pub struct EsLintLspAdapter { - node: Arc, + node: Arc, } impl EsLintLspAdapter { const SERVER_PATH: &'static str = "vscode-eslint/server/out/eslintServer.js"; #[allow(unused)] - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { EsLintLspAdapter { node } } } @@ -288,14 +288,14 @@ impl LspAdapter for EsLintLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_eslint_server_binary(container_dir, &self.node).await + get_cached_eslint_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_eslint_server_binary(container_dir, &self.node).await + get_cached_eslint_server_binary(container_dir, &*self.node).await } async fn label_for_completion( @@ -313,7 +313,7 @@ impl LspAdapter for EsLintLspAdapter { async fn get_cached_eslint_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { // This is unfortunate but we don't know what the version is to build a path directly diff --git a/crates/zed/src/languages/yaml.rs b/crates/zed/src/languages/yaml.rs index 21155cc231..8b438d0949 100644 --- a/crates/zed/src/languages/yaml.rs +++ b/crates/zed/src/languages/yaml.rs @@ -25,11 +25,11 @@ fn server_binary_arguments(server_path: &Path) -> Vec { } pub struct YamlLspAdapter { - node: Arc, + node: Arc, } impl YamlLspAdapter { - pub fn new(node: Arc) -> Self { + pub fn new(node: Arc) -> Self { YamlLspAdapter { node } } } @@ -66,7 +66,10 @@ impl LspAdapter for YamlLspAdapter { if fs::metadata(&server_path).await.is_err() { self.node - .npm_install_packages(&container_dir, [("yaml-language-server", version.as_str())]) + .npm_install_packages( + &container_dir, + &[("yaml-language-server", version.as_str())], + ) .await?; } @@ -81,14 +84,14 @@ impl LspAdapter for YamlLspAdapter { container_dir: PathBuf, _: &dyn LspAdapterDelegate, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } async fn installation_test_binary( &self, container_dir: PathBuf, ) -> Option { - get_cached_server_binary(container_dir, &self.node).await + get_cached_server_binary(container_dir, &*self.node).await } fn workspace_configuration(&self, cx: &mut AppContext) -> BoxFuture<'static, Value> { let tab_size = all_language_settings(None, cx) @@ -109,7 +112,7 @@ impl LspAdapter for YamlLspAdapter { async fn get_cached_server_binary( container_dir: PathBuf, - node: &NodeRuntime, + node: &dyn NodeRuntime, ) -> Option { (|| async move { let mut last_version_dir = None; diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 3e0a8a7a07..f78a4f6419 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -19,7 +19,7 @@ use gpui::{Action, App, AppContext, AssetSource, AsyncAppContext, Task}; use isahc::{config::Configurable, Request}; use language::{LanguageRegistry, Point}; use log::LevelFilter; -use node_runtime::NodeRuntime; +use node_runtime::RealNodeRuntime; use parking_lot::Mutex; use project::Fs; use serde::{Deserialize, Serialize}; @@ -138,7 +138,7 @@ fn main() { languages.set_executor(cx.background().clone()); languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone()); let languages = Arc::new(languages); - let node_runtime = NodeRuntime::instance(http.clone()); + let node_runtime = RealNodeRuntime::new(http.clone()); languages::init(languages.clone(), node_runtime.clone()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx)); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index ba8fa840f5..424bce60f2 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -723,7 +723,6 @@ mod tests { AppContext, AssetSource, Element, Entity, TestAppContext, View, ViewHandle, }; use language::LanguageRegistry; - use node_runtime::NodeRuntime; use project::{Project, ProjectPath}; use serde_json::json; use settings::{handle_settings_file_changes, watch_config_file, SettingsStore}; @@ -732,7 +731,6 @@ mod tests { path::{Path, PathBuf}, }; use theme::{ThemeRegistry, ThemeSettings}; - use util::http::FakeHttpClient; use workspace::{ item::{Item, ItemHandle}, open_new, open_paths, pane, NewFile, SplitDirection, WorkspaceHandle, @@ -2364,8 +2362,7 @@ mod tests { let mut languages = LanguageRegistry::test(); languages.set_executor(cx.background().clone()); let languages = Arc::new(languages); - let http = FakeHttpClient::with_404_response(); - let node_runtime = NodeRuntime::instance(http); + let node_runtime = node_runtime::FakeNodeRuntime::new(); languages::init(languages.clone(), node_runtime); for name in languages.language_names() { languages.language_for_name(&name); From 29e35531af141472ca10824114e69d1969c0e026 Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Wed, 6 Sep 2023 12:52:23 -0400 Subject: [PATCH 46/60] Temporarily comment out cargo check commands --- script/bump-zed-minor-versions | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/script/bump-zed-minor-versions b/script/bump-zed-minor-versions index 8dcf7e334e..2f1ad7345e 100755 --- a/script/bump-zed-minor-versions +++ b/script/bump-zed-minor-versions @@ -31,7 +31,7 @@ preview_tag_name="v${major}.${minor}.${patch}-pre" git fetch origin ${prev_minor_branch_name}:${prev_minor_branch_name} git fetch origin --tags -cargo check -q +# cargo check -q function cleanup { git checkout -q main @@ -89,7 +89,7 @@ git checkout -q main git clean -q -dff old_main_sha=$(git rev-parse HEAD) cargo set-version --package zed --bump minor -cargo check -q +# cargo check -q git commit -q --all --message "${next_minor_branch_name} dev" cat < Date: Wed, 6 Sep 2023 12:52:41 -0400 Subject: [PATCH 47/60] v0.104.x dev --- Cargo.lock | 2 +- crates/zed/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 05cd0ec21c..eb8ba7675f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9761,7 +9761,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.103.0" +version = "0.104.0" dependencies = [ "activity_indicator", "ai", diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 66d55b38f0..e102a66519 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] description = "The fast, collaborative code editor." edition = "2021" name = "zed" -version = "0.103.0" +version = "0.104.0" publish = false [lib] From 5b5c232cd13f087698da51e5ecd0104a1a10ee9a Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Wed, 6 Sep 2023 12:54:53 -0400 Subject: [PATCH 48/60] Revert "Temporarily comment out cargo check commands" This reverts commit 29e35531af141472ca10824114e69d1969c0e026. --- script/bump-zed-minor-versions | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/script/bump-zed-minor-versions b/script/bump-zed-minor-versions index 2f1ad7345e..8dcf7e334e 100755 --- a/script/bump-zed-minor-versions +++ b/script/bump-zed-minor-versions @@ -31,7 +31,7 @@ preview_tag_name="v${major}.${minor}.${patch}-pre" git fetch origin ${prev_minor_branch_name}:${prev_minor_branch_name} git fetch origin --tags -# cargo check -q +cargo check -q function cleanup { git checkout -q main @@ -89,7 +89,7 @@ git checkout -q main git clean -q -dff old_main_sha=$(git rev-parse HEAD) cargo set-version --package zed --bump minor -# cargo check -q +cargo check -q git commit -q --all --message "${next_minor_branch_name} dev" cat < Date: Wed, 6 Sep 2023 13:33:39 -0400 Subject: [PATCH 49/60] collab 0.20.0 --- Cargo.lock | 2 +- crates/collab/Cargo.toml | 2 +- crates/live_kit_client/LiveKitBridge/Package.resolved | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb8ba7675f..0b11bce1fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1453,7 +1453,7 @@ dependencies = [ [[package]] name = "collab" -version = "0.19.0" +version = "0.20.0" dependencies = [ "anyhow", "async-tungstenite", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 914e3f2dfb..fbdfbd2fe3 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -3,7 +3,7 @@ authors = ["Nathan Sobo "] default-run = "collab" edition = "2021" name = "collab" -version = "0.19.0" +version = "0.20.0" publish = false [[bin]] diff --git a/crates/live_kit_client/LiveKitBridge/Package.resolved b/crates/live_kit_client/LiveKitBridge/Package.resolved index b925bc8f0d..85ae088565 100644 --- a/crates/live_kit_client/LiveKitBridge/Package.resolved +++ b/crates/live_kit_client/LiveKitBridge/Package.resolved @@ -42,8 +42,8 @@ "repositoryURL": "https://github.com/apple/swift-protobuf.git", "state": { "branch": null, - "revision": "ce20dc083ee485524b802669890291c0d8090170", - "version": "1.22.1" + "revision": "0af9125c4eae12a4973fb66574c53a54962a9e1e", + "version": "1.21.0" } } ] From 17237f748ce984a6285fe91ace515ba5830e4916 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 6 Sep 2023 15:09:15 -0400 Subject: [PATCH 50/60] update token_count for OpenAIEmbeddings to accomodate for truncation --- crates/semantic_index/src/embedding.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 97c25ca170..8140e244bd 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -181,18 +181,17 @@ impl EmbeddingProvider for OpenAIEmbeddings { fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { + let output = if tokens.len() > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER - .decode(tokens) + .decode(tokens.clone()) .ok() .unwrap_or_else(|| span.to_string()) } else { span.to_string() }; - (output, token_count) + (output, tokens.len()) } async fn embed_batch(&self, spans: Vec) -> Result> { From 265d02a583b01dd7b5e829f77efbfd415bec603e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 6 Sep 2023 15:09:46 -0400 Subject: [PATCH 51/60] update request timeout for open ai embeddings --- crates/semantic_index/src/embedding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 8140e244bd..7228738525 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -203,7 +203,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { .ok_or_else(|| anyhow!("no api key"))?; let mut request_number = 0; - let mut request_timeout: u64 = 10; + let mut request_timeout: u64 = 15; let mut response: Response; while request_number < MAX_RETRIES { response = self From 66c3879306cb00f319ed78604cf689466e4f8ed8 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 6 Sep 2023 09:27:57 -0700 Subject: [PATCH 52/60] Extract randomized test infrastructure for use in other tests --- Cargo.lock | 1 + crates/collab/Cargo.toml | 1 + crates/collab/src/tests.rs | 557 +---- .../src/tests/random_channel_buffer_tests.rs | 49 + .../random_project_collaboration_tests.rs | 1573 ++++++++++++ .../src/tests/randomized_integration_tests.rs | 2199 ----------------- .../src/tests/randomized_test_helpers.rs | 694 ++++++ crates/collab/src/tests/test_server.rs | 551 +++++ crates/gpui_macros/src/gpui_macros.rs | 10 +- 9 files changed, 2887 insertions(+), 2748 deletions(-) create mode 100644 crates/collab/src/tests/random_channel_buffer_tests.rs create mode 100644 crates/collab/src/tests/random_project_collaboration_tests.rs delete mode 100644 crates/collab/src/tests/randomized_integration_tests.rs create mode 100644 crates/collab/src/tests/randomized_test_helpers.rs create mode 100644 crates/collab/src/tests/test_server.rs diff --git a/Cargo.lock b/Cargo.lock index a185542c63..4f68c54433 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1456,6 +1456,7 @@ name = "collab" version = "0.19.0" dependencies = [ "anyhow", + "async-trait", "async-tungstenite", "audio", "axum", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 914e3f2dfb..0346558407 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -80,6 +80,7 @@ theme = { path = "../theme" } workspace = { path = "../workspace", features = ["test-support"] } collab_ui = { path = "../collab_ui", features = ["test-support"] } +async-trait.workspace = true ctor.workspace = true env_logger.workspace = true indoc.workspace = true diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 25f059c0aa..3000f0d8c3 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -1,555 +1,18 @@ -use crate::{ - db::{tests::TestDb, NewUserParams, UserId}, - executor::Executor, - rpc::{Server, CLEANUP_TIMEOUT}, - AppState, -}; -use anyhow::anyhow; -use call::{ActiveCall, Room}; -use channel::ChannelStore; -use client::{ - self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, -}; -use collections::{HashMap, HashSet}; -use fs::FakeFs; -use futures::{channel::oneshot, StreamExt as _}; -use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; -use language::LanguageRegistry; -use parking_lot::Mutex; -use project::{Project, WorktreeId}; -use settings::SettingsStore; -use std::{ - cell::{Ref, RefCell, RefMut}, - env, - ops::{Deref, DerefMut}, - path::Path, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, - Arc, - }, -}; -use util::http::FakeHttpClient; -use workspace::Workspace; +use call::Room; +use gpui::{ModelHandle, TestAppContext}; mod channel_buffer_tests; mod channel_tests; mod integration_tests; -mod randomized_integration_tests; +mod random_channel_buffer_tests; +mod random_project_collaboration_tests; +mod randomized_test_helpers; +mod test_server; -struct TestServer { - app_state: Arc, - server: Arc, - connection_killers: Arc>>>, - forbid_connections: Arc, - _test_db: TestDb, - test_live_kit_server: Arc, -} - -impl TestServer { - async fn start(deterministic: &Arc) -> Self { - static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - - let use_postgres = env::var("USE_POSTGRES").ok(); - let use_postgres = use_postgres.as_deref(); - let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { - TestDb::postgres(deterministic.build_background()) - } else { - TestDb::sqlite(deterministic.build_background()) - }; - let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); - let live_kit_server = live_kit_client::TestServer::create( - format!("http://livekit.{}.test", live_kit_server_id), - format!("devkey-{}", live_kit_server_id), - format!("secret-{}", live_kit_server_id), - deterministic.build_background(), - ) - .unwrap(); - let app_state = Self::build_app_state(&test_db, &live_kit_server).await; - let epoch = app_state - .db - .create_server(&app_state.config.zed_environment) - .await - .unwrap(); - let server = Server::new( - epoch, - app_state.clone(), - Executor::Deterministic(deterministic.build_background()), - ); - server.start().await.unwrap(); - // Advance clock to ensure the server's cleanup task is finished. - deterministic.advance_clock(CLEANUP_TIMEOUT); - Self { - app_state, - server, - connection_killers: Default::default(), - forbid_connections: Default::default(), - _test_db: test_db, - test_live_kit_server: live_kit_server, - } - } - - async fn reset(&self) { - self.app_state.db.reset(); - let epoch = self - .app_state - .db - .create_server(&self.app_state.config.zed_environment) - .await - .unwrap(); - self.server.reset(epoch); - } - - async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { - cx.update(|cx| { - if cx.has_global::() { - panic!("Same cx used to create two test clients") - } - cx.set_global(SettingsStore::test(cx)); - }); - - let http = FakeHttpClient::with_404_response(); - let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await - { - user.id - } else { - self.app_state - .db - .create_user( - &format!("{name}@example.com"), - false, - NewUserParams { - github_login: name.into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .expect("creating user failed") - .user_id - }; - let client_name = name.to_string(); - let mut client = cx.read(|cx| Client::new(http.clone(), cx)); - let server = self.server.clone(); - let db = self.app_state.db.clone(); - let connection_killers = self.connection_killers.clone(); - let forbid_connections = self.forbid_connections.clone(); - - Arc::get_mut(&mut client) - .unwrap() - .set_id(user_id.0 as usize) - .override_authenticate(move |cx| { - cx.spawn(|_| async move { - let access_token = "the-token".to_string(); - Ok(Credentials { - user_id: user_id.0 as u64, - access_token, - }) - }) - }) - .override_establish_connection(move |credentials, cx| { - assert_eq!(credentials.user_id, user_id.0 as u64); - assert_eq!(credentials.access_token, "the-token"); - - let server = server.clone(); - let db = db.clone(); - let connection_killers = connection_killers.clone(); - let forbid_connections = forbid_connections.clone(); - let client_name = client_name.clone(); - cx.spawn(move |cx| async move { - if forbid_connections.load(SeqCst) { - Err(EstablishConnectionError::other(anyhow!( - "server is forbidding connections" - ))) - } else { - let (client_conn, server_conn, killed) = - Connection::in_memory(cx.background()); - let (connection_id_tx, connection_id_rx) = oneshot::channel(); - let user = db - .get_user_by_id(user_id) - .await - .expect("retrieving user failed") - .unwrap(); - cx.background() - .spawn(server.handle_connection( - server_conn, - client_name, - user, - Some(connection_id_tx), - Executor::Deterministic(cx.background()), - )) - .detach(); - let connection_id = connection_id_rx.await.unwrap(); - connection_killers - .lock() - .insert(connection_id.into(), killed); - Ok(client_conn) - } - }) - }); - - let fs = FakeFs::new(cx.background()); - let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); - let channel_store = - cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx)); - let app_state = Arc::new(workspace::AppState { - client: client.clone(), - user_store: user_store.clone(), - channel_store: channel_store.clone(), - languages: Arc::new(LanguageRegistry::test()), - fs: fs.clone(), - build_window_options: |_, _, _| Default::default(), - initialize_workspace: |_, _, _, _| Task::ready(Ok(())), - background_actions: || &[], - }); - - cx.update(|cx| { - theme::init((), cx); - Project::init(&client, cx); - client::init(&client, cx); - language::init(cx); - editor::init_settings(cx); - workspace::init(app_state.clone(), cx); - audio::init((), cx); - call::init(client.clone(), user_store.clone(), cx); - channel::init(&client); - }); - - client - .authenticate_and_connect(false, &cx.to_async()) - .await - .unwrap(); - - let client = TestClient { - app_state, - username: name.to_string(), - state: Default::default(), - }; - client.wait_for_current_user(cx).await; - client - } - - fn disconnect_client(&self, peer_id: PeerId) { - self.connection_killers - .lock() - .remove(&peer_id) - .unwrap() - .store(true, SeqCst); - } - - fn forbid_connections(&self) { - self.forbid_connections.store(true, SeqCst); - } - - fn allow_connections(&self) { - self.forbid_connections.store(false, SeqCst); - } - - async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { - for ix in 1..clients.len() { - let (left, right) = clients.split_at_mut(ix); - let (client_a, cx_a) = left.last_mut().unwrap(); - for (client_b, cx_b) in right { - client_a - .app_state - .user_store - .update(*cx_a, |store, cx| { - store.request_contact(client_b.user_id().unwrap(), cx) - }) - .await - .unwrap(); - cx_a.foreground().run_until_parked(); - client_b - .app_state - .user_store - .update(*cx_b, |store, cx| { - store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) - }) - .await - .unwrap(); - } - } - } - - async fn make_channel( - &self, - channel: &str, - admin: (&TestClient, &mut TestAppContext), - members: &mut [(&TestClient, &mut TestAppContext)], - ) -> u64 { - let (admin_client, admin_cx) = admin; - let channel_id = admin_client - .app_state - .channel_store - .update(admin_cx, |channel_store, cx| { - channel_store.create_channel(channel, None, cx) - }) - .await - .unwrap(); - - for (member_client, member_cx) in members { - admin_client - .app_state - .channel_store - .update(admin_cx, |channel_store, cx| { - channel_store.invite_member( - channel_id, - member_client.user_id().unwrap(), - false, - cx, - ) - }) - .await - .unwrap(); - - admin_cx.foreground().run_until_parked(); - - member_client - .app_state - .channel_store - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) - }) - .await - .unwrap(); - } - - channel_id - } - - async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { - self.make_contacts(clients).await; - - let (left, right) = clients.split_at_mut(1); - let (_client_a, cx_a) = &mut left[0]; - let active_call_a = cx_a.read(ActiveCall::global); - - for (client_b, cx_b) in right { - let user_id_b = client_b.current_user_id(*cx_b).to_proto(); - active_call_a - .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx)) - .await - .unwrap(); - - cx_b.foreground().run_until_parked(); - let active_call_b = cx_b.read(ActiveCall::global); - active_call_b - .update(*cx_b, |call, cx| call.accept_incoming(cx)) - .await - .unwrap(); - } - } - - async fn build_app_state( - test_db: &TestDb, - fake_server: &live_kit_client::TestServer, - ) -> Arc { - Arc::new(AppState { - db: test_db.db().clone(), - live_kit_client: Some(Arc::new(fake_server.create_api_client())), - config: Default::default(), - }) - } -} - -impl Deref for TestServer { - type Target = Server; - - fn deref(&self) -> &Self::Target { - &self.server - } -} - -impl Drop for TestServer { - fn drop(&mut self) { - self.server.teardown(); - self.test_live_kit_server.teardown().unwrap(); - } -} - -struct TestClient { - username: String, - state: RefCell, - app_state: Arc, -} - -#[derive(Default)] -struct TestClientState { - local_projects: Vec>, - remote_projects: Vec>, - buffers: HashMap, HashSet>>, -} - -impl Deref for TestClient { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.app_state.client - } -} - -struct ContactsSummary { - pub current: Vec, - pub outgoing_requests: Vec, - pub incoming_requests: Vec, -} - -impl TestClient { - pub fn fs(&self) -> &FakeFs { - self.app_state.fs.as_fake() - } - - pub fn channel_store(&self) -> &ModelHandle { - &self.app_state.channel_store - } - - pub fn user_store(&self) -> &ModelHandle { - &self.app_state.user_store - } - - pub fn language_registry(&self) -> &Arc { - &self.app_state.languages - } - - pub fn client(&self) -> &Arc { - &self.app_state.client - } - - pub fn current_user_id(&self, cx: &TestAppContext) -> UserId { - UserId::from_proto( - self.app_state - .user_store - .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), - ) - } - - async fn wait_for_current_user(&self, cx: &TestAppContext) { - let mut authed_user = self - .app_state - .user_store - .read_with(cx, |user_store, _| user_store.watch_current_user()); - while authed_user.next().await.unwrap().is_none() {} - } - - async fn clear_contacts(&self, cx: &mut TestAppContext) { - self.app_state - .user_store - .update(cx, |store, _| store.clear_contacts()) - .await; - } - - fn local_projects<'a>(&'a self) -> impl Deref>> + 'a { - Ref::map(self.state.borrow(), |state| &state.local_projects) - } - - fn remote_projects<'a>(&'a self) -> impl Deref>> + 'a { - Ref::map(self.state.borrow(), |state| &state.remote_projects) - } - - fn local_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects) - } - - fn remote_projects_mut<'a>(&'a self) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects) - } - - fn buffers_for_project<'a>( - &'a self, - project: &ModelHandle, - ) -> impl DerefMut>> + 'a { - RefMut::map(self.state.borrow_mut(), |state| { - state.buffers.entry(project.clone()).or_default() - }) - } - - fn buffers<'a>( - &'a self, - ) -> impl DerefMut, HashSet>>> + 'a - { - RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) - } - - fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { - self.app_state - .user_store - .read_with(cx, |store, _| ContactsSummary { - current: store - .contacts() - .iter() - .map(|contact| contact.user.github_login.clone()) - .collect(), - outgoing_requests: store - .outgoing_contact_requests() - .iter() - .map(|user| user.github_login.clone()) - .collect(), - incoming_requests: store - .incoming_contact_requests() - .iter() - .map(|user| user.github_login.clone()) - .collect(), - }) - } - - async fn build_local_project( - &self, - root_path: impl AsRef, - cx: &mut TestAppContext, - ) -> (ModelHandle, WorktreeId) { - let project = cx.update(|cx| { - Project::local( - self.client().clone(), - self.app_state.user_store.clone(), - self.app_state.languages.clone(), - self.app_state.fs.clone(), - cx, - ) - }); - let (worktree, _) = project - .update(cx, |p, cx| { - p.find_or_create_local_worktree(root_path, true, cx) - }) - .await - .unwrap(); - worktree - .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) - .await; - (project, worktree.read_with(cx, |tree, _| tree.id())) - } - - async fn build_remote_project( - &self, - host_project_id: u64, - guest_cx: &mut TestAppContext, - ) -> ModelHandle { - let active_call = guest_cx.read(ActiveCall::global); - let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone()); - room.update(guest_cx, |room, cx| { - room.join_project( - host_project_id, - self.app_state.languages.clone(), - self.app_state.fs.clone(), - cx, - ) - }) - .await - .unwrap() - } - - fn build_workspace( - &self, - project: &ModelHandle, - cx: &mut TestAppContext, - ) -> WindowHandle { - cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) - } -} - -impl Drop for TestClient { - fn drop(&mut self) { - self.app_state.client.teardown(); - } -} +pub use randomized_test_helpers::{ + run_randomized_test, save_randomized_test_plan, RandomizedTest, TestError, UserTestPlan, +}; +pub use test_server::{TestClient, TestServer}; #[derive(Debug, Eq, PartialEq)] struct RoomParticipants { diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs new file mode 100644 index 0000000000..929e567977 --- /dev/null +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -0,0 +1,49 @@ +use crate::tests::{run_randomized_test, RandomizedTest, TestClient, TestError, UserTestPlan}; +use anyhow::Result; +use async_trait::async_trait; +use gpui::{executor::Deterministic, TestAppContext}; +use rand::rngs::StdRng; +use serde_derive::{Deserialize, Serialize}; +use std::{rc::Rc, sync::Arc}; + +#[gpui::test] +async fn test_random_channel_buffers( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + run_randomized_test::(cx, deterministic, rng).await; +} + +struct RandomChannelBufferTest; + +#[derive(Clone, Serialize, Deserialize)] +enum ChannelBufferOperation { + Join, +} + +#[async_trait(?Send)] +impl RandomizedTest for RandomChannelBufferTest { + type Operation = ChannelBufferOperation; + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ChannelBufferOperation { + ChannelBufferOperation::Join + } + + async fn apply_operation( + client: &TestClient, + operation: ChannelBufferOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + Ok(()) + } + + async fn on_client_added(client: &Rc) {} + + fn on_clients_quiesced(clients: &[(Rc, TestAppContext)]) {} +} diff --git a/crates/collab/src/tests/random_project_collaboration_tests.rs b/crates/collab/src/tests/random_project_collaboration_tests.rs new file mode 100644 index 0000000000..242cfbc162 --- /dev/null +++ b/crates/collab/src/tests/random_project_collaboration_tests.rs @@ -0,0 +1,1573 @@ +use crate::{ + db::UserId, + tests::{run_randomized_test, RandomizedTest, TestClient, TestError, UserTestPlan}, +}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use call::ActiveCall; +use collections::{BTreeMap, HashMap}; +use editor::Bias; +use fs::{repository::GitFileStatus, FakeFs, Fs as _}; +use futures::StreamExt; +use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; +use language::{range_to_lsp, FakeLspAdapter, Language, LanguageConfig, PointUtf16}; +use lsp::FakeLanguageServer; +use pretty_assertions::assert_eq; +use project::{search::SearchQuery, Project, ProjectPath}; +use rand::{ + distributions::{Alphanumeric, DistString}, + prelude::*, +}; +use serde::{Deserialize, Serialize}; +use std::{ + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use util::ResultExt; + +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] +async fn test_random_project_collaboration( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + run_randomized_test::(cx, deterministic, rng).await; +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ClientOperation { + AcceptIncomingCall, + RejectIncomingCall, + LeaveCall, + InviteContactToCall { + user_id: UserId, + }, + OpenLocalProject { + first_root_name: String, + }, + OpenRemoteProject { + host_id: UserId, + first_root_name: String, + }, + AddWorktreeToProject { + project_root_name: String, + new_root_path: PathBuf, + }, + CloseRemoteProject { + project_root_name: String, + }, + OpenBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SearchProject { + project_root_name: String, + is_local: bool, + query: String, + detach: bool, + }, + EditBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + edits: Vec<(Range, Arc)>, + }, + CloseBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + }, + SaveBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + detach: bool, + }, + RequestLspDataInBuffer { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + offset: usize, + kind: LspRequestKind, + detach: bool, + }, + CreateWorktreeEntry { + project_root_name: String, + is_local: bool, + full_path: PathBuf, + is_dir: bool, + }, + WriteFsEntry { + path: PathBuf, + is_dir: bool, + content: String, + }, + GitOperation { + operation: GitOperation, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum GitOperation { + WriteGitIndex { + repo_path: PathBuf, + contents: Vec<(PathBuf, String)>, + }, + WriteGitBranch { + repo_path: PathBuf, + new_branch: Option, + }, + WriteGitStatuses { + repo_path: PathBuf, + statuses: Vec<(PathBuf, GitFileStatus)>, + git_operation: bool, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum LspRequestKind { + Rename, + Completion, + CodeAction, + Definition, + Highlights, +} + +struct ProjectCollaborationTest; + +#[async_trait(?Send)] +impl RandomizedTest for ProjectCollaborationTest { + type Operation = ClientOperation; + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> ClientOperation { + let call = cx.read(ActiveCall::global); + loop { + match rng.gen_range(0..100_u32) { + // Mutate the call + 0..=29 => { + // Respond to an incoming call + if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { + break if rng.gen_bool(0.7) { + ClientOperation::AcceptIncomingCall + } else { + ClientOperation::RejectIncomingCall + }; + } + + match rng.gen_range(0..100_u32) { + // Invite a contact to the current call + 0..=70 => { + let available_contacts = + client.user_store().read_with(cx, |user_store, _| { + user_store + .contacts() + .iter() + .filter(|contact| contact.online && !contact.busy) + .cloned() + .collect::>() + }); + if !available_contacts.is_empty() { + let contact = available_contacts.choose(rng).unwrap(); + break ClientOperation::InviteContactToCall { + user_id: UserId(contact.user.id as i32), + }; + } + } + + // Leave the current call + 71.. => { + if plan.allow_client_disconnection + && call.read_with(cx, |call, _| call.room().is_some()) + { + break ClientOperation::LeaveCall; + } + } + } + } + + // Mutate projects + 30..=59 => match rng.gen_range(0..100_u32) { + // Open a new project + 0..=70 => { + // Open a remote project + if let Some(room) = call.read_with(cx, |call, _| call.room().cloned()) { + let existing_remote_project_ids = cx.read(|cx| { + client + .remote_projects() + .iter() + .map(|p| p.read(cx).remote_id().unwrap()) + .collect::>() + }); + let new_remote_projects = room.read_with(cx, |room, _| { + room.remote_participants() + .values() + .flat_map(|participant| { + participant.projects.iter().filter_map(|project| { + if existing_remote_project_ids.contains(&project.id) { + None + } else { + Some(( + UserId::from_proto(participant.user.id), + project.worktree_root_names[0].clone(), + )) + } + }) + }) + .collect::>() + }); + if !new_remote_projects.is_empty() { + let (host_id, first_root_name) = + new_remote_projects.choose(rng).unwrap().clone(); + break ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + }; + } + } + // Open a local project + else { + let first_root_name = plan.next_root_dir_name(); + break ClientOperation::OpenLocalProject { first_root_name }; + } + } + + // Close a remote project + 71..=80 => { + if !client.remote_projects().is_empty() { + let project = client.remote_projects().choose(rng).unwrap().clone(); + let first_root_name = root_name_for_project(&project, cx); + break ClientOperation::CloseRemoteProject { + project_root_name: first_root_name, + }; + } + } + + // Mutate project worktrees + 81.. => match rng.gen_range(0..100_u32) { + // Add a worktree to a local project + 0..=50 => { + let Some(project) = client.local_projects().choose(rng).cloned() else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let mut paths = client.fs().paths(false); + paths.remove(0); + let new_root_path = if paths.is_empty() || rng.gen() { + Path::new("/").join(&plan.next_root_dir_name()) + } else { + paths.choose(rng).unwrap().clone() + }; + break ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + }; + } + + // Add an entry to a worktree + _ => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + && worktree.root_entry().map_or(false, |e| e.is_dir()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let is_dir = rng.gen::(); + let mut full_path = + worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); + full_path.push(gen_file_name(rng)); + if !is_dir { + full_path.set_extension("rs"); + } + break ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + }; + } + }, + }, + + // Query and mutate buffers + 60..=90 => { + let Some(project) = choose_random_project(client, rng) else { + continue; + }; + let project_root_name = root_name_for_project(&project, cx); + let is_local = project.read_with(cx, |project, _| project.is_local()); + + match rng.gen_range(0..100_u32) { + // Manipulate an existing buffer + 0..=70 => { + let Some(buffer) = client + .buffers_for_project(&project) + .iter() + .choose(rng) + .cloned() + else { + continue; + }; + + let full_path = buffer + .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + match rng.gen_range(0..100_u32) { + // Close the buffer + 0..=15 => { + break ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + }; + } + // Save the buffer + 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { + let detach = rng.gen_bool(0.3); + break ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + }; + } + // Edit the buffer + 30..=69 => { + let edits = buffer + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + break ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + }; + } + // Make an LSP request + _ => { + let offset = buffer.read_with(cx, |buffer, _| { + buffer.clip_offset( + rng.gen_range(0..=buffer.len()), + language::Bias::Left, + ) + }); + let detach = rng.gen(); + break ClientOperation::RequestLspDataInBuffer { + project_root_name, + full_path, + offset, + is_local, + kind: match rng.gen_range(0..5_u32) { + 0 => LspRequestKind::Rename, + 1 => LspRequestKind::Highlights, + 2 => LspRequestKind::Definition, + 3 => LspRequestKind::CodeAction, + 4.. => LspRequestKind::Completion, + }, + detach, + }; + } + } + } + + 71..=80 => { + let query = rng.gen_range('a'..='z').to_string(); + let detach = rng.gen_bool(0.3); + break ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + }; + } + + // Open a buffer + 81.. => { + let worktree = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter(|worktree| { + let worktree = worktree.read(cx); + worktree.is_visible() + && worktree.entries(false).any(|e| e.is_file()) + }) + .choose(rng) + }); + let Some(worktree) = worktree else { continue }; + let full_path = worktree.read_with(cx, |worktree, _| { + let entry = worktree + .entries(false) + .filter(|e| e.is_file()) + .choose(rng) + .unwrap(); + if entry.path.as_ref() == Path::new("") { + Path::new(worktree.root_name()).into() + } else { + Path::new(worktree.root_name()).join(&entry.path) + } + }); + break ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + }; + } + } + } + + // Update a git related action + 91..=95 => { + break ClientOperation::GitOperation { + operation: generate_git_operation(rng, client), + }; + } + + // Create or update a file or directory + 96.. => { + let is_dir = rng.gen::(); + let content; + let mut path; + let dir_paths = client.fs().directories(false); + + if is_dir { + content = String::new(); + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + } else { + content = Alphanumeric.sample_string(rng, 16); + + // Create a new file or overwrite an existing file + let file_paths = client.fs().files(); + if file_paths.is_empty() || rng.gen_bool(0.5) { + path = dir_paths.choose(rng).unwrap().clone(); + path.push(gen_file_name(rng)); + path.set_extension("rs"); + } else { + path = file_paths.choose(rng).unwrap().clone() + }; + } + break ClientOperation::WriteFsEntry { + path, + is_dir, + content, + }; + } + } + } + } + + async fn apply_operation( + client: &TestClient, + operation: ClientOperation, + cx: &mut TestAppContext, + ) -> Result<(), TestError> { + match operation { + ClientOperation::AcceptIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: accepting incoming call", client.username); + active_call + .update(cx, |call, cx| call.accept_incoming(cx)) + .await?; + } + + ClientOperation::RejectIncomingCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: declining incoming call", client.username); + active_call.update(cx, |call, cx| call.decline_incoming(cx))?; + } + + ClientOperation::LeaveCall => { + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_none()) { + Err(TestError::Inapplicable)?; + } + + log::info!("{}: hanging up", client.username); + active_call.update(cx, |call, cx| call.hang_up(cx)).await?; + } + + ClientOperation::InviteContactToCall { user_id } => { + let active_call = cx.read(ActiveCall::global); + + log::info!("{}: inviting {}", client.username, user_id,); + active_call + .update(cx, |call, cx| call.invite(user_id.to_proto(), None, cx)) + .await + .log_err(); + } + + ClientOperation::OpenLocalProject { first_root_name } => { + log::info!( + "{}: opening local project at {:?}", + client.username, + first_root_name + ); + + let root_path = Path::new("/").join(&first_root_name); + client.fs().create_dir(&root_path).await.unwrap(); + client + .fs() + .create_file(&root_path.join("main.rs"), Default::default()) + .await + .unwrap(); + let project = client.build_local_project(root_path, cx).await.0; + ensure_project_shared(&project, client, cx).await; + client.local_projects_mut().push(project.clone()); + } + + ClientOperation::AddWorktreeToProject { + project_root_name, + new_root_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: finding/creating local worktree at {:?} to project with root path {}", + client.username, + new_root_path, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + if !client.fs().paths(false).contains(&new_root_path) { + client.fs().create_dir(&new_root_path).await.unwrap(); + } + project + .update(cx, |project, cx| { + project.find_or_create_local_worktree(&new_root_path, true, cx) + }) + .await + .unwrap(); + } + + ClientOperation::CloseRemoteProject { project_root_name } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing remote project with root path {}", + client.username, + project_root_name, + ); + + let ix = client + .remote_projects() + .iter() + .position(|p| p == &project) + .unwrap(); + cx.update(|_| { + client.remote_projects_mut().remove(ix); + client.buffers().retain(|p, _| *p != project); + drop(project); + }); + } + + ClientOperation::OpenRemoteProject { + host_id, + first_root_name, + } => { + let active_call = cx.read(ActiveCall::global); + let project = active_call + .update(cx, |call, cx| { + let room = call.room().cloned()?; + let participant = room + .read(cx) + .remote_participants() + .get(&host_id.to_proto())?; + let project_id = participant + .projects + .iter() + .find(|project| project.worktree_root_names[0] == first_root_name)? + .id; + Some(room.update(cx, |room, cx| { + room.join_project( + project_id, + client.language_registry().clone(), + FakeFs::new(cx.background().clone()), + cx, + ) + })) + }) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: joining remote project of user {}, root name {}", + client.username, + host_id, + first_root_name, + ); + + let project = project.await?; + client.remote_projects_mut().push(project.clone()); + } + + ClientOperation::CreateWorktreeEntry { + project_root_name, + is_local, + full_path, + is_dir, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: creating {} at path {:?} in {} project {}", + client.username, + if is_dir { "dir" } else { "file" }, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + project + .update(cx, |p, cx| p.create_entry(project_path, is_dir, cx)) + .unwrap() + .await?; + } + + ClientOperation::OpenBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let project_path = project_path_for_full_path(&project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: opening buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + ); + + ensure_project_shared(&project, client, cx).await; + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await?; + client.buffers_for_project(&project).insert(buffer); + } + + ClientOperation::EditBuffer { + project_root_name, + is_local, + full_path, + edits, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: editing buffer {:?} in {} project {} with {:?}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + edits + ); + + ensure_project_shared(&project, client, cx).await; + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + } + + ClientOperation::CloseBuffer { + project_root_name, + is_local, + full_path, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: closing buffer {:?} in {} project {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name + ); + + ensure_project_shared(&project, client, cx).await; + cx.update(|_| { + client.buffers_for_project(&project).remove(&buffer); + drop(buffer); + }); + } + + ClientOperation::SaveBuffer { + project_root_name, + is_local, + full_path, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: saving buffer {:?} in {} project {}, {}", + client.username, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + ensure_project_shared(&project, client, cx).await; + let requested_version = buffer.read_with(cx, |buffer, _| buffer.version()); + let save = + project.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + let save = cx.spawn(|cx| async move { + save.await + .map_err(|err| anyhow!("save request failed: {:?}", err))?; + assert!(buffer + .read_with(&cx, |buffer, _| { buffer.saved_version().to_owned() }) + .observed_all(&requested_version)); + anyhow::Ok(()) + }); + if detach { + cx.update(|cx| save.detach_and_log_err(cx)); + } else { + save.await?; + } + } + + ClientOperation::RequestLspDataInBuffer { + project_root_name, + is_local, + full_path, + offset, + kind, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + let buffer = buffer_for_full_path(client, &project, &full_path, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: request LSP {:?} for buffer {:?} in {} project {}, {}", + client.username, + kind, + full_path, + if is_local { "local" } else { "remote" }, + project_root_name, + if detach { "detaching" } else { "awaiting" } + ); + + use futures::{FutureExt as _, TryFutureExt as _}; + let offset = buffer.read_with(cx, |b, _| b.clip_offset(offset, Bias::Left)); + let request = cx.foreground().spawn(project.update(cx, |project, cx| { + match kind { + LspRequestKind::Rename => project + .prepare_rename(buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Completion => project + .completions(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::CodeAction => project + .code_actions(&buffer, offset..offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Definition => project + .definition(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + LspRequestKind::Highlights => project + .document_highlights(&buffer, offset, cx) + .map_ok(|_| ()) + .boxed(), + } + })); + if detach { + request.detach(); + } else { + request.await?; + } + } + + ClientOperation::SearchProject { + project_root_name, + is_local, + query, + detach, + } => { + let project = project_for_root_name(client, &project_root_name, cx) + .ok_or(TestError::Inapplicable)?; + + log::info!( + "{}: search {} project {} for {:?}, {}", + client.username, + if is_local { "local" } else { "remote" }, + project_root_name, + query, + if detach { "detaching" } else { "awaiting" } + ); + + let mut search = project.update(cx, |project, cx| { + project.search( + SearchQuery::text(query, false, false, Vec::new(), Vec::new()), + cx, + ) + }); + drop(project); + let search = cx.background().spawn(async move { + let mut results = HashMap::default(); + while let Some((buffer, ranges)) = search.next().await { + results.entry(buffer).or_insert(ranges); + } + results + }); + search.await; + } + + ClientOperation::WriteFsEntry { + path, + is_dir, + content, + } => { + if !client + .fs() + .directories(false) + .contains(&path.parent().unwrap().to_owned()) + { + return Err(TestError::Inapplicable); + } + + if is_dir { + log::info!("{}: creating dir at {:?}", client.username, path); + client.fs().create_dir(&path).await.unwrap(); + } else { + let exists = client.fs().metadata(&path).await?.is_some(); + let verb = if exists { "updating" } else { "creating" }; + log::info!("{}: {} file at {:?}", verb, client.username, path); + + client + .fs() + .save(&path, &content.as_str().into(), text::LineEnding::Unix) + .await + .unwrap(); + } + } + + ClientOperation::GitOperation { operation } => match operation { + GitOperation::WriteGitIndex { + repo_path, + contents, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + for (path, _) in contents.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git index for repo {:?}: {:?}", + client.username, + repo_path, + contents + ); + + let dot_git_dir = repo_path.join(".git"); + let contents = contents + .iter() + .map(|(path, contents)| (path.as_path(), contents.clone())) + .collect::>(); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client.fs().set_index_for_repo(&dot_git_dir, &contents); + } + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + + log::info!( + "{}: writing git branch for repo {:?}: {:?}", + client.username, + repo_path, + new_branch + ); + + let dot_git_dir = repo_path.join(".git"); + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + client + .fs() + .set_branch_name(&dot_git_dir, new_branch.clone()); + } + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } => { + if !client.fs().directories(false).contains(&repo_path) { + return Err(TestError::Inapplicable); + } + for (path, _) in statuses.iter() { + if !client.fs().files().contains(&repo_path.join(path)) { + return Err(TestError::Inapplicable); + } + } + + log::info!( + "{}: writing git statuses for repo {:?}: {:?}", + client.username, + repo_path, + statuses + ); + + let dot_git_dir = repo_path.join(".git"); + + let statuses = statuses + .iter() + .map(|(path, val)| (path.as_path(), val.clone())) + .collect::>(); + + if client.fs().metadata(&dot_git_dir).await?.is_none() { + client.fs().create_dir(&dot_git_dir).await?; + } + + if git_operation { + client.fs().set_status_for_repo_via_git_operation( + &dot_git_dir, + statuses.as_slice(), + ); + } else { + client.fs().set_status_for_repo_via_working_copy_change( + &dot_git_dir, + statuses.as_slice(), + ); + } + } + }, + } + Ok(()) + } + + async fn on_client_added(client: &Rc) { + let mut language = Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + None, + ); + language + .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { + name: "the-fake-language-server", + capabilities: lsp::LanguageServer::full_capabilities(), + initializer: Some(Box::new({ + let fs = client.app_state.fs.clone(); + move |fake_server: &mut FakeLanguageServer| { + fake_server.handle_request::( + |_, _| async move { + Ok(Some(lsp::CompletionResponse::Array(vec![ + lsp::CompletionItem { + text_edit: Some(lsp::CompletionTextEdit::Edit( + lsp::TextEdit { + range: lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 0), + ), + new_text: "the-new-text".to_string(), + }, + )), + ..Default::default() + }, + ]))) + }, + ); + + fake_server.handle_request::( + |_, _| async move { + Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( + lsp::CodeAction { + title: "the-code-action".to_string(), + ..Default::default() + }, + )])) + }, + ); + + fake_server.handle_request::( + |params, _| async move { + Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( + params.position, + params.position, + )))) + }, + ); + + fake_server.handle_request::({ + let fs = fs.clone(); + move |_, cx| { + let background = cx.background(); + let mut rng = background.rng(); + let count = rng.gen_range::(1..3); + let files = fs.as_fake().files(); + let files = (0..count) + .map(|_| files.choose(&mut *rng).unwrap().clone()) + .collect::>(); + async move { + log::info!("LSP: Returning definitions in files {:?}", &files); + Ok(Some(lsp::GotoDefinitionResponse::Array( + files + .into_iter() + .map(|file| lsp::Location { + uri: lsp::Url::from_file_path(file).unwrap(), + range: Default::default(), + }) + .collect(), + ))) + } + } + }); + + fake_server.handle_request::( + move |_, cx| { + let mut highlights = Vec::new(); + let background = cx.background(); + let mut rng = background.rng(); + + let highlight_count = rng.gen_range(1..=5); + for _ in 0..highlight_count { + let start_row = rng.gen_range(0..100); + let start_column = rng.gen_range(0..100); + let end_row = rng.gen_range(0..100); + let end_column = rng.gen_range(0..100); + let start = PointUtf16::new(start_row, start_column); + let end = PointUtf16::new(end_row, end_column); + let range = if start > end { end..start } else { start..end }; + highlights.push(lsp::DocumentHighlight { + range: range_to_lsp(range.clone()), + kind: Some(lsp::DocumentHighlightKind::READ), + }); + } + highlights.sort_unstable_by_key(|highlight| { + (highlight.range.start, highlight.range.end) + }); + async move { Ok(Some(highlights)) } + }, + ); + } + })), + ..Default::default() + })) + .await; + client.app_state.languages.add(Arc::new(language)); + } + + fn on_clients_quiesced(clients: &[(Rc, TestAppContext)]) { + for (client, client_cx) in clients { + for guest_project in client.remote_projects().iter() { + guest_project.read_with(client_cx, |guest_project, cx| { + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == guest_project.remote_id() + }) + })? + .clone(); + Some((project, cx)) + }); + + if !guest_project.is_read_only() { + if let Some((host_project, host_cx)) = host_project { + let host_worktree_snapshots = + host_project.read_with(host_cx, |host_project, cx| { + host_project + .worktrees(cx) + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>() + }); + let guest_worktree_snapshots = guest_project + .worktrees(cx) + .map(|worktree| { + let worktree = worktree.read(cx); + (worktree.id(), worktree.snapshot()) + }) + .collect::>(); + + assert_eq!( + guest_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + host_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), + "{} has different worktrees than the host for project {:?}", + client.username, guest_project.remote_id(), + ); + + for (id, host_snapshot) in &host_worktree_snapshots { + let guest_snapshot = &guest_worktree_snapshots[id]; + assert_eq!( + guest_snapshot.root_name(), + host_snapshot.root_name(), + "{} has different root name than the host for worktree {}, project {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.abs_path(), + host_snapshot.abs_path(), + "{} has different abs path than the host for worktree {}, project: {:?}", + client.username, + id, + guest_project.remote_id(), + ); + assert_eq!( + guest_snapshot.entries(false).collect::>(), + host_snapshot.entries(false).collect::>(), + "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", + client.username, + host_snapshot.abs_path(), + id, + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.repositories().collect::>(), host_snapshot.repositories().collect::>(), + "{} has different repositories than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + assert_eq!(guest_snapshot.scan_id(), host_snapshot.scan_id(), + "{} has different scan id than the host for worktree {:?} and project {:?}", + client.username, + host_snapshot.abs_path(), + guest_project.remote_id(), + ); + } + } + } + + for buffer in guest_project.opened_buffers(cx) { + let buffer = buffer.read(cx); + assert_eq!( + buffer.deferred_ops_len(), + 0, + "{} has deferred operations for buffer {:?} in project {:?}", + client.username, + buffer.file().unwrap().full_path(cx), + guest_project.remote_id(), + ); + } + }); + } + + let buffers = client.buffers().clone(); + for (guest_project, guest_buffers) in &buffers { + let project_id = if guest_project.read_with(client_cx, |project, _| { + project.is_local() || project.is_read_only() + }) { + continue; + } else { + guest_project + .read_with(client_cx, |project, _| project.remote_id()) + .unwrap() + }; + let guest_user_id = client.user_id().unwrap(); + + let host_project = clients.iter().find_map(|(client, cx)| { + let project = client + .local_projects() + .iter() + .find(|host_project| { + host_project.read_with(cx, |host_project, _| { + host_project.remote_id() == Some(project_id) + }) + })? + .clone(); + Some((client.user_id().unwrap(), project, cx)) + }); + + let (host_user_id, host_project, host_cx) = + if let Some((host_user_id, host_project, host_cx)) = host_project { + (host_user_id, host_project, host_cx) + } else { + continue; + }; + + for guest_buffer in guest_buffers { + let buffer_id = + guest_buffer.read_with(client_cx, |buffer, _| buffer.remote_id()); + let host_buffer = host_project.read_with(host_cx, |project, cx| { + project.buffer_for_id(buffer_id, cx).unwrap_or_else(|| { + panic!( + "host does not have buffer for guest:{}, peer:{:?}, id:{}", + client.username, + client.peer_id(), + buffer_id + ) + }) + }); + let path = host_buffer + .read_with(host_cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); + + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.deferred_ops_len()), + 0, + "{}, buffer {}, path {:?} has deferred operations", + client.username, + buffer_id, + path, + ); + assert_eq!( + guest_buffer.read_with(client_cx, |buffer, _| buffer.text()), + host_buffer.read_with(host_cx, |buffer, _| buffer.text()), + "{}, buffer {}, path {:?}, differs from the host's buffer", + client.username, + buffer_id, + path + ); + + let host_file = host_buffer.read_with(host_cx, |b, _| b.file().cloned()); + let guest_file = guest_buffer.read_with(client_cx, |b, _| b.file().cloned()); + match (host_file, guest_file) { + (Some(host_file), Some(guest_file)) => { + assert_eq!(guest_file.path(), host_file.path()); + assert_eq!(guest_file.is_deleted(), host_file.is_deleted()); + assert_eq!( + guest_file.mtime(), + host_file.mtime(), + "guest {} mtime does not match host {} for path {:?} in project {}", + guest_user_id, + host_user_id, + guest_file.path(), + project_id, + ); + } + (None, None) => {} + (None, _) => panic!("host's file is None, guest's isn't"), + (_, None) => panic!("guest's file is None, hosts's isn't"), + } + + let host_diff_base = host_buffer + .read_with(host_cx, |b, _| b.diff_base().map(ToString::to_string)); + let guest_diff_base = guest_buffer + .read_with(client_cx, |b, _| b.diff_base().map(ToString::to_string)); + assert_eq!( + guest_diff_base, host_diff_base, + "guest {} diff base does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version = + host_buffer.read_with(host_cx, |b, _| b.saved_version().clone()); + let guest_saved_version = + guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone()); + assert_eq!( + guest_saved_version, host_saved_version, + "guest {} saved version does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_version_fingerprint = + host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint()); + let guest_saved_version_fingerprint = + guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint()); + assert_eq!( + guest_saved_version_fingerprint, host_saved_version_fingerprint, + "guest {} saved fingerprint does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime()); + let guest_saved_mtime = + guest_buffer.read_with(client_cx, |b, _| b.saved_mtime()); + assert_eq!( + guest_saved_mtime, host_saved_mtime, + "guest {} saved mtime does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty()); + let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty()); + assert_eq!(guest_is_dirty, host_is_dirty, + "guest {} dirty status does not match host's for path {path:?} in project {project_id}", + client.username + ); + + let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict()); + let guest_has_conflict = + guest_buffer.read_with(client_cx, |b, _| b.has_conflict()); + assert_eq!(guest_has_conflict, host_has_conflict, + "guest {} conflict status does not match host's for path {path:?} in project {project_id}", + client.username + ); + } + } + } + } +} + +fn generate_git_operation(rng: &mut StdRng, client: &TestClient) -> GitOperation { + fn generate_file_paths( + repo_path: &Path, + rng: &mut StdRng, + client: &TestClient, + ) -> Vec { + let mut paths = client + .fs() + .files() + .into_iter() + .filter(|path| path.starts_with(repo_path)) + .collect::>(); + + let count = rng.gen_range(0..=paths.len()); + paths.shuffle(rng); + paths.truncate(count); + + paths + .iter() + .map(|path| path.strip_prefix(repo_path).unwrap().to_path_buf()) + .collect::>() + } + + let repo_path = client.fs().directories(false).choose(rng).unwrap().clone(); + + match rng.gen_range(0..100_u32) { + 0..=25 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let contents = file_paths + .into_iter() + .map(|path| (path, Alphanumeric.sample_string(rng, 16))) + .collect(); + + GitOperation::WriteGitIndex { + repo_path, + contents, + } + } + 26..=63 => { + let new_branch = (rng.gen_range(0..10) > 3).then(|| Alphanumeric.sample_string(rng, 8)); + + GitOperation::WriteGitBranch { + repo_path, + new_branch, + } + } + 64..=100 => { + let file_paths = generate_file_paths(&repo_path, rng, client); + + let statuses = file_paths + .into_iter() + .map(|paths| { + ( + paths, + match rng.gen_range(0..3_u32) { + 0 => GitFileStatus::Added, + 1 => GitFileStatus::Modified, + 2 => GitFileStatus::Conflict, + _ => unreachable!(), + }, + ) + }) + .collect::>(); + + let git_operation = rng.gen::(); + + GitOperation::WriteGitStatuses { + repo_path, + statuses, + git_operation, + } + } + _ => unreachable!(), + } +} + +fn buffer_for_full_path( + client: &TestClient, + project: &ModelHandle, + full_path: &PathBuf, + cx: &TestAppContext, +) -> Option> { + client + .buffers_for_project(project) + .iter() + .find(|buffer| { + buffer.read_with(cx, |buffer, cx| { + buffer.file().unwrap().full_path(cx) == *full_path + }) + }) + .cloned() +} + +fn project_for_root_name( + client: &TestClient, + root_name: &str, + cx: &TestAppContext, +) -> Option> { + if let Some(ix) = project_ix_for_root_name(&*client.local_projects(), root_name, cx) { + return Some(client.local_projects()[ix].clone()); + } + if let Some(ix) = project_ix_for_root_name(&*client.remote_projects(), root_name, cx) { + return Some(client.remote_projects()[ix].clone()); + } + None +} + +fn project_ix_for_root_name( + projects: &[ModelHandle], + root_name: &str, + cx: &TestAppContext, +) -> Option { + projects.iter().position(|project| { + project.read_with(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next().unwrap(); + worktree.read(cx).root_name() == root_name + }) + }) +} + +fn root_name_for_project(project: &ModelHandle, cx: &TestAppContext) -> String { + project.read_with(cx, |project, cx| { + project + .visible_worktrees(cx) + .next() + .unwrap() + .read(cx) + .root_name() + .to_string() + }) +} + +fn project_path_for_full_path( + project: &ModelHandle, + full_path: &Path, + cx: &TestAppContext, +) -> Option { + let mut components = full_path.components(); + let root_name = components.next().unwrap().as_os_str().to_str().unwrap(); + let path = components.as_path().into(); + let worktree_id = project.read_with(cx, |project, cx| { + project.worktrees(cx).find_map(|worktree| { + let worktree = worktree.read(cx); + if worktree.root_name() == root_name { + Some(worktree.id()) + } else { + None + } + }) + })?; + Some(ProjectPath { worktree_id, path }) +} + +async fn ensure_project_shared( + project: &ModelHandle, + client: &TestClient, + cx: &mut TestAppContext, +) { + let first_root_name = root_name_for_project(project, cx); + let active_call = cx.read(ActiveCall::global); + if active_call.read_with(cx, |call, _| call.room().is_some()) + && project.read_with(cx, |project, _| project.is_local() && !project.is_shared()) + { + match active_call + .update(cx, |call, cx| call.share_project(project.clone(), cx)) + .await + { + Ok(project_id) => { + log::info!( + "{}: shared project {} with id {}", + client.username, + first_root_name, + project_id + ); + } + Err(error) => { + log::error!( + "{}: error sharing project {}: {:?}", + client.username, + first_root_name, + error + ); + } + } + } +} + +fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option> { + client + .local_projects() + .iter() + .chain(client.remote_projects().iter()) + .choose(rng) + .cloned() +} + +fn gen_file_name(rng: &mut StdRng) -> String { + let mut name = String::new(); + for _ in 0..10 { + let letter = rng.gen_range('a'..='z'); + name.push(letter); + } + name +} diff --git a/crates/collab/src/tests/randomized_integration_tests.rs b/crates/collab/src/tests/randomized_integration_tests.rs deleted file mode 100644 index 309fcf7e44..0000000000 --- a/crates/collab/src/tests/randomized_integration_tests.rs +++ /dev/null @@ -1,2199 +0,0 @@ -use crate::{ - db::{self, NewUserParams, UserId}, - rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - tests::{TestClient, TestServer}, -}; -use anyhow::{anyhow, Result}; -use call::ActiveCall; -use client::RECEIVE_TIMEOUT; -use collections::{BTreeMap, HashMap}; -use editor::Bias; -use fs::{repository::GitFileStatus, FakeFs, Fs as _}; -use futures::StreamExt as _; -use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext}; -use language::{range_to_lsp, FakeLspAdapter, Language, LanguageConfig, PointUtf16}; -use lsp::FakeLanguageServer; -use parking_lot::Mutex; -use pretty_assertions::assert_eq; -use project::{search::SearchQuery, Project, ProjectPath}; -use rand::{ - distributions::{Alphanumeric, DistString}, - prelude::*, -}; -use serde::{Deserialize, Serialize}; -use settings::SettingsStore; -use std::{ - env, - ops::Range, - path::{Path, PathBuf}, - rc::Rc, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; -use util::ResultExt; - -lazy_static::lazy_static! { - static ref PLAN_LOAD_PATH: Option = path_env_var("LOAD_PLAN"); - static ref PLAN_SAVE_PATH: Option = path_env_var("SAVE_PLAN"); -} -static LOADED_PLAN_JSON: Mutex>> = Mutex::new(None); -static PLAN: Mutex>>> = Mutex::new(None); - -#[gpui::test(iterations = 100, on_failure = "on_failure")] -async fn test_random_collaboration( - cx: &mut TestAppContext, - deterministic: Arc, - rng: StdRng, -) { - deterministic.forbid_parking(); - - let max_peers = env::var("MAX_PEERS") - .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) - .unwrap_or(3); - let max_operations = env::var("OPERATIONS") - .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) - .unwrap_or(10); - - let mut server = TestServer::start(&deterministic).await; - let db = server.app_state.db.clone(); - - let mut users = Vec::new(); - for ix in 0..max_peers { - let username = format!("user-{}", ix + 1); - let user_id = db - .create_user( - &format!("{username}@example.com"), - false, - NewUserParams { - github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - users.push(UserTestPlan { - user_id, - username, - online: false, - next_root_id: 0, - operation_ix: 0, - }); - } - - for (ix, user_a) in users.iter().enumerate() { - for user_b in &users[ix + 1..] { - server - .app_state - .db - .send_contact_request(user_a.user_id, user_b.user_id) - .await - .unwrap(); - server - .app_state - .db - .respond_to_contact_request(user_b.user_id, user_a.user_id, true) - .await - .unwrap(); - } - } - - let plan = Arc::new(Mutex::new(TestPlan::new(rng, users, max_operations))); - - if let Some(path) = &*PLAN_LOAD_PATH { - let json = LOADED_PLAN_JSON - .lock() - .get_or_insert_with(|| { - eprintln!("loaded test plan from path {:?}", path); - std::fs::read(path).unwrap() - }) - .clone(); - plan.lock().deserialize(json); - } - - PLAN.lock().replace(plan.clone()); - - let mut clients = Vec::new(); - let mut client_tasks = Vec::new(); - let mut operation_channels = Vec::new(); - - loop { - let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else { - break; - }; - applied.store(true, SeqCst); - let did_apply = apply_server_operation( - deterministic.clone(), - &mut server, - &mut clients, - &mut client_tasks, - &mut operation_channels, - plan.clone(), - next_operation, - cx, - ) - .await; - if !did_apply { - applied.store(false, SeqCst); - } - } - - drop(operation_channels); - deterministic.start_waiting(); - futures::future::join_all(client_tasks).await; - deterministic.finish_waiting(); - deterministic.run_until_parked(); - - check_consistency_between_clients(&clients); - - for (client, mut cx) in clients { - cx.update(|cx| { - let store = cx.remove_global::(); - cx.clear_globals(); - cx.set_global(store); - drop(client); - }); - } - - deterministic.run_until_parked(); -} - -fn on_failure() { - if let Some(plan) = PLAN.lock().clone() { - if let Some(path) = &*PLAN_SAVE_PATH { - eprintln!("saved test plan to path {:?}", path); - std::fs::write(path, plan.lock().serialize()).unwrap(); - } - } -} - -async fn apply_server_operation( - deterministic: Arc, - server: &mut TestServer, - clients: &mut Vec<(Rc, TestAppContext)>, - client_tasks: &mut Vec>, - operation_channels: &mut Vec>, - plan: Arc>, - operation: Operation, - cx: &mut TestAppContext, -) -> bool { - match operation { - Operation::AddConnection { user_id } => { - let username; - { - let mut plan = plan.lock(); - let user = plan.user(user_id); - if user.online { - return false; - } - user.online = true; - username = user.username.clone(); - }; - log::info!("Adding new connection for {}", username); - let next_entity_id = (user_id.0 * 10_000) as usize; - let mut client_cx = TestAppContext::new( - cx.foreground_platform(), - cx.platform(), - deterministic.build_foreground(user_id.0 as usize), - deterministic.build_background(), - cx.font_cache(), - cx.leak_detector(), - next_entity_id, - cx.function_name.clone(), - ); - - let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded(); - let client = Rc::new(server.create_client(&mut client_cx, &username).await); - operation_channels.push(operation_tx); - clients.push((client.clone(), client_cx.clone())); - client_tasks.push(client_cx.foreground().spawn(simulate_client( - client, - operation_rx, - plan.clone(), - client_cx, - ))); - - log::info!("Added connection for {}", username); - } - - Operation::RemoveConnection { - user_id: removed_user_id, - } => { - log::info!("Simulating full disconnection of user {}", removed_user_id); - let client_ix = clients - .iter() - .position(|(client, cx)| client.current_user_id(cx) == removed_user_id); - let Some(client_ix) = client_ix else { - return false; - }; - let user_connection_ids = server - .connection_pool - .lock() - .user_connection_ids(removed_user_id) - .collect::>(); - assert_eq!(user_connection_ids.len(), 1); - let removed_peer_id = user_connection_ids[0].into(); - let (client, mut client_cx) = clients.remove(client_ix); - let client_task = client_tasks.remove(client_ix); - operation_channels.remove(client_ix); - server.forbid_connections(); - server.disconnect_client(removed_peer_id); - deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - deterministic.start_waiting(); - log::info!("Waiting for user {} to exit...", removed_user_id); - client_task.await; - deterministic.finish_waiting(); - server.allow_connections(); - - for project in client.remote_projects().iter() { - project.read_with(&client_cx, |project, _| { - assert!( - project.is_read_only(), - "project {:?} should be read only", - project.remote_id() - ) - }); - } - - for (client, cx) in clients { - let contacts = server - .app_state - .db - .get_contacts(client.current_user_id(cx)) - .await - .unwrap(); - let pool = server.connection_pool.lock(); - for contact in contacts { - if let db::Contact::Accepted { user_id, busy, .. } = contact { - if user_id == removed_user_id { - assert!(!pool.is_user_online(user_id)); - assert!(!busy); - } - } - } - } - - log::info!("{} removed", client.username); - plan.lock().user(removed_user_id).online = false; - client_cx.update(|cx| { - cx.clear_globals(); - drop(client); - }); - } - - Operation::BounceConnection { user_id } => { - log::info!("Simulating temporary disconnection of user {}", user_id); - let user_connection_ids = server - .connection_pool - .lock() - .user_connection_ids(user_id) - .collect::>(); - if user_connection_ids.is_empty() { - return false; - } - assert_eq!(user_connection_ids.len(), 1); - let peer_id = user_connection_ids[0].into(); - server.disconnect_client(peer_id); - deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); - } - - Operation::RestartServer => { - log::info!("Simulating server restart"); - server.reset().await; - deterministic.advance_clock(RECEIVE_TIMEOUT); - server.start().await.unwrap(); - deterministic.advance_clock(CLEANUP_TIMEOUT); - let environment = &server.app_state.config.zed_environment; - let (stale_room_ids, _) = server - .app_state - .db - .stale_server_resource_ids(environment, server.id()) - .await - .unwrap(); - assert_eq!(stale_room_ids, vec![]); - } - - Operation::MutateClients { - user_ids, - batch_id, - quiesce, - } => { - let mut applied = false; - for user_id in user_ids { - let client_ix = clients - .iter() - .position(|(client, cx)| client.current_user_id(cx) == user_id); - let Some(client_ix) = client_ix else { continue }; - applied = true; - if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) { - log::error!("error signaling user {user_id}: {err}"); - } - } - - if quiesce && applied { - deterministic.run_until_parked(); - check_consistency_between_clients(&clients); - } - - return applied; - } - } - true -} - -async fn apply_client_operation( - client: &TestClient, - operation: ClientOperation, - cx: &mut TestAppContext, -) -> Result<(), TestError> { - match operation { - ClientOperation::AcceptIncomingCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: accepting incoming call", client.username); - active_call - .update(cx, |call, cx| call.accept_incoming(cx)) - .await?; - } - - ClientOperation::RejectIncomingCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.incoming().borrow().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: declining incoming call", client.username); - active_call.update(cx, |call, cx| call.decline_incoming(cx))?; - } - - ClientOperation::LeaveCall => { - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.room().is_none()) { - Err(TestError::Inapplicable)?; - } - - log::info!("{}: hanging up", client.username); - active_call.update(cx, |call, cx| call.hang_up(cx)).await?; - } - - ClientOperation::InviteContactToCall { user_id } => { - let active_call = cx.read(ActiveCall::global); - - log::info!("{}: inviting {}", client.username, user_id,); - active_call - .update(cx, |call, cx| call.invite(user_id.to_proto(), None, cx)) - .await - .log_err(); - } - - ClientOperation::OpenLocalProject { first_root_name } => { - log::info!( - "{}: opening local project at {:?}", - client.username, - first_root_name - ); - - let root_path = Path::new("/").join(&first_root_name); - client.fs().create_dir(&root_path).await.unwrap(); - client - .fs() - .create_file(&root_path.join("main.rs"), Default::default()) - .await - .unwrap(); - let project = client.build_local_project(root_path, cx).await.0; - ensure_project_shared(&project, client, cx).await; - client.local_projects_mut().push(project.clone()); - } - - ClientOperation::AddWorktreeToProject { - project_root_name, - new_root_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: finding/creating local worktree at {:?} to project with root path {}", - client.username, - new_root_path, - project_root_name - ); - - ensure_project_shared(&project, client, cx).await; - if !client.fs().paths(false).contains(&new_root_path) { - client.fs().create_dir(&new_root_path).await.unwrap(); - } - project - .update(cx, |project, cx| { - project.find_or_create_local_worktree(&new_root_path, true, cx) - }) - .await - .unwrap(); - } - - ClientOperation::CloseRemoteProject { project_root_name } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: closing remote project with root path {}", - client.username, - project_root_name, - ); - - let ix = client - .remote_projects() - .iter() - .position(|p| p == &project) - .unwrap(); - cx.update(|_| { - client.remote_projects_mut().remove(ix); - client.buffers().retain(|p, _| *p != project); - drop(project); - }); - } - - ClientOperation::OpenRemoteProject { - host_id, - first_root_name, - } => { - let active_call = cx.read(ActiveCall::global); - let project = active_call - .update(cx, |call, cx| { - let room = call.room().cloned()?; - let participant = room - .read(cx) - .remote_participants() - .get(&host_id.to_proto())?; - let project_id = participant - .projects - .iter() - .find(|project| project.worktree_root_names[0] == first_root_name)? - .id; - Some(room.update(cx, |room, cx| { - room.join_project( - project_id, - client.language_registry().clone(), - FakeFs::new(cx.background().clone()), - cx, - ) - })) - }) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: joining remote project of user {}, root name {}", - client.username, - host_id, - first_root_name, - ); - - let project = project.await?; - client.remote_projects_mut().push(project.clone()); - } - - ClientOperation::CreateWorktreeEntry { - project_root_name, - is_local, - full_path, - is_dir, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let project_path = project_path_for_full_path(&project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: creating {} at path {:?} in {} project {}", - client.username, - if is_dir { "dir" } else { "file" }, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - ); - - ensure_project_shared(&project, client, cx).await; - project - .update(cx, |p, cx| p.create_entry(project_path, is_dir, cx)) - .unwrap() - .await?; - } - - ClientOperation::OpenBuffer { - project_root_name, - is_local, - full_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let project_path = project_path_for_full_path(&project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: opening buffer {:?} in {} project {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - ); - - ensure_project_shared(&project, client, cx).await; - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx)) - .await?; - client.buffers_for_project(&project).insert(buffer); - } - - ClientOperation::EditBuffer { - project_root_name, - is_local, - full_path, - edits, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: editing buffer {:?} in {} project {} with {:?}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - edits - ); - - ensure_project_shared(&project, client, cx).await; - buffer.update(cx, |buffer, cx| { - let snapshot = buffer.snapshot(); - buffer.edit( - edits.into_iter().map(|(range, text)| { - let start = snapshot.clip_offset(range.start, Bias::Left); - let end = snapshot.clip_offset(range.end, Bias::Right); - (start..end, text) - }), - None, - cx, - ); - }); - } - - ClientOperation::CloseBuffer { - project_root_name, - is_local, - full_path, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: closing buffer {:?} in {} project {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name - ); - - ensure_project_shared(&project, client, cx).await; - cx.update(|_| { - client.buffers_for_project(&project).remove(&buffer); - drop(buffer); - }); - } - - ClientOperation::SaveBuffer { - project_root_name, - is_local, - full_path, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: saving buffer {:?} in {} project {}, {}", - client.username, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - if detach { "detaching" } else { "awaiting" } - ); - - ensure_project_shared(&project, client, cx).await; - let requested_version = buffer.read_with(cx, |buffer, _| buffer.version()); - let save = project.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); - let save = cx.spawn(|cx| async move { - save.await - .map_err(|err| anyhow!("save request failed: {:?}", err))?; - assert!(buffer - .read_with(&cx, |buffer, _| { buffer.saved_version().to_owned() }) - .observed_all(&requested_version)); - anyhow::Ok(()) - }); - if detach { - cx.update(|cx| save.detach_and_log_err(cx)); - } else { - save.await?; - } - } - - ClientOperation::RequestLspDataInBuffer { - project_root_name, - is_local, - full_path, - offset, - kind, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - let buffer = buffer_for_full_path(client, &project, &full_path, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: request LSP {:?} for buffer {:?} in {} project {}, {}", - client.username, - kind, - full_path, - if is_local { "local" } else { "remote" }, - project_root_name, - if detach { "detaching" } else { "awaiting" } - ); - - use futures::{FutureExt as _, TryFutureExt as _}; - let offset = buffer.read_with(cx, |b, _| b.clip_offset(offset, Bias::Left)); - let request = cx.foreground().spawn(project.update(cx, |project, cx| { - match kind { - LspRequestKind::Rename => project - .prepare_rename(buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Completion => project - .completions(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::CodeAction => project - .code_actions(&buffer, offset..offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Definition => project - .definition(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - LspRequestKind::Highlights => project - .document_highlights(&buffer, offset, cx) - .map_ok(|_| ()) - .boxed(), - } - })); - if detach { - request.detach(); - } else { - request.await?; - } - } - - ClientOperation::SearchProject { - project_root_name, - is_local, - query, - detach, - } => { - let project = project_for_root_name(client, &project_root_name, cx) - .ok_or(TestError::Inapplicable)?; - - log::info!( - "{}: search {} project {} for {:?}, {}", - client.username, - if is_local { "local" } else { "remote" }, - project_root_name, - query, - if detach { "detaching" } else { "awaiting" } - ); - - let mut search = project.update(cx, |project, cx| { - project.search( - SearchQuery::text(query, false, false, Vec::new(), Vec::new()), - cx, - ) - }); - drop(project); - let search = cx.background().spawn(async move { - let mut results = HashMap::default(); - while let Some((buffer, ranges)) = search.next().await { - results.entry(buffer).or_insert(ranges); - } - results - }); - search.await; - } - - ClientOperation::WriteFsEntry { - path, - is_dir, - content, - } => { - if !client - .fs() - .directories(false) - .contains(&path.parent().unwrap().to_owned()) - { - return Err(TestError::Inapplicable); - } - - if is_dir { - log::info!("{}: creating dir at {:?}", client.username, path); - client.fs().create_dir(&path).await.unwrap(); - } else { - let exists = client.fs().metadata(&path).await?.is_some(); - let verb = if exists { "updating" } else { "creating" }; - log::info!("{}: {} file at {:?}", verb, client.username, path); - - client - .fs() - .save(&path, &content.as_str().into(), text::LineEnding::Unix) - .await - .unwrap(); - } - } - - ClientOperation::GitOperation { operation } => match operation { - GitOperation::WriteGitIndex { - repo_path, - contents, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - - for (path, _) in contents.iter() { - if !client.fs().files().contains(&repo_path.join(path)) { - return Err(TestError::Inapplicable); - } - } - - log::info!( - "{}: writing git index for repo {:?}: {:?}", - client.username, - repo_path, - contents - ); - - let dot_git_dir = repo_path.join(".git"); - let contents = contents - .iter() - .map(|(path, contents)| (path.as_path(), contents.clone())) - .collect::>(); - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - client.fs().set_index_for_repo(&dot_git_dir, &contents); - } - GitOperation::WriteGitBranch { - repo_path, - new_branch, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - - log::info!( - "{}: writing git branch for repo {:?}: {:?}", - client.username, - repo_path, - new_branch - ); - - let dot_git_dir = repo_path.join(".git"); - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - client.fs().set_branch_name(&dot_git_dir, new_branch); - } - GitOperation::WriteGitStatuses { - repo_path, - statuses, - git_operation, - } => { - if !client.fs().directories(false).contains(&repo_path) { - return Err(TestError::Inapplicable); - } - for (path, _) in statuses.iter() { - if !client.fs().files().contains(&repo_path.join(path)) { - return Err(TestError::Inapplicable); - } - } - - log::info!( - "{}: writing git statuses for repo {:?}: {:?}", - client.username, - repo_path, - statuses - ); - - let dot_git_dir = repo_path.join(".git"); - - let statuses = statuses - .iter() - .map(|(path, val)| (path.as_path(), val.clone())) - .collect::>(); - - if client.fs().metadata(&dot_git_dir).await?.is_none() { - client.fs().create_dir(&dot_git_dir).await?; - } - - if git_operation { - client - .fs() - .set_status_for_repo_via_git_operation(&dot_git_dir, statuses.as_slice()); - } else { - client.fs().set_status_for_repo_via_working_copy_change( - &dot_git_dir, - statuses.as_slice(), - ); - } - } - }, - } - Ok(()) -} - -fn check_consistency_between_clients(clients: &[(Rc, TestAppContext)]) { - for (client, client_cx) in clients { - for guest_project in client.remote_projects().iter() { - guest_project.read_with(client_cx, |guest_project, cx| { - let host_project = clients.iter().find_map(|(client, cx)| { - let project = client - .local_projects() - .iter() - .find(|host_project| { - host_project.read_with(cx, |host_project, _| { - host_project.remote_id() == guest_project.remote_id() - }) - })? - .clone(); - Some((project, cx)) - }); - - if !guest_project.is_read_only() { - if let Some((host_project, host_cx)) = host_project { - let host_worktree_snapshots = - host_project.read_with(host_cx, |host_project, cx| { - host_project - .worktrees(cx) - .map(|worktree| { - let worktree = worktree.read(cx); - (worktree.id(), worktree.snapshot()) - }) - .collect::>() - }); - let guest_worktree_snapshots = guest_project - .worktrees(cx) - .map(|worktree| { - let worktree = worktree.read(cx); - (worktree.id(), worktree.snapshot()) - }) - .collect::>(); - - assert_eq!( - guest_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), - host_worktree_snapshots.values().map(|w| w.abs_path()).collect::>(), - "{} has different worktrees than the host for project {:?}", - client.username, guest_project.remote_id(), - ); - - for (id, host_snapshot) in &host_worktree_snapshots { - let guest_snapshot = &guest_worktree_snapshots[id]; - assert_eq!( - guest_snapshot.root_name(), - host_snapshot.root_name(), - "{} has different root name than the host for worktree {}, project {:?}", - client.username, - id, - guest_project.remote_id(), - ); - assert_eq!( - guest_snapshot.abs_path(), - host_snapshot.abs_path(), - "{} has different abs path than the host for worktree {}, project: {:?}", - client.username, - id, - guest_project.remote_id(), - ); - assert_eq!( - guest_snapshot.entries(false).collect::>(), - host_snapshot.entries(false).collect::>(), - "{} has different snapshot than the host for worktree {:?} ({:?}) and project {:?}", - client.username, - host_snapshot.abs_path(), - id, - guest_project.remote_id(), - ); - assert_eq!(guest_snapshot.repositories().collect::>(), host_snapshot.repositories().collect::>(), - "{} has different repositories than the host for worktree {:?} and project {:?}", - client.username, - host_snapshot.abs_path(), - guest_project.remote_id(), - ); - assert_eq!(guest_snapshot.scan_id(), host_snapshot.scan_id(), - "{} has different scan id than the host for worktree {:?} and project {:?}", - client.username, - host_snapshot.abs_path(), - guest_project.remote_id(), - ); - } - } - } - - for buffer in guest_project.opened_buffers(cx) { - let buffer = buffer.read(cx); - assert_eq!( - buffer.deferred_ops_len(), - 0, - "{} has deferred operations for buffer {:?} in project {:?}", - client.username, - buffer.file().unwrap().full_path(cx), - guest_project.remote_id(), - ); - } - }); - } - - let buffers = client.buffers().clone(); - for (guest_project, guest_buffers) in &buffers { - let project_id = if guest_project.read_with(client_cx, |project, _| { - project.is_local() || project.is_read_only() - }) { - continue; - } else { - guest_project - .read_with(client_cx, |project, _| project.remote_id()) - .unwrap() - }; - let guest_user_id = client.user_id().unwrap(); - - let host_project = clients.iter().find_map(|(client, cx)| { - let project = client - .local_projects() - .iter() - .find(|host_project| { - host_project.read_with(cx, |host_project, _| { - host_project.remote_id() == Some(project_id) - }) - })? - .clone(); - Some((client.user_id().unwrap(), project, cx)) - }); - - let (host_user_id, host_project, host_cx) = - if let Some((host_user_id, host_project, host_cx)) = host_project { - (host_user_id, host_project, host_cx) - } else { - continue; - }; - - for guest_buffer in guest_buffers { - let buffer_id = guest_buffer.read_with(client_cx, |buffer, _| buffer.remote_id()); - let host_buffer = host_project.read_with(host_cx, |project, cx| { - project.buffer_for_id(buffer_id, cx).unwrap_or_else(|| { - panic!( - "host does not have buffer for guest:{}, peer:{:?}, id:{}", - client.username, - client.peer_id(), - buffer_id - ) - }) - }); - let path = host_buffer - .read_with(host_cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); - - assert_eq!( - guest_buffer.read_with(client_cx, |buffer, _| buffer.deferred_ops_len()), - 0, - "{}, buffer {}, path {:?} has deferred operations", - client.username, - buffer_id, - path, - ); - assert_eq!( - guest_buffer.read_with(client_cx, |buffer, _| buffer.text()), - host_buffer.read_with(host_cx, |buffer, _| buffer.text()), - "{}, buffer {}, path {:?}, differs from the host's buffer", - client.username, - buffer_id, - path - ); - - let host_file = host_buffer.read_with(host_cx, |b, _| b.file().cloned()); - let guest_file = guest_buffer.read_with(client_cx, |b, _| b.file().cloned()); - match (host_file, guest_file) { - (Some(host_file), Some(guest_file)) => { - assert_eq!(guest_file.path(), host_file.path()); - assert_eq!(guest_file.is_deleted(), host_file.is_deleted()); - assert_eq!( - guest_file.mtime(), - host_file.mtime(), - "guest {} mtime does not match host {} for path {:?} in project {}", - guest_user_id, - host_user_id, - guest_file.path(), - project_id, - ); - } - (None, None) => {} - (None, _) => panic!("host's file is None, guest's isn't"), - (_, None) => panic!("guest's file is None, hosts's isn't"), - } - - let host_diff_base = - host_buffer.read_with(host_cx, |b, _| b.diff_base().map(ToString::to_string)); - let guest_diff_base = guest_buffer - .read_with(client_cx, |b, _| b.diff_base().map(ToString::to_string)); - assert_eq!( - guest_diff_base, host_diff_base, - "guest {} diff base does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_version = - host_buffer.read_with(host_cx, |b, _| b.saved_version().clone()); - let guest_saved_version = - guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone()); - assert_eq!( - guest_saved_version, host_saved_version, - "guest {} saved version does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_version_fingerprint = - host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint()); - let guest_saved_version_fingerprint = - guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint()); - assert_eq!( - guest_saved_version_fingerprint, host_saved_version_fingerprint, - "guest {} saved fingerprint does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime()); - let guest_saved_mtime = guest_buffer.read_with(client_cx, |b, _| b.saved_mtime()); - assert_eq!( - guest_saved_mtime, host_saved_mtime, - "guest {} saved mtime does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty()); - let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty()); - assert_eq!(guest_is_dirty, host_is_dirty, - "guest {} dirty status does not match host's for path {path:?} in project {project_id}", - client.username - ); - - let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict()); - let guest_has_conflict = guest_buffer.read_with(client_cx, |b, _| b.has_conflict()); - assert_eq!(guest_has_conflict, host_has_conflict, - "guest {} conflict status does not match host's for path {path:?} in project {project_id}", - client.username - ); - } - } - } -} - -struct TestPlan { - rng: StdRng, - replay: bool, - stored_operations: Vec<(StoredOperation, Arc)>, - max_operations: usize, - operation_ix: usize, - users: Vec, - next_batch_id: usize, - allow_server_restarts: bool, - allow_client_reconnection: bool, - allow_client_disconnection: bool, -} - -struct UserTestPlan { - user_id: UserId, - username: String, - next_root_id: usize, - operation_ix: usize, - online: bool, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(untagged)] -enum StoredOperation { - Server(Operation), - Client { - user_id: UserId, - batch_id: usize, - operation: ClientOperation, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum Operation { - AddConnection { - user_id: UserId, - }, - RemoveConnection { - user_id: UserId, - }, - BounceConnection { - user_id: UserId, - }, - RestartServer, - MutateClients { - batch_id: usize, - #[serde(skip_serializing)] - #[serde(skip_deserializing)] - user_ids: Vec, - quiesce: bool, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum ClientOperation { - AcceptIncomingCall, - RejectIncomingCall, - LeaveCall, - InviteContactToCall { - user_id: UserId, - }, - OpenLocalProject { - first_root_name: String, - }, - OpenRemoteProject { - host_id: UserId, - first_root_name: String, - }, - AddWorktreeToProject { - project_root_name: String, - new_root_path: PathBuf, - }, - CloseRemoteProject { - project_root_name: String, - }, - OpenBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - }, - SearchProject { - project_root_name: String, - is_local: bool, - query: String, - detach: bool, - }, - EditBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - edits: Vec<(Range, Arc)>, - }, - CloseBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - }, - SaveBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - detach: bool, - }, - RequestLspDataInBuffer { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - offset: usize, - kind: LspRequestKind, - detach: bool, - }, - CreateWorktreeEntry { - project_root_name: String, - is_local: bool, - full_path: PathBuf, - is_dir: bool, - }, - WriteFsEntry { - path: PathBuf, - is_dir: bool, - content: String, - }, - GitOperation { - operation: GitOperation, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum GitOperation { - WriteGitIndex { - repo_path: PathBuf, - contents: Vec<(PathBuf, String)>, - }, - WriteGitBranch { - repo_path: PathBuf, - new_branch: Option, - }, - WriteGitStatuses { - repo_path: PathBuf, - statuses: Vec<(PathBuf, GitFileStatus)>, - git_operation: bool, - }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -enum LspRequestKind { - Rename, - Completion, - CodeAction, - Definition, - Highlights, -} - -enum TestError { - Inapplicable, - Other(anyhow::Error), -} - -impl From for TestError { - fn from(value: anyhow::Error) -> Self { - Self::Other(value) - } -} - -impl TestPlan { - fn new(mut rng: StdRng, users: Vec, max_operations: usize) -> Self { - Self { - replay: false, - allow_server_restarts: rng.gen_bool(0.7), - allow_client_reconnection: rng.gen_bool(0.7), - allow_client_disconnection: rng.gen_bool(0.1), - stored_operations: Vec::new(), - operation_ix: 0, - next_batch_id: 0, - max_operations, - users, - rng, - } - } - - fn deserialize(&mut self, json: Vec) { - let stored_operations: Vec = serde_json::from_slice(&json).unwrap(); - self.replay = true; - self.stored_operations = stored_operations - .iter() - .cloned() - .enumerate() - .map(|(i, mut operation)| { - if let StoredOperation::Server(Operation::MutateClients { - batch_id: current_batch_id, - user_ids, - .. - }) = &mut operation - { - assert!(user_ids.is_empty()); - user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| { - if let StoredOperation::Client { - user_id, batch_id, .. - } = operation - { - if batch_id == current_batch_id { - return Some(user_id); - } - } - None - })); - user_ids.sort_unstable(); - } - (operation, Arc::new(AtomicBool::new(false))) - }) - .collect() - } - - fn serialize(&mut self) -> Vec { - // Format each operation as one line - let mut json = Vec::new(); - json.push(b'['); - for (operation, applied) in &self.stored_operations { - if !applied.load(SeqCst) { - continue; - } - if json.len() > 1 { - json.push(b','); - } - json.extend_from_slice(b"\n "); - serde_json::to_writer(&mut json, operation).unwrap(); - } - json.extend_from_slice(b"\n]\n"); - json - } - - fn next_server_operation( - &mut self, - clients: &[(Rc, TestAppContext)], - ) -> Option<(Operation, Arc)> { - if self.replay { - while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) { - self.operation_ix += 1; - if let (StoredOperation::Server(operation), applied) = stored_operation { - return Some((operation.clone(), applied.clone())); - } - } - None - } else { - let operation = self.generate_server_operation(clients)?; - let applied = Arc::new(AtomicBool::new(false)); - self.stored_operations - .push((StoredOperation::Server(operation.clone()), applied.clone())); - Some((operation, applied)) - } - } - - fn next_client_operation( - &mut self, - client: &TestClient, - current_batch_id: usize, - cx: &TestAppContext, - ) -> Option<(ClientOperation, Arc)> { - let current_user_id = client.current_user_id(cx); - let user_ix = self - .users - .iter() - .position(|user| user.user_id == current_user_id) - .unwrap(); - let user_plan = &mut self.users[user_ix]; - - if self.replay { - while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) { - user_plan.operation_ix += 1; - if let ( - StoredOperation::Client { - user_id, operation, .. - }, - applied, - ) = stored_operation - { - if user_id == ¤t_user_id { - return Some((operation.clone(), applied.clone())); - } - } - } - None - } else { - let operation = self.generate_client_operation(current_user_id, client, cx)?; - let applied = Arc::new(AtomicBool::new(false)); - self.stored_operations.push(( - StoredOperation::Client { - user_id: current_user_id, - batch_id: current_batch_id, - operation: operation.clone(), - }, - applied.clone(), - )); - Some((operation, applied)) - } - } - - fn generate_server_operation( - &mut self, - clients: &[(Rc, TestAppContext)], - ) -> Option { - if self.operation_ix == self.max_operations { - return None; - } - - Some(loop { - break match self.rng.gen_range(0..100) { - 0..=29 if clients.len() < self.users.len() => { - let user = self - .users - .iter() - .filter(|u| !u.online) - .choose(&mut self.rng) - .unwrap(); - self.operation_ix += 1; - Operation::AddConnection { - user_id: user.user_id, - } - } - 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; - let user_id = client.current_user_id(cx); - self.operation_ix += 1; - Operation::RemoveConnection { user_id } - } - 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { - let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; - let user_id = client.current_user_id(cx); - self.operation_ix += 1; - Operation::BounceConnection { user_id } - } - 40..=44 if self.allow_server_restarts && clients.len() > 1 => { - self.operation_ix += 1; - Operation::RestartServer - } - _ if !clients.is_empty() => { - let count = self - .rng - .gen_range(1..10) - .min(self.max_operations - self.operation_ix); - let batch_id = util::post_inc(&mut self.next_batch_id); - let mut user_ids = (0..count) - .map(|_| { - let ix = self.rng.gen_range(0..clients.len()); - let (client, cx) = &clients[ix]; - client.current_user_id(cx) - }) - .collect::>(); - user_ids.sort_unstable(); - Operation::MutateClients { - user_ids, - batch_id, - quiesce: self.rng.gen_bool(0.7), - } - } - _ => continue, - }; - }) - } - - fn generate_client_operation( - &mut self, - user_id: UserId, - client: &TestClient, - cx: &TestAppContext, - ) -> Option { - if self.operation_ix == self.max_operations { - return None; - } - - self.operation_ix += 1; - let call = cx.read(ActiveCall::global); - Some(loop { - match self.rng.gen_range(0..100_u32) { - // Mutate the call - 0..=29 => { - // Respond to an incoming call - if call.read_with(cx, |call, _| call.incoming().borrow().is_some()) { - break if self.rng.gen_bool(0.7) { - ClientOperation::AcceptIncomingCall - } else { - ClientOperation::RejectIncomingCall - }; - } - - match self.rng.gen_range(0..100_u32) { - // Invite a contact to the current call - 0..=70 => { - let available_contacts = - client.user_store().read_with(cx, |user_store, _| { - user_store - .contacts() - .iter() - .filter(|contact| contact.online && !contact.busy) - .cloned() - .collect::>() - }); - if !available_contacts.is_empty() { - let contact = available_contacts.choose(&mut self.rng).unwrap(); - break ClientOperation::InviteContactToCall { - user_id: UserId(contact.user.id as i32), - }; - } - } - - // Leave the current call - 71.. => { - if self.allow_client_disconnection - && call.read_with(cx, |call, _| call.room().is_some()) - { - break ClientOperation::LeaveCall; - } - } - } - } - - // Mutate projects - 30..=59 => match self.rng.gen_range(0..100_u32) { - // Open a new project - 0..=70 => { - // Open a remote project - if let Some(room) = call.read_with(cx, |call, _| call.room().cloned()) { - let existing_remote_project_ids = cx.read(|cx| { - client - .remote_projects() - .iter() - .map(|p| p.read(cx).remote_id().unwrap()) - .collect::>() - }); - let new_remote_projects = room.read_with(cx, |room, _| { - room.remote_participants() - .values() - .flat_map(|participant| { - participant.projects.iter().filter_map(|project| { - if existing_remote_project_ids.contains(&project.id) { - None - } else { - Some(( - UserId::from_proto(participant.user.id), - project.worktree_root_names[0].clone(), - )) - } - }) - }) - .collect::>() - }); - if !new_remote_projects.is_empty() { - let (host_id, first_root_name) = - new_remote_projects.choose(&mut self.rng).unwrap().clone(); - break ClientOperation::OpenRemoteProject { - host_id, - first_root_name, - }; - } - } - // Open a local project - else { - let first_root_name = self.next_root_dir_name(user_id); - break ClientOperation::OpenLocalProject { first_root_name }; - } - } - - // Close a remote project - 71..=80 => { - if !client.remote_projects().is_empty() { - let project = client - .remote_projects() - .choose(&mut self.rng) - .unwrap() - .clone(); - let first_root_name = root_name_for_project(&project, cx); - break ClientOperation::CloseRemoteProject { - project_root_name: first_root_name, - }; - } - } - - // Mutate project worktrees - 81.. => match self.rng.gen_range(0..100_u32) { - // Add a worktree to a local project - 0..=50 => { - let Some(project) = - client.local_projects().choose(&mut self.rng).cloned() - else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let mut paths = client.fs().paths(false); - paths.remove(0); - let new_root_path = if paths.is_empty() || self.rng.gen() { - Path::new("/").join(&self.next_root_dir_name(user_id)) - } else { - paths.choose(&mut self.rng).unwrap().clone() - }; - break ClientOperation::AddWorktreeToProject { - project_root_name, - new_root_path, - }; - } - - // Add an entry to a worktree - _ => { - let Some(project) = choose_random_project(client, &mut self.rng) else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let is_local = project.read_with(cx, |project, _| project.is_local()); - let worktree = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .filter(|worktree| { - let worktree = worktree.read(cx); - worktree.is_visible() - && worktree.entries(false).any(|e| e.is_file()) - && worktree.root_entry().map_or(false, |e| e.is_dir()) - }) - .choose(&mut self.rng) - }); - let Some(worktree) = worktree else { continue }; - let is_dir = self.rng.gen::(); - let mut full_path = - worktree.read_with(cx, |w, _| PathBuf::from(w.root_name())); - full_path.push(gen_file_name(&mut self.rng)); - if !is_dir { - full_path.set_extension("rs"); - } - break ClientOperation::CreateWorktreeEntry { - project_root_name, - is_local, - full_path, - is_dir, - }; - } - }, - }, - - // Query and mutate buffers - 60..=90 => { - let Some(project) = choose_random_project(client, &mut self.rng) else { - continue; - }; - let project_root_name = root_name_for_project(&project, cx); - let is_local = project.read_with(cx, |project, _| project.is_local()); - - match self.rng.gen_range(0..100_u32) { - // Manipulate an existing buffer - 0..=70 => { - let Some(buffer) = client - .buffers_for_project(&project) - .iter() - .choose(&mut self.rng) - .cloned() - else { - continue; - }; - - let full_path = buffer - .read_with(cx, |buffer, cx| buffer.file().unwrap().full_path(cx)); - - match self.rng.gen_range(0..100_u32) { - // Close the buffer - 0..=15 => { - break ClientOperation::CloseBuffer { - project_root_name, - is_local, - full_path, - }; - } - // Save the buffer - 16..=29 if buffer.read_with(cx, |b, _| b.is_dirty()) => { - let detach = self.rng.gen_bool(0.3); - break ClientOperation::SaveBuffer { - project_root_name, - is_local, - full_path, - detach, - }; - } - // Edit the buffer - 30..=69 => { - let edits = buffer.read_with(cx, |buffer, _| { - buffer.get_random_edits(&mut self.rng, 3) - }); - break ClientOperation::EditBuffer { - project_root_name, - is_local, - full_path, - edits, - }; - } - // Make an LSP request - _ => { - let offset = buffer.read_with(cx, |buffer, _| { - buffer.clip_offset( - self.rng.gen_range(0..=buffer.len()), - language::Bias::Left, - ) - }); - let detach = self.rng.gen(); - break ClientOperation::RequestLspDataInBuffer { - project_root_name, - full_path, - offset, - is_local, - kind: match self.rng.gen_range(0..5_u32) { - 0 => LspRequestKind::Rename, - 1 => LspRequestKind::Highlights, - 2 => LspRequestKind::Definition, - 3 => LspRequestKind::CodeAction, - 4.. => LspRequestKind::Completion, - }, - detach, - }; - } - } - } - - 71..=80 => { - let query = self.rng.gen_range('a'..='z').to_string(); - let detach = self.rng.gen_bool(0.3); - break ClientOperation::SearchProject { - project_root_name, - is_local, - query, - detach, - }; - } - - // Open a buffer - 81.. => { - let worktree = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .filter(|worktree| { - let worktree = worktree.read(cx); - worktree.is_visible() - && worktree.entries(false).any(|e| e.is_file()) - }) - .choose(&mut self.rng) - }); - let Some(worktree) = worktree else { continue }; - let full_path = worktree.read_with(cx, |worktree, _| { - let entry = worktree - .entries(false) - .filter(|e| e.is_file()) - .choose(&mut self.rng) - .unwrap(); - if entry.path.as_ref() == Path::new("") { - Path::new(worktree.root_name()).into() - } else { - Path::new(worktree.root_name()).join(&entry.path) - } - }); - break ClientOperation::OpenBuffer { - project_root_name, - is_local, - full_path, - }; - } - } - } - - // Update a git related action - 91..=95 => { - break ClientOperation::GitOperation { - operation: self.generate_git_operation(client), - }; - } - - // Create or update a file or directory - 96.. => { - let is_dir = self.rng.gen::(); - let content; - let mut path; - let dir_paths = client.fs().directories(false); - - if is_dir { - content = String::new(); - path = dir_paths.choose(&mut self.rng).unwrap().clone(); - path.push(gen_file_name(&mut self.rng)); - } else { - content = Alphanumeric.sample_string(&mut self.rng, 16); - - // Create a new file or overwrite an existing file - let file_paths = client.fs().files(); - if file_paths.is_empty() || self.rng.gen_bool(0.5) { - path = dir_paths.choose(&mut self.rng).unwrap().clone(); - path.push(gen_file_name(&mut self.rng)); - path.set_extension("rs"); - } else { - path = file_paths.choose(&mut self.rng).unwrap().clone() - }; - } - break ClientOperation::WriteFsEntry { - path, - is_dir, - content, - }; - } - } - }) - } - - fn generate_git_operation(&mut self, client: &TestClient) -> GitOperation { - fn generate_file_paths( - repo_path: &Path, - rng: &mut StdRng, - client: &TestClient, - ) -> Vec { - let mut paths = client - .fs() - .files() - .into_iter() - .filter(|path| path.starts_with(repo_path)) - .collect::>(); - - let count = rng.gen_range(0..=paths.len()); - paths.shuffle(rng); - paths.truncate(count); - - paths - .iter() - .map(|path| path.strip_prefix(repo_path).unwrap().to_path_buf()) - .collect::>() - } - - let repo_path = client - .fs() - .directories(false) - .choose(&mut self.rng) - .unwrap() - .clone(); - - match self.rng.gen_range(0..100_u32) { - 0..=25 => { - let file_paths = generate_file_paths(&repo_path, &mut self.rng, client); - - let contents = file_paths - .into_iter() - .map(|path| (path, Alphanumeric.sample_string(&mut self.rng, 16))) - .collect(); - - GitOperation::WriteGitIndex { - repo_path, - contents, - } - } - 26..=63 => { - let new_branch = (self.rng.gen_range(0..10) > 3) - .then(|| Alphanumeric.sample_string(&mut self.rng, 8)); - - GitOperation::WriteGitBranch { - repo_path, - new_branch, - } - } - 64..=100 => { - let file_paths = generate_file_paths(&repo_path, &mut self.rng, client); - - let statuses = file_paths - .into_iter() - .map(|paths| { - ( - paths, - match self.rng.gen_range(0..3_u32) { - 0 => GitFileStatus::Added, - 1 => GitFileStatus::Modified, - 2 => GitFileStatus::Conflict, - _ => unreachable!(), - }, - ) - }) - .collect::>(); - - let git_operation = self.rng.gen::(); - - GitOperation::WriteGitStatuses { - repo_path, - statuses, - git_operation, - } - } - _ => unreachable!(), - } - } - - fn next_root_dir_name(&mut self, user_id: UserId) -> String { - let user_ix = self - .users - .iter() - .position(|user| user.user_id == user_id) - .unwrap(); - let root_id = util::post_inc(&mut self.users[user_ix].next_root_id); - format!("dir-{user_id}-{root_id}") - } - - fn user(&mut self, user_id: UserId) -> &mut UserTestPlan { - let ix = self - .users - .iter() - .position(|user| user.user_id == user_id) - .unwrap(); - &mut self.users[ix] - } -} - -async fn simulate_client( - client: Rc, - mut operation_rx: futures::channel::mpsc::UnboundedReceiver, - plan: Arc>, - mut cx: TestAppContext, -) { - // Setup language server - let mut language = Language::new( - LanguageConfig { - name: "Rust".into(), - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - None, - ); - let _fake_language_servers = language - .set_fake_lsp_adapter(Arc::new(FakeLspAdapter { - name: "the-fake-language-server", - capabilities: lsp::LanguageServer::full_capabilities(), - initializer: Some(Box::new({ - let fs = client.app_state.fs.clone(); - move |fake_server: &mut FakeLanguageServer| { - fake_server.handle_request::( - |_, _| async move { - Ok(Some(lsp::CompletionResponse::Array(vec![ - lsp::CompletionItem { - text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { - range: lsp::Range::new( - lsp::Position::new(0, 0), - lsp::Position::new(0, 0), - ), - new_text: "the-new-text".to_string(), - })), - ..Default::default() - }, - ]))) - }, - ); - - fake_server.handle_request::( - |_, _| async move { - Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction( - lsp::CodeAction { - title: "the-code-action".to_string(), - ..Default::default() - }, - )])) - }, - ); - - fake_server.handle_request::( - |params, _| async move { - Ok(Some(lsp::PrepareRenameResponse::Range(lsp::Range::new( - params.position, - params.position, - )))) - }, - ); - - fake_server.handle_request::({ - let fs = fs.clone(); - move |_, cx| { - let background = cx.background(); - let mut rng = background.rng(); - let count = rng.gen_range::(1..3); - let files = fs.as_fake().files(); - let files = (0..count) - .map(|_| files.choose(&mut *rng).unwrap().clone()) - .collect::>(); - async move { - log::info!("LSP: Returning definitions in files {:?}", &files); - Ok(Some(lsp::GotoDefinitionResponse::Array( - files - .into_iter() - .map(|file| lsp::Location { - uri: lsp::Url::from_file_path(file).unwrap(), - range: Default::default(), - }) - .collect(), - ))) - } - } - }); - - fake_server.handle_request::( - move |_, cx| { - let mut highlights = Vec::new(); - let background = cx.background(); - let mut rng = background.rng(); - - let highlight_count = rng.gen_range(1..=5); - for _ in 0..highlight_count { - let start_row = rng.gen_range(0..100); - let start_column = rng.gen_range(0..100); - let end_row = rng.gen_range(0..100); - let end_column = rng.gen_range(0..100); - let start = PointUtf16::new(start_row, start_column); - let end = PointUtf16::new(end_row, end_column); - let range = if start > end { end..start } else { start..end }; - highlights.push(lsp::DocumentHighlight { - range: range_to_lsp(range.clone()), - kind: Some(lsp::DocumentHighlightKind::READ), - }); - } - highlights.sort_unstable_by_key(|highlight| { - (highlight.range.start, highlight.range.end) - }); - async move { Ok(Some(highlights)) } - }, - ); - } - })), - ..Default::default() - })) - .await; - client.app_state.languages.add(Arc::new(language)); - - while let Some(batch_id) = operation_rx.next().await { - let Some((operation, applied)) = plan.lock().next_client_operation(&client, batch_id, &cx) - else { - break; - }; - applied.store(true, SeqCst); - match apply_client_operation(&client, operation, &mut cx).await { - Ok(()) => {} - Err(TestError::Inapplicable) => { - applied.store(false, SeqCst); - log::info!("skipped operation"); - } - Err(TestError::Other(error)) => { - log::error!("{} error: {}", client.username, error); - } - } - cx.background().simulate_random_delay().await; - } - log::info!("{}: done", client.username); -} - -fn buffer_for_full_path( - client: &TestClient, - project: &ModelHandle, - full_path: &PathBuf, - cx: &TestAppContext, -) -> Option> { - client - .buffers_for_project(project) - .iter() - .find(|buffer| { - buffer.read_with(cx, |buffer, cx| { - buffer.file().unwrap().full_path(cx) == *full_path - }) - }) - .cloned() -} - -fn project_for_root_name( - client: &TestClient, - root_name: &str, - cx: &TestAppContext, -) -> Option> { - if let Some(ix) = project_ix_for_root_name(&*client.local_projects(), root_name, cx) { - return Some(client.local_projects()[ix].clone()); - } - if let Some(ix) = project_ix_for_root_name(&*client.remote_projects(), root_name, cx) { - return Some(client.remote_projects()[ix].clone()); - } - None -} - -fn project_ix_for_root_name( - projects: &[ModelHandle], - root_name: &str, - cx: &TestAppContext, -) -> Option { - projects.iter().position(|project| { - project.read_with(cx, |project, cx| { - let worktree = project.visible_worktrees(cx).next().unwrap(); - worktree.read(cx).root_name() == root_name - }) - }) -} - -fn root_name_for_project(project: &ModelHandle, cx: &TestAppContext) -> String { - project.read_with(cx, |project, cx| { - project - .visible_worktrees(cx) - .next() - .unwrap() - .read(cx) - .root_name() - .to_string() - }) -} - -fn project_path_for_full_path( - project: &ModelHandle, - full_path: &Path, - cx: &TestAppContext, -) -> Option { - let mut components = full_path.components(); - let root_name = components.next().unwrap().as_os_str().to_str().unwrap(); - let path = components.as_path().into(); - let worktree_id = project.read_with(cx, |project, cx| { - project.worktrees(cx).find_map(|worktree| { - let worktree = worktree.read(cx); - if worktree.root_name() == root_name { - Some(worktree.id()) - } else { - None - } - }) - })?; - Some(ProjectPath { worktree_id, path }) -} - -async fn ensure_project_shared( - project: &ModelHandle, - client: &TestClient, - cx: &mut TestAppContext, -) { - let first_root_name = root_name_for_project(project, cx); - let active_call = cx.read(ActiveCall::global); - if active_call.read_with(cx, |call, _| call.room().is_some()) - && project.read_with(cx, |project, _| project.is_local() && !project.is_shared()) - { - match active_call - .update(cx, |call, cx| call.share_project(project.clone(), cx)) - .await - { - Ok(project_id) => { - log::info!( - "{}: shared project {} with id {}", - client.username, - first_root_name, - project_id - ); - } - Err(error) => { - log::error!( - "{}: error sharing project {}: {:?}", - client.username, - first_root_name, - error - ); - } - } - } -} - -fn choose_random_project(client: &TestClient, rng: &mut StdRng) -> Option> { - client - .local_projects() - .iter() - .chain(client.remote_projects().iter()) - .choose(rng) - .cloned() -} - -fn gen_file_name(rng: &mut StdRng) -> String { - let mut name = String::new(); - for _ in 0..10 { - let letter = rng.gen_range('a'..='z'); - name.push(letter); - } - name -} - -fn path_env_var(name: &str) -> Option { - let value = env::var(name).ok()?; - let mut path = PathBuf::from(value); - if path.is_relative() { - let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - abs_path.pop(); - abs_path.pop(); - abs_path.push(path); - path = abs_path - } - Some(path) -} diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs new file mode 100644 index 0000000000..dc102b75c6 --- /dev/null +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -0,0 +1,694 @@ +use crate::{ + db::{self, Database, NewUserParams, UserId}, + rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + tests::{TestClient, TestServer}, +}; +use async_trait::async_trait; +use futures::StreamExt; +use gpui::{executor::Deterministic, Task, TestAppContext}; +use parking_lot::Mutex; +use rand::prelude::*; +use rpc::RECEIVE_TIMEOUT; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use settings::SettingsStore; +use std::{ + env, + path::PathBuf, + rc::Rc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; + +lazy_static::lazy_static! { + static ref PLAN_LOAD_PATH: Option = path_env_var("LOAD_PLAN"); + static ref PLAN_SAVE_PATH: Option = path_env_var("SAVE_PLAN"); + static ref MAX_PEERS: usize = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(3); + static ref MAX_OPERATIONS: usize = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(10); + +} + +static LOADED_PLAN_JSON: Mutex>> = Mutex::new(None); +static LAST_PLAN: Mutex Vec>>> = Mutex::new(None); + +struct TestPlan { + rng: StdRng, + replay: bool, + stored_operations: Vec<(StoredOperation, Arc)>, + max_operations: usize, + operation_ix: usize, + users: Vec, + next_batch_id: usize, + allow_server_restarts: bool, + allow_client_reconnection: bool, + allow_client_disconnection: bool, +} + +pub struct UserTestPlan { + pub user_id: UserId, + pub username: String, + pub allow_client_reconnection: bool, + pub allow_client_disconnection: bool, + next_root_id: usize, + operation_ix: usize, + online: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum StoredOperation { + Server(ServerOperation), + Client { + user_id: UserId, + batch_id: usize, + operation: T, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +enum ServerOperation { + AddConnection { + user_id: UserId, + }, + RemoveConnection { + user_id: UserId, + }, + BounceConnection { + user_id: UserId, + }, + RestartServer, + MutateClients { + batch_id: usize, + #[serde(skip_serializing)] + #[serde(skip_deserializing)] + user_ids: Vec, + quiesce: bool, + }, +} + +pub enum TestError { + Inapplicable, + Other(anyhow::Error), +} + +#[async_trait(?Send)] +pub trait RandomizedTest: 'static + Sized { + type Operation: Send + Clone + Serialize + DeserializeOwned; + + fn generate_operation( + client: &TestClient, + rng: &mut StdRng, + plan: &mut UserTestPlan, + cx: &TestAppContext, + ) -> Self::Operation; + + async fn on_client_added(client: &Rc); + + fn on_clients_quiesced(client: &[(Rc, TestAppContext)]); + + async fn apply_operation( + client: &TestClient, + operation: Self::Operation, + cx: &mut TestAppContext, + ) -> Result<(), TestError>; +} + +pub async fn run_randomized_test( + cx: &mut TestAppContext, + deterministic: Arc, + rng: StdRng, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let plan = TestPlan::::new(server.app_state.db.clone(), rng).await; + + LAST_PLAN.lock().replace({ + let plan = plan.clone(); + Box::new(move || plan.lock().serialize()) + }); + + let mut clients = Vec::new(); + let mut client_tasks = Vec::new(); + let mut operation_channels = Vec::new(); + loop { + let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else { + break; + }; + applied.store(true, SeqCst); + let did_apply = TestPlan::apply_server_operation( + plan.clone(), + deterministic.clone(), + &mut server, + &mut clients, + &mut client_tasks, + &mut operation_channels, + next_operation, + cx, + ) + .await; + if !did_apply { + applied.store(false, SeqCst); + } + } + + drop(operation_channels); + deterministic.start_waiting(); + futures::future::join_all(client_tasks).await; + deterministic.finish_waiting(); + + deterministic.run_until_parked(); + T::on_clients_quiesced(&clients); + + for (client, mut cx) in clients { + cx.update(|cx| { + let store = cx.remove_global::(); + cx.clear_globals(); + cx.set_global(store); + drop(client); + }); + } + deterministic.run_until_parked(); + + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, plan.lock().serialize()).unwrap(); + } +} + +pub fn save_randomized_test_plan() { + if let Some(serialize_plan) = LAST_PLAN.lock().take() { + if let Some(path) = &*PLAN_SAVE_PATH { + eprintln!("saved test plan to path {:?}", path); + std::fs::write(path, serialize_plan()).unwrap(); + } + } +} + +impl TestPlan { + pub async fn new(db: Arc, mut rng: StdRng) -> Arc> { + let allow_server_restarts = rng.gen_bool(0.7); + let allow_client_reconnection = rng.gen_bool(0.7); + let allow_client_disconnection = rng.gen_bool(0.1); + + let mut users = Vec::new(); + for ix in 0..*MAX_PEERS { + let username = format!("user-{}", ix + 1); + let user_id = db + .create_user( + &format!("{username}@example.com"), + false, + NewUserParams { + github_login: username.clone(), + github_user_id: (ix + 1) as i32, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + users.push(UserTestPlan { + user_id, + username, + online: false, + next_root_id: 0, + operation_ix: 0, + allow_client_disconnection, + allow_client_reconnection, + }); + } + + for (ix, user_a) in users.iter().enumerate() { + for user_b in &users[ix + 1..] { + db.send_contact_request(user_a.user_id, user_b.user_id) + .await + .unwrap(); + db.respond_to_contact_request(user_b.user_id, user_a.user_id, true) + .await + .unwrap(); + } + } + + let plan = Arc::new(Mutex::new(Self { + replay: false, + allow_server_restarts, + allow_client_reconnection, + allow_client_disconnection, + stored_operations: Vec::new(), + operation_ix: 0, + next_batch_id: 0, + max_operations: *MAX_OPERATIONS, + users, + rng, + })); + + if let Some(path) = &*PLAN_LOAD_PATH { + let json = LOADED_PLAN_JSON + .lock() + .get_or_insert_with(|| { + eprintln!("loaded test plan from path {:?}", path); + std::fs::read(path).unwrap() + }) + .clone(); + plan.lock().deserialize(json); + } + + plan + } + + fn deserialize(&mut self, json: Vec) { + let stored_operations: Vec> = + serde_json::from_slice(&json).unwrap(); + self.replay = true; + self.stored_operations = stored_operations + .iter() + .cloned() + .enumerate() + .map(|(i, mut operation)| { + let did_apply = Arc::new(AtomicBool::new(false)); + if let StoredOperation::Server(ServerOperation::MutateClients { + batch_id: current_batch_id, + user_ids, + .. + }) = &mut operation + { + assert!(user_ids.is_empty()); + user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| { + if let StoredOperation::Client { + user_id, batch_id, .. + } = operation + { + if batch_id == current_batch_id { + return Some(user_id); + } + } + None + })); + user_ids.sort_unstable(); + } + (operation, did_apply) + }) + .collect() + } + + fn serialize(&mut self) -> Vec { + // Format each operation as one line + let mut json = Vec::new(); + json.push(b'['); + for (operation, applied) in &self.stored_operations { + if !applied.load(SeqCst) { + continue; + } + if json.len() > 1 { + json.push(b','); + } + json.extend_from_slice(b"\n "); + serde_json::to_writer(&mut json, operation).unwrap(); + } + json.extend_from_slice(b"\n]\n"); + json + } + + fn next_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option<(ServerOperation, Arc)> { + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) { + self.operation_ix += 1; + if let (StoredOperation::Server(operation), applied) = stored_operation { + return Some((operation.clone(), applied.clone())); + } + } + None + } else { + let operation = self.generate_server_operation(clients)?; + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations + .push((StoredOperation::Server(operation.clone()), applied.clone())); + Some((operation, applied)) + } + } + + fn next_client_operation( + &mut self, + client: &TestClient, + current_batch_id: usize, + cx: &TestAppContext, + ) -> Option<(T::Operation, Arc)> { + let current_user_id = client.current_user_id(cx); + let user_ix = self + .users + .iter() + .position(|user| user.user_id == current_user_id) + .unwrap(); + let user_plan = &mut self.users[user_ix]; + + if self.replay { + while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) { + user_plan.operation_ix += 1; + if let ( + StoredOperation::Client { + user_id, operation, .. + }, + applied, + ) = stored_operation + { + if user_id == ¤t_user_id { + return Some((operation.clone(), applied.clone())); + } + } + } + None + } else { + if self.operation_ix == self.max_operations { + return None; + } + self.operation_ix += 1; + let operation = T::generate_operation( + client, + &mut self.rng, + self.users + .iter_mut() + .find(|user| user.user_id == current_user_id) + .unwrap(), + cx, + ); + let applied = Arc::new(AtomicBool::new(false)); + self.stored_operations.push(( + StoredOperation::Client { + user_id: current_user_id, + batch_id: current_batch_id, + operation: operation.clone(), + }, + applied.clone(), + )); + Some((operation, applied)) + } + } + + fn generate_server_operation( + &mut self, + clients: &[(Rc, TestAppContext)], + ) -> Option { + if self.operation_ix == self.max_operations { + return None; + } + + Some(loop { + break match self.rng.gen_range(0..100) { + 0..=29 if clients.len() < self.users.len() => { + let user = self + .users + .iter() + .filter(|u| !u.online) + .choose(&mut self.rng) + .unwrap(); + self.operation_ix += 1; + ServerOperation::AddConnection { + user_id: user.user_id, + } + } + 30..=34 if clients.len() > 1 && self.allow_client_disconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::RemoveConnection { user_id } + } + 35..=39 if clients.len() > 1 && self.allow_client_reconnection => { + let (client, cx) = &clients[self.rng.gen_range(0..clients.len())]; + let user_id = client.current_user_id(cx); + self.operation_ix += 1; + ServerOperation::BounceConnection { user_id } + } + 40..=44 if self.allow_server_restarts && clients.len() > 1 => { + self.operation_ix += 1; + ServerOperation::RestartServer + } + _ if !clients.is_empty() => { + let count = self + .rng + .gen_range(1..10) + .min(self.max_operations - self.operation_ix); + let batch_id = util::post_inc(&mut self.next_batch_id); + let mut user_ids = (0..count) + .map(|_| { + let ix = self.rng.gen_range(0..clients.len()); + let (client, cx) = &clients[ix]; + client.current_user_id(cx) + }) + .collect::>(); + user_ids.sort_unstable(); + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce: self.rng.gen_bool(0.7), + } + } + _ => continue, + }; + }) + } + + async fn apply_server_operation( + plan: Arc>, + deterministic: Arc, + server: &mut TestServer, + clients: &mut Vec<(Rc, TestAppContext)>, + client_tasks: &mut Vec>, + operation_channels: &mut Vec>, + operation: ServerOperation, + cx: &mut TestAppContext, + ) -> bool { + match operation { + ServerOperation::AddConnection { user_id } => { + let username; + { + let mut plan = plan.lock(); + let user = plan.user(user_id); + if user.online { + return false; + } + user.online = true; + username = user.username.clone(); + }; + log::info!("adding new connection for {}", username); + let next_entity_id = (user_id.0 * 10_000) as usize; + let mut client_cx = TestAppContext::new( + cx.foreground_platform(), + cx.platform(), + deterministic.build_foreground(user_id.0 as usize), + deterministic.build_background(), + cx.font_cache(), + cx.leak_detector(), + next_entity_id, + cx.function_name.clone(), + ); + + let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded(); + let client = Rc::new(server.create_client(&mut client_cx, &username).await); + operation_channels.push(operation_tx); + clients.push((client.clone(), client_cx.clone())); + client_tasks.push(client_cx.foreground().spawn(Self::simulate_client( + plan.clone(), + client, + operation_rx, + client_cx, + ))); + + log::info!("added connection for {}", username); + } + + ServerOperation::RemoveConnection { + user_id: removed_user_id, + } => { + log::info!("simulating full disconnection of user {}", removed_user_id); + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == removed_user_id); + let Some(client_ix) = client_ix else { + return false; + }; + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(removed_user_id) + .collect::>(); + assert_eq!(user_connection_ids.len(), 1); + let removed_peer_id = user_connection_ids[0].into(); + let (client, mut client_cx) = clients.remove(client_ix); + let client_task = client_tasks.remove(client_ix); + operation_channels.remove(client_ix); + server.forbid_connections(); + server.disconnect_client(removed_peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + deterministic.start_waiting(); + log::info!("waiting for user {} to exit...", removed_user_id); + client_task.await; + deterministic.finish_waiting(); + server.allow_connections(); + + for project in client.remote_projects().iter() { + project.read_with(&client_cx, |project, _| { + assert!( + project.is_read_only(), + "project {:?} should be read only", + project.remote_id() + ) + }); + } + + for (client, cx) in clients { + let contacts = server + .app_state + .db + .get_contacts(client.current_user_id(cx)) + .await + .unwrap(); + let pool = server.connection_pool.lock(); + for contact in contacts { + if let db::Contact::Accepted { user_id, busy, .. } = contact { + if user_id == removed_user_id { + assert!(!pool.is_user_online(user_id)); + assert!(!busy); + } + } + } + } + + log::info!("{} removed", client.username); + plan.lock().user(removed_user_id).online = false; + client_cx.update(|cx| { + cx.clear_globals(); + drop(client); + }); + } + + ServerOperation::BounceConnection { user_id } => { + log::info!("simulating temporary disconnection of user {}", user_id); + let user_connection_ids = server + .connection_pool + .lock() + .user_connection_ids(user_id) + .collect::>(); + if user_connection_ids.is_empty() { + return false; + } + assert_eq!(user_connection_ids.len(), 1); + let peer_id = user_connection_ids[0].into(); + server.disconnect_client(peer_id); + deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + } + + ServerOperation::RestartServer => { + log::info!("simulating server restart"); + server.reset().await; + deterministic.advance_clock(RECEIVE_TIMEOUT); + server.start().await.unwrap(); + deterministic.advance_clock(CLEANUP_TIMEOUT); + let environment = &server.app_state.config.zed_environment; + let (stale_room_ids, _) = server + .app_state + .db + .stale_server_resource_ids(environment, server.id()) + .await + .unwrap(); + assert_eq!(stale_room_ids, vec![]); + } + + ServerOperation::MutateClients { + user_ids, + batch_id, + quiesce, + } => { + let mut applied = false; + for user_id in user_ids { + let client_ix = clients + .iter() + .position(|(client, cx)| client.current_user_id(cx) == user_id); + let Some(client_ix) = client_ix else { continue }; + applied = true; + if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) { + log::error!("error signaling user {user_id}: {err}"); + } + } + + if quiesce && applied { + deterministic.run_until_parked(); + T::on_clients_quiesced(&clients); + } + + return applied; + } + } + true + } + + async fn simulate_client( + plan: Arc>, + client: Rc, + mut operation_rx: futures::channel::mpsc::UnboundedReceiver, + mut cx: TestAppContext, + ) { + T::on_client_added(&client).await; + + while let Some(batch_id) = operation_rx.next().await { + let Some((operation, applied)) = + plan.lock().next_client_operation(&client, batch_id, &cx) + else { + break; + }; + applied.store(true, SeqCst); + match T::apply_operation(&client, operation, &mut cx).await { + Ok(()) => {} + Err(TestError::Inapplicable) => { + applied.store(false, SeqCst); + log::info!("skipped operation"); + } + Err(TestError::Other(error)) => { + log::error!("{} error: {}", client.username, error); + } + } + cx.background().simulate_random_delay().await; + } + log::info!("{}: done", client.username); + } + + fn user(&mut self, user_id: UserId) -> &mut UserTestPlan { + self.users + .iter_mut() + .find(|user| user.user_id == user_id) + .unwrap() + } +} + +impl UserTestPlan { + pub fn next_root_dir_name(&mut self) -> String { + let user_id = self.user_id; + let root_id = util::post_inc(&mut self.next_root_id); + format!("dir-{user_id}-{root_id}") + } +} + +impl From for TestError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value) + } +} + +fn path_env_var(name: &str) -> Option { + let value = env::var(name).ok()?; + let mut path = PathBuf::from(value); + if path.is_relative() { + let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + abs_path.pop(); + abs_path.pop(); + abs_path.push(path); + path = abs_path + } + Some(path) +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs new file mode 100644 index 0000000000..44f6ac1450 --- /dev/null +++ b/crates/collab/src/tests/test_server.rs @@ -0,0 +1,551 @@ +use crate::{ + db::{tests::TestDb, NewUserParams, UserId}, + executor::Executor, + rpc::{Server, CLEANUP_TIMEOUT}, + AppState, +}; +use anyhow::anyhow; +use call::ActiveCall; +use channel::ChannelStore; +use client::{ + self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, +}; +use collections::{HashMap, HashSet}; +use fs::FakeFs; +use futures::{channel::oneshot, StreamExt as _}; +use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; +use language::LanguageRegistry; +use parking_lot::Mutex; +use project::{Project, WorktreeId}; +use settings::SettingsStore; +use std::{ + cell::{Ref, RefCell, RefMut}, + env, + ops::{Deref, DerefMut}, + path::Path, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + Arc, + }, +}; +use util::http::FakeHttpClient; +use workspace::Workspace; + +pub struct TestServer { + pub app_state: Arc, + pub test_live_kit_server: Arc, + server: Arc, + connection_killers: Arc>>>, + forbid_connections: Arc, + _test_db: TestDb, +} + +pub struct TestClient { + pub username: String, + pub app_state: Arc, + state: RefCell, +} + +#[derive(Default)] +struct TestClientState { + local_projects: Vec>, + remote_projects: Vec>, + buffers: HashMap, HashSet>>, +} + +pub struct ContactsSummary { + pub current: Vec, + pub outgoing_requests: Vec, + pub incoming_requests: Vec, +} + +impl TestServer { + pub async fn start(deterministic: &Arc) -> Self { + static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); + + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(deterministic.build_background()) + } else { + TestDb::sqlite(deterministic.build_background()) + }; + let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); + let live_kit_server = live_kit_client::TestServer::create( + format!("http://livekit.{}.test", live_kit_server_id), + format!("devkey-{}", live_kit_server_id), + format!("secret-{}", live_kit_server_id), + deterministic.build_background(), + ) + .unwrap(); + let app_state = Self::build_app_state(&test_db, &live_kit_server).await; + let epoch = app_state + .db + .create_server(&app_state.config.zed_environment) + .await + .unwrap(); + let server = Server::new( + epoch, + app_state.clone(), + Executor::Deterministic(deterministic.build_background()), + ); + server.start().await.unwrap(); + // Advance clock to ensure the server's cleanup task is finished. + deterministic.advance_clock(CLEANUP_TIMEOUT); + Self { + app_state, + server, + connection_killers: Default::default(), + forbid_connections: Default::default(), + _test_db: test_db, + test_live_kit_server: live_kit_server, + } + } + + pub async fn reset(&self) { + self.app_state.db.reset(); + let epoch = self + .app_state + .db + .create_server(&self.app_state.config.zed_environment) + .await + .unwrap(); + self.server.reset(epoch); + } + + pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + cx.update(|cx| { + if cx.has_global::() { + panic!("Same cx used to create two test clients") + } + cx.set_global(SettingsStore::test(cx)); + }); + + let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await + { + user.id + } else { + self.app_state + .db + .create_user( + &format!("{name}@example.com"), + false, + NewUserParams { + github_login: name.into(), + github_user_id: 0, + invite_count: 0, + }, + ) + .await + .expect("creating user failed") + .user_id + }; + let client_name = name.to_string(); + let mut client = cx.read(|cx| Client::new(http.clone(), cx)); + let server = self.server.clone(); + let db = self.app_state.db.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); + + Arc::get_mut(&mut client) + .unwrap() + .set_id(user_id.0 as usize) + .override_authenticate(move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok(Credentials { + user_id: user_id.0 as u64, + access_token, + }) + }) + }) + .override_establish_connection(move |credentials, cx| { + assert_eq!(credentials.user_id, user_id.0 as u64); + assert_eq!(credentials.access_token, "the-token"); + + let server = server.clone(); + let db = db.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(EstablishConnectionError::other(anyhow!( + "server is forbidding connections" + ))) + } else { + let (client_conn, server_conn, killed) = + Connection::in_memory(cx.background()); + let (connection_id_tx, connection_id_rx) = oneshot::channel(); + let user = db + .get_user_by_id(user_id) + .await + .expect("retrieving user failed") + .unwrap(); + cx.background() + .spawn(server.handle_connection( + server_conn, + client_name, + user, + Some(connection_id_tx), + Executor::Deterministic(cx.background()), + )) + .detach(); + let connection_id = connection_id_rx.await.unwrap(); + connection_killers + .lock() + .insert(connection_id.into(), killed); + Ok(client_conn) + } + }) + }); + + let fs = FakeFs::new(cx.background()); + let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); + let channel_store = + cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx)); + let app_state = Arc::new(workspace::AppState { + client: client.clone(), + user_store: user_store.clone(), + channel_store: channel_store.clone(), + languages: Arc::new(LanguageRegistry::test()), + fs: fs.clone(), + build_window_options: |_, _, _| Default::default(), + initialize_workspace: |_, _, _, _| Task::ready(Ok(())), + background_actions: || &[], + }); + + cx.update(|cx| { + theme::init((), cx); + Project::init(&client, cx); + client::init(&client, cx); + language::init(cx); + editor::init_settings(cx); + workspace::init(app_state.clone(), cx); + audio::init((), cx); + call::init(client.clone(), user_store.clone(), cx); + channel::init(&client); + }); + + client + .authenticate_and_connect(false, &cx.to_async()) + .await + .unwrap(); + + let client = TestClient { + app_state, + username: name.to_string(), + state: Default::default(), + }; + client.wait_for_current_user(cx).await; + client + } + + pub fn disconnect_client(&self, peer_id: PeerId) { + self.connection_killers + .lock() + .remove(&peer_id) + .unwrap() + .store(true, SeqCst); + } + + pub fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + pub fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); + } + + pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + for ix in 1..clients.len() { + let (left, right) = clients.split_at_mut(ix); + let (client_a, cx_a) = left.last_mut().unwrap(); + for (client_b, cx_b) in right { + client_a + .app_state + .user_store + .update(*cx_a, |store, cx| { + store.request_contact(client_b.user_id().unwrap(), cx) + }) + .await + .unwrap(); + cx_a.foreground().run_until_parked(); + client_b + .app_state + .user_store + .update(*cx_b, |store, cx| { + store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx) + }) + .await + .unwrap(); + } + } + } + + pub async fn make_channel( + &self, + channel: &str, + admin: (&TestClient, &mut TestAppContext), + members: &mut [(&TestClient, &mut TestAppContext)], + ) -> u64 { + let (admin_client, admin_cx) = admin; + let channel_id = admin_client + .app_state + .channel_store + .update(admin_cx, |channel_store, cx| { + channel_store.create_channel(channel, None, cx) + }) + .await + .unwrap(); + + for (member_client, member_cx) in members { + admin_client + .app_state + .channel_store + .update(admin_cx, |channel_store, cx| { + channel_store.invite_member( + channel_id, + member_client.user_id().unwrap(), + false, + cx, + ) + }) + .await + .unwrap(); + + admin_cx.foreground().run_until_parked(); + + member_client + .app_state + .channel_store + .update(*member_cx, |channels, _| { + channels.respond_to_channel_invite(channel_id, true) + }) + .await + .unwrap(); + } + + channel_id + } + + pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) { + self.make_contacts(clients).await; + + let (left, right) = clients.split_at_mut(1); + let (_client_a, cx_a) = &mut left[0]; + let active_call_a = cx_a.read(ActiveCall::global); + + for (client_b, cx_b) in right { + let user_id_b = client_b.current_user_id(*cx_b).to_proto(); + active_call_a + .update(*cx_a, |call, cx| call.invite(user_id_b, None, cx)) + .await + .unwrap(); + + cx_b.foreground().run_until_parked(); + let active_call_b = cx_b.read(ActiveCall::global); + active_call_b + .update(*cx_b, |call, cx| call.accept_incoming(cx)) + .await + .unwrap(); + } + } + + pub async fn build_app_state( + test_db: &TestDb, + fake_server: &live_kit_client::TestServer, + ) -> Arc { + Arc::new(AppState { + db: test_db.db().clone(), + live_kit_client: Some(Arc::new(fake_server.create_api_client())), + config: Default::default(), + }) + } +} + +impl Deref for TestServer { + type Target = Server; + + fn deref(&self) -> &Self::Target { + &self.server + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.server.teardown(); + self.test_live_kit_server.teardown().unwrap(); + } +} + +impl Deref for TestClient { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.app_state.client + } +} + +impl TestClient { + pub fn fs(&self) -> &FakeFs { + self.app_state.fs.as_fake() + } + + pub fn channel_store(&self) -> &ModelHandle { + &self.app_state.channel_store + } + + pub fn user_store(&self) -> &ModelHandle { + &self.app_state.user_store + } + + pub fn language_registry(&self) -> &Arc { + &self.app_state.languages + } + + pub fn client(&self) -> &Arc { + &self.app_state.client + } + + pub fn current_user_id(&self, cx: &TestAppContext) -> UserId { + UserId::from_proto( + self.app_state + .user_store + .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), + ) + } + + pub async fn wait_for_current_user(&self, cx: &TestAppContext) { + let mut authed_user = self + .app_state + .user_store + .read_with(cx, |user_store, _| user_store.watch_current_user()); + while authed_user.next().await.unwrap().is_none() {} + } + + pub async fn clear_contacts(&self, cx: &mut TestAppContext) { + self.app_state + .user_store + .update(cx, |store, _| store.clear_contacts()) + .await; + } + + pub fn local_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.local_projects) + } + + pub fn remote_projects<'a>(&'a self) -> impl Deref>> + 'a { + Ref::map(self.state.borrow(), |state| &state.remote_projects) + } + + pub fn local_projects_mut<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects) + } + + pub fn remote_projects_mut<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects) + } + + pub fn buffers_for_project<'a>( + &'a self, + project: &ModelHandle, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| { + state.buffers.entry(project.clone()).or_default() + }) + } + + pub fn buffers<'a>( + &'a self, + ) -> impl DerefMut, HashSet>>> + 'a + { + RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) + } + + pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { + self.app_state + .user_store + .read_with(cx, |store, _| ContactsSummary { + current: store + .contacts() + .iter() + .map(|contact| contact.user.github_login.clone()) + .collect(), + outgoing_requests: store + .outgoing_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + incoming_requests: store + .incoming_contact_requests() + .iter() + .map(|user| user.github_login.clone()) + .collect(), + }) + } + + pub async fn build_local_project( + &self, + root_path: impl AsRef, + cx: &mut TestAppContext, + ) -> (ModelHandle, WorktreeId) { + let project = cx.update(|cx| { + Project::local( + self.client().clone(), + self.app_state.user_store.clone(), + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }); + let (worktree, _) = project + .update(cx, |p, cx| { + p.find_or_create_local_worktree(root_path, true, cx) + }) + .await + .unwrap(); + worktree + .read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + (project, worktree.read_with(cx, |tree, _| tree.id())) + } + + pub async fn build_remote_project( + &self, + host_project_id: u64, + guest_cx: &mut TestAppContext, + ) -> ModelHandle { + let active_call = guest_cx.read(ActiveCall::global); + let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone()); + room.update(guest_cx, |room, cx| { + room.join_project( + host_project_id, + self.app_state.languages.clone(), + self.app_state.fs.clone(), + cx, + ) + }) + .await + .unwrap() + } + + pub fn build_workspace( + &self, + project: &ModelHandle, + cx: &mut TestAppContext, + ) -> WindowHandle { + cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx)) + } +} + +impl Drop for TestClient { + fn drop(&mut self) { + self.app_state.client.teardown(); + } +} diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index 7f70bc6a91..16f293afbe 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -37,8 +37,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { Some("seed") => starting_seed = parse_int(&meta.lit)?, Some("on_failure") => { if let Lit::Str(name) = meta.lit { - let ident = Ident::new(&name.value(), name.span()); - on_failure_fn_name = quote!(Some(#ident)); + let mut path = syn::Path { + leading_colon: None, + segments: Default::default(), + }; + for part in name.value().split("::") { + path.segments.push(Ident::new(part, name.span()).into()); + } + on_failure_fn_name = quote!(Some(#path)); } else { return Err(TokenStream::from( syn::Error::new( From e779adfe468fca4a6fd81435eec5f717846aeb52 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 6 Sep 2023 14:09:36 -0700 Subject: [PATCH 53/60] Add basic randomized integration test for channel notes --- crates/channel/src/channel_store.rs | 11 +- crates/collab/src/db/queries/channels.rs | 14 ++ .../src/tests/random_channel_buffer_tests.rs | 237 +++++++++++++++++- .../random_project_collaboration_tests.rs | 26 +- .../src/tests/randomized_test_helpers.rs | 37 ++- crates/collab/src/tests/test_server.rs | 9 +- 6 files changed, 295 insertions(+), 39 deletions(-) diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 3d2f61d61f..a4c8da6df4 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Result}; use client::{Client, Subscription, User, UserId, UserStore}; use collections::{hash_map, HashMap, HashSet}; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; -use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; +use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; @@ -152,6 +152,15 @@ impl ChannelStore { self.channels_by_id.get(&channel_id) } + pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool { + if let Some(buffer) = self.opened_buffers.get(&channel_id) { + if let OpenedChannelBuffer::Open(buffer) = buffer { + return buffer.upgrade(cx).is_some(); + } + } + false + } + pub fn open_channel_buffer( &mut self, channel_id: ChannelId, diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index e3d3643a61..5da4dd1464 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -1,6 +1,20 @@ use super::*; impl Database { + #[cfg(test)] + pub async fn all_channels(&self) -> Result> { + self.transaction(move |tx| async move { + let mut channels = Vec::new(); + let mut rows = channel::Entity::find().stream(&*tx).await?; + while let Some(row) = rows.next().await { + let row = row?; + channels.push((row.id, row.name)); + } + Ok(channels) + }) + .await + } + pub async fn create_root_channel( &self, name: &str, diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs index 929e567977..933683eaa6 100644 --- a/crates/collab/src/tests/random_channel_buffer_tests.rs +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -1,12 +1,16 @@ -use crate::tests::{run_randomized_test, RandomizedTest, TestClient, TestError, UserTestPlan}; +use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, TestAppContext}; -use rand::rngs::StdRng; +use rand::prelude::*; use serde_derive::{Deserialize, Serialize}; -use std::{rc::Rc, sync::Arc}; +use std::{ops::Range, rc::Rc, sync::Arc}; +use text::Bias; -#[gpui::test] +#[gpui::test( + iterations = 100, + on_failure = "crate::tests::save_randomized_test_plan" +)] async fn test_random_channel_buffers( cx: &mut TestAppContext, deterministic: Arc, @@ -19,20 +23,105 @@ struct RandomChannelBufferTest; #[derive(Clone, Serialize, Deserialize)] enum ChannelBufferOperation { - Join, + JoinChannelNotes { + channel_name: String, + }, + LeaveChannelNotes { + channel_name: String, + }, + EditChannelNotes { + channel_name: String, + edits: Vec<(Range, Arc)>, + }, + Noop, } +const CHANNEL_COUNT: usize = 3; + #[async_trait(?Send)] impl RandomizedTest for RandomChannelBufferTest { type Operation = ChannelBufferOperation; + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for ix in 0..CHANNEL_COUNT { + let id = db + .create_channel( + &format!("channel-{ix}"), + None, + &format!("livekit-room-{ix}"), + users[0].user_id, + ) + .await + .unwrap(); + for user in &users[1..] { + db.invite_channel_member(id, user.user_id, users[0].user_id, false) + .await + .unwrap(); + db.respond_to_channel_invite(id, user.user_id, true) + .await + .unwrap(); + } + } + } + fn generate_operation( client: &TestClient, rng: &mut StdRng, - plan: &mut UserTestPlan, + _: &mut UserTestPlan, cx: &TestAppContext, ) -> ChannelBufferOperation { - ChannelBufferOperation::Join + let channel_store = client.channel_store().clone(); + let channel_buffers = client.channel_buffers(); + + // When signed out, we can't do anything unless a channel buffer is + // already open. + if channel_buffers.is_empty() + && channel_store.read_with(cx, |store, _| store.channel_count() == 0) + { + return ChannelBufferOperation::Noop; + } + + loop { + match rng.gen_range(0..100_u32) { + 0..=29 => { + let channel_name = client.channel_store().read_with(cx, |store, cx| { + store.channels().find_map(|(_, channel)| { + if store.has_open_channel_buffer(channel.id, cx) { + None + } else { + Some(channel.name.clone()) + } + }) + }); + if let Some(channel_name) = channel_name { + break ChannelBufferOperation::JoinChannelNotes { channel_name }; + } + } + + 30..=40 => { + if let Some(buffer) = channel_buffers.iter().choose(rng) { + let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone()); + break ChannelBufferOperation::LeaveChannelNotes { channel_name }; + } + } + + _ => { + if let Some(buffer) = channel_buffers.iter().choose(rng) { + break buffer.read_with(cx, |b, _| { + let channel_name = b.channel().name.clone(); + let edits = b + .buffer() + .read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3)); + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } + }); + } + } + } + } } async fn apply_operation( @@ -40,10 +129,140 @@ impl RandomizedTest for RandomChannelBufferTest { operation: ChannelBufferOperation, cx: &mut TestAppContext, ) -> Result<(), TestError> { + match operation { + ChannelBufferOperation::JoinChannelNotes { channel_name } => { + let buffer = client.channel_store().update(cx, |store, cx| { + let channel_id = store + .channels() + .find(|(_, c)| c.name == channel_name) + .unwrap() + .1 + .id; + if store.has_open_channel_buffer(channel_id, cx) { + Err(TestError::Inapplicable) + } else { + Ok(store.open_channel_buffer(channel_id, cx)) + } + })?; + + log::info!( + "{}: opening notes for channel {channel_name}", + client.username + ); + client.channel_buffers().insert(buffer.await?); + } + + ChannelBufferOperation::LeaveChannelNotes { channel_name } => { + let buffer = cx.update(|cx| { + let mut left_buffer = Err(TestError::Inapplicable); + client.channel_buffers().retain(|buffer| { + if buffer.read(cx).channel().name == channel_name { + left_buffer = Ok(buffer.clone()); + false + } else { + true + } + }); + left_buffer + })?; + + log::info!( + "{}: closing notes for channel {channel_name}", + client.username + ); + cx.update(|_| drop(buffer)); + } + + ChannelBufferOperation::EditChannelNotes { + channel_name, + edits, + } => { + let channel_buffer = cx + .read(|cx| { + client + .channel_buffers() + .iter() + .find(|buffer| buffer.read(cx).channel().name == channel_name) + .cloned() + }) + .ok_or_else(|| TestError::Inapplicable)?; + + log::info!( + "{}: editing notes for channel {channel_name} with {:?}", + client.username, + edits + ); + + channel_buffer.update(cx, |buffer, cx| { + let buffer = buffer.buffer(); + buffer.update(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + buffer.edit( + edits.into_iter().map(|(range, text)| { + let start = snapshot.clip_offset(range.start, Bias::Left); + let end = snapshot.clip_offset(range.end, Bias::Right); + (start..end, text) + }), + None, + cx, + ); + }); + }); + } + + ChannelBufferOperation::Noop => Err(TestError::Inapplicable)?, + } Ok(()) } - async fn on_client_added(client: &Rc) {} + async fn on_client_added(client: &Rc, cx: &mut TestAppContext) { + let channel_store = client.channel_store(); + while channel_store.read_with(cx, |store, _| store.channel_count() == 0) { + channel_store.next_notification(cx).await; + } + } - fn on_clients_quiesced(clients: &[(Rc, TestAppContext)]) {} + async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + let channels = server.app_state.db.all_channels().await.unwrap(); + + for (channel_id, channel_name) in channels { + let mut collaborator_user_ids = server + .app_state + .db + .get_channel_buffer_collaborators(channel_id) + .await + .unwrap() + .into_iter() + .map(|id| id.to_proto()) + .collect::>(); + collaborator_user_ids.sort(); + + for (client, client_cx) in clients.iter_mut() { + client_cx.update(|cx| { + client + .channel_buffers() + .retain(|b| b.read(cx).is_connected()); + + if let Some(channel_buffer) = client + .channel_buffers() + .iter() + .find(|b| b.read(cx).channel().id == channel_id.to_proto()) + { + let channel_buffer = channel_buffer.read(cx); + let collaborators = channel_buffer.collaborators(); + let mut user_ids = + collaborators.iter().map(|c| c.user_id).collect::>(); + user_ids.sort(); + assert_eq!( + user_ids, + collaborator_user_ids, + "client {} has different user ids for channel {} than the server", + client.user_id().unwrap(), + channel_name + ); + } + }); + } + } + } } diff --git a/crates/collab/src/tests/random_project_collaboration_tests.rs b/crates/collab/src/tests/random_project_collaboration_tests.rs index 242cfbc162..7570768249 100644 --- a/crates/collab/src/tests/random_project_collaboration_tests.rs +++ b/crates/collab/src/tests/random_project_collaboration_tests.rs @@ -1,7 +1,5 @@ -use crate::{ - db::UserId, - tests::{run_randomized_test, RandomizedTest, TestClient, TestError, UserTestPlan}, -}; +use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan}; +use crate::db::UserId; use anyhow::{anyhow, Result}; use async_trait::async_trait; use call::ActiveCall; @@ -145,6 +143,20 @@ struct ProjectCollaborationTest; impl RandomizedTest for ProjectCollaborationTest { type Operation = ClientOperation; + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) { + let db = &server.app_state.db; + for (ix, user_a) in users.iter().enumerate() { + for user_b in &users[ix + 1..] { + db.send_contact_request(user_a.user_id, user_b.user_id) + .await + .unwrap(); + db.respond_to_contact_request(user_b.user_id, user_a.user_id, true) + .await + .unwrap(); + } + } + } + fn generate_operation( client: &TestClient, rng: &mut StdRng, @@ -1005,7 +1017,7 @@ impl RandomizedTest for ProjectCollaborationTest { Ok(()) } - async fn on_client_added(client: &Rc) { + async fn on_client_added(client: &Rc, _: &mut TestAppContext) { let mut language = Language::new( LanguageConfig { name: "Rust".into(), @@ -1119,8 +1131,8 @@ impl RandomizedTest for ProjectCollaborationTest { client.app_state.languages.add(Arc::new(language)); } - fn on_clients_quiesced(clients: &[(Rc, TestAppContext)]) { - for (client, client_cx) in clients { + async fn on_quiesce(_: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { + for (client, client_cx) in clients.iter() { for guest_project in client.remote_projects().iter() { guest_project.read_with(client_cx, |guest_project, cx| { let host_project = clients.iter().find_map(|(client, cx)| { diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index dc102b75c6..39598bdaf9 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -1,5 +1,5 @@ use crate::{ - db::{self, Database, NewUserParams, UserId}, + db::{self, NewUserParams, UserId}, rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, tests::{TestClient, TestServer}, }; @@ -107,15 +107,17 @@ pub trait RandomizedTest: 'static + Sized { cx: &TestAppContext, ) -> Self::Operation; - async fn on_client_added(client: &Rc); - - fn on_clients_quiesced(client: &[(Rc, TestAppContext)]); - async fn apply_operation( client: &TestClient, operation: Self::Operation, cx: &mut TestAppContext, ) -> Result<(), TestError>; + + async fn initialize(server: &mut TestServer, users: &[UserTestPlan]); + + async fn on_client_added(client: &Rc, cx: &mut TestAppContext); + + async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc, TestAppContext)]); } pub async fn run_randomized_test( @@ -125,7 +127,7 @@ pub async fn run_randomized_test( ) { deterministic.forbid_parking(); let mut server = TestServer::start(&deterministic).await; - let plan = TestPlan::::new(server.app_state.db.clone(), rng).await; + let plan = TestPlan::::new(&mut server, rng).await; LAST_PLAN.lock().replace({ let plan = plan.clone(); @@ -162,7 +164,7 @@ pub async fn run_randomized_test( deterministic.finish_waiting(); deterministic.run_until_parked(); - T::on_clients_quiesced(&clients); + T::on_quiesce(&mut server, &mut clients).await; for (client, mut cx) in clients { cx.update(|cx| { @@ -190,7 +192,7 @@ pub fn save_randomized_test_plan() { } impl TestPlan { - pub async fn new(db: Arc, mut rng: StdRng) -> Arc> { + pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc> { let allow_server_restarts = rng.gen_bool(0.7); let allow_client_reconnection = rng.gen_bool(0.7); let allow_client_disconnection = rng.gen_bool(0.1); @@ -198,7 +200,9 @@ impl TestPlan { let mut users = Vec::new(); for ix in 0..*MAX_PEERS { let username = format!("user-{}", ix + 1); - let user_id = db + let user_id = server + .app_state + .db .create_user( &format!("{username}@example.com"), false, @@ -222,16 +226,7 @@ impl TestPlan { }); } - for (ix, user_a) in users.iter().enumerate() { - for user_b in &users[ix + 1..] { - db.send_contact_request(user_a.user_id, user_b.user_id) - .await - .unwrap(); - db.respond_to_contact_request(user_b.user_id, user_a.user_id, true) - .await - .unwrap(); - } - } + T::initialize(server, &users).await; let plan = Arc::new(Mutex::new(Self { replay: false, @@ -619,7 +614,7 @@ impl TestPlan { if quiesce && applied { deterministic.run_until_parked(); - T::on_clients_quiesced(&clients); + T::on_quiesce(server, clients).await; } return applied; @@ -634,7 +629,7 @@ impl TestPlan { mut operation_rx: futures::channel::mpsc::UnboundedReceiver, mut cx: TestAppContext, ) { - T::on_client_added(&client).await; + T::on_client_added(&client, &mut cx).await; while let Some(batch_id) = operation_rx.next().await { let Some((operation, applied)) = diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 44f6ac1450..eef1dde967 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -6,7 +6,7 @@ use crate::{ }; use anyhow::anyhow; use call::ActiveCall; -use channel::ChannelStore; +use channel::{channel_buffer::ChannelBuffer, ChannelStore}; use client::{ self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore, }; @@ -51,6 +51,7 @@ struct TestClientState { local_projects: Vec>, remote_projects: Vec>, buffers: HashMap, HashSet>>, + channel_buffers: HashSet>, } pub struct ContactsSummary { @@ -468,6 +469,12 @@ impl TestClient { RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers) } + pub fn channel_buffers<'a>( + &'a self, + ) -> impl DerefMut>> + 'a { + RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers) + } + pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary { self.app_state .user_store From b75e69d31b5874d2fe3fb0b69008a0d40349b41d Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 6 Sep 2023 14:25:07 -0700 Subject: [PATCH 54/60] Check that channel notes text converges in randomized test --- .../src/tests/random_channel_buffer_tests.rs | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/crates/collab/src/tests/random_channel_buffer_tests.rs b/crates/collab/src/tests/random_channel_buffer_tests.rs index 933683eaa6..a60d3d7d7d 100644 --- a/crates/collab/src/tests/random_channel_buffer_tests.rs +++ b/crates/collab/src/tests/random_channel_buffer_tests.rs @@ -225,7 +225,17 @@ impl RandomizedTest for RandomChannelBufferTest { async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc, TestAppContext)]) { let channels = server.app_state.db.all_channels().await.unwrap(); + for (client, client_cx) in clients.iter_mut() { + client_cx.update(|cx| { + client + .channel_buffers() + .retain(|b| b.read(cx).is_connected()); + }); + } + for (channel_id, channel_name) in channels { + let mut prev_text: Option<(u64, String)> = None; + let mut collaborator_user_ids = server .app_state .db @@ -237,18 +247,30 @@ impl RandomizedTest for RandomChannelBufferTest { .collect::>(); collaborator_user_ids.sort(); - for (client, client_cx) in clients.iter_mut() { - client_cx.update(|cx| { - client - .channel_buffers() - .retain(|b| b.read(cx).is_connected()); - + for (client, client_cx) in clients.iter() { + let user_id = client.user_id().unwrap(); + client_cx.read(|cx| { if let Some(channel_buffer) = client .channel_buffers() .iter() .find(|b| b.read(cx).channel().id == channel_id.to_proto()) { let channel_buffer = channel_buffer.read(cx); + + // Assert that channel buffer's text matches other clients' copies. + let text = channel_buffer.buffer().read(cx).text(); + if let Some((prev_user_id, prev_text)) = &prev_text { + assert_eq!( + &text, + prev_text, + "client {user_id} has different text than client {prev_user_id} for channel {channel_name}", + ); + } else { + prev_text = Some((user_id, text.clone())); + } + + // Assert that all clients and the server agree about who is present in the + // channel buffer. let collaborators = channel_buffer.collaborators(); let mut user_ids = collaborators.iter().map(|c| c.user_id).collect::>(); @@ -256,9 +278,7 @@ impl RandomizedTest for RandomChannelBufferTest { assert_eq!( user_ids, collaborator_user_ids, - "client {} has different user ids for channel {} than the server", - client.user_id().unwrap(), - channel_name + "client {user_id} has different user ids for channel {channel_name} than the server", ); } }); From ed2aed4f93c7549252cba5ac533b15bd03e00766 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 6 Sep 2023 14:29:11 -0700 Subject: [PATCH 55/60] Update test name in randomized-test-minimize script --- script/randomized-test-minimize | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/script/randomized-test-minimize b/script/randomized-test-minimize index ce0b7203b4..df003cbf3e 100755 --- a/script/randomized-test-minimize +++ b/script/randomized-test-minimize @@ -9,7 +9,6 @@ const CARGO_TEST_ARGS = [ '--release', '--lib', '--package', 'collab', - 'random_collaboration', ] if (require.main === module) { @@ -99,7 +98,7 @@ function buildTests() { } function runTests(env) { - const {status, stdout} = spawnSync('cargo', ['test', ...CARGO_TEST_ARGS], { + const {status, stdout} = spawnSync('cargo', ['test', ...CARGO_TEST_ARGS, 'random_project_collaboration'], { stdio: 'pipe', encoding: 'utf8', env: { From 58f58a629b86d5659bdf5ce6dc1b96c08104b6d6 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 6 Sep 2023 14:58:25 -0700 Subject: [PATCH 56/60] Tolerate channel buffer operations being re-sent --- crates/collab/src/db/queries/buffers.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 8236eb9c3b..00de201403 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -380,6 +380,16 @@ impl Database { .collect::>(); if !operations.is_empty() { buffer_operation::Entity::insert_many(operations) + .on_conflict( + OnConflict::columns([ + buffer_operation::Column::BufferId, + buffer_operation::Column::Epoch, + buffer_operation::Column::LamportTimestamp, + buffer_operation::Column::ReplicaId, + ]) + .do_nothing() + .to_owned(), + ) .exec(&*tx) .await?; } From 39e13b667554691f7685c5b0e07dc8e0a479e6ef Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Wed, 6 Sep 2023 22:53:05 -0400 Subject: [PATCH 57/60] Allow call events to be logged without a room id --- crates/call/src/call.rs | 30 ++++++++++++++++------------ crates/client/src/telemetry.rs | 2 +- crates/collab_ui/src/collab_panel.rs | 18 +++++++++++++---- crates/collab_ui/src/collab_ui.rs | 8 ++++---- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 5886462ccf..4db298fe98 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -273,7 +273,13 @@ impl ActiveCall { .borrow_mut() .take() .ok_or_else(|| anyhow!("no incoming call"))?; - Self::report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx); + Self::report_call_event_for_room( + "decline incoming", + Some(call.room_id), + None, + &self.client, + cx, + ); self.client.send(proto::DeclineCall { room_id: call.room_id, })?; @@ -403,22 +409,20 @@ impl ActiveCall { &self.pending_invites } - pub fn report_call_event(&self, operation: &'static str, cx: &AppContext) { - if let Some(room) = self.room() { - let room = room.read(cx); - Self::report_call_event_for_room( - operation, - room.id(), - room.channel_id(), - &self.client, - cx, - ) - } + fn report_call_event(&self, operation: &'static str, cx: &AppContext) { + let (room_id, channel_id) = match self.room() { + Some(room) => { + let room = room.read(cx); + (Some(room.id()), room.channel_id()) + } + None => (None, None), + }; + Self::report_call_event_for_room(operation, room_id, channel_id, &self.client, cx) } pub fn report_call_event_for_room( operation: &'static str, - room_id: u64, + room_id: Option, channel_id: Option, client: &Arc, cx: &AppContext, diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index 9cc5d13af0..f8642dd7fa 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -73,7 +73,7 @@ pub enum ClickhouseEvent { }, Call { operation: &'static str, - room_id: u64, + room_id: Option, channel_id: Option, }, } diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index d27cdc8acf..fba10c61ba 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2240,7 +2240,8 @@ impl CollabPanel { fn open_channel_buffer(&mut self, action: &OpenChannelBuffer, cx: &mut ViewContext) { if let Some(workspace) = self.workspace.upgrade(cx) { let pane = workspace.read(cx).active_pane().clone(); - let channel_view = ChannelView::open(action.channel_id, pane.clone(), workspace, cx); + let channel_id = action.channel_id; + let channel_view = ChannelView::open(channel_id, pane.clone(), workspace, cx); cx.spawn(|_, mut cx| async move { let channel_view = channel_view.await?; pane.update(&mut cx, |pane, cx| { @@ -2249,9 +2250,18 @@ impl CollabPanel { anyhow::Ok(()) }) .detach(); - ActiveCall::global(cx).update(cx, |call, cx| { - call.report_call_event("open channel notes", cx) - }); + let room_id = ActiveCall::global(cx) + .read(cx) + .room() + .map(|room| room.read(cx).id()); + + ActiveCall::report_call_event_for_room( + "open channel notes", + room_id, + Some(channel_id), + &self.client, + cx, + ); } } diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index 04644b62d9..ee34f600fa 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -49,7 +49,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) { if room.is_screen_sharing() { ActiveCall::report_call_event_for_room( "disable screen share", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -58,7 +58,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) { } else { ActiveCall::report_call_event_for_room( "enable screen share", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -78,7 +78,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) { if room.is_muted(cx) { ActiveCall::report_call_event_for_room( "enable microphone", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, @@ -86,7 +86,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) { } else { ActiveCall::report_call_event_for_room( "disable microphone", - room.id(), + Some(room.id()), room.channel_id(), &client, cx, From 3ad1befb11947200a4bfdcddec837732a8b7c6ce Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 7 Sep 2023 15:07:21 +0200 Subject: [PATCH 58/60] Remove unneeded logging Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/embedding_queue.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 024881f0b8..104a4eb8ee 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -75,12 +75,10 @@ impl EmbeddingQueue { }); let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; - let mut saved_tokens = 0; for (ix, span) in file.lock().spans.iter().enumerate() { let span_token_count = if span.embedding.is_none() { span.token_count } else { - saved_tokens += span.token_count; 0 }; @@ -98,7 +96,6 @@ impl EmbeddingQueue { fragment_range.end = ix + 1; self.pending_batch_token_count += span_token_count; } - log::trace!("Saved Tokens: {:?}", saved_tokens); } pub fn flush(&mut self) { @@ -113,10 +110,8 @@ impl EmbeddingQueue { self.executor.spawn(async move { let mut spans = Vec::new(); - let mut span_count = 0; for fragment in &batch { let file = fragment.file.lock(); - span_count += file.spans[fragment.span_range.clone()].len(); spans.extend( { file.spans[fragment.span_range.clone()] @@ -126,9 +121,6 @@ impl EmbeddingQueue { ); } - log::trace!("Documents Length: {:?}", span_count); - log::trace!("Span Length: {:?}", spans.clone().len()); - // If spans is 0, just send the fragment to the finished files if its the last one. if spans.len() == 0 { for fragment in batch.clone() { @@ -149,7 +141,6 @@ impl EmbeddingQueue { if let Some(embedding) = embeddings.next() { span.embedding = Some(embedding); } else { - // log::error!("number of embeddings returned different from number of documents"); } } From 757a28585256407787e4c97e00eb5c720e60c16b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 7 Sep 2023 15:15:16 +0200 Subject: [PATCH 59/60] Keep dropping the `documents` table if it exists This is because we renamed `documents` to `spans`. Co-Authored-By: Kyle Caverly --- crates/semantic_index/src/db.rs | 4 ++++ crates/semantic_index/src/semantic_index.rs | 24 ++++++++++----------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 28bbd56156..c35057594a 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -124,6 +124,10 @@ impl VectorDatabase { } log::trace!("vector database schema out of date. updating..."); + // We renamed the `documents` table to `spans`, so we want to drop + // `documents` without recreating it if it exists. + db.execute("DROP TABLE IF EXISTS documents", []) + .context("failed to drop 'documents' table")?; db.execute("DROP TABLE IF EXISTS spans", []) .context("failed to drop 'spans' table")?; db.execute("DROP TABLE IF EXISTS files", []) diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 1c1c40fa27..8bba2f1d0e 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -92,8 +92,8 @@ pub struct SemanticIndex { struct ProjectState { worktrees: HashMap, - outstanding_job_count_rx: watch::Receiver, - outstanding_job_count_tx: Arc>>, + pending_file_count_rx: watch::Receiver, + pending_file_count_tx: Arc>>, _subscription: gpui::Subscription, } @@ -178,12 +178,12 @@ impl JobHandle { impl ProjectState { fn new(subscription: gpui::Subscription) -> Self { - let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0); - let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx)); + let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0); + let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx)); Self { worktrees: Default::default(), - outstanding_job_count_rx, - outstanding_job_count_tx, + pending_file_count_rx, + pending_file_count_tx, _subscription: subscription, } } @@ -605,7 +605,7 @@ impl SemanticIndex { Some( self.projects .get(&project.downgrade())? - .outstanding_job_count_rx + .pending_file_count_rx .clone(), ) } @@ -774,8 +774,8 @@ impl SemanticIndex { .insert(project.downgrade(), ProjectState::new(subscription)); self.project_worktrees_changed(project.clone(), cx); } - let project_state = self.projects.get(&project.downgrade()).unwrap(); - let mut outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone(); + let project_state = &self.projects[&project.downgrade()]; + let mut pending_file_count_rx = project_state.pending_file_count_rx.clone(); let db = self.db.clone(); let language_registry = self.language_registry.clone(); @@ -792,7 +792,7 @@ impl SemanticIndex { .projects .get_mut(&project.downgrade()) .ok_or_else(|| anyhow!("project was dropped"))?; - let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; + let pending_file_count_tx = &project_state.pending_file_count_tx; project_state .worktrees @@ -816,7 +816,7 @@ impl SemanticIndex { files_to_delete.push((worktree_state.db_id, path.clone())); } else { let absolute_path = worktree.read(cx).absolutize(path); - let job_handle = JobHandle::new(outstanding_job_count_tx); + let job_handle = JobHandle::new(pending_file_count_tx); pending_files.push(PendingFile { absolute_path, relative_path: path.clone(), @@ -879,7 +879,7 @@ impl SemanticIndex { } // Wait until we're done indexing. - while let Some(count) = outstanding_job_count_rx.next().await { + while let Some(count) = pending_file_count_rx.next().await { if count == 0 { break; } From a45c8c380f04f5f39bb360374007379ea1f96781 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 7 Sep 2023 15:25:23 +0200 Subject: [PATCH 60/60] :lipstick: --- crates/semantic_index/src/embedding_queue.rs | 81 ++++++++++---------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 104a4eb8ee..3026eef9ae 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -108,54 +108,55 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - self.executor.spawn(async move { - let mut spans = Vec::new(); - for fragment in &batch { - let file = fragment.file.lock(); - spans.extend( - { + self.executor + .spawn(async move { + let mut spans = Vec::new(); + for fragment in &batch { + let file = fragment.file.lock(); + spans.extend( file.spans[fragment.span_range.clone()] - .iter().filter(|d| d.embedding.is_none()) - .map(|d| d.content.clone()) - } - ); - } - - // If spans is 0, just send the fragment to the finished files if its the last one. - if spans.len() == 0 { - for fragment in batch.clone() { - if let Some(file) = Arc::into_inner(fragment.file) { - finished_files_tx.try_send(file.into_inner()).unwrap(); - } + .iter() + .filter(|d| d.embedding.is_none()) + .map(|d| d.content.clone()), + ); } - return; - }; - - match embedding_provider.embed_batch(spans).await { - Ok(embeddings) => { - let mut embeddings = embeddings.into_iter(); - for fragment in batch { - for span in - &mut fragment.file.lock().spans[fragment.span_range.clone()].iter_mut().filter(|d| d.embedding.is_none()) - { - if let Some(embedding) = embeddings.next() { - span.embedding = Some(embedding); - } else { - log::error!("number of embeddings returned different from number of documents"); - } - } + // If spans is 0, just send the fragment to the finished files if its the last one. + if spans.is_empty() { + for fragment in batch.clone() { if let Some(file) = Arc::into_inner(fragment.file) { finished_files_tx.try_send(file.into_inner()).unwrap(); } } + return; + }; + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for span in &mut fragment.file.lock().spans[fragment.span_range.clone()] + .iter_mut() + .filter(|d| d.embedding.is_none()) + { + if let Some(embedding) = embeddings.next() { + span.embedding = Some(embedding); + } else { + log::error!("number of embeddings != number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } } - Err(error) => { - log::error!("{:?}", error); - } - } - }) - .detach(); + }) + .detach(); } pub fn finished_files(&self) -> channel::Receiver {