1use std::collections::HashSet;
21use std::path::Path;
22use std::sync::Arc;
23
24use crate::catalog::{TableProvider, TableProviderFactory};
25use crate::datasource::listing::{
26 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
27};
28use crate::execution::context::SessionState;
29
30use arrow::datatypes::DataType;
31use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema};
32use datafusion_common::{config_datafusion_err, Result};
33use datafusion_expr::CreateExternalTable;
34
35use async_trait::async_trait;
36use datafusion_catalog::Session;
37
38#[derive(Debug, Default)]
40pub struct ListingTableFactory {}
41
42impl ListingTableFactory {
43 pub fn new() -> Self {
45 Self::default()
46 }
47}
48
49#[async_trait]
50impl TableProviderFactory for ListingTableFactory {
51 async fn create(
52 &self,
53 state: &dyn Session,
54 cmd: &CreateExternalTable,
55 ) -> Result<Arc<dyn TableProvider>> {
56 let session_state = state.as_any().downcast_ref::<SessionState>().unwrap();
58 let file_format = session_state
59 .get_file_format_factory(cmd.file_type.as_str())
60 .ok_or(config_datafusion_err!(
61 "Unable to create table with format {}! Could not find FileFormat.",
62 cmd.file_type
63 ))?
64 .create(session_state, &cmd.options)?;
65
66 let mut table_path = ListingTableUrl::parse(&cmd.location)?;
67 let file_extension = match table_path.is_collection() {
68 true => "",
73 false => &get_extension(cmd.location.as_str()),
74 };
75 let mut options = ListingOptions::new(file_format)
76 .with_session_config_options(session_state.config())
77 .with_file_extension(file_extension);
78
79 let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() {
80 let infer_parts = session_state
81 .config_options()
82 .execution
83 .listing_table_factory_infer_partitions;
84 let part_cols = if cmd.table_partition_cols.is_empty() && infer_parts {
85 options
86 .infer_partitions(session_state, &table_path)
87 .await?
88 .into_iter()
89 } else {
90 cmd.table_partition_cols.clone().into_iter()
91 };
92
93 (
94 None,
95 part_cols
96 .map(|p| {
97 (
98 p,
99 DataType::Dictionary(
100 Box::new(DataType::UInt16),
101 Box::new(DataType::Utf8),
102 ),
103 )
104 })
105 .collect::<Vec<_>>(),
106 )
107 } else {
108 let schema = Arc::clone(cmd.schema.inner());
109 let table_partition_cols = cmd
110 .table_partition_cols
111 .iter()
112 .map(|col| {
113 schema
114 .field_with_name(col)
115 .map_err(|e| arrow_datafusion_err!(e))
116 })
117 .collect::<Result<Vec<_>>>()?
118 .into_iter()
119 .map(|f| (f.name().to_owned(), f.data_type().to_owned()))
120 .collect();
121 let mut project_idx = Vec::new();
125 for i in 0..schema.fields().len() {
126 if !cmd.table_partition_cols.contains(schema.field(i).name()) {
127 project_idx.push(i);
128 }
129 }
130 let schema = Arc::new(schema.project(&project_idx)?);
131 (Some(schema), table_partition_cols)
132 };
133
134 options = options.with_table_partition_cols(table_partition_cols);
135
136 options
137 .validate_partitions(session_state, &table_path)
138 .await?;
139
140 let resolved_schema = match provided_schema {
141 None => {
146 if table_path.is_folder() && table_path.get_glob().is_none() {
149 let glob = match options.format.compression_type() {
153 Some(compression) => {
154 match options.format.get_ext_with_compression(&compression) {
155 Ok(ext) => format!("*.{ext}"),
157 Err(_) => format!("*.{}", cmd.file_type.to_lowercase()),
159 }
160 }
161 None => format!("*.{}", cmd.file_type.to_lowercase()),
162 };
163 table_path = table_path.with_glob(glob.as_ref())?;
164 }
165 let schema = options.infer_schema(session_state, &table_path).await?;
166 let df_schema = Arc::clone(&schema).to_dfschema()?;
167 let column_refs: HashSet<_> = cmd
168 .order_exprs
169 .iter()
170 .flat_map(|sort| sort.iter())
171 .flat_map(|s| s.expr.column_refs())
172 .collect();
173
174 for column in &column_refs {
175 if !df_schema.has_column(column) {
176 return plan_err!("Column {column} is not in schema");
177 }
178 }
179
180 schema
181 }
182 Some(s) => s,
183 };
184 let config = ListingTableConfig::new(table_path)
185 .with_listing_options(options.with_file_sort_order(cmd.order_exprs.clone()))
186 .with_schema(resolved_schema);
187 let provider = ListingTable::try_new(config)?
188 .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache());
189 let table = provider
190 .with_definition(cmd.definition.clone())
191 .with_constraints(cmd.constraints.clone())
192 .with_column_defaults(cmd.column_defaults.clone());
193 Ok(Arc::new(table))
194 }
195}
196
197fn get_extension(path: &str) -> String {
199 let res = Path::new(path).extension().and_then(|ext| ext.to_str());
200 match res {
201 Some(ext) => format!(".{ext}"),
202 None => "".to_string(),
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use datafusion_execution::config::SessionConfig;
209 use glob::Pattern;
210 use std::collections::HashMap;
211 use std::fs;
212 use std::path::PathBuf;
213
214 use super::*;
215 use crate::{
216 datasource::file_format::csv::CsvFormat, execution::context::SessionContext,
217 };
218
219 use datafusion_common::parsers::CompressionTypeVariant;
220 use datafusion_common::{Constraints, DFSchema, TableReference};
221
222 #[tokio::test]
223 async fn test_create_using_non_std_file_ext() {
224 let csv_file = tempfile::Builder::new()
225 .prefix("foo")
226 .suffix(".tbl")
227 .tempfile()
228 .unwrap();
229
230 let factory = ListingTableFactory::new();
231 let context = SessionContext::new();
232 let state = context.state();
233 let name = TableReference::bare("foo");
234 let cmd = CreateExternalTable {
235 name,
236 location: csv_file.path().to_str().unwrap().to_string(),
237 file_type: "csv".to_string(),
238 schema: Arc::new(DFSchema::empty()),
239 table_partition_cols: vec![],
240 if_not_exists: false,
241 or_replace: false,
242 temporary: false,
243 definition: None,
244 order_exprs: vec![],
245 unbounded: false,
246 options: HashMap::from([("format.has_header".into(), "true".into())]),
247 constraints: Constraints::default(),
248 column_defaults: HashMap::new(),
249 };
250 let table_provider = factory.create(&state, &cmd).await.unwrap();
251 let listing_table = table_provider
252 .as_any()
253 .downcast_ref::<ListingTable>()
254 .unwrap();
255 let listing_options = listing_table.options();
256 assert_eq!(".tbl", listing_options.file_extension);
257 }
258
259 #[tokio::test]
260 async fn test_create_using_non_std_file_ext_csv_options() {
261 let csv_file = tempfile::Builder::new()
262 .prefix("foo")
263 .suffix(".tbl")
264 .tempfile()
265 .unwrap();
266
267 let factory = ListingTableFactory::new();
268 let context = SessionContext::new();
269 let state = context.state();
270 let name = TableReference::bare("foo");
271
272 let mut options = HashMap::new();
273 options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
274 options.insert("format.has_header".into(), "true".into());
275 let cmd = CreateExternalTable {
276 name,
277 location: csv_file.path().to_str().unwrap().to_string(),
278 file_type: "csv".to_string(),
279 schema: Arc::new(DFSchema::empty()),
280 table_partition_cols: vec![],
281 if_not_exists: false,
282 or_replace: false,
283 temporary: false,
284 definition: None,
285 order_exprs: vec![],
286 unbounded: false,
287 options,
288 constraints: Constraints::default(),
289 column_defaults: HashMap::new(),
290 };
291 let table_provider = factory.create(&state, &cmd).await.unwrap();
292 let listing_table = table_provider
293 .as_any()
294 .downcast_ref::<ListingTable>()
295 .unwrap();
296
297 let format = listing_table.options().format.clone();
298 let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
299 let csv_options = csv_format.options().clone();
300 assert_eq!(csv_options.schema_infer_max_rec, Some(1000));
301 let listing_options = listing_table.options();
302 assert_eq!(".tbl", listing_options.file_extension);
303 }
304
305 #[tokio::test]
308 async fn test_create_using_folder_with_compression() {
309 let dir = tempfile::tempdir().unwrap();
310
311 let factory = ListingTableFactory::new();
312 let context = SessionContext::new();
313 let state = context.state();
314 let name = TableReference::bare("foo");
315
316 let mut options = HashMap::new();
317 options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
318 options.insert("format.has_header".into(), "true".into());
319 options.insert("format.compression".into(), "gzip".into());
320 let cmd = CreateExternalTable {
321 name,
322 location: dir.path().to_str().unwrap().to_string(),
323 file_type: "csv".to_string(),
324 schema: Arc::new(DFSchema::empty()),
325 table_partition_cols: vec![],
326 if_not_exists: false,
327 or_replace: false,
328 temporary: false,
329 definition: None,
330 order_exprs: vec![],
331 unbounded: false,
332 options,
333 constraints: Constraints::default(),
334 column_defaults: HashMap::new(),
335 };
336 let table_provider = factory.create(&state, &cmd).await.unwrap();
337 let listing_table = table_provider
338 .as_any()
339 .downcast_ref::<ListingTable>()
340 .unwrap();
341
342 let format = listing_table.options().format.clone();
344 let csv_format = format.as_any().downcast_ref::<CsvFormat>().unwrap();
345 let csv_options = csv_format.options().clone();
346 assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP);
347
348 let listing_options = listing_table.options();
349 assert_eq!("", listing_options.file_extension);
350 let table_path = listing_table.table_paths().first().unwrap();
352 assert_eq!(
353 table_path.get_glob().clone().unwrap(),
354 Pattern::new("*.csv.gz").unwrap()
355 );
356 }
357
358 #[tokio::test]
361 async fn test_create_using_folder_without_compression() {
362 let dir = tempfile::tempdir().unwrap();
363
364 let factory = ListingTableFactory::new();
365 let context = SessionContext::new();
366 let state = context.state();
367 let name = TableReference::bare("foo");
368
369 let mut options = HashMap::new();
370 options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned());
371 options.insert("format.has_header".into(), "true".into());
372 let cmd = CreateExternalTable {
373 name,
374 location: dir.path().to_str().unwrap().to_string(),
375 file_type: "csv".to_string(),
376 schema: Arc::new(DFSchema::empty()),
377 table_partition_cols: vec![],
378 if_not_exists: false,
379 or_replace: false,
380 temporary: false,
381 definition: None,
382 order_exprs: vec![],
383 unbounded: false,
384 options,
385 constraints: Constraints::default(),
386 column_defaults: HashMap::new(),
387 };
388 let table_provider = factory.create(&state, &cmd).await.unwrap();
389 let listing_table = table_provider
390 .as_any()
391 .downcast_ref::<ListingTable>()
392 .unwrap();
393
394 let listing_options = listing_table.options();
395 assert_eq!("", listing_options.file_extension);
396 let table_path = listing_table.table_paths().first().unwrap();
398 assert_eq!(
399 table_path.get_glob().clone().unwrap(),
400 Pattern::new("*.csv").unwrap()
401 );
402 }
403
404 #[tokio::test]
405 async fn test_odd_directory_names() {
406 let dir = tempfile::tempdir().unwrap();
407 let mut path = PathBuf::from(dir.path());
408 path.extend(["odd.v1", "odd.v2"]);
409 fs::create_dir_all(&path).unwrap();
410
411 let factory = ListingTableFactory::new();
412 let context = SessionContext::new();
413 let state = context.state();
414 let name = TableReference::bare("foo");
415
416 let cmd = CreateExternalTable {
417 name,
418 location: String::from(path.to_str().unwrap()),
419 file_type: "parquet".to_string(),
420 schema: Arc::new(DFSchema::empty()),
421 table_partition_cols: vec![],
422 if_not_exists: false,
423 or_replace: false,
424 temporary: false,
425 definition: None,
426 order_exprs: vec![],
427 unbounded: false,
428 options: HashMap::new(),
429 constraints: Constraints::default(),
430 column_defaults: HashMap::new(),
431 };
432 let table_provider = factory.create(&state, &cmd).await.unwrap();
433 let listing_table = table_provider
434 .as_any()
435 .downcast_ref::<ListingTable>()
436 .unwrap();
437
438 let listing_options = listing_table.options();
439 assert_eq!("", listing_options.file_extension);
440 }
441
442 #[tokio::test]
443 async fn test_create_with_hive_partitions() {
444 let dir = tempfile::tempdir().unwrap();
445 let mut path = PathBuf::from(dir.path());
446 path.extend(["key1=value1", "key2=value2"]);
447 fs::create_dir_all(&path).unwrap();
448 path.push("data.parquet");
449 fs::File::create_new(&path).unwrap();
450
451 let factory = ListingTableFactory::new();
452 let context = SessionContext::new();
453 let state = context.state();
454 let name = TableReference::bare("foo");
455
456 let cmd = CreateExternalTable {
457 name,
458 location: dir.path().to_str().unwrap().to_string(),
459 file_type: "parquet".to_string(),
460 schema: Arc::new(DFSchema::empty()),
461 table_partition_cols: vec![],
462 if_not_exists: false,
463 or_replace: false,
464 temporary: false,
465 definition: None,
466 order_exprs: vec![],
467 unbounded: false,
468 options: HashMap::new(),
469 constraints: Constraints::default(),
470 column_defaults: HashMap::new(),
471 };
472 let table_provider = factory.create(&state, &cmd).await.unwrap();
473 let listing_table = table_provider
474 .as_any()
475 .downcast_ref::<ListingTable>()
476 .unwrap();
477
478 let listing_options = listing_table.options();
479 let dtype =
480 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8));
481 let expected_cols = vec![
482 (String::from("key1"), dtype.clone()),
483 (String::from("key2"), dtype.clone()),
484 ];
485 assert_eq!(expected_cols, listing_options.table_partition_cols);
486
487 let factory = ListingTableFactory::new();
489 let mut cfg = SessionConfig::new();
490 cfg.options_mut()
491 .execution
492 .listing_table_factory_infer_partitions = false;
493 let context = SessionContext::new_with_config(cfg);
494 let state = context.state();
495 let name = TableReference::bare("foo");
496
497 let cmd = CreateExternalTable {
498 name,
499 location: dir.path().to_str().unwrap().to_string(),
500 file_type: "parquet".to_string(),
501 schema: Arc::new(DFSchema::empty()),
502 table_partition_cols: vec![],
503 if_not_exists: false,
504 or_replace: false,
505 temporary: false,
506 definition: None,
507 order_exprs: vec![],
508 unbounded: false,
509 options: HashMap::new(),
510 constraints: Constraints::default(),
511 column_defaults: HashMap::new(),
512 };
513 let table_provider = factory.create(&state, &cmd).await.unwrap();
514 let listing_table = table_provider
515 .as_any()
516 .downcast_ref::<ListingTable>()
517 .unwrap();
518
519 let listing_options = listing_table.options();
520 assert!(listing_options.table_partition_cols.is_empty());
521 }
522}