Add `/auto` behind a feature flag that's disabled for now, even for
staff.

We've decided on a different design for context inference, but there are
parts of /auto that will be useful for that, so we want them in the code
base even if they're unused for now.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Richard Feldman 2024-09-13 13:17:49 -04:00 committed by GitHub
parent 93a3e8bc94
commit 91ffa02e2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 2776 additions and 1054 deletions

View file

@ -19,14 +19,18 @@ crate-type = ["bin"]
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
blake3.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
heed.workspace = true
http_client.workspace = true

View file

@ -4,7 +4,7 @@ use gpui::App;
use http_client::HttpClientWithUrl;
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticDb};
use settings::SettingsStore;
use std::{
path::{Path, PathBuf},
@ -50,7 +50,7 @@ fn main() {
));
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
let semantic_index = SemanticDb::new(
PathBuf::from("/tmp/semantic-index-db.mdb"),
embedding_provider,
&mut cx,
@ -71,6 +71,7 @@ fn main() {
let project_index = cx
.update(|cx| semantic_index.project_index(project.clone(), cx))
.unwrap()
.unwrap();
let (tx, rx) = oneshot::channel();

View file

@ -12,6 +12,12 @@ use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
@ -68,12 +74,6 @@ impl fmt::Display for Embedding {
}
}
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,

View file

@ -0,0 +1,469 @@
use crate::{
chunking::{self, Chunk},
embedding::{Embedding, EmbeddingProvider, TextToEmbed},
indexing::{IndexingEntryHandle, IndexingEntrySet},
};
use anyhow::{anyhow, Context as _, Result};
use collections::Bound;
use fs::Fs;
use futures::stream::StreamExt;
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{AppContext, Model, Task};
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use log;
use project::{Entry, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
iter,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use util::ResultExt;
use worktree::Snapshot;
pub struct EmbeddingIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
}
impl EmbeddingIndex {
pub fn new(
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
) -> Self {
Self {
worktree,
fs,
db_connection,
db: embedding_db,
language_registry,
embedding_provider,
entry_ids_being_indexed,
}
}
pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
&self.db
}
pub fn index_entries_changed_on_disk(
&self,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree, cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
pub fn index_updated_entries(
&self,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let db = self.db;
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let mut db_entries = db
.iter(&txn)
.context("failed to create iterator")?
.move_between_keys()
.peekable();
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
for entry in worktree.files(false, 0) {
log::trace!("scanning for embedding index: {:?}", &entry.path);
let entry_db_key = db_key_for_path(&entry.path);
let mut saved_mtime = None;
while let Some(db_entry) = db_entries.peek() {
match db_entry {
Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
Ordering::Less => {
if let Some(deletion_range) = deletion_range.as_mut() {
deletion_range.1 = Bound::Included(db_path);
} else {
deletion_range =
Some((Bound::Included(db_path), Bound::Included(db_path)));
}
db_entries.next();
}
Ordering::Equal => {
if let Some(deletion_range) = deletion_range.take() {
deleted_entry_ranges_tx
.send((
deletion_range.0.map(ToString::to_string),
deletion_range.1.map(ToString::to_string),
))
.await?;
}
saved_mtime = db_embedded_file.mtime;
db_entries.next();
break;
}
Ordering::Greater => {
break;
}
},
Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
}
}
if entry.mtime != saved_mtime {
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
if let Some(db_entry) = db_entries.next() {
let (db_path, _) = db_entry?;
deleted_entry_ranges_tx
.send((Bound::Included(db_path.to_string()), Bound::Unbounded))
.await?;
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn scan_updated_entries(
&self,
worktree: Snapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
}
project::PathChange::Removed => {
let db_path = db_key_for_path(path);
deleted_entry_ranges_tx
.send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
.await?;
}
project::PathChange::Loaded => {
// Do nothing.
}
}
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn chunk_files(
&self,
worktree_abs_path: Arc<Path>,
entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
cx: &AppContext,
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
let fs = self.fs.clone();
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
match fs.load(&entry_abs_path).await {
Ok(text) => {
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let chunked_file = ChunkedFile {
chunks: chunking::chunk_text(
&text,
language.as_ref(),
&entry.path,
),
handle,
path: entry.path,
mtime: entry.mtime,
text,
};
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
Err(_)=> {
log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
}
}
}
});
}
})
.await;
Ok(())
});
ChunkFiles {
files: chunked_files_rx,
task,
}
}
pub fn embed_files(
embedding_provider: Arc<dyn EmbeddingProvider>,
chunked_files: channel::Receiver<ChunkedFile>,
cx: &AppContext,
) -> EmbedFiles {
let embedding_provider = embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
chunked_files.chunks_timeout(512, Duration::from_secs(2));
while let Some(chunked_files) = chunked_file_batches.next().await {
// View the batch of files as a vec of chunks
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
// Once those are done, reassemble them back into the files in which they belong
// If any embeddings fail for a file, the entire file is discarded
let chunks: Vec<TextToEmbed> = chunked_files
.iter()
.flat_map(|file| {
file.chunks.iter().map(|chunk| TextToEmbed {
text: &file.text[chunk.range.clone()],
digest: chunk.digest,
})
})
.collect::<Vec<_>>();
let mut embeddings: Vec<Option<Embedding>> = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
if let Some(batch_embeddings) =
embedding_provider.embed(embedding_batch).await.log_err()
{
if batch_embeddings.len() == embedding_batch.len() {
embeddings.extend(batch_embeddings.into_iter().map(Some));
continue;
}
log::error!(
"embedding provider returned unexpected embedding count {}, expected {}",
batch_embeddings.len(), embedding_batch.len()
);
}
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
let mut embedded_file = EmbeddedFile {
path: chunked_file.path,
mtime: chunked_file.mtime,
chunks: Vec::new(),
};
let mut embedded_all_chunks = true;
for (chunk, embedding) in
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
{
if let Some(embedding) = embedding {
embedded_file
.chunks
.push(EmbeddedChunk { chunk, embedding });
} else {
embedded_all_chunks = false;
}
}
if embedded_all_chunks {
embedded_files_tx
.send((embedded_file, chunked_file.handle))
.await?;
}
}
}
Ok(())
});
EmbedFiles {
files: embedded_files_rx,
task,
}
}
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
while let Some(deletion_range) = deleted_entry_ranges.next().await {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
for (file, _) in &embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, file)?;
}
txn.commit()?;
drop(embedded_files);
log::debug!("committed");
}
Ok(())
})
}
pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
let connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
let tx = connection
.read_txn()
.context("failed to create read transaction")?;
let result = db
.iter(&tx)?
.map(|entry| Ok(entry?.1.path.clone()))
.collect::<Result<Vec<Arc<Path>>>>();
drop(tx);
result
})
}
pub fn chunks_for_path(
&self,
path: Arc<Path>,
cx: &AppContext,
) -> Task<Result<Vec<EmbeddedChunk>>> {
let connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
let tx = connection
.read_txn()
.context("failed to create read transaction")?;
Ok(db
.get(&tx, &db_key_for_path(&path))?
.ok_or_else(|| anyhow!("no such path"))?
.chunks
.clone())
})
}
}
struct ScanEntries {
updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>,
}
struct ChunkFiles {
files: channel::Receiver<ChunkedFile>,
task: Task<Result<()>>,
}
pub struct ChunkedFile {
pub path: Arc<Path>,
pub mtime: Option<SystemTime>,
pub handle: IndexingEntryHandle,
pub text: String,
pub chunks: Vec<Chunk>,
}
pub struct EmbedFiles {
pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
pub task: Task<Result<()>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddedFile {
pub path: Arc<Path>,
pub mtime: Option<SystemTime>,
pub chunks: Vec<EmbeddedChunk>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EmbeddedChunk {
pub chunk: Chunk,
pub embedding: Embedding,
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}

View file

@ -0,0 +1,49 @@
use collections::HashSet;
use parking_lot::Mutex;
use project::ProjectEntryId;
use smol::channel;
use std::sync::{Arc, Weak};
/// The set of entries that are currently being indexed.
pub struct IndexingEntrySet {
entry_ids: Mutex<HashSet<ProjectEntryId>>,
tx: channel::Sender<()>,
}
/// When dropped, removes the entry from the set of entries that are being indexed.
#[derive(Clone)]
pub(crate) struct IndexingEntryHandle {
entry_id: ProjectEntryId,
set: Weak<IndexingEntrySet>,
}
impl IndexingEntrySet {
pub fn new(tx: channel::Sender<()>) -> Self {
Self {
entry_ids: Default::default(),
tx,
}
}
pub fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
self.entry_ids.lock().insert(entry_id);
self.tx.send_blocking(()).ok();
IndexingEntryHandle {
entry_id,
set: Arc::downgrade(self),
}
}
pub fn len(&self) -> usize {
self.entry_ids.lock().len()
}
}
impl Drop for IndexingEntryHandle {
fn drop(&mut self) {
if let Some(set) = self.set.upgrade() {
set.tx.send_blocking(()).ok();
set.entry_ids.lock().remove(&self.entry_id);
}
}
}

View file

@ -0,0 +1,523 @@
use crate::{
embedding::{EmbeddingProvider, TextToEmbed},
summary_index::FileSummary,
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use fs::Fs;
use futures::{stream::StreamExt, FutureExt};
use gpui::{
AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel,
};
use language::LanguageRegistry;
use log;
use project::{Project, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
use util::ResultExt;
#[derive(Debug)]
pub struct SearchResult {
pub worktree: Model<Worktree>,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Status {
Idle,
Loading,
Scanning { remaining_count: NonZeroUsize },
}
pub struct ProjectIndex {
db_connection: heed::Env,
project: WeakModel<Project>,
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
_maintain_status: Task<()>,
_subscription: Subscription,
}
impl ProjectIndex {
pub fn new(
project: Model<Project>,
db_connection: heed::Env,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
let (status_tx, mut status_rx) = channel::unbounded();
let mut this = ProjectIndex {
db_connection,
project: project.downgrade(),
worktree_indices: HashMap::default(),
language_registry,
fs,
status_tx,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
_maintain_status: cx.spawn(|this, mut cx| async move {
while status_rx.next().await.is_some() {
if this
.update(&mut cx, |this, cx| this.update_status(cx))
.is_err()
{
break;
}
}
}),
};
this.update_worktree_indices(cx);
this
}
pub fn status(&self) -> Status {
self.last_status
}
pub fn project(&self) -> WeakModel<Project> {
self.project.clone()
}
pub fn fs(&self) -> Arc<dyn Fs> {
self.fs.clone()
}
fn handle_project_event(
&mut self,
_: Model<Project>,
event: &project::Event,
cx: &mut ModelContext<Self>,
) {
match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
self.update_worktree_indices(cx);
}
_ => {}
}
}
fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
let Some(project) = self.project.upgrade() else {
return;
};
let worktrees = project
.read(cx)
.visible_worktrees(cx)
.filter_map(|worktree| {
if worktree.read(cx).is_local() {
Some((worktree.entity_id(), worktree))
} else {
None
}
})
.collect::<HashMap<_, _>>();
self.worktree_indices
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
for (worktree_id, worktree) in worktrees {
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
let worktree_index = WorktreeIndex::load(
worktree.clone(),
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
self.status_tx.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(|this, mut cx| async move {
let result = match worktree_index.await {
Ok(worktree_index) => {
this.update(&mut cx, |this, _| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
index: worktree_index.clone(),
},
);
})?;
Ok(worktree_index)
}
Err(error) => {
this.update(&mut cx, |this, _cx| {
this.worktree_indices.remove(&worktree_id)
})?;
Err(Arc::new(error))
}
};
this.update(&mut cx, |this, cx| this.update_status(cx))?;
result
});
WorktreeIndexHandle::Loading {
index: load_worktree.shared(),
}
});
}
self.update_status(cx);
}
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
let mut indexing_count = 0;
let mut any_loading = false;
for index in self.worktree_indices.values_mut() {
match index {
WorktreeIndexHandle::Loading { .. } => {
any_loading = true;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
indexing_count += index.read(cx).entry_ids_being_indexed().len();
}
}
}
let status = if any_loading {
Status::Loading
} else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
Status::Scanning { remaining_count }
} else {
Status::Idle
};
if status != self.last_status {
self.last_status = status;
cx.emit(status);
}
}
pub fn search(
&self,
query: String,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let chunks_tx = chunks_tx.clone();
worktree_scan_tasks.push(cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(&cx, |index, cx| {
let worktree_id = index.worktree().read(cx).id();
let db_connection = index.db_connection().clone();
let db = *index.embedding_index().db();
cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((worktree_id, db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(chunks_tx);
let project = self.project.clone();
let embedding_provider = self.embedding_provider.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
let query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
score,
},
);
results.truncate(limit);
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(&cx, |project, cx| {
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
for worker_results in results_by_worker {
search_results.extend(worker_results.into_iter().filter_map(|result| {
Some(SearchResult {
worktree: project.worktree_for_id(result.worktree_id, cx)?,
path: result.path,
range: result.range,
score: result.score,
})
}));
}
search_results.sort_unstable_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
});
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
search_results
})
})
}
#[cfg(test)]
pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
let mut result = 0;
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
result += index.read(cx).path_count()?;
}
}
Ok(result)
}
pub(crate) fn worktree_index(
&self,
worktree_id: WorktreeId,
cx: &AppContext,
) -> Option<Model<WorktreeIndex>> {
for index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
if index.read(cx).worktree().read(cx).id() == worktree_id {
return Some(index.clone());
}
}
}
None
}
pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
let mut result = self
.worktree_indices
.values()
.filter_map(|index| {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
Some(index.clone())
} else {
None
}
})
.collect::<Vec<_>>();
result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
result
}
pub fn all_summaries(&self, cx: &AppContext) -> Task<Result<Vec<FileSummary>>> {
let (summaries_tx, summaries_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
worktree_scan_tasks.push(cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(&cx, |index, cx| {
let db_connection = index.db_connection().clone();
let summary_index = index.summary_index();
let file_digest_db = summary_index.file_digest_db();
let summary_db = summary_index.summary_db();
cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create db read transaction")?;
let db_entries = file_digest_db
.iter(&txn)
.context("failed to iterate database")?;
for db_entry in db_entries {
let (file_path, db_file) = db_entry?;
match summary_db.get(&txn, &db_file.digest) {
Ok(opt_summary) => {
// Currently, we only use summaries we already have. If the file hasn't been
// summarized yet, then we skip it and don't include it in the inferred context.
// If we want to do just-in-time summarization, this would be the place to do it!
if let Some(summary) = opt_summary {
summaries_tx
.send((file_path.to_string(), summary.to_string()))
.await?;
} else {
log::warn!("No summary found for {:?}", &db_file);
}
}
Err(err) => {
log::error!(
"Error reading from summary database: {:?}",
err
);
}
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(summaries_tx);
let project = self.project.clone();
cx.spawn(|cx| async move {
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<FileSummary>::new());
}
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((filename, summary)) = summaries_rx.recv().await {
results.push(FileSummary { filename, summary });
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(&cx, |_project, _cx| {
results_by_worker.into_iter().flatten().collect()
})
})
}
/// Empty out the backlogs of all the worktrees in the project
pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future<Output = ()> {
let flush_start = std::time::Instant::now();
futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
let worktree_index = worktree_index.clone();
cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
let worktree_abs_path =
cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
index
.read_with(&cx, |index, cx| {
cx.background_executor()
.spawn(index.summary_index().flush_backlog(worktree_abs_path, cx))
})?
.await
})
}))
.map(move |results| {
// Log any errors, but don't block the user. These summaries are supposed to
// improve quality by providing extra context, but they aren't hard requirements!
for result in results {
if let Err(err) = result {
log::error!("Error flushing summary backlog: {:?}", err);
}
}
log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
})
}
pub fn remaining_summaries(&self, cx: &mut ModelContext<Self>) -> usize {
self.worktree_indices(cx)
.iter()
.map(|index| index.read(cx).summary_index().backlog_len())
.sum()
}
}
impl EventEmitter<Status> for ProjectIndex {}

