datafusion_datasource_avro/avro_to_arrow/
reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::arrow_array_reader::AvroArrowArrayReader;
19use arrow::datatypes::{Fields, SchemaRef};
20use arrow::error::Result as ArrowResult;
21use arrow::record_batch::RecordBatch;
22use datafusion_common::Result;
23use std::io::{Read, Seek};
24use std::sync::Arc;
25
26/// Avro file reader builder
27#[derive(Debug)]
28pub struct ReaderBuilder {
29    /// Optional schema for the Avro file
30    ///
31    /// If the schema is not supplied, the reader will try to read the schema.
32    schema: Option<SchemaRef>,
33    /// Batch size (number of records to load each time)
34    ///
35    /// The default batch size when using the `ReaderBuilder` is 1024 records
36    batch_size: usize,
37    /// Optional projection for which columns to load (zero-based column indices)
38    projection: Option<Vec<String>>,
39}
40
41impl Default for ReaderBuilder {
42    fn default() -> Self {
43        Self {
44            schema: None,
45            batch_size: 1024,
46            projection: None,
47        }
48    }
49}
50
51impl ReaderBuilder {
52    /// Create a new builder for configuring Avro parsing options.
53    ///
54    /// To convert a builder into a reader, call `Reader::from_builder`
55    ///
56    /// # Example
57    ///
58    /// ```
59    /// use std::fs::File;
60    ///
61    /// use datafusion_datasource_avro::avro_to_arrow::{Reader, ReaderBuilder};
62    ///
63    /// fn example() -> Reader<'static, File> {
64    ///     let file = File::open("test/data/basic.avro").unwrap();
65    ///
66    ///     // create a builder, inferring the schema with the first 100 records
67    ///     let builder = ReaderBuilder::new().read_schema().with_batch_size(100);
68    ///
69    ///     let reader = builder.build::<File>(file).unwrap();
70    ///
71    ///     reader
72    /// }
73    /// ```
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Set the Avro file's schema
79    pub fn with_schema(mut self, schema: SchemaRef) -> Self {
80        self.schema = Some(schema);
81        self
82    }
83
84    /// Set the Avro reader to infer the schema of the file
85    pub fn read_schema(mut self) -> Self {
86        // remove any schema that is set
87        self.schema = None;
88        self
89    }
90
91    /// Set the batch size (number of records to load at one time)
92    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
93        self.batch_size = batch_size;
94        self
95    }
96
97    /// Set the reader's column projection
98    pub fn with_projection(mut self, projection: Vec<String>) -> Self {
99        self.projection = Some(projection);
100        self
101    }
102
103    /// Create a new `Reader` from the `ReaderBuilder`
104    pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
105    where
106        R: Read + Seek,
107    {
108        let mut source = source;
109
110        // check if schema should be inferred
111        let schema = match self.schema {
112            Some(schema) => schema,
113            None => Arc::new(super::read_avro_schema_from_reader(&mut source)?),
114        };
115        source.rewind()?;
116        Reader::try_new(source, schema, self.batch_size, self.projection)
117    }
118}
119
120/// Avro file record  reader
121pub struct Reader<'a, R: Read> {
122    array_reader: AvroArrowArrayReader<'a, R>,
123    schema: SchemaRef,
124    batch_size: usize,
125}
126
127impl<R: Read> Reader<'_, R> {
128    /// Create a new Avro Reader from any value that implements the `Read` trait.
129    ///
130    /// If reading a `File`, you can customise the Reader, such as to enable schema
131    /// inference, use `ReaderBuilder`.
132    ///
133    /// If projection is provided, it uses a schema with only the fields in the projection, respecting their order.
134    /// Only the first level of projection is handled. No further projection currently occurs, but would be
135    /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`.
136    pub fn try_new(
137        reader: R,
138        schema: SchemaRef,
139        batch_size: usize,
140        projection: Option<Vec<String>>,
141    ) -> Result<Self> {
142        let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else(
143            || Arc::clone(&schema),
144            |proj| {
145                Arc::new(arrow::datatypes::Schema::new(
146                    proj.iter()
147                        .filter_map(|name| {
148                            schema.column_with_name(name).map(|(_, f)| f.clone())
149                        })
150                        .collect::<Fields>(),
151                ))
152            },
153        );
154
155        Ok(Self {
156            array_reader: AvroArrowArrayReader::try_new(
157                reader,
158                Arc::clone(&projected_schema),
159            )?,
160            schema: projected_schema,
161            batch_size,
162        })
163    }
164
165    /// Returns the schema of the reader, useful for getting the schema without reading
166    /// record batches
167    pub fn schema(&self) -> SchemaRef {
168        Arc::clone(&self.schema)
169    }
170}
171
172impl<R: Read> Iterator for Reader<'_, R> {
173    type Item = ArrowResult<RecordBatch>;
174
175    /// Returns the next batch of results (defined by `self.batch_size`), or `None` if there
176    /// are no more results.
177    fn next(&mut self) -> Option<Self::Item> {
178        self.array_reader.next_batch(self.batch_size)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use arrow::array::*;
186    use arrow::array::{
187        BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
188        TimestampMicrosecondArray,
189    };
190    use arrow::datatypes::TimeUnit;
191    use arrow::datatypes::{DataType, Field};
192    use std::fs::File;
193
194    fn build_reader(name: &'_ str, projection: Option<Vec<String>>) -> Reader<'_, File> {
195        let testdata = datafusion_common::test_util::arrow_test_data();
196        let filename = format!("{testdata}/avro/{name}");
197        let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64);
198        if let Some(projection) = projection {
199            builder = builder.with_projection(projection);
200        }
201        builder.build(File::open(filename).unwrap()).unwrap()
202    }
203
204    fn get_col<'a, T: 'static>(
205        batch: &'a RecordBatch,
206        col: (usize, &Field),
207    ) -> Option<&'a T> {
208        batch.column(col.0).as_any().downcast_ref::<T>()
209    }
210
211    #[test]
212    fn test_avro_basic() {
213        let mut reader = build_reader("alltypes_dictionary.avro", None);
214        let batch = reader.next().unwrap().unwrap();
215
216        assert_eq!(11, batch.num_columns());
217        assert_eq!(2, batch.num_rows());
218
219        let schema = reader.schema();
220        let batch_schema = batch.schema();
221        assert_eq!(schema, batch_schema);
222
223        let id = schema.column_with_name("id").unwrap();
224        assert_eq!(0, id.0);
225        assert_eq!(&DataType::Int32, id.1.data_type());
226        let col = get_col::<Int32Array>(&batch, id).unwrap();
227        assert_eq!(0, col.value(0));
228        assert_eq!(1, col.value(1));
229        let bool_col = schema.column_with_name("bool_col").unwrap();
230        assert_eq!(1, bool_col.0);
231        assert_eq!(&DataType::Boolean, bool_col.1.data_type());
232        let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
233        assert!(col.value(0));
234        assert!(!col.value(1));
235        let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
236        assert_eq!(2, tinyint_col.0);
237        assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
238        let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
239        assert_eq!(0, col.value(0));
240        assert_eq!(1, col.value(1));
241        let smallint_col = schema.column_with_name("smallint_col").unwrap();
242        assert_eq!(3, smallint_col.0);
243        assert_eq!(&DataType::Int32, smallint_col.1.data_type());
244        let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
245        assert_eq!(0, col.value(0));
246        assert_eq!(1, col.value(1));
247        let int_col = schema.column_with_name("int_col").unwrap();
248        assert_eq!(4, int_col.0);
249        let col = get_col::<Int32Array>(&batch, int_col).unwrap();
250        assert_eq!(0, col.value(0));
251        assert_eq!(1, col.value(1));
252        assert_eq!(&DataType::Int32, int_col.1.data_type());
253        let col = get_col::<Int32Array>(&batch, int_col).unwrap();
254        assert_eq!(0, col.value(0));
255        assert_eq!(1, col.value(1));
256        let bigint_col = schema.column_with_name("bigint_col").unwrap();
257        assert_eq!(5, bigint_col.0);
258        let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
259        assert_eq!(0, col.value(0));
260        assert_eq!(10, col.value(1));
261        assert_eq!(&DataType::Int64, bigint_col.1.data_type());
262        let float_col = schema.column_with_name("float_col").unwrap();
263        assert_eq!(6, float_col.0);
264        let col = get_col::<Float32Array>(&batch, float_col).unwrap();
265        assert_eq!(0.0, col.value(0));
266        assert_eq!(1.1, col.value(1));
267        assert_eq!(&DataType::Float32, float_col.1.data_type());
268        let col = get_col::<Float32Array>(&batch, float_col).unwrap();
269        assert_eq!(0.0, col.value(0));
270        assert_eq!(1.1, col.value(1));
271        let double_col = schema.column_with_name("double_col").unwrap();
272        assert_eq!(7, double_col.0);
273        assert_eq!(&DataType::Float64, double_col.1.data_type());
274        let col = get_col::<Float64Array>(&batch, double_col).unwrap();
275        assert_eq!(0.0, col.value(0));
276        assert_eq!(10.1, col.value(1));
277        let date_string_col = schema.column_with_name("date_string_col").unwrap();
278        assert_eq!(8, date_string_col.0);
279        assert_eq!(&DataType::Binary, date_string_col.1.data_type());
280        let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
281        assert_eq!("01/01/09".as_bytes(), col.value(0));
282        assert_eq!("01/01/09".as_bytes(), col.value(1));
283        let string_col = schema.column_with_name("string_col").unwrap();
284        assert_eq!(9, string_col.0);
285        assert_eq!(&DataType::Binary, string_col.1.data_type());
286        let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
287        assert_eq!("0".as_bytes(), col.value(0));
288        assert_eq!("1".as_bytes(), col.value(1));
289        let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
290        assert_eq!(10, timestamp_col.0);
291        assert_eq!(
292            &DataType::Timestamp(TimeUnit::Microsecond, None),
293            timestamp_col.1.data_type()
294        );
295        let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
296        assert_eq!(1230768000000000, col.value(0));
297        assert_eq!(1230768060000000, col.value(1));
298    }
299
300    #[test]
301    fn test_avro_with_projection() {
302        // Test projection to filter and reorder columns
303        let projection = Some(vec![
304            "string_col".to_string(),
305            "double_col".to_string(),
306            "bool_col".to_string(),
307        ]);
308        let mut reader = build_reader("alltypes_dictionary.avro", projection);
309        let batch = reader.next().unwrap().unwrap();
310
311        // Only 3 columns should be present (not all 11)
312        assert_eq!(3, batch.num_columns());
313        assert_eq!(2, batch.num_rows());
314
315        let schema = reader.schema();
316        let batch_schema = batch.schema();
317        assert_eq!(schema, batch_schema);
318
319        // Verify columns are in the order specified in projection
320        // First column should be string_col (was at index 9 in original)
321        assert_eq!("string_col", schema.field(0).name());
322        assert_eq!(&DataType::Binary, schema.field(0).data_type());
323        let col = batch
324            .column(0)
325            .as_any()
326            .downcast_ref::<BinaryArray>()
327            .unwrap();
328        assert_eq!("0".as_bytes(), col.value(0));
329        assert_eq!("1".as_bytes(), col.value(1));
330
331        // Second column should be double_col (was at index 7 in original)
332        assert_eq!("double_col", schema.field(1).name());
333        assert_eq!(&DataType::Float64, schema.field(1).data_type());
334        let col = batch
335            .column(1)
336            .as_any()
337            .downcast_ref::<Float64Array>()
338            .unwrap();
339        assert_eq!(0.0, col.value(0));
340        assert_eq!(10.1, col.value(1));
341
342        // Third column should be bool_col (was at index 1 in original)
343        assert_eq!("bool_col", schema.field(2).name());
344        assert_eq!(&DataType::Boolean, schema.field(2).data_type());
345        let col = batch
346            .column(2)
347            .as_any()
348            .downcast_ref::<BooleanArray>()
349            .unwrap();
350        assert!(col.value(0));
351        assert!(!col.value(1));
352    }
353}