catchup with main
This commit is contained in:
commit
3a661c5977
67 changed files with 2965 additions and 2922 deletions
|
@ -23,6 +23,7 @@ settings = { path = "../settings" }
|
|||
anyhow.workspace = true
|
||||
postage.workspace = true
|
||||
futures.workspace = true
|
||||
ordered-float.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
isahc.workspace = true
|
||||
|
|
|
@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
|
|||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::params;
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
cmp::Reverse,
|
||||
future::Future,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
|
@ -190,6 +191,10 @@ impl VectorDatabase {
|
|||
)",
|
||||
[],
|
||||
)?;
|
||||
db.execute(
|
||||
"CREATE INDEX spans_digest ON spans (digest)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
|
@ -274,6 +279,39 @@ impl VectorDatabase {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_digests(
|
||||
&self,
|
||||
digests: Vec<SpanDigest>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
WHERE digest IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
let digests = Rc::new(
|
||||
digests
|
||||
.into_iter()
|
||||
.map(|p| Value::Blob(p.0.to_vec()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![digests], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
embeddings_by_digest.insert(row.0, row.1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_files(
|
||||
&self,
|
||||
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
|
||||
|
@ -370,16 +408,16 @@ impl VectorDatabase {
|
|||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||
let query_embedding = query_embedding.clone();
|
||||
let file_ids = file_ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
|
||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
|
|
@ -7,6 +7,7 @@ use isahc::http::StatusCode;
|
|||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::Mutex;
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
|
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
|
|||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> f32 {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
|
@ -58,7 +59,7 @@ impl Embedding {
|
|||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -379,13 +380,13 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use rusqlite::{
|
|||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
ops::Range,
|
||||
|
@ -16,7 +17,7 @@ use std::{
|
|||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct SpanDigest([u8; 20]);
|
||||
pub struct SpanDigest(pub [u8; 20]);
|
||||
|
||||
impl FromSql for SpanDigest {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
|
@ -94,12 +95,15 @@ impl CodeContextRetriever {
|
|||
|
||||
fn parse_entire_file(
|
||||
&self,
|
||||
relative_path: &Path,
|
||||
relative_path: Option<&Path>,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
|
@ -114,9 +118,16 @@ impl CodeContextRetriever {
|
|||
}])
|
||||
}
|
||||
|
||||
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
|
||||
fn parse_markdown_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
@ -188,7 +199,7 @@ impl CodeContextRetriever {
|
|||
|
||||
pub fn parse_file_with_template(
|
||||
&mut self,
|
||||
relative_path: &Path,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Span>> {
|
||||
|
@ -196,14 +207,17 @@ impl CodeContextRetriever {
|
|||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
} else if language_name.as_ref() == "Markdown" {
|
||||
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
|
||||
return self.parse_markdown_file(relative_path, &content);
|
||||
}
|
||||
|
||||
let mut spans = self.parse_file(content, language)?;
|
||||
for span in &mut spans {
|
||||
let document_content = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
|
|
|
@ -16,14 +16,16 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
|||
use futures::{future, FutureExt, StreamExt};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use postage::watch;
|
||||
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
cmp::Reverse,
|
||||
future::Future,
|
||||
mem,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, Weak},
|
||||
|
@ -37,7 +39,7 @@ use util::{
|
|||
};
|
||||
use workspace::WorkspaceCreated;
|
||||
|
||||
const SEMANTIC_INDEX_VERSION: usize = 10;
|
||||
const SEMANTIC_INDEX_VERSION: usize = 11;
|
||||
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
|
||||
const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
|
||||
|
||||
|
@ -262,9 +264,11 @@ pub struct PendingFile {
|
|||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SearchResult {
|
||||
pub buffer: ModelHandle<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
pub similarity: OrderedFloat<f32>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
|
@ -402,7 +406,7 @@ impl SemanticIndex {
|
|||
|
||||
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
|
||||
if let Some(mut spans) = retriever
|
||||
.parse_file_with_template(&pending_file.relative_path, &content, language)
|
||||
.parse_file_with_template(Some(&pending_file.relative_path), &content, language)
|
||||
.log_err()
|
||||
{
|
||||
log::trace!(
|
||||
|
@ -422,7 +426,7 @@ impl SemanticIndex {
|
|||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
spans: spans,
|
||||
spans,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -687,39 +691,71 @@ impl SemanticIndex {
|
|||
pub fn search_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
phrase: String,
|
||||
query: String,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
excludes: Vec<PathMatcher>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
if query.is_empty() {
|
||||
return Task::ready(Ok(Vec::new()));
|
||||
}
|
||||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
index.await?;
|
||||
|
||||
let search_start = Instant::now();
|
||||
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
|
||||
});
|
||||
let file_results = this.update(&mut cx, |this, cx| {
|
||||
this.search_files(project, query, limit, includes, excludes, cx)
|
||||
});
|
||||
let (modified_buffer_results, file_results) =
|
||||
futures::join!(modified_buffer_results, file_results);
|
||||
|
||||
// Weave together the results from modified buffers and files.
|
||||
let mut results = Vec::new();
|
||||
let mut modified_buffers = HashSet::default();
|
||||
for result in modified_buffer_results.log_err().unwrap_or_default() {
|
||||
modified_buffers.insert(result.buffer.clone());
|
||||
results.push(result);
|
||||
}
|
||||
for result in file_results.log_err().unwrap_or_default() {
|
||||
if !modified_buffers.contains(&result.buffer) {
|
||||
results.push(result);
|
||||
}
|
||||
}
|
||||
results.sort_by_key(|result| Reverse(result.similarity));
|
||||
results.truncate(limit);
|
||||
log::trace!("Semantic search took {:?}", search_start.elapsed());
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn search_files(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
includes: Vec<PathMatcher>,
|
||||
excludes: Vec<PathMatcher>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let db_path = self.db.path().clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
|
||||
let t0 = Instant::now();
|
||||
let database =
|
||||
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
|
||||
|
||||
if phrase.len() == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
log::trace!(
|
||||
"Embedding search phrase took: {:?} milliseconds",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
let worktree_db_ids = this.read_with(&cx, |this, _| {
|
||||
let project_state = this
|
||||
.projects
|
||||
|
@ -738,6 +774,7 @@ impl SemanticIndex {
|
|||
.collect::<Vec<i64>>();
|
||||
anyhow::Ok(worktree_db_ids)
|
||||
})?;
|
||||
|
||||
let file_ids = database
|
||||
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
|
||||
.await?;
|
||||
|
@ -756,26 +793,26 @@ impl SemanticIndex {
|
|||
let limit = limit.clone();
|
||||
let fs = fs.clone();
|
||||
let db_path = db_path.clone();
|
||||
let phrase_embedding = phrase_embedding.clone();
|
||||
let query = query.clone();
|
||||
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
batch_results.push(async move {
|
||||
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
||||
.await
|
||||
db.top_k_search(&query, limit, batch.as_slice()).await
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let batch_results = futures::future::join_all(batch_results).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for batch_result in batch_results {
|
||||
if batch_result.is_ok() {
|
||||
for (id, similarity) in batch_result.unwrap() {
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
@ -785,7 +822,11 @@ impl SemanticIndex {
|
|||
}
|
||||
}
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
|
||||
let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
|
||||
let scores = results
|
||||
.into_iter()
|
||||
.map(|(_, score)| score)
|
||||
.collect::<Vec<_>>();
|
||||
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
@ -810,24 +851,106 @@ impl SemanticIndex {
|
|||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
|
||||
log::trace!(
|
||||
"Semantic Searching took: {:?} milliseconds in total",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
.filter_map(|(buffer, range)| {
|
||||
.zip(scores)
|
||||
.filter_map(|((buffer, range), similarity)| {
|
||||
let buffer = buffer.log_err()?;
|
||||
let range = buffer.read_with(&cx, |buffer, _| {
|
||||
let start = buffer.clip_offset(range.start, Bias::Left);
|
||||
let end = buffer.clip_offset(range.end, Bias::Right);
|
||||
buffer.anchor_before(start)..buffer.anchor_after(end)
|
||||
});
|
||||
Some(SearchResult { buffer, range })
|
||||
Some(SearchResult {
|
||||
buffer,
|
||||
range,
|
||||
similarity,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
fn search_modified_buffers(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
query: Embedding,
|
||||
limit: usize,
|
||||
excludes: &[PathMatcher],
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let modified_buffers = project
|
||||
.read(cx)
|
||||
.opened_buffers(cx)
|
||||
.into_iter()
|
||||
.filter_map(|buffer_handle| {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
|
||||
excludes.iter().any(|matcher| matcher.is_match(&path))
|
||||
});
|
||||
if buffer.is_dirty() && !excluded {
|
||||
Some((buffer_handle, snapshot))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
for (buffer, snapshot) in modified_buffers {
|
||||
let language = snapshot
|
||||
.language_at(0)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
||||
let mut spans = retriever
|
||||
.parse_file_with_template(None, &snapshot.text(), language)
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
for span in spans {
|
||||
let similarity = span.embedding.unwrap().similarity(&query);
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |result| {
|
||||
Reverse(result.similarity)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
||||
let range = {
|
||||
let start = snapshot.clip_offset(span.range.start, Bias::Left);
|
||||
let end = snapshot.clip_offset(span.range.end, Bias::Right);
|
||||
snapshot.anchor_before(start)..snapshot.anchor_after(end)
|
||||
};
|
||||
|
||||
results.insert(
|
||||
ix,
|
||||
SearchResult {
|
||||
buffer: buffer.clone(),
|
||||
range,
|
||||
similarity,
|
||||
},
|
||||
);
|
||||
results.truncate(limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1009,6 +1132,63 @@ impl SemanticIndex {
|
|||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
async fn embed_spans(
|
||||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
let digests = spans
|
||||
.iter()
|
||||
.map(|span| span.digest.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let embeddings_for_digests = db
|
||||
.embeddings_for_digests(digests)
|
||||
.await
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
|
||||
for span in &*spans {
|
||||
if embeddings_for_digests.contains_key(&span.digest) {
|
||||
continue;
|
||||
};
|
||||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
}
|
||||
|
||||
batch_tokens += span.token_count;
|
||||
batch.push(span.content.clone());
|
||||
}
|
||||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for span in spans {
|
||||
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
|
||||
Some(embedding.clone())
|
||||
} else {
|
||||
embeddings.next()
|
||||
};
|
||||
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
|
||||
span.embedding = Some(embedding);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for SemanticIndex {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue