datafusion_datasource_arrow/
file_format.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions
19//!
20//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format)
21
22use std::any::Any;
23use std::borrow::Cow;
24use std::collections::HashMap;
25use std::fmt::{self, Debug};
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use arrow::error::ArrowError;
30use arrow::ipc::convert::fb_to_schema;
31use arrow::ipc::reader::FileReader;
32use arrow::ipc::writer::IpcWriteOptions;
33use arrow::ipc::{root_as_message, CompressionType};
34use datafusion_common::error::Result;
35use datafusion_common::parsers::CompressionTypeVariant;
36use datafusion_common::{
37    internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics,
38    DEFAULT_ARROW_EXTENSION,
39};
40use datafusion_common_runtime::{JoinSet, SpawnedTask};
41use datafusion_datasource::display::FileGroupDisplay;
42use datafusion_datasource::file::FileSource;
43use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
44use datafusion_datasource::sink::{DataSink, DataSinkExec};
45use datafusion_datasource::write::{
46    get_writer_schema, ObjectWriterBuilder, SharedBuffer,
47};
48use datafusion_execution::{SendableRecordBatchStream, TaskContext};
49use datafusion_expr::dml::InsertOp;
50use datafusion_physical_expr_common::sort_expr::LexRequirement;
51
52use crate::source::ArrowSource;
53use async_trait::async_trait;
54use bytes::Bytes;
55use datafusion_datasource::file_compression_type::FileCompressionType;
56use datafusion_datasource::file_format::{FileFormat, FileFormatFactory};
57use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
58use datafusion_datasource::source::DataSourceExec;
59use datafusion_datasource::write::demux::DemuxedStreamReceiver;
60use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
61use datafusion_session::Session;
62use futures::stream::BoxStream;
63use futures::StreamExt;
64use object_store::{GetResultPayload, ObjectMeta, ObjectStore};
65use tokio::io::AsyncWriteExt;
66
67/// Initial writing buffer size. Note this is just a size hint for efficiency. It
68/// will grow beyond the set value if needed.
69const INITIAL_BUFFER_BYTES: usize = 1048576;
70
71/// If the buffered Arrow data exceeds this size, it is flushed to object store
72const BUFFER_FLUSH_BYTES: usize = 1024000;
73
74#[derive(Default, Debug)]
75/// Factory struct used to create [ArrowFormat]
76pub struct ArrowFormatFactory;
77
78impl ArrowFormatFactory {
79    /// Creates an instance of [ArrowFormatFactory]
80    pub fn new() -> Self {
81        Self {}
82    }
83}
84
85impl FileFormatFactory for ArrowFormatFactory {
86    fn create(
87        &self,
88        _state: &dyn Session,
89        _format_options: &HashMap<String, String>,
90    ) -> Result<Arc<dyn FileFormat>> {
91        Ok(Arc::new(ArrowFormat))
92    }
93
94    fn default(&self) -> Arc<dyn FileFormat> {
95        Arc::new(ArrowFormat)
96    }
97
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101}
102
103impl GetExt for ArrowFormatFactory {
104    fn get_ext(&self) -> String {
105        // Removes the dot, i.e. ".parquet" -> "parquet"
106        DEFAULT_ARROW_EXTENSION[1..].to_string()
107    }
108}
109
110/// Arrow `FileFormat` implementation.
111#[derive(Default, Debug)]
112pub struct ArrowFormat;
113
114#[async_trait]
115impl FileFormat for ArrowFormat {
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119
120    fn get_ext(&self) -> String {
121        ArrowFormatFactory::new().get_ext()
122    }
123
124    fn get_ext_with_compression(
125        &self,
126        file_compression_type: &FileCompressionType,
127    ) -> Result<String> {
128        let ext = self.get_ext();
129        match file_compression_type.get_variant() {
130            CompressionTypeVariant::UNCOMPRESSED => Ok(ext),
131            _ => Err(internal_datafusion_err!(
132                "Arrow FileFormat does not support compression."
133            )),
134        }
135    }
136
137    fn compression_type(&self) -> Option<FileCompressionType> {
138        None
139    }
140
141    async fn infer_schema(
142        &self,
143        _state: &dyn Session,
144        store: &Arc<dyn ObjectStore>,
145        objects: &[ObjectMeta],
146    ) -> Result<SchemaRef> {
147        let mut schemas = vec![];
148        for object in objects {
149            let r = store.as_ref().get(&object.location).await?;
150            let schema = match r.payload {
151                #[cfg(not(target_arch = "wasm32"))]
152                GetResultPayload::File(mut file, _) => {
153                    let reader = FileReader::try_new(&mut file, None)?;
154                    reader.schema()
155                }
156                GetResultPayload::Stream(stream) => {
157                    infer_schema_from_file_stream(stream).await?
158                }
159            };
160            schemas.push(schema.as_ref().clone());
161        }
162        let merged_schema = Schema::try_merge(schemas)?;
163        Ok(Arc::new(merged_schema))
164    }
165
166    async fn infer_stats(
167        &self,
168        _state: &dyn Session,
169        _store: &Arc<dyn ObjectStore>,
170        table_schema: SchemaRef,
171        _object: &ObjectMeta,
172    ) -> Result<Statistics> {
173        Ok(Statistics::new_unknown(&table_schema))
174    }
175
176    async fn create_physical_plan(
177        &self,
178        _state: &dyn Session,
179        conf: FileScanConfig,
180    ) -> Result<Arc<dyn ExecutionPlan>> {
181        let source = Arc::new(ArrowSource::default());
182        let config = FileScanConfigBuilder::from(conf)
183            .with_source(source)
184            .build();
185
186        Ok(DataSourceExec::from_data_source(config))
187    }
188
189    async fn create_writer_physical_plan(
190        &self,
191        input: Arc<dyn ExecutionPlan>,
192        _state: &dyn Session,
193        conf: FileSinkConfig,
194        order_requirements: Option<LexRequirement>,
195    ) -> Result<Arc<dyn ExecutionPlan>> {
196        if conf.insert_op != InsertOp::Append {
197            return not_impl_err!("Overwrites are not implemented yet for Arrow format");
198        }
199
200        let sink = Arc::new(ArrowFileSink::new(conf));
201
202        Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
203    }
204
205    fn file_source(&self) -> Arc<dyn FileSource> {
206        Arc::new(ArrowSource::default())
207    }
208}
209
210/// Implements [`FileSink`] for writing to arrow_ipc files
211struct ArrowFileSink {
212    config: FileSinkConfig,
213}
214
215impl ArrowFileSink {
216    fn new(config: FileSinkConfig) -> Self {
217        Self { config }
218    }
219}
220
221#[async_trait]
222impl FileSink for ArrowFileSink {
223    fn config(&self) -> &FileSinkConfig {
224        &self.config
225    }
226
227    async fn spawn_writer_tasks_and_join(
228        &self,
229        context: &Arc<TaskContext>,
230        demux_task: SpawnedTask<Result<()>>,
231        mut file_stream_rx: DemuxedStreamReceiver,
232        object_store: Arc<dyn ObjectStore>,
233    ) -> Result<u64> {
234        let mut file_write_tasks: JoinSet<std::result::Result<usize, DataFusionError>> =
235            JoinSet::new();
236
237        let ipc_options =
238            IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)?
239                .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
240        while let Some((path, mut rx)) = file_stream_rx.recv().await {
241            let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES);
242            let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options(
243                shared_buffer.clone(),
244                &get_writer_schema(&self.config),
245                ipc_options.clone(),
246            )?;
247            let mut object_store_writer = ObjectWriterBuilder::new(
248                FileCompressionType::UNCOMPRESSED,
249                &path,
250                Arc::clone(&object_store),
251            )
252            .with_buffer_size(Some(
253                context
254                    .session_config()
255                    .options()
256                    .execution
257                    .objectstore_writer_buffer_size,
258            ))
259            .build()?;
260            file_write_tasks.spawn(async move {
261                let mut row_count = 0;
262                while let Some(batch) = rx.recv().await {
263                    row_count += batch.num_rows();
264                    arrow_writer.write(&batch)?;
265                    let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap();
266                    if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
267                        object_store_writer
268                            .write_all(buff_to_flush.as_slice())
269                            .await?;
270                        buff_to_flush.clear();
271                    }
272                }
273                arrow_writer.finish()?;
274                let final_buff = shared_buffer.buffer.try_lock().unwrap();
275
276                object_store_writer.write_all(final_buff.as_slice()).await?;
277                object_store_writer.shutdown().await?;
278                Ok(row_count)
279            });
280        }
281
282        let mut row_count = 0;
283        while let Some(result) = file_write_tasks.join_next().await {
284            match result {
285                Ok(r) => {
286                    row_count += r?;
287                }
288                Err(e) => {
289                    if e.is_panic() {
290                        std::panic::resume_unwind(e.into_panic());
291                    } else {
292                        unreachable!();
293                    }
294                }
295            }
296        }
297
298        demux_task
299            .join_unwind()
300            .await
301            .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
302        Ok(row_count as u64)
303    }
304}
305
306impl Debug for ArrowFileSink {
307    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308        f.debug_struct("ArrowFileSink").finish()
309    }
310}
311
312impl DisplayAs for ArrowFileSink {
313    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        match t {
315            DisplayFormatType::Default | DisplayFormatType::Verbose => {
316                write!(f, "ArrowFileSink(file_groups=",)?;
317                FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
318                write!(f, ")")
319            }
320            DisplayFormatType::TreeRender => {
321                writeln!(f, "format: arrow")?;
322                write!(f, "file={}", &self.config.original_url)
323            }
324        }
325    }
326}
327
328#[async_trait]
329impl DataSink for ArrowFileSink {
330    fn as_any(&self) -> &dyn Any {
331        self
332    }
333
334    fn schema(&self) -> &SchemaRef {
335        self.config.output_schema()
336    }
337
338    async fn write_all(
339        &self,
340        data: SendableRecordBatchStream,
341        context: &Arc<TaskContext>,
342    ) -> Result<u64> {
343        FileSink::write_all(self, data, context).await
344    }
345}
346
347const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
348const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
349
350/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs.
351/// See <https://github.com/apache/arrow-rs/issues/5021>
352async fn infer_schema_from_file_stream(
353    mut stream: BoxStream<'static, object_store::Result<Bytes>>,
354) -> Result<SchemaRef> {
355    // Expected format:
356    // <magic number "ARROW1"> - 6 bytes
357    // <empty padding bytes [to 8 byte boundary]> - 2 bytes
358    // <continuation: 0xFFFFFFFF> - 4 bytes, not present below v0.15.0
359    // <metadata_size: int32> - 4 bytes
360    // <metadata_flatbuffer: bytes>
361    // <rest of file bytes>
362
363    // So in first read we need at least all known sized sections,
364    // which is 6 + 2 + 4 + 4 = 16 bytes.
365    let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?;
366
367    // Files should start with these magic bytes
368    if bytes[0..6] != ARROW_MAGIC {
369        return Err(ArrowError::ParseError(
370            "Arrow file does not contain correct header".to_string(),
371        ))?;
372    }
373
374    // Since continuation marker bytes added in later versions
375    let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER {
376        (&bytes[12..16], 16)
377    } else {
378        (&bytes[8..12], 12)
379    };
380
381    let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]];
382    let meta_len = i32::from_le_bytes(meta_len);
383
384    // Read bytes for Schema message
385    let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize {
386        // Need to read more bytes to decode Message
387        let mut block_data = Vec::with_capacity(meta_len as usize);
388        // In case we had some spare bytes in our initial read chunk
389        block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]);
390        let size_to_read = meta_len as usize - block_data.len();
391        let block_data =
392            collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?;
393        Cow::Owned(block_data)
394    } else {
395        // Already have the bytes we need
396        let end_index = meta_len as usize + rest_of_bytes_start_index;
397        let block_data = &bytes[rest_of_bytes_start_index..end_index];
398        Cow::Borrowed(block_data)
399    };
400
401    // Decode Schema message
402    let message = root_as_message(&block_data).map_err(|err| {
403        ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}"))
404    })?;
405    let ipc_schema = message.header_as_schema().ok_or_else(|| {
406        ArrowError::IpcError("Unable to read IPC message as schema".to_string())
407    })?;
408    let schema = fb_to_schema(ipc_schema);
409
410    Ok(Arc::new(schema))
411}
412
413async fn collect_at_least_n_bytes(
414    stream: &mut BoxStream<'static, object_store::Result<Bytes>>,
415    n: usize,
416    extend_from: Option<Vec<u8>>,
417) -> Result<Vec<u8>> {
418    let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n));
419    // If extending existing buffer then ensure we read n additional bytes
420    let n = n + buf.len();
421    while let Some(bytes) = stream.next().await.transpose()? {
422        buf.extend_from_slice(&bytes);
423        if buf.len() >= n {
424            break;
425        }
426    }
427    if buf.len() < n {
428        return Err(ArrowError::ParseError(
429            "Unexpected end of byte stream for Arrow IPC file".to_string(),
430        ))?;
431    }
432    Ok(buf)
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    use chrono::DateTime;
440    use datafusion_common::config::TableOptions;
441    use datafusion_common::DFSchema;
442    use datafusion_execution::config::SessionConfig;
443    use datafusion_execution::runtime_env::RuntimeEnv;
444    use datafusion_expr::execution_props::ExecutionProps;
445    use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF};
446    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
447    use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path};
448
449    struct MockSession {
450        config: SessionConfig,
451        runtime_env: Arc<RuntimeEnv>,
452    }
453
454    impl MockSession {
455        fn new() -> Self {
456            Self {
457                config: SessionConfig::new(),
458                runtime_env: Arc::new(RuntimeEnv::default()),
459            }
460        }
461    }
462
463    #[async_trait::async_trait]
464    impl Session for MockSession {
465        fn session_id(&self) -> &str {
466            unimplemented!()
467        }
468
469        fn config(&self) -> &SessionConfig {
470            &self.config
471        }
472
473        async fn create_physical_plan(
474            &self,
475            _logical_plan: &LogicalPlan,
476        ) -> Result<Arc<dyn ExecutionPlan>> {
477            unimplemented!()
478        }
479
480        fn create_physical_expr(
481            &self,
482            _expr: Expr,
483            _df_schema: &DFSchema,
484        ) -> Result<Arc<dyn PhysicalExpr>> {
485            unimplemented!()
486        }
487
488        fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
489            unimplemented!()
490        }
491
492        fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
493            unimplemented!()
494        }
495
496        fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
497            unimplemented!()
498        }
499
500        fn runtime_env(&self) -> &Arc<RuntimeEnv> {
501            &self.runtime_env
502        }
503
504        fn execution_props(&self) -> &ExecutionProps {
505            unimplemented!()
506        }
507
508        fn as_any(&self) -> &dyn Any {
509            unimplemented!()
510        }
511
512        fn table_options(&self) -> &TableOptions {
513            unimplemented!()
514        }
515
516        fn table_options_mut(&mut self) -> &mut TableOptions {
517            unimplemented!()
518        }
519
520        fn task_ctx(&self) -> Arc<TaskContext> {
521            unimplemented!()
522        }
523    }
524
525    #[tokio::test]
526    async fn test_infer_schema_stream() -> Result<()> {
527        let mut bytes = std::fs::read("tests/data/example.arrow")?;
528        bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file
529        let location = Path::parse("example.arrow")?;
530        let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
531        in_memory_store.put(&location, bytes.into()).await?;
532
533        let state = MockSession::new();
534        let object_meta = ObjectMeta {
535            location,
536            last_modified: DateTime::default(),
537            size: u64::MAX,
538            e_tag: None,
539            version: None,
540        };
541
542        let arrow_format = ArrowFormat {};
543        let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"];
544
545        // Test chunk sizes where too small so we keep having to read more bytes
546        // And when large enough that first read contains all we need
547        for chunk_size in [7, 3000] {
548            let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size));
549            let inferred_schema = arrow_format
550                .infer_schema(
551                    &state,
552                    &(store.clone() as Arc<dyn ObjectStore>),
553                    std::slice::from_ref(&object_meta),
554                )
555                .await?;
556            let actual_fields = inferred_schema
557                .fields()
558                .iter()
559                .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
560                .collect::<Vec<_>>();
561            assert_eq!(expected, actual_fields);
562        }
563
564        Ok(())
565    }
566
567    #[tokio::test]
568    async fn test_infer_schema_short_stream() -> Result<()> {
569        let mut bytes = std::fs::read("tests/data/example.arrow")?;
570        bytes.truncate(20); // should cause error that file shorter than expected
571        let location = Path::parse("example.arrow")?;
572        let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
573        in_memory_store.put(&location, bytes.into()).await?;
574
575        let state = MockSession::new();
576        let object_meta = ObjectMeta {
577            location,
578            last_modified: DateTime::default(),
579            size: u64::MAX,
580            e_tag: None,
581            version: None,
582        };
583
584        let arrow_format = ArrowFormat {};
585
586        let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7));
587        let err = arrow_format
588            .infer_schema(
589                &state,
590                &(store.clone() as Arc<dyn ObjectStore>),
591                std::slice::from_ref(&object_meta),
592            )
593            .await;
594
595        assert!(err.is_err());
596        assert_eq!(
597            "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file",
598            err.unwrap_err().to_string().lines().next().unwrap()
599        );
600
601        Ok(())
602    }
603}