datafusion_physical_plan/aggregates/topk/
priority_map.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity`
19
20use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable};
21use crate::aggregates::topk::heap::{new_heap, ArrowHeap};
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType;
24use datafusion_common::Result;
25
26/// A `Map<K, V>` / `PriorityQueue` combo that evicts the worst values after reaching `capacity`
27pub struct PriorityMap {
28    map: Box<dyn ArrowHashTable + Send>,
29    heap: Box<dyn ArrowHeap + Send>,
30    capacity: usize,
31    mapper: Vec<(usize, usize)>,
32}
33
34impl PriorityMap {
35    pub fn new(
36        key_type: DataType,
37        val_type: DataType,
38        capacity: usize,
39        descending: bool,
40    ) -> Result<Self> {
41        Ok(Self {
42            map: new_hash_table(capacity, key_type)?,
43            heap: new_heap(capacity, descending, val_type)?,
44            capacity,
45            mapper: Vec::with_capacity(capacity),
46        })
47    }
48
49    pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) {
50        self.map.set_batch(ids);
51        self.heap.set_batch(vals);
52    }
53
54    pub fn insert(&mut self, row_idx: usize) -> Result<()> {
55        assert!(self.map.len() <= self.capacity, "Overflow");
56
57        // if we're full, and the new val is worse than all our values, just bail
58        if self.heap.is_worse(row_idx) {
59            return Ok(());
60        }
61        let map = &mut self.mapper;
62
63        // handle new groups we haven't seen yet
64        map.clear();
65        let replace_idx = self.heap.worst_map_idx();
66        // JUSTIFICATION
67        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
68        //  Soundness: replace_idx kept valid during resizes
69        let (map_idx, did_insert) =
70            unsafe { self.map.find_or_insert(row_idx, replace_idx, map) };
71        if did_insert {
72            self.heap.renumber(map);
73            map.clear();
74            self.heap.insert(row_idx, map_idx, map);
75            // JUSTIFICATION
76            //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
77            //  Soundness: the map was created on the line above, so all the indexes should be valid
78            unsafe { self.map.update_heap_idx(map) };
79            return Ok(());
80        };
81
82        // this is a value for an existing group
83        map.clear();
84        // JUSTIFICATION
85        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
86        //  Soundness: map_idx was just found, so it is valid
87        let heap_idx = unsafe { self.map.heap_idx_at(map_idx) };
88        self.heap.replace_if_better(heap_idx, row_idx, map);
89        // JUSTIFICATION
90        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
91        //  Soundness: the index map was just built, so it will be valid
92        unsafe { self.map.update_heap_idx(map) };
93
94        Ok(())
95    }
96
97    pub fn emit(&mut self) -> Result<Vec<ArrayRef>> {
98        let (vals, map_idxs) = self.heap.drain();
99        let ids = unsafe { self.map.take_all(map_idxs) };
100        Ok(vec![ids, vals])
101    }
102
103    pub fn is_empty(&self) -> bool {
104        self.map.len() == 0
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use arrow::array::{
112        Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray,
113    };
114    use arrow::datatypes::{Field, Schema, SchemaRef};
115    use arrow::util::pretty::pretty_format_batches;
116    use insta::assert_snapshot;
117    use std::sync::Arc;
118
119    #[test]
120    fn should_append_with_utf8view() -> Result<()> {
121        let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"]));
122        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
123        let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?;
124        agg.set_batch(ids, vals);
125        agg.insert(0)?;
126
127        let cols = agg.emit()?;
128        let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?;
129        let batch_schema = batch.schema();
130        assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View);
131
132        let actual = format!("{}", pretty_format_batches(&[batch])?);
133        let expected = r#"
134+----------+--------------+
135| trace_id | timestamp_ms |
136+----------+--------------+
137| 1        | 1            |
138+----------+--------------+
139        "#
140        .trim();
141        assert_eq!(actual, expected);
142
143        Ok(())
144    }
145
146    #[test]
147    fn should_append_with_large_utf8() -> Result<()> {
148        let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"]));
149        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
150        let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?;
151        agg.set_batch(ids, vals);
152        agg.insert(0)?;
153
154        let cols = agg.emit()?;
155        let batch = RecordBatch::try_new(test_large_schema(), cols)?;
156        let batch_schema = batch.schema();
157        assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8);
158
159        let actual = format!("{}", pretty_format_batches(&[batch])?);
160        let expected = r#"
161+----------+--------------+
162| trace_id | timestamp_ms |
163+----------+--------------+
164| 1        | 1            |
165+----------+--------------+
166        "#
167        .trim();
168        assert_eq!(actual, expected);
169
170        Ok(())
171    }
172
173    #[test]
174    fn should_append() -> Result<()> {
175        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"]));
176        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
177        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
178        agg.set_batch(ids, vals);
179        agg.insert(0)?;
180
181        let cols = agg.emit()?;
182        let batch = RecordBatch::try_new(test_schema(), cols)?;
183        let actual = format!("{}", pretty_format_batches(&[batch])?);
184
185        assert_snapshot!(actual, @r#"
186+----------+--------------+
187| trace_id | timestamp_ms |
188+----------+--------------+
189| 1        | 1            |
190+----------+--------------+
191        "#
192        );
193
194        Ok(())
195    }
196
197    #[test]
198    fn should_ignore_higher_group() -> Result<()> {
199        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
200        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
201        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
202        agg.set_batch(ids, vals);
203        agg.insert(0)?;
204        agg.insert(1)?;
205
206        let cols = agg.emit()?;
207        let batch = RecordBatch::try_new(test_schema(), cols)?;
208        let actual = format!("{}", pretty_format_batches(&[batch])?);
209
210        assert_snapshot!(actual, @r#"
211+----------+--------------+
212| trace_id | timestamp_ms |
213+----------+--------------+
214| 1        | 1            |
215+----------+--------------+
216        "#
217        );
218
219        Ok(())
220    }
221
222    #[test]
223    fn should_ignore_lower_group() -> Result<()> {
224        let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
225        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
226        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?;
227        agg.set_batch(ids, vals);
228        agg.insert(0)?;
229        agg.insert(1)?;
230
231        let cols = agg.emit()?;
232        let batch = RecordBatch::try_new(test_schema(), cols)?;
233        let actual = format!("{}", pretty_format_batches(&[batch])?);
234        assert_snapshot!(actual, @r#"
235+----------+--------------+
236| trace_id | timestamp_ms |
237+----------+--------------+
238| 2        | 2            |
239+----------+--------------+
240        "#
241        );
242
243        Ok(())
244    }
245
246    #[test]
247    fn should_ignore_higher_same_group() -> Result<()> {
248        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
249        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
250        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?;
251        agg.set_batch(ids, vals);
252        agg.insert(0)?;
253        agg.insert(1)?;
254
255        let cols = agg.emit()?;
256        let batch = RecordBatch::try_new(test_schema(), cols)?;
257        let actual = format!("{}", pretty_format_batches(&[batch])?);
258        assert_snapshot!(actual, @r#"
259+----------+--------------+
260| trace_id | timestamp_ms |
261+----------+--------------+
262| 1        | 1            |
263+----------+--------------+
264        "#
265        );
266
267        Ok(())
268    }
269
270    #[test]
271    fn should_ignore_lower_same_group() -> Result<()> {
272        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
273        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
274        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
275        agg.set_batch(ids, vals);
276        agg.insert(0)?;
277        agg.insert(1)?;
278
279        let cols = agg.emit()?;
280        let batch = RecordBatch::try_new(test_schema(), cols)?;
281        let actual = format!("{}", pretty_format_batches(&[batch])?);
282        assert_snapshot!(actual, @r#"
283+----------+--------------+
284| trace_id | timestamp_ms |
285+----------+--------------+
286| 1        | 2            |
287+----------+--------------+
288        "#
289        );
290
291        Ok(())
292    }
293
294    #[test]
295    fn should_accept_lower_group() -> Result<()> {
296        let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
297        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
298        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
299        agg.set_batch(ids, vals);
300        agg.insert(0)?;
301        agg.insert(1)?;
302
303        let cols = agg.emit()?;
304        let batch = RecordBatch::try_new(test_schema(), cols)?;
305        let actual = format!("{}", pretty_format_batches(&[batch])?);
306        assert_snapshot!(actual, @r#"
307+----------+--------------+
308| trace_id | timestamp_ms |
309+----------+--------------+
310| 1        | 1            |
311+----------+--------------+
312        "#
313        );
314
315        Ok(())
316    }
317
318    #[test]
319    fn should_accept_higher_group() -> Result<()> {
320        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
321        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
322        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?;
323        agg.set_batch(ids, vals);
324        agg.insert(0)?;
325        agg.insert(1)?;
326
327        let cols = agg.emit()?;
328        let batch = RecordBatch::try_new(test_schema(), cols)?;
329        let actual = format!("{}", pretty_format_batches(&[batch])?);
330        assert_snapshot!(actual, @r#"
331+----------+--------------+
332| trace_id | timestamp_ms |
333+----------+--------------+
334| 2        | 2            |
335+----------+--------------+
336        "#
337        );
338
339        Ok(())
340    }
341
342    #[test]
343    fn should_accept_lower_for_group() -> Result<()> {
344        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
345        let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
346        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?;
347        agg.set_batch(ids, vals);
348        agg.insert(0)?;
349        agg.insert(1)?;
350
351        let cols = agg.emit()?;
352        let batch = RecordBatch::try_new(test_schema(), cols)?;
353        let actual = format!("{}", pretty_format_batches(&[batch])?);
354        assert_snapshot!(actual, @r#"
355+----------+--------------+
356| trace_id | timestamp_ms |
357+----------+--------------+
358| 1        | 1            |
359+----------+--------------+
360        "#
361        );
362
363        Ok(())
364    }
365
366    #[test]
367    fn should_accept_higher_for_group() -> Result<()> {
368        let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
369        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
370        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
371        agg.set_batch(ids, vals);
372        agg.insert(0)?;
373        agg.insert(1)?;
374
375        let cols = agg.emit()?;
376        let batch = RecordBatch::try_new(test_schema(), cols)?;
377        let actual = format!("{}", pretty_format_batches(&[batch])?);
378        assert_snapshot!(actual, @r#"
379+----------+--------------+
380| trace_id | timestamp_ms |
381+----------+--------------+
382| 1        | 2            |
383+----------+--------------+
384        "#
385        );
386
387        Ok(())
388    }
389
390    #[test]
391    fn should_handle_null_ids() -> Result<()> {
392        let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None]));
393        let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
394        let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
395        agg.set_batch(ids, vals);
396        agg.insert(0)?;
397        agg.insert(1)?;
398        agg.insert(2)?;
399
400        let cols = agg.emit()?;
401        let batch = RecordBatch::try_new(test_schema(), cols)?;
402        let actual = format!("{}", pretty_format_batches(&[batch])?);
403        assert_snapshot!(actual, @r#"
404+----------+--------------+
405| trace_id | timestamp_ms |
406+----------+--------------+
407|          | 3            |
408| 1        | 1            |
409+----------+--------------+
410        "#
411        );
412
413        Ok(())
414    }
415
416    fn test_schema() -> SchemaRef {
417        Arc::new(Schema::new(vec![
418            Field::new("trace_id", DataType::Utf8, true),
419            Field::new("timestamp_ms", DataType::Int64, true),
420        ]))
421    }
422
423    fn test_schema_utf8view() -> SchemaRef {
424        Arc::new(Schema::new(vec![
425            Field::new("trace_id", DataType::Utf8View, true),
426            Field::new("timestamp_ms", DataType::Int64, true),
427        ]))
428    }
429
430    fn test_large_schema() -> SchemaRef {
431        Arc::new(Schema::new(vec![
432            Field::new("trace_id", DataType::LargeUtf8, true),
433            Field::new("timestamp_ms", DataType::Int64, true),
434        ]))
435    }
436}