1use apache_avro::schema::RecordSchema;
21use apache_avro::{
22 error::Details as AvroErrorDetails,
23 schema::{Schema as AvroSchema, SchemaKind},
24 types::Value,
25 Error as AvroError, Reader as AvroReader,
26};
27use arrow::array::{
28 make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef,
29 BooleanBuilder, LargeStringArray, ListBuilder, NullArray, OffsetSizeTrait,
30 PrimitiveArray, StringArray, StringBuilder, StringDictionaryBuilder,
31};
32use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray};
33use arrow::buffer::{Buffer, MutableBuffer};
34use arrow::datatypes::{
35 ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type,
36 Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
37 Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
38 Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
39 TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
40 UInt8Type,
41};
42use arrow::datatypes::{Fields, SchemaRef};
43use arrow::error::ArrowError;
44use arrow::error::ArrowError::SchemaError;
45use arrow::error::Result as ArrowResult;
46use arrow::record_batch::RecordBatch;
47use arrow::util::bit_util;
48use datafusion_common::arrow_err;
49use datafusion_common::error::{DataFusionError, Result};
50use num_traits::NumCast;
51use std::collections::BTreeMap;
52use std::io::Read;
53use std::sync::Arc;
54
55type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>];
56
57pub struct AvroArrowArrayReader<'a, R: Read> {
58 reader: AvroReader<'a, R>,
59 schema: SchemaRef,
60 schema_lookup: BTreeMap<String, usize>,
61}
62
63impl<R: Read> AvroArrowArrayReader<'_, R> {
64 pub fn try_new(reader: R, schema: SchemaRef) -> Result<Self> {
65 let reader = AvroReader::new(reader)?;
66 let writer_schema = reader.writer_schema().clone();
67 let schema_lookup = Self::schema_lookup(writer_schema)?;
68 Ok(Self {
69 reader,
70 schema,
71 schema_lookup,
72 })
73 }
74
75 pub fn schema_lookup(schema: AvroSchema) -> Result<BTreeMap<String, usize>> {
76 match schema {
77 AvroSchema::Record(RecordSchema {
78 fields, mut lookup, ..
79 }) => {
80 for field in fields {
81 Self::child_schema_lookup(&field.name, &field.schema, &mut lookup)?;
82 }
83 Ok(lookup)
84 }
85 _ => arrow_err!(SchemaError(
86 "expected avro schema to be a record".to_string(),
87 )),
88 }
89 }
90
91 fn child_schema_lookup<'b>(
92 parent_field_name: &str,
93 schema: &AvroSchema,
94 schema_lookup: &'b mut BTreeMap<String, usize>,
95 ) -> Result<&'b BTreeMap<String, usize>> {
96 match schema {
97 AvroSchema::Union(us) => {
98 let has_nullable = us
99 .find_schema_with_known_schemata::<apache_avro::Schema>(
100 &Value::Null,
101 None,
102 &None,
103 )
104 .is_some();
105 let sub_schemas = us.variants();
106 if has_nullable && sub_schemas.len() == 2 {
107 if let Some(sub_schema) =
108 sub_schemas.iter().find(|&s| !matches!(s, AvroSchema::Null))
109 {
110 Self::child_schema_lookup(
111 parent_field_name,
112 sub_schema,
113 schema_lookup,
114 )?;
115 }
116 }
117 }
118 AvroSchema::Record(RecordSchema { fields, lookup, .. }) => {
119 lookup.iter().for_each(|(field_name, pos)| {
120 schema_lookup
121 .insert(format!("{parent_field_name}.{field_name}"), *pos);
122 });
123
124 for field in fields {
125 let sub_parent_field_name =
126 format!("{}.{}", parent_field_name, field.name);
127 Self::child_schema_lookup(
128 &sub_parent_field_name,
129 &field.schema,
130 schema_lookup,
131 )?;
132 }
133 }
134 AvroSchema::Array(schema) => {
135 let sub_parent_field_name = format!("{parent_field_name}.element");
136 Self::child_schema_lookup(
137 &sub_parent_field_name,
138 &schema.items,
139 schema_lookup,
140 )?;
141 }
142 _ => (),
143 }
144 Ok(schema_lookup)
145 }
146
147 pub fn next_batch(&mut self, batch_size: usize) -> Option<ArrowResult<RecordBatch>> {
149 let rows_result = self
150 .reader
151 .by_ref()
152 .take(batch_size)
153 .map(|value| match value {
154 Ok(Value::Record(v)) => Ok(v),
155 Err(e) => Err(ArrowError::ParseError(format!(
156 "Failed to parse avro value: {e}"
157 ))),
158 other => Err(ArrowError::ParseError(format!(
159 "Row needs to be of type object, got: {other:?}"
160 ))),
161 })
162 .collect::<ArrowResult<Vec<Vec<(String, Value)>>>>();
163
164 let rows = match rows_result {
165 Err(e) => return Some(Err(e)),
167 Ok(rows) if rows.is_empty() => return None,
169 Ok(rows) => rows,
170 };
171
172 let rows = rows.iter().collect::<Vec<&Vec<(String, Value)>>>();
173 let arrays = self.build_struct_array(&rows, "", self.schema.fields());
174
175 Some(arrays.and_then(|arr| RecordBatch::try_new(Arc::clone(&self.schema), arr)))
176 }
177
178 fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef {
179 let mut builder = BooleanBuilder::with_capacity(rows.len());
180 for row in rows {
181 if let Some(value) = self.field_lookup(col_name, row) {
182 if let Some(boolean) = resolve_boolean(value) {
183 builder.append_value(boolean)
184 } else {
185 builder.append_null();
186 }
187 } else {
188 builder.append_null();
189 }
190 }
191 Arc::new(builder.finish())
192 }
193
194 fn build_primitive_array<T>(&self, rows: RecordSlice, col_name: &str) -> ArrayRef
195 where
196 T: ArrowNumericType + Resolver,
197 T::Native: NumCast,
198 {
199 Arc::new(
200 rows.iter()
201 .map(|row| {
202 self.field_lookup(col_name, row)
203 .and_then(|value| resolve_item::<T>(value))
204 })
205 .collect::<PrimitiveArray<T>>(),
206 )
207 }
208
209 #[inline(always)]
210 fn build_string_dictionary_builder<T>(
211 &self,
212 row_len: usize,
213 ) -> StringDictionaryBuilder<T>
214 where
215 T: ArrowPrimitiveType + ArrowDictionaryKeyType,
216 {
217 StringDictionaryBuilder::with_capacity(row_len, row_len, row_len)
218 }
219
220 fn build_wrapped_list_array(
221 &self,
222 rows: RecordSlice,
223 col_name: &str,
224 key_type: &DataType,
225 ) -> ArrowResult<ArrayRef> {
226 match *key_type {
227 DataType::Int8 => {
228 let dtype = DataType::Dictionary(
229 Box::new(DataType::Int8),
230 Box::new(DataType::Utf8),
231 );
232 self.list_array_string_array_builder::<Int8Type>(&dtype, col_name, rows)
233 }
234 DataType::Int16 => {
235 let dtype = DataType::Dictionary(
236 Box::new(DataType::Int16),
237 Box::new(DataType::Utf8),
238 );
239 self.list_array_string_array_builder::<Int16Type>(&dtype, col_name, rows)
240 }
241 DataType::Int32 => {
242 let dtype = DataType::Dictionary(
243 Box::new(DataType::Int32),
244 Box::new(DataType::Utf8),
245 );
246 self.list_array_string_array_builder::<Int32Type>(&dtype, col_name, rows)
247 }
248 DataType::Int64 => {
249 let dtype = DataType::Dictionary(
250 Box::new(DataType::Int64),
251 Box::new(DataType::Utf8),
252 );
253 self.list_array_string_array_builder::<Int64Type>(&dtype, col_name, rows)
254 }
255 DataType::UInt8 => {
256 let dtype = DataType::Dictionary(
257 Box::new(DataType::UInt8),
258 Box::new(DataType::Utf8),
259 );
260 self.list_array_string_array_builder::<UInt8Type>(&dtype, col_name, rows)
261 }
262 DataType::UInt16 => {
263 let dtype = DataType::Dictionary(
264 Box::new(DataType::UInt16),
265 Box::new(DataType::Utf8),
266 );
267 self.list_array_string_array_builder::<UInt16Type>(&dtype, col_name, rows)
268 }
269 DataType::UInt32 => {
270 let dtype = DataType::Dictionary(
271 Box::new(DataType::UInt32),
272 Box::new(DataType::Utf8),
273 );
274 self.list_array_string_array_builder::<UInt32Type>(&dtype, col_name, rows)
275 }
276 DataType::UInt64 => {
277 let dtype = DataType::Dictionary(
278 Box::new(DataType::UInt64),
279 Box::new(DataType::Utf8),
280 );
281 self.list_array_string_array_builder::<UInt64Type>(&dtype, col_name, rows)
282 }
283 ref e => Err(SchemaError(format!(
284 "Data type is currently not supported for dictionaries in list : {e}"
285 ))),
286 }
287 }
288
289 #[inline(always)]
290 fn list_array_string_array_builder<D>(
291 &self,
292 data_type: &DataType,
293 col_name: &str,
294 rows: RecordSlice,
295 ) -> ArrowResult<ArrayRef>
296 where
297 D: ArrowPrimitiveType + ArrowDictionaryKeyType,
298 {
299 let mut builder: Box<dyn ArrayBuilder> = match data_type {
300 DataType::Utf8 => {
301 let values_builder = StringBuilder::with_capacity(rows.len(), 5);
302 Box::new(ListBuilder::new(values_builder))
303 }
304 DataType::Dictionary(_, _) => {
305 let values_builder =
306 self.build_string_dictionary_builder::<D>(rows.len() * 5);
307 Box::new(ListBuilder::new(values_builder))
308 }
309 e => {
310 return Err(SchemaError(format!(
311 "Nested list data builder type is not supported: {e}"
312 )))
313 }
314 };
315
316 for row in rows {
317 if let Some(value) = self.field_lookup(col_name, row) {
318 let value = maybe_resolve_union(value);
319 let vals: Vec<Option<String>> = if let Value::String(v) = value {
321 vec![Some(v.to_string())]
322 } else if let Value::Array(n) = value {
323 n.iter()
324 .map(resolve_string)
325 .collect::<ArrowResult<Vec<Option<String>>>>()?
326 .into_iter()
327 .collect::<Vec<Option<String>>>()
328 } else if let Value::Null = value {
329 vec![None]
330 } else if !matches!(value, Value::Record(_)) {
331 vec![resolve_string(value)?]
332 } else {
333 return Err(SchemaError(
334 "Only scalars are currently supported in Avro arrays".to_string(),
335 ));
336 };
337
338 match data_type {
341 DataType::Utf8 => {
342 let builder = builder
343 .as_any_mut()
344 .downcast_mut::<ListBuilder<StringBuilder>>()
345 .ok_or_else(||SchemaError(
346 "Cast failed for ListBuilder<StringBuilder> during nested data parsing".to_string(),
347 ))?;
348 for val in vals {
349 if let Some(v) = val {
350 builder.values().append_value(&v)
351 } else {
352 builder.values().append_null()
353 };
354 }
355
356 builder.append(true);
358 }
359 DataType::Dictionary(_, _) => {
360 let builder = builder.as_any_mut().downcast_mut::<ListBuilder<StringDictionaryBuilder<D>>>().ok_or_else(||SchemaError(
361 "Cast failed for ListBuilder<StringDictionaryBuilder> during nested data parsing".to_string(),
362 ))?;
363 for val in vals {
364 if let Some(v) = val {
365 let _ = builder.values().append(&v)?;
366 } else {
367 builder.values().append_null()
368 };
369 }
370
371 builder.append(true);
373 }
374 e => {
375 return Err(SchemaError(format!(
376 "Nested list data builder type is not supported: {e}"
377 )))
378 }
379 }
380 }
381 }
382
383 Ok(builder.finish() as ArrayRef)
384 }
385
386 #[inline(always)]
387 fn build_dictionary_array<T>(
388 &self,
389 rows: RecordSlice,
390 col_name: &str,
391 ) -> ArrowResult<ArrayRef>
392 where
393 T::Native: NumCast,
394 T: ArrowPrimitiveType + ArrowDictionaryKeyType,
395 {
396 let mut builder: StringDictionaryBuilder<T> =
397 self.build_string_dictionary_builder(rows.len());
398 for row in rows {
399 if let Some(value) = self.field_lookup(col_name, row) {
400 if let Ok(Some(str_v)) = resolve_string(value) {
401 builder.append(str_v).map(drop)?
402 } else {
403 builder.append_null()
404 }
405 } else {
406 builder.append_null()
407 }
408 }
409 Ok(Arc::new(builder.finish()) as ArrayRef)
410 }
411
412 #[inline(always)]
413 fn build_string_dictionary_array(
414 &self,
415 rows: RecordSlice,
416 col_name: &str,
417 key_type: &DataType,
418 value_type: &DataType,
419 ) -> ArrowResult<ArrayRef> {
420 if let DataType::Utf8 = *value_type {
421 match *key_type {
422 DataType::Int8 => self.build_dictionary_array::<Int8Type>(rows, col_name),
423 DataType::Int16 => {
424 self.build_dictionary_array::<Int16Type>(rows, col_name)
425 }
426 DataType::Int32 => {
427 self.build_dictionary_array::<Int32Type>(rows, col_name)
428 }
429 DataType::Int64 => {
430 self.build_dictionary_array::<Int64Type>(rows, col_name)
431 }
432 DataType::UInt8 => {
433 self.build_dictionary_array::<UInt8Type>(rows, col_name)
434 }
435 DataType::UInt16 => {
436 self.build_dictionary_array::<UInt16Type>(rows, col_name)
437 }
438 DataType::UInt32 => {
439 self.build_dictionary_array::<UInt32Type>(rows, col_name)
440 }
441 DataType::UInt64 => {
442 self.build_dictionary_array::<UInt64Type>(rows, col_name)
443 }
444 _ => Err(SchemaError("unsupported dictionary key type".to_string())),
445 }
446 } else {
447 Err(SchemaError(
448 "dictionary types other than UTF-8 not yet supported".to_string(),
449 ))
450 }
451 }
452
453 fn build_nested_list_array<OffsetSize: OffsetSizeTrait>(
455 &self,
456 parent_field_name: &str,
457 rows: &[&Value],
458 list_field: &Field,
459 ) -> ArrowResult<ArrayRef> {
460 let mut cur_offset = OffsetSize::zero();
462 let list_len = rows.len();
463 let num_list_bytes = bit_util::ceil(list_len, 8);
464 let mut offsets = Vec::with_capacity(list_len + 1);
465 let mut list_nulls = MutableBuffer::from_len_zeroed(num_list_bytes);
466 offsets.push(cur_offset);
467 rows.iter().enumerate().for_each(|(i, v)| {
468 let v = maybe_resolve_union(v);
470 if let Value::Array(a) = v {
471 cur_offset += OffsetSize::from_usize(a.len()).unwrap();
472 bit_util::set_bit(&mut list_nulls, i);
473 } else if let Value::Null = v {
474 } else {
476 cur_offset += OffsetSize::one();
477 }
478 offsets.push(cur_offset);
479 });
480 let valid_len = cur_offset.to_usize().unwrap();
481 let array_data = match list_field.data_type() {
482 DataType::Null => NullArray::new(valid_len).into_data(),
483 DataType::Boolean => {
484 let num_bytes = bit_util::ceil(valid_len, 8);
485 let mut bool_values = MutableBuffer::from_len_zeroed(num_bytes);
486 let mut bool_nulls =
487 MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
488 let mut curr_index = 0;
489 rows.iter().for_each(|v| {
490 if let Value::Array(vs) = v {
491 vs.iter().for_each(|value| {
492 if let Value::Boolean(child) = value {
493 if *child {
495 bit_util::set_bit(&mut bool_values, curr_index);
496 }
497 } else {
498 bit_util::unset_bit(&mut bool_nulls, curr_index);
500 }
501 curr_index += 1;
502 });
503 }
504 });
505 ArrayData::builder(list_field.data_type().clone())
506 .len(valid_len)
507 .add_buffer(bool_values.into())
508 .null_bit_buffer(Some(bool_nulls.into()))
509 .build()
510 .unwrap()
511 }
512 DataType::Int8 => self.read_primitive_list_values::<Int8Type>(rows),
513 DataType::Int16 => self.read_primitive_list_values::<Int16Type>(rows),
514 DataType::Int32 => self.read_primitive_list_values::<Int32Type>(rows),
515 DataType::Int64 => self.read_primitive_list_values::<Int64Type>(rows),
516 DataType::UInt8 => self.read_primitive_list_values::<UInt8Type>(rows),
517 DataType::UInt16 => self.read_primitive_list_values::<UInt16Type>(rows),
518 DataType::UInt32 => self.read_primitive_list_values::<UInt32Type>(rows),
519 DataType::UInt64 => self.read_primitive_list_values::<UInt64Type>(rows),
520 DataType::Float16 => {
521 return Err(SchemaError("Float16 not supported".to_string()))
522 }
523 DataType::Float32 => self.read_primitive_list_values::<Float32Type>(rows),
524 DataType::Float64 => self.read_primitive_list_values::<Float64Type>(rows),
525 DataType::Timestamp(_, _)
526 | DataType::Date32
527 | DataType::Date64
528 | DataType::Time32(_)
529 | DataType::Time64(_) => {
530 return Err(SchemaError(
531 "Temporal types are not yet supported, see ARROW-4803".to_string(),
532 ))
533 }
534 DataType::Utf8 => flatten_string_values(rows)
535 .into_iter()
536 .collect::<StringArray>()
537 .into_data(),
538 DataType::LargeUtf8 => flatten_string_values(rows)
539 .into_iter()
540 .collect::<LargeStringArray>()
541 .into_data(),
542 DataType::List(field) => {
543 let child = self.build_nested_list_array::<i32>(
544 parent_field_name,
545 &flatten_values(rows),
546 field,
547 )?;
548 child.to_data()
549 }
550 DataType::LargeList(field) => {
551 let child = self.build_nested_list_array::<i64>(
552 parent_field_name,
553 &flatten_values(rows),
554 field,
555 )?;
556 child.to_data()
557 }
558 DataType::Struct(fields) => {
559 let array_item_count = rows
561 .iter()
562 .map(|row| match maybe_resolve_union(row) {
563 Value::Array(values) => values.len(),
564 _ => 1,
565 })
566 .sum();
567 let num_bytes = bit_util::ceil(array_item_count, 8);
568 let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes);
569 let mut struct_index = 0;
570 let null_struct_array = vec![("null".to_string(), Value::Null)];
571 let rows: Vec<&Vec<(String, Value)>> = rows
572 .iter()
573 .map(|v| maybe_resolve_union(v))
574 .flat_map(|row| {
575 if let Value::Array(values) = row {
576 values
577 .iter()
578 .map(maybe_resolve_union)
579 .map(|v| match v {
580 Value::Record(record) => {
581 bit_util::set_bit(&mut null_buffer, struct_index);
582 struct_index += 1;
583 record
584 }
585 Value::Null => {
586 struct_index += 1;
587 &null_struct_array
588 }
589 other => panic!("expected Record, got {other:?}"),
590 })
591 .collect::<Vec<&Vec<(String, Value)>>>()
592 } else {
593 struct_index += 1;
594 vec![&null_struct_array]
595 }
596 })
597 .collect();
598
599 let sub_parent_field_name =
600 format!("{}.{}", parent_field_name, list_field.name());
601 let arrays =
602 self.build_struct_array(&rows, &sub_parent_field_name, fields)?;
603 let data_type = DataType::Struct(fields.clone());
604 ArrayDataBuilder::new(data_type)
605 .len(rows.len())
606 .null_bit_buffer(Some(null_buffer.into()))
607 .child_data(arrays.into_iter().map(|a| a.to_data()).collect())
608 .build()
609 .unwrap()
610 }
611 datatype => {
612 return Err(SchemaError(format!(
613 "Nested list of {datatype} not supported"
614 )));
615 }
616 };
617 let list_data = ArrayData::builder(DataType::List(Arc::new(list_field.clone())))
619 .len(list_len)
620 .add_buffer(Buffer::from_slice_ref(&offsets))
621 .add_child_data(array_data)
622 .null_bit_buffer(Some(list_nulls.into()))
623 .build()
624 .unwrap();
625 Ok(Arc::new(GenericListArray::<OffsetSize>::from(list_data)))
626 }
627
628 fn build_struct_array(
633 &self,
634 rows: RecordSlice,
635 parent_field_name: &str,
636 struct_fields: &Fields,
637 ) -> ArrowResult<Vec<ArrayRef>> {
638 let arrays: ArrowResult<Vec<ArrayRef>> = struct_fields
639 .iter()
640 .map(|field| {
641 let field_path = if parent_field_name.is_empty() {
642 field.name().to_string()
643 } else {
644 format!("{}.{}", parent_field_name, field.name())
645 };
646 let arr = match field.data_type() {
647 DataType::Null => Arc::new(NullArray::new(rows.len())) as ArrayRef,
648 DataType::Boolean => self.build_boolean_array(rows, &field_path),
649 DataType::Float64 => {
650 self.build_primitive_array::<Float64Type>(rows, &field_path)
651 }
652 DataType::Float32 => {
653 self.build_primitive_array::<Float32Type>(rows, &field_path)
654 }
655 DataType::Int64 => {
656 self.build_primitive_array::<Int64Type>(rows, &field_path)
657 }
658 DataType::Int32 => {
659 self.build_primitive_array::<Int32Type>(rows, &field_path)
660 }
661 DataType::Int16 => {
662 self.build_primitive_array::<Int16Type>(rows, &field_path)
663 }
664 DataType::Int8 => {
665 self.build_primitive_array::<Int8Type>(rows, &field_path)
666 }
667 DataType::UInt64 => {
668 self.build_primitive_array::<UInt64Type>(rows, &field_path)
669 }
670 DataType::UInt32 => {
671 self.build_primitive_array::<UInt32Type>(rows, &field_path)
672 }
673 DataType::UInt16 => {
674 self.build_primitive_array::<UInt16Type>(rows, &field_path)
675 }
676 DataType::UInt8 => {
677 self.build_primitive_array::<UInt8Type>(rows, &field_path)
678 }
679 DataType::Timestamp(unit, _) => match unit {
681 TimeUnit::Second => self
682 .build_primitive_array::<TimestampSecondType>(
683 rows,
684 &field_path,
685 ),
686 TimeUnit::Microsecond => self
687 .build_primitive_array::<TimestampMicrosecondType>(
688 rows,
689 &field_path,
690 ),
691 TimeUnit::Millisecond => self
692 .build_primitive_array::<TimestampMillisecondType>(
693 rows,
694 &field_path,
695 ),
696 TimeUnit::Nanosecond => self
697 .build_primitive_array::<TimestampNanosecondType>(
698 rows,
699 &field_path,
700 ),
701 },
702 DataType::Date64 => {
703 self.build_primitive_array::<Date64Type>(rows, &field_path)
704 }
705 DataType::Date32 => {
706 self.build_primitive_array::<Date32Type>(rows, &field_path)
707 }
708 DataType::Time64(unit) => match unit {
709 TimeUnit::Microsecond => self
710 .build_primitive_array::<Time64MicrosecondType>(
711 rows,
712 &field_path,
713 ),
714 TimeUnit::Nanosecond => self
715 .build_primitive_array::<Time64NanosecondType>(
716 rows,
717 &field_path,
718 ),
719 t => {
720 return Err(SchemaError(format!(
721 "TimeUnit {t:?} not supported with Time64"
722 )))
723 }
724 },
725 DataType::Time32(unit) => match unit {
726 TimeUnit::Second => self
727 .build_primitive_array::<Time32SecondType>(rows, &field_path),
728 TimeUnit::Millisecond => self
729 .build_primitive_array::<Time32MillisecondType>(
730 rows,
731 &field_path,
732 ),
733 t => {
734 return Err(SchemaError(format!(
735 "TimeUnit {t:?} not supported with Time32"
736 )))
737 }
738 },
739 DataType::Utf8 | DataType::LargeUtf8 => Arc::new(
740 rows.iter()
741 .map(|row| {
742 let maybe_value = self.field_lookup(&field_path, row);
743 match maybe_value {
744 None => Ok(None),
745 Some(v) => resolve_string(v),
746 }
747 })
748 .collect::<ArrowResult<StringArray>>()?,
749 )
750 as ArrayRef,
751 DataType::Binary | DataType::LargeBinary => Arc::new(
752 rows.iter()
753 .map(|row| {
754 let maybe_value = self.field_lookup(&field_path, row);
755 maybe_value.and_then(resolve_bytes)
756 })
757 .collect::<BinaryArray>(),
758 )
759 as ArrayRef,
760 DataType::FixedSizeBinary(ref size) => {
761 Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size(
762 rows.iter().map(|row| {
763 let maybe_value = self.field_lookup(&field_path, row);
764 maybe_value.and_then(|v| resolve_fixed(v, *size as usize))
765 }),
766 *size,
767 )?) as ArrayRef
768 }
769 DataType::List(ref list_field) => {
770 match list_field.data_type() {
771 DataType::Dictionary(ref key_ty, _) => {
772 self.build_wrapped_list_array(rows, &field_path, key_ty)?
773 }
774 _ => {
775 let extracted_rows = rows
777 .iter()
778 .map(|row| {
779 self.field_lookup(&field_path, row)
780 .unwrap_or(&Value::Null)
781 })
782 .collect::<Vec<&Value>>();
783 self.build_nested_list_array::<i32>(
784 &field_path,
785 &extracted_rows,
786 list_field,
787 )?
788 }
789 }
790 }
791 DataType::Dictionary(ref key_ty, ref val_ty) => self
792 .build_string_dictionary_array(
793 rows,
794 &field_path,
795 key_ty,
796 val_ty,
797 )?,
798 DataType::Struct(fields) => {
799 let len = rows.len();
800 let num_bytes = bit_util::ceil(len, 8);
801 let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes);
802 let empty_vec = vec![];
803 let struct_rows = rows
804 .iter()
805 .enumerate()
806 .map(|(i, row)| (i, self.field_lookup(&field_path, row)))
807 .map(|(i, v)| {
808 let v = v.map(maybe_resolve_union);
809 match v {
810 Some(Value::Record(value)) => {
811 bit_util::set_bit(&mut null_buffer, i);
812 value
813 }
814 None | Some(Value::Null) => &empty_vec,
815 other => {
816 panic!("expected struct got {other:?}");
817 }
818 }
819 })
820 .collect::<Vec<&Vec<(String, Value)>>>();
821 let arrays =
822 self.build_struct_array(&struct_rows, &field_path, fields)?;
823 let data_type = DataType::Struct(fields.clone());
825 let data = ArrayDataBuilder::new(data_type)
826 .len(len)
827 .null_bit_buffer(Some(null_buffer.into()))
828 .child_data(arrays.into_iter().map(|a| a.to_data()).collect())
829 .build()?;
830 make_array(data)
831 }
832 _ => {
833 return Err(SchemaError(format!(
834 "type {} not supported",
835 field.data_type()
836 )))
837 }
838 };
839 Ok(arr)
840 })
841 .collect();
842 arrays
843 }
844
845 fn read_primitive_list_values<T>(&self, rows: &[&Value]) -> ArrayData
847 where
848 T: ArrowPrimitiveType + ArrowNumericType,
849 T::Native: NumCast,
850 {
851 let values = rows
852 .iter()
853 .flat_map(|row| {
854 let row = maybe_resolve_union(row);
855 if let Value::Array(values) = row {
856 values
857 .iter()
858 .map(resolve_item::<T>)
859 .collect::<Vec<Option<T::Native>>>()
860 } else if let Some(f) = resolve_item::<T>(row) {
861 vec![Some(f)]
862 } else {
863 vec![]
864 }
865 })
866 .collect::<Vec<Option<T::Native>>>();
867 let array = values.iter().collect::<PrimitiveArray<T>>();
868 array.to_data()
869 }
870
871 fn field_lookup<'b>(
872 &self,
873 name: &str,
874 row: &'b [(String, Value)],
875 ) -> Option<&'b Value> {
876 self.schema_lookup
877 .get(name)
878 .and_then(|i| row.get(*i))
879 .map(|o| &o.1)
880 }
881}
882
883#[inline]
887fn flatten_values<'a>(values: &[&'a Value]) -> Vec<&'a Value> {
888 values
889 .iter()
890 .flat_map(|row| {
891 let v = maybe_resolve_union(row);
892 if let Value::Array(values) = v {
893 values.iter().collect()
894 } else {
895 vec![v]
897 }
898 })
899 .collect()
900}
901
902#[inline]
906fn flatten_string_values(values: &[&Value]) -> Vec<Option<String>> {
907 values
908 .iter()
909 .flat_map(|row| {
910 let row = maybe_resolve_union(row);
911 if let Value::Array(values) = row {
912 values
913 .iter()
914 .map(|s| resolve_string(s).ok().flatten())
915 .collect::<Vec<Option<_>>>()
916 } else if let Value::Null = row {
917 vec![]
918 } else {
919 vec![resolve_string(row).ok().flatten()]
920 }
921 })
922 .collect::<Vec<Option<_>>>()
923}
924
925fn resolve_string(v: &Value) -> ArrowResult<Option<String>> {
929 let v = if let Value::Union(_, b) = v { b } else { v };
930 match v {
931 Value::String(s) => Ok(Some(s.clone())),
932 Value::Bytes(bytes) => String::from_utf8(bytes.to_vec())
933 .map_err(|e| AvroError::new(AvroErrorDetails::ConvertToUtf8(e)))
934 .map(Some),
935 Value::Enum(_, s) => Ok(Some(s.clone())),
936 Value::Null => Ok(None),
937 other => Err(AvroError::new(AvroErrorDetails::GetString(other.clone()))),
938 }
939 .map_err(|e| SchemaError(format!("expected resolvable string : {e}")))
940}
941
942fn resolve_u8(v: &Value) -> Option<u8> {
943 let v = match v {
944 Value::Union(_, inner) => inner.as_ref(),
945 _ => v,
946 };
947
948 match v {
949 Value::Int(n) => u8::try_from(*n).ok(),
950 Value::Long(n) => u8::try_from(*n).ok(),
951 _ => None,
952 }
953}
954
955fn resolve_bytes(v: &Value) -> Option<Vec<u8>> {
956 let v = match v {
957 Value::Union(_, inner) => inner.as_ref(),
958 _ => v,
959 };
960
961 match v {
962 Value::Bytes(bytes) => Some(bytes.clone()),
963 Value::String(s) => Some(s.as_bytes().to_vec()),
964 Value::Array(items) => items.iter().map(resolve_u8).collect::<Option<Vec<u8>>>(),
965 _ => None,
966 }
967}
968
969fn resolve_fixed(v: &Value, size: usize) -> Option<Vec<u8>> {
970 let v = if let Value::Union(_, b) = v { b } else { v };
971 match v {
972 Value::Fixed(n, bytes) => {
973 if *n == size {
974 Some(bytes.clone())
975 } else {
976 None
977 }
978 }
979 _ => None,
980 }
981}
982
983fn resolve_boolean(value: &Value) -> Option<bool> {
984 let v = if let Value::Union(_, b) = value {
985 b
986 } else {
987 value
988 };
989 match v {
990 Value::Boolean(boolean) => Some(*boolean),
991 _ => None,
992 }
993}
994
995trait Resolver: ArrowPrimitiveType {
996 fn resolve(value: &Value) -> Option<Self::Native>;
997}
998
999fn resolve_item<T: Resolver>(value: &Value) -> Option<T::Native> {
1000 T::resolve(value)
1001}
1002
1003fn maybe_resolve_union(value: &Value) -> &Value {
1004 if SchemaKind::from(value) == SchemaKind::Union {
1005 match value {
1007 Value::Union(_, b) => b,
1008 _ => unreachable!(),
1009 }
1010 } else {
1011 value
1012 }
1013}
1014
1015impl<N> Resolver for N
1016where
1017 N: ArrowNumericType,
1018 N::Native: NumCast,
1019{
1020 fn resolve(value: &Value) -> Option<Self::Native> {
1021 let value = maybe_resolve_union(value);
1022 match value {
1023 Value::Int(i) | Value::TimeMillis(i) | Value::Date(i) => NumCast::from(*i),
1024 Value::Long(l)
1025 | Value::TimeMicros(l)
1026 | Value::TimestampMillis(l)
1027 | Value::TimestampMicros(l) => NumCast::from(*l),
1028 Value::Float(f) => NumCast::from(*f),
1029 Value::Double(f) => NumCast::from(*f),
1030 Value::Duration(_d) => unimplemented!(), Value::Null => None,
1032 _ => unreachable!(),
1033 }
1034 }
1035}
1036
1037#[cfg(test)]
1038mod test {
1039 use crate::avro_to_arrow::{Reader, ReaderBuilder};
1040 use arrow::array::Array;
1041 use arrow::datatypes::DataType;
1042 use arrow::datatypes::{Field, TimeUnit};
1043 use datafusion_common::assert_batches_eq;
1044 use datafusion_common::cast::{
1045 as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array,
1046 };
1047 use std::fs::File;
1048 use std::sync::Arc;
1049
1050 fn build_reader(name: &'_ str, batch_size: usize) -> Reader<'_, File> {
1051 let testdata = datafusion_common::test_util::arrow_test_data();
1052 let filename = format!("{testdata}/avro/{name}");
1053 let builder = ReaderBuilder::new()
1054 .read_schema()
1055 .with_batch_size(batch_size);
1056 builder.build(File::open(filename).unwrap()).unwrap()
1057 }
1058
1059 #[test]
1062 fn test_time_avro_milliseconds() {
1063 let mut reader = build_reader("alltypes_plain.avro", 10);
1064 let batch = reader.next().unwrap().unwrap();
1065
1066 assert_eq!(11, batch.num_columns());
1067 assert_eq!(8, batch.num_rows());
1068
1069 let schema = reader.schema();
1070 let batch_schema = batch.schema();
1071 assert_eq!(schema, batch_schema);
1072
1073 let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
1074 assert_eq!(
1075 &DataType::Timestamp(TimeUnit::Microsecond, None),
1076 timestamp_col.1.data_type()
1077 );
1078 let timestamp_array =
1079 as_timestamp_microsecond_array(batch.column(timestamp_col.0)).unwrap();
1080 for i in 0..timestamp_array.len() {
1081 assert!(timestamp_array.is_valid(i));
1082 }
1083 assert_eq!(1235865600000000, timestamp_array.value(0));
1084 assert_eq!(1235865660000000, timestamp_array.value(1));
1085 assert_eq!(1238544000000000, timestamp_array.value(2));
1086 assert_eq!(1238544060000000, timestamp_array.value(3));
1087 assert_eq!(1233446400000000, timestamp_array.value(4));
1088 assert_eq!(1233446460000000, timestamp_array.value(5));
1089 assert_eq!(1230768000000000, timestamp_array.value(6));
1090 assert_eq!(1230768060000000, timestamp_array.value(7));
1091 }
1092
1093 #[test]
1094 fn test_avro_read_list() {
1095 let mut reader = build_reader("list_columns.avro", 3);
1096 let schema = reader.schema();
1097 let (col_id_index, _) = schema.column_with_name("int64_list").unwrap();
1098 let batch = reader.next().unwrap().unwrap();
1099 assert_eq!(batch.num_columns(), 2);
1100 assert_eq!(batch.num_rows(), 3);
1101 let a_array = as_list_array(batch.column(col_id_index)).unwrap();
1102 assert_eq!(
1103 *a_array.data_type(),
1104 DataType::List(Arc::new(Field::new("element", DataType::Int64, true)))
1105 );
1106 let array = a_array.value(0);
1107 assert_eq!(*array.data_type(), DataType::Int64);
1108
1109 assert_eq!(
1110 6,
1111 as_int64_array(&array)
1112 .unwrap()
1113 .iter()
1114 .flatten()
1115 .sum::<i64>()
1116 );
1117 }
1118 #[test]
1119 fn test_avro_read_nested_list() {
1120 let mut reader = build_reader("nested_lists.snappy.avro", 3);
1121 let batch = reader.next().unwrap().unwrap();
1122 assert_eq!(batch.num_columns(), 2);
1123 assert_eq!(batch.num_rows(), 3);
1124 }
1125
1126 #[test]
1127 fn test_complex_list() {
1128 let schema = apache_avro::Schema::parse_str(
1129 r#"
1130 {
1131 "type": "record",
1132 "name": "r1",
1133 "fields": [
1134 {
1135 "name": "headers",
1136 "type": ["null", {
1137 "type": "array",
1138 "items": ["null",{
1139 "name":"r2",
1140 "type": "record",
1141 "fields":[
1142 {"name":"name", "type": ["null", "string"], "default": null},
1143 {"name":"value", "type": ["null", "string"], "default": null}
1144 ]
1145 }]
1146 }],
1147 "default": null
1148 }
1149 ]
1150 }"#,
1151 )
1152 .unwrap();
1153 let r1 = apache_avro::to_value(serde_json::json!({
1154 "headers": [
1155 {
1156 "name": "a",
1157 "value": "b"
1158 }
1159 ]
1160 }))
1161 .unwrap()
1162 .resolve(&schema)
1163 .unwrap();
1164
1165 let mut w = apache_avro::Writer::new(&schema, vec![]);
1166 w.append(r1).unwrap();
1167 let bytes = w.into_inner().unwrap();
1168
1169 let mut reader = ReaderBuilder::new()
1170 .read_schema()
1171 .with_batch_size(2)
1172 .build(std::io::Cursor::new(bytes))
1173 .unwrap();
1174
1175 let batch = reader.next().unwrap().unwrap();
1176 assert_eq!(batch.num_rows(), 1);
1177 assert_eq!(batch.num_columns(), 1);
1178 let expected = [
1179 "+-----------------------+",
1180 "| headers |",
1181 "+-----------------------+",
1182 "| [{name: a, value: b}] |",
1183 "+-----------------------+",
1184 ];
1185 assert_batches_eq!(expected, &[batch]);
1186 }
1187
1188 #[test]
1189 fn test_complex_struct() {
1190 let schema = apache_avro::Schema::parse_str(
1191 r#"
1192 {
1193 "type": "record",
1194 "name": "r1",
1195 "fields": [
1196 {
1197 "name": "dns",
1198 "type": [
1199 "null",
1200 {
1201 "type": "record",
1202 "name": "r13",
1203 "fields": [
1204 {
1205 "name": "answers",
1206 "type": [
1207 "null",
1208 {
1209 "type": "array",
1210 "items": [
1211 "null",
1212 {
1213 "type": "record",
1214 "name": "r292",
1215 "fields": [
1216 {
1217 "name": "class",
1218 "type": ["null", "string"],
1219 "default": null
1220 },
1221 {
1222 "name": "data",
1223 "type": ["null", "string"],
1224 "default": null
1225 },
1226 {
1227 "name": "name",
1228 "type": ["null", "string"],
1229 "default": null
1230 },
1231 {
1232 "name": "ttl",
1233 "type": ["null", "long"],
1234 "default": null
1235 },
1236 {
1237 "name": "type",
1238 "type": ["null", "string"],
1239 "default": null
1240 }
1241 ]
1242 }
1243 ]
1244 }
1245 ],
1246 "default": null
1247 },
1248 {
1249 "name": "header_flags",
1250 "type": [
1251 "null",
1252 {
1253 "type": "array",
1254 "items": ["null", "string"]
1255 }
1256 ],
1257 "default": null
1258 },
1259 {
1260 "name": "id",
1261 "type": ["null", "string"],
1262 "default": null
1263 },
1264 {
1265 "name": "op_code",
1266 "type": ["null", "string"],
1267 "default": null
1268 },
1269 {
1270 "name": "question",
1271 "type": [
1272 "null",
1273 {
1274 "type": "record",
1275 "name": "r288",
1276 "fields": [
1277 {
1278 "name": "class",
1279 "type": ["null", "string"],
1280 "default": null
1281 },
1282 {
1283 "name": "name",
1284 "type": ["null", "string"],
1285 "default": null
1286 },
1287 {
1288 "name": "registered_domain",
1289 "type": ["null", "string"],
1290 "default": null
1291 },
1292 {
1293 "name": "subdomain",
1294 "type": ["null", "string"],
1295 "default": null
1296 },
1297 {
1298 "name": "top_level_domain",
1299 "type": ["null", "string"],
1300 "default": null
1301 },
1302 {
1303 "name": "type",
1304 "type": ["null", "string"],
1305 "default": null
1306 }
1307 ]
1308 }
1309 ],
1310 "default": null
1311 },
1312 {
1313 "name": "resolved_ip",
1314 "type": [
1315 "null",
1316 {
1317 "type": "array",
1318 "items": ["null", "string"]
1319 }
1320 ],
1321 "default": null
1322 },
1323 {
1324 "name": "response_code",
1325 "type": ["null", "string"],
1326 "default": null
1327 },
1328 {
1329 "name": "type",
1330 "type": ["null", "string"],
1331 "default": null
1332 }
1333 ]
1334 }
1335 ],
1336 "default": null
1337 }
1338 ]
1339 }"#,
1340 )
1341 .unwrap();
1342
1343 let jv1 = serde_json::json!({
1344 "dns": {
1345 "answers": [
1346 {
1347 "data": "CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=",
1348 "type": "1"
1349 },
1350 {
1351 "data": "CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=",
1352 "type": "1"
1353 },
1354 {
1355 "data": "CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=",
1356 "type": "1"
1357 }
1358 ],
1359 "question": {
1360 "name": "security.ubuntu.com",
1361 "type": "A"
1362 },
1363 "resolved_ip": [
1364 "67.43.156.1",
1365 "67.43.156.2",
1366 "67.43.156.3"
1367 ],
1368 "response_code": "0"
1369 }
1370 });
1371 let r1 = apache_avro::to_value(jv1)
1372 .unwrap()
1373 .resolve(&schema)
1374 .unwrap();
1375
1376 let mut w = apache_avro::Writer::new(&schema, vec![]);
1377 w.append(r1).unwrap();
1378 let bytes = w.into_inner().unwrap();
1379
1380 let mut reader = ReaderBuilder::new()
1381 .read_schema()
1382 .with_batch_size(1)
1383 .build(std::io::Cursor::new(bytes))
1384 .unwrap();
1385
1386 let batch = reader.next().unwrap().unwrap();
1387 assert_eq!(batch.num_rows(), 1);
1388 assert_eq!(batch.num_columns(), 1);
1389
1390 let expected = [
1391 "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1392 "| dns |",
1393 "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1394 "| {answers: [{class: , data: CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=, name: , ttl: , type: 1}, {class: , data: CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=, name: , ttl: , type: 1}, {class: , data: CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=, name: , ttl: , type: 1}], header_flags: , id: , op_code: , question: {class: , name: security.ubuntu.com, registered_domain: , subdomain: , top_level_domain: , type: A}, resolved_ip: [67.43.156.1, 67.43.156.2, 67.43.156.3], response_code: 0, type: } |",
1395 "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+",
1396 ];
1397 assert_batches_eq!(expected, &[batch]);
1398 }
1399
1400 #[test]
1401 fn test_deep_nullable_struct() {
1402 let schema = apache_avro::Schema::parse_str(
1403 r#"
1404 {
1405 "type": "record",
1406 "name": "r1",
1407 "fields": [
1408 {
1409 "name": "col1",
1410 "type": [
1411 "null",
1412 {
1413 "type": "record",
1414 "name": "r2",
1415 "fields": [
1416 {
1417 "name": "col2",
1418 "type": [
1419 "null",
1420 {
1421 "type": "record",
1422 "name": "r3",
1423 "fields": [
1424 {
1425 "name": "col3",
1426 "type": [
1427 "null",
1428 {
1429 "type": "record",
1430 "name": "r4",
1431 "fields": [
1432 {
1433 "name": "col4",
1434 "type": [
1435 "null",
1436 {
1437 "type": "record",
1438 "name": "r5",
1439 "fields": [
1440 {
1441 "name": "col5",
1442 "type": ["null", "string"]
1443 }
1444 ]
1445 }
1446 ]
1447 }
1448 ]
1449 }
1450 ]
1451 }
1452 ]
1453 }
1454 ]
1455 }
1456 ]
1457 }
1458 ]
1459 }
1460 ]
1461 }
1462 "#,
1463 )
1464 .unwrap();
1465 let r1 = apache_avro::to_value(serde_json::json!({
1466 "col1": {
1467 "col2": {
1468 "col3": {
1469 "col4": {
1470 "col5": "hello"
1471 }
1472 }
1473 }
1474 }
1475 }))
1476 .unwrap()
1477 .resolve(&schema)
1478 .unwrap();
1479 let r2 = apache_avro::to_value(serde_json::json!({
1480 "col1": {
1481 "col2": {
1482 "col3": {
1483 "col4": {
1484 "col5": null
1485 }
1486 }
1487 }
1488 }
1489 }))
1490 .unwrap()
1491 .resolve(&schema)
1492 .unwrap();
1493 let r3 = apache_avro::to_value(serde_json::json!({
1494 "col1": {
1495 "col2": {
1496 "col3": null
1497 }
1498 }
1499 }))
1500 .unwrap()
1501 .resolve(&schema)
1502 .unwrap();
1503 let r4 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1504 .unwrap()
1505 .resolve(&schema)
1506 .unwrap();
1507
1508 let mut w = apache_avro::Writer::new(&schema, vec![]);
1509 w.append(r1).unwrap();
1510 w.append(r2).unwrap();
1511 w.append(r3).unwrap();
1512 w.append(r4).unwrap();
1513 let bytes = w.into_inner().unwrap();
1514
1515 let mut reader = ReaderBuilder::new()
1516 .read_schema()
1517 .with_batch_size(4)
1518 .build(std::io::Cursor::new(bytes))
1519 .unwrap();
1520
1521 let batch = reader.next().unwrap().unwrap();
1522
1523 let expected = [
1524 "+---------------------------------------+",
1525 "| col1 |",
1526 "+---------------------------------------+",
1527 "| {col2: {col3: {col4: {col5: hello}}}} |",
1528 "| {col2: {col3: {col4: {col5: }}}} |",
1529 "| {col2: {col3: }} |",
1530 "| |",
1531 "+---------------------------------------+",
1532 ];
1533 assert_batches_eq!(expected, &[batch]);
1534 }
1535
1536 #[test]
1537 fn test_avro_nullable_struct() {
1538 let schema = apache_avro::Schema::parse_str(
1539 r#"
1540 {
1541 "type": "record",
1542 "name": "r1",
1543 "fields": [
1544 {
1545 "name": "col1",
1546 "type": [
1547 "null",
1548 {
1549 "type": "record",
1550 "name": "r2",
1551 "fields": [
1552 {
1553 "name": "col2",
1554 "type": ["null", "string"]
1555 }
1556 ]
1557 }
1558 ],
1559 "default": null
1560 }
1561 ]
1562 }"#,
1563 )
1564 .unwrap();
1565 let r1 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1566 .unwrap()
1567 .resolve(&schema)
1568 .unwrap();
1569 let r2 = apache_avro::to_value(serde_json::json!({
1570 "col1": {
1571 "col2": "hello"
1572 }
1573 }))
1574 .unwrap()
1575 .resolve(&schema)
1576 .unwrap();
1577 let r3 = apache_avro::to_value(serde_json::json!({
1578 "col1": {
1579 "col2": null
1580 }
1581 }))
1582 .unwrap()
1583 .resolve(&schema)
1584 .unwrap();
1585
1586 let mut w = apache_avro::Writer::new(&schema, vec![]);
1587 w.append(r1).unwrap();
1588 w.append(r2).unwrap();
1589 w.append(r3).unwrap();
1590 let bytes = w.into_inner().unwrap();
1591
1592 let mut reader = ReaderBuilder::new()
1593 .read_schema()
1594 .with_batch_size(3)
1595 .build(std::io::Cursor::new(bytes))
1596 .unwrap();
1597 let batch = reader.next().unwrap().unwrap();
1598 assert_eq!(batch.num_rows(), 3);
1599 assert_eq!(batch.num_columns(), 1);
1600
1601 let expected = [
1602 "+---------------+",
1603 "| col1 |",
1604 "+---------------+",
1605 "| |",
1606 "| {col2: hello} |",
1607 "| {col2: } |",
1608 "+---------------+",
1609 ];
1610 assert_batches_eq!(expected, &[batch]);
1611 }
1612
1613 #[test]
1614 fn test_avro_nullable_struct_array() {
1615 let schema = apache_avro::Schema::parse_str(
1616 r#"
1617 {
1618 "type": "record",
1619 "name": "r1",
1620 "fields": [
1621 {
1622 "name": "col1",
1623 "type": [
1624 "null",
1625 {
1626 "type": "array",
1627 "items": {
1628 "type": [
1629 "null",
1630 {
1631 "type": "record",
1632 "name": "Item",
1633 "fields": [
1634 {
1635 "name": "id",
1636 "type": "long"
1637 }
1638 ]
1639 }
1640 ]
1641 }
1642 }
1643 ],
1644 "default": null
1645 }
1646 ]
1647 }"#,
1648 )
1649 .unwrap();
1650 let jv1 = serde_json::json!({
1651 "col1": [
1652 {
1653 "id": 234
1654 },
1655 {
1656 "id": 345
1657 }
1658 ]
1659 });
1660 let r1 = apache_avro::to_value(jv1)
1661 .unwrap()
1662 .resolve(&schema)
1663 .unwrap();
1664 let r2 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1665 .unwrap()
1666 .resolve(&schema)
1667 .unwrap();
1668
1669 let mut w = apache_avro::Writer::new(&schema, vec![]);
1670 for _i in 0..5 {
1671 w.append(r1.clone()).unwrap();
1672 }
1673 w.append(r2).unwrap();
1674 let bytes = w.into_inner().unwrap();
1675
1676 let mut reader = ReaderBuilder::new()
1677 .read_schema()
1678 .with_batch_size(20)
1679 .build(std::io::Cursor::new(bytes))
1680 .unwrap();
1681 let batch = reader.next().unwrap().unwrap();
1682 assert_eq!(batch.num_rows(), 6);
1683 assert_eq!(batch.num_columns(), 1);
1684
1685 let expected = [
1686 "+------------------------+",
1687 "| col1 |",
1688 "+------------------------+",
1689 "| [{id: 234}, {id: 345}] |",
1690 "| [{id: 234}, {id: 345}] |",
1691 "| [{id: 234}, {id: 345}] |",
1692 "| [{id: 234}, {id: 345}] |",
1693 "| [{id: 234}, {id: 345}] |",
1694 "| |",
1695 "+------------------------+",
1696 ];
1697 assert_batches_eq!(expected, &[batch]);
1698 }
1699
1700 #[test]
1701 fn test_avro_iterator() {
1702 let reader = build_reader("alltypes_plain.avro", 5);
1703 let schema = reader.schema();
1704 let (col_id_index, _) = schema.column_with_name("id").unwrap();
1705
1706 let mut sum_num_rows = 0;
1707 let mut num_batches = 0;
1708 let mut sum_id = 0;
1709 for batch in reader {
1710 let batch = batch.unwrap();
1711 assert_eq!(11, batch.num_columns());
1712 sum_num_rows += batch.num_rows();
1713 num_batches += 1;
1714 let batch_schema = batch.schema();
1715 assert_eq!(schema, batch_schema);
1716 let a_array = as_int32_array(batch.column(col_id_index)).unwrap();
1717 sum_id += (0..a_array.len()).map(|i| a_array.value(i)).sum::<i32>();
1718 }
1719 assert_eq!(8, sum_num_rows);
1720 assert_eq!(2, num_batches);
1721 assert_eq!(28, sum_id);
1722 }
1723}