open ai indexing on open for rust files

This commit is contained in:
KCaverly 2023-06-22 16:50:07 -04:00
parent d4a4db42aa
commit dd309070eb
7 changed files with 252 additions and 55 deletions

57
Cargo.lock generated
View file

@ -1389,15 +1389,6 @@ dependencies = [
"theme", "theme",
] ]
[[package]]
name = "conv"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299"
dependencies = [
"custom_derive",
]
[[package]] [[package]]
name = "copilot" name = "copilot"
version = "0.1.0" version = "0.1.0"
@ -1775,12 +1766,6 @@ dependencies = [
"winapi 0.3.9", "winapi 0.3.9",
] ]
[[package]]
name = "custom_derive"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9"
[[package]] [[package]]
name = "cxx" name = "cxx"
version = "1.0.94" version = "1.0.94"
@ -2219,6 +2204,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]] [[package]]
name = "fancy-regex" name = "fancy-regex"
version = "0.11.0" version = "0.11.0"
@ -2909,6 +2900,15 @@ dependencies = [
"ahash 0.8.3", "ahash 0.8.3",
] ]
[[package]]
name = "hashlink"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf"
dependencies = [
"hashbrown 0.11.2",
]
[[package]] [[package]]
name = "hashlink" name = "hashlink"
version = "0.8.1" version = "0.8.1"
@ -5600,6 +5600,21 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rusqlite"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a"
dependencies = [
"bitflags",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink 0.7.0",
"libsqlite3-sys",
"memchr",
"smallvec",
]
[[package]] [[package]]
name = "rust-embed" name = "rust-embed"
version = "6.6.1" version = "6.6.1"
@ -6531,7 +6546,7 @@ dependencies = [
"futures-executor", "futures-executor",
"futures-intrusive", "futures-intrusive",
"futures-util", "futures-util",
"hashlink", "hashlink 0.8.1",
"hex", "hex",
"hkdf", "hkdf",
"hmac 0.12.1", "hmac 0.12.1",
@ -7898,14 +7913,20 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-compat", "async-compat",
"conv", "async-trait",
"futures 0.3.28", "futures 0.3.28",
"gpui", "gpui",
"isahc",
"language", "language",
"lazy_static",
"log",
"project", "project",
"rand 0.8.5", "rusqlite",
"serde",
"serde_json",
"smol", "smol",
"sqlx", "sqlx",
"tree-sitter",
"util", "util",
"workspace", "workspace",
] ]

View file

@ -476,12 +476,12 @@ pub struct Language {
pub struct Grammar { pub struct Grammar {
id: usize, id: usize,
pub(crate) ts_language: tree_sitter::Language, pub ts_language: tree_sitter::Language,
pub(crate) error_query: Query, pub(crate) error_query: Query,
pub(crate) highlights_query: Option<Query>, pub(crate) highlights_query: Option<Query>,
pub(crate) brackets_config: Option<BracketConfig>, pub(crate) brackets_config: Option<BracketConfig>,
pub(crate) indents_config: Option<IndentConfig>, pub(crate) indents_config: Option<IndentConfig>,
pub(crate) outline_config: Option<OutlineConfig>, pub outline_config: Option<OutlineConfig>,
pub(crate) injection_config: Option<InjectionConfig>, pub(crate) injection_config: Option<InjectionConfig>,
pub(crate) override_config: Option<OverrideConfig>, pub(crate) override_config: Option<OverrideConfig>,
pub(crate) highlight_map: Mutex<HighlightMap>, pub(crate) highlight_map: Mutex<HighlightMap>,
@ -495,12 +495,12 @@ struct IndentConfig {
outdent_capture_ix: Option<u32>, outdent_capture_ix: Option<u32>,
} }
struct OutlineConfig { pub struct OutlineConfig {
query: Query, pub query: Query,
item_capture_ix: u32, pub item_capture_ix: u32,
name_capture_ix: u32, pub name_capture_ix: u32,
context_capture_ix: Option<u32>, pub context_capture_ix: Option<u32>,
extra_context_capture_ix: Option<u32>, pub extra_context_capture_ix: Option<u32>,
} }
struct InjectionConfig { struct InjectionConfig {

View file

@ -19,8 +19,14 @@ futures.workspace = true
smol.workspace = true smol.workspace = true
sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] }
async-compat = "0.2.1" async-compat = "0.2.1"
conv = "0.3.3" rusqlite = "0.27.0"
rand.workspace = true isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
[dev-dependencies] [dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1,8 +1,6 @@
use anyhow::Result; use anyhow::Result;
use async_compat::{Compat, CompatExt}; use async_compat::{Compat, CompatExt};
use conv::ValueFrom; use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool};
use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool};
use std::time::{Duration, Instant};
use crate::IndexedFile; use crate::IndexedFile;

View file

@ -0,0 +1,100 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::serde_json;
use isahc::prelude::Configurable;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
}
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait]
pub trait EmbeddingProvider: Sync {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
let api_key = OPENAI_API_KEY
.as_ref()
.ok_or_else(|| anyhow!("no api key"))?;
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans,
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
let mut response = self.client.send(request).await?;
if !response.status().is_success() {
return Err(anyhow!("openai embedding failed {}", response.status()));
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::info!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// do we need to re-order these based on the `index` field?
eprintln!(
"indices: {:?}",
response
.data
.iter()
.map(|embedding| embedding.index)
.collect::<Vec<_>>()
);
Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.collect())
}
}

