Allow seeking SumTree cursor by multiple seek types per dimension

Also, remove the cursor's sum_dimension. Replace it with a
blanket implementation of Dimension for two-element tuples
of dimensions.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Max Brunsfeld 2021-09-24 18:04:43 -07:00
parent 39fbf7d4d1
commit ab31ddfc31
11 changed files with 458 additions and 509 deletions

View file

@ -3,51 +3,31 @@ use arrayvec::ArrayVec;
use std::{cmp::Ordering, sync::Arc};
#[derive(Clone)]
struct StackEntry<'a, T: Item, S, U> {
struct StackEntry<'a, T: Item, D> {
tree: &'a SumTree<T>,
index: usize,
seek_dimension: S,
sum_dimension: U,
}
impl<'a, T, S, U> StackEntry<'a, T, S, U>
where
T: Item,
S: SeekDimension<'a, T::Summary>,
U: SeekDimension<'a, T::Summary>,
{
fn swap_dimensions(self) -> StackEntry<'a, T, U, S> {
StackEntry {
tree: self.tree,
index: self.index,
seek_dimension: self.sum_dimension,
sum_dimension: self.seek_dimension,
}
}
position: D,
}
#[derive(Clone)]
pub struct Cursor<'a, T: Item, S, U> {
pub struct Cursor<'a, T: Item, D> {
tree: &'a SumTree<T>,
stack: ArrayVec<StackEntry<'a, T, S, U>, 16>,
seek_dimension: S,
sum_dimension: U,
stack: ArrayVec<StackEntry<'a, T, D>, 16>,
position: D,
did_seek: bool,
at_end: bool,
}
impl<'a, T, S, U> Cursor<'a, T, S, U>
impl<'a, T, D> Cursor<'a, T, D>
where
T: Item,
S: Dimension<'a, T::Summary>,
U: Dimension<'a, T::Summary>,
D: Dimension<'a, T::Summary>,
{
pub fn new(tree: &'a SumTree<T>) -> Self {
Self {
tree,
stack: ArrayVec::new(),
seek_dimension: S::default(),
sum_dimension: U::default(),
position: D::default(),
did_seek: false,
at_end: false,
}
@ -57,35 +37,20 @@ where
self.did_seek = false;
self.at_end = false;
self.stack.truncate(0);
self.seek_dimension = S::default();
self.sum_dimension = U::default();
self.position = D::default();
}
pub fn seek_start(&self) -> &S {
&self.seek_dimension
pub fn start(&self) -> &D {
&self.position
}
pub fn seek_end(&self, cx: &<T::Summary as Summary>::Context) -> S {
pub fn end(&self, cx: &<T::Summary as Summary>::Context) -> D {
if let Some(item_summary) = self.item_summary() {
let mut end = self.seek_start().clone();
let mut end = self.start().clone();
end.add_summary(item_summary, cx);
end
} else {
self.seek_start().clone()
}
}
pub fn sum_start(&self) -> &U {
&self.sum_dimension
}
pub fn sum_end(&self, cx: &<T::Summary as Summary>::Context) -> U {
if let Some(item_summary) = self.item_summary() {
let mut end = self.sum_start().clone();
end.add_summary(item_summary, cx);
end
} else {
self.sum_start().clone()
self.start().clone()
}
}
@ -167,8 +132,7 @@ where
assert!(self.did_seek, "Must seek before calling this method");
if self.at_end {
self.seek_dimension = S::default();
self.sum_dimension = U::default();
self.position = D::default();
self.descend_to_last_item(self.tree, cx);
self.at_end = false;
} else {
@ -176,17 +140,10 @@ where
if entry.index > 0 {
let new_index = entry.index - 1;
if let Some(StackEntry {
seek_dimension,
sum_dimension,
..
}) = self.stack.last()
{
self.seek_dimension = seek_dimension.clone();
self.sum_dimension = sum_dimension.clone();
if let Some(StackEntry { position, .. }) = self.stack.last() {
self.position = position.clone();
} else {
self.seek_dimension = S::default();
self.sum_dimension = U::default();
self.position = D::default();
}
match entry.tree.0.as_ref() {
@ -196,27 +153,23 @@ where
..
} => {
for summary in &child_summaries[0..new_index] {
self.seek_dimension.add_summary(summary, cx);
self.sum_dimension.add_summary(summary, cx);
self.position.add_summary(summary, cx);
}
self.stack.push(StackEntry {
tree: entry.tree,
index: new_index,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
self.descend_to_last_item(&child_trees[new_index], cx);
}
Node::Leaf { item_summaries, .. } => {
for item_summary in &item_summaries[0..new_index] {
self.seek_dimension.add_summary(item_summary, cx);
self.sum_dimension.add_summary(item_summary, cx);
self.position.add_summary(item_summary, cx);
}
self.stack.push(StackEntry {
tree: entry.tree,
index: new_index,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
}
}
@ -241,8 +194,7 @@ where
self.stack.push(StackEntry {
tree: self.tree,
index: 0,
seek_dimension: S::default(),
sum_dimension: U::default(),
position: D::default(),
});
descend = true;
self.did_seek = true;
@ -258,8 +210,7 @@ where
..
} => {
if !descend {
entry.seek_dimension = self.seek_dimension.clone();
entry.sum_dimension = self.sum_dimension.clone();
entry.position = self.position.clone();
entry.index += 1;
}
@ -268,8 +219,7 @@ where
if filter_node(next_summary) {
break;
} else {
self.seek_dimension.add_summary(next_summary, cx);
self.sum_dimension.add_summary(next_summary, cx);
self.position.add_summary(next_summary, cx);
}
entry.index += 1;
}
@ -279,10 +229,8 @@ where
Node::Leaf { item_summaries, .. } => {
if !descend {
let item_summary = &item_summaries[entry.index];
self.seek_dimension.add_summary(item_summary, cx);
entry.seek_dimension.add_summary(item_summary, cx);
self.sum_dimension.add_summary(item_summary, cx);
entry.sum_dimension.add_summary(item_summary, cx);
self.position.add_summary(item_summary, cx);
entry.position.add_summary(item_summary, cx);
entry.index += 1;
}
@ -291,10 +239,8 @@ where
if filter_node(next_item_summary) {
return;
} else {
self.seek_dimension.add_summary(next_item_summary, cx);
entry.seek_dimension.add_summary(next_item_summary, cx);
self.sum_dimension.add_summary(next_item_summary, cx);
entry.sum_dimension.add_summary(next_item_summary, cx);
self.position.add_summary(next_item_summary, cx);
entry.position.add_summary(next_item_summary, cx);
entry.index += 1;
}
} else {
@ -310,8 +256,7 @@ where
self.stack.push(StackEntry {
tree: subtree,
index: 0,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
} else {
descend = false;
@ -337,29 +282,25 @@ where
..
} => {
for summary in &child_summaries[0..child_summaries.len() - 1] {
self.seek_dimension.add_summary(summary, cx);
self.sum_dimension.add_summary(summary, cx);
self.position.add_summary(summary, cx);
}
self.stack.push(StackEntry {
tree: subtree,
index: child_trees.len() - 1,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
subtree = child_trees.last().unwrap();
}
Node::Leaf { item_summaries, .. } => {
let last_index = item_summaries.len().saturating_sub(1);
for item_summary in &item_summaries[0..last_index] {
self.seek_dimension.add_summary(item_summary, cx);
self.sum_dimension.add_summary(item_summary, cx);
self.position.add_summary(item_summary, cx);
}
self.stack.push(StackEntry {
tree: subtree,
index: last_index,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
break;
}
@ -368,34 +309,47 @@ where
}
}
impl<'a, T, S, U> Cursor<'a, T, S, U>
impl<'a, T, D> Cursor<'a, T, D>
where
T: Item,
S: SeekDimension<'a, T::Summary>,
U: Dimension<'a, T::Summary>,
D: Dimension<'a, T::Summary>,
{
pub fn seek(&mut self, pos: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> bool {
pub fn seek<Target>(
&mut self,
pos: &Target,
bias: Bias,
cx: &<T::Summary as Summary>::Context,
) -> bool
where
Target: SeekTarget<'a, T::Summary, D>,
{
self.reset();
self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
self.seek_internal::<_, ()>(pos, bias, &mut SeekAggregate::None, cx)
}
pub fn seek_forward(
pub fn seek_forward<Target>(
&mut self,
pos: &S,
pos: &Target,
bias: Bias,
cx: &<T::Summary as Summary>::Context,
) -> bool {
self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
) -> bool
where
Target: SeekTarget<'a, T::Summary, D>,
{
self.seek_internal::<_, ()>(pos, bias, &mut SeekAggregate::None, cx)
}
pub fn slice(
pub fn slice<Target>(
&mut self,
end: &S,
end: &Target,
bias: Bias,
cx: &<T::Summary as Summary>::Context,
) -> SumTree<T> {
) -> SumTree<T>
where
Target: SeekTarget<'a, T::Summary, D>,
{
let mut slice = SeekAggregate::Slice(SumTree::new());
self.seek_internal::<()>(Some(end), bias, &mut slice, cx);
self.seek_internal::<_, ()>(end, bias, &mut slice, cx);
if let SeekAggregate::Slice(slice) = slice {
slice
} else {
@ -405,7 +359,7 @@ where
pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
let mut slice = SeekAggregate::Slice(SumTree::new());
self.seek_internal::<()>(None, Bias::Right, &mut slice, cx);
self.seek_internal::<_, ()>(&End::new(), Bias::Right, &mut slice, cx);
if let SeekAggregate::Slice(slice) = slice {
slice
} else {
@ -413,12 +367,18 @@ where
}
}
pub fn summary<D>(&mut self, end: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> D
pub fn summary<Target, Output>(
&mut self,
end: &Target,
bias: Bias,
cx: &<T::Summary as Summary>::Context,
) -> Output
where
D: Dimension<'a, T::Summary>,
Target: SeekTarget<'a, T::Summary, D>,
Output: Dimension<'a, T::Summary>,
{
let mut summary = SeekAggregate::Summary(D::default());
self.seek_internal(Some(end), bias, &mut summary, cx);
let mut summary = SeekAggregate::Summary(Output::default());
self.seek_internal(end, bias, &mut summary, cx);
if let SeekAggregate::Summary(summary) = summary {
summary
} else {
@ -426,32 +386,30 @@ where
}
}
fn seek_internal<D>(
fn seek_internal<Target, Output>(
&mut self,
target: Option<&S>,
target: &Target,
bias: Bias,
aggregate: &mut SeekAggregate<T, D>,
aggregate: &mut SeekAggregate<T, Output>,
cx: &<T::Summary as Summary>::Context,
) -> bool
where
D: Dimension<'a, T::Summary>,
Target: SeekTarget<'a, T::Summary, D>,
Output: Dimension<'a, T::Summary>,
{
if let Some(target) = target {
debug_assert!(
target.cmp(&self.seek_dimension, cx) >= Ordering::Equal,
"cannot seek backward from {:?} to {:?}",
self.seek_dimension,
target
);
}
debug_assert!(
target.cmp(&self.position, cx) >= Ordering::Equal,
"cannot seek backward from {:?} to {:?}",
self.position,
target
);
if !self.did_seek {
self.did_seek = true;
self.stack.push(StackEntry {
tree: self.tree,
index: 0,
seek_dimension: Default::default(),
sum_dimension: Default::default(),
position: Default::default(),
});
}
@ -471,16 +429,14 @@ where
.iter()
.zip(&child_summaries[entry.index..])
{
let mut child_end = self.seek_dimension.clone();
let mut child_end = self.position.clone();
child_end.add_summary(&child_summary, cx);
let comparison =
target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
let comparison = target.cmp(&child_end, cx);
if comparison == Ordering::Greater
|| (comparison == Ordering::Equal && bias == Bias::Right)
{
self.seek_dimension = child_end;
self.sum_dimension.add_summary(child_summary, cx);
self.position = child_end;
match aggregate {
SeekAggregate::None => {}
SeekAggregate::Slice(slice) => {
@ -491,14 +447,12 @@ where
}
}
entry.index += 1;
entry.seek_dimension = self.seek_dimension.clone();
entry.sum_dimension = self.sum_dimension.clone();
entry.position = self.position.clone();
} else {
self.stack.push(StackEntry {
tree: child_tree,
index: 0,
seek_dimension: self.seek_dimension.clone(),
sum_dimension: self.sum_dimension.clone(),
position: self.position.clone(),
});
ascending = false;
continue 'outer;
@ -521,25 +475,24 @@ where
.iter()
.zip(&item_summaries[entry.index..])
{
let mut child_end = self.seek_dimension.clone();
let mut child_end = self.position.clone();
child_end.add_summary(item_summary, cx);
let comparison =
target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
let comparison = target.cmp(&child_end, cx);
if comparison == Ordering::Greater
|| (comparison == Ordering::Equal && bias == Bias::Right)
{
self.seek_dimension = child_end;
self.sum_dimension.add_summary(item_summary, cx);
self.position = child_end;
match aggregate {
SeekAggregate::None => {}
SeekAggregate::Slice(_) => {
slice_items.push(item.clone());
slice_item_summaries.push(item_summary.clone());
slice_items_summary
.as_mut()
.unwrap()
.add_summary(item_summary, cx);
<T::Summary as Summary>::add_summary(
slice_items_summary.as_mut().unwrap(),
item_summary,
cx,
);
}
SeekAggregate::Summary(summary) => {
summary.add_summary(item_summary, cx);
@ -583,23 +536,22 @@ where
self.at_end = self.stack.is_empty();
debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
let mut end = self.seek_dimension.clone();
let mut end = self.position.clone();
if bias == Bias::Left {
if let Some(summary) = self.item_summary() {
end.add_summary(summary, cx);
}
}
target.map_or(false, |t| t.cmp(&end, cx) == Ordering::Equal)
target.cmp(&end, cx) == Ordering::Equal
}
}
impl<'a, T, S, Seek, Sum> Iterator for Cursor<'a, T, Seek, Sum>
impl<'a, T, S, D> Iterator for Cursor<'a, T, D>
where
T: Item<Summary = S>,
S: Summary<Context = ()>,
Seek: Dimension<'a, T::Summary>,
Sum: Dimension<'a, T::Summary>,
D: Dimension<'a, T::Summary>,
{
type Item = &'a T;
@ -617,45 +569,23 @@ where
}
}
impl<'a, T, S, U> Cursor<'a, T, S, U>
where
T: Item,
S: SeekDimension<'a, T::Summary>,
U: SeekDimension<'a, T::Summary>,
{
pub fn swap_dimensions(self) -> Cursor<'a, T, U, S> {
Cursor {
tree: self.tree,
stack: self
.stack
.into_iter()
.map(StackEntry::swap_dimensions)
.collect(),
seek_dimension: self.sum_dimension,
sum_dimension: self.seek_dimension,
did_seek: self.did_seek,
at_end: self.at_end,
}
}
}
pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
cursor: Cursor<'a, T, (), U>,
pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, D> {
cursor: Cursor<'a, T, D>,
filter_node: F,
}
impl<'a, F, T, U> FilterCursor<'a, F, T, U>
impl<'a, F, T, D> FilterCursor<'a, F, T, D>
where
F: Fn(&T::Summary) -> bool,
T: Item,
U: Dimension<'a, T::Summary>,
D: Dimension<'a, T::Summary>,
{
pub fn new(
tree: &'a SumTree<T>,
filter_node: F,
cx: &<T::Summary as Summary>::Context,
) -> Self {
let mut cursor = tree.cursor::<(), U>();
let mut cursor = tree.cursor::<D>();
cursor.next_internal(&filter_node, cx);
Self {
cursor,
@ -663,8 +593,8 @@ where
}
}
pub fn start(&self) -> &U {
self.cursor.sum_start()
pub fn start(&self) -> &D {
self.cursor.start()
}
pub fn item(&self) -> Option<&'a T> {