ZIm/crates/semantic_index/src/embedding.rs
Kyle Kelley 49371b44cb
Semantic Index (#10329)
This introduces semantic indexing in Zed based on chunking text from
files in the developer's workspace and creating vector embeddings using
an embedding model. As part of this, we've created an embeddings
provider trait that allows us to work with OpenAI, a local Ollama model,
or a Zed hosted embedding.

The semantic index is built by breaking down text for known
(programming) languages into manageable chunks that are smaller than the
max token size. Each chunk is then fed to a language model to create a
high dimensional vector which is then normalized to a unit vector to
allow fast comparison with other vectors with a simple dot product.
Alongside the vector, we store the path of the file and the range within
the document where the vector was sourced from.

Zed will soon grok contextual similarity across different text snippets,
allowing for natural language search beyond keyword matching. This is
being put together both for human-based search as well as providing
results to Large Language Models to allow them to refine how they help
developers.

Remaining todo:

* [x] Change `provider` to `model` within the zed hosted embeddings
database (as its currently a combo of the provider and the model in one
name)


Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Antonio <antonio@zed.dev>
2024-04-12 11:40:59 -06:00

125 lines
3 KiB
Rust

mod cloud;
mod ollama;
mod open_ai;
pub use cloud::*;
pub use ollama::*;
pub use open_ai::*;
use sha2::{Digest, Sha256};
use anyhow::Result;
use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
impl Embedding {
pub fn new(mut embedding: Vec<f32>) -> Self {
let len = embedding.len();
let mut norm = 0f32;
for i in 0..len {
norm += embedding[i] * embedding[i];
}
norm = norm.sqrt();
for dimension in &mut embedding {
*dimension /= norm;
}
Self(embedding)
}
fn len(&self) -> usize {
self.0.len()
}
pub fn similarity(self, other: &Embedding) -> f32 {
debug_assert_eq!(self.0.len(), other.0.len());
self.0
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum()
}
}
impl fmt::Display for Embedding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let digits_to_display = 3;
// Start the Embedding display format
write!(f, "Embedding(sized: {}; values: [", self.len())?;
for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
// Lead with comma if not the first element
if index != 0 {
write!(f, ", ")?;
}
write!(f, "{:.3}", value)?;
}
if self.len() > digits_to_display {
write!(f, "...")?;
}
write!(f, "])")
}
}
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,
pub digest: [u8; 32],
}
impl<'a> TextToEmbed<'a> {
pub fn new(text: &'a str) -> Self {
let digest = Sha256::digest(text.as_bytes());
Self {
text,
digest: digest.into(),
}
}
}
pub struct FakeEmbeddingProvider;
impl EmbeddingProvider for FakeEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
let embeddings = texts
.iter()
.map(|_text| {
let mut embedding = vec![0f32; 1536];
for i in 0..embedding.len() {
embedding[i] = i as f32;
}
Embedding::new(embedding)
})
.collect();
future::ready(Ok(embeddings)).boxed()
}
fn batch_size(&self) -> usize {
16
}
}
#[cfg(test)]
mod test {
use super::*;
#[gpui::test]
fn test_normalize_embedding() {
let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
let value: f32 = 1.0 / 3.0_f32.sqrt();
assert_eq!(normalized, Embedding(vec![value; 3]));
}
}