Include modified buffers in semantic search results (#2970)

This pull request introduces an additional step to
`SemanticIndex::search_project` that includes the content of buffers
that are modified but haven't been saved yet. In most cases, the buffer
will contain a small portion of changed spans that are potentially not
included in the index. To reuse all the other spans that haven't
changed, we will query the database for embeddings by their digest. This
means we have to index spans by their digest, which means some penalty
when writing, but in our tests this didn't seem to make indexing much
slower.

Release Notes:

- Improved semantic search to include results from modified buffers.
(preview-only)
This commit is contained in:
Antonio Scandurra 2023-09-15 12:24:10 +02:00 committed by GitHub
commit a1250b8525
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 294 additions and 60 deletions

1
Cargo.lock generated
View file

@ -6739,6 +6739,7 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"matrixmultiply", "matrixmultiply",
"ordered-float",
"parking_lot 0.11.2", "parking_lot 0.11.2",
"parse_duration", "parse_duration",
"picker", "picker",

View file

@ -912,7 +912,6 @@ impl Project {
self.user_store.clone() self.user_store.clone()
} }
#[cfg(any(test, feature = "test-support"))]
pub fn opened_buffers(&self, cx: &AppContext) -> Vec<ModelHandle<Buffer>> { pub fn opened_buffers(&self, cx: &AppContext) -> Vec<ModelHandle<Buffer>> {
self.opened_buffers self.opened_buffers
.values() .values()

View file

@ -23,6 +23,7 @@ settings = { path = "../settings" }
anyhow.workspace = true anyhow.workspace = true
postage.workspace = true postage.workspace = true
futures.workspace = true futures.workspace = true
ordered-float.workspace = true
smol.workspace = true smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true isahc.workspace = true

View file

@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap; use collections::HashMap;
use futures::channel::oneshot; use futures::channel::oneshot;
use gpui::executor; use gpui::executor;
use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs}; use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp; use rpc::proto::Timestamp;
use rusqlite::params; use rusqlite::params;
use rusqlite::types::Value; use rusqlite::types::Value;
use std::{ use std::{
cmp::Ordering, cmp::Reverse,
future::Future, future::Future,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -190,6 +191,10 @@ impl VectorDatabase {
)", )",
[], [],
)?; )?;
db.execute(
"CREATE INDEX spans_digest ON spans (digest)",
[],
)?;
log::trace!("vector database initialized with updated schema."); log::trace!("vector database initialized with updated schema.");
Ok(()) Ok(())
@ -274,6 +279,39 @@ impl VectorDatabase {
}) })
} }
pub fn embeddings_for_digests(
&self,
digests: Vec<SpanDigest>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
WHERE digest IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
let digests = Rc::new(
digests
.into_iter()
.map(|p| Value::Blob(p.0.to_vec()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![digests], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
Ok(embeddings_by_digest)
})
}
pub fn embeddings_for_files( pub fn embeddings_for_files(
&self, &self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>, worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
@ -370,16 +408,16 @@ impl VectorDatabase {
query_embedding: &Embedding, query_embedding: &Embedding,
limit: usize, limit: usize,
file_ids: &[i64], file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> { ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let query_embedding = query_embedding.clone(); let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec(); let file_ids = file_ids.to_vec();
self.transact(move |db| { self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| { Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding); let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| { let ix = match results
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
}) { {
Ok(ix) => ix, Ok(ix) => ix,
Err(ix) => ix, Err(ix) => ix,
}; };

View file

@ -7,6 +7,7 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable; use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response}; use isahc::{AsyncBody, Response};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
use parking_lot::Mutex; use parking_lot::Mutex;
use parse_duration::parse; use parse_duration::parse;
use postage::watch; use postage::watch;
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
} }
impl Embedding { impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 { pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len(); let len = self.0.len();
assert_eq!(len, other.0.len()); assert_eq!(len, other.0.len());
@ -58,7 +59,7 @@ impl Embedding {
1, 1,
); );
} }
result OrderedFloat(result)
} }
} }
@ -379,13 +380,13 @@ mod tests {
); );
} }
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places); let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor (n * factor).round() / factor
} }
fn reference_dot(a: &[f32], b: &[f32]) -> f32 { fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
} }
} }
} }

View file

