Semantic Index (#10329)

This introduces semantic indexing in Zed based on chunking text from
files in the developer's workspace and creating vector embeddings using
an embedding model. As part of this, we've created an embeddings
provider trait that allows us to work with OpenAI, a local Ollama model,
or a Zed hosted embedding.

The semantic index is built by breaking down text for known
(programming) languages into manageable chunks that are smaller than the
max token size. Each chunk is then fed to a language model to create a
high dimensional vector which is then normalized to a unit vector to
allow fast comparison with other vectors with a simple dot product.
Alongside the vector, we store the path of the file and the range within
the document where the vector was sourced from.

Zed will soon grok contextual similarity across different text snippets,
allowing for natural language search beyond keyword matching. This is
being put together both for human-based search as well as providing
results to Large Language Models to allow them to refine how they help
developers.

Remaining todo:

* [x] Change `provider` to `model` within the zed hosted embeddings
database (as its currently a combo of the provider and the model in one
name)


Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Kyle Kelley 2024-04-12 10:40:59 -07:00 committed by GitHub
parent 4b40e83b8b
commit 49371b44cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 2649 additions and 41 deletions

View file

@ -0,0 +1,48 @@
[package]
name = "semantic_index"
description = "Process, chunk, and embed text as vectors for semantic search."
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lib]
path = "src/semantic_index.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
heed.workspace = true
open_ai.workspace = true
project.workspace = true
settings.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
smol.workspace = true
util. workspace = true
worktree.workspace = true
[dev-dependencies]
env_logger.workspace = true
client = { workspace = true, features = ["test-support"] }
fs = { workspace = true, features = ["test-support"] }
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
languages.workspace = true
project = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
util = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
[lints]
workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,140 @@
use client::Client;
use futures::channel::oneshot;
use gpui::{App, Global, TestAppContext};
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use settings::SettingsStore;
use std::{path::Path, sync::Arc};
use util::http::HttpClientWithUrl;
pub fn init_test(cx: &mut TestAppContext) {
_ = cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
});
}
fn main() {
env_logger::init();
use clock::FakeSystemClock;
App::new().run(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
let clock = Arc::new(FakeSystemClock::default());
let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434"));
let client = client::Client::new(clock, http.clone(), cx);
Client::set_global(client.clone(), cx);
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: cargo run --example index -p semantic_index -- <project_path>");
cx.quit();
return;
}
// let embedding_provider = semantic_index::FakeEmbeddingProvider;
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let embedding_provider = OpenAiEmbeddingProvider::new(
http.clone(),
OpenAiEmbeddingModel::TextEmbedding3Small,
open_ai::OPEN_AI_API_URL.to_string(),
api_key,
);
let semantic_index = SemanticIndex::new(
Path::new("/tmp/semantic-index-db.mdb"),
Arc::new(embedding_provider),
cx,
);
cx.spawn(|mut cx| async move {
let mut semantic_index = semantic_index.await.unwrap();
let project_path = Path::new(&args[1]);
let project = Project::example([project_path], &mut cx).await;
cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
})
.unwrap();
let project_index = cx
.update(|cx| semantic_index.project_index(project.clone(), cx))
.unwrap();
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let subscription = cx.update(|cx| {
cx.subscribe(&project_index, move |_, event, _| {
if let Some(tx) = tx.take() {
_ = tx.send(*event);
}
})
});
let index_start = std::time::Instant::now();
rx.await.expect("no event emitted");
drop(subscription);
println!("Index time: {:?}", index_start.elapsed());
let results = cx
.update(|cx| {
let project_index = project_index.read(cx);
let query = "converting an anchor to a point";
project_index.search(query, 4, cx)
})
.unwrap()
.await;
for search_result in results {
let path = search_result.path.clone();
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
let fs = project.read(cx).fs().clone();
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
})
.unwrap()
.await;
let range = search_result.range.clone();
let content = content[search_result.range].to_owned();
println!(
"✄✄✄✄✄✄✄✄✄✄✄✄✄✄ {:?} @ {} ✄✄✄✄✄✄✄✄✄✄✄✄✄✄",
path, search_result.score
);
println!("{:?}:{:?}:{:?}", path, range.start, range.end);
println!("{}", content);
}
cx.background_executor()
.timer(std::time::Duration::from_secs(100000))
.await;
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
}

