1pub use datafusion_datasource_csv::source::*;
23
24#[cfg(test)]
25mod tests {
26
27 use std::collections::HashMap;
28 use std::fs::{self, File};
29 use std::io::Write;
30 use std::sync::Arc;
31
32 use datafusion_datasource_csv::CsvFormat;
33 use object_store::ObjectStore;
34
35 use crate::prelude::CsvReadOptions;
36 use crate::prelude::SessionContext;
37 use crate::test::partitioned_file_groups;
38 use datafusion_common::test_util::arrow_test_data;
39 use datafusion_common::test_util::batches_to_string;
40 use datafusion_common::{assert_batches_eq, Result};
41 use datafusion_execution::config::SessionConfig;
42 use datafusion_physical_plan::metrics::MetricsSet;
43 use datafusion_physical_plan::ExecutionPlan;
44
45 #[cfg(feature = "compression")]
46 use datafusion_datasource::file_compression_type::FileCompressionType;
47 use datafusion_datasource_csv::partitioned_csv_config;
48 use datafusion_datasource_csv::source::CsvSource;
49 use futures::{StreamExt, TryStreamExt};
50
51 use arrow::datatypes::*;
52 use bytes::Bytes;
53 use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
54 use datafusion_datasource::source::DataSourceExec;
55 use insta::assert_snapshot;
56 use object_store::chunked::ChunkedStore;
57 use object_store::local::LocalFileSystem;
58 use rstest::*;
59 use tempfile::TempDir;
60 use url::Url;
61
62 fn aggr_test_schema() -> SchemaRef {
63 let mut f1 = Field::new("c1", DataType::Utf8, false);
64 f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())]));
65 let schema = Schema::new(vec![
66 f1,
67 Field::new("c2", DataType::UInt32, false),
68 Field::new("c3", DataType::Int8, false),
69 Field::new("c4", DataType::Int16, false),
70 Field::new("c5", DataType::Int32, false),
71 Field::new("c6", DataType::Int64, false),
72 Field::new("c7", DataType::UInt8, false),
73 Field::new("c8", DataType::UInt16, false),
74 Field::new("c9", DataType::UInt32, false),
75 Field::new("c10", DataType::UInt64, false),
76 Field::new("c11", DataType::Float32, false),
77 Field::new("c12", DataType::Float64, false),
78 Field::new("c13", DataType::Utf8, false),
79 ]);
80
81 Arc::new(schema)
82 }
83
84 #[rstest(
85 file_compression_type,
86 case(FileCompressionType::UNCOMPRESSED),
87 case(FileCompressionType::GZIP),
88 case(FileCompressionType::BZIP2),
89 case(FileCompressionType::XZ),
90 case(FileCompressionType::ZSTD)
91 )]
92 #[cfg(feature = "compression")]
93 #[tokio::test]
94 async fn csv_exec_with_projection(
95 file_compression_type: FileCompressionType,
96 ) -> Result<()> {
97 let session_ctx = SessionContext::new();
98 let task_ctx = session_ctx.task_ctx();
99 let file_schema = aggr_test_schema();
100 let path = format!("{}/csv", arrow_test_data());
101 let filename = "aggregate_test_100.csv";
102 let tmp_dir = TempDir::new()?;
103
104 let file_groups = partitioned_file_groups(
105 path.as_str(),
106 filename,
107 1,
108 Arc::new(CsvFormat::default()),
109 file_compression_type.to_owned(),
110 tmp_dir.path(),
111 )?;
112
113 let source = Arc::new(CsvSource::new(true, b',', b'"'));
114 let config = FileScanConfigBuilder::from(partitioned_csv_config(
115 file_schema,
116 file_groups,
117 source,
118 ))
119 .with_file_compression_type(file_compression_type)
120 .with_newlines_in_values(false)
121 .with_projection_indices(Some(vec![0, 2, 4]))
122 .build();
123
124 assert_eq!(13, config.file_schema().fields().len());
125 let csv = DataSourceExec::from_data_source(config);
126
127 assert_eq!(3, csv.schema().fields().len());
128
129 let mut stream = csv.execute(0, task_ctx)?;
130 let batch = stream.next().await.unwrap()?;
131 assert_eq!(3, batch.num_columns());
132 assert_eq!(100, batch.num_rows());
133
134 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
135 +----+-----+------------+
136 | c1 | c3 | c5 |
137 +----+-----+------------+
138 | c | 1 | 2033001162 |
139 | d | -40 | 706441268 |
140 | b | 29 | 994303988 |
141 | a | -85 | 1171968280 |
142 | b | -82 | 1824882165 |
143 +----+-----+------------+
144 "###);}
145 Ok(())
146 }
147
148 #[rstest(
149 file_compression_type,
150 case(FileCompressionType::UNCOMPRESSED),
151 case(FileCompressionType::GZIP),
152 case(FileCompressionType::BZIP2),
153 case(FileCompressionType::XZ),
154 case(FileCompressionType::ZSTD)
155 )]
156 #[cfg(feature = "compression")]
157 #[tokio::test]
158 async fn csv_exec_with_mixed_order_projection(
159 file_compression_type: FileCompressionType,
160 ) -> Result<()> {
161 let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true");
162 let session_ctx = SessionContext::new_with_config(cfg);
163 let task_ctx = session_ctx.task_ctx();
164 let file_schema = aggr_test_schema();
165 let path = format!("{}/csv", arrow_test_data());
166 let filename = "aggregate_test_100.csv";
167 let tmp_dir = TempDir::new()?;
168
169 let file_groups = partitioned_file_groups(
170 path.as_str(),
171 filename,
172 1,
173 Arc::new(CsvFormat::default()),
174 file_compression_type.to_owned(),
175 tmp_dir.path(),
176 )?;
177
178 let source = Arc::new(CsvSource::new(true, b',', b'"'));
179 let config = FileScanConfigBuilder::from(partitioned_csv_config(
180 file_schema,
181 file_groups,
182 source,
183 ))
184 .with_newlines_in_values(false)
185 .with_file_compression_type(file_compression_type.to_owned())
186 .with_projection_indices(Some(vec![4, 0, 2]))
187 .build();
188 assert_eq!(13, config.file_schema().fields().len());
189 let csv = DataSourceExec::from_data_source(config);
190 assert_eq!(3, csv.schema().fields().len());
191
192 let mut stream = csv.execute(0, task_ctx)?;
193 let batch = stream.next().await.unwrap()?;
194 assert_eq!(3, batch.num_columns());
195 assert_eq!(100, batch.num_rows());
196
197 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
198 +------------+----+-----+
199 | c5 | c1 | c3 |
200 +------------+----+-----+
201 | 2033001162 | c | 1 |
202 | 706441268 | d | -40 |
203 | 994303988 | b | 29 |
204 | 1171968280 | a | -85 |
205 | 1824882165 | b | -82 |
206 +------------+----+-----+
207 "###);}
208 Ok(())
209 }
210
211 #[rstest(
212 file_compression_type,
213 case(FileCompressionType::UNCOMPRESSED),
214 case(FileCompressionType::GZIP),
215 case(FileCompressionType::BZIP2),
216 case(FileCompressionType::XZ),
217 case(FileCompressionType::ZSTD)
218 )]
219 #[cfg(feature = "compression")]
220 #[tokio::test]
221 async fn csv_exec_with_limit(
222 file_compression_type: FileCompressionType,
223 ) -> Result<()> {
224 use futures::StreamExt;
225
226 let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true");
227 let session_ctx = SessionContext::new_with_config(cfg);
228 let task_ctx = session_ctx.task_ctx();
229 let file_schema = aggr_test_schema();
230 let path = format!("{}/csv", arrow_test_data());
231 let filename = "aggregate_test_100.csv";
232 let tmp_dir = TempDir::new()?;
233
234 let file_groups = partitioned_file_groups(
235 path.as_str(),
236 filename,
237 1,
238 Arc::new(CsvFormat::default()),
239 file_compression_type.to_owned(),
240 tmp_dir.path(),
241 )?;
242
243 let source = Arc::new(CsvSource::new(true, b',', b'"'));
244 let config = FileScanConfigBuilder::from(partitioned_csv_config(
245 file_schema,
246 file_groups,
247 source,
248 ))
249 .with_newlines_in_values(false)
250 .with_file_compression_type(file_compression_type.to_owned())
251 .with_limit(Some(5))
252 .build();
253 assert_eq!(13, config.file_schema().fields().len());
254 let csv = DataSourceExec::from_data_source(config);
255 assert_eq!(13, csv.schema().fields().len());
256
257 let mut it = csv.execute(0, task_ctx)?;
258 let batch = it.next().await.unwrap()?;
259 assert_eq!(13, batch.num_columns());
260 assert_eq!(5, batch.num_rows());
261
262 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###"
263 +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
264 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 |
265 +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
266 | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW |
267 | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 |
268 | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz |
269 | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB |
270 | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd |
271 +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+
272 "###);}
273
274 Ok(())
275 }
276
277 #[rstest(
278 file_compression_type,
279 case(FileCompressionType::UNCOMPRESSED),
280 case(FileCompressionType::GZIP),
281 case(FileCompressionType::BZIP2),
282 case(FileCompressionType::XZ),
283 case(FileCompressionType::ZSTD)
284 )]
285 #[cfg(feature = "compression")]
286 #[tokio::test]
287 async fn csv_exec_with_missing_column(
288 file_compression_type: FileCompressionType,
289 ) -> Result<()> {
290 let session_ctx = SessionContext::new();
291 let task_ctx = session_ctx.task_ctx();
292 let file_schema = aggr_test_schema_with_missing_col();
293 let path = format!("{}/csv", arrow_test_data());
294 let filename = "aggregate_test_100.csv";
295 let tmp_dir = TempDir::new()?;
296
297 let file_groups = partitioned_file_groups(
298 path.as_str(),
299 filename,
300 1,
301 Arc::new(CsvFormat::default()),
302 file_compression_type.to_owned(),
303 tmp_dir.path(),
304 )?;
305
306 let source = Arc::new(CsvSource::new(true, b',', b'"'));
307 let config = FileScanConfigBuilder::from(partitioned_csv_config(
308 file_schema,
309 file_groups,
310 source,
311 ))
312 .with_newlines_in_values(false)
313 .with_file_compression_type(file_compression_type.to_owned())
314 .with_limit(Some(5))
315 .build();
316 assert_eq!(14, config.file_schema().fields().len());
317 let csv = DataSourceExec::from_data_source(config);
318 assert_eq!(14, csv.schema().fields().len());
319
320 let mut it = csv.execute(0, task_ctx)?;
322 let err = it.next().await.unwrap().unwrap_err().strip_backtrace();
323 assert_eq!(
324 err,
325 "Arrow error: Csv error: incorrect number of fields for line 1, expected 14 got 13"
326 );
327 Ok(())
328 }
329
330 #[rstest(
331 file_compression_type,
332 case(FileCompressionType::UNCOMPRESSED),
333 case(FileCompressionType::GZIP),
334 case(FileCompressionType::BZIP2),
335 case(FileCompressionType::XZ),
336 case(FileCompressionType::ZSTD)
337 )]
338 #[cfg(feature = "compression")]
339 #[tokio::test]
340 async fn csv_exec_with_partition(
341 file_compression_type: FileCompressionType,
342 ) -> Result<()> {
343 use datafusion_common::ScalarValue;
344
345 let session_ctx = SessionContext::new();
346 let task_ctx = session_ctx.task_ctx();
347 let file_schema = aggr_test_schema();
348 let path = format!("{}/csv", arrow_test_data());
349 let filename = "aggregate_test_100.csv";
350 let tmp_dir = TempDir::new()?;
351
352 let mut file_groups = partitioned_file_groups(
353 path.as_str(),
354 filename,
355 1,
356 Arc::new(CsvFormat::default()),
357 file_compression_type.to_owned(),
358 tmp_dir.path(),
359 )?;
360 file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")];
362
363 let num_file_schema_fields = file_schema.fields().len();
364
365 let source = Arc::new(CsvSource::new(true, b',', b'"'));
366 let config = FileScanConfigBuilder::from(partitioned_csv_config(
367 file_schema,
368 file_groups,
369 source,
370 ))
371 .with_newlines_in_values(false)
372 .with_file_compression_type(file_compression_type.to_owned())
373 .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)])
374 .with_projection_indices(Some(vec![0, num_file_schema_fields]))
377 .build();
378
379 assert_eq!(13, config.file_schema().fields().len());
383 let csv = DataSourceExec::from_data_source(config);
384 assert_eq!(2, csv.schema().fields().len());
385
386 let mut it = csv.execute(0, task_ctx)?;
387 let batch = it.next().await.unwrap()?;
388 assert_eq!(2, batch.num_columns());
389 assert_eq!(100, batch.num_rows());
390
391 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###"
392 +----+------------+
393 | c1 | date |
394 +----+------------+
395 | c | 2021-10-26 |
396 | d | 2021-10-26 |
397 | b | 2021-10-26 |
398 | a | 2021-10-26 |
399 | b | 2021-10-26 |
400 +----+------------+
401 "###);}
402
403 let metrics = csv.metrics().expect("doesn't found metrics");
404 let time_elapsed_processing = get_value(&metrics, "time_elapsed_processing");
405 assert!(
406 time_elapsed_processing > 0,
407 "Expected time_elapsed_processing greater than 0",
408 );
409 Ok(())
410 }
411
412 fn populate_csv_partitions(
414 tmp_dir: &TempDir,
415 partition_count: usize,
416 file_extension: &str,
417 ) -> Result<SchemaRef> {
418 let schema = Arc::new(Schema::new(vec![
420 Field::new("c1", DataType::UInt32, false),
421 Field::new("c2", DataType::UInt64, false),
422 Field::new("c3", DataType::Boolean, false),
423 ]));
424
425 for partition in 0..partition_count {
427 let filename = format!("partition-{partition}.{file_extension}");
428 let file_path = tmp_dir.path().join(filename);
429 let mut file = File::create(file_path)?;
430
431 for i in 0..=10 {
433 let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
434 file.write_all(data.as_bytes())?;
435 }
436 }
437
438 Ok(schema)
439 }
440
441 async fn test_additional_stores(
442 file_compression_type: FileCompressionType,
443 store: Arc<dyn ObjectStore>,
444 ) -> Result<()> {
445 let ctx = SessionContext::new();
446 let url = Url::parse("file://").unwrap();
447 ctx.register_object_store(&url, store.clone());
448
449 let task_ctx = ctx.task_ctx();
450
451 let file_schema = aggr_test_schema();
452 let path = format!("{}/csv", arrow_test_data());
453 let filename = "aggregate_test_100.csv";
454 let tmp_dir = TempDir::new()?;
455
456 let file_groups = partitioned_file_groups(
457 path.as_str(),
458 filename,
459 1,
460 Arc::new(CsvFormat::default()),
461 file_compression_type.to_owned(),
462 tmp_dir.path(),
463 )
464 .unwrap();
465
466 let source = Arc::new(CsvSource::new(true, b',', b'"'));
467 let config = FileScanConfigBuilder::from(partitioned_csv_config(
468 file_schema,
469 file_groups,
470 source,
471 ))
472 .with_newlines_in_values(false)
473 .with_file_compression_type(file_compression_type.to_owned())
474 .build();
475 let csv = DataSourceExec::from_data_source(config);
476
477 let it = csv.execute(0, task_ctx).unwrap();
478 let batches: Vec<_> = it.try_collect().await.unwrap();
479
480 let total_rows = batches.iter().map(|b| b.num_rows()).sum::<usize>();
481
482 assert_eq!(total_rows, 100);
483 Ok(())
484 }
485
486 #[rstest(
487 file_compression_type,
488 case(FileCompressionType::UNCOMPRESSED),
489 case(FileCompressionType::GZIP),
490 case(FileCompressionType::BZIP2),
491 case(FileCompressionType::XZ),
492 case(FileCompressionType::ZSTD)
493 )]
494 #[cfg(feature = "compression")]
495 #[tokio::test]
496 async fn test_chunked_csv(
497 file_compression_type: FileCompressionType,
498 #[values(10, 20, 30, 40)] chunk_size: usize,
499 ) -> Result<()> {
500 test_additional_stores(
501 file_compression_type,
502 Arc::new(ChunkedStore::new(
503 Arc::new(LocalFileSystem::new()),
504 chunk_size,
505 )),
506 )
507 .await?;
508 Ok(())
509 }
510
511 #[tokio::test]
512 async fn test_no_trailing_delimiter() {
513 let session_ctx = SessionContext::new();
514 let store = object_store::memory::InMemory::new();
515
516 let data = Bytes::from("a,b\n1,2\n3,4");
517 let path = object_store::path::Path::from("a.csv");
518 store.put(&path, data.into()).await.unwrap();
519
520 let url = Url::parse("memory://").unwrap();
521 session_ctx.register_object_store(&url, Arc::new(store));
522
523 let df = session_ctx
524 .read_csv("memory:///", CsvReadOptions::new())
525 .await
526 .unwrap();
527
528 let result = df.collect().await.unwrap();
529
530 assert_snapshot!(batches_to_string(&result), @r###"
531 +---+---+
532 | a | b |
533 +---+---+
534 | 1 | 2 |
535 | 3 | 4 |
536 +---+---+
537 "###);
538 }
539
540 #[tokio::test]
541 async fn test_terminator() {
542 let session_ctx = SessionContext::new();
543 let store = object_store::memory::InMemory::new();
544
545 let data = Bytes::from("a,b\r1,2\r3,4");
546 let path = object_store::path::Path::from("a.csv");
547 store.put(&path, data.into()).await.unwrap();
548
549 let url = Url::parse("memory://").unwrap();
550 session_ctx.register_object_store(&url, Arc::new(store));
551
552 let df = session_ctx
553 .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\r')))
554 .await
555 .unwrap();
556
557 let result = df.collect().await.unwrap();
558
559 assert_snapshot!(batches_to_string(&result),@r###"
560 +---+---+
561 | a | b |
562 +---+---+
563 | 1 | 2 |
564 | 3 | 4 |
565 +---+---+
566 "###);
567
568 let e = session_ctx
569 .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n')))
570 .await
571 .unwrap()
572 .collect()
573 .await
574 .unwrap_err();
575 assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2")
576 }
577
578 #[tokio::test]
579 async fn test_create_external_table_with_terminator() -> Result<()> {
580 let ctx = SessionContext::new();
581 ctx.sql(
582 r#"
583 CREATE EXTERNAL TABLE t1 (
584 col1 TEXT,
585 col2 TEXT
586 ) STORED AS CSV
587 LOCATION 'tests/data/cr_terminator.csv'
588 OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true');
589 "#,
590 )
591 .await?
592 .collect()
593 .await?;
594
595 let df = ctx.sql(r#"select * from t1"#).await?.collect().await?;
596 assert_snapshot!(batches_to_string(&df),@r###"
597 +------+--------+
598 | col1 | col2 |
599 +------+--------+
600 | id0 | value0 |
601 | id1 | value1 |
602 | id2 | value2 |
603 | id3 | value3 |
604 +------+--------+
605 "###);
606 Ok(())
607 }
608
609 #[tokio::test]
610 async fn test_create_external_table_with_terminator_with_newlines_in_values(
611 ) -> Result<()> {
612 let ctx = SessionContext::new();
613 ctx.sql(r#"
614 CREATE EXTERNAL TABLE t1 (
615 col1 TEXT,
616 col2 TEXT
617 ) STORED AS CSV
618 LOCATION 'tests/data/newlines_in_values_cr_terminator.csv'
619 OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true', 'format.newlines_in_values' 'true');
620 "#).await?.collect().await?;
621
622 let df = ctx.sql(r#"select * from t1"#).await?.collect().await?;
623 let expected = [
624 "+-------+-----------------------------+",
625 "| col1 | col2 |",
626 "+-------+-----------------------------+",
627 "| 1 | hello\rworld |",
628 "| 2 | something\relse |",
629 "| 3 | \rmany\rlines\rmake\rgood test\r |",
630 "| 4 | unquoted |",
631 "| value | end |",
632 "+-------+-----------------------------+",
633 ];
634 assert_batches_eq!(expected, &df);
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn write_csv_results_error_handling() -> Result<()> {
640 let ctx = SessionContext::new();
641
642 let tmp_dir = TempDir::new()?;
644 let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
645 let local_url = Url::parse("file://local").unwrap();
646 ctx.register_object_store(&local_url, local);
647 let options = CsvReadOptions::default()
648 .schema_infer_max_records(2)
649 .has_header(true);
650 let df = ctx.read_csv("tests/data/corrupt.csv", options).await?;
651
652 let out_dir_url = "file://local/out";
653 let e = df
654 .write_csv(
655 out_dir_url,
656 crate::dataframe::DataFrameWriteOptions::new(),
657 None,
658 )
659 .await
660 .expect_err("should fail because input file does not match inferred schema");
661 assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'");
662 Ok(())
663 }
664
665 #[tokio::test]
666 async fn write_csv_results() -> Result<()> {
667 let tmp_dir = TempDir::new()?;
669 let ctx = SessionContext::new_with_config(
670 SessionConfig::new()
671 .with_target_partitions(8)
672 .set_str("datafusion.catalog.has_header", "false"),
673 );
674
675 let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?;
676
677 ctx.register_csv(
679 "test",
680 tmp_dir.path().to_str().unwrap(),
681 CsvReadOptions::new().schema(&schema),
682 )
683 .await?;
684
685 let tmp_dir = TempDir::new()?;
687 let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
688 let local_url = Url::parse("file://local").unwrap();
689
690 ctx.register_object_store(&local_url, local);
691
692 let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/";
694 let out_dir_url = "file://local/out/";
695 let df = ctx.sql("SELECT c1, c2 FROM test").await?;
696 df.write_csv(
697 out_dir_url,
698 crate::dataframe::DataFrameWriteOptions::new(),
699 None,
700 )
701 .await?;
702
703 let ctx = SessionContext::new_with_config(
705 SessionConfig::new().set_str("datafusion.catalog.has_header", "false"),
706 );
707
708 let schema = Arc::new(Schema::new(vec![
709 Field::new("c1", DataType::UInt32, false),
710 Field::new("c2", DataType::UInt64, false),
711 ]));
712
713 let paths = fs::read_dir(&out_dir).unwrap();
715 let mut part_0_name: String = "".to_owned();
716 for path in paths {
717 let path = path.unwrap();
718 let name = path
719 .path()
720 .file_name()
721 .expect("Should be a file name")
722 .to_str()
723 .expect("Should be a str")
724 .to_owned();
725 if name.ends_with("_0.csv") {
726 part_0_name = name;
727 break;
728 }
729 }
730
731 if part_0_name.is_empty() {
732 panic!("Did not find part_0 in csv output files!")
733 }
734 let csv_read_option = CsvReadOptions::new().schema(&schema).has_header(false);
736 ctx.register_csv(
737 "part0",
738 &format!("{out_dir}/{part_0_name}"),
739 csv_read_option.clone(),
740 )
741 .await?;
742 ctx.register_csv("allparts", &out_dir, csv_read_option)
743 .await?;
744
745 let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?;
746 let allparts = ctx
747 .sql("SELECT c1, c2 FROM allparts")
748 .await?
749 .collect()
750 .await?;
751
752 let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum();
753
754 assert_eq!(part0[0].schema(), allparts[0].schema());
755
756 assert_eq!(allparts_count, 80);
757
758 Ok(())
759 }
760
761 fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize {
762 match metrics.sum_by_name(metric_name) {
763 Some(v) => v.as_usize(),
764 _ => {
765 panic!(
766 "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}"
767 );
768 }
769 }
770 }
771
772 fn aggr_test_schema_with_missing_col() -> SchemaRef {
774 let fields =
775 Fields::from_iter(aggr_test_schema().fields().iter().cloned().chain(
776 std::iter::once(Arc::new(Field::new(
777 "missing_col",
778 DataType::Int64,
779 true,
780 ))),
781 ));
782
783 let schema = Schema::new(fields);
784
785 Arc::new(schema)
786 }
787}