View file

@ -1,17 +1,25 @@
mod db; mod db;
use anyhow::Result; mod embedding;
use anyhow::{anyhow, Result};
use db::VectorDatabase; use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use language::LanguageRegistry; use language::LanguageRegistry;
use project::{Fs, Project}; use project::{Fs, Project};
use rand::Rng;
use smol::channel; use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant}; use std::{path::PathBuf, sync::Arc, time::Instant};
use util::ResultExt; use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt};
use workspace::WorkspaceCreated; use workspace::WorkspaceCreated;
pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) { pub fn init(
let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry)); fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext,
) {
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
cx.subscribe_global::<WorkspaceCreated, _>({ cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone(); let vector_store = vector_store.clone();
@ -53,38 +61,86 @@ struct SearchResult {
struct VectorStore { struct VectorStore {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
} }
impl VectorStore { impl VectorStore {
fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self { fn new(
fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self { Self {
fs, fs,
http_client,
language_registry, language_registry,
} }
} }
async fn index_file( async fn index_file(
cursor: &mut QueryCursor,
parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>, fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>, language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf, file_path: PathBuf,
) -> Result<IndexedFile> { ) -> Result<IndexedFile> {
// This is creating dummy documents to test the database writes. let language = language_registry
let mut documents = vec![]; .language_for_file(&file_path, None)
let mut rng = rand::thread_rng(); .await?;
let rand_num_of_documents: u8 = rng.gen_range(0..200);
for _ in 0..rand_num_of_documents { if language.name().as_ref() != "Rust" {
let doc = Document { Err(anyhow!("unsupported language"))?;
offset: 0, }
name: "test symbol".to_string(),
embedding: vec![0.32 as f32; 768], let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
}; let outline_config = grammar
documents.push(doc); .outline_config
.as_ref()
.ok_or_else(|| anyhow!("no outline query"))?;
let content = fs.load(&file_path).await?;
parser.set_language(grammar.ts_language).unwrap();
let tree = parser
.parse(&content, None)
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut documents = Vec::new();
let mut context_spans = Vec::new();
for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
let mut item_range = None;
let mut name_range = None;
for capture in mat.captures {
if capture.index == outline_config.item_capture_ix {
item_range = Some(capture.node.byte_range());
} else if capture.index == outline_config.name_capture_ix {
name_range = Some(capture.node.byte_range());
}
}
if let Some((item_range, name_range)) = item_range.zip(name_range) {
if let Some((item, name)) =
content.get(item_range.clone()).zip(content.get(name_range))
{
context_spans.push(item);
documents.push(Document {
name: name.to_string(),
offset: item_range.start,
embedding: Vec::new(),
});
}
}
}
let embeddings = embedding_provider.embed_batch(context_spans).await?;
for (document, embedding) in documents.iter_mut().zip(embeddings) {
document.embedding = embedding;
} }
return Ok(IndexedFile { return Ok(IndexedFile {
path: file_path, path: file_path,
sha1: "asdfasdfasdf".to_string(), sha1: String::new(),
documents, documents,
}); });
} }
@ -98,8 +154,9 @@ impl VectorStore {
let fs = self.fs.clone(); let fs = self.fs.clone();
let language_registry = self.language_registry.clone(); let language_registry = self.language_registry.clone();
let client = self.http_client.clone();
cx.spawn(|this, cx| async move { cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await; futures::future::join_all(worktree_scans_complete).await;
let worktrees = project.read_with(&cx, |project, cx| { let worktrees = project.read_with(&cx, |project, cx| {
@ -131,15 +188,27 @@ impl VectorStore {
}) })
.detach(); .detach();
let provider = OpenAIEmbeddings { client };
let t0 = Instant::now();
cx.background() cx.background()
.scoped(|scope| { .scoped(|scope| {
for _ in 0..cx.background().num_cpus() { for _ in 0..cx.background().num_cpus() {
scope.spawn(async { scope.spawn(async {
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
while let Ok(file_path) = paths_rx.recv().await { while let Ok(file_path) = paths_rx.recv().await {
if let Some(indexed_file) = if let Some(indexed_file) = Self::index_file(
Self::index_file(&fs, &language_registry, file_path) &mut cursor,
.await &mut parser,
.log_err() &provider,
&fs,
&language_registry,
file_path,
)
.await
.log_err()
{ {
indexed_files_tx.try_send(indexed_file).unwrap(); indexed_files_tx.try_send(indexed_file).unwrap();
} }
@ -148,6 +217,9 @@ impl VectorStore {
} }
}) })
.await; .await;
let duration = t0.elapsed();
log::info!("indexed project in {duration:?}");
}) })
.detach(); .detach();
} }

View file

@ -152,7 +152,7 @@ fn main() {
project_panel::init(cx); project_panel::init(cx);
diagnostics::init(cx); diagnostics::init(cx);
search::init(cx); search::init(cx);
vector_store::init(fs.clone(), languages.clone(), cx); vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx); vim::init(cx);
terminal_view::init(cx); terminal_view::init(cx);
theme_testbench::init(cx); theme_testbench::init(cx);