semantic index eval, indexing appropriately
This commit is contained in:
parent
6f29582fb0
commit
eff44f9aa4
5 changed files with 168 additions and 46 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -6744,6 +6744,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"editor",
|
||||
|
@ -6757,6 +6758,7 @@ dependencies = [
|
|||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"node_runtime",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"picker",
|
||||
|
@ -6766,6 +6768,7 @@ dependencies = [
|
|||
"rand 0.8.5",
|
||||
"rpc",
|
||||
"rusqlite",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -6788,6 +6791,7 @@ dependencies = [
|
|||
"unindent",
|
||||
"util",
|
||||
"workspace",
|
||||
"zed",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -51,6 +51,10 @@ rpc = { path = "../rpc", features = ["test-support"] }
|
|||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"]}
|
||||
git2 = { version = "0.15"}
|
||||
rust-embed = { version = "8.0", features = ["include-exclude"] }
|
||||
client = { path = "../client" }
|
||||
zed = { path = "../zed"}
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
|
||||
pretty_assertions.workspace = true
|
||||
rand.workspace = true
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
{
|
||||
"query": "generate tags based on config",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:261",
|
||||
"tags/src/lib.rs:261"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -54,13 +54,13 @@
|
|||
{
|
||||
"query": "Match based on associativity of actions",
|
||||
"matches": [
|
||||
"cri/src/generate/build_tables/build_parse_table.rs:542",
|
||||
"cri/src/generate/build_tables/build_parse_table.rs:542"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Format token set display",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/item.rs:246",
|
||||
"cli/src/generate/build_tables/item.rs:246"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -1,8 +1,46 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use client::{self, UserStore};
|
||||
use git2::{Object, Oid, Repository};
|
||||
use semantic_index::SearchResult;
|
||||
use gpui::{AppContext, AssetSource, ModelHandle, Task};
|
||||
use language::LanguageRegistry;
|
||||
use node_runtime::RealNodeRuntime;
|
||||
use project::{Fs, Project, RealFs};
|
||||
use rust_embed::RustEmbed;
|
||||
use semantic_index::embedding::OpenAIEmbeddings;
|
||||
use semantic_index::semantic_index_settings::SemanticIndexSettings;
|
||||
use semantic_index::{SearchResult, SemanticIndex};
|
||||
use serde::Deserialize;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{env, fs};
|
||||
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
|
||||
use std::path::{self, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{cmp, env, fs};
|
||||
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
|
||||
use util::http::{self, HttpClient};
|
||||
use util::paths::{self, EMBEDDINGS_DIR};
|
||||
use zed::languages;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "../../assets"]
|
||||
#[include = "fonts/**/*"]
|
||||
#[include = "icons/**/*"]
|
||||
#[include = "themes/**/*"]
|
||||
#[include = "sounds/**/*"]
|
||||
#[include = "*.md"]
|
||||
#[exclude = "*.DS_Store"]
|
||||
pub struct Assets;
|
||||
|
||||
impl AssetSource for Assets {
|
||||
fn load(&self, path: &str) -> Result<std::borrow::Cow<[u8]>> {
|
||||
Self::get(path)
|
||||
.map(|f| f.data)
|
||||
.ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path))
|
||||
}
|
||||
|
||||
fn list(&self, path: &str) -> Vec<std::borrow::Cow<'static, str>> {
|
||||
Self::iter().filter(|p| p.starts_with(path)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct EvaluationQuery {
|
||||
|
@ -13,15 +51,18 @@ struct EvaluationQuery {
|
|||
impl EvaluationQuery {
|
||||
fn match_pairs(&self) -> Vec<(PathBuf, usize)> {
|
||||
let mut pairs = Vec::new();
|
||||
for match_identifier in self.matches {
|
||||
let match_parts = match_identifier.split(":");
|
||||
for match_identifier in self.matches.iter() {
|
||||
let mut match_parts = match_identifier.split(":");
|
||||
|
||||
if let Some(file_path) = match_parts.next() {
|
||||
if let Some(row_number) = match_parts.next() {
|
||||
pairs.push((PathBuf::from(file_path), from_str::<usize>(row_number)));
|
||||
pairs.push((
|
||||
PathBuf::from(file_path),
|
||||
row_number.parse::<usize>().unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
pairs
|
||||
}
|
||||
}
|
||||
|
@ -33,7 +74,7 @@ struct RepoEval {
|
|||
assertions: Vec<EvaluationQuery>,
|
||||
}
|
||||
|
||||
const TMP_REPO_PATH: &str = "./target/eval_repos";
|
||||
const TMP_REPO_PATH: &str = "eval_repos";
|
||||
|
||||
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
|
||||
let eval_folder = env::current_dir()?
|
||||
|
@ -74,7 +115,12 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
|
|||
.unwrap()
|
||||
.to_owned()
|
||||
.replace(".git", "");
|
||||
let clone_path = Path::new(TMP_REPO_PATH).join(&repo_name).to_path_buf();
|
||||
|
||||
let clone_path = fs::canonicalize(env::current_dir()?)?
|
||||
.parent()
|
||||
.ok_or(anyhow!("path canonicalization failed"))?
|
||||
.join(TMP_REPO_PATH)
|
||||
.join(&repo_name);
|
||||
|
||||
// Delete Clone Path if already exists
|
||||
let _ = fs::remove_dir_all(&clone_path);
|
||||
|
@ -105,7 +151,6 @@ fn dcg(hits: Vec<usize>) -> f32 {
|
|||
}
|
||||
|
||||
fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
|
||||
|
||||
// NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
|
||||
// items returned by the search engine relative to the hypothetical ideal.
|
||||
// Relevance is represented as a series of booleans, in which each search result returned
|
||||
|
@ -125,47 +170,118 @@ fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>,
|
|||
// very high quality, whereas rank results quickly drop off after the first result.
|
||||
|
||||
let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
|
||||
let hits = vec![1];
|
||||
|
||||
return dcg(hits) / dcg(ideal);
|
||||
}
|
||||
|
||||
fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
|
||||
|
||||
}
|
||||
|
||||
fn evaluate_repo(repo_eval: RepoEval, clone_path: PathBuf) {
|
||||
|
||||
// Launch new repo as a new Zed workspace/project
|
||||
// Index the project
|
||||
// Search each eval_query
|
||||
// Calculate Statistics
|
||||
// fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {}
|
||||
|
||||
fn init_logger() {
|
||||
env_logger::init();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Launch new repo as a new Zed workspace/project
|
||||
let app = gpui::App::new(Assets).unwrap();
|
||||
let fs = Arc::new(RealFs);
|
||||
let http = http::client();
|
||||
let user_settings_file_rx =
|
||||
watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone());
|
||||
let http_client = http::client();
|
||||
init_logger();
|
||||
|
||||
// zed/main.rs
|
||||
// creating an app and running it, gives you the context.
|
||||
// create a project, find_or_create_local_worktree.
|
||||
app.run(move |cx| {
|
||||
cx.set_global(*RELEASE_CHANNEL);
|
||||
|
||||
if let Ok(repo_evals) = parse_eval() {
|
||||
for repo in repo_evals {
|
||||
let cloned = clone_repo(repo.clone());
|
||||
match cloned {
|
||||
Ok(clone_path) => {
|
||||
println!(
|
||||
"Cloned {:?} @ {:?} into {:?}",
|
||||
repo.repo, repo.commit, &clone_path
|
||||
);
|
||||
let client = client::Client::new(http.clone(), cx);
|
||||
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
|
||||
|
||||
// Evaluate Repo
|
||||
evaluate_repo(repo, clone_path);
|
||||
// Initialize Settings
|
||||
let mut store = SettingsStore::default();
|
||||
store
|
||||
.set_default_settings(default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(store);
|
||||
handle_settings_file_changes(user_settings_file_rx, cx);
|
||||
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error Cloning: {:?}", err);
|
||||
// Initialize Languages
|
||||
let login_shell_env_loaded = Task::ready(());
|
||||
let mut languages = LanguageRegistry::new(login_shell_env_loaded);
|
||||
languages.set_executor(cx.background().clone());
|
||||
let languages = Arc::new(languages);
|
||||
|
||||
let node_runtime = RealNodeRuntime::new(http.clone());
|
||||
languages::init(languages.clone(), node_runtime.clone());
|
||||
|
||||
project::Project::init(&client, cx);
|
||||
semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
|
||||
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
|
||||
let db_file_path = EMBEDDINGS_DIR
|
||||
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
|
||||
.join("embeddings_db");
|
||||
|
||||
let languages = languages.clone();
|
||||
let fs = fs.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs.clone(),
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||
languages.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Ok(repo_evals) = parse_eval() {
|
||||
for repo in repo_evals {
|
||||
let cloned = clone_repo(repo.clone());
|
||||
match cloned {
|
||||
Ok(clone_path) => {
|
||||
log::trace!(
|
||||
"Cloned {:?} @ {:?} into {:?}",
|
||||
repo.repo,
|
||||
repo.commit,
|
||||
&clone_path
|
||||
);
|
||||
|
||||
// Create Project
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
client.clone(),
|
||||
user_store.clone(),
|
||||
languages.clone(),
|
||||
fs.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Register Worktree
|
||||
let _ = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
println!(
|
||||
"Creating worktree in project: {:?}",
|
||||
clone_path.clone()
|
||||
);
|
||||
project.find_or_create_local_worktree(clone_path, true, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
let _ = semantic_index
|
||||
.update(&mut cx, |index, cx| index.index_project(project, cx))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
log::trace!("Error cloning: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mod db;
|
||||
mod embedding;
|
||||
pub mod embedding;
|
||||
mod embedding_queue;
|
||||
mod parsing;
|
||||
pub mod semantic_index_settings;
|
||||
|
@ -301,7 +301,7 @@ impl SemanticIndex {
|
|||
}
|
||||
}
|
||||
|
||||
async fn new(
|
||||
pub async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
database_path: PathBuf,
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
|
@ -837,8 +837,6 @@ impl SemanticIndex {
|
|||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if !self.projects.contains_key(&project.downgrade()) {
|
||||
log::trace!("Registering Project for Semantic Index");
|
||||
|
||||
let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
|
||||
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
|
||||
this.project_worktrees_changed(project.clone(), cx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue