
This pull request introduces semantic search to the assistant using a slash command: https://github.com/zed-industries/zed/assets/482957/62f39eae-d7d5-46bf-a356-dd081ff88312 Moreover, this also adds a status to pending slash commands, so that we can show when a query is running or whether it failed: <img width="1588" alt="image" src="https://github.com/zed-industries/zed/assets/482957/e8d85960-6275-4552-a068-85efb74cfde1"> I think this could be better design-wise, but seems like a pretty good start. Release Notes: - N/A
93 lines
3.3 KiB
Rust
93 lines
3.3 KiB
Rust
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
|
|
use anyhow::{anyhow, Context, Result};
|
|
use client::{proto, Client};
|
|
use collections::HashMap;
|
|
use futures::{future::BoxFuture, FutureExt};
|
|
use std::sync::Arc;
|
|
|
|
pub struct CloudEmbeddingProvider {
|
|
model: String,
|
|
client: Arc<Client>,
|
|
}
|
|
|
|
impl CloudEmbeddingProvider {
|
|
pub fn new(client: Arc<Client>) -> Self {
|
|
Self {
|
|
model: "openai/text-embedding-3-small".into(),
|
|
client,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EmbeddingProvider for CloudEmbeddingProvider {
|
|
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
|
|
// First, fetch any embeddings that are cached based on the requested texts' digests
|
|
// Then compute any embeddings that are missing.
|
|
async move {
|
|
if !self.client.status().borrow().is_connected() {
|
|
return Err(anyhow!("sign in required"));
|
|
}
|
|
|
|
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
|
|
model: self.model.clone(),
|
|
digests: texts
|
|
.iter()
|
|
.map(|to_embed| to_embed.digest.to_vec())
|
|
.collect(),
|
|
});
|
|
let mut embeddings = cached_embeddings
|
|
.await
|
|
.context("failed to fetch cached embeddings via cloud model")?
|
|
.embeddings
|
|
.into_iter()
|
|
.map(|embedding| {
|
|
let digest: [u8; 32] = embedding
|
|
.digest
|
|
.try_into()
|
|
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
|
Ok((digest, embedding.dimensions))
|
|
})
|
|
.collect::<Result<HashMap<_, _>>>()?;
|
|
|
|
let compute_embeddings_request = proto::ComputeEmbeddings {
|
|
model: self.model.clone(),
|
|
texts: texts
|
|
.iter()
|
|
.filter_map(|to_embed| {
|
|
if embeddings.contains_key(&to_embed.digest) {
|
|
None
|
|
} else {
|
|
Some(to_embed.text.to_string())
|
|
}
|
|
})
|
|
.collect(),
|
|
};
|
|
if !compute_embeddings_request.texts.is_empty() {
|
|
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
|
|
for embedding in missing_embeddings.embeddings {
|
|
let digest: [u8; 32] = embedding
|
|
.digest
|
|
.try_into()
|
|
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
|
embeddings.insert(digest, embedding.dimensions);
|
|
}
|
|
}
|
|
|
|
texts
|
|
.iter()
|
|
.map(|to_embed| {
|
|
let embedding =
|
|
embeddings.get(&to_embed.digest).cloned().with_context(|| {
|
|
format!("server did not return an embedding for {:?}", to_embed)
|
|
})?;
|
|
Ok(Embedding::new(embedding))
|
|
})
|
|
.collect()
|
|
}
|
|
.boxed()
|
|
}
|
|
|
|
fn batch_size(&self) -> usize {
|
|
2048
|
|
}
|
|
}
|