datafusion/datasource/physical_plan/
csv.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Reexports the [`datafusion_datasource_json::source`] module, containing CSV based [`FileSource`].
19//!
20//! [`FileSource`]: datafusion_datasource::file::FileSource
21
22pub use datafusion_datasource_csv::source::*;
23
24#[cfg(test)]
25mod tests {
26
27    use std::collections::HashMap;
28    use std::fs::{self, File};
29    use std::io::Write;
30    use std::sync::Arc;
31
32    use datafusion_datasource_csv::CsvFormat;
33    use object_store::ObjectStore;
34
35    use crate::prelude::CsvReadOptions;
36    use crate::prelude::SessionContext;
37    use crate::test::partitioned_file_groups;
38    use datafusion_common::test_util::arrow_test_data;
39    use datafusion_common::test_util::batches_to_string;
40    use datafusion_common::{assert_batches_eq, Result};
41    use datafusion_execution::config::SessionConfig;
42    use datafusion_physical_plan::metrics::MetricsSet;
43    use datafusion_physical_plan::ExecutionPlan;
44
45    #[cfg(feature = "compression")]
46    use datafusion_datasource::file_compression_type::FileCompressionType;
47    use datafusion_datasource_csv::partitioned_csv_config;
48    use datafusion_datasource_csv::source::CsvSource;
49    use futures::{StreamExt, TryStreamExt};
50
51    use arrow::datatypes::*;
52    use bytes::Bytes;
53    use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
54    use datafusion_datasource::source::DataSourceExec;
55    use insta::assert_snapshot;
56    use object_store::chunked::ChunkedStore;
57    use object_store::local::LocalFileSystem;
58    use rstest::*;
59    use tempfile::TempDir;
60    use url::Url;
61
62    fn aggr_test_schema() -> SchemaRef {
63        let mut f1 = Field::new("c1", DataType::Utf8, false);
64        f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())]));
65        let schema = Schema::new(vec![
66            f1,
67            Field::new("c2", DataType::UInt32, false),
68            Field::new("c3", DataType::Int8, false),
69            Field::new("c4", DataType::Int16, false),
70            Field::new("c5", DataType::Int32, false),
71            Field::new("c6", DataType::Int64, false),
72            Field::new("c7", DataType::UInt8, false),
73            Field::new("c8", DataType::UInt16, false),
74            Field::new("c9", DataType::UInt32, false),
75            Field::new("c10", DataType::UInt64, false),
76            Field::new("c11", DataType::Float32, false),
77            Field::new("c12", DataType::Float64, false),
78            Field::new("c13", DataType::Utf8, false),
79        ]);
80
81        Arc::new(schema)
82    }
83
84    #[rstest(
85        file_compression_type,
86        case(FileCompressionType::UNCOMPRESSED),
87        case(FileCompressionType::GZIP),
88        case(FileCompressionType::BZIP2),
89        case(FileCompressionType::XZ),
90        case(FileCompressionType::ZSTD)
91    )]
92    #[cfg(feature = "compression")]
93    #[tokio::test]
94    async fn csv_exec_with_projection(
95        file_compression_type: FileCompressionType,
96    ) -> Result<()> {
97        let session_ctx = SessionContext::new();
98        let task_ctx = session_ctx.task_ctx();
99        let file_schema = aggr_test_schema();
100        let path = format!("{}/csv", arrow_test_data());
101        let filename = "aggregate_test_100.csv";
102        let tmp_dir = TempDir::new()?;
103
104        let file_groups = partitioned_file_groups(
105            path.as_str(),
106            filename,
107            1,
108            Arc::new(CsvFormat::default()),
109            file_compression_type.to_owned(),
110            tmp_dir.path(),
111        )?;
112
113        let source = Arc::new(CsvSource::new(true, b',', b'"'));
114        let config = FileScanConfigBuilder::from(partitioned_csv_config(
115            file_schema,
116            file_groups,
117            source,
118        ))
119        .with_file_compression_type(file_compression_type)
120        .with_newlines_in_values(false)
121        .with_projection_indices(Some(vec![0, 2, 4]))
122        .build();
123
124        assert_eq!(13, config.file_schema().fields().len());
125        let csv = DataSourceExec::from_data_source(config);
126
127        assert_eq!(3, csv.schema().fields().len());
128
129        let mut stream = csv.execute(0, task_ctx)?;
130        let batch = stream.next().await.unwrap()?;
131        assert_eq!(3, batch.num_columns());
132        assert_eq!(100, batch.num_rows());
133
134        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
135            +----+-----+------------+
136            | c1 | c3  | c5         |
137            +----+-----+------------+
138            | c  | 1   | 2033001162 |
139            | d  | -40 | 706441268  |
140            | b  | 29  | 994303988  |
141            | a  | -85 | 1171968280 |
142            | b  | -82 | 1824882165 |
143            +----+-----+------------+
144        "###);}
145        Ok(())
146    }
147
148    #[rstest(
149        file_compression_type,
150        case(FileCompressionType::UNCOMPRESSED),
151        case(FileCompressionType::GZIP),
152        case(FileCompressionType::BZIP2),
153        case(FileCompressionType::XZ),
154        case(FileCompressionType::ZSTD)
155    )]
156    #[cfg(feature = "compression")]
157    #[tokio::test]
158    async fn csv_exec_with_mixed_order_projection(
159        file_compression_type: FileCompressionType,
160    ) -> Result<()> {
161        let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true");
162        let session_ctx = SessionContext::new_with_config(cfg);
163        let task_ctx = session_ctx.task_ctx();
164        let file_schema = aggr_test_schema();
165        let path = format!("{}/csv", arrow_test_data());
166        let filename = "aggregate_test_100.csv";
167        let tmp_dir = TempDir::new()?;
168
169        let file_groups = partitioned_file_groups(
170            path.as_str(),
171            filename,
172            1,
173            Arc::new(CsvFormat::default()),
174            file_compression_type.to_owned(),
175            tmp_dir.path(),
176        )?;
177
178        let source = Arc::new(CsvSource::new(true, b',', b'"'));
179        let config = FileScanConfigBuilder::from(partitioned_csv_config(
180            file_schema,
181            file_groups,
182            source,
183        ))
184        .with_newlines_in_values(false)
185        .with_file_compression_type(file_compression_type.to_owned())
186        .with_projection_indices(Some(vec![4, 0, 2]))
187        .build();
188        assert_eq!(13, config.file_schema().fields().len());
189        let csv = DataSourceExec::from_data_source(config);
190        assert_eq!(3, csv.schema().fields().len());
191
192        let mut stream = csv.execute(0, task_ctx)?;
193        let batch = stream.next().await.unwrap()?;
194        assert_eq!(3, batch.num_columns());
195        assert_eq!(100, batch.num_rows());
196
197        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
198            +------------+----+-----+
199            | c5         | c1 | c3  |
200            +------------+----+-----+
201            | 2033001162 | c  | 1   |
202            | 706441268  | d  | -40 |
203            | 994303988  | b  | 29  |
204            | 1171968280 | a  | -85 |
205            | 1824882165 | b  | -82 |
206            +------------+----+-----+
207        "###);}
208        Ok(())
209    }
210
211    #[rstest(
212        file_compression_type,
213        case(FileCompressionType::UNCOMPRESSED),
214        case(FileCompressionType::GZIP),
215        case(FileCompressionType::BZIP2),
216        case(FileCompressionType::XZ),
217        case(FileCompressionType::ZSTD)
218    )]
219    #[cfg(feature = "compression")]
220    #[tokio::test]
221    async fn csv_exec_with_limit(
222        file_compression_type: FileCompressionType,
223    ) -> Result<()> {
224        use futures::StreamExt;
225
226        let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true");
227        let session_ctx = SessionContext::new_with_config(cfg);
228        let task_ctx = session_ctx.task_ctx();
229        let file_schema = aggr_test_schema();
230        let path = format!("{}/csv", arrow_test_data());
231        let filename = "aggregate_test_100.csv";
232        let tmp_dir = TempDir::new()?;
233
234        let file_groups = partitioned_file_groups(
235            path.as_str(),
236            filename,
237            1,
238            Arc::new(CsvFormat::default()),
239            file_compression_type.to_owned(),
240            tmp_dir.path(),
241        )?;
242
243        let source = Arc::new(CsvSource::new(true, b',', b'"'));
244        let config = FileScanConfigBuilder::from(partitioned_csv_config(
245            file_schema,
246            file_groups,
247            source,
248        ))
249        .with_newlines_in_values(false)
250        .with_file_compression_type(file_compression_type.to_owned())
251        .with_limit(Some(5))
252        .build();
253        assert_eq!(13, config.file_schema().fields().len());
254        let csv = DataSourceExec::from_data_source(config);
255        assert_eq!(13, csv.schema().fields().len());
256
257        let mut it = csv.execute(0, task_ctx)?;
258        let batch = it.next().await.unwrap()?;
259        assert_eq!(13, batch.num_columns());
260        assert_eq!(5, batch.num_rows());
261
262        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###"
263            +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
264            | c1 | c2 | c3  | c4     | c5         | c6                   | c7  | c8    | c9         | c10                  | c11         | c12                 | c13                            |
265            +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
266            | c  | 2  | 1   | 18109  | 2033001162 | -6513304855495910254 | 25  | 43062 | 1491205016 | 5863949479783605708  | 0.110830784 | 0.9294097332465232  | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW |
267            | d  | 5  | -40 | 22614  | 706441268  | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107  | 0.3114712539863804  | C2GT5KVyOPZpgKVl110TyZO0NcJ434 |
268            | b  | 1  | 29  | -18218 | 994303988  | 5983957848665088916  | 204 | 9489  | 3275293996 | 14857091259186476033 | 0.53840446  | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz |
269            | a  | 1  | -85 | -15154 | 1171968280 | 1919439543497968449  | 77  | 52286 | 774637006  | 12101411955859039553 | 0.12285209  | 0.6864391962767343  | 0keZ5G8BffGwgF2RwQD59TFzMStxCB |
270            | b  | 5  | -82 | 22080  | 1824882165 | 7373730676428214987  | 208 | 34331 | 3342719438 | 3330177516592499461  | 0.82634634  | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd |
271            +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
272        "###);}
273
274        Ok(())
275    }
276
277    #[rstest(
278        file_compression_type,
279        case(FileCompressionType::UNCOMPRESSED),
280        case(FileCompressionType::GZIP),
281        case(FileCompressionType::BZIP2),
282        case(FileCompressionType::XZ),
283        case(FileCompressionType::ZSTD)
284    )]
285    #[cfg(feature = "compression")]
286    #[tokio::test]
287    async fn csv_exec_with_missing_column(
288        file_compression_type: FileCompressionType,
289    ) -> Result<()> {
290        let session_ctx = SessionContext::new();
291        let task_ctx = session_ctx.task_ctx();
292        let file_schema = aggr_test_schema_with_missing_col();
293        let path = format!("{}/csv", arrow_test_data());
294        let filename = "aggregate_test_100.csv";
295        let tmp_dir = TempDir::new()?;
296
297        let file_groups = partitioned_file_groups(
298            path.as_str(),
299            filename,
300            1,
301            Arc::new(CsvFormat::default()),
302            file_compression_type.to_owned(),
303            tmp_dir.path(),
304        )?;
305
306        let source = Arc::new(CsvSource::new(true, b',', b'"'));
307        let config = FileScanConfigBuilder::from(partitioned_csv_config(
308            file_schema,
309            file_groups,
310            source,
311        ))
312        .with_newlines_in_values(false)
313        .with_file_compression_type(file_compression_type.to_owned())
314        .with_limit(Some(5))
315        .build();
316        assert_eq!(14, config.file_schema().fields().len());
317        let csv = DataSourceExec::from_data_source(config);
318        assert_eq!(14, csv.schema().fields().len());
319
320        // errors due to https://github.com/apache/datafusion/issues/4918
321        let mut it = csv.execute(0, task_ctx)?;
322        let err = it.next().await.unwrap().unwrap_err().strip_backtrace();
323        assert_eq!(
324            err,
325            "Arrow error: Csv error: incorrect number of fields for line 1, expected 14 got 13"
326        );
327        Ok(())
328    }
329
330    #[rstest(
331        file_compression_type,
332        case(FileCompressionType::UNCOMPRESSED),
333        case(FileCompressionType::GZIP),
334        case(FileCompressionType::BZIP2),
335        case(FileCompressionType::XZ),
336        case(FileCompressionType::ZSTD)
337    )]
338    #[cfg(feature = "compression")]
339    #[tokio::test]
340    async fn csv_exec_with_partition(
341        file_compression_type: FileCompressionType,
342    ) -> Result<()> {
343        use datafusion_common::ScalarValue;
344
345        let session_ctx = SessionContext::new();
346        let task_ctx = session_ctx.task_ctx();
347        let file_schema = aggr_test_schema();
348        let path = format!("{}/csv", arrow_test_data());
349        let filename = "aggregate_test_100.csv";
350        let tmp_dir = TempDir::new()?;
351
352        let mut file_groups = partitioned_file_groups(
353            path.as_str(),
354            filename,
355            1,
356            Arc::new(CsvFormat::default()),
357            file_compression_type.to_owned(),
358            tmp_dir.path(),
359        )?;
360        // Add partition columns / values
361        file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")];
362
363        let num_file_schema_fields = file_schema.fields().len();
364
365        let source = Arc::new(CsvSource::new(true, b',', b'"'));
366        let config = FileScanConfigBuilder::from(partitioned_csv_config(
367            file_schema,
368            file_groups,
369            source,
370        ))
371        .with_newlines_in_values(false)
372        .with_file_compression_type(file_compression_type.to_owned())
373        .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)])
374        // We should be able to project on the partition column
375        // Which is supposed to be after the file fields
376        .with_projection_indices(Some(vec![0, num_file_schema_fields]))
377        .build();
378
379        // we don't have `/date=xx/` in the path but that is ok because
380        // partitions are resolved during scan anyway
381
382        assert_eq!(13, config.file_schema().fields().len());
383        let csv = DataSourceExec::from_data_source(config);
384        assert_eq!(2, csv.schema().fields().len());
385
386        let mut it = csv.execute(0, task_ctx)?;
387        let batch = it.next().await.unwrap()?;
388        assert_eq!(2, batch.num_columns());
389        assert_eq!(100, batch.num_rows());
390
391        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
392            +----+------------+
393            | c1 | date       |
394            +----+------------+
395            | c  | 2021-10-26 |
396            | d  | 2021-10-26 |
397            | b  | 2021-10-26 |
398            | a  | 2021-10-26 |
399            | b  | 2021-10-26 |
400            +----+------------+
401        "###);}
402
403        let metrics = csv.metrics().expect("doesn't found metrics");
404        let time_elapsed_processing = get_value(&metrics, "time_elapsed_processing");
405        assert!(
406            time_elapsed_processing > 0,
407            "Expected time_elapsed_processing greater than 0",
408        );
409        Ok(())
410    }
411
412    /// Generate CSV partitions within the supplied directory
413    fn populate_csv_partitions(
414        tmp_dir: &TempDir,
415        partition_count: usize,
416        file_extension: &str,
417    ) -> Result<SchemaRef> {
418        // define schema for data source (csv file)
419        let schema = Arc::new(Schema::new(vec![
420            Field::new("c1", DataType::UInt32, false),
421            Field::new("c2", DataType::UInt64, false),
422            Field::new("c3", DataType::Boolean, false),
423        ]));
424
425        // generate a partitioned file
426        for partition in 0..partition_count {
427            let filename = format!("partition-{partition}.{file_extension}");
428            let file_path = tmp_dir.path().join(filename);
429            let mut file = File::create(file_path)?;
430
431            // generate some data
432            for i in 0..=10 {
433                let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
434                file.write_all(data.as_bytes())?;
435            }
436        }
437
438        Ok(schema)
439    }
440
441    async fn test_additional_stores(
442        file_compression_type: FileCompressionType,
443        store: Arc<dyn ObjectStore>,
444    ) -> Result<()> {
445        let ctx = SessionContext::new();
446        let url = Url::parse("file://").unwrap();
447        ctx.register_object_store(&url, store.clone());
448
449        let task_ctx = ctx.task_ctx();
450
451        let file_schema = aggr_test_schema();
452        let path = format!("{}/csv", arrow_test_data());
453        let filename = "aggregate_test_100.csv";
454        let tmp_dir = TempDir::new()?;
455
456        let file_groups = partitioned_file_groups(
457            path.as_str(),
458            filename,
459            1,
460            Arc::new(CsvFormat::default()),
461            file_compression_type.to_owned(),
462            tmp_dir.path(),
463        )
464        .unwrap();
465
466        let source = Arc::new(CsvSource::new(true, b',', b'"'));
467        let config = FileScanConfigBuilder::from(partitioned_csv_config(
468            file_schema,
469            file_groups,
470            source,
471        ))
472        .with_newlines_in_values(false)
473        .with_file_compression_type(file_compression_type.to_owned())
474        .build();
475        let csv = DataSourceExec::from_data_source(config);
476
477        let it = csv.execute(0, task_ctx).unwrap();
478        let batches: Vec<_> = it.try_collect().await.unwrap();
479
480        let total_rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
481
482        assert_eq!(total_rows, 100);
483        Ok(())
484    }
485
486    #[rstest(
487        file_compression_type,
488        case(FileCompressionType::UNCOMPRESSED),
489        case(FileCompressionType::GZIP),
490        case(FileCompressionType::BZIP2),
491        case(FileCompressionType::XZ),
492        case(FileCompressionType::ZSTD)
493    )]
494    #[cfg(feature = "compression")]
495    #[tokio::test]
496    async fn test_chunked_csv(
497        file_compression_type: FileCompressionType,
498        #[values(10, 20, 30, 40)] chunk_size: usize,
499    ) -> Result<()> {
500        test_additional_stores(
501            file_compression_type,
502            Arc::new(ChunkedStore::new(
503                Arc::new(LocalFileSystem::new()),
504                chunk_size,
505            )),
506        )
507        .await?;
508        Ok(())
509    }
510
511    #[tokio::test]
512    async fn test_no_trailing_delimiter() {
513        let session_ctx = SessionContext::new();
514        let store = object_store::memory::InMemory::new();
515
516        let data = Bytes::from("a,b\n1,2\n3,4");
517        let path = object_store::path::Path::from("a.csv");
518        store.put(&path, data.into()).await.unwrap();
519
520        let url = Url::parse("memory://").unwrap();
521        session_ctx.register_object_store(&url, Arc::new(store));
522
523        let df = session_ctx
524            .read_csv("memory:///", CsvReadOptions::new())
525            .await
526            .unwrap();
527
528        let result = df.collect().await.unwrap();
529
530        assert_snapshot!(batches_to_string(&result), @r###"
531            +---+---+
532            | a | b |
533            +---+---+
534            | 1 | 2 |
535            | 3 | 4 |
536            +---+---+
537        "###);
538    }
539
540    #[tokio::test]
541    async fn test_terminator() {
542        let session_ctx = SessionContext::new();
543        let store = object_store::memory::InMemory::new();
544
545        let data = Bytes::from("a,b\r1,2\r3,4");
546        let path = object_store::path::Path::from("a.csv");
547        store.put(&path, data.into()).await.unwrap();
548
549        let url = Url::parse("memory://").unwrap();
550        session_ctx.register_object_store(&url, Arc::new(store));
551
552        let df = session_ctx
553            .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\r')))
554            .await
555            .unwrap();
556
557        let result = df.collect().await.unwrap();
558
559        assert_snapshot!(batches_to_string(&result),@r###"
560            +---+---+
561            | a | b |
562            +---+---+
563            | 1 | 2 |
564            | 3 | 4 |
565            +---+---+
566        "###);
567
568        let e = session_ctx
569            .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n')))
570            .await
571            .unwrap()
572            .collect()
573            .await
574            .unwrap_err();
575        assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2")
576    }
577
578    #[tokio::test]
579    async fn test_create_external_table_with_terminator() -> Result<()> {
580        let ctx = SessionContext::new();
581        ctx.sql(
582            r#"
583            CREATE EXTERNAL TABLE t1 (
584            col1 TEXT,
585            col2 TEXT
586            ) STORED AS CSV
587            LOCATION 'tests/data/cr_terminator.csv'
588            OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true');
589    "#,
590        )
591        .await?
592        .collect()
593        .await?;
594
595        let df = ctx.sql(r#"select * from t1"#).await?.collect().await?;
596        assert_snapshot!(batches_to_string(&df),@r###"
597            +------+--------+
598            | col1 | col2   |
599            +------+--------+
600            | id0  | value0 |
601            | id1  | value1 |
602            | id2  | value2 |
603            | id3  | value3 |
604            +------+--------+
605        "###);
606        Ok(())
607    }
608
609    #[tokio::test]
610    async fn test_create_external_table_with_terminator_with_newlines_in_values(
611    ) -> Result<()> {
612        let ctx = SessionContext::new();
613        ctx.sql(r#"
614            CREATE EXTERNAL TABLE t1 (
615            col1 TEXT,
616            col2 TEXT
617            ) STORED AS CSV
618            LOCATION 'tests/data/newlines_in_values_cr_terminator.csv'
619            OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true', 'format.newlines_in_values' 'true');
620    "#).await?.collect().await?;
621
622        let df = ctx.sql(r#"select * from t1"#).await?.collect().await?;
623        let expected = [
624            "+-------+-----------------------------+",
625            "| col1  | col2                        |",
626            "+-------+-----------------------------+",
627            "| 1     | hello\rworld                 |",
628            "| 2     | something\relse              |",
629            "| 3     | \rmany\rlines\rmake\rgood test\r |",
630            "| 4     | unquoted                    |",
631            "| value | end                         |",
632            "+-------+-----------------------------+",
633        ];
634        assert_batches_eq!(expected, &df);
635        Ok(())
636    }
637
638    #[tokio::test]
639    async fn write_csv_results_error_handling() -> Result<()> {
640        let ctx = SessionContext::new();
641
642        // register a local file system object store
643        let tmp_dir = TempDir::new()?;
644        let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
645        let local_url = Url::parse("file://local").unwrap();
646        ctx.register_object_store(&local_url, local);
647        let options = CsvReadOptions::default()
648            .schema_infer_max_records(2)
649            .has_header(true);
650        let df = ctx.read_csv("tests/data/corrupt.csv", options).await?;
651
652        let out_dir_url = "file://local/out";
653        let e = df
654            .write_csv(
655                out_dir_url,
656                crate::dataframe::DataFrameWriteOptions::new(),
657                None,
658            )
659            .await
660            .expect_err("should fail because input file does not match inferred schema");
661        assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'");
662        Ok(())
663    }
664
665    #[tokio::test]
666    async fn write_csv_results() -> Result<()> {
667        // create partitioned input file and context
668        let tmp_dir = TempDir::new()?;
669        let ctx = SessionContext::new_with_config(
670            SessionConfig::new()
671                .with_target_partitions(8)
672                .set_str("datafusion.catalog.has_header", "false"),
673        );
674
675        let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?;
676
677        // register csv file with the execution context
678        ctx.register_csv(
679            "test",
680            tmp_dir.path().to_str().unwrap(),
681            CsvReadOptions::new().schema(&schema),
682        )
683        .await?;
684
685        // register a local file system object store
686        let tmp_dir = TempDir::new()?;
687        let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
688        let local_url = Url::parse("file://local").unwrap();
689
690        ctx.register_object_store(&local_url, local);
691
692        // execute a simple query and write the results to CSV
693        let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/";
694        let out_dir_url = "file://local/out/";
695        let df = ctx.sql("SELECT c1, c2 FROM test").await?;
696        df.write_csv(
697            out_dir_url,
698            crate::dataframe::DataFrameWriteOptions::new(),
699            None,
700        )
701        .await?;
702
703        // create a new context and verify that the results were saved to a partitioned csv file
704        let ctx = SessionContext::new_with_config(
705            SessionConfig::new().set_str("datafusion.catalog.has_header", "false"),
706        );
707
708        let schema = Arc::new(Schema::new(vec![
709            Field::new("c1", DataType::UInt32, false),
710            Field::new("c2", DataType::UInt64, false),
711        ]));
712
713        // get name of first part
714        let paths = fs::read_dir(&out_dir).unwrap();
715        let mut part_0_name: String = "".to_owned();
716        for path in paths {
717            let path = path.unwrap();
718            let name = path
719                .path()
720                .file_name()
721                .expect("Should be a file name")
722                .to_str()
723                .expect("Should be a str")
724                .to_owned();
725            if name.ends_with("_0.csv") {
726                part_0_name = name;
727                break;
728            }
729        }
730
731        if part_0_name.is_empty() {
732            panic!("Did not find part_0 in csv output files!")
733        }
734        // register each partition as well as the top level dir
735        let csv_read_option = CsvReadOptions::new().schema(&schema).has_header(false);
736        ctx.register_csv(
737            "part0",
738            &format!("{out_dir}/{part_0_name}"),
739            csv_read_option.clone(),
740        )
741        .await?;
742        ctx.register_csv("allparts", &out_dir, csv_read_option)
743            .await?;
744
745        let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?;
746        let allparts = ctx
747            .sql("SELECT c1, c2 FROM allparts")
748            .await?
749            .collect()
750            .await?;
751
752        let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum();
753
754        assert_eq!(part0[0].schema(), allparts[0].schema());
755
756        assert_eq!(allparts_count, 80);
757
758        Ok(())
759    }
760
761    fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize {
762        match metrics.sum_by_name(metric_name) {
763            Some(v) => v.as_usize(),
764            _ => {
765                panic!(
766                    "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}"
767                );
768            }
769        }
770    }
771
772    /// Get the schema for the aggregate_test_* csv files with an additional filed not present in the files.
773    fn aggr_test_schema_with_missing_col() -> SchemaRef {
774        let fields =
775            Fields::from_iter(aggr_test_schema().fields().iter().cloned().chain(
776                std::iter::once(Arc::new(Field::new(
777                    "missing_col",
778                    DataType::Int64,
779                    true,
780                ))),
781            ));
782
783        let schema = Schema::new(fields);
784
785        Arc::new(schema)
786    }
787}