datafusion_physical_plan/aggregates/topk/
heap.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A custom binary heap implementation for performant top K aggregation
19
20use arrow::array::{
21    cast::AsArray,
22    types::{IntervalDayTime, IntervalMonthDayNano},
23};
24use arrow::array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
25use arrow::buffer::ScalarBuffer;
26use arrow::datatypes::{i256, DataType};
27use datafusion_common::exec_datafusion_err;
28use datafusion_common::Result;
29
30use half::f16;
31use std::cmp::Ordering;
32use std::fmt::{Debug, Display, Formatter};
33use std::sync::Arc;
34
35/// A custom version of `Ord` that only exists to we can implement it for the Values in our heap
36pub trait Comparable {
37    fn comp(&self, other: &Self) -> Ordering;
38}
39
40impl Comparable for Option<String> {
41    fn comp(&self, other: &Self) -> Ordering {
42        self.cmp(other)
43    }
44}
45
46/// A "type alias" for Values which are stored in our heap
47pub trait ValueType: Comparable + Clone + Debug {}
48
49impl<T> ValueType for T where T: Comparable + Clone + Debug {}
50
51/// An entry in our heap, which contains both the value and a index into an external HashTable
52struct HeapItem<VAL: ValueType> {
53    val: VAL,
54    map_idx: usize,
55}
56
57/// A custom heap implementation that allows several things that couldn't be achieved with
58/// `collections::BinaryHeap`:
59/// 1. It allows values to be updated at arbitrary positions (when group values change)
60/// 2. It can be either a min or max heap
61/// 3. It can use our `HeapItem` type & `Comparable` trait
62/// 4. It is specialized to grow to a certain limit, then always replace without grow & shrink
63struct TopKHeap<VAL: ValueType> {
64    desc: bool,
65    len: usize,
66    capacity: usize,
67    heap: Vec<Option<HeapItem<VAL>>>,
68}
69
70/// An interface to hide the generic type signature of TopKHeap behind arrow arrays
71pub trait ArrowHeap {
72    fn set_batch(&mut self, vals: ArrayRef);
73    fn is_worse(&self, idx: usize) -> bool;
74    fn worst_map_idx(&self) -> usize;
75    fn renumber(&mut self, heap_to_map: &[(usize, usize)]);
76    fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>);
77    fn replace_if_better(
78        &mut self,
79        heap_idx: usize,
80        row_idx: usize,
81        map: &mut Vec<(usize, usize)>,
82    );
83    fn drain(&mut self) -> (ArrayRef, Vec<usize>);
84}
85
86/// An implementation of `ArrowHeap` that deals with primitive values
87pub struct PrimitiveHeap<VAL: ArrowPrimitiveType>
88where
89    <VAL as ArrowPrimitiveType>::Native: Comparable,
90{
91    batch: ArrayRef,
92    heap: TopKHeap<VAL::Native>,
93    desc: bool,
94    data_type: DataType,
95}
96
97impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL>
98where
99    <VAL as ArrowPrimitiveType>::Native: Comparable,
100{
101    pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self {
102        let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
103        Self {
104            batch: owned,
105            heap: TopKHeap::new(limit, desc),
106            desc,
107            data_type,
108        }
109    }
110}
111
112impl<VAL: ArrowPrimitiveType> ArrowHeap for PrimitiveHeap<VAL>
113where
114    <VAL as ArrowPrimitiveType>::Native: Comparable,
115{
116    fn set_batch(&mut self, vals: ArrayRef) {
117        self.batch = vals;
118    }
119
120    fn is_worse(&self, row_idx: usize) -> bool {
121        if !self.heap.is_full() {
122            return false;
123        }
124        let vals = self.batch.as_primitive::<VAL>();
125        let new_val = vals.value(row_idx);
126        let worst_val = self.heap.worst_val().expect("Missing root");
127        (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val)
128    }
129
130    fn worst_map_idx(&self) -> usize {
131        self.heap.worst_map_idx()
132    }
133
134    fn renumber(&mut self, heap_to_map: &[(usize, usize)]) {
135        self.heap.renumber(heap_to_map);
136    }
137
138    fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) {
139        let vals = self.batch.as_primitive::<VAL>();
140        let new_val = vals.value(row_idx);
141        self.heap.append_or_replace(new_val, map_idx, map);
142    }
143
144    fn replace_if_better(
145        &mut self,
146        heap_idx: usize,
147        row_idx: usize,
148        map: &mut Vec<(usize, usize)>,
149    ) {
150        let vals = self.batch.as_primitive::<VAL>();
151        let new_val = vals.value(row_idx);
152        self.heap.replace_if_better(heap_idx, new_val, map);
153    }
154
155    fn drain(&mut self) -> (ArrayRef, Vec<usize>) {
156        let nulls = None;
157        let (vals, map_idxs) = self.heap.drain();
158        let arr = PrimitiveArray::<VAL>::new(ScalarBuffer::from(vals), nulls)
159            .with_data_type(self.data_type.clone());
160        (Arc::new(arr), map_idxs)
161    }
162}
163
164impl<VAL: ValueType> TopKHeap<VAL> {
165    pub fn new(limit: usize, desc: bool) -> Self {
166        Self {
167            desc,
168            capacity: limit,
169            len: 0,
170            heap: (0..=limit).map(|_| None).collect::<Vec<_>>(),
171        }
172    }
173
174    pub fn worst_val(&self) -> Option<&VAL> {
175        let root = self.heap.first()?;
176        let hi = match root {
177            None => return None,
178            Some(hi) => hi,
179        };
180        Some(&hi.val)
181    }
182
183    pub fn worst_map_idx(&self) -> usize {
184        self.heap[0].as_ref().map(|hi| hi.map_idx).unwrap_or(0)
185    }
186
187    pub fn is_full(&self) -> bool {
188        self.len >= self.capacity
189    }
190
191    pub fn len(&self) -> usize {
192        self.len
193    }
194
195    pub fn append_or_replace(
196        &mut self,
197        new_val: VAL,
198        map_idx: usize,
199        map: &mut Vec<(usize, usize)>,
200    ) {
201        if self.is_full() {
202            self.replace_root(new_val, map_idx, map);
203        } else {
204            self.append(new_val, map_idx, map);
205        }
206    }
207
208    fn append(&mut self, new_val: VAL, map_idx: usize, mapper: &mut Vec<(usize, usize)>) {
209        let hi = HeapItem::new(new_val, map_idx);
210        self.heap[self.len] = Some(hi);
211        self.heapify_up(self.len, mapper);
212        self.len += 1;
213    }
214
215    fn pop(&mut self, map: &mut Vec<(usize, usize)>) -> Option<HeapItem<VAL>> {
216        if self.len() == 0 {
217            return None;
218        }
219        if self.len() == 1 {
220            self.len = 0;
221            return self.heap[0].take();
222        }
223        self.swap(0, self.len - 1, map);
224        let former_root = self.heap[self.len - 1].take();
225        self.len -= 1;
226        self.heapify_down(0, map);
227        former_root
228    }
229
230    pub fn drain(&mut self) -> (Vec<VAL>, Vec<usize>) {
231        let mut map = Vec::with_capacity(self.len);
232        let mut vals = Vec::with_capacity(self.len);
233        let mut map_idxs = Vec::with_capacity(self.len);
234        while let Some(worst_hi) = self.pop(&mut map) {
235            vals.push(worst_hi.val);
236            map_idxs.push(worst_hi.map_idx);
237        }
238        vals.reverse();
239        map_idxs.reverse();
240        (vals, map_idxs)
241    }
242
243    fn replace_root(
244        &mut self,
245        new_val: VAL,
246        map_idx: usize,
247        mapper: &mut Vec<(usize, usize)>,
248    ) {
249        let hi = self.heap[0].as_mut().expect("No root");
250        hi.val = new_val;
251        hi.map_idx = map_idx;
252        self.heapify_down(0, mapper);
253    }
254
255    pub fn replace_if_better(
256        &mut self,
257        heap_idx: usize,
258        new_val: VAL,
259        mapper: &mut Vec<(usize, usize)>,
260    ) {
261        let existing = self.heap[heap_idx].as_mut().expect("Missing heap item");
262        if (!self.desc && new_val.comp(&existing.val) != Ordering::Less)
263            || (self.desc && new_val.comp(&existing.val) != Ordering::Greater)
264        {
265            return;
266        }
267        existing.val = new_val;
268        self.heapify_down(heap_idx, mapper);
269    }
270
271    pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) {
272        for (heap_idx, map_idx) in heap_to_map.iter() {
273            if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) {
274                hi.map_idx = *map_idx;
275            }
276        }
277    }
278
279    fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) {
280        let desc = self.desc;
281        while idx != 0 {
282            let parent_idx = (idx - 1) / 2;
283            let node = self.heap[idx].as_ref().expect("No heap item");
284            let parent = self.heap[parent_idx].as_ref().expect("No heap item");
285            if (!desc && node.val.comp(&parent.val) != Ordering::Greater)
286                || (desc && node.val.comp(&parent.val) != Ordering::Less)
287            {
288                return;
289            }
290            self.swap(idx, parent_idx, mapper);
291            idx = parent_idx;
292        }
293    }
294
295    fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) {
296        let a_hi = self.heap[a_idx].take().expect("Missing heap entry");
297        let b_hi = self.heap[b_idx].take().expect("Missing heap entry");
298
299        mapper.push((a_hi.map_idx, b_idx));
300        mapper.push((b_hi.map_idx, a_idx));
301
302        self.heap[a_idx] = Some(b_hi);
303        self.heap[b_idx] = Some(a_hi);
304    }
305
306    fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) {
307        let left_child = node_idx * 2 + 1;
308        let desc = self.desc;
309        let entry = self.heap.get(node_idx).expect("Missing node!");
310        let entry = entry.as_ref().expect("Missing node!");
311        let mut best_idx = node_idx;
312        let mut best_val = &entry.val;
313        for child_idx in left_child..=left_child + 1 {
314            if let Some(Some(child)) = self.heap.get(child_idx) {
315                if (!desc && child.val.comp(best_val) == Ordering::Greater)
316                    || (desc && child.val.comp(best_val) == Ordering::Less)
317                {
318                    best_val = &child.val;
319                    best_idx = child_idx;
320                }
321            }
322        }
323        if best_val.comp(&entry.val) != Ordering::Equal {
324            self.swap(best_idx, node_idx, mapper);
325            self.heapify_down(best_idx, mapper);
326        }
327    }
328
329    fn _tree_print(
330        &self,
331        idx: usize,
332        prefix: String,
333        is_tail: bool,
334        output: &mut String,
335    ) {
336        if let Some(Some(hi)) = self.heap.get(idx) {
337            let connector = if idx != 0 {
338                if is_tail {
339                    "└── "
340                } else {
341                    "├── "
342                }
343            } else {
344                ""
345            };
346            output.push_str(&format!(
347                "{}{}val={:?} idx={}, bucket={}\n",
348                prefix, connector, hi.val, idx, hi.map_idx
349            ));
350            let new_prefix = if is_tail { "" } else { "│   " };
351            let child_prefix = format!("{prefix}{new_prefix}");
352
353            let left_idx = idx * 2 + 1;
354            let right_idx = idx * 2 + 2;
355
356            let left_exists = left_idx < self.len;
357            let right_exists = right_idx < self.len;
358
359            if left_exists {
360                self._tree_print(left_idx, child_prefix.clone(), !right_exists, output);
361            }
362            if right_exists {
363                self._tree_print(right_idx, child_prefix, true, output);
364            }
365        }
366    }
367}
368
369impl<VAL: ValueType> Display for TopKHeap<VAL> {
370    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
371        let mut output = String::new();
372        if !self.heap.is_empty() {
373            self._tree_print(0, String::new(), true, &mut output);
374        }
375        write!(f, "{output}")
376    }
377}
378
379impl<VAL: ValueType> HeapItem<VAL> {
380    pub fn new(val: VAL, buk_idx: usize) -> Self {
381        Self {
382            val,
383            map_idx: buk_idx,
384        }
385    }
386}
387
388impl<VAL: ValueType> Debug for HeapItem<VAL> {
389    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
390        f.write_str("bucket=")?;
391        Debug::fmt(&self.map_idx, f)?;
392        f.write_str(" val=")?;
393        Debug::fmt(&self.val, f)?;
394        f.write_str("\n")?;
395        Ok(())
396    }
397}
398
399impl<VAL: ValueType> Eq for HeapItem<VAL> {}
400
401impl<VAL: ValueType> PartialEq<Self> for HeapItem<VAL> {
402    fn eq(&self, other: &Self) -> bool {
403        self.cmp(other) == Ordering::Equal
404    }
405}
406
407impl<VAL: ValueType> PartialOrd<Self> for HeapItem<VAL> {
408    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
409        Some(self.cmp(other))
410    }
411}
412
413impl<VAL: ValueType> Ord for HeapItem<VAL> {
414    fn cmp(&self, other: &Self) -> Ordering {
415        let res = self.val.comp(&other.val);
416        if res != Ordering::Equal {
417            return res;
418        }
419        self.map_idx.cmp(&other.map_idx)
420    }
421}
422
423macro_rules! compare_float {
424    ($($t:ty),+) => {
425        $(impl Comparable for Option<$t> {
426            fn comp(&self, other: &Self) -> Ordering {
427                match (self, other) {
428                    (Some(me), Some(other)) => me.total_cmp(other),
429                    (Some(_), None) => Ordering::Greater,
430                    (None, Some(_)) => Ordering::Less,
431                    (None, None) => Ordering::Equal,
432                }
433            }
434        })+
435
436        $(impl Comparable for $t {
437            fn comp(&self, other: &Self) -> Ordering {
438                self.total_cmp(other)
439            }
440        })+
441    };
442}
443
444macro_rules! compare_integer {
445    ($($t:ty),+) => {
446        $(impl Comparable for Option<$t> {
447            fn comp(&self, other: &Self) -> Ordering {
448                self.cmp(other)
449            }
450        })+
451
452        $(impl Comparable for $t {
453            fn comp(&self, other: &Self) -> Ordering {
454                self.cmp(other)
455            }
456        })+
457    };
458}
459
460compare_integer!(i8, i16, i32, i64, i128, i256);
461compare_integer!(u8, u16, u32, u64);
462compare_integer!(IntervalDayTime, IntervalMonthDayNano);
463compare_float!(f16, f32, f64);
464
465pub fn new_heap(
466    limit: usize,
467    desc: bool,
468    vt: DataType,
469) -> Result<Box<dyn ArrowHeap + Send>> {
470    macro_rules! downcast_helper {
471        ($vt:ty, $d:ident) => {
472            return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt)))
473        };
474    }
475
476    downcast_primitive! {
477        vt => (downcast_helper, vt),
478        _ => {}
479    }
480
481    Err(exec_datafusion_err!("Can't group type: {vt:?}"))
482}
483
484#[cfg(test)]
485mod tests {
486    use insta::assert_snapshot;
487
488    use super::*;
489
490    #[test]
491    fn should_append() -> Result<()> {
492        let mut map = vec![];
493        let mut heap = TopKHeap::new(10, false);
494        heap.append_or_replace(1, 1, &mut map);
495
496        let actual = heap.to_string();
497        assert_snapshot!(actual, @r#"
498val=1 idx=0, bucket=1
499            "#);
500
501        Ok(())
502    }
503
504    #[test]
505    fn should_heapify_up() -> Result<()> {
506        let mut map = vec![];
507        let mut heap = TopKHeap::new(10, false);
508
509        heap.append_or_replace(1, 1, &mut map);
510        assert_eq!(map, vec![]);
511
512        heap.append_or_replace(2, 2, &mut map);
513        assert_eq!(map, vec![(2, 0), (1, 1)]);
514
515        let actual = heap.to_string();
516        assert_snapshot!(actual, @r#"
517val=2 idx=0, bucket=2
518└── val=1 idx=1, bucket=1
519            "#);
520
521        Ok(())
522    }
523
524    #[test]
525    fn should_heapify_down() -> Result<()> {
526        let mut map = vec![];
527        let mut heap = TopKHeap::new(3, false);
528
529        heap.append_or_replace(1, 1, &mut map);
530        heap.append_or_replace(2, 2, &mut map);
531        heap.append_or_replace(3, 3, &mut map);
532        let actual = heap.to_string();
533        assert_snapshot!(actual, @r#"
534val=3 idx=0, bucket=3
535├── val=1 idx=1, bucket=1
536└── val=2 idx=2, bucket=2
537            "#);
538
539        let mut map = vec![];
540        heap.append_or_replace(0, 0, &mut map);
541        let actual = heap.to_string();
542        assert_snapshot!(actual, @r#"
543val=2 idx=0, bucket=2
544├── val=1 idx=1, bucket=1
545└── val=0 idx=2, bucket=0
546            "#);
547        assert_eq!(map, vec![(2, 0), (0, 2)]);
548
549        Ok(())
550    }
551
552    #[test]
553    fn should_replace() -> Result<()> {
554        let mut map = vec![];
555        let mut heap = TopKHeap::new(4, false);
556
557        heap.append_or_replace(1, 1, &mut map);
558        heap.append_or_replace(2, 2, &mut map);
559        heap.append_or_replace(3, 3, &mut map);
560        heap.append_or_replace(4, 4, &mut map);
561        let actual = heap.to_string();
562        assert_snapshot!(actual, @r#"
563val=4 idx=0, bucket=4
564├── val=3 idx=1, bucket=3
565│   └── val=1 idx=3, bucket=1
566└── val=2 idx=2, bucket=2
567            "#);
568
569        let mut map = vec![];
570        heap.replace_if_better(1, 0, &mut map);
571        let actual = heap.to_string();
572        assert_snapshot!(actual, @r#"
573val=4 idx=0, bucket=4
574├── val=1 idx=1, bucket=1
575│   └── val=0 idx=3, bucket=3
576└── val=2 idx=2, bucket=2
577            "#);
578        assert_eq!(map, vec![(1, 1), (3, 3)]);
579
580        Ok(())
581    }
582
583    #[test]
584    fn should_find_worst() -> Result<()> {
585        let mut map = vec![];
586        let mut heap = TopKHeap::new(10, false);
587
588        heap.append_or_replace(1, 1, &mut map);
589        heap.append_or_replace(2, 2, &mut map);
590
591        let actual = heap.to_string();
592        assert_snapshot!(actual, @r#"
593val=2 idx=0, bucket=2
594└── val=1 idx=1, bucket=1
595            "#);
596
597        assert_eq!(heap.worst_val(), Some(&2));
598        assert_eq!(heap.worst_map_idx(), 2);
599
600        Ok(())
601    }
602
603    #[test]
604    fn should_drain() -> Result<()> {
605        let mut map = vec![];
606        let mut heap = TopKHeap::new(10, false);
607
608        heap.append_or_replace(1, 1, &mut map);
609        heap.append_or_replace(2, 2, &mut map);
610
611        let actual = heap.to_string();
612        assert_snapshot!(actual, @r#"
613val=2 idx=0, bucket=2
614└── val=1 idx=1, bucket=1
615            "#);
616
617        let (vals, map_idxs) = heap.drain();
618        assert_eq!(vals, vec![1, 2]);
619        assert_eq!(map_idxs, vec![1, 2]);
620        assert_eq!(heap.len(), 0);
621
622        Ok(())
623    }
624
625    #[test]
626    fn should_renumber() -> Result<()> {
627        let mut map = vec![];
628        let mut heap = TopKHeap::new(10, false);
629
630        heap.append_or_replace(1, 1, &mut map);
631        heap.append_or_replace(2, 2, &mut map);
632
633        let actual = heap.to_string();
634        assert_snapshot!(actual, @r#"
635val=2 idx=0, bucket=2
636└── val=1 idx=1, bucket=1
637            "#);
638
639        let numbers = vec![(0, 1), (1, 2)];
640        heap.renumber(numbers.as_slice());
641        let actual = heap.to_string();
642        assert_snapshot!(actual, @r#"
643val=2 idx=0, bucket=1
644└── val=1 idx=1, bucket=2
645            "#);
646
647        Ok(())
648    }
649}