datafusion_physical_plan/spill/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the spilling functions
19
20pub(crate) mod in_progress_spill_file;
21pub(crate) mod spill_manager;
22
23use std::fs::File;
24use std::io::BufReader;
25use std::path::{Path, PathBuf};
26use std::pin::Pin;
27use std::ptr::NonNull;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30
31use arrow::array::{layout, ArrayData, BufferSpec};
32use arrow::datatypes::{Schema, SchemaRef};
33use arrow::ipc::{
34    reader::StreamReader,
35    writer::{IpcWriteOptions, StreamWriter},
36    MetadataVersion,
37};
38use arrow::record_batch::RecordBatch;
39
40use datafusion_common::config::SpillCompression;
41use datafusion_common::{exec_datafusion_err, DataFusionError, HashSet, Result};
42use datafusion_common_runtime::SpawnedTask;
43use datafusion_execution::disk_manager::RefCountedTempFile;
44use datafusion_execution::RecordBatchStream;
45use futures::{FutureExt as _, Stream};
46use log::warn;
47
48/// Stream that reads spill files from disk where each batch is read in a spawned blocking task
49/// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`]
50///
51/// A simpler solution would be spawning a long-running blocking task for each
52/// file read (instead of each batch). This approach does not work because when
53/// the number of concurrent reads exceeds the Tokio thread pool limit,
54/// deadlocks can occur and block progress.
55struct SpillReaderStream {
56    schema: SchemaRef,
57    state: SpillReaderStreamState,
58    /// Maximum memory size observed among spilling sorted record batches.
59    /// This is used for validation purposes during reading each RecordBatch from spill.
60    /// For context on why this value is recorded and validated,
61    /// see `physical_plan/sort/multi_level_merge.rs`.
62    max_record_batch_memory: Option<usize>,
63}
64
65// Small margin allowed to accommodate slight memory accounting variation
66const SPILL_BATCH_MEMORY_MARGIN: usize = 4096;
67
68/// When we poll for the next batch, we will get back both the batch and the reader,
69/// so we can call `next` again.
70type NextRecordBatchResult = Result<(StreamReader<BufReader<File>>, Option<RecordBatch>)>;
71
72enum SpillReaderStreamState {
73    /// Initial state: the stream was not initialized yet
74    /// and the file was not opened
75    Uninitialized(RefCountedTempFile),
76
77    /// A read is in progress in a spawned blocking task for which we hold the handle.
78    ReadInProgress(SpawnedTask<NextRecordBatchResult>),
79
80    /// A read has finished and we wait for being polled again in order to start reading the next batch.
81    Waiting(StreamReader<BufReader<File>>),
82
83    /// The stream has finished, successfully or not.
84    Done,
85}
86
87impl SpillReaderStream {
88    fn new(
89        schema: SchemaRef,
90        spill_file: RefCountedTempFile,
91        max_record_batch_memory: Option<usize>,
92    ) -> Self {
93        Self {
94            schema,
95            state: SpillReaderStreamState::Uninitialized(spill_file),
96            max_record_batch_memory,
97        }
98    }
99
100    fn poll_next_inner(
101        &mut self,
102        cx: &mut Context<'_>,
103    ) -> Poll<Option<Result<RecordBatch>>> {
104        match &mut self.state {
105            SpillReaderStreamState::Uninitialized(_) => {
106                // Temporarily replace with `Done` to be able to pass the file to the task.
107                let SpillReaderStreamState::Uninitialized(spill_file) =
108                    std::mem::replace(&mut self.state, SpillReaderStreamState::Done)
109                else {
110                    unreachable!()
111                };
112
113                let task = SpawnedTask::spawn_blocking(move || {
114                    let file = BufReader::new(File::open(spill_file.path())?);
115                    // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
116                    // with validated schemas and buffers. Skip redundant validation during read
117                    // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written.
118                    let mut reader = unsafe {
119                        StreamReader::try_new(file, None)?.with_skip_validation(true)
120                    };
121
122                    let next_batch = reader.next().transpose()?;
123
124                    Ok((reader, next_batch))
125                });
126
127                self.state = SpillReaderStreamState::ReadInProgress(task);
128
129                // Poll again immediately so the inner task is polled and the waker is
130                // registered.
131                self.poll_next_inner(cx)
132            }
133
134            SpillReaderStreamState::ReadInProgress(task) => {
135                let result = futures::ready!(task.poll_unpin(cx))
136                    .unwrap_or_else(|err| Err(DataFusionError::External(Box::new(err))));
137
138                match result {
139                    Ok((reader, batch)) => {
140                        match batch {
141                            Some(batch) => {
142                                if let Some(max_record_batch_memory) =
143                                    self.max_record_batch_memory
144                                {
145                                    let actual_size =
146                                        get_record_batch_memory_size(&batch);
147                                    if actual_size
148                                        > max_record_batch_memory
149                                            + SPILL_BATCH_MEMORY_MARGIN
150                                    {
151                                        warn!(
152                                                "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\
153                                                by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\
154                                                This likely indicates a bug in memory accounting during spilling.\n\
155                                                Please report this issue in https://github.com/apache/datafusion/issues/17340."
156                                            );
157                                    }
158                                }
159                                self.state = SpillReaderStreamState::Waiting(reader);
160
161                                Poll::Ready(Some(Ok(batch)))
162                            }
163                            None => {
164                                // Stream is done
165                                self.state = SpillReaderStreamState::Done;
166
167                                Poll::Ready(None)
168                            }
169                        }
170                    }
171                    Err(err) => {
172                        self.state = SpillReaderStreamState::Done;
173
174                        Poll::Ready(Some(Err(err)))
175                    }
176                }
177            }
178
179            SpillReaderStreamState::Waiting(_) => {
180                // Temporarily replace with `Done` to be able to pass the file to the task.
181                let SpillReaderStreamState::Waiting(mut reader) =
182                    std::mem::replace(&mut self.state, SpillReaderStreamState::Done)
183                else {
184                    unreachable!()
185                };
186
187                let task = SpawnedTask::spawn_blocking(move || {
188                    let next_batch = reader.next().transpose()?;
189
190                    Ok((reader, next_batch))
191                });
192
193                self.state = SpillReaderStreamState::ReadInProgress(task);
194
195                // Poll again immediately so the inner task is polled and the waker is
196                // registered.
197                self.poll_next_inner(cx)
198            }
199
200            SpillReaderStreamState::Done => Poll::Ready(None),
201        }
202    }
203}
204
205impl Stream for SpillReaderStream {
206    type Item = Result<RecordBatch>;
207
208    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209        self.get_mut().poll_next_inner(cx)
210    }
211}
212
213impl RecordBatchStream for SpillReaderStream {
214    fn schema(&self) -> SchemaRef {
215        Arc::clone(&self.schema)
216    }
217}
218
219/// Spill the `RecordBatch` to disk as smaller batches
220/// split by `batch_size_rows`
221#[deprecated(
222    since = "46.0.0",
223    note = "This method is deprecated. Use `SpillManager::spill_record_batch_by_size` instead."
224)]
225pub fn spill_record_batch_by_size(
226    batch: &RecordBatch,
227    path: PathBuf,
228    schema: SchemaRef,
229    batch_size_rows: usize,
230) -> Result<()> {
231    let mut offset = 0;
232    let total_rows = batch.num_rows();
233    let mut writer =
234        IPCStreamWriter::new(&path, schema.as_ref(), SpillCompression::Uncompressed)?;
235
236    while offset < total_rows {
237        let length = std::cmp::min(total_rows - offset, batch_size_rows);
238        let batch = batch.slice(offset, length);
239        offset += batch.num_rows();
240        writer.write(&batch)?;
241    }
242    writer.finish()?;
243
244    Ok(())
245}
246
247/// Calculate total used memory of this batch.
248///
249/// This function is used to estimate the physical memory usage of the `RecordBatch`.
250/// It only counts the memory of large data `Buffer`s, and ignores metadata like
251/// types and pointers.
252/// The implementation will add up all unique `Buffer`'s memory
253/// size, due to:
254/// - The data pointer inside `Buffer` are memory regions returned by global memory
255///   allocator, those regions can't have overlap.
256/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap
257///   or reuse the same `Buffer`. For example: taking a slice from `Array`.
258///
259/// Example:
260/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing
261/// to a sub-region of the same buffer.
262///
263/// {xxxxxxxxxxxxxxxxxxx} <--- buffer
264///       ^    ^  ^    ^
265///       |    |  |    |
266/// col1->{    }  |    |
267/// col2--------->{    }
268///
269/// In the above case, `get_record_batch_memory_size` will return the size of
270/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size.
271///
272/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the
273/// buffer memory size if multiple arrays within the batch are sharing the same
274/// `Buffer`. This method provides temporary fix until the issue is resolved:
275/// <https://github.com/apache/arrow-rs/issues/6439>
276pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize {
277    // Store pointers to `Buffer`'s start memory address (instead of actual
278    // used data region's pointer represented by current `Array`)
279    let mut counted_buffers: HashSet<NonNull<u8>> = HashSet::new();
280    let mut total_size = 0;
281
282    for array in batch.columns() {
283        let array_data = array.to_data();
284        count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size);
285    }
286
287    total_size
288}
289
290/// Count the memory usage of `array_data` and its children recursively.
291fn count_array_data_memory_size(
292    array_data: &ArrayData,
293    counted_buffers: &mut HashSet<NonNull<u8>>,
294    total_size: &mut usize,
295) {
296    // Count memory usage for `array_data`
297    for buffer in array_data.buffers() {
298        if counted_buffers.insert(buffer.data_ptr()) {
299            *total_size += buffer.capacity();
300        } // Otherwise the buffer's memory is already counted
301    }
302
303    if let Some(null_buffer) = array_data.nulls() {
304        if counted_buffers.insert(null_buffer.inner().inner().data_ptr()) {
305            *total_size += null_buffer.inner().inner().capacity();
306        }
307    }
308
309    // Count all children `ArrayData` recursively
310    for child in array_data.child_data() {
311        count_array_data_memory_size(child, counted_buffers, total_size);
312    }
313}
314
315/// Write in Arrow IPC Stream format to a file.
316///
317/// Stream format is used for spill because it supports dictionary replacement, and the random
318/// access of IPC File format is not needed (IPC File format doesn't support dictionary replacement).
319struct IPCStreamWriter {
320    /// Inner writer
321    pub writer: StreamWriter<File>,
322    /// Batches written
323    pub num_batches: usize,
324    /// Rows written
325    pub num_rows: usize,
326    /// Bytes written
327    pub num_bytes: usize,
328}
329
330impl IPCStreamWriter {
331    /// Create new writer
332    pub fn new(
333        path: &Path,
334        schema: &Schema,
335        compression_type: SpillCompression,
336    ) -> Result<Self> {
337        let file = File::create(path).map_err(|e| {
338            exec_datafusion_err!("(Hint: you may increase the file descriptor limit with shell command 'ulimit -n 4096') Failed to create partition file at {path:?}: {e:?}")
339        })?;
340
341        let metadata_version = MetadataVersion::V5;
342        // Depending on the schema, some array types such as StringViewArray require larger (16 byte in this case) alignment.
343        // If the actual buffer layout after IPC read does not satisfy the alignment requirement,
344        // Arrow ArrayBuilder will copy the buffer into a newly allocated, properly aligned buffer.
345        // This copying may lead to memory blowup during IPC read due to duplicated buffers.
346        // To avoid this, we compute the maximum required alignment based on the schema and configure the IPCStreamWriter accordingly.
347        let alignment = get_max_alignment_for_schema(schema);
348        let mut write_options =
349            IpcWriteOptions::try_new(alignment, false, metadata_version)?;
350        write_options = write_options.try_with_compression(compression_type.into())?;
351
352        let writer = StreamWriter::try_new_with_options(file, schema, write_options)?;
353        Ok(Self {
354            num_batches: 0,
355            num_rows: 0,
356            num_bytes: 0,
357            writer,
358        })
359    }
360
361    /// Writes a single batch to the IPC stream and updates the internal counters.
362    ///
363    /// Returns a tuple containing the change in the number of rows and bytes written.
364    pub fn write(&mut self, batch: &RecordBatch) -> Result<(usize, usize)> {
365        self.writer.write(batch)?;
366        self.num_batches += 1;
367        let delta_num_rows = batch.num_rows();
368        self.num_rows += delta_num_rows;
369        let delta_num_bytes: usize = batch.get_array_memory_size();
370        self.num_bytes += delta_num_bytes;
371        Ok((delta_num_rows, delta_num_bytes))
372    }
373
374    /// Finish the writer
375    pub fn finish(&mut self) -> Result<()> {
376        self.writer.finish().map_err(Into::into)
377    }
378}
379
380// Returns the maximum byte alignment required by any field in the schema (>= 8), derived from Arrow buffer layouts.
381fn get_max_alignment_for_schema(schema: &Schema) -> usize {
382    let minimum_alignment = 8;
383    let mut max_alignment = minimum_alignment;
384    for field in schema.fields() {
385        let layout = layout(field.data_type());
386        let required_alignment = layout
387            .buffers
388            .iter()
389            .map(|buffer_spec| {
390                if let BufferSpec::FixedWidth { alignment, .. } = buffer_spec {
391                    *alignment
392                } else {
393                    minimum_alignment
394                }
395            })
396            .max()
397            .unwrap_or(minimum_alignment);
398        max_alignment = std::cmp::max(max_alignment, required_alignment);
399    }
400    max_alignment
401}
402
403#[cfg(test)]
404mod tests {
405    use super::in_progress_spill_file::InProgressSpillFile;
406    use super::*;
407    use crate::common::collect;
408    use crate::metrics::ExecutionPlanMetricsSet;
409    use crate::metrics::SpillMetrics;
410    use crate::spill::spill_manager::SpillManager;
411    use crate::test::build_table_i32;
412    use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray};
413    use arrow::compute::cast;
414    use arrow::datatypes::{DataType, Field, Int32Type, Schema};
415    use arrow::record_batch::RecordBatch;
416    use datafusion_common::Result;
417    use datafusion_execution::runtime_env::RuntimeEnv;
418    use futures::StreamExt as _;
419
420    use std::sync::Arc;
421
422    #[tokio::test]
423    async fn test_batch_spill_and_read() -> Result<()> {
424        let batch1 = build_table_i32(
425            ("a2", &vec![0, 1, 2]),
426            ("b2", &vec![3, 4, 5]),
427            ("c2", &vec![4, 5, 6]),
428        );
429
430        let batch2 = build_table_i32(
431            ("a2", &vec![10, 11, 12]),
432            ("b2", &vec![13, 14, 15]),
433            ("c2", &vec![14, 15, 16]),
434        );
435
436        let schema = batch1.schema();
437        let num_rows = batch1.num_rows() + batch2.num_rows();
438
439        // Construct SpillManager
440        let env = Arc::new(RuntimeEnv::default());
441        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
442        let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
443
444        let spill_file = spill_manager
445            .spill_record_batch_and_finish(&[batch1, batch2], "Test")?
446            .unwrap();
447        assert!(spill_file.path().exists());
448        let spilled_rows = spill_manager.metrics.spilled_rows.value();
449        assert_eq!(spilled_rows, num_rows);
450
451        let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
452        assert_eq!(stream.schema(), schema);
453
454        let batches = collect(stream).await?;
455        assert_eq!(batches.len(), 2);
456
457        Ok(())
458    }
459
460    #[tokio::test]
461    async fn test_batch_spill_and_read_dictionary_arrays() -> Result<()> {
462        // See https://github.com/apache/datafusion/issues/4658
463
464        let batch1 = build_table_i32(
465            ("a2", &vec![0, 1, 2]),
466            ("b2", &vec![3, 4, 5]),
467            ("c2", &vec![4, 5, 6]),
468        );
469
470        let batch2 = build_table_i32(
471            ("a2", &vec![10, 11, 12]),
472            ("b2", &vec![13, 14, 15]),
473            ("c2", &vec![14, 15, 16]),
474        );
475
476        // Dictionary encode the arrays
477        let dict_type =
478            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
479        let dict_schema = Arc::new(Schema::new(vec![
480            Field::new("a2", dict_type.clone(), true),
481            Field::new("b2", dict_type.clone(), true),
482            Field::new("c2", dict_type.clone(), true),
483        ]));
484
485        let batch1 = RecordBatch::try_new(
486            Arc::clone(&dict_schema),
487            batch1
488                .columns()
489                .iter()
490                .map(|array| cast(array, &dict_type))
491                .collect::<Result<_, _>>()?,
492        )?;
493
494        let batch2 = RecordBatch::try_new(
495            Arc::clone(&dict_schema),
496            batch2
497                .columns()
498                .iter()
499                .map(|array| cast(array, &dict_type))
500                .collect::<Result<_, _>>()?,
501        )?;
502
503        // Construct SpillManager
504        let env = Arc::new(RuntimeEnv::default());
505        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
506        let spill_manager = SpillManager::new(env, metrics, Arc::clone(&dict_schema));
507
508        let num_rows = batch1.num_rows() + batch2.num_rows();
509        let spill_file = spill_manager
510            .spill_record_batch_and_finish(&[batch1, batch2], "Test")?
511            .unwrap();
512        let spilled_rows = spill_manager.metrics.spilled_rows.value();
513        assert_eq!(spilled_rows, num_rows);
514
515        let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
516        assert_eq!(stream.schema(), dict_schema);
517        let batches = collect(stream).await?;
518        assert_eq!(batches.len(), 2);
519
520        Ok(())
521    }
522
523    #[tokio::test]
524    async fn test_batch_spill_by_size() -> Result<()> {
525        let batch1 = build_table_i32(
526            ("a2", &vec![0, 1, 2, 3]),
527            ("b2", &vec![3, 4, 5, 6]),
528            ("c2", &vec![4, 5, 6, 7]),
529        );
530
531        let schema = batch1.schema();
532        let env = Arc::new(RuntimeEnv::default());
533        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
534        let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
535
536        let (spill_file, max_batch_mem) = spill_manager
537            .spill_record_batch_by_size_and_return_max_batch_memory(
538                &batch1,
539                "Test Spill",
540                1,
541            )?
542            .unwrap();
543        assert!(spill_file.path().exists());
544        assert!(max_batch_mem > 0);
545
546        let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
547        assert_eq!(stream.schema(), schema);
548
549        let batches = collect(stream).await?;
550        assert_eq!(batches.len(), 4);
551
552        Ok(())
553    }
554
555    fn build_compressible_batch() -> RecordBatch {
556        let schema = Arc::new(Schema::new(vec![
557            Field::new("a", DataType::Utf8, false),
558            Field::new("b", DataType::Int32, false),
559            Field::new("c", DataType::Int32, true),
560        ]));
561
562        let a: ArrayRef = Arc::new(StringArray::from_iter_values(std::iter::repeat_n(
563            "repeated", 100,
564        )));
565        let b: ArrayRef = Arc::new(Int32Array::from(vec![1; 100]));
566        let c: ArrayRef = Arc::new(Int32Array::from(vec![2; 100]));
567
568        RecordBatch::try_new(schema, vec![a, b, c]).unwrap()
569    }
570
571    async fn validate(
572        spill_manager: &SpillManager,
573        spill_file: RefCountedTempFile,
574        num_rows: usize,
575        schema: SchemaRef,
576        batch_count: usize,
577    ) -> Result<()> {
578        let spilled_rows = spill_manager.metrics.spilled_rows.value();
579        assert_eq!(spilled_rows, num_rows);
580
581        let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
582        assert_eq!(stream.schema(), schema);
583
584        let batches = collect(stream).await?;
585        assert_eq!(batches.len(), batch_count);
586
587        Ok(())
588    }
589
590    #[tokio::test]
591    async fn test_spill_compression() -> Result<()> {
592        let batch = build_compressible_batch();
593        let num_rows = batch.num_rows();
594        let schema = batch.schema();
595        let batch_count = 1;
596        let batches = [batch];
597
598        // Construct SpillManager
599        let env = Arc::new(RuntimeEnv::default());
600        let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
601        let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
602        let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
603        let uncompressed_spill_manager = SpillManager::new(
604            Arc::clone(&env),
605            uncompressed_metrics,
606            Arc::clone(&schema),
607        );
608        let lz4_spill_manager =
609            SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema))
610                .with_compression_type(SpillCompression::Lz4Frame);
611        let zstd_spill_manager =
612            SpillManager::new(env, zstd_metrics, Arc::clone(&schema))
613                .with_compression_type(SpillCompression::Zstd);
614        let uncompressed_spill_file = uncompressed_spill_manager
615            .spill_record_batch_and_finish(&batches, "Test")?
616            .unwrap();
617        let lz4_spill_file = lz4_spill_manager
618            .spill_record_batch_and_finish(&batches, "Lz4_Test")?
619            .unwrap();
620        let zstd_spill_file = zstd_spill_manager
621            .spill_record_batch_and_finish(&batches, "ZSTD_Test")?
622            .unwrap();
623        assert!(uncompressed_spill_file.path().exists());
624        assert!(lz4_spill_file.path().exists());
625        assert!(zstd_spill_file.path().exists());
626
627        let lz4_spill_size = std::fs::metadata(lz4_spill_file.path())?.len();
628        let zstd_spill_size = std::fs::metadata(zstd_spill_file.path())?.len();
629        let uncompressed_spill_size =
630            std::fs::metadata(uncompressed_spill_file.path())?.len();
631
632        assert!(uncompressed_spill_size > lz4_spill_size);
633        assert!(uncompressed_spill_size > zstd_spill_size);
634
635        validate(
636            &lz4_spill_manager,
637            lz4_spill_file,
638            num_rows,
639            Arc::clone(&schema),
640            batch_count,
641        )
642        .await?;
643        validate(
644            &zstd_spill_manager,
645            zstd_spill_file,
646            num_rows,
647            Arc::clone(&schema),
648            batch_count,
649        )
650        .await?;
651        validate(
652            &uncompressed_spill_manager,
653            uncompressed_spill_file,
654            num_rows,
655            schema,
656            batch_count,
657        )
658        .await?;
659        Ok(())
660    }
661
662    #[test]
663    fn test_get_record_batch_memory_size() {
664        // Create a simple record batch with two columns
665        let schema = Arc::new(Schema::new(vec![
666            Field::new("ints", DataType::Int32, true),
667            Field::new("float64", DataType::Float64, false),
668        ]));
669
670        let int_array =
671            Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]);
672        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
673
674        let batch = RecordBatch::try_new(
675            schema,
676            vec![Arc::new(int_array), Arc::new(float64_array)],
677        )
678        .unwrap();
679
680        let size = get_record_batch_memory_size(&batch);
681        assert_eq!(size, 60);
682    }
683
684    #[test]
685    fn test_get_record_batch_memory_size_with_null() {
686        // Create a simple record batch with two columns
687        let schema = Arc::new(Schema::new(vec![
688            Field::new("ints", DataType::Int32, true),
689            Field::new("float64", DataType::Float64, false),
690        ]));
691
692        let int_array = Int32Array::from(vec![None, Some(2), Some(3)]);
693        let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]);
694
695        let batch = RecordBatch::try_new(
696            schema,
697            vec![Arc::new(int_array), Arc::new(float64_array)],
698        )
699        .unwrap();
700
701        let size = get_record_batch_memory_size(&batch);
702        assert_eq!(size, 100);
703    }
704
705    #[test]
706    fn test_get_record_batch_memory_size_empty() {
707        // Test with empty record batch
708        let schema = Arc::new(Schema::new(vec![Field::new(
709            "ints",
710            DataType::Int32,
711            false,
712        )]));
713
714        let int_array: Int32Array = Int32Array::from(vec![] as Vec<i32>);
715        let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap();
716
717        let size = get_record_batch_memory_size(&batch);
718        assert_eq!(size, 0, "Empty batch should have 0 memory size");
719    }
720
721    #[test]
722    fn test_get_record_batch_memory_size_shared_buffer() {
723        // Test with slices that share the same underlying buffer
724        let original = Int32Array::from(vec![1, 2, 3, 4, 5]);
725        let slice1 = original.slice(0, 3);
726        let slice2 = original.slice(2, 3);
727
728        // `RecordBatch` with `original` array
729        // ----
730        let schema_origin = Arc::new(Schema::new(vec![Field::new(
731            "origin_col",
732            DataType::Int32,
733            false,
734        )]));
735        let batch_origin =
736            RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap();
737
738        // `RecordBatch` with all columns are reference to `original` array
739        // ----
740        let schema = Arc::new(Schema::new(vec![
741            Field::new("slice1", DataType::Int32, false),
742            Field::new("slice2", DataType::Int32, false),
743        ]));
744
745        let batch_sliced =
746            RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)])
747                .unwrap();
748
749        // Two sizes should all be only counting the buffer in `original` array
750        let size_origin = get_record_batch_memory_size(&batch_origin);
751        let size_sliced = get_record_batch_memory_size(&batch_sliced);
752
753        assert_eq!(size_origin, size_sliced);
754    }
755
756    #[test]
757    fn test_get_record_batch_memory_size_nested_array() {
758        let schema = Arc::new(Schema::new(vec![
759            Field::new(
760                "nested_int",
761                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
762                false,
763            ),
764            Field::new(
765                "nested_int2",
766                DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))),
767                false,
768            ),
769        ]));
770
771        let int_list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
772            Some(vec![Some(1), Some(2), Some(3)]),
773        ]);
774
775        let int_list_array2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
776            Some(vec![Some(4), Some(5), Some(6)]),
777        ]);
778
779        let batch = RecordBatch::try_new(
780            schema,
781            vec![Arc::new(int_list_array), Arc::new(int_list_array2)],
782        )
783        .unwrap();
784
785        let size = get_record_batch_memory_size(&batch);
786        assert_eq!(size, 8208);
787    }
788
789    // ==== Spill manager tests ====
790
791    #[test]
792    fn test_spill_manager_spill_record_batch_and_finish() -> Result<()> {
793        let env = Arc::new(RuntimeEnv::default());
794        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
795        let schema = Arc::new(Schema::new(vec![
796            Field::new("a", DataType::Int32, false),
797            Field::new("b", DataType::Utf8, false),
798        ]));
799
800        let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
801
802        let batch = RecordBatch::try_new(
803            schema,
804            vec![
805                Arc::new(Int32Array::from(vec![1, 2, 3])),
806                Arc::new(StringArray::from(vec!["a", "b", "c"])),
807            ],
808        )?;
809
810        let temp_file = spill_manager.spill_record_batch_and_finish(&[batch], "Test")?;
811        assert!(temp_file.is_some());
812        assert!(temp_file.unwrap().path().exists());
813        Ok(())
814    }
815
816    fn verify_metrics(
817        in_progress_file: &InProgressSpillFile,
818        expected_spill_file_count: usize,
819        expected_spilled_bytes: usize,
820        expected_spilled_rows: usize,
821    ) -> Result<()> {
822        let actual_spill_file_count = in_progress_file
823            .spill_writer
824            .metrics
825            .spill_file_count
826            .value();
827        let actual_spilled_bytes =
828            in_progress_file.spill_writer.metrics.spilled_bytes.value();
829        let actual_spilled_rows =
830            in_progress_file.spill_writer.metrics.spilled_rows.value();
831
832        assert_eq!(
833            actual_spill_file_count, expected_spill_file_count,
834            "Spill file count mismatch"
835        );
836        assert_eq!(
837            actual_spilled_bytes, expected_spilled_bytes,
838            "Spilled bytes mismatch"
839        );
840        assert_eq!(
841            actual_spilled_rows, expected_spilled_rows,
842            "Spilled rows mismatch"
843        );
844
845        Ok(())
846    }
847
848    #[test]
849    fn test_in_progress_spill_file_append_and_finish() -> Result<()> {
850        let env = Arc::new(RuntimeEnv::default());
851        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
852        let schema = Arc::new(Schema::new(vec![
853            Field::new("a", DataType::Int32, false),
854            Field::new("b", DataType::Utf8, false),
855        ]));
856
857        let spill_manager =
858            Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
859        let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
860
861        let batch1 = RecordBatch::try_new(
862            Arc::clone(&schema),
863            vec![
864                Arc::new(Int32Array::from(vec![1, 2, 3])),
865                Arc::new(StringArray::from(vec!["a", "b", "c"])),
866            ],
867        )?;
868
869        let batch2 = RecordBatch::try_new(
870            Arc::clone(&schema),
871            vec![
872                Arc::new(Int32Array::from(vec![4, 5, 6])),
873                Arc::new(StringArray::from(vec!["d", "e", "f"])),
874            ],
875        )?;
876        // After appending each batch, spilled_rows should increase, while spill_file_count and
877        // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called)
878        in_progress_file.append_batch(&batch1)?;
879        verify_metrics(&in_progress_file, 1, 0, 3)?;
880
881        in_progress_file.append_batch(&batch2)?;
882        verify_metrics(&in_progress_file, 1, 0, 6)?;
883
884        let completed_file = in_progress_file.finish()?;
885        assert!(completed_file.is_some());
886        assert!(completed_file.unwrap().path().exists());
887        verify_metrics(&in_progress_file, 1, 712, 6)?;
888        // Double finish produce error
889        let result = in_progress_file.finish();
890        assert!(result.is_err());
891
892        Ok(())
893    }
894
895    // Test write no batches
896    #[test]
897    fn test_in_progress_spill_file_write_no_batches() -> Result<()> {
898        let env = Arc::new(RuntimeEnv::default());
899        let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
900        let schema = Arc::new(Schema::new(vec![
901            Field::new("a", DataType::Int32, false),
902            Field::new("b", DataType::Utf8, false),
903        ]));
904
905        let spill_manager =
906            Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
907
908        // Test write empty batch with interface `InProgressSpillFile` and `append_batch()`
909        let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
910        let completed_file = in_progress_file.finish()?;
911        assert!(completed_file.is_none());
912
913        // Test write empty batch with interface `spill_record_batch_and_finish()`
914        let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?;
915        assert!(completed_file.is_none());
916
917        // Test write empty batch with interface `spill_record_batch_by_size_and_return_max_batch_memory()`
918        let empty_batch = RecordBatch::try_new(
919            Arc::clone(&schema),
920            vec![
921                Arc::new(Int32Array::from(Vec::<Option<i32>>::new())),
922                Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
923            ],
924        )?;
925        let completed_file = spill_manager
926            .spill_record_batch_by_size_and_return_max_batch_memory(
927                &empty_batch,
928                "Test",
929                1,
930            )?;
931        assert!(completed_file.is_none());
932
933        Ok(())
934    }
935
936    #[test]
937    fn test_reading_more_spills_than_tokio_blocking_threads() -> Result<()> {
938        tokio::runtime::Builder::new_current_thread()
939            .enable_all()
940            .max_blocking_threads(1)
941            .build()
942            .unwrap()
943            .block_on(async {
944                let batch = build_table_i32(
945                    ("a2", &vec![0, 1, 2]),
946                    ("b2", &vec![3, 4, 5]),
947                    ("c2", &vec![4, 5, 6]),
948                );
949
950                let schema = batch.schema();
951
952                // Construct SpillManager
953                let env = Arc::new(RuntimeEnv::default());
954                let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
955                let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema));
956                let batches: [_; 10] = std::array::from_fn(|_| batch.clone());
957
958                let spill_file_1 = spill_manager
959                    .spill_record_batch_and_finish(&batches, "Test1")?
960                    .unwrap();
961                let spill_file_2 = spill_manager
962                    .spill_record_batch_and_finish(&batches, "Test2")?
963                    .unwrap();
964
965                let mut stream_1 =
966                    spill_manager.read_spill_as_stream(spill_file_1, None)?;
967                let mut stream_2 =
968                    spill_manager.read_spill_as_stream(spill_file_2, None)?;
969                stream_1.next().await;
970                stream_2.next().await;
971
972                Ok(())
973            })
974    }
975
976    #[test]
977    fn test_alignment_for_schema() -> Result<()> {
978        let schema = Schema::new(vec![Field::new("strings", DataType::Utf8View, false)]);
979        let alignment = get_max_alignment_for_schema(&schema);
980        assert_eq!(alignment, 16);
981
982        let schema = Schema::new(vec![
983            Field::new("int32", DataType::Int32, false),
984            Field::new("int64", DataType::Int64, false),
985        ]);
986        let alignment = get_max_alignment_for_schema(&schema);
987        assert_eq!(alignment, 8);
988        Ok(())
989    }
990}