datafusion_physical_plan/aggregates/topk/
priority_map.rs1use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable};
21use crate::aggregates::topk::heap::{new_heap, ArrowHeap};
22use arrow::array::ArrayRef;
23use arrow::datatypes::DataType;
24use datafusion_common::Result;
25
26pub struct PriorityMap {
28 map: Box<dyn ArrowHashTable + Send>,
29 heap: Box<dyn ArrowHeap + Send>,
30 capacity: usize,
31 mapper: Vec<(usize, usize)>,
32}
33
34impl PriorityMap {
35 pub fn new(
36 key_type: DataType,
37 val_type: DataType,
38 capacity: usize,
39 descending: bool,
40 ) -> Result<Self> {
41 Ok(Self {
42 map: new_hash_table(capacity, key_type)?,
43 heap: new_heap(capacity, descending, val_type)?,
44 capacity,
45 mapper: Vec::with_capacity(capacity),
46 })
47 }
48
49 pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) {
50 self.map.set_batch(ids);
51 self.heap.set_batch(vals);
52 }
53
54 pub fn insert(&mut self, row_idx: usize) -> Result<()> {
55 assert!(self.map.len() <= self.capacity, "Overflow");
56
57 if self.heap.is_worse(row_idx) {
59 return Ok(());
60 }
61 let map = &mut self.mapper;
62
63 map.clear();
65 let replace_idx = self.heap.worst_map_idx();
66 let (map_idx, did_insert) =
70 unsafe { self.map.find_or_insert(row_idx, replace_idx, map) };
71 if did_insert {
72 self.heap.renumber(map);
73 map.clear();
74 self.heap.insert(row_idx, map_idx, map);
75 unsafe { self.map.update_heap_idx(map) };
79 return Ok(());
80 };
81
82 map.clear();
84 let heap_idx = unsafe { self.map.heap_idx_at(map_idx) };
88 self.heap.replace_if_better(heap_idx, row_idx, map);
89 unsafe { self.map.update_heap_idx(map) };
93
94 Ok(())
95 }
96
97 pub fn emit(&mut self) -> Result<Vec<ArrayRef>> {
98 let (vals, map_idxs) = self.heap.drain();
99 let ids = unsafe { self.map.take_all(map_idxs) };
100 Ok(vec![ids, vals])
101 }
102
103 pub fn is_empty(&self) -> bool {
104 self.map.len() == 0
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use arrow::array::{
112 Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray,
113 };
114 use arrow::datatypes::{Field, Schema, SchemaRef};
115 use arrow::util::pretty::pretty_format_batches;
116 use insta::assert_snapshot;
117 use std::sync::Arc;
118
119 #[test]
120 fn should_append_with_utf8view() -> Result<()> {
121 let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"]));
122 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
123 let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?;
124 agg.set_batch(ids, vals);
125 agg.insert(0)?;
126
127 let cols = agg.emit()?;
128 let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?;
129 let batch_schema = batch.schema();
130 assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View);
131
132 let actual = format!("{}", pretty_format_batches(&[batch])?);
133 let expected = r#"
134+----------+--------------+
135| trace_id | timestamp_ms |
136+----------+--------------+
137| 1 | 1 |
138+----------+--------------+
139 "#
140 .trim();
141 assert_eq!(actual, expected);
142
143 Ok(())
144 }
145
146 #[test]
147 fn should_append_with_large_utf8() -> Result<()> {
148 let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"]));
149 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
150 let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?;
151 agg.set_batch(ids, vals);
152 agg.insert(0)?;
153
154 let cols = agg.emit()?;
155 let batch = RecordBatch::try_new(test_large_schema(), cols)?;
156 let batch_schema = batch.schema();
157 assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8);
158
159 let actual = format!("{}", pretty_format_batches(&[batch])?);
160 let expected = r#"
161+----------+--------------+
162| trace_id | timestamp_ms |
163+----------+--------------+
164| 1 | 1 |
165+----------+--------------+
166 "#
167 .trim();
168 assert_eq!(actual, expected);
169
170 Ok(())
171 }
172
173 #[test]
174 fn should_append() -> Result<()> {
175 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"]));
176 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1]));
177 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
178 agg.set_batch(ids, vals);
179 agg.insert(0)?;
180
181 let cols = agg.emit()?;
182 let batch = RecordBatch::try_new(test_schema(), cols)?;
183 let actual = format!("{}", pretty_format_batches(&[batch])?);
184
185 assert_snapshot!(actual, @r#"
186+----------+--------------+
187| trace_id | timestamp_ms |
188+----------+--------------+
189| 1 | 1 |
190+----------+--------------+
191 "#
192 );
193
194 Ok(())
195 }
196
197 #[test]
198 fn should_ignore_higher_group() -> Result<()> {
199 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
200 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
201 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
202 agg.set_batch(ids, vals);
203 agg.insert(0)?;
204 agg.insert(1)?;
205
206 let cols = agg.emit()?;
207 let batch = RecordBatch::try_new(test_schema(), cols)?;
208 let actual = format!("{}", pretty_format_batches(&[batch])?);
209
210 assert_snapshot!(actual, @r#"
211+----------+--------------+
212| trace_id | timestamp_ms |
213+----------+--------------+
214| 1 | 1 |
215+----------+--------------+
216 "#
217 );
218
219 Ok(())
220 }
221
222 #[test]
223 fn should_ignore_lower_group() -> Result<()> {
224 let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
225 let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
226 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?;
227 agg.set_batch(ids, vals);
228 agg.insert(0)?;
229 agg.insert(1)?;
230
231 let cols = agg.emit()?;
232 let batch = RecordBatch::try_new(test_schema(), cols)?;
233 let actual = format!("{}", pretty_format_batches(&[batch])?);
234 assert_snapshot!(actual, @r#"
235+----------+--------------+
236| trace_id | timestamp_ms |
237+----------+--------------+
238| 2 | 2 |
239+----------+--------------+
240 "#
241 );
242
243 Ok(())
244 }
245
246 #[test]
247 fn should_ignore_higher_same_group() -> Result<()> {
248 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
249 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
250 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?;
251 agg.set_batch(ids, vals);
252 agg.insert(0)?;
253 agg.insert(1)?;
254
255 let cols = agg.emit()?;
256 let batch = RecordBatch::try_new(test_schema(), cols)?;
257 let actual = format!("{}", pretty_format_batches(&[batch])?);
258 assert_snapshot!(actual, @r#"
259+----------+--------------+
260| trace_id | timestamp_ms |
261+----------+--------------+
262| 1 | 1 |
263+----------+--------------+
264 "#
265 );
266
267 Ok(())
268 }
269
270 #[test]
271 fn should_ignore_lower_same_group() -> Result<()> {
272 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
273 let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
274 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
275 agg.set_batch(ids, vals);
276 agg.insert(0)?;
277 agg.insert(1)?;
278
279 let cols = agg.emit()?;
280 let batch = RecordBatch::try_new(test_schema(), cols)?;
281 let actual = format!("{}", pretty_format_batches(&[batch])?);
282 assert_snapshot!(actual, @r#"
283+----------+--------------+
284| trace_id | timestamp_ms |
285+----------+--------------+
286| 1 | 2 |
287+----------+--------------+
288 "#
289 );
290
291 Ok(())
292 }
293
294 #[test]
295 fn should_accept_lower_group() -> Result<()> {
296 let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"]));
297 let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
298 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?;
299 agg.set_batch(ids, vals);
300 agg.insert(0)?;
301 agg.insert(1)?;
302
303 let cols = agg.emit()?;
304 let batch = RecordBatch::try_new(test_schema(), cols)?;
305 let actual = format!("{}", pretty_format_batches(&[batch])?);
306 assert_snapshot!(actual, @r#"
307+----------+--------------+
308| trace_id | timestamp_ms |
309+----------+--------------+
310| 1 | 1 |
311+----------+--------------+
312 "#
313 );
314
315 Ok(())
316 }
317
318 #[test]
319 fn should_accept_higher_group() -> Result<()> {
320 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"]));
321 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
322 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?;
323 agg.set_batch(ids, vals);
324 agg.insert(0)?;
325 agg.insert(1)?;
326
327 let cols = agg.emit()?;
328 let batch = RecordBatch::try_new(test_schema(), cols)?;
329 let actual = format!("{}", pretty_format_batches(&[batch])?);
330 assert_snapshot!(actual, @r#"
331+----------+--------------+
332| trace_id | timestamp_ms |
333+----------+--------------+
334| 2 | 2 |
335+----------+--------------+
336 "#
337 );
338
339 Ok(())
340 }
341
342 #[test]
343 fn should_accept_lower_for_group() -> Result<()> {
344 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
345 let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1]));
346 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?;
347 agg.set_batch(ids, vals);
348 agg.insert(0)?;
349 agg.insert(1)?;
350
351 let cols = agg.emit()?;
352 let batch = RecordBatch::try_new(test_schema(), cols)?;
353 let actual = format!("{}", pretty_format_batches(&[batch])?);
354 assert_snapshot!(actual, @r#"
355+----------+--------------+
356| trace_id | timestamp_ms |
357+----------+--------------+
358| 1 | 1 |
359+----------+--------------+
360 "#
361 );
362
363 Ok(())
364 }
365
366 #[test]
367 fn should_accept_higher_for_group() -> Result<()> {
368 let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"]));
369 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
370 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
371 agg.set_batch(ids, vals);
372 agg.insert(0)?;
373 agg.insert(1)?;
374
375 let cols = agg.emit()?;
376 let batch = RecordBatch::try_new(test_schema(), cols)?;
377 let actual = format!("{}", pretty_format_batches(&[batch])?);
378 assert_snapshot!(actual, @r#"
379+----------+--------------+
380| trace_id | timestamp_ms |
381+----------+--------------+
382| 1 | 2 |
383+----------+--------------+
384 "#
385 );
386
387 Ok(())
388 }
389
390 #[test]
391 fn should_handle_null_ids() -> Result<()> {
392 let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None]));
393 let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
394 let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?;
395 agg.set_batch(ids, vals);
396 agg.insert(0)?;
397 agg.insert(1)?;
398 agg.insert(2)?;
399
400 let cols = agg.emit()?;
401 let batch = RecordBatch::try_new(test_schema(), cols)?;
402 let actual = format!("{}", pretty_format_batches(&[batch])?);
403 assert_snapshot!(actual, @r#"
404+----------+--------------+
405| trace_id | timestamp_ms |
406+----------+--------------+
407| | 3 |
408| 1 | 1 |
409+----------+--------------+
410 "#
411 );
412
413 Ok(())
414 }
415
416 fn test_schema() -> SchemaRef {
417 Arc::new(Schema::new(vec![
418 Field::new("trace_id", DataType::Utf8, true),
419 Field::new("timestamp_ms", DataType::Int64, true),
420 ]))
421 }
422
423 fn test_schema_utf8view() -> SchemaRef {
424 Arc::new(Schema::new(vec![
425 Field::new("trace_id", DataType::Utf8View, true),
426 Field::new("timestamp_ms", DataType::Int64, true),
427 ]))
428 }
429
430 fn test_large_schema() -> SchemaRef {
431 Arc::new(Schema::new(vec![
432 Field::new("trace_id", DataType::LargeUtf8, true),
433 Field::new("timestamp_ms", DataType::Int64, true),
434 ]))
435 }
436}