datafusion/test_util/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions to make testing DataFusion based crates easier
19
20#[cfg(feature = "parquet")]
21pub mod parquet;
22
23pub mod csv;
24
25use futures::Stream;
26use std::any::Any;
27use std::collections::HashMap;
28use std::fs::File;
29use std::io::Write;
30use std::path::Path;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34use crate::catalog::{TableProvider, TableProviderFactory};
35use crate::dataframe::DataFrame;
36use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable};
37use crate::datasource::{empty::EmptyTable, provider_as_source};
38use crate::error::Result;
39use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
40use crate::physical_plan::ExecutionPlan;
41use crate::prelude::{CsvReadOptions, SessionContext};
42
43use crate::execution::SendableRecordBatchStream;
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use datafusion_catalog::Session;
47use datafusion_common::TableReference;
48use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
49use std::pin::Pin;
50
51use async_trait::async_trait;
52
53use tempfile::TempDir;
54// backwards compatibility
55#[cfg(feature = "parquet")]
56pub use datafusion_common::test_util::parquet_test_data;
57pub use datafusion_common::test_util::{arrow_test_data, get_data_dir};
58
59use crate::execution::RecordBatchStream;
60
61/// Scan an empty data source, mainly used in tests
62pub fn scan_empty(
63    name: Option<&str>,
64    table_schema: &Schema,
65    projection: Option<Vec<usize>>,
66) -> Result<LogicalPlanBuilder> {
67    let table_schema = Arc::new(table_schema.clone());
68    let provider = Arc::new(EmptyTable::new(table_schema));
69    let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE));
70    LogicalPlanBuilder::scan(name, provider_as_source(provider), projection)
71}
72
73/// Scan an empty data source with configured partition, mainly used in tests.
74pub fn scan_empty_with_partitions(
75    name: Option<&str>,
76    table_schema: &Schema,
77    projection: Option<Vec<usize>>,
78    partitions: usize,
79) -> Result<LogicalPlanBuilder> {
80    let table_schema = Arc::new(table_schema.clone());
81    let provider = Arc::new(EmptyTable::new(table_schema).with_partitions(partitions));
82    let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE));
83    LogicalPlanBuilder::scan(name, provider_as_source(provider), projection)
84}
85
86/// Get the schema for the aggregate_test_* csv files
87pub fn aggr_test_schema() -> SchemaRef {
88    let mut f1 = Field::new("c1", DataType::Utf8, false);
89    f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())]));
90    let schema = Schema::new(vec![
91        f1,
92        Field::new("c2", DataType::UInt32, false),
93        Field::new("c3", DataType::Int8, false),
94        Field::new("c4", DataType::Int16, false),
95        Field::new("c5", DataType::Int32, false),
96        Field::new("c6", DataType::Int64, false),
97        Field::new("c7", DataType::UInt8, false),
98        Field::new("c8", DataType::UInt16, false),
99        Field::new("c9", DataType::UInt32, false),
100        Field::new("c10", DataType::UInt64, false),
101        Field::new("c11", DataType::Float32, false),
102        Field::new("c12", DataType::Float64, false),
103        Field::new("c13", DataType::Utf8, false),
104    ]);
105
106    Arc::new(schema)
107}
108
109/// Register session context for the aggregate_test_100.csv file
110pub async fn register_aggregate_csv(
111    ctx: &SessionContext,
112    table_name: &str,
113) -> Result<()> {
114    let schema = aggr_test_schema();
115    let testdata = arrow_test_data();
116    ctx.register_csv(
117        table_name,
118        &format!("{testdata}/csv/aggregate_test_100.csv"),
119        CsvReadOptions::new().schema(schema.as_ref()),
120    )
121    .await?;
122    Ok(())
123}
124
125/// Create a table from the aggregate_test_100.csv file with the specified name
126pub async fn test_table_with_name(name: &str) -> Result<DataFrame> {
127    let ctx = SessionContext::new();
128    register_aggregate_csv(&ctx, name).await?;
129    ctx.table(name).await
130}
131
132/// Create a table from the aggregate_test_100.csv file with the name "aggregate_test_100"
133pub async fn test_table() -> Result<DataFrame> {
134    test_table_with_name("aggregate_test_100").await
135}
136
137/// Execute SQL and return results
138#[cfg(feature = "sql")]
139pub async fn plan_and_collect(
140    ctx: &SessionContext,
141    sql: &str,
142) -> Result<Vec<RecordBatch>> {
143    ctx.sql(sql).await?.collect().await
144}
145
146/// Generate CSV partitions within the supplied directory
147pub fn populate_csv_partitions(
148    tmp_dir: &TempDir,
149    partition_count: usize,
150    file_extension: &str,
151) -> Result<SchemaRef> {
152    // define schema for data source (csv file)
153    let schema = Arc::new(Schema::new(vec![
154        Field::new("c1", DataType::UInt32, false),
155        Field::new("c2", DataType::UInt64, false),
156        Field::new("c3", DataType::Boolean, false),
157    ]));
158
159    // generate a partitioned file
160    for partition in 0..partition_count {
161        let filename = format!("partition-{partition}.{file_extension}");
162        let file_path = tmp_dir.path().join(filename);
163        let mut file = File::create(file_path)?;
164
165        // generate some data
166        for i in 0..=10 {
167            let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
168            file.write_all(data.as_bytes())?;
169        }
170    }
171
172    Ok(schema)
173}
174
175/// TableFactory for tests
176#[derive(Default, Debug)]
177pub struct TestTableFactory {}
178
179#[async_trait]
180impl TableProviderFactory for TestTableFactory {
181    async fn create(
182        &self,
183        _: &dyn Session,
184        cmd: &CreateExternalTable,
185    ) -> Result<Arc<dyn TableProvider>> {
186        Ok(Arc::new(TestTableProvider {
187            url: cmd.location.to_string(),
188            schema: Arc::clone(cmd.schema.inner()),
189        }))
190    }
191}
192
193/// TableProvider for testing purposes
194#[derive(Debug)]
195pub struct TestTableProvider {
196    /// URL of table files or folder
197    pub url: String,
198    /// test table schema
199    pub schema: SchemaRef,
200}
201
202impl TestTableProvider {}
203
204#[async_trait]
205impl TableProvider for TestTableProvider {
206    fn as_any(&self) -> &dyn Any {
207        self
208    }
209
210    fn schema(&self) -> SchemaRef {
211        Arc::clone(&self.schema)
212    }
213
214    fn table_type(&self) -> TableType {
215        unimplemented!("TestTableProvider is a stub for testing.")
216    }
217
218    async fn scan(
219        &self,
220        _state: &dyn Session,
221        _projection: Option<&Vec<usize>>,
222        _filters: &[Expr],
223        _limit: Option<usize>,
224    ) -> Result<Arc<dyn ExecutionPlan>> {
225        unimplemented!("TestTableProvider is a stub for testing.")
226    }
227}
228
229/// This function creates an unbounded sorted file for testing purposes.
230pub fn register_unbounded_file_with_ordering(
231    ctx: &SessionContext,
232    schema: SchemaRef,
233    file_path: &Path,
234    table_name: &str,
235    file_sort_order: Vec<Vec<SortExpr>>,
236) -> Result<()> {
237    let source = FileStreamProvider::new_file(schema, file_path.into());
238    let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order);
239
240    // Register table:
241    ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?;
242    Ok(())
243}
244
245/// Creates a bounded stream that emits the same record batch a specified number of times.
246/// This is useful for testing purposes.
247pub fn bounded_stream(
248    record_batch: RecordBatch,
249    limit: usize,
250) -> SendableRecordBatchStream {
251    Box::pin(BoundedStream {
252        record_batch,
253        count: 0,
254        limit,
255    })
256}
257
258struct BoundedStream {
259    record_batch: RecordBatch,
260    count: usize,
261    limit: usize,
262}
263
264impl Stream for BoundedStream {
265    type Item = Result<RecordBatch, crate::error::DataFusionError>;
266
267    fn poll_next(
268        mut self: Pin<&mut Self>,
269        _cx: &mut Context<'_>,
270    ) -> Poll<Option<Self::Item>> {
271        if self.count >= self.limit {
272            Poll::Ready(None)
273        } else {
274            self.count += 1;
275            Poll::Ready(Some(Ok(self.record_batch.clone())))
276        }
277    }
278}
279
280impl RecordBatchStream for BoundedStream {
281    fn schema(&self) -> SchemaRef {
282        self.record_batch.schema()
283    }
284}