Make scoring more precise by using floats when diffing AI refactors

This commit is contained in:
Antonio Scandurra 2023-08-24 12:45:44 +02:00
parent 481bcbf204
commit 9674b03855
3 changed files with 19 additions and 13 deletions

1
Cargo.lock generated
View file

@ -113,6 +113,7 @@ dependencies = [
"language",
"log",
"menu",
"ordered-float",
"project",
"rand 0.8.5",
"regex",

View file

@ -26,6 +26,7 @@ chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
indoc.workspace = true
isahc.workspace = true
ordered-float.workspace = true
regex.workspace = true
schemars.workspace = true
serde.workspace = true

View file

@ -1,11 +1,13 @@
use collections::HashMap;
use ordered_float::OrderedFloat;
use std::{
cmp,
fmt::{self, Debug},
ops::Range,
};
struct Matrix {
cells: Vec<isize>,
cells: Vec<f64>,
rows: usize,
cols: usize,
}
@ -20,12 +22,12 @@ impl Matrix {
}
fn resize(&mut self, rows: usize, cols: usize) {
self.cells.resize(rows * cols, 0);
self.cells.resize(rows * cols, 0.);
self.rows = rows;
self.cols = cols;
}
fn get(&self, row: usize, col: usize) -> isize {
fn get(&self, row: usize, col: usize) -> f64 {
if row >= self.rows {
panic!("row out of bounds")
}
@ -36,7 +38,7 @@ impl Matrix {
self.cells[col * self.rows + row]
}
fn set(&mut self, row: usize, col: usize, value: isize) {
fn set(&mut self, row: usize, col: usize, value: f64) {
if row >= self.rows {
panic!("row out of bounds")
}
@ -79,16 +81,17 @@ pub struct Diff {
}
impl Diff {
const INSERTION_SCORE: isize = -1;
const DELETION_SCORE: isize = -5;
const EQUALITY_BASE: isize = 2;
const INSERTION_SCORE: f64 = -1.;
const DELETION_SCORE: f64 = -5.;
const EQUALITY_BASE: f64 = 1.618;
const MAX_EQUALITY_EXPONENT: i32 = 32;
pub fn new(old: String) -> Self {
let old = old.chars().collect::<Vec<_>>();
let mut scores = Matrix::new();
scores.resize(old.len() + 1, 1);
for i in 0..=old.len() {
scores.set(i, 0, i as isize * Self::DELETION_SCORE);
scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
}
Self {
old,
@ -105,7 +108,7 @@ impl Diff {
self.scores.resize(self.old.len() + 1, self.new.len() + 1);
for j in self.new_text_ix + 1..=self.new.len() {
self.scores.set(0, j, j as isize * Self::INSERTION_SCORE);
self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
for i in 1..=self.old.len() {
let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
@ -117,10 +120,11 @@ impl Diff {
if self.old[i - 1] == ' ' {
self.scores.get(i - 1, j - 1)
} else {
self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.pow(equal_run / 3)
let exponent = cmp::min(equal_run as i32 / 3, Self::MAX_EQUALITY_EXPONENT);
self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
}
} else {
isize::MIN
f64::NEG_INFINITY
};
let score = insertion_score.max(deletion_score).max(equality_score);
@ -128,7 +132,7 @@ impl Diff {
}
}
let mut max_score = isize::MIN;
let mut max_score = f64::NEG_INFINITY;
let mut next_old_text_ix = self.old_text_ix;
let next_new_text_ix = self.new.len();
for i in self.old_text_ix..=self.old.len() {
@ -173,7 +177,7 @@ impl Diff {
let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
.iter()
.max_by_key(|cell| cell.map(|(i, j)| self.scores.get(i, j)))
.max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
.unwrap()
.unwrap();