1use std::cmp::Ordering;
19use std::sync::Arc;
20
21use arrow::array::{
22 types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray,
23 GenericByteViewArray, OffsetSizeTrait, PrimitiveArray, StringViewArray,
24};
25use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
26use arrow::compute::SortOptions;
27use arrow::datatypes::ArrowNativeTypeOp;
28use arrow::row::Rows;
29use datafusion_execution::memory_pool::MemoryReservation;
30
31pub trait CursorValues {
36 fn len(&self) -> usize;
37
38 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool;
40
41 fn eq_to_previous(cursor: &Self, idx: usize) -> bool;
44
45 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering;
47}
48
49#[derive(Debug)]
80pub struct Cursor<T: CursorValues> {
81 offset: usize,
82 values: T,
83}
84
85impl<T: CursorValues> Cursor<T> {
86 pub fn new(values: T) -> Self {
88 Self { offset: 0, values }
89 }
90
91 pub fn is_finished(&self) -> bool {
93 self.offset == self.values.len()
94 }
95
96 pub fn advance(&mut self) -> usize {
98 let t = self.offset;
99 self.offset += 1;
100 t
101 }
102
103 pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor<T>>) -> bool {
104 if self.offset > 0 {
105 self.is_eq_to_prev_row()
106 } else if let Some(prev_cursor) = prev_cursor {
107 self.is_eq_to_prev_row_in_prev_batch(prev_cursor)
108 } else {
109 false
110 }
111 }
112}
113
114impl<T: CursorValues> PartialEq for Cursor<T> {
115 fn eq(&self, other: &Self) -> bool {
116 T::eq(&self.values, self.offset, &other.values, other.offset)
117 }
118}
119
120impl<T: CursorValues> Cursor<T> {
121 fn is_eq_to_prev_row(&self) -> bool {
122 T::eq_to_previous(&self.values, self.offset)
123 }
124
125 fn is_eq_to_prev_row_in_prev_batch(&self, other: &Self) -> bool {
126 assert_eq!(self.offset, 0);
127 T::eq(
128 &self.values,
129 self.offset,
130 &other.values,
131 other.values.len() - 1,
132 )
133 }
134}
135
136impl<T: CursorValues> Eq for Cursor<T> {}
137
138impl<T: CursorValues> PartialOrd for Cursor<T> {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144impl<T: CursorValues> Ord for Cursor<T> {
145 fn cmp(&self, other: &Self) -> Ordering {
146 T::compare(&self.values, self.offset, &other.values, other.offset)
147 }
148}
149
150#[derive(Debug)]
154pub struct RowValues {
155 rows: Arc<Rows>,
156
157 _reservation: MemoryReservation,
160}
161
162impl RowValues {
163 pub fn new(rows: Arc<Rows>, reservation: MemoryReservation) -> Self {
169 assert_eq!(
170 rows.size(),
171 reservation.size(),
172 "memory reservation mismatch"
173 );
174 assert!(rows.num_rows() > 0);
175 Self {
176 rows,
177 _reservation: reservation,
178 }
179 }
180}
181
182impl CursorValues for RowValues {
183 fn len(&self) -> usize {
184 self.rows.num_rows()
185 }
186
187 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
188 l.rows.row(l_idx) == r.rows.row(r_idx)
189 }
190
191 fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
192 assert!(idx > 0);
193 cursor.rows.row(idx) == cursor.rows.row(idx - 1)
194 }
195
196 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
197 l.rows.row(l_idx).cmp(&r.rows.row(r_idx))
198 }
199}
200
201pub trait CursorArray: Array + 'static {
203 type Values: CursorValues;
204
205 fn values(&self) -> Self::Values;
206}
207
208impl<T: ArrowPrimitiveType> CursorArray for PrimitiveArray<T> {
209 type Values = PrimitiveValues<T::Native>;
210
211 fn values(&self) -> Self::Values {
212 PrimitiveValues(self.values().clone())
213 }
214}
215
216#[derive(Debug)]
217pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);
218
219impl<T: ArrowNativeTypeOp> CursorValues for PrimitiveValues<T> {
220 fn len(&self) -> usize {
221 self.0.len()
222 }
223
224 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
225 l.0[l_idx].is_eq(r.0[r_idx])
226 }
227
228 fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
229 assert!(idx > 0);
230 cursor.0[idx].is_eq(cursor.0[idx - 1])
231 }
232
233 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
234 l.0[l_idx].compare(r.0[r_idx])
235 }
236}
237
238pub struct ByteArrayValues<T: OffsetSizeTrait> {
239 offsets: OffsetBuffer<T>,
240 values: Buffer,
241}
242
243impl<T: OffsetSizeTrait> ByteArrayValues<T> {
244 fn value(&self, idx: usize) -> &[u8] {
245 assert!(idx < self.len());
246 unsafe {
248 let start = self.offsets.get_unchecked(idx).as_usize();
249 let end = self.offsets.get_unchecked(idx + 1).as_usize();
250 self.values.get_unchecked(start..end)
251 }
252 }
253}
254
255impl<T: OffsetSizeTrait> CursorValues for ByteArrayValues<T> {
256 fn len(&self) -> usize {
257 self.offsets.len() - 1
258 }
259
260 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
261 l.value(l_idx) == r.value(r_idx)
262 }
263
264 fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
265 assert!(idx > 0);
266 cursor.value(idx) == cursor.value(idx - 1)
267 }
268
269 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
270 l.value(l_idx).cmp(r.value(r_idx))
271 }
272}
273
274impl<T: ByteArrayType> CursorArray for GenericByteArray<T> {
275 type Values = ByteArrayValues<T::Offset>;
276
277 fn values(&self) -> Self::Values {
278 ByteArrayValues {
279 offsets: self.offsets().clone(),
280 values: self.values().clone(),
281 }
282 }
283}
284
285impl CursorArray for StringViewArray {
286 type Values = StringViewArray;
287 fn values(&self) -> Self {
288 self.gc()
289 }
290}
291
292impl CursorValues for StringViewArray {
293 fn len(&self) -> usize {
294 self.views().len()
295 }
296
297 #[inline(always)]
298 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
299 let l_view = unsafe { l.views().get_unchecked(l_idx) };
303 let r_view = unsafe { r.views().get_unchecked(r_idx) };
304
305 if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
306 return l_view == r_view;
307 }
308
309 let l_len = *l_view as u32;
310 let r_len = *r_view as u32;
311 if l_len != r_len {
312 return false;
313 }
314
315 unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() }
316 }
317
318 #[inline(always)]
319 fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
320 let l_view = unsafe { cursor.views().get_unchecked(idx) };
324 let r_view = unsafe { cursor.views().get_unchecked(idx - 1) };
325 if cursor.data_buffers().is_empty() {
326 return l_view == r_view;
327 }
328
329 let l_len = *l_view as u32;
330 let r_len = *r_view as u32;
331
332 if l_len != r_len {
333 return false;
334 }
335
336 unsafe {
337 GenericByteViewArray::compare_unchecked(cursor, idx, cursor, idx - 1).is_eq()
338 }
339 }
340
341 #[inline(always)]
342 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
343 if l.data_buffers().is_empty() && r.data_buffers().is_empty() {
347 let l_view = unsafe { l.views().get_unchecked(l_idx) };
348 let r_view = unsafe { r.views().get_unchecked(r_idx) };
349 return StringViewArray::inline_key_fast(*l_view)
350 .cmp(&StringViewArray::inline_key_fast(*r_view));
351 }
352
353 unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) }
354 }
355}
356
357#[derive(Debug)]
361pub struct ArrayValues<T: CursorValues> {
362 values: T,
363 null_threshold: usize,
366 options: SortOptions,
367
368 _reservation: MemoryReservation,
371}
372
373impl<T: CursorValues> ArrayValues<T> {
374 pub fn new<A: CursorArray<Values = T>>(
379 options: SortOptions,
380 array: &A,
381 reservation: MemoryReservation,
382 ) -> Self {
383 assert!(array.len() > 0, "Empty array passed to FieldCursor");
384 let null_threshold = match options.nulls_first {
385 true => array.null_count(),
386 false => array.len() - array.null_count(),
387 };
388
389 Self {
390 values: array.values(),
391 null_threshold,
392 options,
393 _reservation: reservation,
394 }
395 }
396
397 fn is_null(&self, idx: usize) -> bool {
398 (idx < self.null_threshold) == self.options.nulls_first
399 }
400}
401
402impl<T: CursorValues> CursorValues for ArrayValues<T> {
403 fn len(&self) -> usize {
404 self.values.len()
405 }
406
407 fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
408 match (l.is_null(l_idx), r.is_null(r_idx)) {
409 (true, true) => true,
410 (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx),
411 _ => false,
412 }
413 }
414
415 fn eq_to_previous(cursor: &Self, idx: usize) -> bool {
416 assert!(idx > 0);
417 match (cursor.is_null(idx), cursor.is_null(idx - 1)) {
418 (true, true) => true,
419 (false, false) => T::eq(&cursor.values, idx, &cursor.values, idx - 1),
420 _ => false,
421 }
422 }
423
424 fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
425 match (l.is_null(l_idx), r.is_null(r_idx)) {
426 (true, true) => Ordering::Equal,
427 (true, false) => match l.options.nulls_first {
428 true => Ordering::Less,
429 false => Ordering::Greater,
430 },
431 (false, true) => match l.options.nulls_first {
432 true => Ordering::Greater,
433 false => Ordering::Less,
434 },
435 (false, false) => match l.options.descending {
436 true => T::compare(&r.values, r_idx, &l.values, l_idx),
437 false => T::compare(&l.values, l_idx, &r.values, r_idx),
438 },
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use datafusion_execution::memory_pool::{
446 GreedyMemoryPool, MemoryConsumer, MemoryPool,
447 };
448 use std::sync::Arc;
449
450 use super::*;
451
452 fn new_primitive(
453 options: SortOptions,
454 values: ScalarBuffer<i32>,
455 null_count: usize,
456 ) -> Cursor<ArrayValues<PrimitiveValues<i32>>> {
457 let null_threshold = match options.nulls_first {
458 true => null_count,
459 false => values.len() - null_count,
460 };
461
462 let memory_pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(10000));
463 let consumer = MemoryConsumer::new("test");
464 let reservation = consumer.register(&memory_pool);
465
466 let values = ArrayValues {
467 values: PrimitiveValues(values),
468 null_threshold,
469 options,
470 _reservation: reservation,
471 };
472
473 Cursor::new(values)
474 }
475
476 #[test]
477 fn test_primitive_nulls_first() {
478 let options = SortOptions {
479 descending: false,
480 nulls_first: true,
481 };
482
483 let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
484 let mut a = new_primitive(options, buffer, 1);
485 let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
486 let mut b = new_primitive(options, buffer, 2);
487
488 assert_eq!(a.cmp(&b), Ordering::Equal);
490 assert_eq!(a, b);
491
492 b.advance();
494 assert_eq!(a.cmp(&b), Ordering::Equal);
495 assert_eq!(a, b);
496
497 b.advance();
499 assert_eq!(a.cmp(&b), Ordering::Less);
500
501 a.advance();
503 assert_eq!(a.cmp(&b), Ordering::Greater);
504
505 b.advance();
507 assert_eq!(a.cmp(&b), Ordering::Greater);
508
509 b.advance();
511 assert_eq!(a.cmp(&b), Ordering::Equal);
512 assert_eq!(a, b);
513
514 b.advance();
516 assert_eq!(a.cmp(&b), Ordering::Less);
517
518 a.advance();
520 assert_eq!(a.cmp(&b), Ordering::Less);
521
522 let options = SortOptions {
523 descending: false,
524 nulls_first: false,
525 };
526
527 let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
528 let mut a = new_primitive(options, buffer, 2);
529 let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
530 let mut b = new_primitive(options, buffer, 2);
531
532 assert_eq!(a.cmp(&b), Ordering::Greater);
534
535 b.advance();
537 assert_eq!(a.cmp(&b), Ordering::Less);
538
539 a.advance();
541 assert_eq!(a.cmp(&b), Ordering::Less);
542
543 a.advance();
545 assert_eq!(a.cmp(&b), Ordering::Equal);
546 assert_eq!(a, b);
547
548 let options = SortOptions {
549 descending: true,
550 nulls_first: false,
551 };
552
553 let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
554 let mut a = new_primitive(options, buffer, 3);
555 let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
556 let mut b = new_primitive(options, buffer, 2);
557
558 assert_eq!(a.cmp(&b), Ordering::Greater);
560
561 b.advance();
563 assert_eq!(a.cmp(&b), Ordering::Less);
564
565 b.advance();
567 assert_eq!(a.cmp(&b), Ordering::Less);
568
569 b.advance();
571 assert_eq!(a.cmp(&b), Ordering::Less);
572
573 a.advance();
575 assert_eq!(a.cmp(&b), Ordering::Equal);
576 assert_eq!(a, b);
577
578 let options = SortOptions {
579 descending: true,
580 nulls_first: true,
581 };
582
583 let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
584 let mut a = new_primitive(options, buffer, 2);
585 let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
586 let mut b = new_primitive(options, buffer, 1);
587
588 assert_eq!(a.cmp(&b), Ordering::Equal);
590 assert_eq!(a, b);
591
592 a.advance();
594 assert_eq!(a.cmp(&b), Ordering::Equal);
595 assert_eq!(a, b);
596
597 b.advance();
599 assert_eq!(a.cmp(&b), Ordering::Less);
600
601 a.advance();
603 assert_eq!(a.cmp(&b), Ordering::Greater);
604
605 b.advance();
607 assert_eq!(a.cmp(&b), Ordering::Less);
608 }
609}