datafusion_physical_plan/
common.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines common code used in execution plans
19
20use std::fs;
21use std::fs::metadata;
22use std::sync::Arc;
23
24use super::SendableRecordBatchStream;
25use crate::stream::RecordBatchReceiverStream;
26use crate::{ColumnStatistics, Statistics};
27
28use arrow::array::Array;
29use arrow::datatypes::Schema;
30use arrow::record_batch::RecordBatch;
31use datafusion_common::stats::Precision;
32use datafusion_common::{plan_err, Result};
33use datafusion_execution::memory_pool::MemoryReservation;
34
35use futures::{StreamExt, TryStreamExt};
36use parking_lot::Mutex;
37
38/// [`MemoryReservation`] used across query execution streams
39pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
40
41/// Create a vector of record batches from a stream
42pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
43    stream.try_collect::<Vec<_>>().await
44}
45
46/// Recursively builds a list of files in a directory with a given extension
47pub fn build_checked_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
48    let mut filenames: Vec<String> = Vec::new();
49    build_file_list_recurse(dir, &mut filenames, ext)?;
50    if filenames.is_empty() {
51        return plan_err!("No files found at {dir} with file extension {ext}");
52    }
53    Ok(filenames)
54}
55
56/// Recursively builds a list of files in a directory with a given extension
57pub fn build_file_list(dir: &str, ext: &str) -> Result<Vec<String>> {
58    let mut filenames: Vec<String> = Vec::new();
59    build_file_list_recurse(dir, &mut filenames, ext)?;
60    Ok(filenames)
61}
62
63/// Recursively build a list of files in a directory with a given extension with an accumulator list
64fn build_file_list_recurse(
65    dir: &str,
66    filenames: &mut Vec<String>,
67    ext: &str,
68) -> Result<()> {
69    let metadata = metadata(dir)?;
70    if metadata.is_file() {
71        if dir.ends_with(ext) {
72            filenames.push(dir.to_string());
73        }
74    } else {
75        for entry in fs::read_dir(dir)? {
76            let entry = entry?;
77            let path = entry.path();
78            if let Some(path_name) = path.to_str() {
79                if path.is_dir() {
80                    build_file_list_recurse(path_name, filenames, ext)?;
81                } else if path_name.ends_with(ext) {
82                    filenames.push(path_name.to_string());
83                }
84            } else {
85                return plan_err!("Invalid path");
86            }
87        }
88    }
89    Ok(())
90}
91
92/// If running in a tokio context spawns the execution of `stream` to a separate task
93/// allowing it to execute in parallel with an intermediate buffer of size `buffer`
94pub fn spawn_buffered(
95    mut input: SendableRecordBatchStream,
96    buffer: usize,
97) -> SendableRecordBatchStream {
98    // Use tokio only if running from a multi-thread tokio context
99    match tokio::runtime::Handle::try_current() {
100        Ok(handle)
101            if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread =>
102        {
103            let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer);
104
105            let sender = builder.tx();
106
107            builder.spawn(async move {
108                while let Some(item) = input.next().await {
109                    if sender.send(item).await.is_err() {
110                        // Receiver dropped when query is shutdown early (e.g., limit) or error,
111                        // no need to return propagate the send error.
112                        return Ok(());
113                    }
114                }
115
116                Ok(())
117            });
118
119            builder.build()
120        }
121        _ => input,
122    }
123}
124
125/// Computes the statistics for an in-memory RecordBatch
126///
127/// Only computes statistics that are in arrows metadata (num rows, byte size and nulls)
128/// and does not apply any kernel on the actual data.
129pub fn compute_record_batch_statistics(
130    batches: &[Vec<RecordBatch>],
131    schema: &Schema,
132    projection: Option<Vec<usize>>,
133) -> Statistics {
134    let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum();
135
136    let projection = match projection {
137        Some(p) => p,
138        None => (0..schema.fields().len()).collect(),
139    };
140
141    let total_byte_size = batches
142        .iter()
143        .flatten()
144        .map(|b| {
145            projection
146                .iter()
147                .map(|index| b.column(*index).get_array_memory_size())
148                .sum::<usize>()
149        })
150        .sum();
151
152    let mut null_counts = vec![0; projection.len()];
153
154    for partition in batches.iter() {
155        for batch in partition {
156            for (stat_index, col_index) in projection.iter().enumerate() {
157                null_counts[stat_index] += batch
158                    .column(*col_index)
159                    .logical_nulls()
160                    .map(|nulls| nulls.null_count())
161                    .unwrap_or_default();
162            }
163        }
164    }
165    let column_statistics = null_counts
166        .into_iter()
167        .map(|null_count| {
168            let mut s = ColumnStatistics::new_unknown();
169            s.null_count = Precision::Exact(null_count);
170            s
171        })
172        .collect();
173
174    Statistics {
175        num_rows: Precision::Exact(nb_rows),
176        total_byte_size: Precision::Exact(total_byte_size),
177        column_statistics,
178    }
179}
180
181/// Checks if the given projection is valid for the given schema.
182pub fn can_project(
183    schema: &arrow::datatypes::SchemaRef,
184    projection: Option<&Vec<usize>>,
185) -> Result<()> {
186    match projection {
187        Some(columns) => {
188            if columns
189                .iter()
190                .max()
191                .is_some_and(|&i| i >= schema.fields().len())
192            {
193                Err(arrow::error::ArrowError::SchemaError(format!(
194                    "project index {} out of bounds, max field {}",
195                    columns.iter().max().unwrap(),
196                    schema.fields().len()
197                ))
198                .into())
199            } else {
200                Ok(())
201            }
202        }
203        None => Ok(()),
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    use arrow::{
212        array::{Float32Array, Float64Array, UInt64Array},
213        datatypes::{DataType, Field},
214    };
215
216    #[test]
217    fn test_compute_record_batch_statistics_empty() -> Result<()> {
218        let schema = Arc::new(Schema::new(vec![
219            Field::new("f32", DataType::Float32, false),
220            Field::new("f64", DataType::Float64, false),
221        ]));
222        let stats = compute_record_batch_statistics(&[], &schema, Some(vec![0, 1]));
223
224        assert_eq!(stats.num_rows, Precision::Exact(0));
225        assert_eq!(stats.total_byte_size, Precision::Exact(0));
226        Ok(())
227    }
228
229    #[test]
230    fn test_compute_record_batch_statistics() -> Result<()> {
231        let schema = Arc::new(Schema::new(vec![
232            Field::new("f32", DataType::Float32, false),
233            Field::new("f64", DataType::Float64, false),
234            Field::new("u64", DataType::UInt64, false),
235        ]));
236        let batch = RecordBatch::try_new(
237            Arc::clone(&schema),
238            vec![
239                Arc::new(Float32Array::from(vec![1., 2., 3.])),
240                Arc::new(Float64Array::from(vec![9., 8., 7.])),
241                Arc::new(UInt64Array::from(vec![4, 5, 6])),
242            ],
243        )?;
244
245        // Just select f32,f64
246        let select_projection = Some(vec![0, 1]);
247        let byte_size = batch
248            .project(&select_projection.clone().unwrap())
249            .unwrap()
250            .get_array_memory_size();
251
252        let actual =
253            compute_record_batch_statistics(&[vec![batch]], &schema, select_projection);
254
255        let expected = Statistics {
256            num_rows: Precision::Exact(3),
257            total_byte_size: Precision::Exact(byte_size),
258            column_statistics: vec![
259                ColumnStatistics {
260                    distinct_count: Precision::Absent,
261                    max_value: Precision::Absent,
262                    min_value: Precision::Absent,
263                    sum_value: Precision::Absent,
264                    null_count: Precision::Exact(0),
265                },
266                ColumnStatistics {
267                    distinct_count: Precision::Absent,
268                    max_value: Precision::Absent,
269                    min_value: Precision::Absent,
270                    sum_value: Precision::Absent,
271                    null_count: Precision::Exact(0),
272                },
273            ],
274        };
275
276        assert_eq!(actual, expected);
277        Ok(())
278    }
279
280    #[test]
281    fn test_compute_record_batch_statistics_null() -> Result<()> {
282        let schema =
283            Arc::new(Schema::new(vec![Field::new("u64", DataType::UInt64, true)]));
284        let batch1 = RecordBatch::try_new(
285            Arc::clone(&schema),
286            vec![Arc::new(UInt64Array::from(vec![Some(1), None, None]))],
287        )?;
288        let batch2 = RecordBatch::try_new(
289            Arc::clone(&schema),
290            vec![Arc::new(UInt64Array::from(vec![Some(1), Some(2), None]))],
291        )?;
292        let byte_size = batch1.get_array_memory_size() + batch2.get_array_memory_size();
293        let actual =
294            compute_record_batch_statistics(&[vec![batch1], vec![batch2]], &schema, None);
295
296        let expected = Statistics {
297            num_rows: Precision::Exact(6),
298            total_byte_size: Precision::Exact(byte_size),
299            column_statistics: vec![ColumnStatistics {
300                distinct_count: Precision::Absent,
301                max_value: Precision::Absent,
302                min_value: Precision::Absent,
303                sum_value: Precision::Absent,
304                null_count: Precision::Exact(3),
305            }],
306        };
307
308        assert_eq!(actual, expected);
309        Ok(())
310    }
311}