add recall and precision to semantic index

This commit is contained in:
KCaverly 2023-09-18 18:25:02 -04:00
parent 566bb9f71b
commit 25bd357426
2 changed files with 179 additions and 35 deletions

View file

@ -1,5 +1,5 @@
{ {
"repo": "https://github.com/AntonOsika/gpt_engineer.git", "repo": "https://github.com/AntonOsika/gpt-engineer.git",
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2", "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
"assertions": [ "assertions": [
{ {

View file

@ -10,7 +10,7 @@ use rust_embed::RustEmbed;
use semantic_index::embedding::OpenAIEmbeddings; use semantic_index::embedding::OpenAIEmbeddings;
use semantic_index::semantic_index_settings::SemanticIndexSettings; use semantic_index::semantic_index_settings::SemanticIndexSettings;
use semantic_index::{SearchResult, SemanticIndex}; use semantic_index::{SearchResult, SemanticIndex};
use serde::Deserialize; use serde::{Deserialize, Serialize};
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore}; use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
@ -43,7 +43,7 @@ impl AssetSource for Assets {
} }
} }
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone, Serialize)]
struct EvaluationQuery { struct EvaluationQuery {
query: String, query: String,
matches: Vec<String>, matches: Vec<String>,
@ -72,15 +72,6 @@ struct RepoEval {
assertions: Vec<EvaluationQuery>, assertions: Vec<EvaluationQuery>,
} }
struct EvaluationResults {
token_count: usize,
span_count: usize,
time_to_index: Duration,
time_to_search: Vec<Duration>,
ndcg: HashMap<usize, f32>,
map: HashMap<usize, f32>,
}
const TMP_REPO_PATH: &str = "eval_repos"; const TMP_REPO_PATH: &str = "eval_repos";
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> { fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
@ -114,7 +105,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
Ok(repo_evals) Ok(repo_evals)
} }
fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> { fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
let repo_name = Path::new(repo_eval.repo.as_str()) let repo_name = Path::new(repo_eval.repo.as_str())
.file_name() .file_name()
.unwrap() .unwrap()
@ -146,7 +137,7 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
repo.checkout_tree(&obj, None)?; repo.checkout_tree(&obj, None)?;
repo.set_head_detached(obj.id())?; repo.set_head_detached(obj.id())?;
Ok(clone_path) Ok((repo_name, clone_path))
} }
fn dcg(hits: Vec<usize>) -> f32 { fn dcg(hits: Vec<usize>) -> f32 {
@ -253,30 +244,165 @@ fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
let mut rolling_map = 0.0; let mut rolling_map = 0.0;
for (idx, h) in hits.into_iter().enumerate() { for (idx, h) in hits.into_iter().enumerate() {
rolling_non_zero += h as f32; rolling_non_zero += h as f32;
rolling_map += rolling_non_zero / (idx + 1) as f32; if h == 1 {
rolling_map += rolling_non_zero / (idx + 1) as f32;
}
map_at_k.push(rolling_map / non_zero); map_at_k.push(rolling_map / non_zero);
} }
map_at_k map_at_k
} }
fn evaluate_mrr(hits: Vec<usize>) -> f32 {
for (idx, h) in hits.into_iter().enumerate() {
if h == 1 {
return 1.0 / (idx + 1) as f32;
}
}
return 0.0;
}
fn init_logger() { fn init_logger() {
env_logger::init(); env_logger::init();
} }
#[derive(Serialize)]
struct QueryMetrics {
query: EvaluationQuery,
millis_to_search: Duration,
ndcg: Vec<f32>,
map: Vec<f32>,
mrr: f32,
hits: Vec<usize>,
precision: Vec<f32>,
recall: Vec<f32>,
}
#[derive(Serialize)]
struct SummaryMetrics {
millis_to_search: f32,
ndcg: Vec<f32>,
map: Vec<f32>,
mrr: f32,
precision: Vec<f32>,
recall: Vec<f32>,
}
#[derive(Serialize)]
struct RepoEvaluationMetrics {
millis_to_index: Duration,
query_metrics: Vec<QueryMetrics>,
repo_metrics: Option<SummaryMetrics>,
}
impl RepoEvaluationMetrics {
fn new(millis_to_index: Duration) -> Self {
RepoEvaluationMetrics {
millis_to_index,
query_metrics: Vec::new(),
repo_metrics: None,
}
}
fn save(&self, repo_name: String) -> Result<()> {
let results_string = serde_json::to_string(&self)?;
fs::write(format!("./{}_evaluation.json", repo_name), results_string)
.expect("Unable to write file");
Ok(())
}
fn summarize(&mut self) {
let l = self.query_metrics.len() as f32;
let millis_to_search: f32 = self
.query_metrics
.iter()
.map(|metrics| metrics.millis_to_search.as_millis())
.sum::<u128>() as f32
/ l;
let mut ndcg_sum = vec![0.0; 10];
let mut map_sum = vec![0.0; 10];
let mut precision_sum = vec![0.0; 10];
let mut recall_sum = vec![0.0; 10];
let mut mmr_sum = 0.0;
for query_metric in self.query_metrics.iter() {
for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
*ndcg += query_ndcg;
}
for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
*mapp += query_map;
}
for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
*pre += query_pre;
}
for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
*rec += query_rec;
}
mmr_sum += query_metric.mrr;
}
let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
let precision = precision_sum
.iter()
.map(|val| val / l)
.collect::<Vec<f32>>();
let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
let mrr = mmr_sum / l;
self.repo_metrics = Some(SummaryMetrics {
millis_to_search,
ndcg,
map,
mrr,
precision,
recall,
})
}
}
fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
let mut rolling_hit: f32 = 0.0;
let mut precision = Vec::new();
for (idx, hit) in hits.into_iter().enumerate() {
rolling_hit += hit as f32;
precision.push(rolling_hit / ((idx as f32) + 1.0));
}
precision
}
fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
let total_relevant = ideal.iter().sum::<usize>() as f32;
let mut recall = Vec::new();
let mut rolling_hit: f32 = 0.0;
for hit in hits {
rolling_hit += hit as f32;
recall.push(rolling_hit / total_relevant);
}
recall
}
async fn evaluate_repo( async fn evaluate_repo(
repo_name: String,
index: ModelHandle<SemanticIndex>, index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>, project: ModelHandle<Project>,
query_matches: Vec<EvaluationQuery>, query_matches: Vec<EvaluationQuery>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<()> { ) -> Result<RepoEvaluationMetrics> {
// Index Project // Index Project
let index_t0 = Instant::now(); let index_t0 = Instant::now();
index index
.update(cx, |index, cx| index.index_project(project.clone(), cx)) .update(cx, |index, cx| index.index_project(project.clone(), cx))
.await?; .await?;
let index_time = index_t0.elapsed(); let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
println!("Time to Index: {:?}", index_time.as_millis());
for query in query_matches { for query in query_matches {
// Query each match in order // Query each match in order
@ -286,26 +412,45 @@ async fn evaluate_repo(
index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx) index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
}) })
.await?; .await?;
let search_time = search_t0.elapsed(); let millis_to_search = search_t0.elapsed();
println!("Time to Search: {:?}", search_time.as_millis());
// Get Hits/Ideal // Get Hits/Ideal
let k = 10; let k = 10;
let (ideal, hits) = self::get_hits(query, search_results, k, cx); let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
// Evaluate ndcg@k, for k = 1, 3, 5, 10 // Evaluate ndcg@k, for k = 1, 3, 5, 10
let ndcg = evaluate_ndcg(hits.clone(), ideal); let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
println!("NDCG: {:?}", ndcg);
// Evaluate map@k, for k = 1, 3, 5, 10 // Evaluate map@k, for k = 1, 3, 5, 10
let map = evaluate_map(hits); let map = evaluate_map(hits.clone());
println!("MAP: {:?}", map);
// Evaluate span count // Evaluate mrr
// Evaluate token count let mrr = evaluate_mrr(hits.clone());
// Evaluate precision
let precision = evaluate_precision(hits.clone());
// Evaluate Recall
let recall = evaluate_recall(hits.clone(), ideal);
let query_metrics = QueryMetrics {
query,
millis_to_search,
ndcg,
map,
mrr,
hits,
precision,
recall,
};
repo_metrics.query_metrics.push(query_metrics);
} }
anyhow::Ok(()) repo_metrics.summarize();
repo_metrics.save(repo_name);
anyhow::Ok(repo_metrics)
} }
fn main() { fn main() {
@ -367,12 +512,10 @@ fn main() {
for repo in repo_evals { for repo in repo_evals {
let cloned = clone_repo(repo.clone()); let cloned = clone_repo(repo.clone());
match cloned { match cloned {
Ok(clone_path) => { Ok((repo_name, clone_path)) => {
log::trace!( println!(
"Cloned {:?} @ {:?} into {:?}", "Cloned {:?} @ {:?} into {:?}",
repo.repo, repo.repo, repo.commit, &clone_path
repo.commit,
&clone_path
); );
// Create Project // Create Project
@ -393,7 +536,8 @@ fn main() {
}) })
.await; .await;
evaluate_repo( let repo_metrics = evaluate_repo(
repo_name,
semantic_index.clone(), semantic_index.clone(),
project, project,
repo.assertions, repo.assertions,
@ -402,7 +546,7 @@ fn main() {
.await?; .await?;
} }
Err(err) => { Err(err) => {
log::trace!("Error cloning: {:?}", err); println!("Error cloning: {:?}", err);
} }
} }
} }