semantic index eval, indexing appropriately

This commit is contained in:
KCaverly 2023-09-13 20:02:15 -04:00
parent 6f29582fb0
commit eff44f9aa4
5 changed files with 168 additions and 46 deletions

4
Cargo.lock generated
View file

@ -6744,6 +6744,7 @@ dependencies = [
"anyhow",
"async-trait",
"bincode",
"client",
"collections",
"ctor",
"editor",
@ -6757,6 +6758,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
"node_runtime",
"parking_lot 0.11.2",
"parse_duration",
"picker",
@ -6766,6 +6768,7 @@ dependencies = [
"rand 0.8.5",
"rpc",
"rusqlite",
"rust-embed",
"schemars",
"serde",
"serde_json",
@ -6788,6 +6791,7 @@ dependencies = [
"unindent",
"util",
"workspace",
"zed",
]
[[package]]

View file

@ -51,6 +51,10 @@ rpc = { path = "../rpc", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"]}
git2 = { version = "0.15"}
rust-embed = { version = "8.0", features = ["include-exclude"] }
client = { path = "../client" }
zed = { path = "../zed"}
node_runtime = { path = "../node_runtime"}
pretty_assertions.workspace = true
rand.workspace = true

View file

@ -17,7 +17,7 @@
{
"query": "generate tags based on config",
"matches": [
"tags/src/lib.rs:261",
"tags/src/lib.rs:261"
]
},
{
@ -54,13 +54,13 @@
{
"query": "Match based on associativity of actions",
"matches": [
"cri/src/generate/build_tables/build_parse_table.rs:542",
"cri/src/generate/build_tables/build_parse_table.rs:542"
]
},
{
"query": "Format token set display",
"matches": [
"cli/src/generate/build_tables/item.rs:246",
"cli/src/generate/build_tables/item.rs:246"
]
},
{

View file

@ -1,8 +1,46 @@
use anyhow::{anyhow, Result};
use client::{self, UserStore};
use git2::{Object, Oid, Repository};
use semantic_index::SearchResult;
use gpui::{AppContext, AssetSource, ModelHandle, Task};
use language::LanguageRegistry;
use node_runtime::RealNodeRuntime;
use project::{Fs, Project, RealFs};
use rust_embed::RustEmbed;
use semantic_index::embedding::OpenAIEmbeddings;
use semantic_index::semantic_index_settings::SemanticIndexSettings;
use semantic_index::{SearchResult, SemanticIndex};
use serde::Deserialize;
use std::path::{Path, PathBuf};
use std::{env, fs};
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
use std::path::{self, Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{cmp, env, fs};
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
use util::http::{self, HttpClient};
use util::paths::{self, EMBEDDINGS_DIR};
use zed::languages;
#[derive(RustEmbed)]
#[folder = "../../assets"]
#[include = "fonts/**/*"]
#[include = "icons/**/*"]
#[include = "themes/**/*"]
#[include = "sounds/**/*"]
#[include = "*.md"]
#[exclude = "*.DS_Store"]
pub struct Assets;
impl AssetSource for Assets {
fn load(&self, path: &str) -> Result<std::borrow::Cow<[u8]>> {
Self::get(path)
.map(|f| f.data)
.ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path))
}
fn list(&self, path: &str) -> Vec<std::borrow::Cow<'static, str>> {
Self::iter().filter(|p| p.starts_with(path)).collect()
}
}
#[derive(Deserialize, Clone)]
struct EvaluationQuery {
@ -13,15 +51,18 @@ struct EvaluationQuery {
impl EvaluationQuery {
fn match_pairs(&self) -> Vec<(PathBuf, usize)> {
let mut pairs = Vec::new();
for match_identifier in self.matches {
let match_parts = match_identifier.split(":");
for match_identifier in self.matches.iter() {
let mut match_parts = match_identifier.split(":");
if let Some(file_path) = match_parts.next() {
if let Some(row_number) = match_parts.next() {
pairs.push((PathBuf::from(file_path), from_str::<usize>(row_number)));
pairs.push((
PathBuf::from(file_path),
row_number.parse::<usize>().unwrap(),
));
}
}
}
pairs
}
}
@ -33,7 +74,7 @@ struct RepoEval {
assertions: Vec<EvaluationQuery>,
}
const TMP_REPO_PATH: &str = "./target/eval_repos";
const TMP_REPO_PATH: &str = "eval_repos";
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
let eval_folder = env::current_dir()?
@ -74,7 +115,12 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
.unwrap()
.to_owned()
.replace(".git", "");
let clone_path = Path::new(TMP_REPO_PATH).join(&repo_name).to_path_buf();
let clone_path = fs::canonicalize(env::current_dir()?)?
.parent()
.ok_or(anyhow!("path canonicalization failed"))?
.join(TMP_REPO_PATH)
.join(&repo_name);
// Delete Clone Path if already exists
let _ = fs::remove_dir_all(&clone_path);
@ -105,7 +151,6 @@ fn dcg(hits: Vec<usize>) -> f32 {
}
fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
// NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
// items returned by the search engine relative to the hypothetical ideal.
// Relevance is represented as a series of booleans, in which each search result returned
@ -125,47 +170,118 @@ fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>,
// very high quality, whereas rank results quickly drop off after the first result.
let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
let hits = vec![1];
return dcg(hits) / dcg(ideal);
}
fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
}
fn evaluate_repo(repo_eval: RepoEval, clone_path: PathBuf) {
// Launch new repo as a new Zed workspace/project
// Index the project
// Search each eval_query
// Calculate Statistics
// fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {}
fn init_logger() {
env_logger::init();
}
fn main() {
// Launch new repo as a new Zed workspace/project
let app = gpui::App::new(Assets).unwrap();
let fs = Arc::new(RealFs);
let http = http::client();
let user_settings_file_rx =
watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone());
let http_client = http::client();
init_logger();
// zed/main.rs
// creating an app and running it, gives you the context.
// create a project, find_or_create_local_worktree.
app.run(move |cx| {
cx.set_global(*RELEASE_CHANNEL);
if let Ok(repo_evals) = parse_eval() {
for repo in repo_evals {
let cloned = clone_repo(repo.clone());
match cloned {
Ok(clone_path) => {
println!(
"Cloned {:?} @ {:?} into {:?}",
repo.repo, repo.commit, &clone_path
);
let client = client::Client::new(http.clone(), cx);
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
// Evaluate Repo
evaluate_repo(repo, clone_path);
// Initialize Settings
let mut store = SettingsStore::default();
store
.set_default_settings(default_settings().as_ref(), cx)
.unwrap();
cx.set_global(store);
handle_settings_file_changes(user_settings_file_rx, cx);
}
Err(err) => {
println!("Error Cloning: {:?}", err);
// Initialize Languages
let login_shell_env_loaded = Task::ready(());
let mut languages = LanguageRegistry::new(login_shell_env_loaded);
languages.set_executor(cx.background().clone());
let languages = Arc::new(languages);
let node_runtime = RealNodeRuntime::new(http.clone());
languages::init(languages.clone(), node_runtime.clone());
project::Project::init(&client, cx);
semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
settings::register::<SemanticIndexSettings>(cx);
let db_file_path = EMBEDDINGS_DIR
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
let languages = languages.clone();
let fs = fs.clone();
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
languages.clone(),
cx.clone(),
)
.await?;
if let Ok(repo_evals) = parse_eval() {
for repo in repo_evals {
let cloned = clone_repo(repo.clone());
match cloned {
Ok(clone_path) => {
log::trace!(
"Cloned {:?} @ {:?} into {:?}",
repo.repo,
repo.commit,
&clone_path
);
// Create Project
let project = cx.update(|cx| {
Project::local(
client.clone(),
user_store.clone(),
languages.clone(),
fs.clone(),
cx,
)
});
// Register Worktree
let _ = project
.update(&mut cx, |project, cx| {
println!(
"Creating worktree in project: {:?}",
clone_path.clone()
);
project.find_or_create_local_worktree(clone_path, true, cx)
})
.await;
let _ = semantic_index
.update(&mut cx, |index, cx| index.index_project(project, cx))
.await;
}
Err(err) => {
log::trace!("Error cloning: {:?}", err);
}
}
}
}
}
}
anyhow::Ok(())
})
.detach();
});
}

View file

@ -1,5 +1,5 @@
mod db;
mod embedding;
pub mod embedding;
mod embedding_queue;
mod parsing;
pub mod semantic_index_settings;
@ -301,7 +301,7 @@ impl SemanticIndex {
}
}
async fn new(
pub async fn new(
fs: Arc<dyn Fs>,
database_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
@ -837,8 +837,6 @@ impl SemanticIndex {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.projects.contains_key(&project.downgrade()) {
log::trace!("Registering Project for Semantic Index");
let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
this.project_worktrees_changed(project.clone(), cx);