View file

@ -0,0 +1,3 @@
fn main() {
println!("Hello Indexer!");
}

View file

@ -0,0 +1,43 @@
# Searching for a needle in a haystack
When you have a large amount of text, it can be useful to search for a specific word or phrase. This is often referred to as "finding a needle in a haystack." In this markdown document, we're "hiding" a key phrase for our text search to find. Can you find it?
## Instructions
1. Use the search functionality in your text editor or markdown viewer to find the hidden phrase in this document.
2. Once you've found the **phrase**, write it down and proceed to the next step.
Honestly, I just want to fill up plenty of characters so that we chunk this markdown into several chunks.
## Tips
- Relax
- Take a deep breath
- Focus on the task at hand
- Don't get distracted by other text
- Use the search functionality to your advantage
## Example code
```python
def search_for_needle(haystack, needle):
if needle in haystack:
return True
else:
return False
```
```javascript
function searchForNeedle(haystack, needle) {
return haystack.includes(needle);
}
```
## Background
When creating an index for a book or searching for a specific term in a large document, the ability to quickly find a specific word or phrase is essential. This is where search functionality comes in handy. However, one should _remember_ that the search is only as good as the index that was built. As they say, garbage in, garbage out!
## Conclusion
Searching for a needle in a haystack can be a challenging task, but with the right tools and techniques, it becomes much easier. Whether you're looking for a specific word in a document or trying to find a key piece of information in a large dataset, the ability to search efficiently is a valuable skill to have.

View file

