wip
Co-Authored-By: Antonio <antonio@zed.dev>
This commit is contained in:
parent
82760d6d1a
commit
ce76955068
2 changed files with 131 additions and 60 deletions
|
@ -1,9 +1,9 @@
|
||||||
use ndarray::{Array1, Array2, Axis, CowArray};
|
use ndarray::CowArray;
|
||||||
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, Value};
|
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, Value};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use util::paths::MODELS_DIR;
|
use util::paths::MODELS_DIR;
|
||||||
|
|
||||||
struct CrossEncoder {
|
pub struct CrossEncoder {
|
||||||
session: Session,
|
session: Session,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
@ -25,15 +25,22 @@ impl CrossEncoder {
|
||||||
|
|
||||||
let session = SessionBuilder::new(&environment)?
|
let session = SessionBuilder::new(&environment)?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||||
.with_intra_threads(1)?
|
|
||||||
.with_model_from_file(model_path)?;
|
.with_model_from_file(model_path)?;
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
let mut tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||||
|
tokenizer
|
||||||
|
.with_truncation(Some(tokenizers::TruncationParams {
|
||||||
|
direction: Default::default(),
|
||||||
|
max_length: 512,
|
||||||
|
strategy: Default::default(),
|
||||||
|
stride: 0,
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
Ok(Self { session, tokenizer })
|
Ok(Self { session, tokenizer })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn score(&self, query: &str, candidates: Vec<&str>) -> anyhow::Result<Vec<f32>> {
|
pub fn score(&self, query: &str, candidates: &[String]) -> anyhow::Result<Vec<f32>> {
|
||||||
let spans = candidates
|
let spans = candidates
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|candidate| format!("{}. {}", query, candidate))
|
.map(|candidate| format!("{}. {}", query, candidate))
|
||||||
|
@ -91,9 +98,18 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cross_encoder() {
|
fn test_cross_encoder() {
|
||||||
let cross_encoder = CrossEncoder::load().unwrap();
|
let cross_encoder = CrossEncoder::load().unwrap();
|
||||||
|
let results = cross_encoder
|
||||||
let sample_candidates = vec!["I love you.", "I hate you."];
|
.score(
|
||||||
let results = cross_encoder.score("I like you", sample_candidates.clone());
|
"I like you",
|
||||||
assert_eq!(results.unwrap().len(), sample_candidates.len());
|
&[
|
||||||
|
"I hate you.".into(),
|
||||||
|
"I love you.".into(),
|
||||||
|
"my name is kyle".into(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
assert!(results[1] > results[0]);
|
||||||
|
assert!(results[0] > results[2]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ pub mod semantic_index_settings;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::{cross_encoder::CrossEncoder, semantic_index_settings::SemanticIndexSettings};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
|
@ -266,6 +266,7 @@ pub struct PendingFile {
|
||||||
pub struct SearchResult {
|
pub struct SearchResult {
|
||||||
pub buffer: ModelHandle<Buffer>,
|
pub buffer: ModelHandle<Buffer>,
|
||||||
pub range: Range<Anchor>,
|
pub range: Range<Anchor>,
|
||||||
|
pub similarity: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SemanticIndex {
|
impl SemanticIndex {
|
||||||
|
@ -697,7 +698,7 @@ impl SemanticIndex {
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
|
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
@ -709,7 +710,7 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
|
|
||||||
let phrase_embedding = embedding_provider
|
let phrase_embedding = embedding_provider
|
||||||
.embed_batch(vec![phrase])
|
.embed_batch(vec![phrase.clone()])
|
||||||
.await?
|
.await?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
|
@ -750,6 +751,11 @@ impl SemanticIndex {
|
||||||
ids_len / batch_n
|
ids_len / batch_n
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let cross_encoder = Arc::new(
|
||||||
|
cx.background()
|
||||||
|
.spawn(async move { CrossEncoder::load() })
|
||||||
|
.await?,
|
||||||
|
);
|
||||||
let mut batch_results = Vec::new();
|
let mut batch_results = Vec::new();
|
||||||
for batch in file_ids.chunks(batch_size) {
|
for batch in file_ids.chunks(batch_size) {
|
||||||
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
|
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
|
||||||
|
@ -757,77 +763,126 @@ impl SemanticIndex {
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
let db_path = db_path.clone();
|
let db_path = db_path.clone();
|
||||||
let phrase_embedding = phrase_embedding.clone();
|
let phrase_embedding = phrase_embedding.clone();
|
||||||
|
let phrase = phrase.clone();
|
||||||
|
let cross_encoder = cross_encoder.clone();
|
||||||
|
let project = project.clone();
|
||||||
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
{
|
{
|
||||||
batch_results.push(async move {
|
let this = this.clone();
|
||||||
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
batch_results.push(cx.spawn(|mut cx| async move {
|
||||||
|
let span_ids = db
|
||||||
|
.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.map(|(span_id, _)| span_id)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut spans_by_buffer = HashMap::default();
|
||||||
|
for (worktree_db_id, path, range) in db.spans_for_ids(&span_ids).await? {
|
||||||
|
let worktree_id = this.read_with(&cx, |this, _| {
|
||||||
|
let project_state = this
|
||||||
|
.projects
|
||||||
|
.get(&project.downgrade())
|
||||||
|
.ok_or_else(|| anyhow!("project not added"))?;
|
||||||
|
|
||||||
|
anyhow::Ok(project_state.worktree_id_for_db_id(worktree_db_id))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if let Some(worktree_id) = worktree_id {
|
||||||
|
let buffer = project
|
||||||
|
.update(&mut cx, |project, cx| {
|
||||||
|
project.open_buffer((worktree_id, path), cx)
|
||||||
|
})
|
||||||
.await
|
.await
|
||||||
|
.log_err();
|
||||||
|
if let Some(buffer) = buffer {
|
||||||
|
let range = buffer.read_with(&cx, |buffer, _| {
|
||||||
|
let range = buffer.clip_offset(range.start, Bias::Left)
|
||||||
|
..buffer.clip_offset(range.end, Bias::Right);
|
||||||
|
buffer.anchor_before(range.start)
|
||||||
|
..buffer.anchor_after(range.end)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
spans_by_buffer
|
||||||
|
.entry(buffer)
|
||||||
|
.or_insert(Vec::new())
|
||||||
|
.push(range);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
for (buffer, ranges) in &spans_by_buffer {
|
||||||
|
buffer.read_with(&cx, |buffer, _| {
|
||||||
|
for range in ranges {
|
||||||
|
let span =
|
||||||
|
buffer.text_for_range(range.clone()).collect::<String>();
|
||||||
|
spans.push(span);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cross Encoder
|
||||||
|
// TODO: move background.spawn into cross_encoder.
|
||||||
|
let results = cx
|
||||||
|
.background()
|
||||||
|
.spawn(async move {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let mut scores = cross_encoder.score(&phrase, &spans)?.into_iter();
|
||||||
|
for (buffer, ranges) in spans_by_buffer {
|
||||||
|
for range in ranges {
|
||||||
|
let similarity = if let Some(similarity) = scores.next() {
|
||||||
|
similarity
|
||||||
|
} else {
|
||||||
|
log::error!("cross encoder returned too few scores");
|
||||||
|
f32::NEG_INFINITY
|
||||||
|
};
|
||||||
|
|
||||||
|
results.push(SearchResult {
|
||||||
|
buffer: buffer.clone(),
|
||||||
|
range,
|
||||||
|
similarity,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::Ok(results)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
anyhow::Ok(results)
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let batch_results = futures::future::join_all(batch_results).await;
|
let batch_results = futures::future::join_all(batch_results).await;
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
for batch_result in batch_results {
|
for batch_result in batch_results {
|
||||||
if batch_result.is_ok() {
|
if let Some(batch_result) = batch_result.log_err() {
|
||||||
for (id, similarity) in batch_result.unwrap() {
|
for new_result in batch_result {
|
||||||
let ix = match results.binary_search_by(|(_, s)| {
|
let ix = match results.binary_search_by(|old_result| {
|
||||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
new_result
|
||||||
|
.similarity
|
||||||
|
.partial_cmp(&old_result.similarity)
|
||||||
|
.unwrap_or(Ordering::Equal)
|
||||||
}) {
|
}) {
|
||||||
Ok(ix) => ix,
|
Ok(ix) => ix,
|
||||||
Err(ix) => ix,
|
Err(ix) => ix,
|
||||||
};
|
};
|
||||||
results.insert(ix, (id, similarity));
|
dbg!(ix);
|
||||||
|
dbg!(new_result.similarity);
|
||||||
|
results.insert(ix, new_result);
|
||||||
results.truncate(limit);
|
results.truncate(limit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
|
|
||||||
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
|
||||||
|
|
||||||
let mut tasks = Vec::new();
|
|
||||||
let mut ranges = Vec::new();
|
|
||||||
let weak_project = project.downgrade();
|
|
||||||
project.update(&mut cx, |project, cx| {
|
|
||||||
for (worktree_db_id, file_path, byte_range) in spans {
|
|
||||||
let project_state =
|
|
||||||
if let Some(state) = this.read(cx).projects.get(&weak_project) {
|
|
||||||
state
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!("project not added"));
|
|
||||||
};
|
|
||||||
if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
|
|
||||||
tasks.push(project.open_buffer((worktree_id, file_path), cx));
|
|
||||||
ranges.push(byte_range);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let buffers = futures::future::join_all(tasks).await;
|
|
||||||
|
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"Semantic Searching took: {:?} milliseconds in total",
|
"Semantic Searching took: {:?} milliseconds in total",
|
||||||
t0.elapsed().as_millis()
|
t0.elapsed().as_millis()
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(buffers
|
Ok(results)
|
||||||
.into_iter()
|
|
||||||
.zip(ranges)
|
|
||||||
.filter_map(|(buffer, range)| {
|
|
||||||
let buffer = buffer.log_err()?;
|
|
||||||
let range = buffer.read_with(&cx, |buffer, _| {
|
|
||||||
let start = buffer.clip_offset(range.start, Bias::Left);
|
|
||||||
let end = buffer.clip_offset(range.end, Bias::Right);
|
|
||||||
buffer.anchor_before(start)..buffer.anchor_after(end)
|
|
||||||
});
|
|
||||||
Some(SearchResult { buffer, range })
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>())
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue