datafusion_physical_plan/sorts/
cursor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::sync::Arc;
20
21use arrow::array::{
22    types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray,
23    GenericByteViewArray, OffsetSizeTrait, PrimitiveArray, StringViewArray,
24};
25use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
26use arrow::compute::SortOptions;
27use arrow::datatypes::ArrowNativeTypeOp;
28use arrow::row::Rows;
29use datafusion_execution::memory_pool::MemoryReservation;
30
31/// A comparable collection of values for use with [`Cursor`]
32///
33/// This is a trait as there are several specialized implementations, such as for
34/// single columns or for normalized multi column keys ([`Rows`])
35pub trait CursorValues {
36    fn len(&self) -> usize;
37
38    /// Returns true if `l[l_idx] == r[r_idx]`
39    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool;
40
41    /// Returns true if `row[idx] == row[idx - 1]`
42    /// Given `idx` should be greater than 0
43    fn eq_to_previous(cursor: &Self, idx: usize) -> bool;
44
45    /// Returns comparison of `l[l_idx]` and `r[r_idx]`
46    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering;
47}
48
49/// A comparable cursor, used by sort operations
50///
51/// A `Cursor` is a pointer into a collection of rows, stored in
52/// [`CursorValues`]
53///
54/// ```text
55///
56/// ┌───────────────────────┐
57/// │                       │           ┌──────────────────────┐
58/// │ ┌─────────┐ ┌─────┐   │    ─ ─ ─ ─│      Cursor<T>       │
59/// │ │    1    │ │  A  │   │   │       └──────────────────────┘
60/// │ ├─────────┤ ├─────┤   │
61/// │ │    2    │ │  A  │◀─ ┼ ─ ┘          Cursor<T> tracks an
62/// │ └─────────┘ └─────┘   │                offset within a
63/// │     ...       ...     │                  CursorValues
64/// │                       │
65/// │ ┌─────────┐ ┌─────┐   │
66/// │ │    3    │ │  E  │   │
67/// │ └─────────┘ └─────┘   │
68/// │                       │
69/// │     CursorValues      │
70/// └───────────────────────┘
71///
72///
73/// Store logical rows using
74/// one of several  formats,
75/// with specialized
76/// implementations
77/// depending on the column
78/// types
79#[derive(Debug)]
80pub struct Cursor<T: CursorValues> {
81    offset: usize,
82    values: T,
83}
84
85impl<T: CursorValues> Cursor<T> {
86    /// Create a [`Cursor`] from the given [`CursorValues`]
87    pub fn new(values: T) -> Self {
88        Self { offset: 0, values }
89    }
90
91    /// Returns true if there are no more rows in this cursor
92    pub fn is_finished(&self) -> bool {
93        self.offset == self.values.len()
94    }
95
96    /// Advance the cursor, returning the previous row index
97    pub fn advance(&mut self) -> usize {
98        let t = self.offset;
99        self.offset += 1;
100        t
101    }
102
103    pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor<T>>) -> bool {
104        if self.offset > 0 {
105            self.is_eq_to_prev_row()
106        } else if let Some(prev_cursor) = prev_cursor {
107            self.is_eq_to_prev_row_in_prev_batch(prev_cursor)
108        } else {
109            false
110        }
111    }
112}
113
114impl<T: CursorValues> PartialEq for Cursor<T> {
115    fn eq(&self, other: &Self) -> bool {
116        T::eq(&self.values, self.offset, &other.values, other.offset)
117    }
118}
119
120impl<T: CursorValues> Cursor<T> {
121    fn is_eq_to_prev_row(&self) -> bool {
122        T::eq_to_previous(&self.values, self.offset)
123    }
124
125    fn is_eq_to_prev_row_in_prev_batch(&self, other: &Self) -> bool {
126        assert_eq!(self.offset, 0);
127        T::eq(
128            &self.values,
129            self.offset,
130            &other.values,
131            other.values.len() - 1,
132        )
133    }
134}
135
136impl<T: CursorValues> Eq for Cursor<T> {}
137
138impl<T: CursorValues> PartialOrd for Cursor<T> {
139    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144impl<T: CursorValues> Ord for Cursor<T> {
145    fn cmp(&self, other: &Self) -> Ordering {
146        T::compare(&self.values, self.offset, &other.values, other.offset)
147    }
148}
149
150/// Implements [`CursorValues`] for [`Rows`]
151///
152/// Used for sorting when there are multiple columns in the sort key
153#[derive(Debug)]
154pub struct RowValues {
155    rows: Arc<Rows>,
156
157    /// Tracks for the memory used by in the `Rows` of this
158    /// cursor. Freed on drop
159    _reservation: MemoryReservation,
160}
161
162impl RowValues {
163    /// Create a new [`RowValues`] from `rows` and a `reservation`
164    /// that tracks its memory. There must be at least one row
165    ///
166    /// Panics if the reservation is not for exactly `rows.size()`
167    /// bytes or if `rows` is empty.
168    pub fn new(rows: Arc<Rows>, reservation: MemoryReservation) -> Self {
169        assert_eq!(
170            rows.size(),
171            reservation.size(),
172            "memory reservation mismatch"
173        );
174        assert!(rows.num_rows() > 0);
175        Self {
176            rows,
177            _reservation: reservation,
178        }
179    }
180}
181
182impl CursorValues for RowValues {
183    fn len(&self) -> usize {
184        self.rows.num_rows()
185    }
186
187    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
188        l.rows.row(l_idx) == r.rows.row(r_idx)
189    }
190
191    fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
192        assert!(idx > 0);
193        cursor.rows.row(idx) == cursor.rows.row(idx - 1)
194    }
195
196    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
197        l.rows.row(l_idx).cmp(&r.rows.row(r_idx))
198    }
199}
200
201/// An [`Array`] that can be converted into [`CursorValues`]
202pub trait CursorArray: Array + 'static {
203    type Values: CursorValues;
204
205    fn values(&self) -> Self::Values;
206}
207
208impl<T: ArrowPrimitiveType> CursorArray for PrimitiveArray<T> {
209    type Values = PrimitiveValues<T::Native>;
210
211    fn values(&self) -> Self::Values {
212        PrimitiveValues(self.values().clone())
213    }
214}
215
216#[derive(Debug)]
217pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);
218
219impl<T: ArrowNativeTypeOp> CursorValues for PrimitiveValues<T> {
220    fn len(&self) -> usize {
221        self.0.len()
222    }
223
224    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
225        l.0[l_idx].is_eq(r.0[r_idx])
226    }
227
228    fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
229        assert!(idx > 0);
230        cursor.0[idx].is_eq(cursor.0[idx - 1])
231    }
232
233    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
234        l.0[l_idx].compare(r.0[r_idx])
235    }
236}
237
238pub struct ByteArrayValues<T: OffsetSizeTrait> {
239    offsets: OffsetBuffer<T>,
240    values: Buffer,
241}
242
243impl<T: OffsetSizeTrait> ByteArrayValues<T> {
244    fn value(&self, idx: usize) -> &[u8] {
245        assert!(idx < self.len());
246        // Safety: offsets are valid and checked bounds above
247        unsafe {
248            let start = self.offsets.get_unchecked(idx).as_usize();
249            let end = self.offsets.get_unchecked(idx + 1).as_usize();
250            self.values.get_unchecked(start..end)
251        }
252    }
253}
254
255impl<T: OffsetSizeTrait> CursorValues for ByteArrayValues<T> {
256    fn len(&self) -> usize {
257        self.offsets.len() - 1
258    }
259
260    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
261        l.value(l_idx) == r.value(r_idx)
262    }
263
264    fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
265        assert!(idx > 0);
266        cursor.value(idx) == cursor.value(idx - 1)
267    }
268
269    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
270        l.value(l_idx).cmp(r.value(r_idx))
271    }
272}
273
274impl<T: ByteArrayType> CursorArray for GenericByteArray<T> {
275    type Values = ByteArrayValues<T::Offset>;
276
277    fn values(&self) -> Self::Values {
278        ByteArrayValues {
279            offsets: self.offsets().clone(),
280            values: self.values().clone(),
281        }
282    }
283}
284
285impl CursorArray for StringViewArray {
286    type Values = StringViewArray;
287    fn values(&self) -> Self {
288        self.gc()
289    }
290}
291
292impl CursorValues for StringViewArray {
293    fn len(&self) -> usize {
294        self.views().len()
295    }
296
297    #[inline(always)]
298    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
299        // SAFETY: Both l_idx and r_idx are guaranteed to be within bounds,
300        // and any null-checks are handled in the outer layers.
301        // Fast path: Compare the lengths before full byte comparison.
302        let l_view = unsafe { l.views().get_unchecked(l_idx) };
303        let r_view = unsafe { r.views().get_unchecked(r_idx) };
304
305        if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
306            return l_view == r_view;
307        }
308
309        let l_len = *l_view as u32;
310        let r_len = *r_view as u32;
311        if l_len != r_len {
312            return false;
313        }
314
315        unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() }
316    }
317
318    #[inline(always)]
319    fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
320        // SAFETY: The caller guarantees that idx > 0 and the indices are valid.
321        // Already checked it in is_eq_to_prev_one function
322        // Fast path: Compare the lengths of the current and previous views.
323        let l_view = unsafe { cursor.views().get_unchecked(idx) };
324        let r_view = unsafe { cursor.views().get_unchecked(idx - 1) };
325        if cursor.data_buffers().is_empty() {
326            return l_view == r_view;
327        }
328
329        let l_len = *l_view as u32;
330        let r_len = *r_view as u32;
331
332        if l_len != r_len {
333            return false;
334        }
335
336        unsafe {
337            GenericByteViewArray::compare_unchecked(cursor, idx, cursor, idx - 1).is_eq()
338        }
339    }
340
341    #[inline(always)]
342    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
343        // SAFETY: Prior assertions guarantee that l_idx and r_idx are valid indices.
344        // Null-checks are assumed to have been handled in the wrapper (e.g., ArrayValues).
345        // And the bound is checked in is_finished, it is safe to call get_unchecked
346        if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
347            let l_view = unsafe { l.views().get_unchecked(l_idx) };
348            let r_view = unsafe { r.views().get_unchecked(r_idx) };
349            return StringViewArray::inline_key_fast(*l_view)
350                .cmp(&StringViewArray::inline_key_fast(*r_view));
351        }
352
353        unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) }
354    }
355}
356
357/// A collection of sorted, nullable [`CursorValues`]
358///
359/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
360#[derive(Debug)]
361pub struct ArrayValues<T: CursorValues> {
362    values: T,
363    // If nulls first, the first non-null index
364    // Otherwise, the first null index
365    null_threshold: usize,
366    options: SortOptions,
367
368    /// Tracks the memory used by the values array,
369    /// freed on drop.
370    _reservation: MemoryReservation,
371}
372
373impl<T: CursorValues> ArrayValues<T> {
374    /// Create a new [`ArrayValues`] from the provided `values` sorted according
375    /// to `options`.
376    ///
377    /// Panics if the array is empty
378    pub fn new<A: CursorArray<Values = T>>(
379        options: SortOptions,
380        array: &A,
381        reservation: MemoryReservation,
382    ) -> Self {
383        assert!(array.len() > 0, "Empty array passed to FieldCursor");
384        let null_threshold = match options.nulls_first {
385            true => array.null_count(),
386            false => array.len() - array.null_count(),
387        };
388
389        Self {
390            values: array.values(),
391            null_threshold,
392            options,
393            _reservation: reservation,
394        }
395    }
396
397    fn is_null(&self, idx: usize) -> bool {
398        (idx < self.null_threshold) == self.options.nulls_first
399    }
400}
401
402impl<T: CursorValues> CursorValues for ArrayValues<T> {
403    fn len(&self) -> usize {
404        self.values.len()
405    }
406
407    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
408        match (l.is_null(l_idx), r.is_null(r_idx)) {
409            (true, true) => true,
410            (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx),
411            _ => false,
412        }
413    }
414
415    fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
416        assert!(idx > 0);
417        match (cursor.is_null(idx), cursor.is_null(idx - 1)) {
418            (true, true) => true,
419            (false, false) => T::eq(&cursor.values, idx, &cursor.values, idx - 1),
420            _ => false,
421        }
422    }
423
424    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
425        match (l.is_null(l_idx), r.is_null(r_idx)) {
426            (true, true) => Ordering::Equal,
427            (true, false) => match l.options.nulls_first {
428                true => Ordering::Less,
429                false => Ordering::Greater,
430            },
431            (false, true) => match l.options.nulls_first {
432                true => Ordering::Greater,
433                false => Ordering::Less,
434            },
435            (false, false) => match l.options.descending {
436                true => T::compare(&r.values, r_idx, &l.values, l_idx),
437                false => T::compare(&l.values, l_idx, &r.values, r_idx),
438            },
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use datafusion_execution::memory_pool::{
446        GreedyMemoryPool, MemoryConsumer, MemoryPool,
447    };
448    use std::sync::Arc;
449
450    use super::*;
451
452    fn new_primitive(
453        options: SortOptions,
454        values: ScalarBuffer<i32>,
455        null_count: usize,
456    ) -> Cursor<ArrayValues<PrimitiveValues<i32>>> {
457        let null_threshold = match options.nulls_first {
458            true => null_count,
459            false => values.len() - null_count,
460        };
461
462        let memory_pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(10000));
463        let consumer = MemoryConsumer::new("test");
464        let reservation = consumer.register(&memory_pool);
465
466        let values = ArrayValues {
467            values: PrimitiveValues(values),
468            null_threshold,
469            options,
470            _reservation: reservation,
471        };
472
473        Cursor::new(values)
474    }
475
476    #[test]
477    fn test_primitive_nulls_first() {
478        let options = SortOptions {
479            descending: false,
480            nulls_first: true,
481        };
482
483        let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
484        let mut a = new_primitive(options, buffer, 1);
485        let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
486        let mut b = new_primitive(options, buffer, 2);
487
488        // NULL == NULL
489        assert_eq!(a.cmp(&b), Ordering::Equal);
490        assert_eq!(a, b);
491
492        // NULL == NULL
493        b.advance();
494        assert_eq!(a.cmp(&b), Ordering::Equal);
495        assert_eq!(a, b);
496
497        // NULL < -2
498        b.advance();
499        assert_eq!(a.cmp(&b), Ordering::Less);
500
501        // 1 > -2
502        a.advance();
503        assert_eq!(a.cmp(&b), Ordering::Greater);
504
505        // 1 > -1
506        b.advance();
507        assert_eq!(a.cmp(&b), Ordering::Greater);
508
509        // 1 == 1
510        b.advance();
511        assert_eq!(a.cmp(&b), Ordering::Equal);
512        assert_eq!(a, b);
513
514        // 9 > 1
515        b.advance();
516        assert_eq!(a.cmp(&b), Ordering::Less);
517
518        // 9 > 2
519        a.advance();
520        assert_eq!(a.cmp(&b), Ordering::Less);
521
522        let options = SortOptions {
523            descending: false,
524            nulls_first: false,
525        };
526
527        let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
528        let mut a = new_primitive(options, buffer, 2);
529        let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
530        let mut b = new_primitive(options, buffer, 2);
531
532        // 0 > -1
533        assert_eq!(a.cmp(&b), Ordering::Greater);
534
535        // 0 < NULL
536        b.advance();
537        assert_eq!(a.cmp(&b), Ordering::Less);
538
539        // 1 < NULL
540        a.advance();
541        assert_eq!(a.cmp(&b), Ordering::Less);
542
543        // NULL = NULL
544        a.advance();
545        assert_eq!(a.cmp(&b), Ordering::Equal);
546        assert_eq!(a, b);
547
548        let options = SortOptions {
549            descending: true,
550            nulls_first: false,
551        };
552
553        let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
554        let mut a = new_primitive(options, buffer, 3);
555        let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
556        let mut b = new_primitive(options, buffer, 2);
557
558        // 6 > 67
559        assert_eq!(a.cmp(&b), Ordering::Greater);
560
561        // 6 < -3
562        b.advance();
563        assert_eq!(a.cmp(&b), Ordering::Less);
564
565        // 6 < NULL
566        b.advance();
567        assert_eq!(a.cmp(&b), Ordering::Less);
568
569        // 6 < NULL
570        b.advance();
571        assert_eq!(a.cmp(&b), Ordering::Less);
572
573        // NULL == NULL
574        a.advance();
575        assert_eq!(a.cmp(&b), Ordering::Equal);
576        assert_eq!(a, b);
577
578        let options = SortOptions {
579            descending: true,
580            nulls_first: true,
581        };
582
583        let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
584        let mut a = new_primitive(options, buffer, 2);
585        let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
586        let mut b = new_primitive(options, buffer, 1);
587
588        // NULL == NULL
589        assert_eq!(a.cmp(&b), Ordering::Equal);
590        assert_eq!(a, b);
591
592        // NULL == NULL
593        a.advance();
594        assert_eq!(a.cmp(&b), Ordering::Equal);
595        assert_eq!(a, b);
596
597        // NULL < 4546
598        b.advance();
599        assert_eq!(a.cmp(&b), Ordering::Less);
600
601        // 6 > 4546
602        a.advance();
603        assert_eq!(a.cmp(&b), Ordering::Greater);
604
605        // 6 < -3
606        b.advance();
607        assert_eq!(a.cmp(&b), Ordering::Less);
608    }
609}