ZIm/crates/semantic_index/src/project_index.rs
2025-03-31 20:55:27 +02:00

549 lines
20 KiB
Rust

use crate::{
embedding::{EmbeddingProvider, TextToEmbed},
summary_index::FileSummary,
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
};
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
use fs::Fs;
use futures::FutureExt;
use gpui::{
App, AppContext as _, Context, Entity, EntityId, EventEmitter, Subscription, Task, WeakEntity,
};
use language::LanguageRegistry;
use log;
use project::{Project, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
num::NonZeroUsize,
ops::{Range, RangeInclusive},
path::{Path, PathBuf},
sync::Arc,
};
use util::ResultExt;
#[derive(Debug)]
pub struct SearchResult {
pub worktree: Entity<Worktree>,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
pub query_index: usize,
}
#[derive(Debug, PartialEq, Eq)]
pub struct LoadedSearchResult {
pub path: Arc<Path>,
pub full_path: PathBuf,
pub excerpt_content: String,
pub row_range: RangeInclusive<u32>,
pub query_index: usize,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub query_index: usize,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Status {
Idle,
Loading,
Scanning { remaining_count: NonZeroUsize },
}
pub struct ProjectIndex {
db_connection: heed::Env,
project: WeakEntity<Project>,
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
_maintain_status: Task<()>,
_subscription: Subscription,
}
impl ProjectIndex {
pub fn new(
project: Entity<Project>,
db_connection: heed::Env,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut Context<Self>,
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
let (status_tx, status_rx) = channel::unbounded();
let mut this = ProjectIndex {
db_connection,
project: project.downgrade(),
worktree_indices: HashMap::default(),
language_registry,
fs,
status_tx,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
_maintain_status: cx.spawn(async move |this, cx| {
while status_rx.recv().await.is_ok() {
if this.update(cx, |this, cx| this.update_status(cx)).is_err() {
break;
}
}
}),
};
this.update_worktree_indices(cx);
this
}
pub fn status(&self) -> Status {
self.last_status
}
pub fn project(&self) -> WeakEntity<Project> {
self.project.clone()
}
pub fn fs(&self) -> Arc<dyn Fs> {
self.fs.clone()
}
fn handle_project_event(
&mut self,
_: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.update_worktree_indices(cx);
}
_ => {}
}
}
fn update_worktree_indices(&mut self, cx: &mut Context<Self>) {
let Some(project) = self.project.upgrade() else {
return;
};
let worktrees = project
.read(cx)
.visible_worktrees(cx)
.filter_map(|worktree| {
if worktree.read(cx).is_local() {
Some((worktree.entity_id(), worktree))
} else {
None
}
})
.collect::<HashMap<_, _>>();
self.worktree_indices
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
for (worktree_id, worktree) in worktrees {
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
let worktree_index = WorktreeIndex::load(
worktree.clone(),
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
self.status_tx.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(async move |this, cx| {
let result = match worktree_index.await {
Ok(worktree_index) => {
this.update(cx, |this, _| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
index: worktree_index.clone(),
},
);
})?;
Ok(worktree_index)
}
Err(error) => {
this.update(cx, |this, _cx| {
this.worktree_indices.remove(&worktree_id)
})?;
Err(Arc::new(error))
}
};
this.update(cx, |this, cx| this.update_status(cx))?;
result
});
WorktreeIndexHandle::Loading {
index: load_worktree.shared(),
}
});
}
self.update_status(cx);
}
fn update_status(&mut self, cx: &mut Context<Self>) {
let mut indexing_count = 0;
let mut any_loading = false;
for index in self.worktree_indices.values_mut() {
match index {
WorktreeIndexHandle::Loading { .. } => {
any_loading = true;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
indexing_count += index.read(cx).entry_ids_being_indexed().len();
}
}
}
let status = if any_loading {
Status::Loading
} else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
Status::Scanning { remaining_count }
} else {
Status::Idle
};
if status != self.last_status {
self.last_status = status;
cx.emit(status);
}
}
pub fn search(
&self,
queries: Vec<String>,
limit: usize,
cx: &App,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let chunks_tx = chunks_tx.clone();
worktree_scan_tasks.push(cx.spawn(async move |cx| {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(cx, |index, cx| {
let worktree_id = index.worktree().read(cx).id();
let db_connection = index.db_connection().clone();
let db = *index.embedding_index().db();
cx.background_spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((worktree_id, db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(chunks_tx);
let project = self.project.clone();
let embedding_provider = self.embedding_provider.clone();
cx.spawn(async move |cx| {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {queries:?}");
let queries: Vec<TextToEmbed> = queries
.iter()
.map(|s| TextToEmbed::new(s.as_str()))
.collect();
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
if query_embeddings.len() != queries.len() {
return Err(anyhow!(
"The number of query embeddings does not match the number of queries"
));
}
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let (score, query_index) =
chunk.embedding.similarity(&query_embeddings);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
if ix < limit {
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
query_index,
score,
},
);
if results.len() > limit {
results.pop();
}
}
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(cx, |project, cx| {
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
for worker_results in results_by_worker {
search_results.extend(worker_results.into_iter().filter_map(|result| {
Some(SearchResult {
worktree: project.worktree_for_id(result.worktree_id, cx)?,
path: result.path,
range: result.range,
score: result.score,
query_index: result.query_index,
})
}));
}
search_results.sort_unstable_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
});
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
search_results
})
})
}
#[cfg(test)]
pub fn path_count(&self, cx: &App) -> Result<u64> {
let mut result = 0;
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
result += index.read(cx).path_count()?;
}
}
Ok(result)
}
pub(crate) fn worktree_index(
&self,
worktree_id: WorktreeId,
cx: &App,
) -> Option<Entity<WorktreeIndex>> {
for index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
if index.read(cx).worktree().read(cx).id() == worktree_id {
return Some(index.clone());
}
}
}
None
}
pub(crate) fn worktree_indices(&self, cx: &App) -> Vec<Entity<WorktreeIndex>> {
let mut result = self
.worktree_indices
.values()
.filter_map(|index| {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
Some(index.clone())
} else {
None
}
})
.collect::<Vec<_>>();
result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
result
}
pub fn all_summaries(&self, cx: &App) -> Task<Result<Vec<FileSummary>>> {
let (summaries_tx, summaries_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
worktree_scan_tasks.push(cx.spawn(async move |cx| {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(cx, |index, cx| {
let db_connection = index.db_connection().clone();
let summary_index = index.summary_index();
let file_digest_db = summary_index.file_digest_db();
let summary_db = summary_index.summary_db();
cx.background_spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create db read transaction")?;
let db_entries = file_digest_db
.iter(&txn)
.context("failed to iterate database")?;
for db_entry in db_entries {
let (file_path, db_file) = db_entry?;
match summary_db.get(&txn, &db_file.digest) {
Ok(opt_summary) => {
// Currently, we only use summaries we already have. If the file hasn't been
// summarized yet, then we skip it and don't include it in the inferred context.
// If we want to do just-in-time summarization, this would be the place to do it!
if let Some(summary) = opt_summary {
summaries_tx
.send((file_path.to_string(), summary.to_string()))
.await?;
} else {
log::warn!("No summary found for {:?}", &db_file);
}
}
Err(err) => {
log::error!(
"Error reading from summary database: {:?}",
err
);
}
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(summaries_tx);
let project = self.project.clone();
cx.spawn(async move |cx| {
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<FileSummary>::new());
}
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((filename, summary)) = summaries_rx.recv().await {
results.push(FileSummary { filename, summary });
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(cx, |_project, _cx| {
results_by_worker.into_iter().flatten().collect()
})
})
}
/// Empty out the backlogs of all the worktrees in the project
pub fn flush_summary_backlogs(&self, cx: &App) -> impl Future<Output = ()> {
let flush_start = std::time::Instant::now();
futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
let worktree_index = worktree_index.clone();
cx.spawn(async move |cx| {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
let worktree_abs_path =
cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
index
.read_with(cx, |index, cx| {
cx.background_spawn(
index.summary_index().flush_backlog(worktree_abs_path, cx),
)
})?
.await
})
}))
.map(move |results| {
// Log any errors, but don't block the user. These summaries are supposed to
// improve quality by providing extra context, but they aren't hard requirements!
for result in results {
if let Err(err) = result {
log::error!("Error flushing summary backlog: {:?}", err);
}
}
log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
})
}
pub fn remaining_summaries(&self, cx: &mut Context<Self>) -> usize {
self.worktree_indices(cx)
.iter()
.map(|index| index.read(cx).summary_index().backlog_len())
.sum()
}
}
impl EventEmitter<Status> for ProjectIndex {}