add recall and precision to semantic index
This commit is contained in:
parent
566bb9f71b
commit
25bd357426
2 changed files with 179 additions and 35 deletions
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"repo": "https://github.com/AntonOsika/gpt_engineer.git",
|
"repo": "https://github.com/AntonOsika/gpt-engineer.git",
|
||||||
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
|
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
|
||||||
"assertions": [
|
"assertions": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -10,7 +10,7 @@ use rust_embed::RustEmbed;
|
||||||
use semantic_index::embedding::OpenAIEmbeddings;
|
use semantic_index::embedding::OpenAIEmbeddings;
|
||||||
use semantic_index::semantic_index_settings::SemanticIndexSettings;
|
use semantic_index::semantic_index_settings::SemanticIndexSettings;
|
||||||
use semantic_index::{SearchResult, SemanticIndex};
|
use semantic_index::{SearchResult, SemanticIndex};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
|
use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -43,7 +43,7 @@ impl AssetSource for Assets {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Clone)]
|
#[derive(Deserialize, Clone, Serialize)]
|
||||||
struct EvaluationQuery {
|
struct EvaluationQuery {
|
||||||
query: String,
|
query: String,
|
||||||
matches: Vec<String>,
|
matches: Vec<String>,
|
||||||
|
@ -72,15 +72,6 @@ struct RepoEval {
|
||||||
assertions: Vec<EvaluationQuery>,
|
assertions: Vec<EvaluationQuery>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct EvaluationResults {
|
|
||||||
token_count: usize,
|
|
||||||
span_count: usize,
|
|
||||||
time_to_index: Duration,
|
|
||||||
time_to_search: Vec<Duration>,
|
|
||||||
ndcg: HashMap<usize, f32>,
|
|
||||||
map: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
const TMP_REPO_PATH: &str = "eval_repos";
|
const TMP_REPO_PATH: &str = "eval_repos";
|
||||||
|
|
||||||
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
|
fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
|
||||||
|
@ -114,7 +105,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
|
||||||
Ok(repo_evals)
|
Ok(repo_evals)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
|
fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<(String, PathBuf)> {
|
||||||
let repo_name = Path::new(repo_eval.repo.as_str())
|
let repo_name = Path::new(repo_eval.repo.as_str())
|
||||||
.file_name()
|
.file_name()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -146,7 +137,7 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
|
||||||
repo.checkout_tree(&obj, None)?;
|
repo.checkout_tree(&obj, None)?;
|
||||||
repo.set_head_detached(obj.id())?;
|
repo.set_head_detached(obj.id())?;
|
||||||
|
|
||||||
Ok(clone_path)
|
Ok((repo_name, clone_path))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dcg(hits: Vec<usize>) -> f32 {
|
fn dcg(hits: Vec<usize>) -> f32 {
|
||||||
|
@ -253,30 +244,165 @@ fn evaluate_map(hits: Vec<usize>) -> Vec<f32> {
|
||||||
let mut rolling_map = 0.0;
|
let mut rolling_map = 0.0;
|
||||||
for (idx, h) in hits.into_iter().enumerate() {
|
for (idx, h) in hits.into_iter().enumerate() {
|
||||||
rolling_non_zero += h as f32;
|
rolling_non_zero += h as f32;
|
||||||
rolling_map += rolling_non_zero / (idx + 1) as f32;
|
if h == 1 {
|
||||||
|
rolling_map += rolling_non_zero / (idx + 1) as f32;
|
||||||
|
}
|
||||||
map_at_k.push(rolling_map / non_zero);
|
map_at_k.push(rolling_map / non_zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
map_at_k
|
map_at_k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn evaluate_mrr(hits: Vec<usize>) -> f32 {
|
||||||
|
for (idx, h) in hits.into_iter().enumerate() {
|
||||||
|
if h == 1 {
|
||||||
|
return 1.0 / (idx + 1) as f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
fn init_logger() {
|
fn init_logger() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct QueryMetrics {
|
||||||
|
query: EvaluationQuery,
|
||||||
|
millis_to_search: Duration,
|
||||||
|
ndcg: Vec<f32>,
|
||||||
|
map: Vec<f32>,
|
||||||
|
mrr: f32,
|
||||||
|
hits: Vec<usize>,
|
||||||
|
precision: Vec<f32>,
|
||||||
|
recall: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct SummaryMetrics {
|
||||||
|
millis_to_search: f32,
|
||||||
|
ndcg: Vec<f32>,
|
||||||
|
map: Vec<f32>,
|
||||||
|
mrr: f32,
|
||||||
|
precision: Vec<f32>,
|
||||||
|
recall: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct RepoEvaluationMetrics {
|
||||||
|
millis_to_index: Duration,
|
||||||
|
query_metrics: Vec<QueryMetrics>,
|
||||||
|
repo_metrics: Option<SummaryMetrics>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RepoEvaluationMetrics {
|
||||||
|
fn new(millis_to_index: Duration) -> Self {
|
||||||
|
RepoEvaluationMetrics {
|
||||||
|
millis_to_index,
|
||||||
|
query_metrics: Vec::new(),
|
||||||
|
repo_metrics: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save(&self, repo_name: String) -> Result<()> {
|
||||||
|
let results_string = serde_json::to_string(&self)?;
|
||||||
|
fs::write(format!("./{}_evaluation.json", repo_name), results_string)
|
||||||
|
.expect("Unable to write file");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn summarize(&mut self) {
|
||||||
|
let l = self.query_metrics.len() as f32;
|
||||||
|
let millis_to_search: f32 = self
|
||||||
|
.query_metrics
|
||||||
|
.iter()
|
||||||
|
.map(|metrics| metrics.millis_to_search.as_millis())
|
||||||
|
.sum::<u128>() as f32
|
||||||
|
/ l;
|
||||||
|
|
||||||
|
let mut ndcg_sum = vec![0.0; 10];
|
||||||
|
let mut map_sum = vec![0.0; 10];
|
||||||
|
let mut precision_sum = vec![0.0; 10];
|
||||||
|
let mut recall_sum = vec![0.0; 10];
|
||||||
|
let mut mmr_sum = 0.0;
|
||||||
|
|
||||||
|
for query_metric in self.query_metrics.iter() {
|
||||||
|
for (ndcg, query_ndcg) in ndcg_sum.iter_mut().zip(query_metric.ndcg.clone()) {
|
||||||
|
*ndcg += query_ndcg;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mapp, query_map) in map_sum.iter_mut().zip(query_metric.map.clone()) {
|
||||||
|
*mapp += query_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (pre, query_pre) in precision_sum.iter_mut().zip(query_metric.precision.clone()) {
|
||||||
|
*pre += query_pre;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (rec, query_rec) in recall_sum.iter_mut().zip(query_metric.recall.clone()) {
|
||||||
|
*rec += query_rec;
|
||||||
|
}
|
||||||
|
|
||||||
|
mmr_sum += query_metric.mrr;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndcg = ndcg_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
|
||||||
|
let map = map_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
|
||||||
|
let precision = precision_sum
|
||||||
|
.iter()
|
||||||
|
.map(|val| val / l)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
let recall = recall_sum.iter().map(|val| val / l).collect::<Vec<f32>>();
|
||||||
|
let mrr = mmr_sum / l;
|
||||||
|
|
||||||
|
self.repo_metrics = Some(SummaryMetrics {
|
||||||
|
millis_to_search,
|
||||||
|
ndcg,
|
||||||
|
map,
|
||||||
|
mrr,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate_precision(hits: Vec<usize>) -> Vec<f32> {
|
||||||
|
let mut rolling_hit: f32 = 0.0;
|
||||||
|
let mut precision = Vec::new();
|
||||||
|
for (idx, hit) in hits.into_iter().enumerate() {
|
||||||
|
rolling_hit += hit as f32;
|
||||||
|
precision.push(rolling_hit / ((idx as f32) + 1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
precision
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate_recall(hits: Vec<usize>, ideal: Vec<usize>) -> Vec<f32> {
|
||||||
|
let total_relevant = ideal.iter().sum::<usize>() as f32;
|
||||||
|
let mut recall = Vec::new();
|
||||||
|
let mut rolling_hit: f32 = 0.0;
|
||||||
|
for hit in hits {
|
||||||
|
rolling_hit += hit as f32;
|
||||||
|
recall.push(rolling_hit / total_relevant);
|
||||||
|
}
|
||||||
|
|
||||||
|
recall
|
||||||
|
}
|
||||||
|
|
||||||
async fn evaluate_repo(
|
async fn evaluate_repo(
|
||||||
|
repo_name: String,
|
||||||
index: ModelHandle<SemanticIndex>,
|
index: ModelHandle<SemanticIndex>,
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
query_matches: Vec<EvaluationQuery>,
|
query_matches: Vec<EvaluationQuery>,
|
||||||
cx: &mut AsyncAppContext,
|
cx: &mut AsyncAppContext,
|
||||||
) -> Result<()> {
|
) -> Result<RepoEvaluationMetrics> {
|
||||||
// Index Project
|
// Index Project
|
||||||
let index_t0 = Instant::now();
|
let index_t0 = Instant::now();
|
||||||
index
|
index
|
||||||
.update(cx, |index, cx| index.index_project(project.clone(), cx))
|
.update(cx, |index, cx| index.index_project(project.clone(), cx))
|
||||||
.await?;
|
.await?;
|
||||||
let index_time = index_t0.elapsed();
|
let mut repo_metrics = RepoEvaluationMetrics::new(index_t0.elapsed());
|
||||||
println!("Time to Index: {:?}", index_time.as_millis());
|
|
||||||
|
|
||||||
for query in query_matches {
|
for query in query_matches {
|
||||||
// Query each match in order
|
// Query each match in order
|
||||||
|
@ -286,26 +412,45 @@ async fn evaluate_repo(
|
||||||
index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
|
index.search_project(project.clone(), query.clone().query, 10, vec![], vec![], cx)
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
let search_time = search_t0.elapsed();
|
let millis_to_search = search_t0.elapsed();
|
||||||
println!("Time to Search: {:?}", search_time.as_millis());
|
|
||||||
|
|
||||||
// Get Hits/Ideal
|
// Get Hits/Ideal
|
||||||
let k = 10;
|
let k = 10;
|
||||||
let (ideal, hits) = self::get_hits(query, search_results, k, cx);
|
let (ideal, hits) = self::get_hits(query.clone(), search_results, k, cx);
|
||||||
|
|
||||||
// Evaluate ndcg@k, for k = 1, 3, 5, 10
|
// Evaluate ndcg@k, for k = 1, 3, 5, 10
|
||||||
let ndcg = evaluate_ndcg(hits.clone(), ideal);
|
let ndcg = evaluate_ndcg(hits.clone(), ideal.clone());
|
||||||
println!("NDCG: {:?}", ndcg);
|
|
||||||
|
|
||||||
// Evaluate map@k, for k = 1, 3, 5, 10
|
// Evaluate map@k, for k = 1, 3, 5, 10
|
||||||
let map = evaluate_map(hits);
|
let map = evaluate_map(hits.clone());
|
||||||
println!("MAP: {:?}", map);
|
|
||||||
|
|
||||||
// Evaluate span count
|
// Evaluate mrr
|
||||||
// Evaluate token count
|
let mrr = evaluate_mrr(hits.clone());
|
||||||
|
|
||||||
|
// Evaluate precision
|
||||||
|
let precision = evaluate_precision(hits.clone());
|
||||||
|
|
||||||
|
// Evaluate Recall
|
||||||
|
let recall = evaluate_recall(hits.clone(), ideal);
|
||||||
|
|
||||||
|
let query_metrics = QueryMetrics {
|
||||||
|
query,
|
||||||
|
millis_to_search,
|
||||||
|
ndcg,
|
||||||
|
map,
|
||||||
|
mrr,
|
||||||
|
hits,
|
||||||
|
precision,
|
||||||
|
recall,
|
||||||
|
};
|
||||||
|
|
||||||
|
repo_metrics.query_metrics.push(query_metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::Ok(())
|
repo_metrics.summarize();
|
||||||
|
repo_metrics.save(repo_name);
|
||||||
|
|
||||||
|
anyhow::Ok(repo_metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
@ -367,12 +512,10 @@ fn main() {
|
||||||
for repo in repo_evals {
|
for repo in repo_evals {
|
||||||
let cloned = clone_repo(repo.clone());
|
let cloned = clone_repo(repo.clone());
|
||||||
match cloned {
|
match cloned {
|
||||||
Ok(clone_path) => {
|
Ok((repo_name, clone_path)) => {
|
||||||
log::trace!(
|
println!(
|
||||||
"Cloned {:?} @ {:?} into {:?}",
|
"Cloned {:?} @ {:?} into {:?}",
|
||||||
repo.repo,
|
repo.repo, repo.commit, &clone_path
|
||||||
repo.commit,
|
|
||||||
&clone_path
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create Project
|
// Create Project
|
||||||
|
@ -393,7 +536,8 @@ fn main() {
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
evaluate_repo(
|
let repo_metrics = evaluate_repo(
|
||||||
|
repo_name,
|
||||||
semantic_index.clone(),
|
semantic_index.clone(),
|
||||||
project,
|
project,
|
||||||
repo.assertions,
|
repo.assertions,
|
||||||
|
@ -402,7 +546,7 @@ fn main() {
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::trace!("Error cloning: {:?}", err);
|
println!("Error cloning: {:?}", err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue