1pub use datafusion_datasource_csv::file_format::*;
20
21#[cfg(test)]
22mod tests {
23 use std::fmt::{self, Display};
24 use std::ops::Range;
25 use std::sync::{Arc, Mutex};
26
27 use super::*;
28
29 use crate::datasource::file_format::test_util::scan_format;
30 use crate::datasource::listing::ListingOptions;
31 use crate::execution::session_state::SessionStateBuilder;
32 use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
33 use arrow_schema::{DataType, Field, Schema, SchemaRef};
34 use datafusion_catalog::Session;
35 use datafusion_common::cast::as_string_array;
36 use datafusion_common::config::CsvOptions;
37 use datafusion_common::internal_err;
38 use datafusion_common::stats::Precision;
39 use datafusion_common::test_util::{arrow_test_data, batches_to_string};
40 use datafusion_common::Result;
41 use datafusion_datasource::decoder::{
42 BatchDeserializer, DecoderDeserializer, DeserializerOutput,
43 };
44 use datafusion_datasource::file_compression_type::FileCompressionType;
45 use datafusion_datasource::file_format::FileFormat;
46 use datafusion_datasource::write::BatchSerializer;
47 use datafusion_expr::{col, lit};
48 use datafusion_physical_plan::{collect, ExecutionPlan};
49
50 use arrow::array::{
51 Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray,
52 };
53 use arrow::compute::concat_batches;
54 use arrow::csv::ReaderBuilder;
55 use arrow::util::pretty::pretty_format_batches;
56 use async_trait::async_trait;
57 use bytes::Bytes;
58 use chrono::DateTime;
59 use datafusion_common::parsers::CompressionTypeVariant;
60 use futures::stream::BoxStream;
61 use futures::StreamExt;
62 use insta::assert_snapshot;
63 use object_store::chunked::ChunkedStore;
64 use object_store::local::LocalFileSystem;
65 use object_store::path::Path;
66 use object_store::{
67 Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload,
68 ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult,
69 };
70 use regex::Regex;
71 use rstest::*;
72
73 #[derive(Debug)]
76 struct VariableStream {
77 bytes_to_repeat: Bytes,
78 max_iterations: u64,
79 iterations_detected: Arc<Mutex<usize>>,
80 }
81
82 impl Display for VariableStream {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 write!(f, "VariableStream")
85 }
86 }
87
88 #[async_trait]
89 impl ObjectStore for VariableStream {
90 async fn put_opts(
91 &self,
92 _location: &Path,
93 _payload: PutPayload,
94 _opts: PutOptions,
95 ) -> object_store::Result<PutResult> {
96 unimplemented!()
97 }
98
99 async fn put_multipart_opts(
100 &self,
101 _location: &Path,
102 _opts: PutMultipartOptions,
103 ) -> object_store::Result<Box<dyn MultipartUpload>> {
104 unimplemented!()
105 }
106
107 async fn get(&self, location: &Path) -> object_store::Result<GetResult> {
108 self.get_opts(location, GetOptions::default()).await
109 }
110
111 async fn get_opts(
112 &self,
113 location: &Path,
114 _opts: GetOptions,
115 ) -> object_store::Result<GetResult> {
116 let bytes = self.bytes_to_repeat.clone();
117 let len = bytes.len() as u64;
118 let range = 0..len * self.max_iterations;
119 let arc = self.iterations_detected.clone();
120 let stream = futures::stream::repeat_with(move || {
121 let arc_inner = arc.clone();
122 *arc_inner.lock().unwrap() += 1;
123 Ok(bytes.clone())
124 })
125 .take(self.max_iterations as usize)
126 .boxed();
127
128 Ok(GetResult {
129 payload: GetResultPayload::Stream(stream),
130 meta: ObjectMeta {
131 location: location.clone(),
132 last_modified: Default::default(),
133 size: range.end,
134 e_tag: None,
135 version: None,
136 },
137 range: Default::default(),
138 attributes: Attributes::default(),
139 })
140 }
141
142 async fn get_ranges(
143 &self,
144 _location: &Path,
145 _ranges: &[Range<u64>],
146 ) -> object_store::Result<Vec<Bytes>> {
147 unimplemented!()
148 }
149
150 async fn head(&self, _location: &Path) -> object_store::Result<ObjectMeta> {
151 unimplemented!()
152 }
153
154 async fn delete(&self, _location: &Path) -> object_store::Result<()> {
155 unimplemented!()
156 }
157
158 fn list(
159 &self,
160 _prefix: Option<&Path>,
161 ) -> BoxStream<'static, object_store::Result<ObjectMeta>> {
162 unimplemented!()
163 }
164
165 async fn list_with_delimiter(
166 &self,
167 _prefix: Option<&Path>,
168 ) -> object_store::Result<ListResult> {
169 unimplemented!()
170 }
171
172 async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> {
173 unimplemented!()
174 }
175
176 async fn copy_if_not_exists(
177 &self,
178 _from: &Path,
179 _to: &Path,
180 ) -> object_store::Result<()> {
181 unimplemented!()
182 }
183 }
184
185 impl VariableStream {
186 pub fn new(bytes_to_repeat: Bytes, max_iterations: u64) -> Self {
187 Self {
188 bytes_to_repeat,
189 max_iterations,
190 iterations_detected: Arc::new(Mutex::new(0)),
191 }
192 }
193
194 pub fn get_iterations_detected(&self) -> usize {
195 *self.iterations_detected.lock().unwrap()
196 }
197 }
198
199 #[tokio::test]
200 async fn read_small_batches() -> Result<()> {
201 let config = SessionConfig::new().with_batch_size(2);
202 let session_ctx = SessionContext::new_with_config(config);
203 let state = session_ctx.state();
204 let task_ctx = state.task_ctx();
205 let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
207 let exec =
208 get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?;
209 let stream = exec.execute(0, task_ctx)?;
210
211 let tt_batches: i32 = stream
212 .map(|batch| {
213 let batch = batch.unwrap();
214 assert_eq!(12, batch.num_columns());
215 assert_eq!(2, batch.num_rows());
216 })
217 .fold(0, |acc, _| async move { acc + 1i32 })
218 .await;
219
220 assert_eq!(tt_batches, 50 );
221
222 assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent);
224 assert_eq!(
225 exec.partition_statistics(None)?.total_byte_size,
226 Precision::Absent
227 );
228
229 Ok(())
230 }
231
232 #[tokio::test]
233 async fn read_limit() -> Result<()> {
234 let session_ctx = SessionContext::new();
235 let state = session_ctx.state();
236 let task_ctx = session_ctx.task_ctx();
237 let projection = Some(vec![0, 1, 2, 3]);
238 let exec =
239 get_exec(&state, "aggregate_test_100.csv", projection, Some(1), true).await?;
240 let batches = collect(exec, task_ctx).await?;
241 assert_eq!(1, batches.len());
242 assert_eq!(4, batches[0].num_columns());
243 assert_eq!(1, batches[0].num_rows());
244
245 Ok(())
246 }
247
248 #[tokio::test]
249 async fn infer_schema() -> Result<()> {
250 let session_ctx = SessionContext::new();
251 let state = session_ctx.state();
252
253 let projection = None;
254 let root = "./tests/data/csv";
255 let format = CsvFormat::default().with_has_header(true);
256 let exec = scan_format(
257 &state,
258 &format,
259 None,
260 root,
261 "aggregate_test_100_with_nulls.csv",
262 projection,
263 None,
264 )
265 .await?;
266
267 let x: Vec<String> = exec
268 .schema()
269 .fields()
270 .iter()
271 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
272 .collect();
273 assert_eq!(
274 vec![
275 "c1: Utf8",
276 "c2: Int64",
277 "c3: Int64",
278 "c4: Int64",
279 "c5: Int64",
280 "c6: Int64",
281 "c7: Int64",
282 "c8: Int64",
283 "c9: Int64",
284 "c10: Utf8",
285 "c11: Float64",
286 "c12: Float64",
287 "c13: Utf8",
288 "c14: Null",
289 "c15: Utf8"
290 ],
291 x
292 );
293
294 Ok(())
295 }
296
297 #[tokio::test]
298 async fn infer_schema_with_null_regex() -> Result<()> {
299 let session_ctx = SessionContext::new();
300 let state = session_ctx.state();
301
302 let projection = None;
303 let root = "./tests/data/csv";
304 let format = CsvFormat::default()
305 .with_has_header(true)
306 .with_null_regex(Some("^NULL$|^$".to_string()));
307 let exec = scan_format(
308 &state,
309 &format,
310 None,
311 root,
312 "aggregate_test_100_with_nulls.csv",
313 projection,
314 None,
315 )
316 .await?;
317
318 let x: Vec<String> = exec
319 .schema()
320 .fields()
321 .iter()
322 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
323 .collect();
324 assert_eq!(
325 vec![
326 "c1: Utf8",
327 "c2: Int64",
328 "c3: Int64",
329 "c4: Int64",
330 "c5: Int64",
331 "c6: Int64",
332 "c7: Int64",
333 "c8: Int64",
334 "c9: Int64",
335 "c10: Utf8",
336 "c11: Float64",
337 "c12: Float64",
338 "c13: Utf8",
339 "c14: Null",
340 "c15: Null"
341 ],
342 x
343 );
344
345 Ok(())
346 }
347
348 #[tokio::test]
349 async fn read_char_column() -> Result<()> {
350 let session_ctx = SessionContext::new();
351 let state = session_ctx.state();
352 let task_ctx = session_ctx.task_ctx();
353 let projection = Some(vec![0]);
354 let exec =
355 get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?;
356
357 let batches = collect(exec, task_ctx).await.expect("Collect batches");
358
359 assert_eq!(1, batches.len());
360 assert_eq!(1, batches[0].num_columns());
361 assert_eq!(100, batches[0].num_rows());
362
363 let array = as_string_array(batches[0].column(0))?;
364 let mut values: Vec<&str> = vec![];
365 for i in 0..5 {
366 values.push(array.value(i));
367 }
368
369 assert_eq!(vec!["c", "d", "b", "a", "b"], values);
370
371 Ok(())
372 }
373
374 #[tokio::test]
375 async fn test_infer_schema_stream() -> Result<()> {
376 let session_ctx = SessionContext::new();
377 let state = session_ctx.state();
378 let variable_object_store =
379 Arc::new(VariableStream::new(Bytes::from("1,2,3,4,5\n"), 200));
380 let object_meta = ObjectMeta {
381 location: Path::parse("/")?,
382 last_modified: DateTime::default(),
383 size: u64::MAX,
384 e_tag: None,
385 version: None,
386 };
387
388 let num_rows_to_read = 100;
389 let csv_format = CsvFormat::default()
390 .with_has_header(false)
391 .with_schema_infer_max_rec(num_rows_to_read);
392 let inferred_schema = csv_format
393 .infer_schema(
394 &state,
395 &(variable_object_store.clone() as Arc<dyn ObjectStore>),
396 &[object_meta],
397 )
398 .await?;
399
400 let actual_fields: Vec<_> = inferred_schema
401 .fields()
402 .iter()
403 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
404 .collect();
405 assert_eq!(
406 vec![
407 "column_1: Int64",
408 "column_2: Int64",
409 "column_3: Int64",
410 "column_4: Int64",
411 "column_5: Int64"
412 ],
413 actual_fields
414 );
415 assert_eq!(
418 num_rows_to_read,
419 variable_object_store.get_iterations_detected()
420 );
421
422 Ok(())
423 }
424
425 #[tokio::test]
426 async fn test_infer_schema_escape_chars() -> Result<()> {
427 let session_ctx = SessionContext::new();
428 let state = session_ctx.state();
429 let variable_object_store = Arc::new(VariableStream::new(
430 Bytes::from(
431 r#"c1,c2,c3,c4
4320.3,"Here, is a comma\"",third,3
4330.31,"double quotes are ok, "" quote",third again,9
4340.314,abc,xyz,27"#,
435 ),
436 1,
437 ));
438 let object_meta = ObjectMeta {
439 location: Path::parse("/")?,
440 last_modified: DateTime::default(),
441 size: u64::MAX,
442 e_tag: None,
443 version: None,
444 };
445
446 let num_rows_to_read = 3;
447 let csv_format = CsvFormat::default()
448 .with_has_header(true)
449 .with_schema_infer_max_rec(num_rows_to_read)
450 .with_quote(b'"')
451 .with_escape(Some(b'\\'));
452
453 let inferred_schema = csv_format
454 .infer_schema(
455 &state,
456 &(variable_object_store.clone() as Arc<dyn ObjectStore>),
457 &[object_meta],
458 )
459 .await?;
460
461 let actual_fields: Vec<_> = inferred_schema
462 .fields()
463 .iter()
464 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
465 .collect();
466
467 assert_eq!(
468 vec!["c1: Float64", "c2: Utf8", "c3: Utf8", "c4: Int64",],
469 actual_fields
470 );
471 Ok(())
472 }
473
474 #[tokio::test]
475 async fn test_infer_schema_stream_null_chunks() -> Result<()> {
476 let session_ctx = SessionContext::new();
477 let state = session_ctx.state();
478
479 let chunked_object_store = Arc::new(ChunkedStore::new(
488 Arc::new(VariableStream::new(
489 Bytes::from(
490 r#"c1,c2,c3
4911,1.0,
492,,
493"#,
494 ),
495 1,
496 )),
497 1,
498 ));
499 let object_meta = ObjectMeta {
500 location: Path::parse("/")?,
501 last_modified: DateTime::default(),
502 size: u64::MAX,
503 e_tag: None,
504 version: None,
505 };
506
507 let csv_format = CsvFormat::default().with_has_header(true);
508 let inferred_schema = csv_format
509 .infer_schema(
510 &state,
511 &(chunked_object_store as Arc<dyn ObjectStore>),
512 &[object_meta],
513 )
514 .await?;
515
516 let actual_fields: Vec<_> = inferred_schema
517 .fields()
518 .iter()
519 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
520 .collect();
521
522 assert_eq!(vec!["c1: Int64", "c2: Float64", "c3: Null"], actual_fields);
524 Ok(())
525 }
526
527 #[rstest(
528 file_compression_type,
529 case(FileCompressionType::UNCOMPRESSED),
530 case(FileCompressionType::GZIP),
531 case(FileCompressionType::BZIP2),
532 case(FileCompressionType::XZ),
533 case(FileCompressionType::ZSTD)
534 )]
535 #[cfg(feature = "compression")]
536 #[tokio::test]
537 async fn query_compress_data(
538 file_compression_type: FileCompressionType,
539 ) -> Result<()> {
540 use arrow_schema::{DataType, Field, Schema};
541 use datafusion_common::DataFusionError;
542 use datafusion_datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
543 use futures::TryStreamExt;
544
545 let mut cfg = SessionConfig::new();
546 cfg.options_mut().catalog.has_header = true;
547 let session_state = SessionStateBuilder::new()
548 .with_config(cfg)
549 .with_default_features()
550 .build();
551 let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap();
552 let path = Path::from("csv/aggregate_test_100.csv");
553 let csv = CsvFormat::default().with_has_header(true);
554 let records_to_read = csv
555 .options()
556 .schema_infer_max_rec
557 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
558 let store = Arc::new(integration) as Arc<dyn ObjectStore>;
559 let original_stream = store.get(&path).await?;
560
561 let compressed_stream =
563 file_compression_type.to_owned().convert_to_compress_stream(
564 original_stream
565 .into_stream()
566 .map_err(DataFusionError::from)
567 .boxed(),
568 );
569
570 let expected = Schema::new(vec![
572 Field::new("c1", DataType::Utf8, true),
573 Field::new("c2", DataType::Int64, true),
574 Field::new("c3", DataType::Int64, true),
575 Field::new("c4", DataType::Int64, true),
576 Field::new("c5", DataType::Int64, true),
577 Field::new("c6", DataType::Int64, true),
578 Field::new("c7", DataType::Int64, true),
579 Field::new("c8", DataType::Int64, true),
580 Field::new("c9", DataType::Int64, true),
581 Field::new("c10", DataType::Utf8, true),
582 Field::new("c11", DataType::Float64, true),
583 Field::new("c12", DataType::Float64, true),
584 Field::new("c13", DataType::Utf8, true),
585 ]);
586
587 let compressed_csv = csv.with_file_compression_type(file_compression_type);
588
589 let decoded_stream = compressed_csv
591 .read_to_delimited_chunks_from_stream(compressed_stream.unwrap())
592 .await;
593 let (schema, records_read) = compressed_csv
594 .infer_schema_from_stream(&session_state, records_to_read, decoded_stream)
595 .await?;
596
597 assert_eq!(expected, schema);
598 assert_eq!(100, records_read);
599 Ok(())
600 }
601
602 #[cfg(feature = "compression")]
603 #[tokio::test]
604 async fn query_compress_csv() -> Result<()> {
605 let ctx = SessionContext::new();
606
607 let csv_options = CsvReadOptions::default()
608 .has_header(true)
609 .file_compression_type(FileCompressionType::GZIP)
610 .file_extension("csv.gz");
611 let df = ctx
612 .read_csv(
613 &format!("{}/csv/aggregate_test_100.csv.gz", arrow_test_data()),
614 csv_options,
615 )
616 .await?;
617
618 let record_batch = df
619 .filter(col("c1").eq(lit("a")).and(col("c2").gt(lit("4"))))?
620 .select_columns(&["c2", "c3"])?
621 .collect()
622 .await?;
623
624 assert_snapshot!(batches_to_string(&record_batch), @r###"
625 +----+------+
626 | c2 | c3 |
627 +----+------+
628 | 5 | 36 |
629 | 5 | -31 |
630 | 5 | -101 |
631 +----+------+
632 "###);
633
634 Ok(())
635 }
636
637 async fn get_exec(
638 state: &dyn Session,
639 file_name: &str,
640 projection: Option<Vec<usize>>,
641 limit: Option<usize>,
642 has_header: bool,
643 ) -> Result<Arc<dyn ExecutionPlan>> {
644 let root = format!("{}/csv", arrow_test_data());
645 let format = CsvFormat::default().with_has_header(has_header);
646 scan_format(state, &format, None, &root, file_name, projection, limit).await
647 }
648
649 #[tokio::test]
650 async fn test_csv_serializer() -> Result<()> {
651 let ctx = SessionContext::new();
652 let df = ctx
653 .read_csv(
654 &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
655 CsvReadOptions::default().has_header(true),
656 )
657 .await?;
658 let batches = df
659 .select_columns(&["c2", "c3"])?
660 .limit(0, Some(10))?
661 .collect()
662 .await?;
663 let batch = concat_batches(&batches[0].schema(), &batches)?;
664 let serializer = CsvSerializer::new();
665 let bytes = serializer.serialize(batch, true)?;
666 assert_eq!(
667 "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
668 String::from_utf8(bytes.into()).unwrap()
669 );
670 Ok(())
671 }
672
673 #[tokio::test]
674 async fn test_csv_serializer_no_header() -> Result<()> {
675 let ctx = SessionContext::new();
676 let df = ctx
677 .read_csv(
678 &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
679 CsvReadOptions::default().has_header(true),
680 )
681 .await?;
682 let batches = df
683 .select_columns(&["c2", "c3"])?
684 .limit(0, Some(10))?
685 .collect()
686 .await?;
687 let batch = concat_batches(&batches[0].schema(), &batches)?;
688 let serializer = CsvSerializer::new().with_header(false);
689 let bytes = serializer.serialize(batch, true)?;
690 assert_eq!(
691 "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
692 String::from_utf8(bytes.into()).unwrap()
693 );
694 Ok(())
695 }
696
697 async fn count_query_csv_partitions(
700 ctx: &SessionContext,
701 sql: &str,
702 ) -> Result<usize> {
703 let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
704 let result = df.collect().await?;
705 let plan = format!("{}", &pretty_format_batches(&result)?);
706
707 let re = Regex::new(r"DataSourceExec: file_groups=\{(\d+) group").unwrap();
708
709 if let Some(captures) = re.captures(&plan) {
710 if let Some(match_) = captures.get(1) {
711 let n_partitions = match_.as_str().parse::<usize>().unwrap();
712 return Ok(n_partitions);
713 }
714 }
715
716 internal_err!("query contains no DataSourceExec")
717 }
718
719 #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
720 #[tokio::test]
721 async fn test_csv_parallel_basic(n_partitions: usize) -> Result<()> {
722 let config = SessionConfig::new()
723 .with_repartition_file_scans(true)
724 .with_repartition_file_min_size(0)
725 .with_target_partitions(n_partitions);
726 let ctx = SessionContext::new_with_config(config);
727 let testdata = arrow_test_data();
728 ctx.register_csv(
729 "aggr",
730 &format!("{testdata}/csv/aggregate_test_100.csv"),
731 CsvReadOptions::new().has_header(true),
732 )
733 .await?;
734
735 let query = "select sum(c2) from aggr;";
736 let query_result = ctx.sql(query).await?.collect().await?;
737 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
738
739 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
740 +--------------+
741 | sum(aggr.c2) |
742 +--------------+
743 | 285 |
744 +--------------+
745 "###);
746 }
747
748 assert_eq!(n_partitions, actual_partitions);
749
750 Ok(())
751 }
752
753 #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
754 #[cfg(feature = "compression")]
755 #[tokio::test]
756 async fn test_csv_parallel_compressed(n_partitions: usize) -> Result<()> {
757 let config = SessionConfig::new()
758 .with_repartition_file_scans(true)
759 .with_repartition_file_min_size(0)
760 .with_target_partitions(n_partitions);
761 let csv_options = CsvReadOptions::default()
762 .has_header(true)
763 .file_compression_type(FileCompressionType::GZIP)
764 .file_extension("csv.gz");
765 let ctx = SessionContext::new_with_config(config);
766 let testdata = arrow_test_data();
767 ctx.register_csv(
768 "aggr",
769 &format!("{testdata}/csv/aggregate_test_100.csv.gz"),
770 csv_options,
771 )
772 .await?;
773
774 let query = "select sum(c3) from aggr;";
775 let query_result = ctx.sql(query).await?.collect().await?;
776 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
777
778 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
779 +--------------+
780 | sum(aggr.c3) |
781 +--------------+
782 | 781 |
783 +--------------+
784 "###);
785 }
786
787 assert_eq!(1, actual_partitions); Ok(())
790 }
791
792 #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
793 #[tokio::test]
794 async fn test_csv_parallel_newlines_in_values(n_partitions: usize) -> Result<()> {
795 let config = SessionConfig::new()
796 .with_repartition_file_scans(true)
797 .with_repartition_file_min_size(0)
798 .with_target_partitions(n_partitions);
799 let csv_options = CsvReadOptions::default()
800 .has_header(true)
801 .newlines_in_values(true);
802 let ctx = SessionContext::new_with_config(config);
803 let testdata = arrow_test_data();
804 ctx.register_csv(
805 "aggr",
806 &format!("{testdata}/csv/aggregate_test_100.csv"),
807 csv_options,
808 )
809 .await?;
810
811 let query = "select sum(c3) from aggr;";
812 let query_result = ctx.sql(query).await?.collect().await?;
813 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
814
815 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
816 +--------------+
817 | sum(aggr.c3) |
818 +--------------+
819 | 781 |
820 +--------------+
821 "###);
822 }
823
824 assert_eq!(1, actual_partitions); Ok(())
827 }
828
829 #[tokio::test]
834 async fn test_csv_empty_file() -> Result<()> {
835 let ctx = SessionContext::new();
836 ctx.register_csv(
837 "empty",
838 "tests/data/empty_0_byte.csv",
839 CsvReadOptions::new().has_header(false),
840 )
841 .await?;
842
843 let query = "select * from empty where random() > 0.5;";
844 let query_result = ctx.sql(query).await?.collect().await?;
845
846 assert_snapshot!(batches_to_string(&query_result),@r###"
847 ++
848 ++
849 "###);
850
851 Ok(())
852 }
853
854 #[tokio::test]
859 async fn test_csv_empty_with_header() -> Result<()> {
860 let ctx = SessionContext::new();
861 ctx.register_csv(
862 "empty",
863 "tests/data/empty.csv",
864 CsvReadOptions::new().has_header(true),
865 )
866 .await?;
867
868 let query = "select * from empty where random() > 0.5;";
869 let query_result = ctx.sql(query).await?.collect().await?;
870
871 assert_snapshot!(batches_to_string(&query_result),@r###"
872 ++
873 ++
874 "###);
875
876 Ok(())
877 }
878
879 #[tokio::test]
898 async fn test_csv_some_empty_with_header() -> Result<()> {
899 let ctx = SessionContext::new();
900 ctx.register_csv(
901 "some_empty_with_header",
902 "tests/data/empty_files/some_empty_with_header",
903 CsvReadOptions::new().has_header(true),
904 )
905 .await?;
906
907 let query = "select sum(c3) from some_empty_with_header;";
908 let query_result = ctx.sql(query).await?.collect().await?;
909
910 assert_snapshot!(batches_to_string(&query_result),@r"
911 +--------------------------------+
912 | sum(some_empty_with_header.c3) |
913 +--------------------------------+
914 | 3 |
915 +--------------------------------+
916 ");
917
918 Ok(())
919 }
920
921 #[tokio::test]
922 async fn test_csv_extension_compressed() -> Result<()> {
923 let ctx = SessionContext::new();
926
927 let df = ctx
928 .read_csv(
929 &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
930 CsvReadOptions::default().has_header(true),
931 )
932 .await?;
933
934 let tmp_dir = tempfile::TempDir::new().unwrap();
935 let path = format!("{}", tmp_dir.path().to_string_lossy());
936
937 let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
938 let cfg2 = CsvOptions::default()
939 .with_has_header(true)
940 .with_compression(CompressionTypeVariant::GZIP);
941
942 df.write_csv(&path, cfg1, Some(cfg2)).await?;
943 assert!(std::path::Path::new(&path).exists());
944
945 let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
946 assert_eq!(files.len(), 1);
947 assert!(files
948 .last()
949 .unwrap()
950 .as_ref()
951 .unwrap()
952 .path()
953 .file_name()
954 .unwrap()
955 .to_str()
956 .unwrap()
957 .ends_with(".csv.gz"));
958
959 Ok(())
960 }
961
962 #[tokio::test]
963 async fn test_csv_extension_uncompressed() -> Result<()> {
964 let ctx = SessionContext::new();
967
968 let df = ctx
969 .read_csv(
970 &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
971 CsvReadOptions::default().has_header(true),
972 )
973 .await?;
974
975 let tmp_dir = tempfile::TempDir::new().unwrap();
976 let path = format!("{}", tmp_dir.path().to_string_lossy());
977
978 let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
979 let cfg2 = CsvOptions::default().with_has_header(true);
980
981 df.write_csv(&path, cfg1, Some(cfg2)).await?;
982 assert!(std::path::Path::new(&path).exists());
983
984 let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
985 assert_eq!(files.len(), 1);
986 assert!(files
987 .last()
988 .unwrap()
989 .as_ref()
990 .unwrap()
991 .path()
992 .file_name()
993 .unwrap()
994 .to_str()
995 .unwrap()
996 .ends_with(".csv"));
997
998 Ok(())
999 }
1000
1001 #[tokio::test]
1011 async fn test_csv_multiple_empty_files() -> Result<()> {
1012 let config = SessionConfig::new()
1014 .with_repartition_file_scans(true)
1015 .with_repartition_file_min_size(0)
1016 .with_target_partitions(4);
1017 let ctx = SessionContext::new_with_config(config);
1018 let file_format = Arc::new(CsvFormat::default().with_has_header(false));
1019 let listing_options = ListingOptions::new(file_format.clone())
1020 .with_file_extension(file_format.get_ext());
1021 ctx.register_listing_table(
1022 "empty",
1023 "tests/data/empty_files/all_empty/",
1024 listing_options,
1025 None,
1026 None,
1027 )
1028 .await
1029 .unwrap();
1030
1031 let query = "select * from empty where random() > 0.5;";
1033 let query_result = ctx.sql(query).await?.collect().await?;
1034
1035 assert_snapshot!(batches_to_string(&query_result),@r###"
1036 ++
1037 ++
1038 "###);
1039
1040 Ok(())
1041 }
1042
1043 #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
1062 #[tokio::test]
1063 async fn test_csv_parallel_some_file_empty(n_partitions: usize) -> Result<()> {
1064 let config = SessionConfig::new()
1065 .with_repartition_file_scans(true)
1066 .with_repartition_file_min_size(0)
1067 .with_target_partitions(n_partitions);
1068 let ctx = SessionContext::new_with_config(config);
1069 let file_format = Arc::new(CsvFormat::default().with_has_header(false));
1070 let listing_options = ListingOptions::new(file_format.clone())
1071 .with_file_extension(file_format.get_ext());
1072 ctx.register_listing_table(
1073 "empty",
1074 "tests/data/empty_files/some_empty",
1075 listing_options,
1076 None,
1077 None,
1078 )
1079 .await
1080 .unwrap();
1081
1082 let query = "select sum(column_1) from empty where column_1 > 0;";
1084 let query_result = ctx.sql(query).await?.collect().await?;
1085 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1086
1087 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1088 +---------------------+
1089 | sum(empty.column_1) |
1090 +---------------------+
1091 | 10 |
1092 +---------------------+
1093 "###);}
1094
1095 assert_eq!(n_partitions, actual_partitions); Ok(())
1098 }
1099
1100 #[rstest(n_partitions, case(1), case(2), case(3), case(5), case(10), case(32))]
1108 #[tokio::test]
1109 async fn test_csv_parallel_one_col(n_partitions: usize) -> Result<()> {
1110 let config = SessionConfig::new()
1111 .with_repartition_file_scans(true)
1112 .with_repartition_file_min_size(0)
1113 .with_target_partitions(n_partitions);
1114 let ctx = SessionContext::new_with_config(config);
1115
1116 ctx.register_csv(
1117 "one_col",
1118 "tests/data/one_col.csv",
1119 CsvReadOptions::new().has_header(false),
1120 )
1121 .await?;
1122
1123 let query = "select sum(column_1) from one_col where column_1 > 0;";
1124 let query_result = ctx.sql(query).await?.collect().await?;
1125 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1126
1127 let file_size = std::fs::metadata("tests/data/one_col.csv")?.len() as usize;
1128 let expected_partitions = if n_partitions <= file_size {
1130 n_partitions
1131 } else {
1132 file_size
1133 };
1134
1135 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1136 +-----------------------+
1137 | sum(one_col.column_1) |
1138 +-----------------------+
1139 | 50 |
1140 +-----------------------+
1141 "###);
1142 }
1143
1144 assert_eq!(expected_partitions, actual_partitions);
1145
1146 Ok(())
1147 }
1148
1149 #[rstest(n_partitions, case(1), case(2), case(10), case(16))]
1156 #[tokio::test]
1157 async fn test_csv_parallel_wide_rows(n_partitions: usize) -> Result<()> {
1158 let config = SessionConfig::new()
1159 .with_repartition_file_scans(true)
1160 .with_repartition_file_min_size(0)
1161 .with_target_partitions(n_partitions);
1162 let ctx = SessionContext::new_with_config(config);
1163 ctx.register_csv(
1164 "wide_rows",
1165 "tests/data/wide_rows.csv",
1166 CsvReadOptions::new().has_header(false),
1167 )
1168 .await?;
1169
1170 let query = "select sum(column_1) + sum(column_33) + sum(column_50) + sum(column_77) + sum(column_100) as sum_of_5_cols from wide_rows where column_1 > 0;";
1171 let query_result = ctx.sql(query).await?.collect().await?;
1172 let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1173
1174 insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1175 +---------------+
1176 | sum_of_5_cols |
1177 +---------------+
1178 | 15 |
1179 +---------------+
1180 "###);}
1181
1182 assert_eq!(n_partitions, actual_partitions);
1183
1184 Ok(())
1185 }
1186
1187 #[rstest]
1188 fn test_csv_deserializer_with_finish(
1189 #[values(1, 5, 17)] batch_size: usize,
1190 #[values(0, 5, 93)] line_count: usize,
1191 ) -> Result<()> {
1192 let schema = csv_schema();
1193 let generator = CsvBatchGenerator::new(batch_size, line_count);
1194 let mut deserializer = csv_deserializer(batch_size, &schema);
1195
1196 for data in generator {
1197 deserializer.digest(data);
1198 }
1199 deserializer.finish();
1200
1201 let batch_count = line_count.div_ceil(batch_size);
1202
1203 let mut all_batches = RecordBatch::new_empty(schema.clone());
1204 for _ in 0..batch_count {
1205 let output = deserializer.next()?;
1206 let DeserializerOutput::RecordBatch(batch) = output else {
1207 panic!("Expected RecordBatch, got {output:?}");
1208 };
1209 all_batches = concat_batches(&schema, &[all_batches, batch])?;
1210 }
1211 assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);
1212
1213 let expected = csv_expected_batch(schema, line_count)?;
1214
1215 assert_eq!(
1216 expected.clone(),
1217 all_batches.clone(),
1218 "Expected:\n{}\nActual:\n{}",
1219 pretty_format_batches(&[expected])?,
1220 pretty_format_batches(&[all_batches])?,
1221 );
1222
1223 Ok(())
1224 }
1225
1226 #[rstest]
1227 fn test_csv_deserializer_without_finish(
1228 #[values(1, 5, 17)] batch_size: usize,
1229 #[values(0, 5, 93)] line_count: usize,
1230 ) -> Result<()> {
1231 let schema = csv_schema();
1232 let generator = CsvBatchGenerator::new(batch_size, line_count);
1233 let mut deserializer = csv_deserializer(batch_size, &schema);
1234
1235 for data in generator {
1236 deserializer.digest(data);
1237 }
1238
1239 let batch_count = line_count / batch_size;
1240
1241 let mut all_batches = RecordBatch::new_empty(schema.clone());
1242 for _ in 0..batch_count {
1243 let output = deserializer.next()?;
1244 let DeserializerOutput::RecordBatch(batch) = output else {
1245 panic!("Expected RecordBatch, got {output:?}");
1246 };
1247 all_batches = concat_batches(&schema, &[all_batches, batch])?;
1248 }
1249 assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);
1250
1251 let expected = csv_expected_batch(schema, batch_count * batch_size)?;
1252
1253 assert_eq!(
1254 expected.clone(),
1255 all_batches.clone(),
1256 "Expected:\n{}\nActual:\n{}",
1257 pretty_format_batches(&[expected])?,
1258 pretty_format_batches(&[all_batches])?,
1259 );
1260
1261 Ok(())
1262 }
1263
1264 struct CsvBatchGenerator {
1265 batch_size: usize,
1266 line_count: usize,
1267 offset: usize,
1268 }
1269
1270 impl CsvBatchGenerator {
1271 fn new(batch_size: usize, line_count: usize) -> Self {
1272 Self {
1273 batch_size,
1274 line_count,
1275 offset: 0,
1276 }
1277 }
1278 }
1279
1280 impl Iterator for CsvBatchGenerator {
1281 type Item = Bytes;
1282
1283 fn next(&mut self) -> Option<Self::Item> {
1284 let mut buffer = Vec::new();
1286 for _ in 0..self.batch_size {
1287 if self.offset >= self.line_count {
1288 break;
1289 }
1290 buffer.extend_from_slice(&csv_line(self.offset));
1291 self.offset += 1;
1292 }
1293
1294 (!buffer.is_empty()).then(|| buffer.into())
1295 }
1296 }
1297
1298 fn csv_expected_batch(schema: SchemaRef, line_count: usize) -> Result<RecordBatch> {
1299 let mut c1 = Vec::with_capacity(line_count);
1300 let mut c2 = Vec::with_capacity(line_count);
1301 let mut c3 = Vec::with_capacity(line_count);
1302 let mut c4 = Vec::with_capacity(line_count);
1303
1304 for i in 0..line_count {
1305 let (int_value, float_value, bool_value, char_value) = csv_values(i);
1306 c1.push(int_value);
1307 c2.push(float_value);
1308 c3.push(bool_value);
1309 c4.push(char_value);
1310 }
1311
1312 let expected = RecordBatch::try_new(
1313 schema.clone(),
1314 vec![
1315 Arc::new(Int32Array::from(c1)),
1316 Arc::new(Float64Array::from(c2)),
1317 Arc::new(BooleanArray::from(c3)),
1318 Arc::new(StringArray::from(c4)),
1319 ],
1320 )?;
1321 Ok(expected)
1322 }
1323
1324 fn csv_line(line_number: usize) -> Bytes {
1325 let (int_value, float_value, bool_value, char_value) = csv_values(line_number);
1326 format!("{int_value},{float_value},{bool_value},{char_value}\n").into()
1327 }
1328
1329 fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
1330 let int_value = line_number as i32;
1331 let float_value = line_number as f64;
1332 let bool_value = line_number.is_multiple_of(2);
1333 let char_value = format!("{line_number}-string");
1334 (int_value, float_value, bool_value, char_value)
1335 }
1336
1337 fn csv_schema() -> SchemaRef {
1338 Arc::new(Schema::new(vec![
1339 Field::new("c1", DataType::Int32, true),
1340 Field::new("c2", DataType::Float64, true),
1341 Field::new("c3", DataType::Boolean, true),
1342 Field::new("c4", DataType::Utf8, true),
1343 ]))
1344 }
1345
1346 fn csv_deserializer(
1347 batch_size: usize,
1348 schema: &Arc<Schema>,
1349 ) -> impl BatchDeserializer<Bytes> {
1350 let decoder = ReaderBuilder::new(schema.clone())
1351 .with_batch_size(batch_size)
1352 .build_decoder();
1353 DecoderDeserializer::new(CsvDecoder::new(decoder))
1354 }
1355
1356 fn csv_deserializer_with_truncated(
1357 batch_size: usize,
1358 schema: &Arc<Schema>,
1359 ) -> impl BatchDeserializer<Bytes> {
1360 let decoder = ReaderBuilder::new(schema.clone())
1362 .with_batch_size(batch_size)
1363 .with_truncated_rows(true) .build_decoder();
1365 DecoderDeserializer::new(CsvDecoder::new(decoder))
1366 }
1367
1368 #[tokio::test]
1369 async fn infer_schema_with_truncated_rows_true() -> Result<()> {
1370 let session_ctx = SessionContext::new();
1371 let state = session_ctx.state();
1372
1373 let csv_data = Bytes::from("a,b,c\n1,2\n3,4,5\n");
1375 let variable_object_store = Arc::new(VariableStream::new(csv_data, 1));
1376 let object_meta = ObjectMeta {
1377 location: Path::parse("/")?,
1378 last_modified: DateTime::default(),
1379 size: u64::MAX,
1380 e_tag: None,
1381 version: None,
1382 };
1383
1384 let csv_options = CsvOptions::default().with_truncated_rows(true);
1386 let csv_format = CsvFormat::default()
1387 .with_has_header(true)
1388 .with_options(csv_options)
1389 .with_schema_infer_max_rec(10);
1390
1391 let inferred_schema = csv_format
1392 .infer_schema(
1393 &state,
1394 &(variable_object_store.clone() as Arc<dyn ObjectStore>),
1395 &[object_meta],
1396 )
1397 .await?;
1398
1399 assert_eq!(inferred_schema.fields().len(), 3);
1401
1402 for f in inferred_schema.fields() {
1404 assert!(f.is_nullable());
1405 }
1406
1407 Ok(())
1408 }
1409 #[test]
1410 fn test_decoder_truncated_rows_runtime() -> Result<()> {
1411 let schema = csv_schema(); let mut deserializer = csv_deserializer_with_truncated(10, &schema);
1416
1417 let input = Bytes::from("0,0.0,true,0-string\n1,1.0,true\n");
1419 deserializer.digest(input);
1420
1421 deserializer.finish();
1423
1424 let output = deserializer.next()?;
1425 match output {
1426 DeserializerOutput::RecordBatch(batch) => {
1427 assert!(batch.num_rows() >= 2);
1429 let col4 = batch
1431 .column(3)
1432 .as_any()
1433 .downcast_ref::<StringArray>()
1434 .expect("column 4 should be StringArray");
1435
1436 assert!(!col4.is_null(0));
1438 assert!(col4.is_null(1));
1439 }
1440 other => panic!("expected RecordBatch but got {other:?}"),
1441 }
1442 Ok(())
1443 }
1444
1445 #[tokio::test]
1446 async fn infer_schema_truncated_rows_false_error() -> Result<()> {
1447 let session_ctx = SessionContext::new();
1448 let state = session_ctx.state();
1449
1450 let csv_data = Bytes::from("id,a,b,c\n1,foo,bar\n2,foo,bar,baz\n");
1452 let variable_object_store = Arc::new(VariableStream::new(csv_data, 1));
1453 let object_meta = ObjectMeta {
1454 location: Path::parse("/")?,
1455 last_modified: DateTime::default(),
1456 size: u64::MAX,
1457 e_tag: None,
1458 version: None,
1459 };
1460
1461 let csv_format = CsvFormat::default()
1463 .with_has_header(true)
1464 .with_schema_infer_max_rec(10);
1465
1466 let res = csv_format
1467 .infer_schema(
1468 &state,
1469 &(variable_object_store.clone() as Arc<dyn ObjectStore>),
1470 &[object_meta],
1471 )
1472 .await;
1473
1474 assert!(
1476 res.is_err(),
1477 "expected infer_schema to error on truncated rows when disabled"
1478 );
1479
1480 if let Err(err) = res {
1482 let msg = format!("{err}");
1483 assert!(
1484 msg.contains("Encountered unequal lengths")
1485 || msg.contains("incorrect number of fields"),
1486 "unexpected error message: {msg}",
1487 );
1488 }
1489
1490 Ok(())
1491 }
1492
1493 #[tokio::test]
1494 async fn test_read_csv_truncated_rows_via_tempfile() -> Result<()> {
1495 use std::io::Write;
1496
1497 let ctx = SessionContext::new();
1499
1500 let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; write!(tmp, "a,b,c\n1,2\n3,4,5\n")?;
1504 let path = tmp.path().to_str().unwrap().to_string();
1505
1506 let options = CsvReadOptions::default().truncated_rows(true);
1510
1511 println!("options: {}, path: {path}", options.truncated_rows);
1512
1513 let df = ctx.read_csv(&path, options).await?;
1515
1516 let batches = df.collect().await?;
1518 let combined = concat_batches(&batches[0].schema(), &batches)?;
1519
1520 let col_c = combined.column(2);
1522 assert!(
1523 col_c.is_null(0),
1524 "expected first row column 'c' to be NULL due to truncated row"
1525 );
1526
1527 assert!(combined.num_rows() >= 2);
1529
1530 Ok(())
1531 }
1532}