@ -7,6 +7,7 @@ use rusqlite::{
}; };
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
use std::{ use std::{
borrow::Cow,
cmp::{self, Reverse}, cmp::{self, Reverse},
collections::HashSet, collections::HashSet,
ops::Range, ops::Range,
@ -16,7 +17,7 @@ use std::{
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)] #[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest([u8; 20]); pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest { impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> { fn column_result(value: ValueRef) -> FromSqlResult<Self> {
@ -94,12 +95,15 @@ impl CodeContextRetriever {
fn parse_entire_file( fn parse_entire_file(
&self, &self,
relative_path: &Path, relative_path: Option<&Path>,
language_name: Arc<str>, language_name: Arc<str>,
content: &str, content: &str,
) -> Result<Vec<Span>> { ) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE let document_span = ENTIRE_FILE_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
@ -114,9 +118,16 @@ impl CodeContextRetriever {
}]) }])
} }
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> { fn parse_markdown_file(
&self,
relative_path: Option<&Path>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span); let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
@ -188,7 +199,7 @@ impl CodeContextRetriever {
pub fn parse_file_with_template( pub fn parse_file_with_template(
&mut self, &mut self,
relative_path: &Path, relative_path: Option<&Path>,
content: &str, content: &str,
language: Arc<Language>, language: Arc<Language>,
) -> Result<Vec<Span>> { ) -> Result<Vec<Span>> {
@ -196,14 +207,17 @@ impl CodeContextRetriever {
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content); return self.parse_entire_file(relative_path, language_name, &content);
} else if language_name.as_ref() == "Markdown" { } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
return self.parse_markdown_file(relative_path, &content); return self.parse_markdown_file(relative_path, &content);
} }
let mut spans = self.parse_file(content, language)?; let mut spans = self.parse_file(content, language)?;
for span in &mut spans { for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref()) .replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("item", &span.content); .replace("item", &span.content);

View file

@ -16,14 +16,16 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt}; use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
use ordered_float::OrderedFloat;
use parking_lot::Mutex; use parking_lot::Mutex;
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch; use postage::watch;
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use smol::channel; use smol::channel;
use std::{ use std::{
cmp::Ordering, cmp::Reverse,
future::Future, future::Future,
mem,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::{Arc, Weak}, sync::{Arc, Weak},
@ -37,7 +39,7 @@ use util::{
}; };
use workspace::WorkspaceCreated; use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 10; const SEMANTIC_INDEX_VERSION: usize = 11;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
@ -262,9 +264,11 @@ pub struct PendingFile {
job_handle: JobHandle, job_handle: JobHandle,
} }
#[derive(Clone)]
pub struct SearchResult { pub struct SearchResult {
pub buffer: ModelHandle<Buffer>, pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>, pub range: Range<Anchor>,
pub similarity: OrderedFloat<f32>,
} }
impl SemanticIndex { impl SemanticIndex {
@ -402,7 +406,7 @@ impl SemanticIndex {
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
if let Some(mut spans) = retriever if let Some(mut spans) = retriever
.parse_file_with_template(&pending_file.relative_path, &content, language) .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
.log_err() .log_err()
{ {
log::trace!( log::trace!(
@ -422,7 +426,7 @@ impl SemanticIndex {
path: pending_file.relative_path, path: pending_file.relative_path,
mtime: pending_file.modified_time, mtime: pending_file.modified_time,
job_handle: pending_file.job_handle, job_handle: pending_file.job_handle,
spans: spans, spans,
}); });
} }
} }
@ -687,39 +691,71 @@ impl SemanticIndex {
pub fn search_project( pub fn search_project(
&mut self, &mut self,
project: ModelHandle<Project>, project: ModelHandle<Project>,
phrase: String, query: String,
limit: usize, limit: usize,
includes: Vec<PathMatcher>, includes: Vec<PathMatcher>,
excludes: Vec<PathMatcher>, excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> { ) -> Task<Result<Vec<SearchResult>>> {
if query.is_empty() {
return Task::ready(Ok(Vec::new()));
}
let index = self.index_project(project.clone(), cx); let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
cx.spawn(|this, mut cx| async move {
let query = embedding_provider
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
index.await?;
let search_start = Instant::now();
let modified_buffer_results = this.update(&mut cx, |this, cx| {
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
});
let file_results = this.update(&mut cx, |this, cx| {
this.search_files(project, query, limit, includes, excludes, cx)
});
let (modified_buffer_results, file_results) =
futures::join!(modified_buffer_results, file_results);
// Weave together the results from modified buffers and files.
let mut results = Vec::new();
let mut modified_buffers = HashSet::default();
for result in modified_buffer_results.log_err().unwrap_or_default() {
modified_buffers.insert(result.buffer.clone());
results.push(result);
}
for result in file_results.log_err().unwrap_or_default() {
if !modified_buffers.contains(&result.buffer) {
results.push(result);
}
}
results.sort_by_key(|result| Reverse(result.similarity));
results.truncate(limit);
log::trace!("Semantic search took {:?}", search_start.elapsed());
Ok(results)
})
}
pub fn search_files(
&mut self,
project: ModelHandle<Project>,
query: Embedding,
limit: usize,
includes: Vec<PathMatcher>,
excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let db_path = self.db.path().clone(); let db_path = self.db.path().clone();
let fs = self.fs.clone(); let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let database = let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
if phrase.len() == 0 {
return Ok(Vec::new());
}
let phrase_embedding = embedding_provider
.embed_batch(vec![phrase])
.await?
.into_iter()
.next()
.unwrap();
log::trace!(
"Embedding search phrase took: {:?} milliseconds",
t0.elapsed().as_millis()
);
let worktree_db_ids = this.read_with(&cx, |this, _| { let worktree_db_ids = this.read_with(&cx, |this, _| {
let project_state = this let project_state = this
.projects .projects
@ -738,6 +774,7 @@ impl SemanticIndex {
.collect::<Vec<i64>>(); .collect::<Vec<i64>>();
anyhow::Ok(worktree_db_ids) anyhow::Ok(worktree_db_ids)
})?; })?;
let file_ids = database let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?; .await?;
@ -756,26 +793,26 @@ impl SemanticIndex {
let limit = limit.clone(); let limit = limit.clone();
let fs = fs.clone(); let fs = fs.clone();
let db_path = db_path.clone(); let db_path = db_path.clone();
let phrase_embedding = phrase_embedding.clone(); let query = query.clone();
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await .await
.log_err() .log_err()
{ {
batch_results.push(async move { batch_results.push(async move {
db.top_k_search(&phrase_embedding, limit, batch.as_slice()) db.top_k_search(&query, limit, batch.as_slice()).await
.await
}); });
} }
} }
let batch_results = futures::future::join_all(batch_results).await; let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new(); let mut results = Vec::new();
for batch_result in batch_results { for batch_result in batch_results {
if batch_result.is_ok() { if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() { for (id, similarity) in batch_result.unwrap() {
let ix = match results.binary_search_by(|(_, s)| { let ix = match results
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
}) { {
Ok(ix) => ix, Ok(ix) => ix,
Err(ix) => ix, Err(ix) => ix,
}; };
@ -785,7 +822,11 @@ impl SemanticIndex {
} }
} }
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>(); let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
let scores = results
.into_iter()
.map(|(_, score)| score)
.collect::<Vec<_>>();
let spans = database.spans_for_ids(ids.as_slice()).await?; let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
@ -810,24 +851,106 @@ impl SemanticIndex {
let buffers = futures::future::join_all(tasks).await; let buffers = futures::future::join_all(tasks).await;
log::trace!(
"Semantic Searching took: {:?} milliseconds in total",
t0.elapsed().as_millis()
);
Ok(buffers Ok(buffers
.into_iter() .into_iter()
.zip(ranges) .zip(ranges)
.filter_map(|(buffer, range)| { .zip(scores)
.filter_map(|((buffer, range), similarity)| {
let buffer = buffer.log_err()?; let buffer = buffer.log_err()?;
let range = buffer.read_with(&cx, |buffer, _| { let range = buffer.read_with(&cx, |buffer, _| {
let start = buffer.clip_offset(range.start, Bias::Left); let start = buffer.clip_offset(range.start, Bias::Left);
let end = buffer.clip_offset(range.end, Bias::Right); let end = buffer.clip_offset(range.end, Bias::Right);
buffer.anchor_before(start)..buffer.anchor_after(end) buffer.anchor_before(start)..buffer.anchor_after(end)
}); });
Some(SearchResult { buffer, range }) Some(SearchResult {
buffer,
range,
similarity,
}) })
.collect::<Vec<_>>()) })
.collect())
})
}
fn search_modified_buffers(
&self,
project: &ModelHandle<Project>,
query: Embedding,
limit: usize,
excludes: &[PathMatcher],
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let modified_buffers = project
.read(cx)
.opened_buffers(cx)
.into_iter()
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
let snapshot = buffer.snapshot();
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
excludes.iter().any(|matcher| matcher.is_match(&path))
});
if buffer.is_dirty() && !excluded {
Some((buffer_handle, snapshot))
} else {
None
}
})
.collect::<HashMap<_, _>>();
let embedding_provider = self.embedding_provider.clone();
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
for (buffer, snapshot) in modified_buffers {
let language = snapshot
.language_at(0)
.cloned()
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
let mut spans = retriever
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
let ix = match results
.binary_search_by_key(&Reverse(similarity), |result| {
Reverse(result.similarity)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
let range = {
let start = snapshot.clip_offset(span.range.start, Bias::Left);
let end = snapshot.clip_offset(span.range.end, Bias::Right);
snapshot.anchor_before(start)..snapshot.anchor_after(end)
};
results.insert(
ix,
SearchResult {
buffer: buffer.clone(),
range,
similarity,
},
);
results.truncate(limit);
}
}
}
Ok(results)
}) })
} }
@ -1011,6 +1134,63 @@ impl SemanticIndex {
Ok(()) Ok(())
}) })
} }
async fn embed_spans(
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
let digests = spans
.iter()
.map(|span| span.digest.clone())
.collect::<Vec<_>>();
let embeddings_for_digests = db
.embeddings_for_digests(digests)
.await
.log_err()
.unwrap_or_default();
for span in &*spans {
if embeddings_for_digests.contains_key(&span.digest) {
continue;
};
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
Some(embedding.clone())
} else {
embeddings.next()
};
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
span.embedding = Some(embedding);
}
Ok(())
}
} }
impl Entity for SemanticIndex { impl Entity for SemanticIndex {