datafusion/datasource/file_format/
csv.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Re-exports the [`datafusion_datasource_csv::file_format`] module, and contains tests for it.
19pub use datafusion_datasource_csv::file_format::*;
20
21#[cfg(test)]
22mod tests {
23    use std::fmt::{self, Display};
24    use std::ops::Range;
25    use std::sync::{Arc, Mutex};
26
27    use super::*;
28
29    use crate::datasource::file_format::test_util::scan_format;
30    use crate::datasource::listing::ListingOptions;
31    use crate::execution::session_state::SessionStateBuilder;
32    use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
33    use arrow_schema::{DataType, Field, Schema, SchemaRef};
34    use datafusion_catalog::Session;
35    use datafusion_common::cast::as_string_array;
36    use datafusion_common::config::CsvOptions;
37    use datafusion_common::internal_err;
38    use datafusion_common::stats::Precision;
39    use datafusion_common::test_util::{arrow_test_data, batches_to_string};
40    use datafusion_common::Result;
41    use datafusion_datasource::decoder::{
42        BatchDeserializer, DecoderDeserializer, DeserializerOutput,
43    };
44    use datafusion_datasource::file_compression_type::FileCompressionType;
45    use datafusion_datasource::file_format::FileFormat;
46    use datafusion_datasource::write::BatchSerializer;
47    use datafusion_expr::{col, lit};
48    use datafusion_physical_plan::{collect, ExecutionPlan};
49
50    use arrow::array::{
51        Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray,
52    };
53    use arrow::compute::concat_batches;
54    use arrow::csv::ReaderBuilder;
55    use arrow::util::pretty::pretty_format_batches;
56    use async_trait::async_trait;
57    use bytes::Bytes;
58    use chrono::DateTime;
59    use datafusion_common::parsers::CompressionTypeVariant;
60    use futures::stream::BoxStream;
61    use futures::StreamExt;
62    use insta::assert_snapshot;
63    use object_store::chunked::ChunkedStore;
64    use object_store::local::LocalFileSystem;
65    use object_store::path::Path;
66    use object_store::{
67        Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload,
68        ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult,
69    };
70    use regex::Regex;
71    use rstest::*;
72
73    /// Mock ObjectStore to provide an variable stream of bytes on get
74    /// Able to keep track of how many iterations of the provided bytes were repeated
75    #[derive(Debug)]
76    struct VariableStream {
77        bytes_to_repeat: Bytes,
78        max_iterations: u64,
79        iterations_detected: Arc<Mutex<usize>>,
80    }
81
82    impl Display for VariableStream {
83        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84            write!(f, "VariableStream")
85        }
86    }
87
88    #[async_trait]
89    impl ObjectStore for VariableStream {
90        async fn put_opts(
91            &self,
92            _location: &Path,
93            _payload: PutPayload,
94            _opts: PutOptions,
95        ) -> object_store::Result<PutResult> {
96            unimplemented!()
97        }
98
99        async fn put_multipart_opts(
100            &self,
101            _location: &Path,
102            _opts: PutMultipartOptions,
103        ) -> object_store::Result<Box<dyn MultipartUpload>> {
104            unimplemented!()
105        }
106
107        async fn get(&self, location: &Path) -> object_store::Result<GetResult> {
108            self.get_opts(location, GetOptions::default()).await
109        }
110
111        async fn get_opts(
112            &self,
113            location: &Path,
114            _opts: GetOptions,
115        ) -> object_store::Result<GetResult> {
116            let bytes = self.bytes_to_repeat.clone();
117            let len = bytes.len() as u64;
118            let range = 0..len * self.max_iterations;
119            let arc = self.iterations_detected.clone();
120            let stream = futures::stream::repeat_with(move || {
121                let arc_inner = arc.clone();
122                *arc_inner.lock().unwrap() += 1;
123                Ok(bytes.clone())
124            })
125            .take(self.max_iterations as usize)
126            .boxed();
127
128            Ok(GetResult {
129                payload: GetResultPayload::Stream(stream),
130                meta: ObjectMeta {
131                    location: location.clone(),
132                    last_modified: Default::default(),
133                    size: range.end,
134                    e_tag: None,
135                    version: None,
136                },
137                range: Default::default(),
138                attributes: Attributes::default(),
139            })
140        }
141
142        async fn get_ranges(
143            &self,
144            _location: &Path,
145            _ranges: &[Range<u64>],
146        ) -> object_store::Result<Vec<Bytes>> {
147            unimplemented!()
148        }
149
150        async fn head(&self, _location: &Path) -> object_store::Result<ObjectMeta> {
151            unimplemented!()
152        }
153
154        async fn delete(&self, _location: &Path) -> object_store::Result<()> {
155            unimplemented!()
156        }
157
158        fn list(
159            &self,
160            _prefix: Option<&Path>,
161        ) -> BoxStream<'static, object_store::Result<ObjectMeta>> {
162            unimplemented!()
163        }
164
165        async fn list_with_delimiter(
166            &self,
167            _prefix: Option<&Path>,
168        ) -> object_store::Result<ListResult> {
169            unimplemented!()
170        }
171
172        async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> {
173            unimplemented!()
174        }
175
176        async fn copy_if_not_exists(
177            &self,
178            _from: &Path,
179            _to: &Path,
180        ) -> object_store::Result<()> {
181            unimplemented!()
182        }
183    }
184
185    impl VariableStream {
186        pub fn new(bytes_to_repeat: Bytes, max_iterations: u64) -> Self {
187            Self {
188                bytes_to_repeat,
189                max_iterations,
190                iterations_detected: Arc::new(Mutex::new(0)),
191            }
192        }
193
194        pub fn get_iterations_detected(&self) -> usize {
195            *self.iterations_detected.lock().unwrap()
196        }
197    }
198
199    #[tokio::test]
200    async fn read_small_batches() -> Result<()> {
201        let config = SessionConfig::new().with_batch_size(2);
202        let session_ctx = SessionContext::new_with_config(config);
203        let state = session_ctx.state();
204        let task_ctx = state.task_ctx();
205        // skip column 9 that overflows the automatically discovered column type of i64 (u64 would work)
206        let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
207        let exec =
208            get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?;
209        let stream = exec.execute(0, task_ctx)?;
210
211        let tt_batches: i32 = stream
212            .map(|batch| {
213                let batch = batch.unwrap();
214                assert_eq!(12, batch.num_columns());
215                assert_eq!(2, batch.num_rows());
216            })
217            .fold(0, |acc, _| async move { acc + 1i32 })
218            .await;
219
220        assert_eq!(tt_batches, 50 /* 100/2 */);
221
222        // test metadata
223        assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent);
224        assert_eq!(
225            exec.partition_statistics(None)?.total_byte_size,
226            Precision::Absent
227        );
228
229        Ok(())
230    }
231
232    #[tokio::test]
233    async fn read_limit() -> Result<()> {
234        let session_ctx = SessionContext::new();
235        let state = session_ctx.state();
236        let task_ctx = session_ctx.task_ctx();
237        let projection = Some(vec![0, 1, 2, 3]);
238        let exec =
239            get_exec(&state, "aggregate_test_100.csv", projection, Some(1), true).await?;
240        let batches = collect(exec, task_ctx).await?;
241        assert_eq!(1, batches.len());
242        assert_eq!(4, batches[0].num_columns());
243        assert_eq!(1, batches[0].num_rows());
244
245        Ok(())
246    }
247
248    #[tokio::test]
249    async fn infer_schema() -> Result<()> {
250        let session_ctx = SessionContext::new();
251        let state = session_ctx.state();
252
253        let projection = None;
254        let root = "./tests/data/csv";
255        let format = CsvFormat::default().with_has_header(true);
256        let exec = scan_format(
257            &state,
258            &format,
259            None,
260            root,
261            "aggregate_test_100_with_nulls.csv",
262            projection,
263            None,
264        )
265        .await?;
266
267        let x: Vec<String> = exec
268            .schema()
269            .fields()
270            .iter()
271            .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
272            .collect();
273        assert_eq!(
274            vec![
275                "c1: Utf8",
276                "c2: Int64",
277                "c3: Int64",
278                "c4: Int64",
279                "c5: Int64",
280                "c6: Int64",
281                "c7: Int64",
282                "c8: Int64",
283                "c9: Int64",
284                "c10: Utf8",
285                "c11: Float64",
286                "c12: Float64",
287                "c13: Utf8",
288                "c14: Null",
289                "c15: Utf8"
290            ],
291            x
292        );
293
294        Ok(())
295    }
296
297    #[tokio::test]
298    async fn infer_schema_with_null_regex() -> Result<()> {
299        let session_ctx = SessionContext::new();
300        let state = session_ctx.state();
301
302        let projection = None;
303        let root = "./tests/data/csv";
304        let format = CsvFormat::default()
305            .with_has_header(true)
306            .with_null_regex(Some("^NULL$|^$".to_string()));
307        let exec = scan_format(
308            &state,
309            &format,
310            None,
311            root,
312            "aggregate_test_100_with_nulls.csv",
313            projection,
314            None,
315        )
316        .await?;
317
318        let x: Vec<String> = exec
319            .schema()
320            .fields()
321            .iter()
322            .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
323            .collect();
324        assert_eq!(
325            vec![
326                "c1: Utf8",
327                "c2: Int64",
328                "c3: Int64",
329                "c4: Int64",
330                "c5: Int64",
331                "c6: Int64",
332                "c7: Int64",
333                "c8: Int64",
334                "c9: Int64",
335                "c10: Utf8",
336                "c11: Float64",
337                "c12: Float64",
338                "c13: Utf8",
339                "c14: Null",
340                "c15: Null"
341            ],
342            x
343        );
344
345        Ok(())
346    }
347
348    #[tokio::test]
349    async fn read_char_column() -> Result<()> {
350        let session_ctx = SessionContext::new();
351        let state = session_ctx.state();
352        let task_ctx = session_ctx.task_ctx();
353        let projection = Some(vec![0]);
354        let exec =
355            get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?;
356
357        let batches = collect(exec, task_ctx).await.expect("Collect batches");
358
359        assert_eq!(1, batches.len());
360        assert_eq!(1, batches[0].num_columns());
361        assert_eq!(100, batches[0].num_rows());
362
363        let array = as_string_array(batches[0].column(0))?;
364        let mut values: Vec<&str> = vec![];
365        for i in 0..5 {
366            values.push(array.value(i));
367        }
368
369        assert_eq!(vec!["c", "d", "b", "a", "b"], values);
370
371        Ok(())
372    }
373
374    #[tokio::test]
375    async fn test_infer_schema_stream() -> Result<()> {
376        let session_ctx = SessionContext::new();
377        let state = session_ctx.state();
378        let variable_object_store =
379            Arc::new(VariableStream::new(Bytes::from("1,2,3,4,5\n"), 200));
380        let object_meta = ObjectMeta {
381            location: Path::parse("/")?,
382            last_modified: DateTime::default(),
383            size: u64::MAX,
384            e_tag: None,
385            version: None,
386        };
387
388        let num_rows_to_read = 100;
389        let csv_format = CsvFormat::default()
390            .with_has_header(false)
391            .with_schema_infer_max_rec(num_rows_to_read);
392        let inferred_schema = csv_format
393            .infer_schema(
394                &state,
395                &(variable_object_store.clone() as Arc<dyn ObjectStore>),
396                &[object_meta],
397            )
398            .await?;
399
400        let actual_fields: Vec<_> = inferred_schema
401            .fields()
402            .iter()
403            .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
404            .collect();
405        assert_eq!(
406            vec![
407                "column_1: Int64",
408                "column_2: Int64",
409                "column_3: Int64",
410                "column_4: Int64",
411                "column_5: Int64"
412            ],
413            actual_fields
414        );
415        // ensuring on csv infer that it won't try to read entire file
416        // should only read as many rows as was configured in the CsvFormat
417        assert_eq!(
418            num_rows_to_read,
419            variable_object_store.get_iterations_detected()
420        );
421
422        Ok(())
423    }
424
425    #[tokio::test]
426    async fn test_infer_schema_escape_chars() -> Result<()> {
427        let session_ctx = SessionContext::new();
428        let state = session_ctx.state();
429        let variable_object_store = Arc::new(VariableStream::new(
430            Bytes::from(
431                r#"c1,c2,c3,c4
4320.3,"Here, is a comma\"",third,3
4330.31,"double quotes are ok, "" quote",third again,9
4340.314,abc,xyz,27"#,
435            ),
436            1,
437        ));
438        let object_meta = ObjectMeta {
439            location: Path::parse("/")?,
440            last_modified: DateTime::default(),
441            size: u64::MAX,
442            e_tag: None,
443            version: None,
444        };
445
446        let num_rows_to_read = 3;
447        let csv_format = CsvFormat::default()
448            .with_has_header(true)
449            .with_schema_infer_max_rec(num_rows_to_read)
450            .with_quote(b'"')
451            .with_escape(Some(b'\\'));
452
453        let inferred_schema = csv_format
454            .infer_schema(
455                &state,
456                &(variable_object_store.clone() as Arc<dyn ObjectStore>),
457                &[object_meta],
458            )
459            .await?;
460
461        let actual_fields: Vec<_> = inferred_schema
462            .fields()
463            .iter()
464            .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
465            .collect();
466
467        assert_eq!(
468            vec!["c1: Float64", "c2: Utf8", "c3: Utf8", "c4: Int64",],
469            actual_fields
470        );
471        Ok(())
472    }
473
474    #[tokio::test]
475    async fn test_infer_schema_stream_null_chunks() -> Result<()> {
476        let session_ctx = SessionContext::new();
477        let state = session_ctx.state();
478
479        // a stream where each line is read as a separate chunk,
480        // data type for each chunk is inferred separately.
481        // +----+-----+----+
482        // | c1 | c2  | c3 |
483        // +----+-----+----+
484        // | 1  | 1.0 |    |  type: Int64, Float64, Null
485        // |    |     |    |  type: Null, Null, Null
486        // +----+-----+----+
487        let chunked_object_store = Arc::new(ChunkedStore::new(
488            Arc::new(VariableStream::new(
489                Bytes::from(
490                    r#"c1,c2,c3
4911,1.0,
492,,
493"#,
494                ),
495                1,
496            )),
497            1,
498        ));
499        let object_meta = ObjectMeta {
500            location: Path::parse("/")?,
501            last_modified: DateTime::default(),
502            size: u64::MAX,
503            e_tag: None,
504            version: None,
505        };
506
507        let csv_format = CsvFormat::default().with_has_header(true);
508        let inferred_schema = csv_format
509            .infer_schema(
510                &state,
511                &(chunked_object_store as Arc<dyn ObjectStore>),
512                &[object_meta],
513            )
514            .await?;
515
516        let actual_fields: Vec<_> = inferred_schema
517            .fields()
518            .iter()
519            .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
520            .collect();
521
522        // ensure null chunks don't skew type inference
523        assert_eq!(vec!["c1: Int64", "c2: Float64", "c3: Null"], actual_fields);
524        Ok(())
525    }
526
527    #[rstest(
528        file_compression_type,
529        case(FileCompressionType::UNCOMPRESSED),
530        case(FileCompressionType::GZIP),
531        case(FileCompressionType::BZIP2),
532        case(FileCompressionType::XZ),
533        case(FileCompressionType::ZSTD)
534    )]
535    #[cfg(feature = "compression")]
536    #[tokio::test]
537    async fn query_compress_data(
538        file_compression_type: FileCompressionType,
539    ) -> Result<()> {
540        use arrow_schema::{DataType, Field, Schema};
541        use datafusion_common::DataFusionError;
542        use datafusion_datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
543        use futures::TryStreamExt;
544
545        let mut cfg = SessionConfig::new();
546        cfg.options_mut().catalog.has_header = true;
547        let session_state = SessionStateBuilder::new()
548            .with_config(cfg)
549            .with_default_features()
550            .build();
551        let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap();
552        let path = Path::from("csv/aggregate_test_100.csv");
553        let csv = CsvFormat::default().with_has_header(true);
554        let records_to_read = csv
555            .options()
556            .schema_infer_max_rec
557            .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
558        let store = Arc::new(integration) as Arc<dyn ObjectStore>;
559        let original_stream = store.get(&path).await?;
560
561        //convert original_stream to compressed_stream for next step
562        let compressed_stream =
563            file_compression_type.to_owned().convert_to_compress_stream(
564                original_stream
565                    .into_stream()
566                    .map_err(DataFusionError::from)
567                    .boxed(),
568            );
569
570        //prepare expected schema for assert_eq
571        let expected = Schema::new(vec![
572            Field::new("c1", DataType::Utf8, true),
573            Field::new("c2", DataType::Int64, true),
574            Field::new("c3", DataType::Int64, true),
575            Field::new("c4", DataType::Int64, true),
576            Field::new("c5", DataType::Int64, true),
577            Field::new("c6", DataType::Int64, true),
578            Field::new("c7", DataType::Int64, true),
579            Field::new("c8", DataType::Int64, true),
580            Field::new("c9", DataType::Int64, true),
581            Field::new("c10", DataType::Utf8, true),
582            Field::new("c11", DataType::Float64, true),
583            Field::new("c12", DataType::Float64, true),
584            Field::new("c13", DataType::Utf8, true),
585        ]);
586
587        let compressed_csv = csv.with_file_compression_type(file_compression_type);
588
589        //convert compressed_stream to decoded_stream
590        let decoded_stream = compressed_csv
591            .read_to_delimited_chunks_from_stream(compressed_stream.unwrap())
592            .await;
593        let (schema, records_read) = compressed_csv
594            .infer_schema_from_stream(&session_state, records_to_read, decoded_stream)
595            .await?;
596
597        assert_eq!(expected, schema);
598        assert_eq!(100, records_read);
599        Ok(())
600    }
601
602    #[cfg(feature = "compression")]
603    #[tokio::test]
604    async fn query_compress_csv() -> Result<()> {
605        let ctx = SessionContext::new();
606
607        let csv_options = CsvReadOptions::default()
608            .has_header(true)
609            .file_compression_type(FileCompressionType::GZIP)
610            .file_extension("csv.gz");
611        let df = ctx
612            .read_csv(
613                &format!("{}/csv/aggregate_test_100.csv.gz", arrow_test_data()),
614                csv_options,
615            )
616            .await?;
617
618        let record_batch = df
619            .filter(col("c1").eq(lit("a")).and(col("c2").gt(lit("4"))))?
620            .select_columns(&["c2", "c3"])?
621            .collect()
622            .await?;
623
624        assert_snapshot!(batches_to_string(&record_batch), @r###"
625            +----+------+
626            | c2 | c3   |
627            +----+------+
628            | 5  | 36   |
629            | 5  | -31  |
630            | 5  | -101 |
631            +----+------+
632        "###);
633
634        Ok(())
635    }
636
637    async fn get_exec(
638        state: &dyn Session,
639        file_name: &str,
640        projection: Option<Vec<usize>>,
641        limit: Option<usize>,
642        has_header: bool,
643    ) -> Result<Arc<dyn ExecutionPlan>> {
644        let root = format!("{}/csv", arrow_test_data());
645        let format = CsvFormat::default().with_has_header(has_header);
646        scan_format(state, &format, None, &root, file_name, projection, limit).await
647    }
648
649    #[tokio::test]
650    async fn test_csv_serializer() -> Result<()> {
651        let ctx = SessionContext::new();
652        let df = ctx
653            .read_csv(
654                &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
655                CsvReadOptions::default().has_header(true),
656            )
657            .await?;
658        let batches = df
659            .select_columns(&["c2", "c3"])?
660            .limit(0, Some(10))?
661            .collect()
662            .await?;
663        let batch = concat_batches(&batches[0].schema(), &batches)?;
664        let serializer = CsvSerializer::new();
665        let bytes = serializer.serialize(batch, true)?;
666        assert_eq!(
667            "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
668            String::from_utf8(bytes.into()).unwrap()
669        );
670        Ok(())
671    }
672
673    #[tokio::test]
674    async fn test_csv_serializer_no_header() -> Result<()> {
675        let ctx = SessionContext::new();
676        let df = ctx
677            .read_csv(
678                &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
679                CsvReadOptions::default().has_header(true),
680            )
681            .await?;
682        let batches = df
683            .select_columns(&["c2", "c3"])?
684            .limit(0, Some(10))?
685            .collect()
686            .await?;
687        let batch = concat_batches(&batches[0].schema(), &batches)?;
688        let serializer = CsvSerializer::new().with_header(false);
689        let bytes = serializer.serialize(batch, true)?;
690        assert_eq!(
691            "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
692            String::from_utf8(bytes.into()).unwrap()
693        );
694        Ok(())
695    }
696
697    /// Explain the `sql` query under `ctx` to make sure the underlying csv scan is parallelized
698    /// e.g. "DataSourceExec: file_groups={2 groups:" in plan means 2 DataSourceExec runs concurrently
699    async fn count_query_csv_partitions(
700        ctx: &SessionContext,
701        sql: &str,
702    ) -> Result<usize> {
703        let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
704        let result = df.collect().await?;
705        let plan = format!("{}", &pretty_format_batches(&result)?);
706
707        let re = Regex::new(r"DataSourceExec: file_groups=\{(\d+) group").unwrap();
708
709        if let Some(captures) = re.captures(&plan) {
710            if let Some(match_) = captures.get(1) {
711                let n_partitions = match_.as_str().parse::<usize>().unwrap();
712                return Ok(n_partitions);
713            }
714        }
715
716        internal_err!("query contains no DataSourceExec")
717    }
718
719    #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
720    #[tokio::test]
721    async fn test_csv_parallel_basic(n_partitions: usize) -> Result<()> {
722        let config = SessionConfig::new()
723            .with_repartition_file_scans(true)
724            .with_repartition_file_min_size(0)
725            .with_target_partitions(n_partitions);
726        let ctx = SessionContext::new_with_config(config);
727        let testdata = arrow_test_data();
728        ctx.register_csv(
729            "aggr",
730            &format!("{testdata}/csv/aggregate_test_100.csv"),
731            CsvReadOptions::new().has_header(true),
732        )
733        .await?;
734
735        let query = "select sum(c2) from aggr;";
736        let query_result = ctx.sql(query).await?.collect().await?;
737        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
738
739        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
740        +--------------+
741        | sum(aggr.c2) |
742        +--------------+
743        | 285          |
744        +--------------+
745        "###);
746        }
747
748        assert_eq!(n_partitions, actual_partitions);
749
750        Ok(())
751    }
752
753    #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
754    #[cfg(feature = "compression")]
755    #[tokio::test]
756    async fn test_csv_parallel_compressed(n_partitions: usize) -> Result<()> {
757        let config = SessionConfig::new()
758            .with_repartition_file_scans(true)
759            .with_repartition_file_min_size(0)
760            .with_target_partitions(n_partitions);
761        let csv_options = CsvReadOptions::default()
762            .has_header(true)
763            .file_compression_type(FileCompressionType::GZIP)
764            .file_extension("csv.gz");
765        let ctx = SessionContext::new_with_config(config);
766        let testdata = arrow_test_data();
767        ctx.register_csv(
768            "aggr",
769            &format!("{testdata}/csv/aggregate_test_100.csv.gz"),
770            csv_options,
771        )
772        .await?;
773
774        let query = "select sum(c3) from aggr;";
775        let query_result = ctx.sql(query).await?.collect().await?;
776        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
777
778        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
779        +--------------+
780        | sum(aggr.c3) |
781        +--------------+
782        | 781          |
783        +--------------+
784        "###);
785        }
786
787        assert_eq!(1, actual_partitions); // Compressed csv won't be scanned in parallel
788
789        Ok(())
790    }
791
792    #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
793    #[tokio::test]
794    async fn test_csv_parallel_newlines_in_values(n_partitions: usize) -> Result<()> {
795        let config = SessionConfig::new()
796            .with_repartition_file_scans(true)
797            .with_repartition_file_min_size(0)
798            .with_target_partitions(n_partitions);
799        let csv_options = CsvReadOptions::default()
800            .has_header(true)
801            .newlines_in_values(true);
802        let ctx = SessionContext::new_with_config(config);
803        let testdata = arrow_test_data();
804        ctx.register_csv(
805            "aggr",
806            &format!("{testdata}/csv/aggregate_test_100.csv"),
807            csv_options,
808        )
809        .await?;
810
811        let query = "select sum(c3) from aggr;";
812        let query_result = ctx.sql(query).await?.collect().await?;
813        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
814
815        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
816        +--------------+
817        | sum(aggr.c3) |
818        +--------------+
819        | 781          |
820        +--------------+
821        "###);
822        }
823
824        assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set
825
826        Ok(())
827    }
828
829    /// Read a single empty csv file
830    ///
831    /// empty_0_byte.csv:
832    /// (file is empty)
833    #[tokio::test]
834    async fn test_csv_empty_file() -> Result<()> {
835        let ctx = SessionContext::new();
836        ctx.register_csv(
837            "empty",
838            "tests/data/empty_0_byte.csv",
839            CsvReadOptions::new().has_header(false),
840        )
841        .await?;
842
843        let query = "select * from empty where random() > 0.5;";
844        let query_result = ctx.sql(query).await?.collect().await?;
845
846        assert_snapshot!(batches_to_string(&query_result),@r###"
847            ++
848            ++
849        "###);
850
851        Ok(())
852    }
853
854    /// Read a single empty csv file with header
855    ///
856    /// empty.csv:
857    /// c1,c2,c3
858    #[tokio::test]
859    async fn test_csv_empty_with_header() -> Result<()> {
860        let ctx = SessionContext::new();
861        ctx.register_csv(
862            "empty",
863            "tests/data/empty.csv",
864            CsvReadOptions::new().has_header(true),
865        )
866        .await?;
867
868        let query = "select * from empty where random() > 0.5;";
869        let query_result = ctx.sql(query).await?.collect().await?;
870
871        assert_snapshot!(batches_to_string(&query_result),@r###"
872            ++
873            ++
874        "###);
875
876        Ok(())
877    }
878
879    /// Read multiple csv files (some are empty) with header
880    ///
881    /// some_empty_with_header
882    /// ├── a_empty.csv
883    /// ├── b.csv
884    /// └── c_nulls_column.csv
885    ///
886    /// a_empty.csv:
887    /// c1,c2,c3
888    ///
889    /// b.csv:
890    /// c1,c2,c3
891    /// 1,1,1
892    /// 2,2,2
893    ///
894    /// c_nulls_column.csv:
895    /// c1,c2,c3
896    /// 3,3,
897    #[tokio::test]
898    async fn test_csv_some_empty_with_header() -> Result<()> {
899        let ctx = SessionContext::new();
900        ctx.register_csv(
901            "some_empty_with_header",
902            "tests/data/empty_files/some_empty_with_header",
903            CsvReadOptions::new().has_header(true),
904        )
905        .await?;
906
907        let query = "select sum(c3) from some_empty_with_header;";
908        let query_result = ctx.sql(query).await?.collect().await?;
909
910        assert_snapshot!(batches_to_string(&query_result),@r"
911        +--------------------------------+
912        | sum(some_empty_with_header.c3) |
913        +--------------------------------+
914        | 3                              |
915        +--------------------------------+
916        ");
917
918        Ok(())
919    }
920
921    #[tokio::test]
922    async fn test_csv_extension_compressed() -> Result<()> {
923        // Write compressed CSV files
924        // Expect: under the directory, a file is created with ".csv.gz" extension
925        let ctx = SessionContext::new();
926
927        let df = ctx
928            .read_csv(
929                &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
930                CsvReadOptions::default().has_header(true),
931            )
932            .await?;
933
934        let tmp_dir = tempfile::TempDir::new().unwrap();
935        let path = format!("{}", tmp_dir.path().to_string_lossy());
936
937        let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
938        let cfg2 = CsvOptions::default()
939            .with_has_header(true)
940            .with_compression(CompressionTypeVariant::GZIP);
941
942        df.write_csv(&path, cfg1, Some(cfg2)).await?;
943        assert!(std::path::Path::new(&path).exists());
944
945        let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
946        assert_eq!(files.len(), 1);
947        assert!(files
948            .last()
949            .unwrap()
950            .as_ref()
951            .unwrap()
952            .path()
953            .file_name()
954            .unwrap()
955            .to_str()
956            .unwrap()
957            .ends_with(".csv.gz"));
958
959        Ok(())
960    }
961
962    #[tokio::test]
963    async fn test_csv_extension_uncompressed() -> Result<()> {
964        // Write plain uncompressed CSV files
965        // Expect: under the directory, a file is created with ".csv" extension
966        let ctx = SessionContext::new();
967
968        let df = ctx
969            .read_csv(
970                &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()),
971                CsvReadOptions::default().has_header(true),
972            )
973            .await?;
974
975        let tmp_dir = tempfile::TempDir::new().unwrap();
976        let path = format!("{}", tmp_dir.path().to_string_lossy());
977
978        let cfg1 = crate::dataframe::DataFrameWriteOptions::new();
979        let cfg2 = CsvOptions::default().with_has_header(true);
980
981        df.write_csv(&path, cfg1, Some(cfg2)).await?;
982        assert!(std::path::Path::new(&path).exists());
983
984        let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect();
985        assert_eq!(files.len(), 1);
986        assert!(files
987            .last()
988            .unwrap()
989            .as_ref()
990            .unwrap()
991            .path()
992            .file_name()
993            .unwrap()
994            .to_str()
995            .unwrap()
996            .ends_with(".csv"));
997
998        Ok(())
999    }
1000
1001    /// Read multiple empty csv files
1002    ///
1003    /// all_empty
1004    /// ├── empty0.csv
1005    /// ├── empty1.csv
1006    /// └── empty2.csv
1007    ///
1008    /// empty0.csv/empty1.csv/empty2.csv:
1009    /// (file is empty)
1010    #[tokio::test]
1011    async fn test_csv_multiple_empty_files() -> Result<()> {
1012        // Testing that partitioning doesn't break with empty files
1013        let config = SessionConfig::new()
1014            .with_repartition_file_scans(true)
1015            .with_repartition_file_min_size(0)
1016            .with_target_partitions(4);
1017        let ctx = SessionContext::new_with_config(config);
1018        let file_format = Arc::new(CsvFormat::default().with_has_header(false));
1019        let listing_options = ListingOptions::new(file_format.clone())
1020            .with_file_extension(file_format.get_ext());
1021        ctx.register_listing_table(
1022            "empty",
1023            "tests/data/empty_files/all_empty/",
1024            listing_options,
1025            None,
1026            None,
1027        )
1028        .await
1029        .unwrap();
1030
1031        // Require a predicate to enable repartition for the optimizer
1032        let query = "select * from empty where random() > 0.5;";
1033        let query_result = ctx.sql(query).await?.collect().await?;
1034
1035        assert_snapshot!(batches_to_string(&query_result),@r###"
1036            ++
1037            ++
1038        "###);
1039
1040        Ok(())
1041    }
1042
1043    /// Read multiple csv files (some are empty) in parallel
1044    ///
1045    /// some_empty
1046    /// ├── a_empty.csv
1047    /// ├── b.csv
1048    /// ├── c_empty.csv
1049    /// ├── d.csv
1050    /// └── e_empty.csv
1051    ///
1052    /// a_empty.csv/c_empty.csv/e_empty.csv:
1053    /// (file is empty)
1054    ///
1055    /// b.csv/d.csv:
1056    /// 1\n
1057    /// 1\n
1058    /// 1\n
1059    /// 1\n
1060    /// 1\n
1061    #[rstest(n_partitions, case(1), case(2), case(3), case(4))]
1062    #[tokio::test]
1063    async fn test_csv_parallel_some_file_empty(n_partitions: usize) -> Result<()> {
1064        let config = SessionConfig::new()
1065            .with_repartition_file_scans(true)
1066            .with_repartition_file_min_size(0)
1067            .with_target_partitions(n_partitions);
1068        let ctx = SessionContext::new_with_config(config);
1069        let file_format = Arc::new(CsvFormat::default().with_has_header(false));
1070        let listing_options = ListingOptions::new(file_format.clone())
1071            .with_file_extension(file_format.get_ext());
1072        ctx.register_listing_table(
1073            "empty",
1074            "tests/data/empty_files/some_empty",
1075            listing_options,
1076            None,
1077            None,
1078        )
1079        .await
1080        .unwrap();
1081
1082        // Require a predicate to enable repartition for the optimizer
1083        let query = "select sum(column_1) from empty where column_1 > 0;";
1084        let query_result = ctx.sql(query).await?.collect().await?;
1085        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1086
1087        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1088            +---------------------+
1089            | sum(empty.column_1) |
1090            +---------------------+
1091            | 10                  |
1092            +---------------------+
1093        "###);}
1094
1095        assert_eq!(n_partitions, actual_partitions); // Won't get partitioned if all files are empty
1096
1097        Ok(())
1098    }
1099
1100    /// Parallel scan on a csv file with only 1 byte in each line
1101    /// Testing partition byte range land on line boundaries
1102    ///
1103    /// one_col.csv:
1104    /// 5\n
1105    /// 5\n
1106    /// (...10 rows total)
1107    #[rstest(n_partitions, case(1), case(2), case(3), case(5), case(10), case(32))]
1108    #[tokio::test]
1109    async fn test_csv_parallel_one_col(n_partitions: usize) -> Result<()> {
1110        let config = SessionConfig::new()
1111            .with_repartition_file_scans(true)
1112            .with_repartition_file_min_size(0)
1113            .with_target_partitions(n_partitions);
1114        let ctx = SessionContext::new_with_config(config);
1115
1116        ctx.register_csv(
1117            "one_col",
1118            "tests/data/one_col.csv",
1119            CsvReadOptions::new().has_header(false),
1120        )
1121        .await?;
1122
1123        let query = "select sum(column_1) from one_col where column_1 > 0;";
1124        let query_result = ctx.sql(query).await?.collect().await?;
1125        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1126
1127        let file_size = std::fs::metadata("tests/data/one_col.csv")?.len() as usize;
1128        // A 20-Byte file at most get partitioned into 20 chunks
1129        let expected_partitions = if n_partitions <= file_size {
1130            n_partitions
1131        } else {
1132            file_size
1133        };
1134
1135        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1136        +-----------------------+
1137        | sum(one_col.column_1) |
1138        +-----------------------+
1139        | 50                    |
1140        +-----------------------+
1141        "###);
1142        }
1143
1144        assert_eq!(expected_partitions, actual_partitions);
1145
1146        Ok(())
1147    }
1148
1149    /// Parallel scan on a csv file with 2 wide rows
1150    /// The byte range of a partition might be within some line
1151    ///
1152    /// wode_rows.csv:
1153    /// 1, 1, ..., 1\n (100 columns total)
1154    /// 2, 2, ..., 2\n
1155    #[rstest(n_partitions, case(1), case(2), case(10), case(16))]
1156    #[tokio::test]
1157    async fn test_csv_parallel_wide_rows(n_partitions: usize) -> Result<()> {
1158        let config = SessionConfig::new()
1159            .with_repartition_file_scans(true)
1160            .with_repartition_file_min_size(0)
1161            .with_target_partitions(n_partitions);
1162        let ctx = SessionContext::new_with_config(config);
1163        ctx.register_csv(
1164            "wide_rows",
1165            "tests/data/wide_rows.csv",
1166            CsvReadOptions::new().has_header(false),
1167        )
1168        .await?;
1169
1170        let query = "select sum(column_1) + sum(column_33) + sum(column_50) + sum(column_77) + sum(column_100) as sum_of_5_cols from wide_rows where column_1 > 0;";
1171        let query_result = ctx.sql(query).await?.collect().await?;
1172        let actual_partitions = count_query_csv_partitions(&ctx, query).await?;
1173
1174        insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###"
1175            +---------------+
1176            | sum_of_5_cols |
1177            +---------------+
1178            | 15            |
1179            +---------------+
1180        "###);}
1181
1182        assert_eq!(n_partitions, actual_partitions);
1183
1184        Ok(())
1185    }
1186
1187    #[rstest]
1188    fn test_csv_deserializer_with_finish(
1189        #[values(1, 5, 17)] batch_size: usize,
1190        #[values(0, 5, 93)] line_count: usize,
1191    ) -> Result<()> {
1192        let schema = csv_schema();
1193        let generator = CsvBatchGenerator::new(batch_size, line_count);
1194        let mut deserializer = csv_deserializer(batch_size, &schema);
1195
1196        for data in generator {
1197            deserializer.digest(data);
1198        }
1199        deserializer.finish();
1200
1201        let batch_count = line_count.div_ceil(batch_size);
1202
1203        let mut all_batches = RecordBatch::new_empty(schema.clone());
1204        for _ in 0..batch_count {
1205            let output = deserializer.next()?;
1206            let DeserializerOutput::RecordBatch(batch) = output else {
1207                panic!("Expected RecordBatch, got {output:?}");
1208            };
1209            all_batches = concat_batches(&schema, &[all_batches, batch])?;
1210        }
1211        assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);
1212
1213        let expected = csv_expected_batch(schema, line_count)?;
1214
1215        assert_eq!(
1216            expected.clone(),
1217            all_batches.clone(),
1218            "Expected:\n{}\nActual:\n{}",
1219            pretty_format_batches(&[expected])?,
1220            pretty_format_batches(&[all_batches])?,
1221        );
1222
1223        Ok(())
1224    }
1225
1226    #[rstest]
1227    fn test_csv_deserializer_without_finish(
1228        #[values(1, 5, 17)] batch_size: usize,
1229        #[values(0, 5, 93)] line_count: usize,
1230    ) -> Result<()> {
1231        let schema = csv_schema();
1232        let generator = CsvBatchGenerator::new(batch_size, line_count);
1233        let mut deserializer = csv_deserializer(batch_size, &schema);
1234
1235        for data in generator {
1236            deserializer.digest(data);
1237        }
1238
1239        let batch_count = line_count / batch_size;
1240
1241        let mut all_batches = RecordBatch::new_empty(schema.clone());
1242        for _ in 0..batch_count {
1243            let output = deserializer.next()?;
1244            let DeserializerOutput::RecordBatch(batch) = output else {
1245                panic!("Expected RecordBatch, got {output:?}");
1246            };
1247            all_batches = concat_batches(&schema, &[all_batches, batch])?;
1248        }
1249        assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);
1250
1251        let expected = csv_expected_batch(schema, batch_count * batch_size)?;
1252
1253        assert_eq!(
1254            expected.clone(),
1255            all_batches.clone(),
1256            "Expected:\n{}\nActual:\n{}",
1257            pretty_format_batches(&[expected])?,
1258            pretty_format_batches(&[all_batches])?,
1259        );
1260
1261        Ok(())
1262    }
1263
1264    struct CsvBatchGenerator {
1265        batch_size: usize,
1266        line_count: usize,
1267        offset: usize,
1268    }
1269
1270    impl CsvBatchGenerator {
1271        fn new(batch_size: usize, line_count: usize) -> Self {
1272            Self {
1273                batch_size,
1274                line_count,
1275                offset: 0,
1276            }
1277        }
1278    }
1279
1280    impl Iterator for CsvBatchGenerator {
1281        type Item = Bytes;
1282
1283        fn next(&mut self) -> Option<Self::Item> {
1284            // Return `batch_size` rows per batch:
1285            let mut buffer = Vec::new();
1286            for _ in 0..self.batch_size {
1287                if self.offset >= self.line_count {
1288                    break;
1289                }
1290                buffer.extend_from_slice(&csv_line(self.offset));
1291                self.offset += 1;
1292            }
1293
1294            (!buffer.is_empty()).then(|| buffer.into())
1295        }
1296    }
1297
1298    fn csv_expected_batch(schema: SchemaRef, line_count: usize) -> Result<RecordBatch> {
1299        let mut c1 = Vec::with_capacity(line_count);
1300        let mut c2 = Vec::with_capacity(line_count);
1301        let mut c3 = Vec::with_capacity(line_count);
1302        let mut c4 = Vec::with_capacity(line_count);
1303
1304        for i in 0..line_count {
1305            let (int_value, float_value, bool_value, char_value) = csv_values(i);
1306            c1.push(int_value);
1307            c2.push(float_value);
1308            c3.push(bool_value);
1309            c4.push(char_value);
1310        }
1311
1312        let expected = RecordBatch::try_new(
1313            schema.clone(),
1314            vec![
1315                Arc::new(Int32Array::from(c1)),
1316                Arc::new(Float64Array::from(c2)),
1317                Arc::new(BooleanArray::from(c3)),
1318                Arc::new(StringArray::from(c4)),
1319            ],
1320        )?;
1321        Ok(expected)
1322    }
1323
1324    fn csv_line(line_number: usize) -> Bytes {
1325        let (int_value, float_value, bool_value, char_value) = csv_values(line_number);
1326        format!("{int_value},{float_value},{bool_value},{char_value}\n").into()
1327    }
1328
1329    fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
1330        let int_value = line_number as i32;
1331        let float_value = line_number as f64;
1332        let bool_value = line_number.is_multiple_of(2);
1333        let char_value = format!("{line_number}-string");
1334        (int_value, float_value, bool_value, char_value)
1335    }
1336
1337    fn csv_schema() -> SchemaRef {
1338        Arc::new(Schema::new(vec![
1339            Field::new("c1", DataType::Int32, true),
1340            Field::new("c2", DataType::Float64, true),
1341            Field::new("c3", DataType::Boolean, true),
1342            Field::new("c4", DataType::Utf8, true),
1343        ]))
1344    }
1345
1346    fn csv_deserializer(
1347        batch_size: usize,
1348        schema: &Arc<Schema>,
1349    ) -> impl BatchDeserializer<Bytes> {
1350        let decoder = ReaderBuilder::new(schema.clone())
1351            .with_batch_size(batch_size)
1352            .build_decoder();
1353        DecoderDeserializer::new(CsvDecoder::new(decoder))
1354    }
1355
1356    fn csv_deserializer_with_truncated(
1357        batch_size: usize,
1358        schema: &Arc<Schema>,
1359    ) -> impl BatchDeserializer<Bytes> {
1360        // using Arrow's ReaderBuilder and enabling truncated_rows
1361        let decoder = ReaderBuilder::new(schema.clone())
1362            .with_batch_size(batch_size)
1363            .with_truncated_rows(true) // <- enable runtime truncated_rows
1364            .build_decoder();
1365        DecoderDeserializer::new(CsvDecoder::new(decoder))
1366    }
1367
1368    #[tokio::test]
1369    async fn infer_schema_with_truncated_rows_true() -> Result<()> {
1370        let session_ctx = SessionContext::new();
1371        let state = session_ctx.state();
1372
1373        // CSV: header has 3 columns, but first data row has only 2 columns, second row has 3
1374        let csv_data = Bytes::from("a,b,c\n1,2\n3,4,5\n");
1375        let variable_object_store = Arc::new(VariableStream::new(csv_data, 1));
1376        let object_meta = ObjectMeta {
1377            location: Path::parse("/")?,
1378            last_modified: DateTime::default(),
1379            size: u64::MAX,
1380            e_tag: None,
1381            version: None,
1382        };
1383
1384        // Construct CsvFormat and enable truncated_rows via CsvOptions
1385        let csv_options = CsvOptions::default().with_truncated_rows(true);
1386        let csv_format = CsvFormat::default()
1387            .with_has_header(true)
1388            .with_options(csv_options)
1389            .with_schema_infer_max_rec(10);
1390
1391        let inferred_schema = csv_format
1392            .infer_schema(
1393                &state,
1394                &(variable_object_store.clone() as Arc<dyn ObjectStore>),
1395                &[object_meta],
1396            )
1397            .await?;
1398
1399        // header has 3 columns; inferred schema should also have 3
1400        assert_eq!(inferred_schema.fields().len(), 3);
1401
1402        // inferred columns should be nullable
1403        for f in inferred_schema.fields() {
1404            assert!(f.is_nullable());
1405        }
1406
1407        Ok(())
1408    }
1409    #[test]
1410    fn test_decoder_truncated_rows_runtime() -> Result<()> {
1411        // Synchronous test: Decoder API used here is synchronous
1412        let schema = csv_schema(); // helper already defined in file
1413
1414        // Construct a decoder that enables truncated_rows at runtime
1415        let mut deserializer = csv_deserializer_with_truncated(10, &schema);
1416
1417        // Provide two rows: first row complete, second row missing last column
1418        let input = Bytes::from("0,0.0,true,0-string\n1,1.0,true\n");
1419        deserializer.digest(input);
1420
1421        // Finish and collect output
1422        deserializer.finish();
1423
1424        let output = deserializer.next()?;
1425        match output {
1426            DeserializerOutput::RecordBatch(batch) => {
1427                // ensure at least two rows present
1428                assert!(batch.num_rows() >= 2);
1429                // column 4 (index 3) should be a StringArray where second row is NULL
1430                let col4 = batch
1431                    .column(3)
1432                    .as_any()
1433                    .downcast_ref::<StringArray>()
1434                    .expect("column 4 should be StringArray");
1435
1436                // first row present, second row should be null
1437                assert!(!col4.is_null(0));
1438                assert!(col4.is_null(1));
1439            }
1440            other => panic!("expected RecordBatch but got {other:?}"),
1441        }
1442        Ok(())
1443    }
1444
1445    #[tokio::test]
1446    async fn infer_schema_truncated_rows_false_error() -> Result<()> {
1447        let session_ctx = SessionContext::new();
1448        let state = session_ctx.state();
1449
1450        // CSV: header has 4 cols, first data row has 3 cols -> truncated at end
1451        let csv_data = Bytes::from("id,a,b,c\n1,foo,bar\n2,foo,bar,baz\n");
1452        let variable_object_store = Arc::new(VariableStream::new(csv_data, 1));
1453        let object_meta = ObjectMeta {
1454            location: Path::parse("/")?,
1455            last_modified: DateTime::default(),
1456            size: u64::MAX,
1457            e_tag: None,
1458            version: None,
1459        };
1460
1461        // CsvFormat without enabling truncated_rows (default behavior = false)
1462        let csv_format = CsvFormat::default()
1463            .with_has_header(true)
1464            .with_schema_infer_max_rec(10);
1465
1466        let res = csv_format
1467            .infer_schema(
1468                &state,
1469                &(variable_object_store.clone() as Arc<dyn ObjectStore>),
1470                &[object_meta],
1471            )
1472            .await;
1473
1474        // Expect an error due to unequal lengths / incorrect number of fields
1475        assert!(
1476            res.is_err(),
1477            "expected infer_schema to error on truncated rows when disabled"
1478        );
1479
1480        // Optional: check message contains indicative text (two known possibilities)
1481        if let Err(err) = res {
1482            let msg = format!("{err}");
1483            assert!(
1484                msg.contains("Encountered unequal lengths")
1485                    || msg.contains("incorrect number of fields"),
1486                "unexpected error message: {msg}",
1487            );
1488        }
1489
1490        Ok(())
1491    }
1492
1493    #[tokio::test]
1494    async fn test_read_csv_truncated_rows_via_tempfile() -> Result<()> {
1495        use std::io::Write;
1496
1497        // create a SessionContext
1498        let ctx = SessionContext::new();
1499
1500        // Create a temp file with a .csv suffix so the reader accepts it
1501        let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; // ensures path ends with .csv
1502                                                                           // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete.
1503        write!(tmp, "a,b,c\n1,2\n3,4,5\n")?;
1504        let path = tmp.path().to_str().unwrap().to_string();
1505
1506        // Build CsvReadOptions: header present, enable truncated_rows.
1507        // (Use the exact builder method your crate exposes: `truncated_rows(true)` here,
1508        //  if the method name differs in your codebase use the appropriate one.)
1509        let options = CsvReadOptions::default().truncated_rows(true);
1510
1511        println!("options: {}, path: {path}", options.truncated_rows);
1512
1513        // Call the API under test
1514        let df = ctx.read_csv(&path, options).await?;
1515
1516        // Collect the results and combine batches so we can inspect columns
1517        let batches = df.collect().await?;
1518        let combined = concat_batches(&batches[0].schema(), &batches)?;
1519
1520        // Column 'c' is the 3rd column (index 2). The first data row was truncated -> should be NULL.
1521        let col_c = combined.column(2);
1522        assert!(
1523            col_c.is_null(0),
1524            "expected first row column 'c' to be NULL due to truncated row"
1525        );
1526
1527        // Also ensure we read at least one row
1528        assert!(combined.num_rows() >= 2);
1529
1530        Ok(())
1531    }
1532}