@ -0,0 +1,409 @@
use language::{with_parser, Grammar, Tree};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{cmp, ops::Range, sync::Arc};
const CHUNK_THRESHOLD: usize = 1500;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub range: Range<usize>,
pub digest: [u8; 32],
}
pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
if let Some(grammar) = grammar {
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(&text, None).expect("invalid language")
});
chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
} else {
chunk_lines(&text)
}
}
fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
let mut chunk_ranges = Vec::new();
let mut cursor = tree.walk();
let mut range = 0..0;
loop {
let node = cursor.node();
// If adding the node to the current chunk exceeds the threshold
if node.end_byte() - range.start > chunk_threshold {
// Try to descend into its first child. If we can't, flush the current
// range and try again.
if cursor.goto_first_child() {
continue;
} else if !range.is_empty() {
chunk_ranges.push(range.clone());
range.start = range.end;
continue;
}
// If we get here, the node itself has no children but is larger than the threshold.
// Break its text into arbitrary chunks.
split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
}
range.end = node.end_byte();
// If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
while !cursor.goto_next_sibling() {
if !cursor.goto_parent() {
if !range.is_empty() {
chunk_ranges.push(range);
}
return chunk_ranges
.into_iter()
.map(|range| {
let digest = Sha256::digest(&text[range.clone()]).into();
Chunk { range, digest }
})
.collect();
}
}
}
}
fn chunk_lines(text: &str) -> Vec<Chunk> {
let mut chunk_ranges = Vec::new();
let mut range = 0..0;
let mut newlines = text.match_indices('\n').peekable();
while let Some((newline_ix, _)) = newlines.peek() {
let newline_ix = newline_ix + 1;
if newline_ix - range.start <= CHUNK_THRESHOLD {
range.end = newline_ix;
newlines.next();
} else {
if range.is_empty() {
split_text(text, range, newline_ix, &mut chunk_ranges);
range = newline_ix..newline_ix;
} else {
chunk_ranges.push(range.clone());
range.start = range.end;
}
}
}
if !range.is_empty() {
chunk_ranges.push(range);
}
chunk_ranges
.into_iter()
.map(|range| {
let mut hasher = Sha256::new();
hasher.update(&text[range.clone()]);
let mut digest = [0u8; 32];
digest.copy_from_slice(hasher.finalize().as_slice());
Chunk { range, digest }
})
.collect()
}
fn split_text(
text: &str,
mut range: Range<usize>,
max_end: usize,
chunk_ranges: &mut Vec<Range<usize>>,
) {
while range.start < max_end {
range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
while !text.is_char_boundary(range.end) {
range.end -= 1;
}
chunk_ranges.push(range.clone());
range.start = range.end;
}
}
#[cfg(test)]
mod tests {
use super::*;
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
// This example comes from crates/gpui/examples/window_positioning.rs which
// has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
static TEXT: &str = r#"
use gpui::*;
struct WindowContent {
text: SharedString,
}
impl Render for WindowContent {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
div()
.flex()
.bg(rgb(0x1e2025))
.size_full()
.justify_center()
.items_center()
.text_xl()
.text_color(rgb(0xffffff))
.child(self.text.clone())
}
}
fn main() {
App::new().run(|cx: &mut AppContext| {
// Create several new windows, positioned in the top right corner of each screen
for screen in cx.displays() {
let options = {
let popup_margin_width = DevicePixels::from(16);
let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
let window_size = Size {
width: px(400.),
height: px(72.),
};
let screen_bounds = screen.bounds();
let size: Size<DevicePixels> = window_size.into();
let bounds = gpui::Bounds::<DevicePixels> {
origin: screen_bounds.upper_right()
- point(size.width + popup_margin_width, popup_margin_height),
size: window_size.into(),
};
WindowOptions {
// Set the bounds of the window in screen coordinates
bounds: Some(bounds),
// Specify the display_id to ensure the window is created on the correct screen
display_id: Some(screen.id()),
titlebar: None,
window_background: WindowBackgroundAppearance::default(),
focus: false,
show: true,
kind: WindowKind::PopUp,
is_movable: false,
fullscreen: false,
}
};
cx.open_window(options, |cx| {
cx.new_view(|_| WindowContent {
text: format!("{:?}", screen.id()).into(),
})
});
}
});
}"#;
fn setup_rust_language() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
}
#[test]
fn test_chunk_text() {
let text = "a\n".repeat(1000);
let chunks = chunk_text(&text, None);
assert_eq!(
chunks.len(),
((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
);
}
#[test]
fn test_chunk_text_grammar() {
// Let's set up a big text with some known segments
// We'll then chunk it and verify that the chunks are correct
let language = setup_rust_language();
let chunks = chunk_text(TEXT, language.grammar());
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, 1498);
// The break between chunks is right before the "Specify the display_id" comment
assert_eq!(chunks[1].range.start, 1498);
assert_eq!(chunks[1].range.end, 2396);
}
#[test]
fn test_chunk_parse_tree() {
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(TEXT, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, TEXT, 250);
assert_eq!(chunks.len(), 11);
}
#[test]
fn test_chunk_unparsable() {
// Even if a chunk is unparsable, we should still be able to chunk it
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let text = r#"fn main() {"#;
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, text, 250);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, 11);
}
#[test]
fn test_empty_text() {
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse("", None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
assert!(chunks.is_empty(), "Chunks should be empty for empty text");
}
#[test]
fn test_single_large_node() {
let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(&large_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
assert_eq!(
chunks.len(),
3,
"Large chunks are broken up according to grammar as best as possible"
);
// Expect chunks to be static, aaaaaa..., and = 2
assert_eq!(chunks[0].range.start, 0);
assert_eq!(chunks[0].range.end, "static".len());
assert_eq!(chunks[1].range.start, "static".len());
assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
assert_eq!(chunks[2].range.end, large_text.len());
}
#[test]
fn test_multiple_small_nodes() {
let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(small_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, small_text, 5);
assert!(
chunks.len() > 1,
"Should have multiple chunks for multiple small nodes"
);
}
#[test]
fn test_node_with_children() {
let nested_text = "fn main() { let a = 1; let b = 2; }";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(nested_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, nested_text, 10);
assert!(
chunks.len() > 1,
"Should have multiple chunks for a node with children"
);
}
#[test]
fn test_text_with_unparsable_sections() {
// This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
let language = setup_rust_language();
let grammar = language.grammar().unwrap();
let tree = with_parser(|parser| {
parser
.set_language(&grammar.ts_language)
.expect("incompatible grammar");
parser.parse(mixed_text, None).expect("invalid language")
});
let chunks = chunk_parse_tree(tree, mixed_text, 11);
assert!(
chunks.len() > 1,
"Should handle both parsable and unparsable sections correctly"
);
let expected_chunks = [
"fn main() {",
" let a = 1;",
" let b = 2;",
" }",
" unparsable",
" bits here",
];
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(
&mixed_text[chunk.range.clone()],
expected_chunks[i],
"Chunk {} should match",
i
);
}
}
}

