rename vector_store crate to semantic_index
This commit is contained in:
parent
e630ff38c4
commit
8b42f5b1b3
14 changed files with 186 additions and 183 deletions
58
crates/semantic_index/Cargo.toml
Normal file
58
crates/semantic_index/Cargo.toml
Normal file
|
@ -0,0 +1,58 @@
|
|||
[package]
|
||||
name = "semantic_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/semantic_index.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
language = { path = "../language" }
|
||||
project = { path = "../project" }
|
||||
workspace = { path = "../workspace" }
|
||||
util = { path = "../util" }
|
||||
picker = { path = "../picker" }
|
||||
theme = { path = "../theme" }
|
||||
editor = { path = "../editor" }
|
||||
rpc = { path = "../rpc" }
|
||||
settings = { path = "../settings" }
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
lazy_static.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
matrixmultiply = "0.3.7"
|
||||
tiktoken-rs = "0.5.0"
|
||||
parking_lot.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"]}
|
||||
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
tempdir.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
tree-sitter-typescript = "*"
|
||||
tree-sitter-rust = "*"
|
||||
tree-sitter-toml = "*"
|
||||
tree-sitter-cpp = "*"
|
||||
tree-sitter-elixir = "*"
|
31
crates/semantic_index/README.md
Normal file
31
crates/semantic_index/README.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
|
||||
WIP: Sample SQL Queries
|
||||
/*
|
||||
|
||||
create table "files" (
|
||||
"id" INTEGER PRIMARY KEY,
|
||||
"path" VARCHAR,
|
||||
"sha1" VARCHAR,
|
||||
);
|
||||
|
||||
create table symbols (
|
||||
"file_id" INTEGER REFERENCES("files", "id") ON CASCADE DELETE,
|
||||
"offset" INTEGER,
|
||||
"embedding" VECTOR,
|
||||
);
|
||||
|
||||
insert into "files" ("path", "sha1") values ("src/main.rs", "sha1") return id;
|
||||
insert into symbols (
|
||||
"file_id",
|
||||
"start",
|
||||
"end",
|
||||
"embedding"
|
||||
) values (
|
||||
(id,),
|
||||
(id,),
|
||||
(id,),
|
||||
(id,),
|
||||
)
|
||||
|
||||
|
||||
*/
|
375
crates/semantic_index/src/db.rs
Normal file
375
crates/semantic_index/src/db.rs
Normal file
|
@ -0,0 +1,375 @@
|
|||
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use anyhow::{anyhow, Result};
|
||||
use project::Fs;
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
pub relative_path: String,
|
||||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding(pub Vec<f32>);
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
return Ok(Embedding(embedding.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorDatabase {
|
||||
db: rusqlite::Connection,
|
||||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
}
|
||||
|
||||
let this = Self {
|
||||
db: rusqlite::Connection::open(path.as_path())?,
|
||||
};
|
||||
this.initialize_database()?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn get_existing_version(&self) -> Result<i64> {
|
||||
let mut version_query = self
|
||||
.db
|
||||
.prepare("SELECT version from semantic_index_config")?;
|
||||
version_query
|
||||
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
|
||||
.map_err(|err| anyhow!("version query failed: {err}"))
|
||||
}
|
||||
|
||||
fn initialize_database(&self) -> Result<()> {
|
||||
rusqlite::vtab::array::load_module(&self.db)?;
|
||||
|
||||
if self
|
||||
.get_existing_version()
|
||||
.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.db
|
||||
.execute(
|
||||
"
|
||||
DROP TABLE semantic_index_config;
|
||||
DROP TABLE worktrees;
|
||||
DROP TABLE files;
|
||||
DROP TABLE documents;
|
||||
",
|
||||
[],
|
||||
)
|
||||
.ok();
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
self.db.execute(
|
||||
"CREATE TABLE semantic_index_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO semantic_index_config (version) VALUES (?1)",
|
||||
params![SEMANTIC_INDEX_VERSION],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
|
||||
self.db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
|
||||
params![worktree_id, delete_path.to_str()],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn insert_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
) -> Result<()> {
|
||||
// Write to files table, and return generated id.
|
||||
self.db.execute(
|
||||
"
|
||||
DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
|
||||
",
|
||||
params![worktree_id, path.to_str()],
|
||||
)?;
|
||||
let mtime = Timestamp::from(mtime);
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT INTO files
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
|
||||
VALUES
|
||||
(?1, ?2, $3, $4);
|
||||
",
|
||||
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|
||||
)?;
|
||||
|
||||
let file_id = self.db.last_insert_rowid();
|
||||
|
||||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
|
||||
params![
|
||||
file_id,
|
||||
document.range.start.to_string(),
|
||||
document.range.end.to_string(),
|
||||
document.name,
|
||||
embedding_blob
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
|
||||
// Check that the absolute path doesnt exist
|
||||
let mut worktree_query = self
|
||||
.db
|
||||
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
})
|
||||
.map_err(|err| anyhow!(err));
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return worktree_id;
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT into worktrees (absolute_path) VALUES (?1)
|
||||
",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(self.db.last_insert_rowid())
|
||||
}
|
||||
|
||||
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT relative_path, mtime_seconds, mtime_nanos
|
||||
FROM files
|
||||
WHERE worktree_id = ?1
|
||||
ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?.into(),
|
||||
Timestamp {
|
||||
seconds: row.get(1)?,
|
||||
nanos: row.get(2)?,
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
query_embedding: &Vec<f32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(&worktree_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results
|
||||
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
self.get_documents_by_ids(&ids)
|
||||
}
|
||||
|
||||
fn for_each_document(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
mut f: impl FnMut(i64, Vec<f32>),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
documents.id, documents.embedding
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
documents.file_id = files.id AND
|
||||
files.worktree_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
query_statement
|
||||
.query_map(params![ids_to_sql(worktree_ids)], |row| {
|
||||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|(id, embedding)| f(id, embedding.0));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_documents_by_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
documents.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
documents.start_byte,
|
||||
documents.end_byte, documents.name
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
documents.file_id = files.id AND
|
||||
documents.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
row.get(5)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range, name) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range, name));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in ids {
|
||||
let value = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing document id {}", id))?;
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
||||
Rc::new(
|
||||
ids.iter()
|
||||
.copied()
|
||||
.map(|v| rusqlite::types::Value::from(v))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
|
||||
let len = vec_a.len();
|
||||
assert_eq!(len, vec_b.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
vec_a.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
vec_b.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
166
crates/semantic_index/src/embedding.rs
Normal file
166
crates/semantic_index/src/embedding.rs
Normal file
|
@ -0,0 +1,166 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::executor::Background;
|
||||
use gpui::serde_json;
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Background>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
}
|
||||
|
||||
pub struct DummyEmbeddings {}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = vec![0.32 as f32; 1536];
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
async fn truncate(span: String) -> String {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
|
||||
if tokens.len() > 8190 {
|
||||
tokens.truncate(8190);
|
||||
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
||||
if result.is_ok() {
|
||||
let transformed = result.unwrap();
|
||||
// assert_ne!(transformed, span);
|
||||
return transformed;
|
||||
}
|
||||
}
|
||||
|
||||
return span.to_string();
|
||||
}
|
||||
|
||||
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
let api_key = OPENAI_API_KEY
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no api key"))?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut response: Response<AsyncBody>;
|
||||
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(api_key, spans.iter().map(|x| &**x).collect())
|
||||
.await?;
|
||||
request_number += 1;
|
||||
|
||||
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
|
||||
return Err(anyhow!(
|
||||
"openai max retries, error: {:?}",
|
||||
&response.status()
|
||||
));
|
||||
}
|
||||
|
||||
match response.status() {
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
self.executor.timer(delay).await;
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
log::info!("BAD REQUEST: {:?}", &response.status());
|
||||
// Don't worry about delaying bad request, as we can assume
|
||||
// we haven't been rate limited yet.
|
||||
for span in spans.iter_mut() {
|
||||
*span = Self::truncate(span.to_string()).await;
|
||||
}
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::info!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.embedding)
|
||||
.collect());
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("openai embedding failed {}", response.status()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("openai embedding failed"))
|
||||
}
|
||||
}
|
172
crates/semantic_index/src/modal.rs
Normal file
172
crates/semantic_index/src/modal.rs
Normal file
|
@ -0,0 +1,172 @@
|
|||
use crate::{SearchResult, SemanticIndex};
|
||||
use editor::{scroll::autoscroll::Autoscroll, Editor};
|
||||
use gpui::{
|
||||
actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext,
|
||||
WeakViewHandle,
|
||||
};
|
||||
use picker::{Picker, PickerDelegate, PickerEvent};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
const MIN_QUERY_LEN: usize = 5;
|
||||
const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500);
|
||||
|
||||
actions!(semantic_search, [Toggle]);
|
||||
|
||||
pub type SemanticSearch = Picker<SemanticSearchDelegate>;
|
||||
|
||||
pub struct SemanticSearchDelegate {
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
project: ModelHandle<Project>,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
selected_match_index: usize,
|
||||
matches: Vec<SearchResult>,
|
||||
history: HashMap<String, Vec<SearchResult>>,
|
||||
}
|
||||
|
||||
impl SemanticSearchDelegate {
|
||||
// This is currently searching on every keystroke,
|
||||
// This is wildly overkill, and has the potential to get expensive
|
||||
// We will need to update this to throttle searching
|
||||
pub fn new(
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
project: ModelHandle<Project>,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
) -> Self {
|
||||
Self {
|
||||
workspace,
|
||||
project,
|
||||
semantic_index,
|
||||
selected_match_index: 0,
|
||||
matches: vec![],
|
||||
history: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for SemanticSearchDelegate {
|
||||
fn placeholder_text(&self) -> Arc<str> {
|
||||
"Search repository in natural language...".into()
|
||||
}
|
||||
|
||||
fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
|
||||
if let Some(search_result) = self.matches.get(self.selected_match_index) {
|
||||
// Open Buffer
|
||||
let search_result = search_result.clone();
|
||||
let buffer = self.project.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: search_result.worktree_id,
|
||||
path: search_result.file_path.clone().into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let workspace = self.workspace.clone();
|
||||
let position = search_result.clone().byte_range.start;
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let buffer = buffer.await?;
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
let editor = workspace.open_project_item::<Editor>(buffer, cx);
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.change_selections(Some(Autoscroll::center()), cx, |s| {
|
||||
s.select_ranges([position..position])
|
||||
});
|
||||
});
|
||||
})?;
|
||||
Ok::<_, anyhow::Error>(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
cx.emit(PickerEvent::Dismiss);
|
||||
}
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.matches.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_match_index
|
||||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext<SemanticSearch>) {
|
||||
self.selected_match_index = ix;
|
||||
}
|
||||
|
||||
fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
|
||||
log::info!("Searching for {:?}...", query);
|
||||
if query.len() < MIN_QUERY_LEN {
|
||||
log::info!("Query below minimum length");
|
||||
return Task::ready(());
|
||||
}
|
||||
|
||||
let semantic_index = self.semantic_index.clone();
|
||||
let project = self.project.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await;
|
||||
|
||||
let retrieved_cached = this.update(&mut cx, |this, _| {
|
||||
let delegate = this.delegate_mut();
|
||||
if delegate.history.contains_key(&query) {
|
||||
let historic_results = delegate.history.get(&query).unwrap().to_owned();
|
||||
delegate.matches = historic_results.clone();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(retrieved) = retrieved_cached.log_err() {
|
||||
if !retrieved {
|
||||
let task = semantic_index.update(&mut cx, |store, cx| {
|
||||
store.search_project(project.clone(), query.to_string(), 10, cx)
|
||||
});
|
||||
|
||||
if let Some(results) = task.await.log_err() {
|
||||
log::info!("Not queried previously, searching...");
|
||||
this.update(&mut cx, |this, _| {
|
||||
let delegate = this.delegate_mut();
|
||||
delegate.matches = results.clone();
|
||||
delegate.history.insert(query, results);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
} else {
|
||||
log::info!("Already queried, retrieved directly from cached history");
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
mouse_state: &mut MouseState,
|
||||
selected: bool,
|
||||
cx: &AppContext,
|
||||
) -> AnyElement<Picker<Self>> {
|
||||
let theme = theme::current(cx);
|
||||
let style = &theme.picker.item;
|
||||
let current_style = style.in_state(selected).style_for(mouse_state);
|
||||
|
||||
let search_result = &self.matches[ix];
|
||||
|
||||
let path = search_result.file_path.to_string_lossy();
|
||||
let name = search_result.name.clone();
|
||||
|
||||
Flex::column()
|
||||
.with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
|
||||
.with_child(Label::new(
|
||||
path.to_string(),
|
||||
style.inactive_state().default.label.clone(),
|
||||
))
|
||||
.contained()
|
||||
.with_style(current_style.container)
|
||||
.into_any()
|
||||
}
|
||||
}
|
139
crates/semantic_index/src/parsing.rs
Normal file
139
crates/semantic_index/src/parsing.rs
Normal file
|
@ -0,0 +1,139 @@
|
|||
use anyhow::{anyhow, Ok, Result};
|
||||
use language::Language;
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Document {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
const CODE_CONTEXT_TEMPLATE: &str =
|
||||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const ENTIRE_FILE_TEMPLATE: &str =
|
||||
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"];
|
||||
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
pub cursor: QueryCursor,
|
||||
}
|
||||
|
||||
impl CodeContextRetriever {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
parser: Parser::new(),
|
||||
cursor: QueryCursor::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn _parse_entire_file(
|
||||
&self,
|
||||
relative_path: &Path,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Document>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &content);
|
||||
|
||||
Ok(vec![Document {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
name: language_name.to_string(),
|
||||
}])
|
||||
}
|
||||
|
||||
pub fn parse_file(
|
||||
&mut self,
|
||||
relative_path: &Path,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Document>> {
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
|
||||
return self._parse_entire_file(relative_path, language.name(), &content);
|
||||
}
|
||||
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
let embedding_config = grammar
|
||||
.embedding_config
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no embedding queries"))?;
|
||||
|
||||
self.parser.set_language(grammar.ts_language).unwrap();
|
||||
|
||||
let tree = self
|
||||
.parser
|
||||
.parse(&content, None)
|
||||
.ok_or_else(|| anyhow!("parsing failed"))?;
|
||||
|
||||
let mut documents = Vec::new();
|
||||
|
||||
// Iterate through query matches
|
||||
let mut name_ranges: Vec<Range<usize>> = vec![];
|
||||
for mat in self.cursor.matches(
|
||||
&embedding_config.query,
|
||||
tree.root_node(),
|
||||
content.as_bytes(),
|
||||
) {
|
||||
let mut name: Vec<&str> = vec![];
|
||||
let mut item: Option<&str> = None;
|
||||
let mut byte_range: Option<Range<usize>> = None;
|
||||
let mut context_spans: Vec<&str> = vec![];
|
||||
for capture in mat.captures {
|
||||
if capture.index == embedding_config.item_capture_ix {
|
||||
byte_range = Some(capture.node.byte_range());
|
||||
item = content.get(capture.node.byte_range());
|
||||
} else if capture.index == embedding_config.name_capture_ix {
|
||||
let name_range = capture.node.byte_range();
|
||||
if name_ranges.contains(&name_range) {
|
||||
continue;
|
||||
}
|
||||
name_ranges.push(name_range.clone());
|
||||
if let Some(name_content) = content.get(name_range.clone()) {
|
||||
name.push(name_content);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(context_capture_ix) = embedding_config.context_capture_ix {
|
||||
if capture.index == context_capture_ix {
|
||||
if let Some(context) = content.get(capture.node.byte_range()) {
|
||||
context_spans.push(context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((item, byte_range)) = item.zip(byte_range) {
|
||||
if !name.is_empty() {
|
||||
let item = if context_spans.is_empty() {
|
||||
item.to_string()
|
||||
} else {
|
||||
format!("{}\n{}", context_spans.join("\n"), item)
|
||||
};
|
||||
|
||||
let document_text = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_str().unwrap())
|
||||
.replace("<language>", &language.name().to_lowercase())
|
||||
.replace("<item>", item.as_str());
|
||||
|
||||
documents.push(Document {
|
||||
range: byte_range,
|
||||
content: document_text,
|
||||
embedding: Vec::new(),
|
||||
name: name.join(" ").to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(documents);
|
||||
}
|
||||
}
|
686
crates/semantic_index/src/semantic_index.rs
Normal file
686
crates/semantic_index/src/semantic_index.rs
Normal file
|
@ -0,0 +1,686 @@
|
|||
mod db;
|
||||
mod embedding;
|
||||
mod modal;
|
||||
mod parsing;
|
||||
mod semantic_index_settings;
|
||||
|
||||
#[cfg(test)]
|
||||
mod semantic_index_tests;
|
||||
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::VectorDatabase;
|
||||
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
|
||||
use futures::{channel::oneshot, Future};
|
||||
use gpui::{
|
||||
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
|
||||
WeakModelHandle,
|
||||
};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
|
||||
use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use project::{Fs, Project, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc, Weak,
|
||||
},
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use util::{
|
||||
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
|
||||
http::HttpClient,
|
||||
paths::EMBEDDINGS_DIR,
|
||||
ResultExt,
|
||||
};
|
||||
use workspace::{Workspace, WorkspaceCreated};
|
||||
|
||||
const SEMANTIC_INDEX_VERSION: usize = 1;
|
||||
const EMBEDDINGS_BATCH_SIZE: usize = 150;
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
|
||||
let db_file_path = EMBEDDINGS_DIR
|
||||
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
||||
.join("embeddings_db");
|
||||
|
||||
SemanticSearch::init(cx);
|
||||
cx.add_action(
|
||||
|workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
|
||||
if cx.has_global::<ModelHandle<SemanticIndex>>() {
|
||||
let semantic_index = cx.global::<ModelHandle<SemanticIndex>>().clone();
|
||||
workspace.toggle_modal(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
let workspace = cx.weak_handle();
|
||||
cx.add_view(|cx| {
|
||||
SemanticSearch::new(
|
||||
SemanticSearchDelegate::new(workspace, project, semantic_index),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
if *RELEASE_CHANNEL == ReleaseChannel::Stable
|
||||
|| !settings::get::<SemanticIndexSettings>(cx).enabled
|
||||
{
|
||||
log::info!("NOT ENABLED");
|
||||
return;
|
||||
}
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs,
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddings {
|
||||
client: http_client,
|
||||
executor: cx.background(),
|
||||
}),
|
||||
language_registry,
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
cx.update(|cx| {
|
||||
cx.set_global(semantic_index.clone());
|
||||
cx.subscribe_global::<WorkspaceCreated, _>({
|
||||
let semantic_index = semantic_index.clone();
|
||||
move |event, cx| {
|
||||
let workspace = &event.0;
|
||||
if let Some(workspace) = workspace.upgrade(cx) {
|
||||
let project = workspace.read(cx).project().clone();
|
||||
if project.read(cx).is_local() {
|
||||
semantic_index.update(cx, |store, cx| {
|
||||
store.index_project(project, cx).detach();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub struct SemanticIndex {
|
||||
fs: Arc<dyn Fs>,
|
||||
database_url: Arc<PathBuf>,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
db_update_tx: channel::Sender<DbOperation>,
|
||||
parsing_files_tx: channel::Sender<PendingFile>,
|
||||
_db_update_task: Task<()>,
|
||||
_embed_batch_task: Task<()>,
|
||||
_batch_files_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
next_job_id: Arc<AtomicUsize>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
worktree_db_ids: Vec<(WorktreeId, i64)>,
|
||||
outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
|
||||
}
|
||||
|
||||
type JobId = usize;
|
||||
|
||||
struct JobHandle {
|
||||
id: JobId,
|
||||
set: Weak<Mutex<HashSet<JobId>>>,
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
|
||||
self.worktree_db_ids
|
||||
.iter()
|
||||
.find_map(|(worktree_id, db_id)| {
|
||||
if *worktree_id == id {
|
||||
Some(*db_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
|
||||
self.worktree_db_ids
|
||||
.iter()
|
||||
.find_map(|(worktree_id, db_id)| {
|
||||
if *db_id == id {
|
||||
Some(*worktree_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PendingFile {
|
||||
worktree_db_id: i64,
|
||||
relative_path: PathBuf,
|
||||
absolute_path: PathBuf,
|
||||
language: Arc<Language>,
|
||||
modified_time: SystemTime,
|
||||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
pub worktree_id: WorktreeId,
|
||||
pub name: String,
|
||||
pub byte_range: Range<usize>,
|
||||
pub file_path: PathBuf,
|
||||
}
|
||||
|
||||
enum DbOperation {
|
||||
InsertFile {
|
||||
worktree_id: i64,
|
||||
documents: Vec<Document>,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
job_handle: JobHandle,
|
||||
},
|
||||
Delete {
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
},
|
||||
FindOrCreateWorktree {
|
||||
path: PathBuf,
|
||||
sender: oneshot::Sender<Result<i64>>,
|
||||
},
|
||||
FileMTimes {
|
||||
worktree_id: i64,
|
||||
sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
|
||||
},
|
||||
}
|
||||
|
||||
enum EmbeddingJob {
|
||||
Enqueue {
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
job_handle: JobHandle,
|
||||
},
|
||||
Flush,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
database_url: PathBuf,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<ModelHandle<Self>> {
|
||||
let database_url = Arc::new(database_url);
|
||||
|
||||
let db = cx
|
||||
.background()
|
||||
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
|
||||
.await?;
|
||||
|
||||
Ok(cx.add_model(|cx| {
|
||||
// paths_tx -> embeddings_tx -> db_update_tx
|
||||
|
||||
//db_update_tx/rx: Updating Database
|
||||
let (db_update_tx, db_update_rx) = channel::unbounded();
|
||||
let _db_update_task = cx.background().spawn(async move {
|
||||
while let Ok(job) = db_update_rx.recv().await {
|
||||
match job {
|
||||
DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
} => {
|
||||
db.insert_file(worktree_id, path, mtime, documents)
|
||||
.log_err();
|
||||
drop(job_handle)
|
||||
}
|
||||
DbOperation::Delete { worktree_id, path } => {
|
||||
db.delete_file(worktree_id, path).log_err();
|
||||
}
|
||||
DbOperation::FindOrCreateWorktree { path, sender } => {
|
||||
let id = db.find_or_create_worktree(&path);
|
||||
sender.send(id).ok();
|
||||
}
|
||||
DbOperation::FileMTimes {
|
||||
worktree_id: worktree_db_id,
|
||||
sender,
|
||||
} => {
|
||||
let file_mtimes = db.get_file_mtimes(worktree_db_id);
|
||||
sender.send(file_mtimes).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// embed_tx/rx: Embed Batch and Send to Database
|
||||
let (embed_batch_tx, embed_batch_rx) =
|
||||
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
|
||||
let _embed_batch_task = cx.background().spawn({
|
||||
let db_update_tx = db_update_tx.clone();
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
async move {
|
||||
while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
|
||||
// Construct Batch
|
||||
let mut batch_documents = vec![];
|
||||
for (_, documents, _, _, _) in embeddings_queue.iter() {
|
||||
batch_documents
|
||||
.extend(documents.iter().map(|document| document.content.as_str()));
|
||||
}
|
||||
|
||||
if let Ok(embeddings) =
|
||||
embedding_provider.embed_batch(batch_documents).await
|
||||
{
|
||||
log::trace!(
|
||||
"created {} embeddings for {} files",
|
||||
embeddings.len(),
|
||||
embeddings_queue.len(),
|
||||
);
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
|
||||
for embedding in embeddings.iter() {
|
||||
while embeddings_queue[i].1.len() == j {
|
||||
i += 1;
|
||||
j = 0;
|
||||
}
|
||||
|
||||
embeddings_queue[i].1[j].embedding = embedding.to_owned();
|
||||
j += 1;
|
||||
}
|
||||
|
||||
for (worktree_id, documents, path, mtime, job_handle) in
|
||||
embeddings_queue.into_iter()
|
||||
{
|
||||
for document in documents.iter() {
|
||||
// TODO: Update this so it doesn't panic
|
||||
assert!(
|
||||
document.embedding.len() > 0,
|
||||
"Document Embedding Not Complete"
|
||||
);
|
||||
}
|
||||
|
||||
db_update_tx
|
||||
.send(DbOperation::InsertFile {
|
||||
worktree_id,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// batch_tx/rx: Batch Files to Send for Embeddings
|
||||
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
|
||||
let _batch_files_task = cx.background().spawn(async move {
|
||||
let mut queue_len = 0;
|
||||
let mut embeddings_queue = vec![];
|
||||
|
||||
while let Ok(job) = batch_files_rx.recv().await {
|
||||
let should_flush = match job {
|
||||
EmbeddingJob::Enqueue {
|
||||
documents,
|
||||
worktree_id,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
} => {
|
||||
queue_len += &documents.len();
|
||||
embeddings_queue.push((
|
||||
worktree_id,
|
||||
documents,
|
||||
path,
|
||||
mtime,
|
||||
job_handle,
|
||||
));
|
||||
queue_len >= EMBEDDINGS_BATCH_SIZE
|
||||
}
|
||||
EmbeddingJob::Flush => true,
|
||||
};
|
||||
|
||||
if should_flush {
|
||||
embed_batch_tx.try_send(embeddings_queue).unwrap();
|
||||
embeddings_queue = vec![];
|
||||
queue_len = 0;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// parsing_files_tx/rx: Parsing Files to Embeddable Documents
|
||||
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
|
||||
|
||||
let mut _parsing_files_tasks = Vec::new();
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
let fs = fs.clone();
|
||||
let parsing_files_rx = parsing_files_rx.clone();
|
||||
let batch_files_tx = batch_files_tx.clone();
|
||||
_parsing_files_tasks.push(cx.background().spawn(async move {
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
while let Ok(pending_file) = parsing_files_rx.recv().await {
|
||||
if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
|
||||
{
|
||||
if let Some(documents) = retriever
|
||||
.parse_file(
|
||||
&pending_file.relative_path,
|
||||
&content,
|
||||
pending_file.language,
|
||||
)
|
||||
.log_err()
|
||||
{
|
||||
log::trace!(
|
||||
"parsed path {:?}: {} documents",
|
||||
pending_file.relative_path,
|
||||
documents.len()
|
||||
);
|
||||
|
||||
batch_files_tx
|
||||
.try_send(EmbeddingJob::Enqueue {
|
||||
worktree_id: pending_file.worktree_db_id,
|
||||
path: pending_file.relative_path,
|
||||
mtime: pending_file.modified_time,
|
||||
job_handle: pending_file.job_handle,
|
||||
documents,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
if parsing_files_rx.len() == 0 {
|
||||
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
Self {
|
||||
fs,
|
||||
database_url,
|
||||
embedding_provider,
|
||||
language_registry,
|
||||
db_update_tx,
|
||||
next_job_id: Default::default(),
|
||||
parsing_files_tx,
|
||||
_db_update_task,
|
||||
_embed_batch_task,
|
||||
_batch_files_task,
|
||||
_parsing_files_tasks,
|
||||
projects: HashMap::new(),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.db_update_tx
|
||||
.try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
|
||||
.unwrap();
|
||||
async move { rx.await? }
|
||||
}
|
||||
|
||||
fn get_file_mtimes(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.db_update_tx
|
||||
.try_send(DbOperation::FileMTimes {
|
||||
worktree_id,
|
||||
sender: tx,
|
||||
})
|
||||
.unwrap();
|
||||
async move { rx.await? }
|
||||
}
|
||||
|
||||
fn index_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<usize>> {
|
||||
let worktree_scans_complete = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| {
|
||||
let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
|
||||
async move {
|
||||
scan_complete.await;
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let worktree_db_ids = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| {
|
||||
self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let language_registry = self.language_registry.clone();
|
||||
let db_update_tx = self.db_update_tx.clone();
|
||||
let parsing_files_tx = self.parsing_files_tx.clone();
|
||||
let next_job_id = self.next_job_id.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
futures::future::join_all(worktree_scans_complete).await;
|
||||
|
||||
let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
|
||||
|
||||
let worktrees = project.read_with(&cx, |project, cx| {
|
||||
project
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let mut worktree_file_mtimes = HashMap::new();
|
||||
let mut db_ids_by_worktree_id = HashMap::new();
|
||||
for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
|
||||
let db_id = db_id?;
|
||||
db_ids_by_worktree_id.insert(worktree.id(), db_id);
|
||||
worktree_file_mtimes.insert(
|
||||
worktree.id(),
|
||||
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
// let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
|
||||
let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.projects.insert(
|
||||
project.downgrade(),
|
||||
ProjectState {
|
||||
worktree_db_ids: db_ids_by_worktree_id
|
||||
.iter()
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect(),
|
||||
outstanding_jobs: outstanding_jobs.clone(),
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
cx.background()
|
||||
.spawn(async move {
|
||||
let mut count = 0;
|
||||
let t0 = Instant::now();
|
||||
for worktree in worktrees.into_iter() {
|
||||
let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
|
||||
for file in worktree.files(false, 0) {
|
||||
let absolute_path = worktree.absolutize(&file.path);
|
||||
|
||||
if let Ok(language) = language_registry
|
||||
.language_for_file(&absolute_path, None)
|
||||
.await
|
||||
{
|
||||
if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
|
||||
&& language
|
||||
.grammar()
|
||||
.and_then(|grammar| grammar.embedding_config.as_ref())
|
||||
.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let path_buf = file.path.to_path_buf();
|
||||
let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
|
||||
let already_stored = stored_mtime
|
||||
.map_or(false, |existing_mtime| existing_mtime == file.mtime);
|
||||
|
||||
if !already_stored {
|
||||
log::trace!("sending for parsing: {:?}", path_buf);
|
||||
count += 1;
|
||||
let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
let job_handle = JobHandle {
|
||||
id: job_id,
|
||||
set: Arc::downgrade(&outstanding_jobs),
|
||||
};
|
||||
outstanding_jobs.lock().insert(job_id);
|
||||
parsing_files_tx
|
||||
.try_send(PendingFile {
|
||||
worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
|
||||
relative_path: path_buf,
|
||||
absolute_path,
|
||||
language,
|
||||
job_handle,
|
||||
modified_time: file.mtime,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
for file in file_mtimes.keys() {
|
||||
db_update_tx
|
||||
.try_send(DbOperation::Delete {
|
||||
worktree_id: db_ids_by_worktree_id[&worktree.id()],
|
||||
path: file.to_owned(),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
log::trace!(
|
||||
"parsing worktree completed in {:?}",
|
||||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(count)
|
||||
})
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remaining_files_to_index_for_project(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
) -> Option<usize> {
|
||||
Some(
|
||||
self.projects
|
||||
.get(&project.downgrade())?
|
||||
.outstanding_jobs
|
||||
.lock()
|
||||
.len(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn search_project(
|
||||
&mut self,
|
||||
project: ModelHandle<Project>,
|
||||
phrase: String,
|
||||
limit: usize,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
|
||||
state
|
||||
} else {
|
||||
return Task::ready(Err(anyhow!("project not added")));
|
||||
};
|
||||
|
||||
let worktree_db_ids = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.filter_map(|worktree| {
|
||||
let worktree_id = worktree.read(cx).id();
|
||||
project_state.db_id_for_worktree_id(worktree_id)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
let documents = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
let database = VectorDatabase::new(fs, database_url).await?;
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
|
||||
})
|
||||
.await?;
|
||||
|
||||
this.read_with(&cx, |this, _| {
|
||||
let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
|
||||
state
|
||||
} else {
|
||||
return Err(anyhow!("project not added"));
|
||||
};
|
||||
|
||||
Ok(documents
|
||||
.into_iter()
|
||||
.filter_map(|(worktree_db_id, file_path, byte_range, name)| {
|
||||
let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
|
||||
Some(SearchResult {
|
||||
worktree_id,
|
||||
name,
|
||||
byte_range,
|
||||
file_path,
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for SemanticIndex {
|
||||
type Event = ();
|
||||
}
|
||||
|
||||
impl Drop for JobHandle {
|
||||
fn drop(&mut self) {
|
||||
if let Some(set) = self.set.upgrade() {
|
||||
set.lock().remove(&self.id);
|
||||
}
|
||||
}
|
||||
}
|
30
crates/semantic_index/src/semantic_index_settings.rs
Normal file
30
crates/semantic_index/src/semantic_index_settings.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
use anyhow;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Setting;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SemanticIndexSettings {
|
||||
pub enabled: bool,
|
||||
pub reindexing_delay_seconds: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct SemanticIndexSettingsContent {
|
||||
pub enabled: Option<bool>,
|
||||
pub reindexing_delay_seconds: Option<usize>,
|
||||
}
|
||||
|
||||
impl Setting for SemanticIndexSettings {
|
||||
const KEY: Option<&'static str> = Some("semantic_index");
|
||||
|
||||
type FileContent = SemanticIndexSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
}
|
1034
crates/semantic_index/src/semantic_index_tests.rs
Normal file
1034
crates/semantic_index/src/semantic_index_tests.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue