datafusion/datasource/
listing_table_factory.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Factory for creating ListingTables with default options
19
20use std::collections::HashSet;
21use std::path::Path;
22use std::sync::Arc;
23
24use crate::catalog::{TableProvider, TableProviderFactory};
25use crate::datasource::listing::{
26    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
27};
28use crate::execution::context::SessionState;
29
30use arrow::datatypes::DataType;
31use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema};
32use datafusion_common::{config_datafusion_err, Result};
33use datafusion_expr::CreateExternalTable;
34
35use async_trait::async_trait;
36use datafusion_catalog::Session;
37
38/// A `TableProviderFactory` capable of creating new `ListingTable`s
39#[derive(Debug, Default)]
40pub struct ListingTableFactory {}
41
42impl ListingTableFactory {
43    /// Creates a new `ListingTableFactory`
44    pub fn new() -> Self {
45        Self::default()
46    }
47}
48
49#[async_trait]
50impl TableProviderFactory for ListingTableFactory {
51    async fn create(
52        &self,
53        state: &dyn Session,
54        cmd: &CreateExternalTable,
55    ) -> Result<Arc<dyn TableProvider>> {
56        // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state?
57        let session_state = state.as_any().downcast_ref::<SessionState>().unwrap();
58        let file_format = session_state
59            .get_file_format_factory(cmd.file_type.as_str())
60            .ok_or(config_datafusion_err!(
61                "Unable to create table with format {}! Could not find FileFormat.",
62                cmd.file_type
63            ))?
64            .create(session_state, &cmd.options)?;
65
66        let mut table_path = ListingTableUrl::parse(&cmd.location)?;
67        let file_extension = match table_path.is_collection() {
68            // Setting the extension to be empty instead of allowing the default extension seems
69            // odd, but was done to ensure existing behavior isn't modified. It seems like this
70            // could be refactored to either use the default extension or set the fully expected
71            // extension when compression is included (e.g. ".csv.gz")
72            true => "",
73            false => &get_extension(cmd.location.as_str()),
74        };
75        let mut options = ListingOptions::new(file_format)
76            .with_session_config_options(session_state.config())
77            .with_file_extension(file_extension);
78
79        let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() {
80            let infer_parts = session_state
81                .config_options()
82                .execution
83                .listing_table_factory_infer_partitions;
84            let part_cols = if cmd.table_partition_cols.is_empty() && infer_parts {
85                options
86                    .infer_partitions(session_state, &table_path)
87                    .await?
88                    .into_iter()
89            } else {
90                cmd.table_partition_cols.clone().into_iter()
91            };
92
93            (
94                None,
95                part_cols
96                    .map(|p| {
97                        (
98                            p,
99                            DataType::Dictionary(
100                                Box::new(DataType::UInt16),
101                                Box::new(DataType::Utf8),
102                            ),
103                        )
104                    })
105                    .collect::<Vec<_>>(),
106            )
107        } else {
108            let schema = Arc::clone(cmd.schema.inner());
109            let table_partition_cols = cmd
110                .table_partition_cols
111                .iter()
112                .map(|col| {
113                    schema
114                        .field_with_name(col)
115                        .map_err(|e| arrow_datafusion_err!(e))
116                })
117                .collect::<Result<Vec<_>>>()?
118                .into_iter()
119                .map(|f| (f.name().to_owned(), f.data_type().to_owned()))
120                .collect();
121            // exclude partition columns to support creating partitioned external table
122            // with a specified column definition like
123            // `create external table a(c0 int, c1 int) stored as csv partitioned by (c1)...`
124            let mut project_idx = Vec::new();
125            for i in 0..schema.fields().len() {
126                if !cmd.table_partition_cols.contains(schema.field(i).name()) {
127                    project_idx.push(i);
128                }
129            }
130            let schema = Arc::new(schema.project(&project_idx)?);
131            (Some(schema), table_partition_cols)
132        };
133
134        options = options.with_table_partition_cols(table_partition_cols);
135
136        options
137            .validate_partitions(session_state, &table_path)
138            .await?;
139
140        let resolved_schema = match provided_schema {
141            // We will need to check the table columns against the schema
142            // this is done so that we can do an ORDER BY for external table creation
143            // specifically for parquet file format.
144            // See: https://github.com/apache/datafusion/issues/7317
145            None => {
146                // if the folder then rewrite a file path as 'path/*.parquet'
147                // to only read the files the reader can understand
148                if table_path.is_folder() && table_path.get_glob().is_none() {
149                    // Since there are no files yet to infer an actual extension,
150                    // derive the pattern based on compression type.
151                    // So for gzipped CSV the pattern is `*.csv.gz`
152                    let glob = match options.format.compression_type() {
153                        Some(compression) => {
154                            match options.format.get_ext_with_compression(&compression) {
155                                // Use glob based on `FileFormat` extension
156                                Ok(ext) => format!("*.{ext}"),
157                                // Fallback to `file_type`, if not supported by `FileFormat`
158                                Err(_) => format!("*.{}", cmd.file_type.to_lowercase()),
159                            }
160                        }
161                        None => format!("*.{}", cmd.file_type.to_lowercase()),
162                    };
163                    table_path = table_path.with_glob(glob.as_ref())?;
164                }
165                let schema = options.infer_schema(session_state, &table_path).await?;
166                let df_schema = Arc::clone(&schema).to_dfschema()?;
167                let column_refs: HashSet<_> = cmd
168                    .order_exprs
169                    .iter()
170                    .flat_map(|sort| sort.iter())
171                    .flat_map(|s| s.expr.column_refs())
172                    .collect();
173
174                for column in &column_refs {
175                    if !df_schema.has_column(column) {
176                        return plan_err!("Column {column} is not in schema");
177                    }
178                }
179
180                schema
181            }
182            Some(s) => s,
183        };
184        let config = ListingTableConfig::new(table_path)
185            .with_listing_options(options.with_file_sort_order(cmd.order_exprs.clone()))
186            .with_schema(resolved_schema);
187        let provider = ListingTable::try_new(config)?
188            .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache());
189        let table = provider
190            .with_definition(cmd.definition.clone())
191            .with_constraints(cmd.constraints.clone())
192            .with_column_defaults(cmd.column_defaults.clone());
193        Ok(Arc::new(table))
194    }
195}
196
197// Get file extension from path
198fn get_extension(path: &str) -> String {
199    let res = Path::new(path).extension().and_then(|ext| ext.to_str());
200    match res {
201        Some(ext) => format!(".{ext}"),
202        None => "".to_string(),
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use datafusion_execution::config::SessionConfig;
209    use glob::Pattern;
210    use std::collections::HashMap;
211    use std::fs;
212    use std::path::PathBuf;
213
214    use super::*;
215    use crate::{
216        datasource::file_format::csv::CsvFormat, execution::context::SessionContext,
217    };
218
219    use datafusion_common::parsers::CompressionTypeVariant;
220    use datafusion_common::{Constraints, DFSchema, TableReference};
221
222    #[tokio::test]
223    async fn test_create_using_non_std_file_ext() {
224        let csv_file = tempfile::Builder::new()
225            .prefix("foo")
226            .suffix(".tbl")
227            .tempfile()
228            .unwrap();
229
230        let factory = ListingTableFactory::new();
231        let context = SessionContext::new();
232        let state = context.state();
233        let name = TableReference::bare("foo");
234        let cmd = CreateExternalTable {
235            name,
236            location: csv_file.path().to_str().unwrap().to_string(),
237            file_type: "csv".to_string(),
238            schema: Arc::new(DFSchema::empty()),
239            table_partition_cols: vec![],
240            if_not_exists: false,
241            or_replace: false,
242            temporary: false,
243            definition: None,
244            order_exprs: vec![],
245            unbounded: false,
246            options: HashMap::from([("format.has_header".into(), "true".into())]),
247            constraints: Constraints::default(),
248            column_defaults: HashMap::new(),
249        };
250        let table_provider = factory.create(&state, &cmd).await.unwrap();
251        let listing_table = table_provider
252            .as_any()
253            .downcast_ref::<ListingTable>()
254            .unwrap();
255        let listing_options = listing_table.options();
256        assert_eq!(".tbl", listing_options.file_extension);
257    }
258
259    #[tokio::test]
260    async fn test_create_using_non_std_file_ext_csv_options() {
261        let csv_file = tempfile::Builder::new()
262            .prefix("foo")
263            .suffix(".tbl")
264            .tempfile()
265            .unwrap();
266
267        let factory = ListingTableFactory::new();
268        let context = SessionContext::new();
269        let state = context.state();
270        let name = TableReference::bare("foo");
271
272        let mut options = HashMap::new();
273        options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
274        options.insert("format.has_header".into(), "true".into());
275        let cmd = CreateExternalTable {
276            name,
277            location: csv_file.path().to_str().unwrap().to_string(),
278            file_type: "csv".to_string(),
279            schema: Arc::new(DFSchema::empty()),
280            table_partition_cols: vec![],
281            if_not_exists: false,
282            or_replace: false,
283            temporary: false,
284            definition: None,
285            order_exprs: vec![],
286            unbounded: false,
287            options,
288            constraints: Constraints::default(),
289            column_defaults: HashMap::new(),
290        };
291        let table_provider = factory.create(&state, &cmd).await.unwrap();
292        let listing_table = table_provider
293            .as_any()
294            .downcast_ref::<ListingTable>()
295            .unwrap();
296
297        let format = listing_table.options().format.clone();
298        let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
299        let csv_options = csv_format.options().clone();
300        assert_eq!(csv_options.schema_infer_max_rec, Some(1000));
301        let listing_options = listing_table.options();
302        assert_eq!(".tbl", listing_options.file_extension);
303    }
304
305    /// Validates that CreateExternalTable with compression
306    /// searches for gzipped files in a directory location
307    #[tokio::test]
308    async fn test_create_using_folder_with_compression() {
309        let dir = tempfile::tempdir().unwrap();
310
311        let factory = ListingTableFactory::new();
312        let context = SessionContext::new();
313        let state = context.state();
314        let name = TableReference::bare("foo");
315
316        let mut options = HashMap::new();
317        options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
318        options.insert("format.has_header".into(), "true".into());
319        options.insert("format.compression".into(), "gzip".into());
320        let cmd = CreateExternalTable {
321            name,
322            location: dir.path().to_str().unwrap().to_string(),
323            file_type: "csv".to_string(),
324            schema: Arc::new(DFSchema::empty()),
325            table_partition_cols: vec![],
326            if_not_exists: false,
327            or_replace: false,
328            temporary: false,
329            definition: None,
330            order_exprs: vec![],
331            unbounded: false,
332            options,
333            constraints: Constraints::default(),
334            column_defaults: HashMap::new(),
335        };
336        let table_provider = factory.create(&state, &cmd).await.unwrap();
337        let listing_table = table_provider
338            .as_any()
339            .downcast_ref::<ListingTable>()
340            .unwrap();
341
342        // Verify compression is used
343        let format = listing_table.options().format.clone();
344        let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
345        let csv_options = csv_format.options().clone();
346        assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP);
347
348        let listing_options = listing_table.options();
349        assert_eq!("", listing_options.file_extension);
350        // Glob pattern is set to search for gzipped files
351        let table_path = listing_table.table_paths().first().unwrap();
352        assert_eq!(
353            table_path.get_glob().clone().unwrap(),
354            Pattern::new("*.csv.gz").unwrap()
355        );
356    }
357
358    /// Validates that CreateExternalTable without compression
359    /// searches for normal files in a directory location
360    #[tokio::test]
361    async fn test_create_using_folder_without_compression() {
362        let dir = tempfile::tempdir().unwrap();
363
364        let factory = ListingTableFactory::new();
365        let context = SessionContext::new();
366        let state = context.state();
367        let name = TableReference::bare("foo");
368
369        let mut options = HashMap::new();
370        options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
371        options.insert("format.has_header".into(), "true".into());
372        let cmd = CreateExternalTable {
373            name,
374            location: dir.path().to_str().unwrap().to_string(),
375            file_type: "csv".to_string(),
376            schema: Arc::new(DFSchema::empty()),
377            table_partition_cols: vec![],
378            if_not_exists: false,
379            or_replace: false,
380            temporary: false,
381            definition: None,
382            order_exprs: vec![],
383            unbounded: false,
384            options,
385            constraints: Constraints::default(),
386            column_defaults: HashMap::new(),
387        };
388        let table_provider = factory.create(&state, &cmd).await.unwrap();
389        let listing_table = table_provider
390            .as_any()
391            .downcast_ref::<ListingTable>()
392            .unwrap();
393
394        let listing_options = listing_table.options();
395        assert_eq!("", listing_options.file_extension);
396        // Glob pattern is set to search for gzipped files
397        let table_path = listing_table.table_paths().first().unwrap();
398        assert_eq!(
399            table_path.get_glob().clone().unwrap(),
400            Pattern::new("*.csv").unwrap()
401        );
402    }
403
404    #[tokio::test]
405    async fn test_odd_directory_names() {
406        let dir = tempfile::tempdir().unwrap();
407        let mut path = PathBuf::from(dir.path());
408        path.extend(["odd.v1", "odd.v2"]);
409        fs::create_dir_all(&path).unwrap();
410
411        let factory = ListingTableFactory::new();
412        let context = SessionContext::new();
413        let state = context.state();
414        let name = TableReference::bare("foo");
415
416        let cmd = CreateExternalTable {
417            name,
418            location: String::from(path.to_str().unwrap()),
419            file_type: "parquet".to_string(),
420            schema: Arc::new(DFSchema::empty()),
421            table_partition_cols: vec![],
422            if_not_exists: false,
423            or_replace: false,
424            temporary: false,
425            definition: None,
426            order_exprs: vec![],
427            unbounded: false,
428            options: HashMap::new(),
429            constraints: Constraints::default(),
430            column_defaults: HashMap::new(),
431        };
432        let table_provider = factory.create(&state, &cmd).await.unwrap();
433        let listing_table = table_provider
434            .as_any()
435            .downcast_ref::<ListingTable>()
436            .unwrap();
437
438        let listing_options = listing_table.options();
439        assert_eq!("", listing_options.file_extension);
440    }
441
442    #[tokio::test]
443    async fn test_create_with_hive_partitions() {
444        let dir = tempfile::tempdir().unwrap();
445        let mut path = PathBuf::from(dir.path());
446        path.extend(["key1=value1", "key2=value2"]);
447        fs::create_dir_all(&path).unwrap();
448        path.push("data.parquet");
449        fs::File::create_new(&path).unwrap();
450
451        let factory = ListingTableFactory::new();
452        let context = SessionContext::new();
453        let state = context.state();
454        let name = TableReference::bare("foo");
455
456        let cmd = CreateExternalTable {
457            name,
458            location: dir.path().to_str().unwrap().to_string(),
459            file_type: "parquet".to_string(),
460            schema: Arc::new(DFSchema::empty()),
461            table_partition_cols: vec![],
462            if_not_exists: false,
463            or_replace: false,
464            temporary: false,
465            definition: None,
466            order_exprs: vec![],
467            unbounded: false,
468            options: HashMap::new(),
469            constraints: Constraints::default(),
470            column_defaults: HashMap::new(),
471        };
472        let table_provider = factory.create(&state, &cmd).await.unwrap();
473        let listing_table = table_provider
474            .as_any()
475            .downcast_ref::<ListingTable>()
476            .unwrap();
477
478        let listing_options = listing_table.options();
479        let dtype =
480            DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8));
481        let expected_cols = vec![
482            (String::from("key1"), dtype.clone()),
483            (String::from("key2"), dtype.clone()),
484        ];
485        assert_eq!(expected_cols, listing_options.table_partition_cols);
486
487        // Ensure partition detection can be disabled via config
488        let factory = ListingTableFactory::new();
489        let mut cfg = SessionConfig::new();
490        cfg.options_mut()
491            .execution
492            .listing_table_factory_infer_partitions = false;
493        let context = SessionContext::new_with_config(cfg);
494        let state = context.state();
495        let name = TableReference::bare("foo");
496
497        let cmd = CreateExternalTable {
498            name,
499            location: dir.path().to_str().unwrap().to_string(),
500            file_type: "parquet".to_string(),
501            schema: Arc::new(DFSchema::empty()),
502            table_partition_cols: vec![],
503            if_not_exists: false,
504            or_replace: false,
505            temporary: false,
506            definition: None,
507            order_exprs: vec![],
508            unbounded: false,
509            options: HashMap::new(),
510            constraints: Constraints::default(),
511            column_defaults: HashMap::new(),
512        };
513        let table_provider = factory.create(&state, &cmd).await.unwrap();
514        let listing_table = table_provider
515            .as_any()
516            .downcast_ref::<ListingTable>()
517            .unwrap();
518
519        let listing_options = listing_table.options();
520        assert!(listing_options.table_partition_cols.is_empty());
521    }
522}