mod cloud; mod lmstudio; mod ollama; mod open_ai; pub use cloud::*; pub use lmstudio::*; pub use ollama::*; pub use open_ai::*; use sha2::{Digest, Sha256}; use anyhow::Result; use futures::{future::BoxFuture, FutureExt}; use serde::{Deserialize, Serialize}; use std::{fmt, future}; /// Trait for embedding providers. Texts in, vectors out. pub trait EmbeddingProvider: Sync + Send { fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>>; fn batch_size(&self) -> usize; } #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct Embedding(Vec); impl Embedding { pub fn new(mut embedding: Vec) -> Self { let len = embedding.len(); let mut norm = 0f32; for i in 0..len { norm += embedding[i] * embedding[i]; } norm = norm.sqrt(); for dimension in &mut embedding { *dimension /= norm; } Self(embedding) } fn len(&self) -> usize { self.0.len() } pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) { debug_assert!(others.iter().all(|other| self.0.len() == other.0.len())); others .iter() .enumerate() .map(|(index, other)| { let dot_product: f32 = self .0 .iter() .copied() .zip(other.0.iter().copied()) .map(|(a, b)| a * b) .sum(); (dot_product, index) }) .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) .unwrap_or((0.0, 0)) } } impl fmt::Display for Embedding { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let digits_to_display = 3; // Start the Embedding display format write!(f, "Embedding(sized: {}; values: [", self.len())?; for (index, value) in self.0.iter().enumerate().take(digits_to_display) { // Lead with comma if not the first element if index != 0 { write!(f, ", ")?; } write!(f, "{:.3}", value)?; } if self.len() > digits_to_display { write!(f, "...")?; } write!(f, "])") } } #[derive(Debug)] pub struct TextToEmbed<'a> { pub text: &'a str, pub digest: [u8; 32], } impl<'a> TextToEmbed<'a> { pub fn new(text: &'a str) -> Self { let digest = Sha256::digest(text.as_bytes()); Self { text, digest: digest.into(), } } } pub struct FakeEmbeddingProvider; impl EmbeddingProvider for FakeEmbeddingProvider { fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { let embeddings = texts .iter() .map(|_text| { let mut embedding = vec![0f32; 1536]; for i in 0..embedding.len() { embedding[i] = i as f32; } Embedding::new(embedding) }) .collect(); future::ready(Ok(embeddings)).boxed() } fn batch_size(&self) -> usize { 16 } } #[cfg(test)] mod test { use super::*; #[gpui::test] fn test_normalize_embedding() { let normalized = Embedding::new(vec![1.0, 1.0, 1.0]); let value: f32 = 1.0 / 3.0_f32.sqrt(); assert_eq!(normalized, Embedding(vec![value; 3])); } }