View file

@ -0,0 +1,125 @@
mod cloud;
mod ollama;
mod open_ai;
pub use cloud::*;
pub use ollama::*;
pub use open_ai::*;
use sha2::{Digest, Sha256};
use anyhow::Result;
use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
impl Embedding {
pub fn new(mut embedding: Vec<f32>) -> Self {
let len = embedding.len();
let mut norm = 0f32;
for i in 0..len {
norm += embedding[i] * embedding[i];
}
norm = norm.sqrt();
for dimension in &mut embedding {
*dimension /= norm;
}
Self(embedding)
}
fn len(&self) -> usize {
self.0.len()
}
pub fn similarity(self, other: &Embedding) -> f32 {
debug_assert_eq!(self.0.len(), other.0.len());
self.0
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum()
}
}
impl fmt::Display for Embedding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let digits_to_display = 3;
// Start the Embedding display format
write!(f, "Embedding(sized: {}; values: [", self.len())?;
for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
// Lead with comma if not the first element
if index != 0 {
write!(f, ", ")?;
}
write!(f, "{:.3}", value)?;
}
if self.len() > digits_to_display {
write!(f, "...")?;
}
write!(f, "])")
}
}
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,
pub digest: [u8; 32],
}
impl<'a> TextToEmbed<'a> {
pub fn new(text: &'a str) -> Self {
let digest = Sha256::digest(text.as_bytes());
Self {
text,
digest: digest.into(),
}
}
}
pub struct FakeEmbeddingProvider;
impl EmbeddingProvider for FakeEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
.map(|_text| {
let mut embedding = vec![0f32; 1536];
for i in 0..embedding.len() {
embedding[i] = i as f32;
}
Embedding::new(embedding)
})
.collect();
future::ready(Ok(embeddings)).boxed()
}
fn batch_size(&self) -> usize {
16
}
}
#[cfg(test)]
mod test {
use super::*;
#[gpui::test]
fn test_normalize_embedding() {
let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
let value: f32 = 1.0 / 3.0_f32.sqrt();
assert_eq!(normalized, Embedding(vec![value; 3]));
}
}

View file

@ -0,0 +1,88 @@
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
use anyhow::{anyhow, Context, Result};
use client::{proto, Client};
use collections::HashMap;
use futures::{future::BoxFuture, FutureExt};
use std::sync::Arc;
pub struct CloudEmbeddingProvider {
model: String,
client: Arc<Client>,
}
impl CloudEmbeddingProvider {
pub fn new(client: Arc<Client>) -> Self {
Self {
model: "openai/text-embedding-3-small".into(),
client,
}
}
}
impl EmbeddingProvider for CloudEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
// First, fetch any embeddings that are cached based on the requested texts' digests
// Then compute any embeddings that are missing.
async move {
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
model: self.model.clone(),
digests: texts
.iter()
.map(|to_embed| to_embed.digest.to_vec())
.collect(),
});
let mut embeddings = cached_embeddings
.await
.context("failed to fetch cached embeddings via cloud model")?
.embeddings
.into_iter()
.map(|embedding| {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
Ok((digest, embedding.dimensions))
})
.collect::<Result<HashMap<_, _>>>()?;
let compute_embeddings_request = proto::ComputeEmbeddings {
model: self.model.clone(),
texts: texts
.iter()
.filter_map(|to_embed| {
if embeddings.contains_key(&to_embed.digest) {
None
} else {
Some(to_embed.text.to_string())
}
})
.collect(),
};
if !compute_embeddings_request.texts.is_empty() {
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
for embedding in missing_embeddings.embeddings {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
embeddings.insert(digest, embedding.dimensions);
}
}
texts
.iter()
.map(|to_embed| {
let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
format!("server did not return an embedding for {:?}", to_embed)
})?;
Ok(Embedding::new(dimensions))
})
.collect()
}
.boxed()
}
fn batch_size(&self) -> usize {
2048
}
}

