1use arrow::array::{
21 cast::AsArray,
22 types::{IntervalDayTime, IntervalMonthDayNano},
23};
24use arrow::array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
25use arrow::buffer::ScalarBuffer;
26use arrow::datatypes::{i256, DataType};
27use datafusion_common::exec_datafusion_err;
28use datafusion_common::Result;
29
30use half::f16;
31use std::cmp::Ordering;
32use std::fmt::{Debug, Display, Formatter};
33use std::sync::Arc;
34
35pub trait Comparable {
37 fn comp(&self, other: &Self) -> Ordering;
38}
39
40impl Comparable for Option<String> {
41 fn comp(&self, other: &Self) -> Ordering {
42 self.cmp(other)
43 }
44}
45
46pub trait ValueType: Comparable + Clone + Debug {}
48
49impl<T> ValueType for T where T: Comparable + Clone + Debug {}
50
51struct HeapItem<VAL: ValueType> {
53 val: VAL,
54 map_idx: usize,
55}
56
57struct TopKHeap<VAL: ValueType> {
64 desc: bool,
65 len: usize,
66 capacity: usize,
67 heap: Vec<Option<HeapItem<VAL>>>,
68}
69
70pub trait ArrowHeap {
72 fn set_batch(&mut self, vals: ArrayRef);
73 fn is_worse(&self, idx: usize) -> bool;
74 fn worst_map_idx(&self) -> usize;
75 fn renumber(&mut self, heap_to_map: &[(usize, usize)]);
76 fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>);
77 fn replace_if_better(
78 &mut self,
79 heap_idx: usize,
80 row_idx: usize,
81 map: &mut Vec<(usize, usize)>,
82 );
83 fn drain(&mut self) -> (ArrayRef, Vec<usize>);
84}
85
86pub struct PrimitiveHeap<VAL: ArrowPrimitiveType>
88where
89 <VAL as ArrowPrimitiveType>::Native: Comparable,
90{
91 batch: ArrayRef,
92 heap: TopKHeap<VAL::Native>,
93 desc: bool,
94 data_type: DataType,
95}
96
97impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL>
98where
99 <VAL as ArrowPrimitiveType>::Native: Comparable,
100{
101 pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self {
102 let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
103 Self {
104 batch: owned,
105 heap: TopKHeap::new(limit, desc),
106 desc,
107 data_type,
108 }
109 }
110}
111
112impl<VAL: ArrowPrimitiveType> ArrowHeap for PrimitiveHeap<VAL>
113where
114 <VAL as ArrowPrimitiveType>::Native: Comparable,
115{
116 fn set_batch(&mut self, vals: ArrayRef) {
117 self.batch = vals;
118 }
119
120 fn is_worse(&self, row_idx: usize) -> bool {
121 if !self.heap.is_full() {
122 return false;
123 }
124 let vals = self.batch.as_primitive::<VAL>();
125 let new_val = vals.value(row_idx);
126 let worst_val = self.heap.worst_val().expect("Missing root");
127 (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val)
128 }
129
130 fn worst_map_idx(&self) -> usize {
131 self.heap.worst_map_idx()
132 }
133
134 fn renumber(&mut self, heap_to_map: &[(usize, usize)]) {
135 self.heap.renumber(heap_to_map);
136 }
137
138 fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) {
139 let vals = self.batch.as_primitive::<VAL>();
140 let new_val = vals.value(row_idx);
141 self.heap.append_or_replace(new_val, map_idx, map);
142 }
143
144 fn replace_if_better(
145 &mut self,
146 heap_idx: usize,
147 row_idx: usize,
148 map: &mut Vec<(usize, usize)>,
149 ) {
150 let vals = self.batch.as_primitive::<VAL>();
151 let new_val = vals.value(row_idx);
152 self.heap.replace_if_better(heap_idx, new_val, map);
153 }
154
155 fn drain(&mut self) -> (ArrayRef, Vec<usize>) {
156 let nulls = None;
157 let (vals, map_idxs) = self.heap.drain();
158 let arr = PrimitiveArray::<VAL>::new(ScalarBuffer::from(vals), nulls)
159 .with_data_type(self.data_type.clone());
160 (Arc::new(arr), map_idxs)
161 }
162}
163
164impl<VAL: ValueType> TopKHeap<VAL> {
165 pub fn new(limit: usize, desc: bool) -> Self {
166 Self {
167 desc,
168 capacity: limit,
169 len: 0,
170 heap: (0..=limit).map(|_| None).collect::<Vec<_>>(),
171 }
172 }
173
174 pub fn worst_val(&self) -> Option<&VAL> {
175 let root = self.heap.first()?;
176 let hi = match root {
177 None => return None,
178 Some(hi) => hi,
179 };
180 Some(&hi.val)
181 }
182
183 pub fn worst_map_idx(&self) -> usize {
184 self.heap[0].as_ref().map(|hi| hi.map_idx).unwrap_or(0)
185 }
186
187 pub fn is_full(&self) -> bool {
188 self.len >= self.capacity
189 }
190
191 pub fn len(&self) -> usize {
192 self.len
193 }
194
195 pub fn append_or_replace(
196 &mut self,
197 new_val: VAL,
198 map_idx: usize,
199 map: &mut Vec<(usize, usize)>,
200 ) {
201 if self.is_full() {
202 self.replace_root(new_val, map_idx, map);
203 } else {
204 self.append(new_val, map_idx, map);
205 }
206 }
207
208 fn append(&mut self, new_val: VAL, map_idx: usize, mapper: &mut Vec<(usize, usize)>) {
209 let hi = HeapItem::new(new_val, map_idx);
210 self.heap[self.len] = Some(hi);
211 self.heapify_up(self.len, mapper);
212 self.len += 1;
213 }
214
215 fn pop(&mut self, map: &mut Vec<(usize, usize)>) -> Option<HeapItem<VAL>> {
216 if self.len() == 0 {
217 return None;
218 }
219 if self.len() == 1 {
220 self.len = 0;
221 return self.heap[0].take();
222 }
223 self.swap(0, self.len - 1, map);
224 let former_root = self.heap[self.len - 1].take();
225 self.len -= 1;
226 self.heapify_down(0, map);
227 former_root
228 }
229
230 pub fn drain(&mut self) -> (Vec<VAL>, Vec<usize>) {
231 let mut map = Vec::with_capacity(self.len);
232 let mut vals = Vec::with_capacity(self.len);
233 let mut map_idxs = Vec::with_capacity(self.len);
234 while let Some(worst_hi) = self.pop(&mut map) {
235 vals.push(worst_hi.val);
236 map_idxs.push(worst_hi.map_idx);
237 }
238 vals.reverse();
239 map_idxs.reverse();
240 (vals, map_idxs)
241 }
242
243 fn replace_root(
244 &mut self,
245 new_val: VAL,
246 map_idx: usize,
247 mapper: &mut Vec<(usize, usize)>,
248 ) {
249 let hi = self.heap[0].as_mut().expect("No root");
250 hi.val = new_val;
251 hi.map_idx = map_idx;
252 self.heapify_down(0, mapper);
253 }
254
255 pub fn replace_if_better(
256 &mut self,
257 heap_idx: usize,
258 new_val: VAL,
259 mapper: &mut Vec<(usize, usize)>,
260 ) {
261 let existing = self.heap[heap_idx].as_mut().expect("Missing heap item");
262 if (!self.desc && new_val.comp(&existing.val) != Ordering::Less)
263 || (self.desc && new_val.comp(&existing.val) != Ordering::Greater)
264 {
265 return;
266 }
267 existing.val = new_val;
268 self.heapify_down(heap_idx, mapper);
269 }
270
271 pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) {
272 for (heap_idx, map_idx) in heap_to_map.iter() {
273 if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) {
274 hi.map_idx = *map_idx;
275 }
276 }
277 }
278
279 fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) {
280 let desc = self.desc;
281 while idx != 0 {
282 let parent_idx = (idx - 1) / 2;
283 let node = self.heap[idx].as_ref().expect("No heap item");
284 let parent = self.heap[parent_idx].as_ref().expect("No heap item");
285 if (!desc && node.val.comp(&parent.val) != Ordering::Greater)
286 || (desc && node.val.comp(&parent.val) != Ordering::Less)
287 {
288 return;
289 }
290 self.swap(idx, parent_idx, mapper);
291 idx = parent_idx;
292 }
293 }
294
295 fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) {
296 let a_hi = self.heap[a_idx].take().expect("Missing heap entry");
297 let b_hi = self.heap[b_idx].take().expect("Missing heap entry");
298
299 mapper.push((a_hi.map_idx, b_idx));
300 mapper.push((b_hi.map_idx, a_idx));
301
302 self.heap[a_idx] = Some(b_hi);
303 self.heap[b_idx] = Some(a_hi);
304 }
305
306 fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) {
307 let left_child = node_idx * 2 + 1;
308 let desc = self.desc;
309 let entry = self.heap.get(node_idx).expect("Missing node!");
310 let entry = entry.as_ref().expect("Missing node!");
311 let mut best_idx = node_idx;
312 let mut best_val = &entry.val;
313 for child_idx in left_child..=left_child + 1 {
314 if let Some(Some(child)) = self.heap.get(child_idx) {
315 if (!desc && child.val.comp(best_val) == Ordering::Greater)
316 || (desc && child.val.comp(best_val) == Ordering::Less)
317 {
318 best_val = &child.val;
319 best_idx = child_idx;
320 }
321 }
322 }
323 if best_val.comp(&entry.val) != Ordering::Equal {
324 self.swap(best_idx, node_idx, mapper);
325 self.heapify_down(best_idx, mapper);
326 }
327 }
328
329 fn _tree_print(
330 &self,
331 idx: usize,
332 prefix: String,
333 is_tail: bool,
334 output: &mut String,
335 ) {
336 if let Some(Some(hi)) = self.heap.get(idx) {
337 let connector = if idx != 0 {
338 if is_tail {
339 "└── "
340 } else {
341 "├── "
342 }
343 } else {
344 ""
345 };
346 output.push_str(&format!(
347 "{}{}val={:?} idx={}, bucket={}\n",
348 prefix, connector, hi.val, idx, hi.map_idx
349 ));
350 let new_prefix = if is_tail { "" } else { "│ " };
351 let child_prefix = format!("{prefix}{new_prefix}");
352
353 let left_idx = idx * 2 + 1;
354 let right_idx = idx * 2 + 2;
355
356 let left_exists = left_idx < self.len;
357 let right_exists = right_idx < self.len;
358
359 if left_exists {
360 self._tree_print(left_idx, child_prefix.clone(), !right_exists, output);
361 }
362 if right_exists {
363 self._tree_print(right_idx, child_prefix, true, output);
364 }
365 }
366 }
367}
368
369impl<VAL: ValueType> Display for TopKHeap<VAL> {
370 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
371 let mut output = String::new();
372 if !self.heap.is_empty() {
373 self._tree_print(0, String::new(), true, &mut output);
374 }
375 write!(f, "{output}")
376 }
377}
378
379impl<VAL: ValueType> HeapItem<VAL> {
380 pub fn new(val: VAL, buk_idx: usize) -> Self {
381 Self {
382 val,
383 map_idx: buk_idx,
384 }
385 }
386}
387
388impl<VAL: ValueType> Debug for HeapItem<VAL> {
389 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
390 f.write_str("bucket=")?;
391 Debug::fmt(&self.map_idx, f)?;
392 f.write_str(" val=")?;
393 Debug::fmt(&self.val, f)?;
394 f.write_str("\n")?;
395 Ok(())
396 }
397}
398
399impl<VAL: ValueType> Eq for HeapItem<VAL> {}
400
401impl<VAL: ValueType> PartialEq<Self> for HeapItem<VAL> {
402 fn eq(&self, other: &Self) -> bool {
403 self.cmp(other) == Ordering::Equal
404 }
405}
406
407impl<VAL: ValueType> PartialOrd<Self> for HeapItem<VAL> {
408 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
409 Some(self.cmp(other))
410 }
411}
412
413impl<VAL: ValueType> Ord for HeapItem<VAL> {
414 fn cmp(&self, other: &Self) -> Ordering {
415 let res = self.val.comp(&other.val);
416 if res != Ordering::Equal {
417 return res;
418 }
419 self.map_idx.cmp(&other.map_idx)
420 }
421}
422
423macro_rules! compare_float {
424 ($($t:ty),+) => {
425 $(impl Comparable for Option<$t> {
426 fn comp(&self, other: &Self) -> Ordering {
427 match (self, other) {
428 (Some(me), Some(other)) => me.total_cmp(other),
429 (Some(_), None) => Ordering::Greater,
430 (None, Some(_)) => Ordering::Less,
431 (None, None) => Ordering::Equal,
432 }
433 }
434 })+
435
436 $(impl Comparable for $t {
437 fn comp(&self, other: &Self) -> Ordering {
438 self.total_cmp(other)
439 }
440 })+
441 };
442}
443
444macro_rules! compare_integer {
445 ($($t:ty),+) => {
446 $(impl Comparable for Option<$t> {
447 fn comp(&self, other: &Self) -> Ordering {
448 self.cmp(other)
449 }
450 })+
451
452 $(impl Comparable for $t {
453 fn comp(&self, other: &Self) -> Ordering {
454 self.cmp(other)
455 }
456 })+
457 };
458}
459
460compare_integer!(i8, i16, i32, i64, i128, i256);
461compare_integer!(u8, u16, u32, u64);
462compare_integer!(IntervalDayTime, IntervalMonthDayNano);
463compare_float!(f16, f32, f64);
464
465pub fn new_heap(
466 limit: usize,
467 desc: bool,
468 vt: DataType,
469) -> Result<Box<dyn ArrowHeap + Send>> {
470 macro_rules! downcast_helper {
471 ($vt:ty, $d:ident) => {
472 return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt)))
473 };
474 }
475
476 downcast_primitive! {
477 vt => (downcast_helper, vt),
478 _ => {}
479 }
480
481 Err(exec_datafusion_err!("Can't group type: {vt:?}"))
482}
483
484#[cfg(test)]
485mod tests {
486 use insta::assert_snapshot;
487
488 use super::*;
489
490 #[test]
491 fn should_append() -> Result<()> {
492 let mut map = vec![];
493 let mut heap = TopKHeap::new(10, false);
494 heap.append_or_replace(1, 1, &mut map);
495
496 let actual = heap.to_string();
497 assert_snapshot!(actual, @r#"
498val=1 idx=0, bucket=1
499 "#);
500
501 Ok(())
502 }
503
504 #[test]
505 fn should_heapify_up() -> Result<()> {
506 let mut map = vec![];
507 let mut heap = TopKHeap::new(10, false);
508
509 heap.append_or_replace(1, 1, &mut map);
510 assert_eq!(map, vec![]);
511
512 heap.append_or_replace(2, 2, &mut map);
513 assert_eq!(map, vec![(2, 0), (1, 1)]);
514
515 let actual = heap.to_string();
516 assert_snapshot!(actual, @r#"
517val=2 idx=0, bucket=2
518└── val=1 idx=1, bucket=1
519 "#);
520
521 Ok(())
522 }
523
524 #[test]
525 fn should_heapify_down() -> Result<()> {
526 let mut map = vec![];
527 let mut heap = TopKHeap::new(3, false);
528
529 heap.append_or_replace(1, 1, &mut map);
530 heap.append_or_replace(2, 2, &mut map);
531 heap.append_or_replace(3, 3, &mut map);
532 let actual = heap.to_string();
533 assert_snapshot!(actual, @r#"
534val=3 idx=0, bucket=3
535├── val=1 idx=1, bucket=1
536└── val=2 idx=2, bucket=2
537 "#);
538
539 let mut map = vec![];
540 heap.append_or_replace(0, 0, &mut map);
541 let actual = heap.to_string();
542 assert_snapshot!(actual, @r#"
543val=2 idx=0, bucket=2
544├── val=1 idx=1, bucket=1
545└── val=0 idx=2, bucket=0
546 "#);
547 assert_eq!(map, vec![(2, 0), (0, 2)]);
548
549 Ok(())
550 }
551
552 #[test]
553 fn should_replace() -> Result<()> {
554 let mut map = vec![];
555 let mut heap = TopKHeap::new(4, false);
556
557 heap.append_or_replace(1, 1, &mut map);
558 heap.append_or_replace(2, 2, &mut map);
559 heap.append_or_replace(3, 3, &mut map);
560 heap.append_or_replace(4, 4, &mut map);
561 let actual = heap.to_string();
562 assert_snapshot!(actual, @r#"
563val=4 idx=0, bucket=4
564├── val=3 idx=1, bucket=3
565│ └── val=1 idx=3, bucket=1
566└── val=2 idx=2, bucket=2
567 "#);
568
569 let mut map = vec![];
570 heap.replace_if_better(1, 0, &mut map);
571 let actual = heap.to_string();
572 assert_snapshot!(actual, @r#"
573val=4 idx=0, bucket=4
574├── val=1 idx=1, bucket=1
575│ └── val=0 idx=3, bucket=3
576└── val=2 idx=2, bucket=2
577 "#);
578 assert_eq!(map, vec![(1, 1), (3, 3)]);
579
580 Ok(())
581 }
582
583 #[test]
584 fn should_find_worst() -> Result<()> {
585 let mut map = vec![];
586 let mut heap = TopKHeap::new(10, false);
587
588 heap.append_or_replace(1, 1, &mut map);
589 heap.append_or_replace(2, 2, &mut map);
590
591 let actual = heap.to_string();
592 assert_snapshot!(actual, @r#"
593val=2 idx=0, bucket=2
594└── val=1 idx=1, bucket=1
595 "#);
596
597 assert_eq!(heap.worst_val(), Some(&2));
598 assert_eq!(heap.worst_map_idx(), 2);
599
600 Ok(())
601 }
602
603 #[test]
604 fn should_drain() -> Result<()> {
605 let mut map = vec![];
606 let mut heap = TopKHeap::new(10, false);
607
608 heap.append_or_replace(1, 1, &mut map);
609 heap.append_or_replace(2, 2, &mut map);
610
611 let actual = heap.to_string();
612 assert_snapshot!(actual, @r#"
613val=2 idx=0, bucket=2
614└── val=1 idx=1, bucket=1
615 "#);
616
617 let (vals, map_idxs) = heap.drain();
618 assert_eq!(vals, vec![1, 2]);
619 assert_eq!(map_idxs, vec![1, 2]);
620 assert_eq!(heap.len(), 0);
621
622 Ok(())
623 }
624
625 #[test]
626 fn should_renumber() -> Result<()> {
627 let mut map = vec![];
628 let mut heap = TopKHeap::new(10, false);
629
630 heap.append_or_replace(1, 1, &mut map);
631 heap.append_or_replace(2, 2, &mut map);
632
633 let actual = heap.to_string();
634 assert_snapshot!(actual, @r#"
635val=2 idx=0, bucket=2
636└── val=1 idx=1, bucket=1
637 "#);
638
639 let numbers = vec![(0, 1), (1, 2)];
640 heap.renumber(numbers.as_slice());
641 let actual = heap.to_string();
642 assert_snapshot!(actual, @r#"
643val=2 idx=0, bucket=1
644└── val=1 idx=1, bucket=2
645 "#);
646
647 Ok(())
648 }
649}