View file

@ -55,8 +55,12 @@ impl ProjectIndexDebugView {
for index in worktree_indices {
let (root_path, worktree_id, worktree_paths) =
index.read_with(&cx, |index, cx| {
let worktree = index.worktree.read(cx);
(worktree.abs_path(), worktree.id(), index.paths(cx))
let worktree = index.worktree().read(cx);
(
worktree.abs_path(),
worktree.id(),
index.embedding_index().paths(cx),
)
})?;
rows.push(Row::Worktree(root_path));
rows.extend(
@ -82,10 +86,12 @@ impl ProjectIndexDebugView {
cx: &mut ViewContext<Self>,
) -> Option<()> {
let project_index = self.index.read(cx);
let fs = project_index.fs.clone();
let fs = project_index.fs().clone();
let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx);
let root_path = worktree_index.worktree.read(cx).abs_path();
let chunks = worktree_index.chunks_for_path(file_path.clone(), cx);
let root_path = worktree_index.worktree().read(cx).abs_path();
let chunks = worktree_index
.embedding_index()
.chunks_for_path(file_path.clone(), cx);
cx.spawn(|this, mut cx| async move {
let chunks = chunks.await?;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
use collections::HashMap;
use std::{path::Path, sync::Arc, time::SystemTime};
const MAX_FILES_BEFORE_RESUMMARIZE: usize = 4;
const MAX_BYTES_BEFORE_RESUMMARIZE: u64 = 1_000_000; // 1 MB
#[derive(Default, Debug)]
pub struct SummaryBacklog {
/// Key: path to a file that needs summarization, but that we haven't summarized yet. Value: that file's size on disk, in bytes, and its mtime.
files: HashMap<Arc<Path>, (u64, Option<SystemTime>)>,
/// Cache of the sum of all values in `files`, so we don't have to traverse the whole map to check if we're over the byte limit.
total_bytes: u64,
}
impl SummaryBacklog {
/// Store the given path in the backlog, along with how many bytes are in it.
pub fn insert(&mut self, path: Arc<Path>, bytes_on_disk: u64, mtime: Option<SystemTime>) {
let (prev_bytes, _) = self
.files
.insert(path, (bytes_on_disk, mtime))
.unwrap_or_default(); // Default to 0 prev_bytes
// Update the cached total by subtracting out the old amount and adding the new one.
self.total_bytes = self.total_bytes - prev_bytes + bytes_on_disk;
}
/// Returns true if the total number of bytes in the backlog exceeds a predefined threshold.
pub fn needs_drain(&self) -> bool {
self.files.len() > MAX_FILES_BEFORE_RESUMMARIZE ||
// The whole purpose of the cached total_bytes is to make this comparison cheap.
// Otherwise we'd have to traverse the entire dictionary every time we wanted this answer.
self.total_bytes > MAX_BYTES_BEFORE_RESUMMARIZE
}
/// Remove all the entries in the backlog and return the file paths as an iterator.
#[allow(clippy::needless_lifetimes)] // Clippy thinks this 'a can be elided, but eliding it gives a compile error
pub fn drain<'a>(&'a mut self) -> impl Iterator<Item = (Arc<Path>, Option<SystemTime>)> + 'a {
self.total_bytes = 0;
self.files
.drain()
.map(|(path, (_size, mtime))| (path, mtime))
}
pub fn len(&self) -> usize {
self.files.len()
}
}

View file

@ -0,0 +1,693 @@
use anyhow::{anyhow, Context as _, Result};
use arrayvec::ArrayString;
use fs::Fs;
use futures::{stream::StreamExt, TryFutureExt};
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{AppContext, Model, Task};
use heed::{
types::{SerdeBincode, Str},
RoTxn,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, Role,
};
use log;
use parking_lot::Mutex;
use project::{Entry, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
future::Future,
path::Path,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use util::ResultExt;
use worktree::Snapshot;
use crate::{indexing::IndexingEntrySet, summary_backlog::SummaryBacklog};
#[derive(Serialize, Deserialize, Debug)]
pub struct FileSummary {
pub filename: String,
pub summary: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct UnsummarizedFile {
// Path to the file on disk
path: Arc<Path>,
// The mtime of the file on disk
mtime: Option<SystemTime>,
// BLAKE3 hash of the source file's contents
digest: Blake3Digest,
// The source file's contents
contents: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct SummarizedFile {
// Path to the file on disk
path: String,
// The mtime of the file on disk
mtime: Option<SystemTime>,
// BLAKE3 hash of the source file's contents
digest: Blake3Digest,
// The LLM's summary of the file's contents
summary: String,
}
/// This is what blake3's to_hex() method returns - see https://docs.rs/blake3/1.5.3/src/blake3/lib.rs.html#246
pub type Blake3Digest = ArrayString<{ blake3::OUT_LEN * 2 }>;
#[derive(Debug, Serialize, Deserialize)]
pub struct FileDigest {
pub mtime: Option<SystemTime>,
pub digest: Blake3Digest,
}
struct NeedsSummary {
files: channel::Receiver<UnsummarizedFile>,
task: Task<Result<()>>,
}
struct SummarizeFiles {
files: channel::Receiver<SummarizedFile>,
task: Task<Result<()>>,
}
pub struct SummaryIndex {
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>, // Key: file path. Val: BLAKE3 digest of its contents.
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>, // Key: BLAKE3 digest of a file's contents. Val: LLM summary of those contents.
backlog: Arc<Mutex<SummaryBacklog>>,
_entry_ids_being_indexed: Arc<IndexingEntrySet>, // TODO can this be removed?
}
struct Backlogged {
paths_to_digest: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
task: Task<Result<()>>,
}
struct MightNeedSummaryFiles {
files: channel::Receiver<UnsummarizedFile>,
task: Task<Result<()>>,
}
impl SummaryIndex {
pub fn new(
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>,
_entry_ids_being_indexed: Arc<IndexingEntrySet>,
) -> Self {
Self {
worktree,
fs,
db_connection,
file_digest_db,
summary_db,
_entry_ids_being_indexed,
backlog: Default::default(),
}
}
pub fn file_digest_db(&self) -> heed::Database<Str, SerdeBincode<FileDigest>> {
self.file_digest_db
}
pub fn summary_db(&self) -> heed::Database<SerdeBincode<Blake3Digest>, Str> {
self.summary_db
}
pub fn index_entries_changed_on_disk(
&self,
is_auto_available: bool,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged;
let digest;
let needs_summary;
let summaries;
let persist;
if is_auto_available {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
backlogged = self.scan_entries(worktree, cx);
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
needs_summary = self.check_summary_cache(digest.files, cx);
summaries = self.summarize_files(needs_summary.files, cx);
persist = self.persist_summaries(summaries.files, cx);
} else {
// This feature is only staff-shipped, so make the rest of these no-ops.
backlogged = Backlogged {
paths_to_digest: channel::unbounded().1,
task: Task::ready(Ok(())),
};
digest = MightNeedSummaryFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
needs_summary = NeedsSummary {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
summaries = SummarizeFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
persist = Task::ready(Ok(()));
}
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
if is_auto_available {
log::info!(
"Summarizing everything that changed on disk took {:?}",
start.elapsed()
);
}
Ok(())
}
}
pub fn index_updated_entries(
&mut self,
updated_entries: UpdatedEntriesSet,
is_auto_available: bool,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged;
let digest;
let needs_summary;
let summaries;
let persist;
if is_auto_available {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
backlogged = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
needs_summary = self.check_summary_cache(digest.files, cx);
summaries = self.summarize_files(needs_summary.files, cx);
persist = self.persist_summaries(summaries.files, cx);
} else {
// This feature is only staff-shipped, so make the rest of these no-ops.
backlogged = Backlogged {
paths_to_digest: channel::unbounded().1,
task: Task::ready(Ok(())),
};
digest = MightNeedSummaryFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
needs_summary = NeedsSummary {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
summaries = SummarizeFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
persist = Task::ready(Ok(()));
}
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
log::info!("Summarizing updated entries took {:?}", start.elapsed());
Ok(())
}
}
fn check_summary_cache(
&self,
mut might_need_summary: channel::Receiver<UnsummarizedFile>,
cx: &AppContext,
) -> NeedsSummary {
let db_connection = self.db_connection.clone();
let db = self.summary_db;
let (needs_summary_tx, needs_summary_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
while let Some(file) = might_need_summary.next().await {
let tx = db_connection
.read_txn()
.context("Failed to create read transaction for checking which hashes are in summary cache")?;
match db.get(&tx, &file.digest) {
Ok(opt_answer) => {
if opt_answer.is_none() {
// It's not in the summary cache db, so we need to summarize it.
log::debug!("File {:?} (digest {:?}) was NOT in the db cache and needs to be resummarized.", file.path.display(), &file.digest);
needs_summary_tx.send(file).await?;
} else {
log::debug!("File {:?} (digest {:?}) was in the db cache and does not need to be resummarized.", file.path.display(), &file.digest);
}
}
Err(err) => {
log::error!("Reading from the summaries database failed: {:?}", err);
}
}
}
Ok(())
});
NeedsSummary {
files: needs_summary_rx,
task,
}
}
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> Backlogged {
let (tx, rx) = channel::bounded(512);
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let backlog = Arc::clone(&self.backlog);
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
for entry in worktree.files(false, 0) {
let needs_summary =
Self::add_to_backlog(Arc::clone(&backlog), digest_db, &txn, entry);
if !needs_summary.is_empty() {
tx.send(needs_summary).await?;
}
}
// TODO delete db entries for deleted files
Ok(())
});
Backlogged {
paths_to_digest: rx,
task,
}
}
fn add_to_backlog(
backlog: Arc<Mutex<SummaryBacklog>>,
digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
txn: &RoTxn<'_>,
entry: &Entry,
) -> Vec<(Arc<Path>, Option<SystemTime>)> {
let entry_db_key = db_key_for_path(&entry.path);
match digest_db.get(&txn, &entry_db_key) {
Ok(opt_saved_digest) => {
// The file path is the same, but the mtime is different. (Or there was no mtime.)
// It needs updating, so add it to the backlog! Then, if the backlog is full, drain it and summarize its contents.
if entry.mtime != opt_saved_digest.and_then(|digest| digest.mtime) {
let mut backlog = backlog.lock();
log::info!(
"Inserting {:?} ({:?} bytes) into backlog",
&entry.path,
entry.size,
);
backlog.insert(Arc::clone(&entry.path), entry.size, entry.mtime);
if backlog.needs_drain() {
log::info!("Draining summary backlog...");
return backlog.drain().collect();
}
}
}
Err(err) => {
log::error!(
"Error trying to get file digest db entry {:?}: {:?}",
&entry_db_key,
err
);
}
}
Vec::new()
}
fn scan_updated_entries(
&self,
worktree: Snapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> Backlogged {
log::info!("Scanning for updated entries that might need summarization...");
let (tx, rx) = channel::bounded(512);
// let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let backlog = Arc::clone(&self.backlog);
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Loaded
| project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
let needs_summary = Self::add_to_backlog(
Arc::clone(&backlog),
digest_db,
&txn,
entry,
);
if !needs_summary.is_empty() {
tx.send(needs_summary).await?;
}
}
}
}
project::PathChange::Removed => {
let _db_path = db_key_for_path(path);
// TODO delete db entries for deleted files
// deleted_entry_ranges_tx
// .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
// .await?;
}
}
}
Ok(())
});
Backlogged {
paths_to_digest: rx,
// deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn digest_files(
&self,
paths: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
worktree_abs_path: Arc<Path>,
cx: &AppContext,
) -> MightNeedSummaryFiles {
let fs = self.fs.clone();
let (rx, tx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok(pairs) = paths.recv().await {
// Note: we could process all these files concurrently if desired. Might or might not speed things up.
for (path, mtime) in pairs {
let entry_abs_path = worktree_abs_path.join(&path);
// Load the file's contents and compute its hash digest.
let unsummarized_file = {
let Some(contents) = fs
.load(&entry_abs_path)
.await
.with_context(|| {
format!("failed to read path {entry_abs_path:?}")
})
.log_err()
else {
continue;
};
let digest = {
let mut hasher = blake3::Hasher::new();
// Incorporate both the (relative) file path as well as the contents of the file into the hash.
// This is because in some languages and frameworks, identical files can do different things
// depending on their paths (e.g. Rails controllers). It's also why we send the path to the model.
hasher.update(path.display().to_string().as_bytes());
hasher.update(contents.as_bytes());
hasher.finalize().to_hex()
};
UnsummarizedFile {
digest,
contents,
path,
mtime,
}
};
if let Err(err) = rx
.send(unsummarized_file)
.map_err(|error| anyhow!(error))
.await
{
log::error!("Error: {:?}", err);
return;
}
}
}
});
}
})
.await;
Ok(())
});
MightNeedSummaryFiles { files: tx, task }
}
fn summarize_files(
&self,
mut unsummarized_files: channel::Receiver<UnsummarizedFile>,
cx: &AppContext,
) -> SummarizeFiles {
let (summarized_tx, summarized_rx) = channel::bounded(512);
let task = cx.spawn(|cx| async move {
while let Some(file) = unsummarized_files.next().await {
log::debug!("Summarizing {:?}", file);
let summary = cx
.update(|cx| Self::summarize_code(&file.contents, &file.path, cx))?
.await
.unwrap_or_else(|err| {
// Log a warning because we'll continue anyway.
// In the future, we may want to try splitting it up into multiple requests and concatenating the summaries,
// but this might give bad summaries due to cutting off source code files in the middle.
log::warn!("Failed to summarize {} - {:?}", file.path.display(), err);
String::new()
});
// Note that the summary could be empty because of an error talking to a cloud provider,
// e.g. because the context limit was exceeded. In that case, we return Ok(String::new()).
if !summary.is_empty() {
summarized_tx
.send(SummarizedFile {
path: file.path.display().to_string(),
digest: file.digest,
summary,
mtime: file.mtime,
})
.await?
}
}
Ok(())
});
SummarizeFiles {
files: summarized_rx,
task,
}
}
fn summarize_code(
code: &str,
path: &Path,
cx: &AppContext,
) -> impl Future<Output = Result<String>> {
let start = Instant::now();
let (summary_model_id, use_cache): (LanguageModelId, bool) = (
"Qwen/Qwen2-7B-Instruct".to_string().into(), // TODO read this from the user's settings.
false, // qwen2 doesn't have a cache, but we should probably infer this from the model
);
let Some(model) = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.find(|model| &model.id() == &summary_model_id)
else {
return cx.background_executor().spawn(async move {
Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))
});
};
let utf8_path = path.to_string_lossy();
const PROMPT_BEFORE_CODE: &str = "Summarize what the code in this file does in 3 sentences, using no newlines or bullet points in the summary:";
let prompt = format!("{PROMPT_BEFORE_CODE}\n{utf8_path}:\n{code}");
log::debug!(
"Summarizing code by sending this prompt to {:?}: {:?}",
model.name(),
&prompt
);
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: use_cache,
}],
tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
};
let code_len = code.len();
cx.spawn(|cx| async move {
let stream = model.stream_completion(request, &cx);
cx.background_executor()
.spawn(async move {
let answer: String = stream
.await?
.filter_map(|event| async {
if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
Some(text)
} else {
None
}
})
.collect()
.await;
log::info!(
"It took {:?} to summarize {:?} bytes of code.",
start.elapsed(),
code_len
);
log::debug!("Summary was: {:?}", &answer);
Ok(answer)
})
.await
// TODO if summarization failed, put it back in the backlog!
})
}
fn persist_summaries(
&self,
summaries: channel::Receiver<SummarizedFile>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let summary_db = self.summary_db;
cx.background_executor().spawn(async move {
let mut summaries = summaries.chunks_timeout(4096, Duration::from_secs(2));
while let Some(summaries) = summaries.next().await {
let mut txn = db_connection.write_txn()?;
for file in &summaries {
log::debug!(
"Saving summary of {:?} - which is {} bytes of summary for content digest {:?}",
&file.path,
file.summary.len(),
file.digest
);
digest_db.put(
&mut txn,
&file.path,
&FileDigest {
mtime: file.mtime,
digest: file.digest,
},
)?;
summary_db.put(&mut txn, &file.digest, &file.summary)?;
}
txn.commit()?;
drop(summaries);
log::debug!("committed summaries");
}
Ok(())
})
}
/// Empty out the backlog of files that haven't been resummarized, and resummarize them immediately.
pub(crate) fn flush_backlog(
&self,
worktree_abs_path: Arc<Path>,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged = {
let (tx, rx) = channel::bounded(512);
let needs_summary: Vec<(Arc<Path>, Option<SystemTime>)> = {
let mut backlog = self.backlog.lock();
backlog.drain().collect()
};
let task = cx.background_executor().spawn(async move {
tx.send(needs_summary).await?;
Ok(())
});
Backlogged {
paths_to_digest: rx,
task,
}
};
let digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
let needs_summary = self.check_summary_cache(digest.files, cx);
let summaries = self.summarize_files(needs_summary.files, cx);
let persist = self.persist_summaries(summaries.files, cx);
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
log::info!("Summarizing backlogged entries took {:?}", start.elapsed());
Ok(())
}
}
pub(crate) fn backlog_len(&self) -> usize {
self.backlog.lock().len()
}
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}

