datafusion_datasource_avro/avro_to_arrow/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use apache_avro::schema::{
19    Alias, DecimalSchema, EnumSchema, FixedSchema, Name, RecordSchema,
20};
21use apache_avro::types::Value;
22use apache_avro::Schema as AvroSchema;
23use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode};
24use arrow::datatypes::{Field, UnionFields};
25use datafusion_common::error::Result;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29/// Converts an avro schema to an arrow schema
30pub fn to_arrow_schema(avro_schema: &apache_avro::Schema) -> Result<Schema> {
31    let mut schema_fields = vec![];
32    match avro_schema {
33        AvroSchema::Record(RecordSchema { fields, .. }) => {
34            for field in fields {
35                schema_fields.push(schema_to_field_with_props(
36                    &field.schema,
37                    Some(&field.name),
38                    field.is_nullable(),
39                    Some(external_props(&field.schema)),
40                )?)
41            }
42        }
43        schema => schema_fields.push(schema_to_field(schema, Some(""), false)?),
44    }
45
46    let schema = Schema::new(schema_fields);
47    Ok(schema)
48}
49
50fn schema_to_field(
51    schema: &apache_avro::Schema,
52    name: Option<&str>,
53    nullable: bool,
54) -> Result<Field> {
55    schema_to_field_with_props(schema, name, nullable, Default::default())
56}
57
58fn schema_to_field_with_props(
59    schema: &AvroSchema,
60    name: Option<&str>,
61    nullable: bool,
62    props: Option<HashMap<String, String>>,
63) -> Result<Field> {
64    let mut nullable = nullable;
65    let field_type: DataType = match schema {
66        AvroSchema::Ref { .. } => todo!("Add support for AvroSchema::Ref"),
67        AvroSchema::Null => DataType::Null,
68        AvroSchema::Boolean => DataType::Boolean,
69        AvroSchema::Int => DataType::Int32,
70        AvroSchema::Long => DataType::Int64,
71        AvroSchema::Float => DataType::Float32,
72        AvroSchema::Double => DataType::Float64,
73        AvroSchema::Bytes => DataType::Binary,
74        AvroSchema::String => DataType::Utf8,
75        AvroSchema::Array(item_schema) => DataType::List(Arc::new(
76            schema_to_field_with_props(&item_schema.items, Some("element"), false, None)?,
77        )),
78        AvroSchema::Map(value_schema) => {
79            let value_field = schema_to_field_with_props(
80                &value_schema.types,
81                Some("value"),
82                false,
83                None,
84            )?;
85            DataType::Dictionary(
86                Box::new(DataType::Utf8),
87                Box::new(value_field.data_type().clone()),
88            )
89        }
90        AvroSchema::Union(us) => {
91            // If there are only two variants and one of them is null, set the other type as the field data type
92            let has_nullable = us
93                .find_schema_with_known_schemata::<apache_avro::Schema>(
94                    &Value::Null,
95                    None,
96                    &None,
97                )
98                .is_some();
99            let sub_schemas = us.variants();
100            if has_nullable && sub_schemas.len() == 2 {
101                nullable = true;
102                if let Some(schema) = sub_schemas
103                    .iter()
104                    .find(|&schema| !matches!(schema, AvroSchema::Null))
105                {
106                    schema_to_field_with_props(schema, None, has_nullable, None)?
107                        .data_type()
108                        .clone()
109                } else {
110                    return Err(apache_avro::Error::new(
111                        apache_avro::error::Details::GetUnionDuplicate,
112                    )
113                    .into());
114                }
115            } else {
116                let fields = sub_schemas
117                    .iter()
118                    .map(|s| schema_to_field_with_props(s, None, has_nullable, None))
119                    .collect::<Result<Vec<Field>>>()?;
120                let type_ids = 0_i8..fields.len() as i8;
121                DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
122            }
123        }
124        AvroSchema::Record(RecordSchema { fields, .. }) => {
125            let fields: Result<_> = fields
126                .iter()
127                .map(|field| {
128                    let mut props = HashMap::new();
129                    if let Some(doc) = &field.doc {
130                        props.insert("avro::doc".to_string(), doc.clone());
131                    }
132                    /*if let Some(aliases) = fields.aliases {
133                        props.insert("aliases", aliases);
134                    }*/
135                    schema_to_field_with_props(
136                        &field.schema,
137                        Some(&field.name),
138                        false,
139                        Some(props),
140                    )
141                })
142                .collect();
143            DataType::Struct(fields?)
144        }
145        AvroSchema::Enum(EnumSchema { .. }) => DataType::Utf8,
146        AvroSchema::Fixed(FixedSchema { size, .. }) => {
147            DataType::FixedSizeBinary(*size as i32)
148        }
149        AvroSchema::Decimal(DecimalSchema {
150            precision, scale, ..
151        }) => DataType::Decimal128(*precision as u8, *scale as i8),
152        AvroSchema::BigDecimal => DataType::LargeBinary,
153        AvroSchema::Uuid => DataType::FixedSizeBinary(16),
154        AvroSchema::Date => DataType::Date32,
155        AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond),
156        AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond),
157        AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None),
158        AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None),
159        AvroSchema::TimestampNanos => DataType::Timestamp(TimeUnit::Nanosecond, None),
160        AvroSchema::LocalTimestampMillis => todo!(),
161        AvroSchema::LocalTimestampMicros => todo!(),
162        AvroSchema::LocalTimestampNanos => todo!(),
163        AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond),
164    };
165
166    let data_type = field_type.clone();
167    let name = name.unwrap_or_else(|| default_field_name(&data_type));
168
169    let mut field = Field::new(name, field_type, nullable);
170    field.set_metadata(props.unwrap_or_default());
171    Ok(field)
172}
173
174fn default_field_name(dt: &DataType) -> &str {
175    match dt {
176        DataType::Null => "null",
177        DataType::Boolean => "bit",
178        DataType::Int8 => "tinyint",
179        DataType::Int16 => "smallint",
180        DataType::Int32 => "int",
181        DataType::Int64 => "bigint",
182        DataType::UInt8 => "uint1",
183        DataType::UInt16 => "uint2",
184        DataType::UInt32 => "uint4",
185        DataType::UInt64 => "uint8",
186        DataType::Float16 => "float2",
187        DataType::Float32 => "float4",
188        DataType::Float64 => "float8",
189        DataType::Date32 => "dateday",
190        DataType::Date64 => "datemilli",
191        DataType::Time32(tu) | DataType::Time64(tu) => match tu {
192            TimeUnit::Second => "timesec",
193            TimeUnit::Millisecond => "timemilli",
194            TimeUnit::Microsecond => "timemicro",
195            TimeUnit::Nanosecond => "timenano",
196        },
197        DataType::Timestamp(tu, tz) => {
198            if tz.is_some() {
199                match tu {
200                    TimeUnit::Second => "timestampsectz",
201                    TimeUnit::Millisecond => "timestampmillitz",
202                    TimeUnit::Microsecond => "timestampmicrotz",
203                    TimeUnit::Nanosecond => "timestampnanotz",
204                }
205            } else {
206                match tu {
207                    TimeUnit::Second => "timestampsec",
208                    TimeUnit::Millisecond => "timestampmilli",
209                    TimeUnit::Microsecond => "timestampmicro",
210                    TimeUnit::Nanosecond => "timestampnano",
211                }
212            }
213        }
214        DataType::Duration(_) => "duration",
215        DataType::Interval(unit) => match unit {
216            IntervalUnit::YearMonth => "intervalyear",
217            IntervalUnit::DayTime => "intervalmonth",
218            IntervalUnit::MonthDayNano => "intervalmonthdaynano",
219        },
220        DataType::Binary => "varbinary",
221        DataType::FixedSizeBinary(_) => "fixedsizebinary",
222        DataType::LargeBinary => "largevarbinary",
223        DataType::Utf8 => "varchar",
224        DataType::LargeUtf8 => "largevarchar",
225        DataType::List(_) => "list",
226        DataType::FixedSizeList(_, _) => "fixed_size_list",
227        DataType::LargeList(_) => "largelist",
228        DataType::Struct(_) => "struct",
229        DataType::Union(_, _) => "union",
230        DataType::Dictionary(_, _) => "map",
231        DataType::Map(_, _) => unimplemented!("Map support not implemented"),
232        DataType::RunEndEncoded(_, _) => {
233            unimplemented!("RunEndEncoded support not implemented")
234        }
235        DataType::Utf8View
236        | DataType::BinaryView
237        | DataType::ListView(_)
238        | DataType::LargeListView(_) => {
239            unimplemented!("View support not implemented")
240        }
241        DataType::Decimal32(_, _) => "decimal",
242        DataType::Decimal64(_, _) => "decimal",
243        DataType::Decimal128(_, _) => "decimal",
244        DataType::Decimal256(_, _) => "decimal",
245    }
246}
247
248fn external_props(schema: &AvroSchema) -> HashMap<String, String> {
249    let mut props = HashMap::new();
250    match &schema {
251        AvroSchema::Record(RecordSchema {
252            doc: Some(ref doc), ..
253        })
254        | AvroSchema::Enum(EnumSchema {
255            doc: Some(ref doc), ..
256        })
257        | AvroSchema::Fixed(FixedSchema {
258            doc: Some(ref doc), ..
259        }) => {
260            props.insert("avro::doc".to_string(), doc.clone());
261        }
262        _ => {}
263    }
264    match &schema {
265        AvroSchema::Record(RecordSchema {
266            name: Name { namespace, .. },
267            aliases: Some(aliases),
268            ..
269        })
270        | AvroSchema::Enum(EnumSchema {
271            name: Name { namespace, .. },
272            aliases: Some(aliases),
273            ..
274        })
275        | AvroSchema::Fixed(FixedSchema {
276            name: Name { namespace, .. },
277            aliases: Some(aliases),
278            ..
279        }) => {
280            let aliases: Vec<String> = aliases
281                .iter()
282                .map(|alias| aliased(alias, namespace.as_deref(), None))
283                .collect();
284            props.insert(
285                "avro::aliases".to_string(),
286                format!("[{}]", aliases.join(",")),
287            );
288        }
289        _ => {}
290    }
291    props
292}
293
294/// Returns the fully qualified name for a field
295pub fn aliased(
296    alias: &Alias,
297    namespace: Option<&str>,
298    default_namespace: Option<&str>,
299) -> String {
300    if alias.namespace().is_some() {
301        alias.fullname(None)
302    } else {
303        let namespace = namespace.as_ref().copied().or(default_namespace);
304
305        match namespace {
306            Some(ref namespace) => format!("{}.{}", namespace, alias.name()),
307            None => alias.fullname(None),
308        }
309    }
310}
311
312#[cfg(test)]
313mod test {
314    use super::{aliased, external_props, to_arrow_schema};
315    use apache_avro::schema::{Alias, EnumSchema, FixedSchema, Name, RecordSchema};
316    use apache_avro::Schema as AvroSchema;
317    use arrow::datatypes::DataType::{Binary, Float32, Float64, Timestamp, Utf8};
318    use arrow::datatypes::DataType::{Boolean, Int32, Int64};
319    use arrow::datatypes::TimeUnit::Microsecond;
320    use arrow::datatypes::{Field, Schema};
321
322    fn alias(name: &str) -> Alias {
323        Alias::new(name).unwrap()
324    }
325
326    #[test]
327    fn test_alias() {
328        assert_eq!(aliased(&alias("foo.bar"), None, None), "foo.bar");
329        assert_eq!(aliased(&alias("bar"), Some("foo"), None), "foo.bar");
330        assert_eq!(aliased(&alias("bar"), Some("foo"), Some("cat")), "foo.bar");
331        assert_eq!(aliased(&alias("bar"), None, Some("cat")), "cat.bar");
332    }
333
334    #[test]
335    fn test_external_props() {
336        let record_schema = AvroSchema::Record(RecordSchema {
337            name: Name {
338                name: "record".to_string(),
339                namespace: None,
340            },
341            aliases: Some(vec![alias("fooalias"), alias("baralias")]),
342            doc: Some("record documentation".to_string()),
343            fields: vec![],
344            lookup: Default::default(),
345            attributes: Default::default(),
346        });
347        let props = external_props(&record_schema);
348        assert_eq!(
349            props.get("avro::doc"),
350            Some(&"record documentation".to_string())
351        );
352        assert_eq!(
353            props.get("avro::aliases"),
354            Some(&"[fooalias,baralias]".to_string())
355        );
356        let enum_schema = AvroSchema::Enum(EnumSchema {
357            name: Name {
358                name: "enum".to_string(),
359                namespace: None,
360            },
361            aliases: Some(vec![alias("fooenum"), alias("barenum")]),
362            doc: Some("enum documentation".to_string()),
363            symbols: vec![],
364            default: None,
365            attributes: Default::default(),
366        });
367        let props = external_props(&enum_schema);
368        assert_eq!(
369            props.get("avro::doc"),
370            Some(&"enum documentation".to_string())
371        );
372        assert_eq!(
373            props.get("avro::aliases"),
374            Some(&"[fooenum,barenum]".to_string())
375        );
376        let fixed_schema = AvroSchema::Fixed(FixedSchema {
377            name: Name {
378                name: "fixed".to_string(),
379                namespace: None,
380            },
381            aliases: Some(vec![alias("foofixed"), alias("barfixed")]),
382            size: 1,
383            doc: None,
384            default: None,
385            attributes: Default::default(),
386        });
387        let props = external_props(&fixed_schema);
388        assert_eq!(
389            props.get("avro::aliases"),
390            Some(&"[foofixed,barfixed]".to_string())
391        );
392    }
393
394    #[test]
395    fn test_invalid_avro_schema() {}
396
397    #[test]
398    fn test_plain_types_schema() {
399        let schema = AvroSchema::parse_str(
400            r#"
401            {
402              "type" : "record",
403              "name" : "topLevelRecord",
404              "fields" : [ {
405                "name" : "id",
406                "type" : [ "int", "null" ]
407              }, {
408                "name" : "bool_col",
409                "type" : [ "boolean", "null" ]
410              }, {
411                "name" : "tinyint_col",
412                "type" : [ "int", "null" ]
413              }, {
414                "name" : "smallint_col",
415                "type" : [ "int", "null" ]
416              }, {
417                "name" : "int_col",
418                "type" : [ "int", "null" ]
419              }, {
420                "name" : "bigint_col",
421                "type" : [ "long", "null" ]
422              }, {
423                "name" : "float_col",
424                "type" : [ "float", "null" ]
425              }, {
426                "name" : "double_col",
427                "type" : [ "double", "null" ]
428              }, {
429                "name" : "date_string_col",
430                "type" : [ "bytes", "null" ]
431              }, {
432                "name" : "string_col",
433                "type" : [ "bytes", "null" ]
434              }, {
435                "name" : "timestamp_col",
436                "type" : [ {
437                  "type" : "long",
438                  "logicalType" : "timestamp-micros"
439                }, "null" ]
440              } ]
441            }"#,
442        );
443        assert!(schema.is_ok(), "{schema:?}");
444        let arrow_schema = to_arrow_schema(&schema.unwrap());
445        assert!(arrow_schema.is_ok(), "{arrow_schema:?}");
446        let expected = Schema::new(vec![
447            Field::new("id", Int32, true),
448            Field::new("bool_col", Boolean, true),
449            Field::new("tinyint_col", Int32, true),
450            Field::new("smallint_col", Int32, true),
451            Field::new("int_col", Int32, true),
452            Field::new("bigint_col", Int64, true),
453            Field::new("float_col", Float32, true),
454            Field::new("double_col", Float64, true),
455            Field::new("date_string_col", Binary, true),
456            Field::new("string_col", Binary, true),
457            Field::new("timestamp_col", Timestamp(Microsecond, None), true),
458        ]);
459        assert_eq!(arrow_schema.unwrap(), expected);
460    }
461
462    #[test]
463    fn test_nested_schema() {
464        let avro_schema = apache_avro::Schema::parse_str(
465            r#"
466            {
467              "type": "record",
468              "name": "r1",
469              "fields": [
470                {
471                  "name": "col1",
472                  "type": [
473                    "null",
474                    {
475                      "type": "record",
476                      "name": "r2",
477                      "fields": [
478                        {
479                          "name": "col2",
480                          "type": "string"
481                        },
482                        {
483                          "name": "col3",
484                          "type": ["null", "string"],
485                          "default": null
486                        }
487                      ]
488                    }
489                  ],
490                  "default": null
491                }
492              ]
493            }"#,
494        )
495        .unwrap();
496        // should not use Avro Record names.
497        let expected_arrow_schema = Schema::new(vec![Field::new(
498            "col1",
499            arrow::datatypes::DataType::Struct(
500                vec![
501                    Field::new("col2", Utf8, false),
502                    Field::new("col3", Utf8, true),
503                ]
504                .into(),
505            ),
506            true,
507        )]);
508        assert_eq!(
509            to_arrow_schema(&avro_schema).unwrap(),
510            expected_arrow_schema
511        );
512    }
513
514    #[test]
515    fn test_non_record_schema() {
516        let arrow_schema = to_arrow_schema(&AvroSchema::String);
517        assert!(arrow_schema.is_ok(), "{arrow_schema:?}");
518        assert_eq!(
519            arrow_schema.unwrap(),
520            Schema::new(vec![Field::new("", Utf8, false)])
521        );
522    }
523}