datafusion_datasource_avro/avro_to_arrow/
arrow_array_reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Avro to Arrow array readers
19
20use apache_avro::schema::RecordSchema;
21use apache_avro::{
22    error::Details as AvroErrorDetails,
23    schema::{Schema as AvroSchema, SchemaKind},
24    types::Value,
25    Error as AvroError, Reader as AvroReader,
26};
27use arrow::array::{
28    make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef,
29    BooleanBuilder, LargeStringArray, ListBuilder, NullArray, OffsetSizeTrait,
30    PrimitiveArray, StringArray, StringBuilder, StringDictionaryBuilder,
31};
32use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray};
33use arrow::buffer::{Buffer, MutableBuffer};
34use arrow::datatypes::{
35    ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type,
36    Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
37    Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
38    Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
39    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
40    UInt8Type,
41};
42use arrow::datatypes::{Fields, SchemaRef};
43use arrow::error::ArrowError;
44use arrow::error::ArrowError::SchemaError;
45use arrow::error::Result as ArrowResult;
46use arrow::record_batch::RecordBatch;
47use arrow::util::bit_util;
48use datafusion_common::arrow_err;
49use datafusion_common::error::{DataFusionError, Result};
50use num_traits::NumCast;
51use std::collections::BTreeMap;
52use std::io::Read;
53use std::sync::Arc;
54
55type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>];
56
57pub struct AvroArrowArrayReader<'a, R: Read> {
58    reader: AvroReader<'a, R>,
59    schema: SchemaRef,
60    schema_lookup: BTreeMap<String, usize>,
61}
62
63impl<R: Read> AvroArrowArrayReader<'_, R> {
64    pub fn try_new(reader: R, schema: SchemaRef) -> Result<Self> {
65        let reader = AvroReader::new(reader)?;
66        let writer_schema = reader.writer_schema().clone();
67        let schema_lookup = Self::schema_lookup(writer_schema)?;
68        Ok(Self {
69            reader,
70            schema,
71            schema_lookup,
72        })
73    }
74
75    pub fn schema_lookup(schema: AvroSchema) -> Result<BTreeMap<String, usize>> {
76        match schema {
77            AvroSchema::Record(RecordSchema {
78                fields, mut lookup, ..
79            }) => {
80                for field in fields {
81                    Self::child_schema_lookup(&field.name, &field.schema, &mut lookup)?;
82                }
83                Ok(lookup)
84            }
85            _ => arrow_err!(SchemaError(
86                "expected avro schema to be a record".to_string(),
87            )),
88        }
89    }
90
91    fn child_schema_lookup<'b>(
92        parent_field_name: &str,
93        schema: &AvroSchema,
94        schema_lookup: &'b mut BTreeMap<String, usize>,
95    ) -> Result<&'b BTreeMap<String, usize>> {
96        match schema {
97            AvroSchema::Union(us) => {
98                let has_nullable = us
99                    .find_schema_with_known_schemata::<apache_avro::Schema>(
100                        &Value::Null,
101                        None,
102                        &None,
103                    )
104                    .is_some();
105                let sub_schemas = us.variants();
106                if has_nullable && sub_schemas.len() == 2 {
107                    if let Some(sub_schema) =
108                        sub_schemas.iter().find(|&s| !matches!(s, AvroSchema::Null))
109                    {
110                        Self::child_schema_lookup(
111                            parent_field_name,
112                            sub_schema,
113                            schema_lookup,
114                        )?;
115                    }
116                }
117            }
118            AvroSchema::Record(RecordSchema { fields, lookup, .. }) => {
119                lookup.iter().for_each(|(field_name, pos)| {
120                    schema_lookup
121                        .insert(format!("{parent_field_name}.{field_name}"), *pos);
122                });
123
124                for field in fields {
125                    let sub_parent_field_name =
126                        format!("{}.{}", parent_field_name, field.name);
127                    Self::child_schema_lookup(
128                        &sub_parent_field_name,
129                        &field.schema,
130                        schema_lookup,
131                    )?;
132                }
133            }
134            AvroSchema::Array(schema) => {
135                let sub_parent_field_name = format!("{parent_field_name}.element");
136                Self::child_schema_lookup(
137                    &sub_parent_field_name,
138                    &schema.items,
139                    schema_lookup,
140                )?;
141            }
142            _ => (),
143        }
144        Ok(schema_lookup)
145    }
146
147    /// Read the next batch of records
148    pub fn next_batch(&mut self, batch_size: usize) -> Option<ArrowResult<RecordBatch>> {
149        let rows_result = self
150            .reader
151            .by_ref()
152            .take(batch_size)
153            .map(|value| match value {
154                Ok(Value::Record(v)) => Ok(v),
155                Err(e) => Err(ArrowError::ParseError(format!(
156                    "Failed to parse avro value: {e}"
157                ))),
158                other => Err(ArrowError::ParseError(format!(
159                    "Row needs to be of type object, got: {other:?}"
160                ))),
161            })
162            .collect::<ArrowResult<Vec<Vec<(String, Value)>>>>();
163
164        let rows = match rows_result {
165            // Return error early
166            Err(e) => return Some(Err(e)),
167            // No rows: return None early
168            Ok(rows) if rows.is_empty() => return None,
169            Ok(rows) => rows,
170        };
171
172        let rows = rows.iter().collect::<Vec<&Vec<(String, Value)>>>();
173        let arrays = self.build_struct_array(&rows, "", self.schema.fields());
174
175        Some(arrays.and_then(|arr| RecordBatch::try_new(Arc::clone(&self.schema), arr)))
176    }
177
178    fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef {
179        let mut builder = BooleanBuilder::with_capacity(rows.len());
180        for row in rows {
181            if let Some(value) = self.field_lookup(col_name, row) {
182                if let Some(boolean) = resolve_boolean(value) {
183                    builder.append_value(boolean)
184                } else {
185                    builder.append_null();
186                }
187            } else {
188                builder.append_null();
189            }
190        }
191        Arc::new(builder.finish())
192    }
193
194    fn build_primitive_array<T>(&self, rows: RecordSlice, col_name: &str) -> ArrayRef
195    where
196        T: ArrowNumericType + Resolver,
197        T::Native: NumCast,
198    {
199        Arc::new(
200            rows.iter()
201                .map(|row| {
202                    self.field_lookup(col_name, row)
203                        .and_then(|value| resolve_item::<T>(value))
204                })
205                .collect::<PrimitiveArray<T>>(),
206        )
207    }
208
209    #[inline(always)]
210    fn build_string_dictionary_builder<T>(
211        &self,
212        row_len: usize,
213    ) -> StringDictionaryBuilder<T>
214    where
215        T: ArrowPrimitiveType + ArrowDictionaryKeyType,
216    {
217        StringDictionaryBuilder::with_capacity(row_len, row_len, row_len)
218    }
219
220    fn build_wrapped_list_array(
221        &self,
222        rows: RecordSlice,
223        col_name: &str,
224        key_type: &DataType,
225    ) -> ArrowResult<ArrayRef> {
226        match *key_type {
227            DataType::Int8 => {
228                let dtype = DataType::Dictionary(
229                    Box::new(DataType::Int8),
230                    Box::new(DataType::Utf8),
231                );
232                self.list_array_string_array_builder::<Int8Type>(&dtype, col_name, rows)
233            }
234            DataType::Int16 => {
235                let dtype = DataType::Dictionary(
236                    Box::new(DataType::Int16),
237                    Box::new(DataType::Utf8),
238                );
239                self.list_array_string_array_builder::<Int16Type>(&dtype, col_name, rows)
240            }
241            DataType::Int32 => {
242                let dtype = DataType::Dictionary(
243                    Box::new(DataType::Int32),
244                    Box::new(DataType::Utf8),
245                );
246                self.list_array_string_array_builder::<Int32Type>(&dtype, col_name, rows)
247            }
248            DataType::Int64 => {
249                let dtype = DataType::Dictionary(
250                    Box::new(DataType::Int64),
251                    Box::new(DataType::Utf8),
252                );
253                self.list_array_string_array_builder::<Int64Type>(&dtype, col_name, rows)
254            }
255            DataType::UInt8 => {
256                let dtype = DataType::Dictionary(
257                    Box::new(DataType::UInt8),
258                    Box::new(DataType::Utf8),
259                );
260                self.list_array_string_array_builder::<UInt8Type>(&dtype, col_name, rows)
261            }
262            DataType::UInt16 => {
263                let dtype = DataType::Dictionary(
264                    Box::new(DataType::UInt16),
265                    Box::new(DataType::Utf8),
266                );
267                self.list_array_string_array_builder::<UInt16Type>(&dtype, col_name, rows)
268            }
269            DataType::UInt32 => {
270                let dtype = DataType::Dictionary(
271                    Box::new(DataType::UInt32),
272                    Box::new(DataType::Utf8),
273                );
274                self.list_array_string_array_builder::<UInt32Type>(&dtype, col_name, rows)
275            }
276            DataType::UInt64 => {
277                let dtype = DataType::Dictionary(
278                    Box::new(DataType::UInt64),
279                    Box::new(DataType::Utf8),
280                );
281                self.list_array_string_array_builder::<UInt64Type>(&dtype, col_name, rows)
282            }
283            ref e => Err(SchemaError(format!(
284                "Data type is currently not supported for dictionaries in list : {e}"
285            ))),
286        }
287    }
288
289    #[inline(always)]
290    fn list_array_string_array_builder<D>(
291        &self,
292        data_type: &DataType,
293        col_name: &str,
294        rows: RecordSlice,
295    ) -> ArrowResult<ArrayRef>
296    where
297        D: ArrowPrimitiveType + ArrowDictionaryKeyType,
298    {
299        let mut builder: Box<dyn ArrayBuilder> = match data_type {
300            DataType::Utf8 => {
301                let values_builder = StringBuilder::with_capacity(rows.len(), 5);
302                Box::new(ListBuilder::new(values_builder))
303            }
304            DataType::Dictionary(_, _) => {
305                let values_builder =
306                    self.build_string_dictionary_builder::<D>(rows.len() * 5);
307                Box::new(ListBuilder::new(values_builder))
308            }
309            e => {
310                return Err(SchemaError(format!(
311                    "Nested list data builder type is not supported: {e}"
312                )))
313            }
314        };
315
316        for row in rows {
317            if let Some(value) = self.field_lookup(col_name, row) {
318                let value = maybe_resolve_union(value);
319                // value can be an array or a scalar
320                let vals: Vec<Option<String>> = if let Value::String(v) = value {
321                    vec![Some(v.to_string())]
322                } else if let Value::Array(n) = value {
323                    n.iter()
324                        .map(resolve_string)
325                        .collect::<ArrowResult<Vec<Option<String>>>>()?
326                        .into_iter()
327                        .collect::<Vec<Option<String>>>()
328                } else if let Value::Null = value {
329                    vec![None]
330                } else if !matches!(value, Value::Record(_)) {
331                    vec![resolve_string(value)?]
332                } else {
333                    return Err(SchemaError(
334                        "Only scalars are currently supported in Avro arrays".to_string(),
335                    ));
336                };
337
338                // TODO: ARROW-10335: APIs of dictionary arrays and others are different. Unify
339                // them.
340                match data_type {
341                    DataType::Utf8 => {
342                        let builder = builder
343                            .as_any_mut()
344                            .downcast_mut::<ListBuilder<StringBuilder>>()
345                            .ok_or_else(||SchemaError(
346                                "Cast failed for ListBuilder<StringBuilder> during nested data parsing".to_string(),
347                            ))?;
348                        for val in vals {
349                            if let Some(v) = val {
350                                builder.values().append_value(&v)
351                            } else {
352                                builder.values().append_null()
353                            };
354                        }
355
356                        // Append to the list
357                        builder.append(true);
358                    }
359                    DataType::Dictionary(_, _) => {
360                        let builder = builder.as_any_mut().downcast_mut::<ListBuilder<StringDictionaryBuilder<D>>>().ok_or_else(||SchemaError(
361                            "Cast failed for ListBuilder<StringDictionaryBuilder> during nested data parsing".to_string(),
362                        ))?;
363                        for val in vals {
364                            if let Some(v) = val {
365                                let _ = builder.values().append(&v)?;
366                            } else {
367                                builder.values().append_null()
368                            };
369                        }
370
371                        // Append to the list
372                        builder.append(true);
373                    }
374                    e => {
375                        return Err(SchemaError(format!(
376                            "Nested list data builder type is not supported: {e}"
377                        )))
378                    }
379                }
380            }
381        }
382
383        Ok(builder.finish() as ArrayRef)
384    }
385
386    #[inline(always)]
387    fn build_dictionary_array<T>(
388        &self,
389        rows: RecordSlice,
390        col_name: &str,
391    ) -> ArrowResult<ArrayRef>
392    where
393        T::Native: NumCast,
394        T: ArrowPrimitiveType + ArrowDictionaryKeyType,
395    {
396        let mut builder: StringDictionaryBuilder<T> =
397            self.build_string_dictionary_builder(rows.len());
398        for row in rows {
399            if let Some(value) = self.field_lookup(col_name, row) {
400                if let Ok(Some(str_v)) = resolve_string(value) {
401                    builder.append(str_v).map(drop)?
402                } else {
403                    builder.append_null()
404                }
405            } else {
406                builder.append_null()
407            }
408        }
409        Ok(Arc::new(builder.finish()) as ArrayRef)
410    }
411
412    #[inline(always)]
413    fn build_string_dictionary_array(
414        &self,
415        rows: RecordSlice,
416        col_name: &str,
417        key_type: &DataType,
418        value_type: &DataType,
419    ) -> ArrowResult<ArrayRef> {
420        if let DataType::Utf8 = *value_type {
421            match *key_type {
422                DataType::Int8 => self.build_dictionary_array::<Int8Type>(rows, col_name),
423                DataType::Int16 => {
424                    self.build_dictionary_array::<Int16Type>(rows, col_name)
425                }
426                DataType::Int32 => {
427                    self.build_dictionary_array::<Int32Type>(rows, col_name)
428                }
429                DataType::Int64 => {
430                    self.build_dictionary_array::<Int64Type>(rows, col_name)
431                }
432                DataType::UInt8 => {
433                    self.build_dictionary_array::<UInt8Type>(rows, col_name)
434                }
435                DataType::UInt16 => {
436                    self.build_dictionary_array::<UInt16Type>(rows, col_name)
437                }
438                DataType::UInt32 => {
439                    self.build_dictionary_array::<UInt32Type>(rows, col_name)
440                }
441                DataType::UInt64 => {
442                    self.build_dictionary_array::<UInt64Type>(rows, col_name)
443                }
444                _ => Err(SchemaError("unsupported dictionary key type".to_string())),
445            }
446        } else {
447            Err(SchemaError(
448                "dictionary types other than UTF-8 not yet supported".to_string(),
449            ))
450        }
451    }
452
453    /// Build a nested GenericListArray from a list of unnested `Value`s
454    fn build_nested_list_array<OffsetSize: OffsetSizeTrait>(
455        &self,
456        parent_field_name: &str,
457        rows: &[&Value],
458        list_field: &Field,
459    ) -> ArrowResult<ArrayRef> {
460        // build list offsets
461        let mut cur_offset = OffsetSize::zero();
462        let list_len = rows.len();
463        let num_list_bytes = bit_util::ceil(list_len, 8);
464        let mut offsets = Vec::with_capacity(list_len + 1);
465        let mut list_nulls = MutableBuffer::from_len_zeroed(num_list_bytes);
466        offsets.push(cur_offset);
467        rows.iter().enumerate().for_each(|(i, v)| {
468            // TODO: unboxing Union(Array(Union(...))) should probably be done earlier
469            let v = maybe_resolve_union(v);
470            if let Value::Array(a) = v {
471                cur_offset += OffsetSize::from_usize(a.len()).unwrap();
472                bit_util::set_bit(&mut list_nulls, i);
473            } else if let Value::Null = v {
474                // value is null, not incremented
475            } else {
476                cur_offset += OffsetSize::one();
477            }
478            offsets.push(cur_offset);
479        });
480        let valid_len = cur_offset.to_usize().unwrap();
481        let array_data = match list_field.data_type() {
482            DataType::Null => NullArray::new(valid_len).into_data(),
483            DataType::Boolean => {
484                let num_bytes = bit_util::ceil(valid_len, 8);
485                let mut bool_values = MutableBuffer::from_len_zeroed(num_bytes);
486                let mut bool_nulls =
487                    MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
488                let mut curr_index = 0;
489                rows.iter().for_each(|v| {
490                    if let Value::Array(vs) = v {
491                        vs.iter().for_each(|value| {
492                            if let Value::Boolean(child) = value {
493                                // if valid boolean, append value
494                                if *child {
495                                    bit_util::set_bit(&mut bool_values, curr_index);
496                                }
497                            } else {
498                                // null slot
499                                bit_util::unset_bit(&mut bool_nulls, curr_index);
500                            }
501                            curr_index += 1;
502                        });
503                    }
504                });
505                ArrayData::builder(list_field.data_type().clone())
506                    .len(valid_len)
507                    .add_buffer(bool_values.into())
508                    .null_bit_buffer(Some(bool_nulls.into()))
509                    .build()
510                    .unwrap()
511            }
512            DataType::Int8 => self.read_primitive_list_values::<Int8Type>(rows),
513            DataType::Int16 => self.read_primitive_list_values::<Int16Type>(rows),
514            DataType::Int32 => self.read_primitive_list_values::<Int32Type>(rows),
515            DataType::Int64 => self.read_primitive_list_values::<Int64Type>(rows),
516            DataType::UInt8 => self.read_primitive_list_values::<UInt8Type>(rows),
517            DataType::UInt16 => self.read_primitive_list_values::<UInt16Type>(rows),
518            DataType::UInt32 => self.read_primitive_list_values::<UInt32Type>(rows),
519            DataType::UInt64 => self.read_primitive_list_values::<UInt64Type>(rows),
520            DataType::Float16 => {
521                return Err(SchemaError("Float16 not supported".to_string()))
522            }
523            DataType::Float32 => self.read_primitive_list_values::<Float32Type>(rows),
524            DataType::Float64 => self.read_primitive_list_values::<Float64Type>(rows),
525            DataType::Timestamp(_, _)
526            | DataType::Date32
527            | DataType::Date64
528            | DataType::Time32(_)
529            | DataType::Time64(_) => {
530                return Err(SchemaError(
531                    "Temporal types are not yet supported, see ARROW-4803".to_string(),
532                ))
533            }
534            DataType::Utf8 => flatten_string_values(rows)
535                .into_iter()
536                .collect::<StringArray>()
537                .into_data(),
538            DataType::LargeUtf8 => flatten_string_values(rows)
539                .into_iter()
540                .collect::<LargeStringArray>()
541                .into_data(),
542            DataType::List(field) => {
543                let child = self.build_nested_list_array::<i32>(
544                    parent_field_name,
545                    &flatten_values(rows),
546                    field,
547                )?;
548                child.to_data()
549            }
550            DataType::LargeList(field) => {
551                let child = self.build_nested_list_array::<i64>(
552                    parent_field_name,
553                    &flatten_values(rows),
554                    field,
555                )?;
556                child.to_data()
557            }
558            DataType::Struct(fields) => {
559                // extract list values, with non-lists converted to Value::Null
560                let array_item_count = rows
561                    .iter()
562                    .map(|row| match maybe_resolve_union(row) {
563                        Value::Array(values) => values.len(),
564                        _ => 1,
565                    })
566                    .sum();
567                let num_bytes = bit_util::ceil(array_item_count, 8);
568                let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes);
569                let mut struct_index = 0;
570                let null_struct_array = vec![("null".to_string(), Value::Null)];
571                let rows: Vec<&Vec<(String, Value)>> = rows
572                    .iter()
573                    .map(|v| maybe_resolve_union(v))
574                    .flat_map(|row| {
575                        if let Value::Array(values) = row {
576                            values
577                                .iter()
578                                .map(maybe_resolve_union)
579                                .map(|v| match v {
580                                    Value::Record(record) => {
581                                        bit_util::set_bit(&mut null_buffer, struct_index);
582                                        struct_index += 1;
583                                        record
584                                    }
585                                    Value::Null => {
586                                        struct_index += 1;
587                                        &null_struct_array
588                                    }
589                                    other => panic!("expected Record, got {other:?}"),
590                                })
591                                .collect::<Vec<&Vec<(String, Value)>>>()
592                        } else {
593                            struct_index += 1;
594                            vec![&null_struct_array]
595                        }
596                    })
597                    .collect();
598
599                let sub_parent_field_name =
600                    format!("{}.{}", parent_field_name, list_field.name());
601                let arrays =
602                    self.build_struct_array(&rows, &sub_parent_field_name, fields)?;
603                let data_type = DataType::Struct(fields.clone());
604                ArrayDataBuilder::new(data_type)
605                    .len(rows.len())
606                    .null_bit_buffer(Some(null_buffer.into()))
607                    .child_data(arrays.into_iter().map(|a| a.to_data()).collect())
608                    .build()
609                    .unwrap()
610            }
611            datatype => {
612                return Err(SchemaError(format!(
613                    "Nested list of {datatype} not supported"
614                )));
615            }
616        };
617        // build list
618        let list_data = ArrayData::builder(DataType::List(Arc::new(list_field.clone())))
619            .len(list_len)
620            .add_buffer(Buffer::from_slice_ref(&offsets))
621            .add_child_data(array_data)
622            .null_bit_buffer(Some(list_nulls.into()))
623            .build()
624            .unwrap();
625        Ok(Arc::new(GenericListArray::<OffsetSize>::from(list_data)))
626    }
627
628    /// Builds the child values of a `StructArray`, falling short of constructing the StructArray.
629    /// The function does not construct the StructArray as some callers would want the child arrays.
630    ///
631    /// *Note*: The function is recursive, and will read nested structs.
632    fn build_struct_array(
633        &self,
634        rows: RecordSlice,
635        parent_field_name: &str,
636        struct_fields: &Fields,
637    ) -> ArrowResult<Vec<ArrayRef>> {
638        let arrays: ArrowResult<Vec<ArrayRef>> = struct_fields
639            .iter()
640            .map(|field| {
641                let field_path = if parent_field_name.is_empty() {
642                    field.name().to_string()
643                } else {
644                    format!("{}.{}", parent_field_name, field.name())
645                };
646                let arr = match field.data_type() {
647                    DataType::Null => Arc::new(NullArray::new(rows.len())) as ArrayRef,
648                    DataType::Boolean => self.build_boolean_array(rows, &field_path),
649                    DataType::Float64 => {
650                        self.build_primitive_array::<Float64Type>(rows, &field_path)
651                    }
652                    DataType::Float32 => {
653                        self.build_primitive_array::<Float32Type>(rows, &field_path)
654                    }
655                    DataType::Int64 => {
656                        self.build_primitive_array::<Int64Type>(rows, &field_path)
657                    }
658                    DataType::Int32 => {
659                        self.build_primitive_array::<Int32Type>(rows, &field_path)
660                    }
661                    DataType::Int16 => {
662                        self.build_primitive_array::<Int16Type>(rows, &field_path)
663                    }
664                    DataType::Int8 => {
665                        self.build_primitive_array::<Int8Type>(rows, &field_path)
666                    }
667                    DataType::UInt64 => {
668                        self.build_primitive_array::<UInt64Type>(rows, &field_path)
669                    }
670                    DataType::UInt32 => {
671                        self.build_primitive_array::<UInt32Type>(rows, &field_path)
672                    }
673                    DataType::UInt16 => {
674                        self.build_primitive_array::<UInt16Type>(rows, &field_path)
675                    }
676                    DataType::UInt8 => {
677                        self.build_primitive_array::<UInt8Type>(rows, &field_path)
678                    }
679                    // TODO: this is incomplete
680                    DataType::Timestamp(unit, _) => match unit {
681                        TimeUnit::Second => self
682                            .build_primitive_array::<TimestampSecondType>(
683                                rows,
684                                &field_path,
685                            ),
686                        TimeUnit::Microsecond => self
687                            .build_primitive_array::<TimestampMicrosecondType>(
688                                rows,
689                                &field_path,
690                            ),
691                        TimeUnit::Millisecond => self
692                            .build_primitive_array::<TimestampMillisecondType>(
693                                rows,
694                                &field_path,
695                            ),
696                        TimeUnit::Nanosecond => self
697                            .build_primitive_array::<TimestampNanosecondType>(
698                                rows,
699                                &field_path,
700                            ),
701                    },
702                    DataType::Date64 => {
703                        self.build_primitive_array::<Date64Type>(rows, &field_path)
704                    }
705                    DataType::Date32 => {
706                        self.build_primitive_array::<Date32Type>(rows, &field_path)
707                    }
708                    DataType::Time64(unit) => match unit {
709                        TimeUnit::Microsecond => self
710                            .build_primitive_array::<Time64MicrosecondType>(
711                                rows,
712                                &field_path,
713                            ),
714                        TimeUnit::Nanosecond => self
715                            .build_primitive_array::<Time64NanosecondType>(
716                                rows,
717                                &field_path,
718                            ),
719                        t => {
720                            return Err(SchemaError(format!(
721                                "TimeUnit {t:?} not supported with Time64"
722                            )))
723                        }
724                    },
725                    DataType::Time32(unit) => match unit {
726                        TimeUnit::Second => self
727                            .build_primitive_array::<Time32SecondType>(rows, &field_path),
728                        TimeUnit::Millisecond => self
729                            .build_primitive_array::<Time32MillisecondType>(
730                                rows,
731                                &field_path,
732                            ),
733                        t => {
734                            return Err(SchemaError(format!(
735                                "TimeUnit {t:?} not supported with Time32"
736                            )))
737                        }
738                    },
739                    DataType::Utf8 | DataType::LargeUtf8 => Arc::new(
740                        rows.iter()
741                            .map(|row| {
742                                let maybe_value = self.field_lookup(&field_path, row);
743                                match maybe_value {
744                                    None => Ok(None),
745                                    Some(v) => resolve_string(v),
746                                }
747                            })
748                            .collect::<ArrowResult<StringArray>>()?,
749                    )
750                        as ArrayRef,
751                    DataType::Binary | DataType::LargeBinary => Arc::new(
752                        rows.iter()
753                            .map(|row| {
754                                let maybe_value = self.field_lookup(&field_path, row);
755                                maybe_value.and_then(resolve_bytes)
756                            })
757                            .collect::<BinaryArray>(),
758                    )
759                        as ArrayRef,
760                    DataType::FixedSizeBinary(ref size) => {
761                        Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size(
762                            rows.iter().map(|row| {
763                                let maybe_value = self.field_lookup(&field_path, row);
764                                maybe_value.and_then(|v| resolve_fixed(v, *size as usize))
765                            }),
766                            *size,
767                        )?) as ArrayRef
768                    }
769                    DataType::List(ref list_field) => {
770                        match list_field.data_type() {
771                            DataType::Dictionary(ref key_ty, _) => {
772                                self.build_wrapped_list_array(rows, &field_path, key_ty)?
773                            }
774                            _ => {
775                                // extract rows by name
776                                let extracted_rows = rows
777                                    .iter()
778                                    .map(|row| {
779                                        self.field_lookup(&field_path, row)
780                                            .unwrap_or(&Value::Null)
781                                    })
782                                    .collect::<Vec<&Value>>();
783                                self.build_nested_list_array::<i32>(
784                                    &field_path,
785                                    &extracted_rows,
786                                    list_field,
787                                )?
788                            }
789                        }
790                    }
791                    DataType::Dictionary(ref key_ty, ref val_ty) => self
792                        .build_string_dictionary_array(
793                            rows,
794                            &field_path,
795                            key_ty,
796                            val_ty,
797                        )?,
798                    DataType::Struct(fields) => {
799                        let len = rows.len();
800                        let num_bytes = bit_util::ceil(len, 8);
801                        let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes);
802                        let empty_vec = vec![];
803                        let struct_rows = rows
804                            .iter()
805                            .enumerate()
806                            .map(|(i, row)| (i, self.field_lookup(&field_path, row)))
807                            .map(|(i, v)| {
808                                let v = v.map(maybe_resolve_union);
809                                match v {
810                                    Some(Value::Record(value)) => {
811                                        bit_util::set_bit(&mut null_buffer, i);
812                                        value
813                                    }
814                                    None | Some(Value::Null) => &empty_vec,
815                                    other => {
816                                        panic!("expected struct got {other:?}");
817                                    }
818                                }
819                            })
820                            .collect::<Vec<&Vec<(String, Value)>>>();
821                        let arrays =
822                            self.build_struct_array(&struct_rows, &field_path, fields)?;
823                        // construct a struct array's data in order to set null buffer
824                        let data_type = DataType::Struct(fields.clone());
825                        let data = ArrayDataBuilder::new(data_type)
826                            .len(len)
827                            .null_bit_buffer(Some(null_buffer.into()))
828                            .child_data(arrays.into_iter().map(|a| a.to_data()).collect())
829                            .build()?;
830                        make_array(data)
831                    }
832                    _ => {
833                        return Err(SchemaError(format!(
834                            "type {} not supported",
835                            field.data_type()
836                        )))
837                    }
838                };
839                Ok(arr)
840            })
841            .collect();
842        arrays
843    }
844
845    /// Read the primitive list's values into ArrayData
846    fn read_primitive_list_values<T>(&self, rows: &[&Value]) -> ArrayData
847    where
848        T: ArrowPrimitiveType + ArrowNumericType,
849        T::Native: NumCast,
850    {
851        let values = rows
852            .iter()
853            .flat_map(|row| {
854                let row = maybe_resolve_union(row);
855                if let Value::Array(values) = row {
856                    values
857                        .iter()
858                        .map(resolve_item::<T>)
859                        .collect::<Vec<Option<T::Native>>>()
860                } else if let Some(f) = resolve_item::<T>(row) {
861                    vec![Some(f)]
862                } else {
863                    vec![]
864                }
865            })
866            .collect::<Vec<Option<T::Native>>>();
867        let array = values.iter().collect::<PrimitiveArray<T>>();
868        array.to_data()
869    }
870
871    fn field_lookup<'b>(
872        &self,
873        name: &str,
874        row: &'b [(String, Value)],
875    ) -> Option<&'b Value> {
876        self.schema_lookup
877            .get(name)
878            .and_then(|i| row.get(*i))
879            .map(|o| &o.1)
880    }
881}
882
883/// Flattens a list of Avro values, by flattening lists, and treating all other values as
884/// single-value lists.
885/// This is used to read into nested lists (list of list, list of struct) and non-dictionary lists.
886#[inline]
887fn flatten_values<'a>(values: &[&'a Value]) -> Vec<&'a Value> {
888    values
889        .iter()
890        .flat_map(|row| {
891            let v = maybe_resolve_union(row);
892            if let Value::Array(values) = v {
893                values.iter().collect()
894            } else {
895                // we interpret a scalar as a single-value list to minimise data loss
896                vec![v]
897            }
898        })
899        .collect()
900}
901
902/// Flattens a list into string values, dropping Value::Null in the process.
903/// This is useful for interpreting any Avro array as string, dropping nulls.
904/// See `value_as_string`.
905#[inline]
906fn flatten_string_values(values: &[&Value]) -> Vec<Option<String>> {
907    values
908        .iter()
909        .flat_map(|row| {
910            let row = maybe_resolve_union(row);
911            if let Value::Array(values) = row {
912                values
913                    .iter()
914                    .map(|s| resolve_string(s).ok().flatten())
915                    .collect::<Vec<Option<_>>>()
916            } else if let Value::Null = row {
917                vec![]
918            } else {
919                vec![resolve_string(row).ok().flatten()]
920            }
921        })
922        .collect::<Vec<Option<_>>>()
923}
924
925/// Reads an Avro value as a string, regardless of its type.
926/// This is useful if the expected datatype is a string, in which case we preserve
927/// all the values regardless of they type.
928fn resolve_string(v: &Value) -> ArrowResult<Option<String>> {
929    let v = if let Value::Union(_, b) = v { b } else { v };
930    match v {
931        Value::String(s) => Ok(Some(s.clone())),
932        Value::Bytes(bytes) => String::from_utf8(bytes.to_vec())
933            .map_err(|e| AvroError::new(AvroErrorDetails::ConvertToUtf8(e)))
934            .map(Some),
935        Value::Enum(_, s) => Ok(Some(s.clone())),
936        Value::Null => Ok(None),
937        other => Err(AvroError::new(AvroErrorDetails::GetString(other.clone()))),
938    }
939    .map_err(|e| SchemaError(format!("expected resolvable string : {e}")))
940}
941
942fn resolve_u8(v: &Value) -> Option<u8> {
943    let v = match v {
944        Value::Union(_, inner) => inner.as_ref(),
945        _ => v,
946    };
947
948    match v {
949        Value::Int(n) => u8::try_from(*n).ok(),
950        Value::Long(n) => u8::try_from(*n).ok(),
951        _ => None,
952    }
953}
954
955fn resolve_bytes(v: &Value) -> Option<Vec<u8>> {
956    let v = match v {
957        Value::Union(_, inner) => inner.as_ref(),
958        _ => v,
959    };
960
961    match v {
962        Value::Bytes(bytes) => Some(bytes.clone()),
963        Value::String(s) => Some(s.as_bytes().to_vec()),
964        Value::Array(items) => items.iter().map(resolve_u8).collect::<Option<Vec<u8>>>(),
965        _ => None,
966    }
967}
968
969fn resolve_fixed(v: &Value, size: usize) -> Option<Vec<u8>> {
970    let v = if let Value::Union(_, b) = v { b } else { v };
971    match v {
972        Value::Fixed(n, bytes) => {
973            if *n == size {
974                Some(bytes.clone())
975            } else {
976                None
977            }
978        }
979        _ => None,
980    }
981}
982
983fn resolve_boolean(value: &Value) -> Option<bool> {
984    let v = if let Value::Union(_, b) = value {
985        b
986    } else {
987        value
988    };
989    match v {
990        Value::Boolean(boolean) => Some(*boolean),
991        _ => None,
992    }
993}
994
995trait Resolver: ArrowPrimitiveType {
996    fn resolve(value: &Value) -> Option<Self::Native>;
997}
998
999fn resolve_item<T: Resolver>(value: &Value) -> Option<T::Native> {
1000    T::resolve(value)
1001}
1002
1003fn maybe_resolve_union(value: &Value) -> &Value {
1004    if SchemaKind::from(value) == SchemaKind::Union {
1005        // Pull out the Union, and attempt to resolve against it.
1006        match value {
1007            Value::Union(_, b) => b,
1008            _ => unreachable!(),
1009        }
1010    } else {
1011        value
1012    }
1013}
1014
1015impl<N> Resolver for N
1016where
1017    N: ArrowNumericType,
1018    N::Native: NumCast,
1019{
1020    fn resolve(value: &Value) -> Option<Self::Native> {
1021        let value = maybe_resolve_union(value);
1022        match value {
1023            Value::Int(i) | Value::TimeMillis(i) | Value::Date(i) => NumCast::from(*i),
1024            Value::Long(l)
1025            | Value::TimeMicros(l)
1026            | Value::TimestampMillis(l)
1027            | Value::TimestampMicros(l) => NumCast::from(*l),
1028            Value::Float(f) => NumCast::from(*f),
1029            Value::Double(f) => NumCast::from(*f),
1030            Value::Duration(_d) => unimplemented!(), // shenanigans type
1031            Value::Null => None,
1032            _ => unreachable!(),
1033        }
1034    }
1035}
1036
1037#[cfg(test)]
1038mod test {
1039    use crate::avro_to_arrow::{Reader, ReaderBuilder};
1040    use arrow::array::Array;
1041    use arrow::datatypes::DataType;
1042    use arrow::datatypes::{Field, TimeUnit};
1043    use datafusion_common::assert_batches_eq;
1044    use datafusion_common::cast::{
1045        as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array,
1046    };
1047    use std::fs::File;
1048    use std::sync::Arc;
1049
1050    fn build_reader(name: &'_ str, batch_size: usize) -> Reader<'_, File> {
1051        let testdata = datafusion_common::test_util::arrow_test_data();
1052        let filename = format!("{testdata}/avro/{name}");
1053        let builder = ReaderBuilder::new()
1054            .read_schema()
1055            .with_batch_size(batch_size);
1056        builder.build(File::open(filename).unwrap()).unwrap()
1057    }
1058
1059    // TODO: Fixed, Enum, Dictionary
1060
1061    #[test]
1062    fn test_time_avro_milliseconds() {
1063        let mut reader = build_reader("alltypes_plain.avro", 10);
1064        let batch = reader.next().unwrap().unwrap();
1065
1066        assert_eq!(11, batch.num_columns());
1067        assert_eq!(8, batch.num_rows());
1068
1069        let schema = reader.schema();
1070        let batch_schema = batch.schema();
1071        assert_eq!(schema, batch_schema);
1072
1073        let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
1074        assert_eq!(
1075            &DataType::Timestamp(TimeUnit::Microsecond, None),
1076            timestamp_col.1.data_type()
1077        );
1078        let timestamp_array =
1079            as_timestamp_microsecond_array(batch.column(timestamp_col.0)).unwrap();
1080        for i in 0..timestamp_array.len() {
1081            assert!(timestamp_array.is_valid(i));
1082        }
1083        assert_eq!(1235865600000000, timestamp_array.value(0));
1084        assert_eq!(1235865660000000, timestamp_array.value(1));
1085        assert_eq!(1238544000000000, timestamp_array.value(2));
1086        assert_eq!(1238544060000000, timestamp_array.value(3));
1087        assert_eq!(1233446400000000, timestamp_array.value(4));
1088        assert_eq!(1233446460000000, timestamp_array.value(5));
1089        assert_eq!(1230768000000000, timestamp_array.value(6));
1090        assert_eq!(1230768060000000, timestamp_array.value(7));
1091    }
1092
1093    #[test]
1094    fn test_avro_read_list() {
1095        let mut reader = build_reader("list_columns.avro", 3);
1096        let schema = reader.schema();
1097        let (col_id_index, _) = schema.column_with_name("int64_list").unwrap();
1098        let batch = reader.next().unwrap().unwrap();
1099        assert_eq!(batch.num_columns(), 2);
1100        assert_eq!(batch.num_rows(), 3);
1101        let a_array = as_list_array(batch.column(col_id_index)).unwrap();
1102        assert_eq!(
1103            *a_array.data_type(),
1104            DataType::List(Arc::new(Field::new("element", DataType::Int64, true)))
1105        );
1106        let array = a_array.value(0);
1107        assert_eq!(*array.data_type(), DataType::Int64);
1108
1109        assert_eq!(
1110            6,
1111            as_int64_array(&array)
1112                .unwrap()
1113                .iter()
1114                .flatten()
1115                .sum::<i64>()
1116        );
1117    }
1118    #[test]
1119    fn test_avro_read_nested_list() {
1120        let mut reader = build_reader("nested_lists.snappy.avro", 3);
1121        let batch = reader.next().unwrap().unwrap();
1122        assert_eq!(batch.num_columns(), 2);
1123        assert_eq!(batch.num_rows(), 3);
1124    }
1125
1126    #[test]
1127    fn test_complex_list() {
1128        let schema = apache_avro::Schema::parse_str(
1129            r#"
1130            {
1131              "type": "record",
1132              "name": "r1",
1133              "fields": [
1134                {
1135                  "name": "headers",
1136                  "type": ["null", {
1137                        "type": "array",
1138                        "items": ["null",{
1139                            "name":"r2",
1140                            "type": "record",
1141                            "fields":[
1142                                {"name":"name", "type": ["null", "string"], "default": null},
1143                                {"name":"value", "type": ["null", "string"], "default": null}
1144                            ]
1145                        }]
1146                    }],
1147                    "default": null
1148                }
1149              ]
1150            }"#,
1151        )
1152        .unwrap();
1153        let r1 = apache_avro::to_value(serde_json::json!({
1154            "headers": [
1155                {
1156                    "name": "a",
1157                    "value": "b"
1158                }
1159            ]
1160        }))
1161        .unwrap()
1162        .resolve(&schema)
1163        .unwrap();
1164
1165        let mut w = apache_avro::Writer::new(&schema, vec![]);
1166        w.append(r1).unwrap();
1167        let bytes = w.into_inner().unwrap();
1168
1169        let mut reader = ReaderBuilder::new()
1170            .read_schema()
1171            .with_batch_size(2)
1172            .build(std::io::Cursor::new(bytes))
1173            .unwrap();
1174
1175        let batch = reader.next().unwrap().unwrap();
1176        assert_eq!(batch.num_rows(), 1);
1177        assert_eq!(batch.num_columns(), 1);
1178        let expected = [
1179            "+-----------------------+",
1180            "| headers               |",
1181            "+-----------------------+",
1182            "| [{name: a, value: b}] |",
1183            "+-----------------------+",
1184        ];
1185        assert_batches_eq!(expected, &[batch]);
1186    }
1187
1188    #[test]
1189    fn test_complex_struct() {
1190        let schema = apache_avro::Schema::parse_str(
1191            r#"
1192        {
1193          "type": "record",
1194          "name": "r1",
1195          "fields": [
1196            {
1197              "name": "dns",
1198              "type": [
1199                "null",
1200                {
1201                  "type": "record",
1202                  "name": "r13",
1203                  "fields": [
1204                    {
1205                      "name": "answers",
1206                      "type": [
1207                        "null",
1208                        {
1209                          "type": "array",
1210                          "items": [
1211                            "null",
1212                            {
1213                              "type": "record",
1214                              "name": "r292",
1215                              "fields": [
1216                                {
1217                                  "name": "class",
1218                                  "type": ["null", "string"],
1219                                  "default": null
1220                                },
1221                                {
1222                                  "name": "data",
1223                                  "type": ["null", "string"],
1224                                  "default": null
1225                                },
1226                                {
1227                                  "name": "name",
1228                                  "type": ["null", "string"],
1229                                  "default": null
1230                                },
1231                                {
1232                                  "name": "ttl",
1233                                  "type": ["null", "long"],
1234                                  "default": null
1235                                },
1236                                {
1237                                  "name": "type",
1238                                  "type": ["null", "string"],
1239                                  "default": null
1240                                }
1241                              ]
1242                            }
1243                          ]
1244                        }
1245                      ],
1246                      "default": null
1247                    },
1248                    {
1249                      "name": "header_flags",
1250                      "type": [
1251                        "null",
1252                        {
1253                          "type": "array",
1254                          "items": ["null", "string"]
1255                        }
1256                      ],
1257                      "default": null
1258                    },
1259                    {
1260                      "name": "id",
1261                      "type": ["null", "string"],
1262                      "default": null
1263                    },
1264                    {
1265                      "name": "op_code",
1266                      "type": ["null", "string"],
1267                      "default": null
1268                    },
1269                    {
1270                      "name": "question",
1271                      "type": [
1272                        "null",
1273                        {
1274                          "type": "record",
1275                          "name": "r288",
1276                          "fields": [
1277                            {
1278                              "name": "class",
1279                              "type": ["null", "string"],
1280                              "default": null
1281                            },
1282                            {
1283                              "name": "name",
1284                              "type": ["null", "string"],
1285                              "default": null
1286                            },
1287                            {
1288                              "name": "registered_domain",
1289                              "type": ["null", "string"],
1290                              "default": null
1291                            },
1292                            {
1293                              "name": "subdomain",
1294                              "type": ["null", "string"],
1295                              "default": null
1296                            },
1297                            {
1298                              "name": "top_level_domain",
1299                              "type": ["null", "string"],
1300                              "default": null
1301                            },
1302                            {
1303                              "name": "type",
1304                              "type": ["null", "string"],
1305                              "default": null
1306                            }
1307                          ]
1308                        }
1309                      ],
1310                      "default": null
1311                    },
1312                    {
1313                      "name": "resolved_ip",
1314                      "type": [
1315                        "null",
1316                        {
1317                          "type": "array",
1318                          "items": ["null", "string"]
1319                        }
1320                      ],
1321                      "default": null
1322                    },
1323                    {
1324                      "name": "response_code",
1325                      "type": ["null", "string"],
1326                      "default": null
1327                    },
1328                    {
1329                      "name": "type",
1330                      "type": ["null", "string"],
1331                      "default": null
1332                    }
1333                  ]
1334                }
1335              ],
1336              "default": null
1337            }
1338          ]
1339        }"#,
1340        )
1341        .unwrap();
1342
1343        let jv1 = serde_json::json!({
1344          "dns": {
1345            "answers": [
1346                {
1347                    "data": "CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=",
1348                    "type": "1"
1349                },
1350                {
1351                    "data": "CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=",
1352                    "type": "1"
1353                },
1354                {
1355                    "data": "CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=",
1356                    "type": "1"
1357                }
1358            ],
1359            "question": {
1360                "name": "security.ubuntu.com",
1361                "type": "A"
1362            },
1363            "resolved_ip": [
1364                "67.43.156.1",
1365                "67.43.156.2",
1366                "67.43.156.3"
1367            ],
1368            "response_code": "0"
1369          }
1370        });
1371        let r1 = apache_avro::to_value(jv1)
1372            .unwrap()
1373            .resolve(&schema)
1374            .unwrap();
1375
1376        let mut w = apache_avro::Writer::new(&schema, vec![]);
1377        w.append(r1).unwrap();
1378        let bytes = w.into_inner().unwrap();
1379
1380        let mut reader = ReaderBuilder::new()
1381            .read_schema()
1382            .with_batch_size(1)
1383            .build(std::io::Cursor::new(bytes))
1384            .unwrap();
1385
1386        let batch = reader.next().unwrap().unwrap();
1387        assert_eq!(batch.num_rows(), 1);
1388        assert_eq!(batch.num_columns(), 1);
1389
1390        let expected = [
1391            "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1392            "| dns                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          |",
1393            "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1394            "| {answers: [{class: , data: CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=, name: , ttl: , type: 1}, {class: , data: CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=, name: , ttl: , type: 1}, {class: , data: CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=, name: , ttl: , type: 1}], header_flags: , id: , op_code: , question: {class: , name: security.ubuntu.com, registered_domain: , subdomain: , top_level_domain: , type: A}, resolved_ip: [67.43.156.1, 67.43.156.2, 67.43.156.3], response_code: 0, type: } |",
1395            "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1396        ];
1397        assert_batches_eq!(expected, &[batch]);
1398    }
1399
1400    #[test]
1401    fn test_deep_nullable_struct() {
1402        let schema = apache_avro::Schema::parse_str(
1403            r#"
1404            {
1405                "type": "record",
1406                "name": "r1",
1407                "fields": [
1408                  {
1409                    "name": "col1",
1410                    "type": [
1411                      "null",
1412                      {
1413                        "type": "record",
1414                        "name": "r2",
1415                        "fields": [
1416                          {
1417                            "name": "col2",
1418                            "type": [
1419                              "null",
1420                              {
1421                                "type": "record",
1422                                "name": "r3",
1423                                "fields": [
1424                                  {
1425                                    "name": "col3",
1426                                    "type": [
1427                                      "null",
1428                                      {
1429                                        "type": "record",
1430                                        "name": "r4",
1431                                        "fields": [
1432                                          {
1433                                            "name": "col4",
1434                                            "type": [
1435                                              "null",
1436                                              {
1437                                                "type": "record",
1438                                                "name": "r5",
1439                                                "fields": [
1440                                                  {
1441                                                    "name": "col5",
1442                                                    "type": ["null", "string"]
1443                                                  }
1444                                                ]
1445                                              }
1446                                            ]
1447                                          }
1448                                        ]
1449                                      }
1450                                    ]
1451                                  }
1452                                ]
1453                              }
1454                            ]
1455                          }
1456                        ]
1457                      }
1458                    ]
1459                  }
1460                ]
1461              }
1462            "#,
1463        )
1464        .unwrap();
1465        let r1 = apache_avro::to_value(serde_json::json!({
1466            "col1": {
1467                "col2": {
1468                    "col3": {
1469                        "col4": {
1470                            "col5": "hello"
1471                        }
1472                    }
1473                }
1474            }
1475        }))
1476        .unwrap()
1477        .resolve(&schema)
1478        .unwrap();
1479        let r2 = apache_avro::to_value(serde_json::json!({
1480            "col1": {
1481                "col2": {
1482                    "col3": {
1483                        "col4": {
1484                            "col5": null
1485                        }
1486                    }
1487                }
1488            }
1489        }))
1490        .unwrap()
1491        .resolve(&schema)
1492        .unwrap();
1493        let r3 = apache_avro::to_value(serde_json::json!({
1494            "col1": {
1495                "col2": {
1496                    "col3": null
1497                }
1498            }
1499        }))
1500        .unwrap()
1501        .resolve(&schema)
1502        .unwrap();
1503        let r4 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1504            .unwrap()
1505            .resolve(&schema)
1506            .unwrap();
1507
1508        let mut w = apache_avro::Writer::new(&schema, vec![]);
1509        w.append(r1).unwrap();
1510        w.append(r2).unwrap();
1511        w.append(r3).unwrap();
1512        w.append(r4).unwrap();
1513        let bytes = w.into_inner().unwrap();
1514
1515        let mut reader = ReaderBuilder::new()
1516            .read_schema()
1517            .with_batch_size(4)
1518            .build(std::io::Cursor::new(bytes))
1519            .unwrap();
1520
1521        let batch = reader.next().unwrap().unwrap();
1522
1523        let expected = [
1524            "+---------------------------------------+",
1525            "| col1                                  |",
1526            "+---------------------------------------+",
1527            "| {col2: {col3: {col4: {col5: hello}}}} |",
1528            "| {col2: {col3: {col4: {col5: }}}}      |",
1529            "| {col2: {col3: }}                      |",
1530            "|                                       |",
1531            "+---------------------------------------+",
1532        ];
1533        assert_batches_eq!(expected, &[batch]);
1534    }
1535
1536    #[test]
1537    fn test_avro_nullable_struct() {
1538        let schema = apache_avro::Schema::parse_str(
1539            r#"
1540            {
1541              "type": "record",
1542              "name": "r1",
1543              "fields": [
1544                {
1545                  "name": "col1",
1546                  "type": [
1547                    "null",
1548                    {
1549                      "type": "record",
1550                      "name": "r2",
1551                      "fields": [
1552                        {
1553                          "name": "col2",
1554                          "type": ["null", "string"]
1555                        }
1556                      ]
1557                    }
1558                  ],
1559                  "default": null
1560                }
1561              ]
1562            }"#,
1563        )
1564        .unwrap();
1565        let r1 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1566            .unwrap()
1567            .resolve(&schema)
1568            .unwrap();
1569        let r2 = apache_avro::to_value(serde_json::json!({
1570            "col1": {
1571                "col2": "hello"
1572            }
1573        }))
1574        .unwrap()
1575        .resolve(&schema)
1576        .unwrap();
1577        let r3 = apache_avro::to_value(serde_json::json!({
1578            "col1": {
1579                "col2": null
1580            }
1581        }))
1582        .unwrap()
1583        .resolve(&schema)
1584        .unwrap();
1585
1586        let mut w = apache_avro::Writer::new(&schema, vec![]);
1587        w.append(r1).unwrap();
1588        w.append(r2).unwrap();
1589        w.append(r3).unwrap();
1590        let bytes = w.into_inner().unwrap();
1591
1592        let mut reader = ReaderBuilder::new()
1593            .read_schema()
1594            .with_batch_size(3)
1595            .build(std::io::Cursor::new(bytes))
1596            .unwrap();
1597        let batch = reader.next().unwrap().unwrap();
1598        assert_eq!(batch.num_rows(), 3);
1599        assert_eq!(batch.num_columns(), 1);
1600
1601        let expected = [
1602            "+---------------+",
1603            "| col1          |",
1604            "+---------------+",
1605            "|               |",
1606            "| {col2: hello} |",
1607            "| {col2: }      |",
1608            "+---------------+",
1609        ];
1610        assert_batches_eq!(expected, &[batch]);
1611    }
1612
1613    #[test]
1614    fn test_avro_nullable_struct_array() {
1615        let schema = apache_avro::Schema::parse_str(
1616            r#"
1617            {
1618              "type": "record",
1619              "name": "r1",
1620              "fields": [
1621                {
1622                  "name": "col1",
1623                  "type": [
1624                    "null",
1625                    {
1626                        "type": "array",
1627                        "items": {
1628                            "type": [
1629                                "null",
1630                                {
1631                                    "type": "record",
1632                                    "name": "Item",
1633                                    "fields": [
1634                                        {
1635                                            "name": "id",
1636                                            "type": "long"
1637                                        }
1638                                    ]
1639                                }
1640                            ]
1641                        }
1642                    }
1643                  ],
1644                  "default": null
1645                }
1646              ]
1647            }"#,
1648        )
1649        .unwrap();
1650        let jv1 = serde_json::json!({
1651            "col1": [
1652                {
1653                    "id": 234
1654                },
1655                {
1656                    "id": 345
1657                }
1658            ]
1659        });
1660        let r1 = apache_avro::to_value(jv1)
1661            .unwrap()
1662            .resolve(&schema)
1663            .unwrap();
1664        let r2 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1665            .unwrap()
1666            .resolve(&schema)
1667            .unwrap();
1668
1669        let mut w = apache_avro::Writer::new(&schema, vec![]);
1670        for _i in 0..5 {
1671            w.append(r1.clone()).unwrap();
1672        }
1673        w.append(r2).unwrap();
1674        let bytes = w.into_inner().unwrap();
1675
1676        let mut reader = ReaderBuilder::new()
1677            .read_schema()
1678            .with_batch_size(20)
1679            .build(std::io::Cursor::new(bytes))
1680            .unwrap();
1681        let batch = reader.next().unwrap().unwrap();
1682        assert_eq!(batch.num_rows(), 6);
1683        assert_eq!(batch.num_columns(), 1);
1684
1685        let expected = [
1686            "+------------------------+",
1687            "| col1                   |",
1688            "+------------------------+",
1689            "| [{id: 234}, {id: 345}] |",
1690            "| [{id: 234}, {id: 345}] |",
1691            "| [{id: 234}, {id: 345}] |",
1692            "| [{id: 234}, {id: 345}] |",
1693            "| [{id: 234}, {id: 345}] |",
1694            "|                        |",
1695            "+------------------------+",
1696        ];
1697        assert_batches_eq!(expected, &[batch]);
1698    }
1699
1700    #[test]
1701    fn test_avro_iterator() {
1702        let reader = build_reader("alltypes_plain.avro", 5);
1703        let schema = reader.schema();
1704        let (col_id_index, _) = schema.column_with_name("id").unwrap();
1705
1706        let mut sum_num_rows = 0;
1707        let mut num_batches = 0;
1708        let mut sum_id = 0;
1709        for batch in reader {
1710            let batch = batch.unwrap();
1711            assert_eq!(11, batch.num_columns());
1712            sum_num_rows += batch.num_rows();
1713            num_batches += 1;
1714            let batch_schema = batch.schema();
1715            assert_eq!(schema, batch_schema);
1716            let a_array = as_int32_array(batch.column(col_id_index)).unwrap();
1717            sum_id += (0..a_array.len()).map(|i| a_array.value(i)).sum::<i32>();
1718        }
1719        assert_eq!(8, sum_num_rows);
1720        assert_eq!(2, num_batches);
1721        assert_eq!(28, sum_id);
1722    }
1723}