datafusion_physical_plan/topk/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! TopK: Combination of Sort / LIMIT
19
20use arrow::{
21    array::{Array, AsArray},
22    compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder},
23    row::{RowConverter, Rows, SortField},
24};
25use datafusion_expr::{ColumnarValue, Operator};
26use std::mem::size_of;
27use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};
28
29use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
30use crate::spill::get_record_batch_memory_size;
31use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
32
33use arrow::array::{ArrayRef, RecordBatch};
34use arrow::datatypes::SchemaRef;
35use datafusion_common::{
36    internal_datafusion_err, internal_err, HashMap, Result, ScalarValue,
37};
38use datafusion_execution::{
39    memory_pool::{MemoryConsumer, MemoryReservation},
40    runtime_env::RuntimeEnv,
41};
42use datafusion_physical_expr::{
43    expressions::{is_not_null, is_null, lit, BinaryExpr, DynamicFilterPhysicalExpr},
44    PhysicalExpr,
45};
46use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
47use parking_lot::RwLock;
48
49/// Global TopK
50///
51/// # Background
52///
53/// "Top K" is a common query optimization used for queries such as
54/// "find the top 3 customers by revenue". The (simplified) SQL for
55/// such a query might be:
56///
57/// ```sql
58/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3;
59/// ```
60///
61/// The simple plan would be:
62///
63/// ```sql
64/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3;
65/// +--------------+----------------------------------------+
66/// | plan_type    | plan                                   |
67/// +--------------+----------------------------------------+
68/// | logical_plan | Limit: 3                               |
69/// |              |   Sort: revenue DESC NULLS FIRST       |
70/// |              |     Projection: customer_id, revenue   |
71/// |              |       TableScan: sales                 |
72/// +--------------+----------------------------------------+
73/// ```
74///
75/// While this plan produces the correct answer, it will fully sorts the
76/// input before discarding everything other than the top 3 elements.
77///
78/// The same answer can be produced by simply keeping track of the top
79/// K=3 elements, reducing the total amount of required buffer memory.
80///
81/// # Partial Sort Optimization
82///
83/// This implementation additionally optimizes queries where the input is already
84/// partially sorted by a common prefix of the requested ordering. Once the top K
85/// heap is full, if subsequent rows are guaranteed to be strictly greater (in sort
86/// order) on this prefix than the largest row currently stored, the operator
87/// safely terminates early.
88///
89/// ## Example
90///
91/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as:
92///
93/// ```sql
94/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 10;
95/// ```
96///
97/// can terminate scanning early once sufficient rows from the latest days have been
98/// collected, skipping older data.
99///
100/// # Structure
101///
102/// This operator tracks the top K items using a `TopKHeap`.
103pub struct TopK {
104    /// schema of the output (and the input)
105    schema: SchemaRef,
106    /// Runtime metrics
107    metrics: TopKMetrics,
108    /// Reservation
109    reservation: MemoryReservation,
110    /// The target number of rows for output batches
111    batch_size: usize,
112    /// sort expressions
113    expr: LexOrdering,
114    /// row converter, for sort keys
115    row_converter: RowConverter,
116    /// scratch space for converting rows
117    scratch_rows: Rows,
118    /// stores the top k values and their sort key values, in order
119    heap: TopKHeap,
120    /// row converter, for common keys between the sort keys and the input ordering
121    common_sort_prefix_converter: Option<RowConverter>,
122    /// Common sort prefix between the input and the sort expressions to allow early exit optimization
123    common_sort_prefix: Arc<[PhysicalSortExpr]>,
124    /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
125    filter: Arc<RwLock<TopKDynamicFilters>>,
126    /// If true, indicates that all rows of subsequent batches are guaranteed
127    /// to be greater (by byte order, after row conversion) than the top K,
128    /// which means the top K won't change and the computation can be finished early.
129    pub(crate) finished: bool,
130}
131
132#[derive(Debug, Clone)]
133pub struct TopKDynamicFilters {
134    /// The current *global* threshold for the dynamic filter.
135    /// This is shared across all partitions and is updated by any of them.
136    /// Stored as row bytes for efficient comparison.
137    threshold_row: Option<Vec<u8>>,
138    /// The expression used to evaluate the dynamic filter
139    /// Only updated when lock held for the duration of the update
140    expr: Arc<DynamicFilterPhysicalExpr>,
141}
142
143impl TopKDynamicFilters {
144    /// Create a new `TopKDynamicFilters` with the given expression
145    pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
146        Self {
147            threshold_row: None,
148            expr,
149        }
150    }
151
152    pub fn expr(&self) -> Arc<DynamicFilterPhysicalExpr> {
153        Arc::clone(&self.expr)
154    }
155}
156
157// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
158const ESTIMATED_BYTES_PER_ROW: usize = 20;
159
160fn build_sort_fields(
161    ordering: &[PhysicalSortExpr],
162    schema: &SchemaRef,
163) -> Result<Vec<SortField>> {
164    ordering
165        .iter()
166        .map(|e| {
167            Ok(SortField::new_with_options(
168                e.expr.data_type(schema)?,
169                e.options,
170            ))
171        })
172        .collect::<Result<_>>()
173}
174
175impl TopK {
176    /// Create a new [`TopK`] that stores the top `k` values, as
177    /// defined by the sort expressions in `expr`.
178    // TODO: make a builder or some other nicer API
179    #[allow(clippy::too_many_arguments)]
180    pub fn try_new(
181        partition_id: usize,
182        schema: SchemaRef,
183        common_sort_prefix: Vec<PhysicalSortExpr>,
184        expr: LexOrdering,
185        k: usize,
186        batch_size: usize,
187        runtime: Arc<RuntimeEnv>,
188        metrics: &ExecutionPlanMetricsSet,
189        filter: Arc<RwLock<TopKDynamicFilters>>,
190    ) -> Result<Self> {
191        let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]"))
192            .register(&runtime.memory_pool);
193
194        let sort_fields = build_sort_fields(&expr, &schema)?;
195
196        // TODO there is potential to add special cases for single column sort fields
197        // to improve performance
198        let row_converter = RowConverter::new(sort_fields)?;
199        let scratch_rows =
200            row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size);
201
202        let prefix_row_converter = if common_sort_prefix.is_empty() {
203            None
204        } else {
205            let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?;
206            Some(RowConverter::new(input_sort_fields)?)
207        };
208
209        Ok(Self {
210            schema: Arc::clone(&schema),
211            metrics: TopKMetrics::new(metrics, partition_id),
212            reservation,
213            batch_size,
214            expr,
215            row_converter,
216            scratch_rows,
217            heap: TopKHeap::new(k, batch_size),
218            common_sort_prefix_converter: prefix_row_converter,
219            common_sort_prefix: Arc::from(common_sort_prefix),
220            finished: false,
221            filter,
222        })
223    }
224
225    /// Insert `batch`, remembering if any of its values are among
226    /// the top k seen so far.
227    pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> {
228        // Updates on drop
229        let baseline = self.metrics.baseline.clone();
230        let _timer = baseline.elapsed_compute().timer();
231
232        let mut sort_keys: Vec<ArrayRef> = self
233            .expr
234            .iter()
235            .map(|expr| {
236                let value = expr.expr.evaluate(&batch)?;
237                value.into_array(batch.num_rows())
238            })
239            .collect::<Result<Vec<_>>>()?;
240
241        let mut selected_rows = None;
242
243        // If a filter is provided, update it with the new rows
244        let filter = self.filter.read().expr.current()?;
245        let filtered = filter.evaluate(&batch)?;
246        let num_rows = batch.num_rows();
247        let array = filtered.into_array(num_rows)?;
248        let mut filter = array.as_boolean().clone();
249        let true_count = filter.true_count();
250        if true_count == 0 {
251            // nothing to filter, so no need to update
252            return Ok(());
253        }
254        // only update the keys / rows if the filter does not match all rows
255        if true_count < num_rows {
256            // Indices in `set_indices` should be correct if filter contains nulls
257            // So we prepare the filter here. Note this is also done in the `FilterBuilder`
258            // so there is no overhead to do this here.
259            if filter.nulls().is_some() {
260                filter = prep_null_mask_filter(&filter);
261            }
262
263            let filter_predicate = FilterBuilder::new(&filter);
264            let filter_predicate = if sort_keys.len() > 1 {
265                // Optimize filter when it has multiple sort keys
266                filter_predicate.optimize().build()
267            } else {
268                filter_predicate.build()
269            };
270            selected_rows = Some(filter);
271            sort_keys = sort_keys
272                .iter()
273                .map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
274                .collect::<Result<Vec<_>>>()?;
275        }
276        // reuse existing `Rows` to avoid reallocations
277        let rows = &mut self.scratch_rows;
278        rows.clear();
279        self.row_converter.append(rows, &sort_keys)?;
280
281        let mut batch_entry = self.heap.register_batch(batch.clone());
282
283        let replacements = match selected_rows {
284            Some(filter) => {
285                self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
286            }
287            None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
288        };
289
290        if replacements > 0 {
291            self.metrics.row_replacements.add(replacements);
292
293            self.heap.insert_batch_entry(batch_entry);
294
295            // conserve memory
296            self.heap.maybe_compact()?;
297
298            // update memory reservation
299            self.reservation.try_resize(self.size())?;
300
301            // flag the topK as finished if we know that all
302            // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K,
303            // which means the top K won't change and the computation can be finished early.
304            self.attempt_early_completion(&batch)?;
305
306            // update the filter representation of our TopK heap
307            self.update_filter()?;
308        }
309
310        Ok(())
311    }
312
313    fn find_new_topk_items(
314        &mut self,
315        items: impl Iterator<Item = usize>,
316        batch_entry: &mut RecordBatchEntry,
317    ) -> usize {
318        let mut replacements = 0;
319        let rows = &mut self.scratch_rows;
320        for (index, row) in items.zip(rows.iter()) {
321            match self.heap.max() {
322                // heap has k items, and the new row is greater than the
323                // current max in the heap ==> it is not a new topk
324                Some(max_row) if row.as_ref() >= max_row.row() => {}
325                // don't yet have k items or new item is lower than the currently k low values
326                None | Some(_) => {
327                    self.heap.add(batch_entry, row, index);
328                    replacements += 1;
329                }
330            }
331        }
332        replacements
333    }
334
335    /// Update the filter representation of our TopK heap.
336    /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`,
337    /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`,
338    /// the filter will be updated to:
339    ///
340    /// ```sql
341    /// (a > 1 OR (a = 1 AND b < 5)) AND
342    /// (a > 1 OR (a = 1 AND b < 4)) AND
343    /// (a > 2 OR (a = 2 AND b < 3))
344    /// ```
345    fn update_filter(&mut self) -> Result<()> {
346        // If the heap doesn't have k elements yet, we can't create thresholds
347        let Some(max_row) = self.heap.max() else {
348            return Ok(());
349        };
350
351        let new_threshold_row = &max_row.row;
352
353        // Fast path: check if the current value in topk is better than what is
354        // currently set in the filter with a read only lock
355        let needs_update = self
356            .filter
357            .read()
358            .threshold_row
359            .as_ref()
360            .map(|current_row| {
361                // new < current means new threshold is more selective
362                new_threshold_row < current_row
363            })
364            .unwrap_or(true); // No current threshold, so we need to set one
365
366        // exit early if the current values are better
367        if !needs_update {
368            return Ok(());
369        }
370
371        // Extract scalar values BEFORE acquiring lock to reduce critical section
372        let thresholds = match self.heap.get_threshold_values(&self.expr)? {
373            Some(t) => t,
374            None => return Ok(()),
375        };
376
377        // Build the filter expression OUTSIDE any synchronization
378        let predicate = Self::build_filter_expression(&self.expr, thresholds)?;
379        let new_threshold = new_threshold_row.to_vec();
380
381        // update the threshold. Since there was a lock gap, we must check if it is still the best
382        // may have changed while we were building the expression without the lock
383        let mut filter = self.filter.write();
384        let old_threshold = filter.threshold_row.take();
385
386        // Update filter if we successfully updated the threshold
387        // (or if there was no previous threshold and we're the first)
388        match old_threshold {
389            Some(old_threshold) => {
390                // new threshold is still better than the old one
391                if new_threshold.as_slice() < old_threshold.as_slice() {
392                    filter.threshold_row = Some(new_threshold);
393                } else {
394                    // some other thread updated the threshold to a better
395                    // one while we were building so there is no need to
396                    // update the filter
397                    filter.threshold_row = Some(old_threshold);
398                    return Ok(());
399                }
400            }
401            None => {
402                // No previous threshold, so we can set the new one
403                filter.threshold_row = Some(new_threshold);
404            }
405        };
406
407        // Update the filter expression
408        if let Some(pred) = predicate {
409            if !pred.eq(&lit(true)) {
410                filter.expr.update(pred)?;
411            }
412        }
413
414        Ok(())
415    }
416
417    /// Build the filter expression with the given thresholds.
418    /// This is now called outside of any locks to reduce critical section time.
419    fn build_filter_expression(
420        sort_exprs: &[PhysicalSortExpr],
421        thresholds: Vec<ScalarValue>,
422    ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
423        // Create filter expressions for each threshold
424        let mut filters: Vec<Arc<dyn PhysicalExpr>> =
425            Vec::with_capacity(thresholds.len());
426
427        let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
428        for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
429            // Create the appropriate operator based on sort order
430            let op = if sort_expr.options.descending {
431                // For descending sort, we want col > threshold (exclude smaller values)
432                Operator::Gt
433            } else {
434                // For ascending sort, we want col < threshold (exclude larger values)
435                Operator::Lt
436            };
437
438            let value_null = value.is_null();
439
440            let comparison = Arc::new(BinaryExpr::new(
441                Arc::clone(&sort_expr.expr),
442                op,
443                lit(value.clone()),
444            ));
445
446            let comparison_with_null = match (sort_expr.options.nulls_first, value_null) {
447                // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison)
448                (true, true) => lit(false),
449                (true, false) => Arc::new(BinaryExpr::new(
450                    is_null(Arc::clone(&sort_expr.expr))?,
451                    Operator::Or,
452                    comparison,
453                )),
454                // For nulls last, transform to (threshold.value is null and threshold.expr is not null)
455                // or (threshold.value is not null and comparison)
456                (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?,
457                (false, false) => comparison,
458            };
459
460            let mut eq_expr = Arc::new(BinaryExpr::new(
461                Arc::clone(&sort_expr.expr),
462                Operator::Eq,
463                lit(value.clone()),
464            ));
465
466            if value_null {
467                eq_expr = Arc::new(BinaryExpr::new(
468                    is_null(Arc::clone(&sort_expr.expr))?,
469                    Operator::Or,
470                    eq_expr,
471                ));
472            }
473
474            // For a query like order by a, b, the filter for column `b` is only applied if
475            // the condition a = threshold.value (considering null equality) is met.
476            // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field,
477            // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields.
478            match prev_sort_expr.take() {
479                None => {
480                    prev_sort_expr = Some(eq_expr);
481                    filters.push(comparison_with_null);
482                }
483                Some(p) => {
484                    filters.push(Arc::new(BinaryExpr::new(
485                        Arc::clone(&p),
486                        Operator::And,
487                        comparison_with_null,
488                    )));
489
490                    prev_sort_expr =
491                        Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr)));
492                }
493            }
494        }
495
496        let dynamic_predicate = filters
497            .into_iter()
498            .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b)));
499
500        Ok(dynamic_predicate)
501    }
502
503    /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full,
504    /// check if the computation can be finished early.
505    /// This is the case if the last row of the current batch is strictly greater than the max row in the heap,
506    /// comparing only on the shared prefix columns.
507    fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> {
508        // Early exit if the batch is empty as there is no last row to extract from it.
509        if batch.num_rows() == 0 {
510            return Ok(());
511        }
512
513        // prefix_row_converter is only `Some` if the input ordering has a common prefix with the TopK,
514        // so early exit if it is `None`.
515        let Some(prefix_converter) = &self.common_sort_prefix_converter else {
516            return Ok(());
517        };
518
519        // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full).
520        let Some(max_topk_row) = self.heap.max() else {
521            return Ok(());
522        };
523
524        // Evaluate the prefix for the last row of the current batch.
525        let last_row_idx = batch.num_rows() - 1;
526        let mut batch_prefix_scratch =
527            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
528
529        self.compute_common_sort_prefix(batch, last_row_idx, &mut batch_prefix_scratch)?;
530
531        // Retrieve the max row from the heap.
532        let store_entry = self
533            .heap
534            .store
535            .get(max_topk_row.batch_id)
536            .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
537        let max_batch = &store_entry.batch;
538        let mut heap_prefix_scratch =
539            prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW
540        self.compute_common_sort_prefix(
541            max_batch,
542            max_topk_row.index,
543            &mut heap_prefix_scratch,
544        )?;
545
546        // If the last row's prefix is strictly greater than the max prefix, mark as finished.
547        if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() {
548            self.finished = true;
549        }
550
551        Ok(())
552    }
553
554    // Helper function to compute the prefix for a given batch and row index, storing the result in scratch.
555    fn compute_common_sort_prefix(
556        &self,
557        batch: &RecordBatch,
558        last_row_idx: usize,
559        scratch: &mut Rows,
560    ) -> Result<()> {
561        let last_row: Vec<ArrayRef> = self
562            .common_sort_prefix
563            .iter()
564            .map(|expr| {
565                expr.expr
566                    .evaluate(&batch.slice(last_row_idx, 1))?
567                    .into_array(1)
568            })
569            .collect::<Result<_>>()?;
570
571        self.common_sort_prefix_converter
572            .as_ref()
573            .unwrap()
574            .append(scratch, &last_row)?;
575        Ok(())
576    }
577
578    /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap
579    pub fn emit(self) -> Result<SendableRecordBatchStream> {
580        let Self {
581            schema,
582            metrics,
583            reservation: _,
584            batch_size,
585            expr: _,
586            row_converter: _,
587            scratch_rows: _,
588            mut heap,
589            common_sort_prefix_converter: _,
590            common_sort_prefix: _,
591            finished: _,
592            filter: _,
593        } = self;
594        let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop
595
596        // break into record batches as needed
597        let mut batches = vec![];
598        if let Some(mut batch) = heap.emit()? {
599            metrics.baseline.output_rows().add(batch.num_rows());
600
601            loop {
602                if batch.num_rows() <= batch_size {
603                    batches.push(Ok(batch));
604                    break;
605                } else {
606                    batches.push(Ok(batch.slice(0, batch_size)));
607                    let remaining_length = batch.num_rows() - batch_size;
608                    batch = batch.slice(batch_size, remaining_length);
609                }
610            }
611        };
612        Ok(Box::pin(RecordBatchStreamAdapter::new(
613            schema,
614            futures::stream::iter(batches),
615        )))
616    }
617
618    /// return the size of memory used by this operator, in bytes
619    fn size(&self) -> usize {
620        size_of::<Self>()
621            + self.row_converter.size()
622            + self.scratch_rows.size()
623            + self.heap.size()
624    }
625}
626
627struct TopKMetrics {
628    /// metrics
629    pub baseline: BaselineMetrics,
630
631    /// count of how many rows were replaced in the heap
632    pub row_replacements: Count,
633}
634
635impl TopKMetrics {
636    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
637        Self {
638            baseline: BaselineMetrics::new(metrics, partition),
639            row_replacements: MetricBuilder::new(metrics)
640                .counter("row_replacements", partition),
641        }
642    }
643}
644
645/// This structure keeps at most the *smallest* k items, using the
646/// [arrow::row] format for sort keys. While it is called "topK" for
647/// values like `1, 2, 3, 4, 5` the "top 3" really means the
648/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`.
649///
650/// Using the `Row` format handles things such as ascending vs
651/// descending and nulls first vs nulls last.
652struct TopKHeap {
653    /// The maximum number of elements to store in this heap.
654    k: usize,
655    /// The target number of rows for output batches
656    batch_size: usize,
657    /// Storage for up at most `k` items using a BinaryHeap. Reversed
658    /// so that the smallest k so far is on the top
659    inner: BinaryHeap<TopKRow>,
660    /// Storage the original row values (TopKRow only has the sort key)
661    store: RecordBatchStore,
662    /// The size of all owned data held by this heap
663    owned_bytes: usize,
664}
665
666impl TopKHeap {
667    fn new(k: usize, batch_size: usize) -> Self {
668        assert!(k > 0);
669        Self {
670            k,
671            batch_size,
672            inner: BinaryHeap::new(),
673            store: RecordBatchStore::new(),
674            owned_bytes: 0,
675        }
676    }
677
678    /// Register a [`RecordBatch`] with the heap, returning the
679    /// appropriate entry
680    pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry {
681        self.store.register(batch)
682    }
683
684    /// Insert a [`RecordBatchEntry`] created by a previous call to
685    /// [`Self::register_batch`] into storage.
686    pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) {
687        self.store.insert(entry)
688    }
689
690    /// Returns the largest value stored by the heap if there are k
691    /// items, otherwise returns None. Remember this structure is
692    /// keeping the "smallest" k values
693    fn max(&self) -> Option<&TopKRow> {
694        if self.inner.len() < self.k {
695            None
696        } else {
697            self.inner.peek()
698        }
699    }
700
701    /// Adds `row` to this heap. If inserting this new item would
702    /// increase the size past `k`, removes the previously smallest
703    /// item.
704    fn add(
705        &mut self,
706        batch_entry: &mut RecordBatchEntry,
707        row: impl AsRef<[u8]>,
708        index: usize,
709    ) {
710        let batch_id = batch_entry.id;
711        batch_entry.uses += 1;
712
713        assert!(self.inner.len() <= self.k);
714        let row = row.as_ref();
715
716        // Reuse storage for evicted item if possible
717        let new_top_k = if self.inner.len() == self.k {
718            let prev_min = self.inner.pop().unwrap();
719
720            // Update batch use
721            if prev_min.batch_id == batch_entry.id {
722                batch_entry.uses -= 1;
723            } else {
724                self.store.unuse(prev_min.batch_id);
725            }
726
727            // update memory accounting
728            self.owned_bytes -= prev_min.owned_size();
729            prev_min.with_new_row(row, batch_id, index)
730        } else {
731            TopKRow::new(row, batch_id, index)
732        };
733
734        self.owned_bytes += new_top_k.owned_size();
735
736        // put the new row into the heap
737        self.inner.push(new_top_k)
738    }
739
740    /// Returns the values stored in this heap, from values low to
741    /// high, as a single [`RecordBatch`], resetting the inner heap
742    pub fn emit(&mut self) -> Result<Option<RecordBatch>> {
743        Ok(self.emit_with_state()?.0)
744    }
745
746    /// Returns the values stored in this heap, from values low to
747    /// high, as a single [`RecordBatch`], and a sorted vec of the
748    /// current heap's contents
749    pub fn emit_with_state(&mut self) -> Result<(Option<RecordBatch>, Vec<TopKRow>)> {
750        // generate sorted rows
751        let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec();
752
753        if self.store.is_empty() {
754            return Ok((None, topk_rows));
755        }
756
757        // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then
758        // build the `indices` vec below. This is needed since the batch ids are not continuous.
759        let mut record_batches = Vec::new();
760        let mut batch_id_array_pos = HashMap::new();
761        for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() {
762            record_batches.push(&batch.batch);
763            batch_id_array_pos.insert(*batch_id, array_pos);
764        }
765
766        let indices: Vec<_> = topk_rows
767            .iter()
768            .map(|k| (batch_id_array_pos[&k.batch_id], k.index))
769            .collect();
770
771        // At this point `indices` contains indexes within the
772        // rows and `input_arrays` contains a reference to the
773        // relevant RecordBatch for that index. `interleave_record_batch` pulls
774        // them together into a single new batch
775        let new_batch = interleave_record_batch(&record_batches, &indices)?;
776
777        Ok((Some(new_batch), topk_rows))
778    }
779
780    /// Compact this heap, rewriting all stored batches into a single
781    /// input batch
782    pub fn maybe_compact(&mut self) -> Result<()> {
783        // we compact if the number of "unused" rows in the store is
784        // past some pre-defined threshold. Target holding up to
785        // around 20 batches, but handle cases of large k where some
786        // batches might be partially full
787        let max_unused_rows = (20 * self.batch_size) + self.k;
788        let unused_rows = self.store.unused_rows();
789
790        // don't compact if the store has one extra batch or
791        // unused rows is under the threshold
792        if self.store.len() <= 2 || unused_rows < max_unused_rows {
793            return Ok(());
794        }
795        // at first, compact the entire thing always into a new batch
796        // (maybe we can get fancier in the future about ignoring
797        // batches that have a high usage ratio already
798
799        // Note: new batch is in the same order as inner
800        let num_rows = self.inner.len();
801        let (new_batch, mut topk_rows) = self.emit_with_state()?;
802        let Some(new_batch) = new_batch else {
803            return Ok(());
804        };
805
806        // clear all old entries in store (this invalidates all
807        // store_ids in `inner`)
808        self.store.clear();
809
810        let mut batch_entry = self.register_batch(new_batch);
811        batch_entry.uses = num_rows;
812
813        // rewrite all existing entries to use the new batch, and
814        // remove old entries. The sortedness and their relative
815        // position do not change
816        for (i, topk_row) in topk_rows.iter_mut().enumerate() {
817            topk_row.batch_id = batch_entry.id;
818            topk_row.index = i;
819        }
820        self.insert_batch_entry(batch_entry);
821        // restore the heap
822        self.inner = BinaryHeap::from(topk_rows);
823
824        Ok(())
825    }
826
827    /// return the size of memory used by this heap, in bytes
828    fn size(&self) -> usize {
829        size_of::<Self>()
830            + (self.inner.capacity() * size_of::<TopKRow>())
831            + self.store.size()
832            + self.owned_bytes
833    }
834
835    fn get_threshold_values(
836        &self,
837        sort_exprs: &[PhysicalSortExpr],
838    ) -> Result<Option<Vec<ScalarValue>>> {
839        // If the heap doesn't have k elements yet, we can't create thresholds
840        let max_row = match self.max() {
841            Some(row) => row,
842            None => return Ok(None),
843        };
844
845        // Get the batch that contains the max row
846        let batch_entry = match self.store.get(max_row.batch_id) {
847            Some(entry) => entry,
848            None => return internal_err!("Invalid batch ID in TopKRow"),
849        };
850
851        // Extract threshold values for each sort expression
852        let mut scalar_values = Vec::with_capacity(sort_exprs.len());
853        for sort_expr in sort_exprs {
854            // Extract the value for this column from the max row
855            let expr = Arc::clone(&sort_expr.expr);
856            let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?;
857
858            // Convert to scalar value - should be a single value since we're evaluating on a single row batch
859            let scalar = match value {
860                ColumnarValue::Scalar(scalar) => scalar,
861                ColumnarValue::Array(array) if array.len() == 1 => {
862                    // Extract the first (and only) value from the array
863                    ScalarValue::try_from_array(&array, 0)?
864                }
865                array => {
866                    return internal_err!("Expected a scalar value, got {:?}", array)
867                }
868            };
869
870            scalar_values.push(scalar);
871        }
872
873        Ok(Some(scalar_values))
874    }
875}
876
877/// Represents one of the top K rows held in this heap. Orders
878/// according to memcmp of row (e.g. the arrow Row format, but could
879/// also be primitive values)
880///
881/// Reuses allocations to minimize runtime overhead of creating new Vecs
882#[derive(Debug, PartialEq)]
883struct TopKRow {
884    /// the value of the sort key for this row. This contains the
885    /// bytes that could be stored in `OwnedRow` but uses `Vec<u8>` to
886    /// reuse allocations.
887    row: Vec<u8>,
888    /// the RecordBatch this row came from: an id into a [`RecordBatchStore`]
889    batch_id: u32,
890    /// the index in this record batch the row came from
891    index: usize,
892}
893
894impl TopKRow {
895    /// Create a new TopKRow with new allocation
896    fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self {
897        Self {
898            row: row.as_ref().to_vec(),
899            batch_id,
900            index,
901        }
902    }
903
904    /// Create a new  TopKRow reusing the existing allocation
905    fn with_new_row(
906        self,
907        new_row: impl AsRef<[u8]>,
908        batch_id: u32,
909        index: usize,
910    ) -> Self {
911        let Self {
912            mut row,
913            batch_id: _,
914            index: _,
915        } = self;
916        row.clear();
917        row.extend_from_slice(new_row.as_ref());
918
919        Self {
920            row,
921            batch_id,
922            index,
923        }
924    }
925
926    /// Returns the number of bytes owned by this row in the heap (not
927    /// including itself)
928    fn owned_size(&self) -> usize {
929        self.row.capacity()
930    }
931
932    /// Returns a slice to the owned row value
933    fn row(&self) -> &[u8] {
934        self.row.as_slice()
935    }
936}
937
938impl Eq for TopKRow {}
939
940impl PartialOrd for TopKRow {
941    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
942        // TODO PartialOrd is not consistent with PartialEq; PartialOrd contract is violated
943        Some(self.cmp(other))
944    }
945}
946
947impl Ord for TopKRow {
948    fn cmp(&self, other: &Self) -> Ordering {
949        self.row.cmp(&other.row)
950    }
951}
952
953#[derive(Debug)]
954struct RecordBatchEntry {
955    id: u32,
956    batch: RecordBatch,
957    // for this batch, how many times has it been used
958    uses: usize,
959}
960
961/// This structure tracks [`RecordBatch`] by an id so that:
962///
963/// 1. The baches can be tracked via an id that can be copied cheaply
964/// 2. The total memory held by all batches is tracked
965#[derive(Debug)]
966struct RecordBatchStore {
967    /// id generator
968    next_id: u32,
969    /// storage
970    batches: HashMap<u32, RecordBatchEntry>,
971    /// total size of all record batches tracked by this store
972    batches_size: usize,
973}
974
975impl RecordBatchStore {
976    fn new() -> Self {
977        Self {
978            next_id: 0,
979            batches: HashMap::new(),
980            batches_size: 0,
981        }
982    }
983
984    /// Register this batch with the store and assign an ID. No
985    /// attempt is made to compare this batch to other batches
986    pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry {
987        let id = self.next_id;
988        self.next_id += 1;
989        RecordBatchEntry { id, batch, uses: 0 }
990    }
991
992    /// Insert a record batch entry into this store, tracking its
993    /// memory use, if it has any uses
994    pub fn insert(&mut self, entry: RecordBatchEntry) {
995        // uses of 0 means that none of the rows in the batch were stored in the topk
996        if entry.uses > 0 {
997            self.batches_size += get_record_batch_memory_size(&entry.batch);
998            self.batches.insert(entry.id, entry);
999        }
1000    }
1001
1002    /// Clear all values in this store, invalidating all previous batch ids
1003    fn clear(&mut self) {
1004        self.batches.clear();
1005        self.batches_size = 0;
1006    }
1007
1008    fn get(&self, id: u32) -> Option<&RecordBatchEntry> {
1009        self.batches.get(&id)
1010    }
1011
1012    /// returns the total number of batches stored in this store
1013    fn len(&self) -> usize {
1014        self.batches.len()
1015    }
1016
1017    /// Returns the total number of rows in batches minus the number
1018    /// which are in use
1019    fn unused_rows(&self) -> usize {
1020        self.batches
1021            .values()
1022            .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses)
1023            .sum()
1024    }
1025
1026    /// returns true if the store has nothing stored
1027    fn is_empty(&self) -> bool {
1028        self.batches.is_empty()
1029    }
1030
1031    /// remove a use from the specified batch id. If the use count
1032    /// reaches zero the batch entry is removed from the store
1033    ///
1034    /// panics if there were no remaining uses of id
1035    pub fn unuse(&mut self, id: u32) {
1036        let remove = if let Some(batch_entry) = self.batches.get_mut(&id) {
1037            batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow");
1038            batch_entry.uses == 0
1039        } else {
1040            panic!("No entry for id {id}");
1041        };
1042
1043        if remove {
1044            let old_entry = self.batches.remove(&id).unwrap();
1045            self.batches_size = self
1046                .batches_size
1047                .checked_sub(get_record_batch_memory_size(&old_entry.batch))
1048                .unwrap();
1049        }
1050    }
1051
1052    /// returns the size of memory used by this store, including all
1053    /// referenced `RecordBatch`es, in bytes
1054    pub fn size(&self) -> usize {
1055        size_of::<Self>()
1056            + self.batches.capacity() * (size_of::<u32>() + size_of::<RecordBatchEntry>())
1057            + self.batches_size
1058    }
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use arrow::array::{Float64Array, Int32Array, RecordBatch};
1065    use arrow::datatypes::{DataType, Field, Schema};
1066    use arrow_schema::SortOptions;
1067    use datafusion_common::assert_batches_eq;
1068    use datafusion_physical_expr::expressions::col;
1069    use futures::TryStreamExt;
1070
1071    /// This test ensures the size calculation is correct for RecordBatches with multiple columns.
1072    #[test]
1073    fn test_record_batch_store_size() {
1074        // given
1075        let schema = Arc::new(Schema::new(vec![
1076            Field::new("ints", DataType::Int32, true),
1077            Field::new("float64", DataType::Float64, false),
1078        ]));
1079        let mut record_batch_store = RecordBatchStore::new();
1080        let int_array =
1081            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
1082        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40
1083
1084        let record_batch_entry = RecordBatchEntry {
1085            id: 0,
1086            batch: RecordBatch::try_new(
1087                schema,
1088                vec![Arc::new(int_array), Arc::new(float64_array)],
1089            )
1090            .unwrap(),
1091            uses: 1,
1092        };
1093
1094        // when insert record batch entry
1095        record_batch_store.insert(record_batch_entry);
1096        assert_eq!(record_batch_store.batches_size, 60);
1097
1098        // when unuse record batch entry
1099        record_batch_store.unuse(0);
1100        assert_eq!(record_batch_store.batches_size, 0);
1101    }
1102
1103    /// This test validates that the `try_finish` method marks the TopK operator as finished
1104    /// when the prefix (on column "a") of the last row in the current batch is strictly greater
1105    /// than the max top‑k row.
1106    /// The full sort expression is defined on both columns ("a", "b"), but the input ordering is only on "a".
1107    #[tokio::test]
1108    async fn test_try_finish_marks_finished_with_prefix() -> Result<()> {
1109        // Create a schema with two columns.
1110        let schema = Arc::new(Schema::new(vec![
1111            Field::new("a", DataType::Int32, false),
1112            Field::new("b", DataType::Float64, false),
1113        ]));
1114
1115        // Create sort expressions.
1116        // Full sort: first by "a", then by "b".
1117        let sort_expr_a = PhysicalSortExpr {
1118            expr: col("a", schema.as_ref())?,
1119            options: SortOptions::default(),
1120        };
1121        let sort_expr_b = PhysicalSortExpr {
1122            expr: col("b", schema.as_ref())?,
1123            options: SortOptions::default(),
1124        };
1125
1126        // Input ordering uses only column "a" (a prefix of the full sort).
1127        let prefix = vec![sort_expr_a.clone()];
1128        let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]);
1129
1130        // Create a dummy runtime environment and metrics.
1131        let runtime = Arc::new(RuntimeEnv::default());
1132        let metrics = ExecutionPlanMetricsSet::new();
1133
1134        // Create a TopK instance with k = 3 and batch_size = 2.
1135        let mut topk = TopK::try_new(
1136            0,
1137            Arc::clone(&schema),
1138            prefix,
1139            full_expr,
1140            3,
1141            2,
1142            runtime,
1143            &metrics,
1144            Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new(
1145                DynamicFilterPhysicalExpr::new(vec![], lit(true)),
1146            )))),
1147        )?;
1148
1149        // Create the first batch with two columns:
1150        // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0].
1151        let array_a1: ArrayRef =
1152            Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)]));
1153        let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 30.0]));
1154        let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, array_b1])?;
1155
1156        // Insert the first batch.
1157        // At this point the heap is not yet “finished” because the prefix of the last row of the batch
1158        // is not strictly greater than the prefix of the max top‑k row (both being `2`).
1159        topk.insert_batch(batch1)?;
1160        assert!(
1161            !topk.finished,
1162            "Expected 'finished' to be false after the first batch."
1163        );
1164
1165        // Create the second batch with two columns:
1166        // Column "a": [2, 3], Column "b": [10.0, 20.0].
1167        let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), Some(3)]));
1168        let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0]));
1169        let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, array_b2])?;
1170
1171        // Insert the second batch.
1172        // The last row in this batch has a prefix value of `3`,
1173        // which is strictly greater than the max top‑k row (with value `2`),
1174        // so try_finish should mark the TopK as finished.
1175        topk.insert_batch(batch2)?;
1176        assert!(
1177            topk.finished,
1178            "Expected 'finished' to be true after the second batch."
1179        );
1180
1181        // Verify the TopK correctly emits the top k rows from both batches
1182        // (the value 10.0 for b is from the second batch).
1183        let results: Vec<_> = topk.emit()?.try_collect().await?;
1184        assert_batches_eq!(
1185            &[
1186                "+---+------+",
1187                "| a | b    |",
1188                "+---+------+",
1189                "| 1 | 15.0 |",
1190                "| 1 | 20.0 |",
1191                "| 2 | 10.0 |",
1192                "+---+------+",
1193            ],
1194            &results
1195        );
1196
1197        Ok(())
1198    }
1199}