datafusion_datasource_avro/avro_to_arrow/
reader.rs1use super::arrow_array_reader::AvroArrowArrayReader;
19use arrow::datatypes::{Fields, SchemaRef};
20use arrow::error::Result as ArrowResult;
21use arrow::record_batch::RecordBatch;
22use datafusion_common::Result;
23use std::io::{Read, Seek};
24use std::sync::Arc;
25
26#[derive(Debug)]
28pub struct ReaderBuilder {
29 schema: Option<SchemaRef>,
33 batch_size: usize,
37 projection: Option<Vec<String>>,
39}
40
41impl Default for ReaderBuilder {
42 fn default() -> Self {
43 Self {
44 schema: None,
45 batch_size: 1024,
46 projection: None,
47 }
48 }
49}
50
51impl ReaderBuilder {
52 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
80 self.schema = Some(schema);
81 self
82 }
83
84 pub fn read_schema(mut self) -> Self {
86 self.schema = None;
88 self
89 }
90
91 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
93 self.batch_size = batch_size;
94 self
95 }
96
97 pub fn with_projection(mut self, projection: Vec<String>) -> Self {
99 self.projection = Some(projection);
100 self
101 }
102
103 pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
105 where
106 R: Read + Seek,
107 {
108 let mut source = source;
109
110 let schema = match self.schema {
112 Some(schema) => schema,
113 None => Arc::new(super::read_avro_schema_from_reader(&mut source)?),
114 };
115 source.rewind()?;
116 Reader::try_new(source, schema, self.batch_size, self.projection)
117 }
118}
119
120pub struct Reader<'a, R: Read> {
122 array_reader: AvroArrowArrayReader<'a, R>,
123 schema: SchemaRef,
124 batch_size: usize,
125}
126
127impl<R: Read> Reader<'_, R> {
128 pub fn try_new(
137 reader: R,
138 schema: SchemaRef,
139 batch_size: usize,
140 projection: Option<Vec<String>>,
141 ) -> Result<Self> {
142 let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else(
143 || Arc::clone(&schema),
144 |proj| {
145 Arc::new(arrow::datatypes::Schema::new(
146 proj.iter()
147 .filter_map(|name| {
148 schema.column_with_name(name).map(|(_, f)| f.clone())
149 })
150 .collect::<Fields>(),
151 ))
152 },
153 );
154
155 Ok(Self {
156 array_reader: AvroArrowArrayReader::try_new(
157 reader,
158 Arc::clone(&projected_schema),
159 )?,
160 schema: projected_schema,
161 batch_size,
162 })
163 }
164
165 pub fn schema(&self) -> SchemaRef {
168 Arc::clone(&self.schema)
169 }
170}
171
172impl<R: Read> Iterator for Reader<'_, R> {
173 type Item = ArrowResult<RecordBatch>;
174
175 fn next(&mut self) -> Option<Self::Item> {
178 self.array_reader.next_batch(self.batch_size)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use arrow::array::*;
186 use arrow::array::{
187 BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
188 TimestampMicrosecondArray,
189 };
190 use arrow::datatypes::TimeUnit;
191 use arrow::datatypes::{DataType, Field};
192 use std::fs::File;
193
194 fn build_reader(name: &'_ str, projection: Option<Vec<String>>) -> Reader<'_, File> {
195 let testdata = datafusion_common::test_util::arrow_test_data();
196 let filename = format!("{testdata}/avro/{name}");
197 let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64);
198 if let Some(projection) = projection {
199 builder = builder.with_projection(projection);
200 }
201 builder.build(File::open(filename).unwrap()).unwrap()
202 }
203
204 fn get_col<'a, T: 'static>(
205 batch: &'a RecordBatch,
206 col: (usize, &Field),
207 ) -> Option<&'a T> {
208 batch.column(col.0).as_any().downcast_ref::<T>()
209 }
210
211 #[test]
212 fn test_avro_basic() {
213 let mut reader = build_reader("alltypes_dictionary.avro", None);
214 let batch = reader.next().unwrap().unwrap();
215
216 assert_eq!(11, batch.num_columns());
217 assert_eq!(2, batch.num_rows());
218
219 let schema = reader.schema();
220 let batch_schema = batch.schema();
221 assert_eq!(schema, batch_schema);
222
223 let id = schema.column_with_name("id").unwrap();
224 assert_eq!(0, id.0);
225 assert_eq!(&DataType::Int32, id.1.data_type());
226 let col = get_col::<Int32Array>(&batch, id).unwrap();
227 assert_eq!(0, col.value(0));
228 assert_eq!(1, col.value(1));
229 let bool_col = schema.column_with_name("bool_col").unwrap();
230 assert_eq!(1, bool_col.0);
231 assert_eq!(&DataType::Boolean, bool_col.1.data_type());
232 let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
233 assert!(col.value(0));
234 assert!(!col.value(1));
235 let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
236 assert_eq!(2, tinyint_col.0);
237 assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
238 let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
239 assert_eq!(0, col.value(0));
240 assert_eq!(1, col.value(1));
241 let smallint_col = schema.column_with_name("smallint_col").unwrap();
242 assert_eq!(3, smallint_col.0);
243 assert_eq!(&DataType::Int32, smallint_col.1.data_type());
244 let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
245 assert_eq!(0, col.value(0));
246 assert_eq!(1, col.value(1));
247 let int_col = schema.column_with_name("int_col").unwrap();
248 assert_eq!(4, int_col.0);
249 let col = get_col::<Int32Array>(&batch, int_col).unwrap();
250 assert_eq!(0, col.value(0));
251 assert_eq!(1, col.value(1));
252 assert_eq!(&DataType::Int32, int_col.1.data_type());
253 let col = get_col::<Int32Array>(&batch, int_col).unwrap();
254 assert_eq!(0, col.value(0));
255 assert_eq!(1, col.value(1));
256 let bigint_col = schema.column_with_name("bigint_col").unwrap();
257 assert_eq!(5, bigint_col.0);
258 let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
259 assert_eq!(0, col.value(0));
260 assert_eq!(10, col.value(1));
261 assert_eq!(&DataType::Int64, bigint_col.1.data_type());
262 let float_col = schema.column_with_name("float_col").unwrap();
263 assert_eq!(6, float_col.0);
264 let col = get_col::<Float32Array>(&batch, float_col).unwrap();
265 assert_eq!(0.0, col.value(0));
266 assert_eq!(1.1, col.value(1));
267 assert_eq!(&DataType::Float32, float_col.1.data_type());
268 let col = get_col::<Float32Array>(&batch, float_col).unwrap();
269 assert_eq!(0.0, col.value(0));
270 assert_eq!(1.1, col.value(1));
271 let double_col = schema.column_with_name("double_col").unwrap();
272 assert_eq!(7, double_col.0);
273 assert_eq!(&DataType::Float64, double_col.1.data_type());
274 let col = get_col::<Float64Array>(&batch, double_col).unwrap();
275 assert_eq!(0.0, col.value(0));
276 assert_eq!(10.1, col.value(1));
277 let date_string_col = schema.column_with_name("date_string_col").unwrap();
278 assert_eq!(8, date_string_col.0);
279 assert_eq!(&DataType::Binary, date_string_col.1.data_type());
280 let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
281 assert_eq!("01/01/09".as_bytes(), col.value(0));
282 assert_eq!("01/01/09".as_bytes(), col.value(1));
283 let string_col = schema.column_with_name("string_col").unwrap();
284 assert_eq!(9, string_col.0);
285 assert_eq!(&DataType::Binary, string_col.1.data_type());
286 let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
287 assert_eq!("0".as_bytes(), col.value(0));
288 assert_eq!("1".as_bytes(), col.value(1));
289 let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
290 assert_eq!(10, timestamp_col.0);
291 assert_eq!(
292 &DataType::Timestamp(TimeUnit::Microsecond, None),
293 timestamp_col.1.data_type()
294 );
295 let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
296 assert_eq!(1230768000000000, col.value(0));
297 assert_eq!(1230768060000000, col.value(1));
298 }
299
300 #[test]
301 fn test_avro_with_projection() {
302 let projection = Some(vec![
304 "string_col".to_string(),
305 "double_col".to_string(),
306 "bool_col".to_string(),
307 ]);
308 let mut reader = build_reader("alltypes_dictionary.avro", projection);
309 let batch = reader.next().unwrap().unwrap();
310
311 assert_eq!(3, batch.num_columns());
313 assert_eq!(2, batch.num_rows());
314
315 let schema = reader.schema();
316 let batch_schema = batch.schema();
317 assert_eq!(schema, batch_schema);
318
319 assert_eq!("string_col", schema.field(0).name());
322 assert_eq!(&DataType::Binary, schema.field(0).data_type());
323 let col = batch
324 .column(0)
325 .as_any()
326 .downcast_ref::<BinaryArray>()
327 .unwrap();
328 assert_eq!("0".as_bytes(), col.value(0));
329 assert_eq!("1".as_bytes(), col.value(1));
330
331 assert_eq!("double_col", schema.field(1).name());
333 assert_eq!(&DataType::Float64, schema.field(1).data_type());
334 let col = batch
335 .column(1)
336 .as_any()
337 .downcast_ref::<Float64Array>()
338 .unwrap();
339 assert_eq!(0.0, col.value(0));
340 assert_eq!(10.1, col.value(1));
341
342 assert_eq!("bool_col", schema.field(2).name());
344 assert_eq!(&DataType::Boolean, schema.field(2).data_type());
345 let col = batch
346 .column(2)
347 .as_any()
348 .downcast_ref::<BooleanArray>()
349 .unwrap();
350 assert!(col.value(0));
351 assert!(!col.value(1));
352 }
353}