View file

@ -0,0 +1,74 @@
use anyhow::{Context as _, Result};
use futures::{future::BoxFuture, AsyncReadExt, FutureExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use util::http::HttpClient;
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
pub enum OllamaEmbeddingModel {
NomicEmbedText,
MxbaiEmbedLarge,
}
pub struct OllamaEmbeddingProvider {
client: Arc<dyn HttpClient>,
model: OllamaEmbeddingModel,
}
#[derive(Serialize)]
struct OllamaEmbeddingRequest {
model: String,
prompt: String,
}
#[derive(Deserialize)]
struct OllamaEmbeddingResponse {
embedding: Vec<f32>,
}
impl OllamaEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
Self { client, model }
}
}
impl EmbeddingProvider for OllamaEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
//
let model = match self.model {
OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
};
futures::future::try_join_all(texts.into_iter().map(|to_embed| {
let request = OllamaEmbeddingRequest {
model: model.to_string(),
prompt: to_embed.text.to_string(),
};
let request = serde_json::to_string(&request).unwrap();
async {
let response = self
.client
.post_json("http://localhost:11434/api/embeddings", request.into())
.await?;
let mut body = String::new();
response.into_body().read_to_string(&mut body).await?;
let response: OllamaEmbeddingResponse =
serde_json::from_str(&body).context("Unable to pull response")?;
Ok(Embedding::new(response.embedding))
}
}))
.boxed()
}
fn batch_size(&self) -> usize {
// TODO: Figure out decent value
10
}
}

View file

@ -0,0 +1,55 @@
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
use anyhow::Result;
use futures::{future::BoxFuture, FutureExt};
pub use open_ai::OpenAiEmbeddingModel;
use std::sync::Arc;
use util::http::HttpClient;
pub struct OpenAiEmbeddingProvider {
client: Arc<dyn HttpClient>,
model: OpenAiEmbeddingModel,
api_url: String,
api_key: String,
}
impl OpenAiEmbeddingProvider {
pub fn new(
client: Arc<dyn HttpClient>,
model: OpenAiEmbeddingModel,
api_url: String,
api_key: String,
) -> Self {
Self {
client,
model,
api_url,
api_key,
}
}
}
impl EmbeddingProvider for OpenAiEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embed = open_ai::embed(
self.client.as_ref(),
&self.api_url,
&self.api_key,
self.model,
texts.iter().map(|to_embed| to_embed.text),
);
async move {
let response = embed.await?;
Ok(response
.data
.into_iter()
.map(|data| Embedding::new(data.embedding))
.collect())
}
.boxed()
}
fn batch_size(&self) -> usize {
// From https://platform.openai.com/docs/api-reference/embeddings/create
2048
}
}

View file

