use std::{cmp::Ordering, fmt::Debug}; use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary}; #[derive(Clone, Debug, PartialEq, Eq)] pub struct TreeMap(SumTree>) where K: Clone + Debug + Default + Ord, V: Clone + Debug; #[derive(Clone, Debug, PartialEq, Eq)] pub struct MapEntry { key: K, value: V, } #[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] pub struct MapKey(K); #[derive(Clone, Debug, Default)] pub struct MapKeyRef<'a, K>(Option<&'a K>); #[derive(Clone)] pub struct TreeSet(TreeMap) where K: Clone + Debug + Default + Ord; impl TreeMap { pub fn from_ordered_entries(entries: impl IntoIterator) -> Self { let tree = SumTree::from_iter( entries .into_iter() .map(|(key, value)| MapEntry { key, value }), &(), ); Self(tree) } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn get<'a>(&self, key: &'a K) -> Option<&V> { let mut cursor = self.0.cursor::>(); cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &()); if let Some(item) = cursor.item() { if *key == item.key().0 { Some(&item.value) } else { None } } else { None } } pub fn insert(&mut self, key: K, value: V) { self.0.insert_or_replace(MapEntry { key, value }, &()); } pub fn remove(&mut self, key: &K) -> Option { let mut removed = None; let mut cursor = self.0.cursor::>(); let key = MapKeyRef(Some(key)); let mut new_tree = cursor.slice(&key, Bias::Left, &()); if key.cmp(&cursor.end(&()), &()) == Ordering::Equal { removed = Some(cursor.item().unwrap().value.clone()); cursor.next(&()); } new_tree.push_tree(cursor.suffix(&()), &()); drop(cursor); self.0 = new_tree; removed } /// Returns the key-value pair with the greatest key less than or equal to the given key. pub fn closest(&self, key: &K) -> Option<(&K, &V)> { let mut cursor = self.0.cursor::>(); let key = MapKeyRef(Some(key)); cursor.seek(&key, Bias::Right, &()); cursor.prev(&()); cursor.item().map(|item| (&item.key, &item.value)) } pub fn remove_between(&mut self, from: &K, until: &K) { let mut cursor = self.0.cursor::>(); let from_key = MapKeyRef(Some(from)); let mut new_tree = cursor.slice(&from_key, Bias::Left, &()); let until_key = MapKeyRef(Some(until)); cursor.seek_forward(&until_key, Bias::Left, &()); new_tree.push_tree(cursor.suffix(&()), &()); drop(cursor); self.0 = new_tree; } pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator + '_ { let mut cursor = self.0.cursor::>(); let from_key = MapKeyRef(Some(from)); cursor.seek(&from_key, Bias::Left, &()); cursor .into_iter() .map(|map_entry| (&map_entry.key, &map_entry.value)) } pub fn update(&mut self, key: &K, f: F) -> Option where F: FnOnce(&mut V) -> T, { let mut cursor = self.0.cursor::>(); let key = MapKeyRef(Some(key)); let mut new_tree = cursor.slice(&key, Bias::Left, &()); let mut result = None; if key.cmp(&cursor.end(&()), &()) == Ordering::Equal { let mut updated = cursor.item().unwrap().clone(); result = Some(f(&mut updated.value)); new_tree.push(updated, &()); cursor.next(&()); } new_tree.push_tree(cursor.suffix(&()), &()); drop(cursor); self.0 = new_tree; result } pub fn retain bool>(&mut self, mut predicate: F) { let mut new_map = SumTree::>::default(); let mut cursor = self.0.cursor::>(); cursor.next(&()); while let Some(item) = cursor.item() { if predicate(&item.key, &item.value) { new_map.push(item.clone(), &()); } cursor.next(&()); } drop(cursor); self.0 = new_map; } pub fn iter(&self) -> impl Iterator + '_ { self.0.iter().map(|entry| (&entry.key, &entry.value)) } pub fn values(&self) -> impl Iterator + '_ { self.0.iter().map(|entry| &entry.value) } pub fn insert_tree(&mut self, other: TreeMap) { let edits = other .iter() .map(|(key, value)| { Edit::Insert(MapEntry { key: key.to_owned(), value: value.to_owned(), }) }) .collect(); self.0.edit(edits, &()); } pub fn remove_by(&mut self, key: &K, f: F) where F: Fn(&K) -> bool, { let mut cursor = self.0.cursor::>(); let key = MapKeyRef(Some(key)); let mut new_tree = cursor.slice(&key, Bias::Left, &()); let until = RemoveByTarget(key, &f); cursor.seek_forward(&until, Bias::Right, &()); new_tree.push_tree(cursor.suffix(&()), &()); drop(cursor); self.0 = new_tree; } } struct RemoveByTarget<'a, K>(MapKeyRef<'a, K>, &'a dyn Fn(&K) -> bool); impl<'a, K: Debug> Debug for RemoveByTarget<'a, K> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RemoveByTarget") .field("key", &self.0) .field("F", &"<...>") .finish() } } impl<'a, K: Debug + Clone + Default + Ord> SeekTarget<'a, MapKey, MapKeyRef<'a, K>> for RemoveByTarget<'_, K> { fn cmp( &self, cursor_location: &MapKeyRef<'a, K>, _cx: & as Summary>::Context, ) -> Ordering { if let Some(cursor_location) = cursor_location.0 { if (self.1)(cursor_location) { Ordering::Equal } else { self.0 .0.unwrap().cmp(cursor_location) } } else { Ordering::Greater } } } impl Default for TreeMap where K: Clone + Debug + Default + Ord, V: Clone + Debug, { fn default() -> Self { Self(Default::default()) } } impl Item for MapEntry where K: Clone + Debug + Default + Ord, V: Clone, { type Summary = MapKey; fn summary(&self) -> Self::Summary { self.key() } } impl KeyedItem for MapEntry where K: Clone + Debug + Default + Ord, V: Clone, { type Key = MapKey; fn key(&self) -> Self::Key { MapKey(self.key.clone()) } } impl Summary for MapKey where K: Clone + Debug + Default, { type Context = (); fn add_summary(&mut self, summary: &Self, _: &()) { *self = summary.clone() } } impl<'a, K> Dimension<'a, MapKey> for MapKeyRef<'a, K> where K: Clone + Debug + Default + Ord, { fn add_summary(&mut self, summary: &'a MapKey, _: &()) { self.0 = Some(&summary.0) } } impl<'a, K> SeekTarget<'a, MapKey, MapKeyRef<'a, K>> for MapKeyRef<'_, K> where K: Clone + Debug + Default + Ord, { fn cmp(&self, cursor_location: &MapKeyRef, _: &()) -> Ordering { self.0.cmp(&cursor_location.0) } } impl Default for TreeSet where K: Clone + Debug + Default + Ord, { fn default() -> Self { Self(Default::default()) } } impl TreeSet where K: Clone + Debug + Default + Ord, { pub fn from_ordered_entries(entries: impl IntoIterator) -> Self { Self(TreeMap::from_ordered_entries( entries.into_iter().map(|key| (key, ())), )) } pub fn insert(&mut self, key: K) { self.0.insert(key, ()); } pub fn contains(&self, key: &K) -> bool { self.0.get(key).is_some() } pub fn iter(&self) -> impl Iterator + '_ { self.0.iter().map(|(k, _)| k) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic() { let mut map = TreeMap::default(); assert_eq!(map.iter().collect::>(), vec![]); map.insert(3, "c"); assert_eq!(map.get(&3), Some(&"c")); assert_eq!(map.iter().collect::>(), vec![(&3, &"c")]); map.insert(1, "a"); assert_eq!(map.get(&1), Some(&"a")); assert_eq!(map.iter().collect::>(), vec![(&1, &"a"), (&3, &"c")]); map.insert(2, "b"); assert_eq!(map.get(&2), Some(&"b")); assert_eq!(map.get(&1), Some(&"a")); assert_eq!(map.get(&3), Some(&"c")); assert_eq!( map.iter().collect::>(), vec![(&1, &"a"), (&2, &"b"), (&3, &"c")] ); assert_eq!(map.closest(&0), None); assert_eq!(map.closest(&1), Some((&1, &"a"))); assert_eq!(map.closest(&10), Some((&3, &"c"))); map.remove(&2); assert_eq!(map.get(&2), None); assert_eq!(map.iter().collect::>(), vec![(&1, &"a"), (&3, &"c")]); assert_eq!(map.closest(&2), Some((&1, &"a"))); map.remove(&3); assert_eq!(map.get(&3), None); assert_eq!(map.iter().collect::>(), vec![(&1, &"a")]); map.remove(&1); assert_eq!(map.get(&1), None); assert_eq!(map.iter().collect::>(), vec![]); map.insert(4, "d"); map.insert(5, "e"); map.insert(6, "f"); map.retain(|key, _| *key % 2 == 0); assert_eq!(map.iter().collect::>(), vec![(&4, &"d"), (&6, &"f")]); } #[test] fn test_remove_between() { let mut map = TreeMap::default(); map.insert("a", 1); map.insert("b", 2); map.insert("baa", 3); map.insert("baaab", 4); map.insert("c", 5); map.remove_between(&"ba", &"bb"); assert_eq!(map.get(&"a"), Some(&1)); assert_eq!(map.get(&"b"), Some(&2)); assert_eq!(map.get(&"baaa"), None); assert_eq!(map.get(&"baaaab"), None); assert_eq!(map.get(&"c"), Some(&5)); } #[test] fn test_remove_by() { let mut map = TreeMap::default(); map.insert("a", 1); map.insert("aa", 1); map.insert("b", 2); map.insert("baa", 3); map.insert("baaab", 4); map.insert("c", 5); map.insert("ca", 6); map.remove_by(&"ba", |key| key.starts_with("ba")); assert_eq!(map.get(&"a"), Some(&1)); assert_eq!(map.get(&"aa"), Some(&1)); assert_eq!(map.get(&"b"), Some(&2)); assert_eq!(map.get(&"baaa"), None); assert_eq!(map.get(&"baaaab"), None); assert_eq!(map.get(&"c"), Some(&5)); assert_eq!(map.get(&"ca"), Some(&6)); map.remove_by(&"c", |key| key.starts_with("c")); assert_eq!(map.get(&"a"), Some(&1)); assert_eq!(map.get(&"aa"), Some(&1)); assert_eq!(map.get(&"b"), Some(&2)); assert_eq!(map.get(&"c"), None); assert_eq!(map.get(&"ca"), None); map.remove_by(&"a", |key| key.starts_with("a")); assert_eq!(map.get(&"a"), None); assert_eq!(map.get(&"aa"), None); assert_eq!(map.get(&"b"), Some(&2)); map.remove_by(&"b", |key| key.starts_with("b")); assert_eq!(map.get(&"b"), None); } #[test] fn test_iter_from() { let mut map = TreeMap::default(); map.insert("a", 1); map.insert("b", 2); map.insert("baa", 3); map.insert("baaab", 4); map.insert("c", 5); let result = map .iter_from(&"ba") .take_while(|(key, _)| key.starts_with(&"ba")) .collect::>(); assert_eq!(result.len(), 2); assert!(result.iter().find(|(k, _)| k == &&"baa").is_some()); assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some()); let result = map .iter_from(&"c") .take_while(|(key, _)| key.starts_with(&"c")) .collect::>(); assert_eq!(result.len(), 1); assert!(result.iter().find(|(k, _)| k == &&"c").is_some()); } #[test] fn test_insert_tree() { let mut map = TreeMap::default(); map.insert("a", 1); map.insert("b", 2); map.insert("c", 3); let mut other = TreeMap::default(); other.insert("a", 2); other.insert("b", 2); other.insert("d", 4); map.insert_tree(other); assert_eq!(map.iter().count(), 4); assert_eq!(map.get(&"a"), Some(&2)); assert_eq!(map.get(&"b"), Some(&2)); assert_eq!(map.get(&"c"), Some(&3)); assert_eq!(map.get(&"d"), Some(&4)); } }