View file

@ -0,0 +1,217 @@
use crate::embedding::EmbeddingProvider;
use crate::embedding_index::EmbeddingIndex;
use crate::indexing::IndexingEntrySet;
use crate::summary_index::SummaryIndex;
use anyhow::Result;
use feature_flags::{AutoCommand, FeatureFlagAppExt};
use fs::Fs;
use futures::future::Shared;
use gpui::{
AppContext, AsyncAppContext, Context, Model, ModelContext, Subscription, Task, WeakModel,
};
use language::LanguageRegistry;
use log;
use project::{UpdatedEntriesSet, Worktree};
use smol::channel;
use std::sync::Arc;
use util::ResultExt;
#[derive(Clone)]
pub enum WorktreeIndexHandle {
Loading {
index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
},
Loaded {
index: Model<WorktreeIndex>,
},
}
pub struct WorktreeIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
embedding_index: EmbeddingIndex,
summary_index: SummaryIndex,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
_index_entries: Task<Result<()>>,
_subscription: Subscription,
}
impl WorktreeIndex {
pub fn load(
worktree: Model<Worktree>,
db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let worktree_for_index = worktree.clone();
let worktree_for_summary = worktree.clone();
let worktree_abs_path = worktree.read(cx).abs_path();
let embedding_fs = Arc::clone(&fs);
let summary_fs = fs;
cx.spawn(|mut cx| async move {
let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
let (embedding_index, summary_index) = cx
.background_executor()
.spawn({
let entries_being_indexed = Arc::clone(&entries_being_indexed);
let db_connection = db_connection.clone();
async move {
let mut txn = db_connection.write_txn()?;
let embedding_index = {
let db_name = worktree_abs_path.to_string_lossy();
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
EmbeddingIndex::new(
worktree_for_index,
embedding_fs,
db_connection.clone(),
db,
language_registry,
embedding_provider,
Arc::clone(&entries_being_indexed),
)
};
let summary_index = {
let file_digest_db = {
let db_name =
// Prepend something that wouldn't be found at the beginning of an
// absolute path, so we don't get db key namespace conflicts with
// embeddings, which use the abs path as a key.
format!("digests-{}", worktree_abs_path.to_string_lossy());
db_connection.create_database(&mut txn, Some(&db_name))?
};
let summary_db = {
let db_name =
// Prepend something that wouldn't be found at the beginning of an
// absolute path, so we don't get db key namespace conflicts with
// embeddings, which use the abs path as a key.
format!("summaries-{}", worktree_abs_path.to_string_lossy());
db_connection.create_database(&mut txn, Some(&db_name))?
};
SummaryIndex::new(
worktree_for_summary,
summary_fs,
db_connection.clone(),
file_digest_db,
summary_db,
Arc::clone(&entries_being_indexed),
)
};
txn.commit()?;
anyhow::Ok((embedding_index, summary_index))
}
})
.await?;
cx.new_model(|cx| {
Self::new(
worktree,
db_connection,
embedding_index,
summary_index,
entries_being_indexed,
cx,
)
})
})
}
#[allow(clippy::too_many_arguments)]
pub fn new(
worktree: Model<Worktree>,
db_connection: heed::Env,
embedding_index: EmbeddingIndex,
summary_index: SummaryIndex,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
cx: &mut ModelContext<Self>,
) -> Self {
let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
if let worktree::Event::UpdatedEntries(update) = event {
log::debug!("Updating entries...");
_ = updated_entries_tx.try_send(update.clone());
}
});
Self {
db_connection,
embedding_index,
summary_index,
worktree,
entry_ids_being_indexed,
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription,
}
}
pub fn entry_ids_being_indexed(&self) -> &IndexingEntrySet {
self.entry_ids_being_indexed.as_ref()
}
pub fn worktree(&self) -> &Model<Worktree> {
&self.worktree
}
pub fn db_connection(&self) -> &heed::Env {
&self.db_connection
}
pub fn embedding_index(&self) -> &EmbeddingIndex {
&self.embedding_index
}
pub fn summary_index(&self) -> &SummaryIndex {
&self.summary_index
}
async fn index_entries(
this: WeakModel<Self>,
updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext,
) -> Result<()> {
let is_auto_available = cx.update(|cx| cx.wait_for_flag::<AutoCommand>())?.await;
let index = this.update(&mut cx, |this, cx| {
futures::future::try_join(
this.embedding_index.index_entries_changed_on_disk(cx),
this.summary_index
.index_entries_changed_on_disk(is_auto_available, cx),
)
})?;
index.await.log_err();
while let Ok(updated_entries) = updated_entries.recv().await {
let is_auto_available = cx
.update(|cx| cx.has_flag::<AutoCommand>())
.unwrap_or(false);
let index = this.update(&mut cx, |this, cx| {
futures::future::try_join(
this.embedding_index
.index_updated_entries(updated_entries.clone(), cx),
this.summary_index.index_updated_entries(
updated_entries,
is_auto_available,
cx,
),
)
})?;
index.await.log_err();
}
Ok(())
}
#[cfg(test)]
pub fn path_count(&self) -> Result<u64> {
use anyhow::Context;
let txn = self
.db_connection
.read_txn()
.context("failed to create read transaction")?;
Ok(self.embedding_index().db().len(&txn)?)
}
}