@ -0,0 +1,954 @@
mod chunking;
mod embedding;
use anyhow::{anyhow, Context as _, Result};
use chunking::{chunk_text, Chunk};
use collections::{Bound, HashMap};
pub use embedding::*;
use fs::Fs;
use futures::stream::StreamExt;
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
Subscription, Task, WeakModel,
};
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use project::{Entry, Project, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use util::ResultExt;
use worktree::LocalSnapshot;
pub struct SemanticIndex {
embedding_provider: Arc<dyn EmbeddingProvider>,
db_connection: heed::Env,
project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
}
impl Global for SemanticIndex {}
impl SemanticIndex {
pub fn new(
db_path: &Path,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Self>> {
let db_path = db_path.to_path_buf();
cx.spawn(|cx| async move {
let db_connection = cx
.background_executor()
.spawn(async move {
unsafe {
heed::EnvOpenOptions::new()
.map_size(1024 * 1024 * 1024)
.max_dbs(3000)
.open(db_path)
}
})
.await?;
Ok(SemanticIndex {
db_connection,
embedding_provider,
project_indices: HashMap::default(),
})
})
}
pub fn project_index(
&mut self,
project: Model<Project>,
cx: &mut AppContext,
) -> Model<ProjectIndex> {
self.project_indices
.entry(project.downgrade())
.or_insert_with(|| {
cx.new_model(|cx| {
ProjectIndex::new(
project,
self.db_connection.clone(),
self.embedding_provider.clone(),
cx,
)
})
})
.clone()
}
}
pub struct ProjectIndex {
db_connection: heed::Env,
project: Model<Project>,
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
embedding_provider: Arc<dyn EmbeddingProvider>,
_subscription: Subscription,
}
enum WorktreeIndexHandle {
Loading {
_task: Task<Result<()>>,
},
Loaded {
index: Model<WorktreeIndex>,
_subscription: Subscription,
},
}
impl ProjectIndex {
fn new(
project: Model<Project>,
db_connection: heed::Env,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
let mut this = ProjectIndex {
db_connection,
project: project.clone(),
worktree_indices: HashMap::default(),
language_registry,
fs,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
};
this.update_worktree_indices(cx);
this
}
fn handle_project_event(
&mut self,
_: Model<Project>,
event: &project::Event,
cx: &mut ModelContext<Self>,
) {
match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
self.update_worktree_indices(cx);
}
_ => {}
}
}
fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
let worktrees = self
.project
.read(cx)
.visible_worktrees(cx)
.filter_map(|worktree| {
if worktree.read(cx).is_local() {
Some((worktree.entity_id(), worktree))
} else {
None
}
})
.collect::<HashMap<_, _>>();
self.worktree_indices
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
for (worktree_id, worktree) in worktrees {
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
let worktree_index = WorktreeIndex::load(
worktree.clone(),
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(|this, mut cx| async move {
if let Some(index) = worktree_index.await.log_err() {
this.update(&mut cx, |this, cx| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
_subscription: cx
.observe(&index, |this, _, cx| this.update_status(cx)),
index,
},
);
})?;
} else {
this.update(&mut cx, |this, _cx| {
this.worktree_indices.remove(&worktree_id)
})?;
}
this.update(&mut cx, |this, cx| this.update_status(cx))
});
WorktreeIndexHandle::Loading {
_task: load_worktree,
}
});
}
self.update_status(cx);
}
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
let mut status = Status::Idle;
for index in self.worktree_indices.values() {
match index {
WorktreeIndexHandle::Loading { .. } => {
status = Status::Scanning;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
if index.read(cx).status == Status::Scanning {
status = Status::Scanning;
break;
}
}
}
}
if status != self.last_status {
self.last_status = status;
cx.emit(status);
}
}
pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
let mut worktree_searches = Vec::new();
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
worktree_searches
.push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
}
}
cx.spawn(|_| async move {
let mut results = Vec::new();
let worktree_searches = futures::future::join_all(worktree_searches).await;
for worktree_search_results in worktree_searches {
if let Some(worktree_search_results) = worktree_search_results.log_err() {
results.extend(worktree_search_results);
}
}
results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.truncate(limit);
results
})
}
}
pub struct SearchResult {
pub worktree: Model<Worktree>,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status {
Idle,
Scanning,
}
impl EventEmitter<Status> for ProjectIndex {}
struct WorktreeIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
status: Status,
_index_entries: Task<Result<()>>,
_subscription: Subscription,
}
impl WorktreeIndex {
pub fn load(
worktree: Model<Worktree>,
db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let worktree_abs_path = worktree.read(cx).abs_path();
cx.spawn(|mut cx| async move {
let db = cx
.background_executor()
.spawn({
let db_connection = db_connection.clone();
async move {
let mut txn = db_connection.write_txn()?;
let db_name = worktree_abs_path.to_string_lossy();
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
txn.commit()?;
anyhow::Ok(db)
}
})
.await?;
cx.new_model(|cx| {
Self::new(
worktree,
db_connection,
db,
language_registry,
fs,
embedding_provider,
cx,
)
})
})
}
fn new(
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
if let worktree::Event::UpdatedEntries(update) = event {
_ = updated_entries_tx.try_send(update.clone());
}
});
Self {
db_connection,
db,
worktree,
language_registry,
fs,
embedding_provider,
status: Status::Idle,
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription,
}
}
async fn index_entries(
this: WeakModel<Self>,
updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext,
) -> Result<()> {
let index = this.update(&mut cx, |this, cx| {
cx.notify();
this.status = Status::Scanning;
this.index_entries_changed_on_disk(cx)
})?;
index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
while let Ok(updated_entries) = updated_entries.recv().await {
let index = this.update(&mut cx, |this, cx| {
cx.notify();
this.status = Status::Scanning;
this.index_updated_entries(updated_entries, cx)
})?;
index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
}
Ok(())
}
fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn index_updated_entries(
&self,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries, cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = self.embed_files(chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let db = self.db;
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let mut db_entries = db
.iter(&txn)
.context("failed to create iterator")?
.move_between_keys()
.peekable();
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
for entry in worktree.files(false, 0) {
let entry_db_key = db_key_for_path(&entry.path);
let mut saved_mtime = None;
while let Some(db_entry) = db_entries.peek() {
match db_entry {
Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
Ordering::Less => {
if let Some(deletion_range) = deletion_range.as_mut() {
deletion_range.1 = Bound::Included(db_path);
} else {
deletion_range =
Some((Bound::Included(db_path), Bound::Included(db_path)));
}
db_entries.next();
}
Ordering::Equal => {
if let Some(deletion_range) = deletion_range.take() {
deleted_entry_ranges_tx
.send((
deletion_range.0.map(ToString::to_string),
deletion_range.1.map(ToString::to_string),
))
.await?;
}
saved_mtime = db_embedded_file.mtime;
db_entries.next();
break;
}
Ordering::Greater => {
break;
}
},
Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
}
}
if entry.mtime != saved_mtime {
updated_entries_tx.send(entry.clone()).await?;
}
}
if let Some(db_entry) = db_entries.next() {
let (db_path, _) = db_entry?;
deleted_entry_ranges_tx
.send((Bound::Included(db_path.to_string()), Bound::Unbounded))
.await?;
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn scan_updated_entries(
&self,
worktree: LocalSnapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
updated_entries_tx.send(entry.clone()).await?;
}
}
project::PathChange::Removed => {
let db_path = db_key_for_path(path);
deleted_entry_ranges_tx
.send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
.await?;
}
project::PathChange::Loaded => {
// Do nothing.
}
}
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn chunk_files(
&self,
worktree_abs_path: Arc<Path>,
entries: channel::Receiver<Entry>,
cx: &AppContext,
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
let fs = self.fs.clone();
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok(entry) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
continue;
};
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let grammar =
language.as_ref().and_then(|language| language.grammar());
let chunked_file = ChunkedFile {
worktree_root: worktree_abs_path.clone(),
chunks: chunk_text(&text, grammar),
entry,
text,
};
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
});
}
})
.await;
Ok(())
});
ChunkFiles {
files: chunked_files_rx,
task,
}
}
fn embed_files(
&self,
chunked_files: channel::Receiver<ChunkedFile>,
cx: &AppContext,
) -> EmbedFiles {
let embedding_provider = self.embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
chunked_files.chunks_timeout(512, Duration::from_secs(2));
while let Some(chunked_files) = chunked_file_batches.next().await {
// View the batch of files as a vec of chunks
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
// Once those are done, reassemble it back into which files they belong to
let chunks = chunked_files
.iter()
.flat_map(|file| {
file.chunks.iter().map(|chunk| TextToEmbed {
text: &file.text[chunk.range.clone()],
digest: chunk.digest,
})
})
.collect::<Vec<_>>();
let mut embeddings = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
// todo!("add a retry facility")
embeddings.extend(embedding_provider.embed(embedding_batch).await?);
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
let chunk_embeddings = embeddings
.by_ref()
.take(chunked_file.chunks.len())
.collect::<Vec<_>>();
let embedded_chunks = chunked_file
.chunks
.into_iter()
.zip(chunk_embeddings)
.map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
.collect();
let embedded_file = EmbeddedFile {
path: chunked_file.entry.path.clone(),
mtime: chunked_file.entry.mtime,
chunks: embedded_chunks,
};
embedded_files_tx.send(embedded_file).await?;
}
}
Ok(())
});
EmbedFiles {
files: embedded_files_rx,
task,
}
}
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<EmbeddedFile>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
while let Some(deletion_range) = deleted_entry_ranges.next().await {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
for file in embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, &file)?;
}
txn.commit()?;
log::debug!("committed");
}
Ok(())
})
}
fn search(
&self,
query: &str,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let db_connection = self.db_connection.clone();
let db = self.db;
let scan_chunks = cx.background_executor().spawn({
async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
}
});
let query = query.to_string();
let embedding_provider = self.embedding_provider.clone();
let worktree = self.worktree.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
let mut query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.pop()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut workers = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
workers.push(Vec::<SearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for worker_results in workers.iter_mut() {
cx.spawn(async {
while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
let score = embedded_chunk.embedding.similarity(&query_embedding);
let ix = match worker_results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
worker_results.insert(
ix,
SearchResult {
worktree: worktree.clone(),
path: path.clone(),
range: embedded_chunk.chunk.range.clone(),
score,
},
);
worker_results.truncate(limit);
}
});
}
})
.await;
scan_chunks.await?;
let mut search_results = Vec::with_capacity(workers.len() * limit);
for worker_results in workers {
search_results.extend(worker_results);
}
search_results
.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
Ok(search_results)
})
}
}
struct ScanEntries {
updated_entries: channel::Receiver<Entry>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>,
}
struct ChunkFiles {
files: channel::Receiver<ChunkedFile>,
task: Task<Result<()>>,
}
struct ChunkedFile {
#[allow(dead_code)]
pub worktree_root: Arc<Path>,
pub entry: Entry,
pub text: String,
pub chunks: Vec<Chunk>,
}
struct EmbedFiles {
files: channel::Receiver<EmbeddedFile>,
task: Task<Result<()>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddedFile {
path: Arc<Path>,
mtime: Option<SystemTime>,
chunks: Vec<EmbeddedChunk>,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbeddedChunk {
chunk: Chunk,
embedding: Embedding,
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::oneshot;
use futures::{future::BoxFuture, FutureExt};
use gpui::{Global, TestAppContext};
use language::language_settings::AllLanguageSettings;
use project::Project;
use settings::SettingsStore;
use std::{future, path::Path, sync::Arc};
fn init_test(cx: &mut TestAppContext) {
_ = cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
});
}
pub struct TestEmbeddingProvider;
impl EmbeddingProvider for TestEmbeddingProvider {
fn embed<'a>(
&'a self,
texts: &'a [TextToEmbed<'a>],
) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
.map(|text| {
let mut embedding = vec![0f32; 2];
// if the text contains garbage, give it a 1 in the first dimension
if text.text.contains("garbage in") {
embedding[0] = 0.9;
} else {
embedding[0] = -0.9;
}
if text.text.contains("garbage out") {
embedding[1] = 0.9;
} else {
embedding[1] = -0.9;
}
Embedding::new(embedding)
})
.collect();
future::ready(Ok(embeddings)).boxed()
}
fn batch_size(&self) -> usize {
16
}
}
#[gpui::test]
async fn test_search(cx: &mut TestAppContext) {
cx.executor().allow_parking();
init_test(cx);
let temp_dir = tempfile::tempdir().unwrap();
let mut semantic_index = cx
.update(|cx| {
let semantic_index = SemanticIndex::new(
Path::new(temp_dir.path()),
Arc::new(TestEmbeddingProvider),
cx,
);
semantic_index
})
.await
.unwrap();
// todo!(): use a fixture
let project_path = Path::new("./fixture");
let project = cx
.spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
.await;
cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
});
let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let subscription = cx.update(|cx| {
cx.subscribe(&project_index, move |_, event, _| {
if let Some(tx) = tx.take() {
_ = tx.send(*event);
}
})
});
rx.await.expect("no event emitted");
drop(subscription);
let results = cx
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
project_index.search(query, 4, cx)
})
.await;
assert!(results.len() > 1, "should have found some results");
for result in &results {
println!("result: {:?}", result.path);
println!("score: {:?}", result.score);
}
// Find result that is greater than 0.5
let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
assert_eq!(search_result.path.to_string_lossy(), "needle.md");
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
let fs = project.read(cx).fs().clone();
cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
})
.await;
let range = search_result.range.clone();
let content = content[range.clone()].to_owned();
assert!(content.contains("garbage in, garbage out"));
}
}