1use std::collections::HashSet;
21use std::fmt::Debug;
22use std::sync::{Arc, Weak};
23
24use super::options::ReadOptions;
25use crate::datasource::dynamic_file::DynamicListTableFactory;
26use crate::execution::session_state::SessionStateBuilder;
27use crate::{
28 catalog::listing_schema::ListingSchemaProvider,
29 catalog::{
30 CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory,
31 },
32 dataframe::DataFrame,
33 datasource::listing::{
34 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35 },
36 datasource::{provider_as_source, MemTable, ViewTable},
37 error::Result,
38 execution::{
39 options::ArrowReadOptions,
40 runtime_env::{RuntimeEnv, RuntimeEnvBuilder},
41 FunctionRegistry,
42 },
43 logical_expr::AggregateUDF,
44 logical_expr::ScalarUDF,
45 logical_expr::{
46 CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
47 CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
48 DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
49 TableType, UNNAMED_TABLE,
50 },
51 physical_expr::PhysicalExpr,
52 physical_plan::ExecutionPlan,
53 variable::{VarProvider, VarType},
54};
55
56pub use crate::execution::session_state::SessionState;
58
59use arrow::datatypes::{Schema, SchemaRef};
60use arrow::record_batch::RecordBatch;
61use datafusion_catalog::memory::MemorySchemaProvider;
62use datafusion_catalog::MemoryCatalogProvider;
63use datafusion_catalog::{
64 DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory,
65};
66use datafusion_common::config::ConfigOptions;
67use datafusion_common::metadata::ScalarAndMetadata;
68use datafusion_common::{
69 config::{ConfigExtension, TableOptions},
70 exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err,
71 plan_datafusion_err, plan_err,
72 tree_node::{TreeNodeRecursion, TreeNodeVisitor},
73 DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference,
74};
75pub use datafusion_execution::config::SessionConfig;
76use datafusion_execution::registry::SerializerRegistry;
77pub use datafusion_execution::TaskContext;
78pub use datafusion_expr::execution_props::ExecutionProps;
79use datafusion_expr::{
80 expr_rewriter::FunctionRewrite,
81 logical_plan::{DdlStatement, Statement},
82 planner::ExprPlanner,
83 Expr, UserDefinedLogicalNode, WindowUDF,
84};
85use datafusion_optimizer::analyzer::type_coercion::TypeCoercion;
86use datafusion_optimizer::Analyzer;
87use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
88use datafusion_session::SessionStore;
89
90use async_trait::async_trait;
91use chrono::{DateTime, Utc};
92use object_store::ObjectStore;
93use parking_lot::RwLock;
94use url::Url;
95
96mod csv;
97mod json;
98#[cfg(feature = "parquet")]
99mod parquet;
100
101#[cfg(feature = "avro")]
102mod avro;
103
104pub trait DataFilePaths {
108 fn to_urls(self) -> Result<Vec<ListingTableUrl>>;
110}
111
112impl DataFilePaths for &str {
113 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
114 Ok(vec![ListingTableUrl::parse(self)?])
115 }
116}
117
118impl DataFilePaths for String {
119 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
120 Ok(vec![ListingTableUrl::parse(self)?])
121 }
122}
123
124impl DataFilePaths for &String {
125 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
126 Ok(vec![ListingTableUrl::parse(self)?])
127 }
128}
129
130impl<P> DataFilePaths for Vec<P>
131where
132 P: AsRef<str>,
133{
134 fn to_urls(self) -> Result<Vec<ListingTableUrl>> {
135 self.iter()
136 .map(ListingTableUrl::parse)
137 .collect::<Result<Vec<ListingTableUrl>>>()
138 }
139}
140
141#[derive(Clone)]
279pub struct SessionContext {
280 session_id: String,
282 session_start_time: DateTime<Utc>,
284 state: Arc<RwLock<SessionState>>,
286}
287
288impl Default for SessionContext {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294impl SessionContext {
295 pub fn new() -> Self {
297 Self::new_with_config(SessionConfig::new())
298 }
299
300 pub async fn refresh_catalogs(&self) -> Result<()> {
302 let cat_names = self.catalog_names().clone();
303 for cat_name in cat_names.iter() {
304 let cat = self
305 .catalog(cat_name.as_str())
306 .ok_or_else(|| internal_datafusion_err!("Catalog not found!"))?;
307 for schema_name in cat.schema_names() {
308 let schema = cat
309 .schema(schema_name.as_str())
310 .ok_or_else(|| internal_datafusion_err!("Schema not found!"))?;
311 let lister = schema.as_any().downcast_ref::<ListingSchemaProvider>();
312 if let Some(lister) = lister {
313 lister.refresh(&self.state()).await?;
314 }
315 }
316 }
317 Ok(())
318 }
319
320 pub fn new_with_config(config: SessionConfig) -> Self {
326 let runtime = Arc::new(RuntimeEnv::default());
327 Self::new_with_config_rt(config, runtime)
328 }
329
330 pub fn new_with_config_rt(config: SessionConfig, runtime: Arc<RuntimeEnv>) -> Self {
344 let state = SessionStateBuilder::new()
345 .with_config(config)
346 .with_runtime_env(runtime)
347 .with_default_features()
348 .build();
349 Self::new_with_state(state)
350 }
351
352 pub fn new_with_state(state: SessionState) -> Self {
354 Self {
355 session_id: state.session_id().to_string(),
356 session_start_time: Utc::now(),
357 state: Arc::new(RwLock::new(state)),
358 }
359 }
360
361 pub fn enable_url_table(self) -> Self {
401 let current_catalog_list = Arc::clone(self.state.read().catalog_list());
402 let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new()));
403 let catalog_list = Arc::new(DynamicFileCatalog::new(
404 current_catalog_list,
405 Arc::clone(&factory) as Arc<dyn UrlTableFactory>,
406 ));
407
408 let session_id = self.session_id.clone();
409 let ctx: SessionContext = self
410 .into_state_builder()
411 .with_session_id(session_id)
412 .with_catalog_list(catalog_list)
413 .build()
414 .into();
415 factory.session_store().with_state(ctx.state_weak_ref());
417 ctx
418 }
419
420 pub fn into_state_builder(self) -> SessionStateBuilder {
443 let SessionContext {
444 session_id: _,
445 session_start_time: _,
446 state,
447 } = self;
448 let state = match Arc::try_unwrap(state) {
449 Ok(rwlock) => rwlock.into_inner(),
450 Err(state) => state.read().clone(),
451 };
452 SessionStateBuilder::from(state)
453 }
454
455 pub fn session_start_time(&self) -> DateTime<Utc> {
457 self.session_start_time
458 }
459
460 pub fn with_function_factory(
462 self,
463 function_factory: Arc<dyn FunctionFactory>,
464 ) -> Self {
465 self.state.write().set_function_factory(function_factory);
466 self
467 }
468
469 pub fn add_optimizer_rule(
473 &self,
474 optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
475 ) {
476 self.state.write().append_optimizer_rule(optimizer_rule);
477 }
478
479 pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>) {
483 self.state.write().add_analyzer_rule(analyzer_rule);
484 }
485
486 pub fn register_object_store(
502 &self,
503 url: &Url,
504 object_store: Arc<dyn ObjectStore>,
505 ) -> Option<Arc<dyn ObjectStore>> {
506 self.runtime_env().register_object_store(url, object_store)
507 }
508
509 pub fn deregister_object_store(&self, url: &Url) -> Result<Arc<dyn ObjectStore>> {
513 self.runtime_env().deregister_object_store(url)
514 }
515
516 pub fn register_batch(
518 &self,
519 table_name: &str,
520 batch: RecordBatch,
521 ) -> Result<Option<Arc<dyn TableProvider>>> {
522 let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
523 self.register_table(
524 TableReference::Bare {
525 table: table_name.into(),
526 },
527 Arc::new(table),
528 )
529 }
530
531 pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
533 Arc::clone(self.state.read().runtime_env())
534 }
535
536 pub fn session_id(&self) -> String {
538 self.session_id.clone()
539 }
540
541 pub fn table_factory(
544 &self,
545 file_type: &str,
546 ) -> Option<Arc<dyn TableProviderFactory>> {
547 self.state.read().table_factories().get(file_type).cloned()
548 }
549
550 pub fn enable_ident_normalization(&self) -> bool {
552 self.state
553 .read()
554 .config()
555 .options()
556 .sql_parser
557 .enable_ident_normalization
558 }
559
560 pub fn copied_config(&self) -> SessionConfig {
562 self.state.read().config().clone()
563 }
564
565 pub fn copied_table_options(&self) -> TableOptions {
567 self.state.read().default_table_options()
568 }
569
570 #[cfg(feature = "sql")]
597 pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
598 self.sql_with_options(sql, SQLOptions::new()).await
599 }
600
601 #[cfg(feature = "sql")]
628 pub async fn sql_with_options(
629 &self,
630 sql: &str,
631 options: SQLOptions,
632 ) -> Result<DataFrame> {
633 let plan = self.state().create_logical_plan(sql).await?;
634 options.verify_plan(&plan)?;
635
636 self.execute_logical_plan(plan).await
637 }
638
639 #[cfg(feature = "sql")]
661 pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result<Expr> {
662 self.state.read().create_logical_expr(sql, df_schema)
663 }
664
665 pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result<DataFrame> {
673 match plan {
674 LogicalPlan::Ddl(ddl) => {
675 match ddl {
679 DdlStatement::CreateExternalTable(cmd) => {
680 (Box::pin(async move { self.create_external_table(&cmd).await })
681 as std::pin::Pin<Box<dyn futures::Future<Output = _> + Send>>)
682 .await
683 }
684 DdlStatement::CreateMemoryTable(cmd) => {
685 Box::pin(self.create_memory_table(cmd)).await
686 }
687 DdlStatement::CreateView(cmd) => {
688 Box::pin(self.create_view(cmd)).await
689 }
690 DdlStatement::CreateCatalogSchema(cmd) => {
691 Box::pin(self.create_catalog_schema(cmd)).await
692 }
693 DdlStatement::CreateCatalog(cmd) => {
694 Box::pin(self.create_catalog(cmd)).await
695 }
696 DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await,
697 DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await,
698 DdlStatement::DropCatalogSchema(cmd) => {
699 Box::pin(self.drop_schema(cmd)).await
700 }
701 DdlStatement::CreateFunction(cmd) => {
702 Box::pin(self.create_function(cmd)).await
703 }
704 DdlStatement::DropFunction(cmd) => {
705 Box::pin(self.drop_function(cmd)).await
706 }
707 ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))),
708 }
709 }
710 LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
712 self.set_variable(stmt).await
713 }
714 LogicalPlan::Statement(Statement::Prepare(Prepare {
715 name,
716 input,
717 fields,
718 })) => {
719 if !fields.is_empty() {
721 let param_names = input.get_parameter_names()?;
722 if param_names.len() != fields.len() {
723 return plan_err!(
724 "Prepare specifies {} data types but query has {} parameters",
725 fields.len(),
726 param_names.len()
727 );
728 }
729 }
730 self.state.write().store_prepared(name, fields, input)?;
736 self.return_empty_dataframe()
737 }
738 LogicalPlan::Statement(Statement::Execute(execute)) => {
739 self.execute_prepared(execute)
740 }
741 LogicalPlan::Statement(Statement::Deallocate(deallocate)) => {
742 self.state
743 .write()
744 .remove_prepared(deallocate.name.as_str())?;
745 self.return_empty_dataframe()
746 }
747 plan => Ok(DataFrame::new(self.state(), plan)),
748 }
749 }
750
751 pub fn create_physical_expr(
779 &self,
780 expr: Expr,
781 df_schema: &DFSchema,
782 ) -> Result<Arc<dyn PhysicalExpr>> {
783 self.state.read().create_physical_expr(expr, df_schema)
784 }
785
786 fn return_empty_dataframe(&self) -> Result<DataFrame> {
788 let plan = LogicalPlanBuilder::empty(false).build()?;
789 Ok(DataFrame::new(self.state(), plan))
790 }
791
792 async fn create_external_table(
793 &self,
794 cmd: &CreateExternalTable,
795 ) -> Result<DataFrame> {
796 let exist = self.table_exist(cmd.name.clone())?;
797
798 if cmd.temporary {
799 return not_impl_err!("Temporary tables not supported");
800 }
801
802 match (cmd.if_not_exists, cmd.or_replace, exist) {
803 (true, false, true) => self.return_empty_dataframe(),
804 (false, true, true) => {
805 let result = self
806 .find_and_deregister(cmd.name.clone(), TableType::Base)
807 .await;
808
809 match result {
810 Ok(true) => {
811 let table_provider: Arc<dyn TableProvider> =
812 self.create_custom_table(cmd).await?;
813 self.register_table(cmd.name.clone(), table_provider)?;
814 self.return_empty_dataframe()
815 }
816 Ok(false) => {
817 let table_provider: Arc<dyn TableProvider> =
818 self.create_custom_table(cmd).await?;
819 self.register_table(cmd.name.clone(), table_provider)?;
820 self.return_empty_dataframe()
821 }
822 Err(e) => {
823 exec_err!("Errored while deregistering external table: {}", e)
824 }
825 }
826 }
827 (true, true, true) => {
828 exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
829 }
830 (_, _, false) => {
831 let table_provider: Arc<dyn TableProvider> =
832 self.create_custom_table(cmd).await?;
833 self.register_table(cmd.name.clone(), table_provider)?;
834 self.return_empty_dataframe()
835 }
836 (false, false, true) => {
837 exec_err!("External table '{}' already exists", cmd.name)
838 }
839 }
840 }
841
842 async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result<DataFrame> {
843 let CreateMemoryTable {
844 name,
845 input,
846 if_not_exists,
847 or_replace,
848 constraints,
849 column_defaults,
850 temporary,
851 } = cmd;
852
853 let input = Arc::unwrap_or_clone(input);
854 let input = self.state().optimize(&input)?;
855
856 if temporary {
857 return not_impl_err!("Temporary tables not supported");
858 }
859
860 let table = self.table(name.clone()).await;
861 match (if_not_exists, or_replace, table) {
862 (true, false, Ok(_)) => self.return_empty_dataframe(),
863 (false, true, Ok(_)) => {
864 self.deregister_table(name.clone())?;
865 let schema = Arc::clone(input.schema().inner());
866 let physical = DataFrame::new(self.state(), input);
867
868 let batches: Vec<_> = physical.collect_partitioned().await?;
869 let table = Arc::new(
870 MemTable::try_new(schema, batches)?
872 .with_constraints(constraints)
873 .with_column_defaults(column_defaults.into_iter().collect()),
874 );
875
876 self.register_table(name.clone(), table)?;
877 self.return_empty_dataframe()
878 }
879 (true, true, Ok(_)) => {
880 exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'")
881 }
882 (_, _, Err(_)) => {
883 let schema = Arc::clone(input.schema().inner());
884 let physical = DataFrame::new(self.state(), input);
885
886 let batches: Vec<_> = physical.collect_partitioned().await?;
887 let table = Arc::new(
888 MemTable::try_new(schema, batches)?
890 .with_constraints(constraints)
891 .with_column_defaults(column_defaults.into_iter().collect()),
892 );
893
894 self.register_table(name, table)?;
895 self.return_empty_dataframe()
896 }
897 (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"),
898 }
899 }
900
901 fn apply_type_coercion(logical_plan: LogicalPlan) -> Result<LogicalPlan> {
903 let options = ConfigOptions::default();
904 Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]).execute_and_check(
905 logical_plan,
906 &options,
907 |_, _| {},
908 )
909 }
910
911 async fn create_view(&self, cmd: CreateView) -> Result<DataFrame> {
912 let CreateView {
913 name,
914 input,
915 or_replace,
916 definition,
917 temporary,
918 } = cmd;
919
920 let view = self.table(name.clone()).await;
921
922 if temporary {
923 return not_impl_err!("Temporary views not supported");
924 }
925
926 match (or_replace, view) {
927 (true, Ok(_)) => {
928 self.deregister_table(name.clone())?;
929 let input = Self::apply_type_coercion(input.as_ref().clone())?;
930 let table = Arc::new(ViewTable::new(input, definition));
931 self.register_table(name, table)?;
932 self.return_empty_dataframe()
933 }
934 (_, Err(_)) => {
935 let input = Self::apply_type_coercion(input.as_ref().clone())?;
936 let table = Arc::new(ViewTable::new(input, definition));
937 self.register_table(name, table)?;
938 self.return_empty_dataframe()
939 }
940 (false, Ok(_)) => exec_err!("Table '{name}' already exists"),
941 }
942 }
943
944 async fn create_catalog_schema(&self, cmd: CreateCatalogSchema) -> Result<DataFrame> {
945 let CreateCatalogSchema {
946 schema_name,
947 if_not_exists,
948 ..
949 } = cmd;
950
951 let tokens: Vec<&str> = schema_name.split('.').collect();
954 let (catalog, schema_name) = match tokens.len() {
955 1 => {
956 let state = self.state.read();
957 let name = &state.config().options().catalog.default_catalog;
958 let catalog = state.catalog_list().catalog(name).ok_or_else(|| {
959 exec_datafusion_err!("Missing default catalog '{name}'")
960 })?;
961 (catalog, tokens[0])
962 }
963 2 => {
964 let name = &tokens[0];
965 let catalog = self
966 .catalog(name)
967 .ok_or_else(|| exec_datafusion_err!("Missing catalog '{name}'"))?;
968 (catalog, tokens[1])
969 }
970 _ => return exec_err!("Unable to parse catalog from {schema_name}"),
971 };
972 let schema = catalog.schema(schema_name);
973
974 match (if_not_exists, schema) {
975 (true, Some(_)) => self.return_empty_dataframe(),
976 (true, None) | (false, None) => {
977 let schema = Arc::new(MemorySchemaProvider::new());
978 catalog.register_schema(schema_name, schema)?;
979 self.return_empty_dataframe()
980 }
981 (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"),
982 }
983 }
984
985 async fn create_catalog(&self, cmd: CreateCatalog) -> Result<DataFrame> {
986 let CreateCatalog {
987 catalog_name,
988 if_not_exists,
989 ..
990 } = cmd;
991 let catalog = self.catalog(catalog_name.as_str());
992
993 match (if_not_exists, catalog) {
994 (true, Some(_)) => self.return_empty_dataframe(),
995 (true, None) | (false, None) => {
996 let new_catalog = Arc::new(MemoryCatalogProvider::new());
997 self.state
998 .write()
999 .catalog_list()
1000 .register_catalog(catalog_name, new_catalog);
1001 self.return_empty_dataframe()
1002 }
1003 (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"),
1004 }
1005 }
1006
1007 async fn drop_table(&self, cmd: DropTable) -> Result<DataFrame> {
1008 let DropTable {
1009 name, if_exists, ..
1010 } = cmd;
1011 let result = self
1012 .find_and_deregister(name.clone(), TableType::Base)
1013 .await;
1014 match (result, if_exists) {
1015 (Ok(true), _) => self.return_empty_dataframe(),
1016 (_, true) => self.return_empty_dataframe(),
1017 (_, _) => exec_err!("Table '{name}' doesn't exist."),
1018 }
1019 }
1020
1021 async fn drop_view(&self, cmd: DropView) -> Result<DataFrame> {
1022 let DropView {
1023 name, if_exists, ..
1024 } = cmd;
1025 let result = self
1026 .find_and_deregister(name.clone(), TableType::View)
1027 .await;
1028 match (result, if_exists) {
1029 (Ok(true), _) => self.return_empty_dataframe(),
1030 (_, true) => self.return_empty_dataframe(),
1031 (_, _) => exec_err!("View '{name}' doesn't exist."),
1032 }
1033 }
1034
1035 async fn drop_schema(&self, cmd: DropCatalogSchema) -> Result<DataFrame> {
1036 let DropCatalogSchema {
1037 name,
1038 if_exists: allow_missing,
1039 cascade,
1040 schema: _,
1041 } = cmd;
1042 let catalog = {
1043 let state = self.state.read();
1044 let catalog_name = match &name {
1045 SchemaReference::Full { catalog, .. } => catalog.to_string(),
1046 SchemaReference::Bare { .. } => {
1047 state.config_options().catalog.default_catalog.to_string()
1048 }
1049 };
1050 if let Some(catalog) = state.catalog_list().catalog(&catalog_name) {
1051 catalog
1052 } else if allow_missing {
1053 return self.return_empty_dataframe();
1054 } else {
1055 return self.schema_doesnt_exist_err(name);
1056 }
1057 };
1058 let dereg = catalog.deregister_schema(name.schema_name(), cascade)?;
1059 match (dereg, allow_missing) {
1060 (None, true) => self.return_empty_dataframe(),
1061 (None, false) => self.schema_doesnt_exist_err(name),
1062 (Some(_), _) => self.return_empty_dataframe(),
1063 }
1064 }
1065
1066 fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result<DataFrame> {
1067 exec_err!("Schema '{schemaref}' doesn't exist.")
1068 }
1069
1070 async fn set_variable(&self, stmt: SetVariable) -> Result<DataFrame> {
1071 let SetVariable {
1072 variable, value, ..
1073 } = stmt;
1074
1075 if variable.starts_with("datafusion.runtime.") {
1077 self.set_runtime_variable(&variable, &value)?;
1078 } else {
1079 let mut state = self.state.write();
1080 state.config_mut().options_mut().set(&variable, &value)?;
1081
1082 let config_options = state.config().options();
1085
1086 let udfs_to_update: Vec<_> = state
1088 .scalar_functions()
1089 .values()
1090 .filter_map(|udf| {
1091 udf.inner()
1092 .with_updated_config(config_options)
1093 .map(Arc::new)
1094 })
1095 .collect();
1096
1097 for udf in udfs_to_update {
1098 state.register_udf(udf)?;
1099 }
1100
1101 drop(state);
1102 }
1103
1104 self.return_empty_dataframe()
1105 }
1106
1107 fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> {
1108 let key = variable.strip_prefix("datafusion.runtime.").unwrap();
1109
1110 let mut state = self.state.write();
1111
1112 let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env());
1113 builder = match key {
1114 "memory_limit" => {
1115 let memory_limit = Self::parse_memory_limit(value)?;
1116 builder.with_memory_limit(memory_limit, 1.0)
1117 }
1118 "max_temp_directory_size" => {
1119 let directory_size = Self::parse_memory_limit(value)?;
1120 builder.with_max_temp_directory_size(directory_size as u64)
1121 }
1122 "temp_directory" => builder.with_temp_file_path(value),
1123 "metadata_cache_limit" => {
1124 let limit = Self::parse_memory_limit(value)?;
1125 builder.with_metadata_cache_limit(limit)
1126 }
1127 _ => return plan_err!("Unknown runtime configuration: {variable}"),
1128 };
1129
1130 *state = SessionStateBuilder::from(state.clone())
1131 .with_runtime_env(Arc::new(builder.build()?))
1132 .build();
1133
1134 Ok(())
1135 }
1136
1137 pub fn parse_memory_limit(limit: &str) -> Result<usize> {
1154 let (number, unit) = limit.split_at(limit.len() - 1);
1155 let number: f64 = number.parse().map_err(|_| {
1156 plan_datafusion_err!("Failed to parse number from memory limit '{limit}'")
1157 })?;
1158
1159 match unit {
1160 "K" => Ok((number * 1024.0) as usize),
1161 "M" => Ok((number * 1024.0 * 1024.0) as usize),
1162 "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize),
1163 _ => plan_err!("Unsupported unit '{unit}' in memory limit '{limit}'"),
1164 }
1165 }
1166
1167 async fn create_custom_table(
1168 &self,
1169 cmd: &CreateExternalTable,
1170 ) -> Result<Arc<dyn TableProvider>> {
1171 let state = self.state.read().clone();
1172 let file_type = cmd.file_type.to_uppercase();
1173 let factory =
1174 state
1175 .table_factories()
1176 .get(file_type.as_str())
1177 .ok_or_else(|| {
1178 exec_datafusion_err!("Unable to find factory for {}", cmd.file_type)
1179 })?;
1180 let table = (*factory).create(&state, cmd).await?;
1181 Ok(table)
1182 }
1183
1184 async fn find_and_deregister(
1185 &self,
1186 table_ref: impl Into<TableReference>,
1187 table_type: TableType,
1188 ) -> Result<bool> {
1189 let table_ref = table_ref.into();
1190 let table = table_ref.table().to_owned();
1191 let maybe_schema = {
1192 let state = self.state.read();
1193 let resolved = state.resolve_table_ref(table_ref);
1194 state
1195 .catalog_list()
1196 .catalog(&resolved.catalog)
1197 .and_then(|c| c.schema(&resolved.schema))
1198 };
1199
1200 if let Some(schema) = maybe_schema {
1201 if let Some(table_provider) = schema.table(&table).await? {
1202 if table_provider.table_type() == table_type {
1203 schema.deregister_table(&table)?;
1204 return Ok(true);
1205 }
1206 }
1207 }
1208
1209 Ok(false)
1210 }
1211
1212 async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
1213 let function = {
1214 let state = self.state.read().clone();
1215 let function_factory = state.function_factory();
1216
1217 match function_factory {
1218 Some(f) => f.create(&state, stmt).await?,
1219 _ => {
1220 return Err(DataFusionError::Configuration(
1221 "Function factory has not been configured".to_string(),
1222 ))
1223 }
1224 }
1225 };
1226
1227 match function {
1228 RegisterFunction::Scalar(f) => {
1229 self.state.write().register_udf(f)?;
1230 }
1231 RegisterFunction::Aggregate(f) => {
1232 self.state.write().register_udaf(f)?;
1233 }
1234 RegisterFunction::Window(f) => {
1235 self.state.write().register_udwf(f)?;
1236 }
1237 RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
1238 };
1239
1240 self.return_empty_dataframe()
1241 }
1242
1243 async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
1244 let mut dropped = false;
1247 dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some();
1248 dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some();
1249 dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some();
1250 dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some();
1251
1252 if !stmt.if_exists && !dropped {
1258 exec_err!("Function does not exist")
1259 } else {
1260 self.return_empty_dataframe()
1261 }
1262 }
1263
1264 fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
1265 let Execute {
1266 name, parameters, ..
1267 } = execute;
1268 let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
1269 exec_datafusion_err!("Prepared statement '{}' does not exist", name)
1270 })?;
1271
1272 let mut params: Vec<ScalarAndMetadata> = parameters
1274 .into_iter()
1275 .map(|e| match e {
1276 Expr::Literal(scalar, metadata) => {
1277 Ok(ScalarAndMetadata::new(scalar, metadata))
1278 }
1279 _ => not_impl_err!("Unsupported parameter type: {}", e),
1280 })
1281 .collect::<Result<_>>()?;
1282
1283 if !prepared.fields.is_empty() {
1285 if params.len() != prepared.fields.len() {
1286 return exec_err!(
1287 "Prepared statement '{}' expects {} parameters, but {} provided",
1288 name,
1289 prepared.fields.len(),
1290 params.len()
1291 );
1292 }
1293 params = params
1294 .into_iter()
1295 .zip(prepared.fields.iter())
1296 .map(|(e, dt)| -> Result<_> { e.cast_storage_to(dt.data_type()) })
1297 .collect::<Result<_>>()?;
1298 }
1299
1300 let params = ParamValues::List(params);
1301 let plan = prepared
1302 .plan
1303 .as_ref()
1304 .clone()
1305 .replace_params_with_values(¶ms)?;
1306 Ok(DataFrame::new(self.state(), plan))
1307 }
1308
1309 pub fn register_variable(
1311 &self,
1312 variable_type: VarType,
1313 provider: Arc<dyn VarProvider + Send + Sync>,
1314 ) {
1315 self.state
1316 .write()
1317 .execution_props_mut()
1318 .add_var_provider(variable_type, provider);
1319 }
1320
1321 pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
1323 self.state.write().register_udtf(name, fun)
1324 }
1325
1326 pub fn register_udf(&self, f: ScalarUDF) {
1336 let mut state = self.state.write();
1337 state.register_udf(Arc::new(f)).ok();
1338 }
1339
1340 pub fn register_udaf(&self, f: AggregateUDF) {
1348 self.state.write().register_udaf(Arc::new(f)).ok();
1349 }
1350
1351 pub fn register_udwf(&self, f: WindowUDF) {
1359 self.state.write().register_udwf(Arc::new(f)).ok();
1360 }
1361
1362 pub fn deregister_udf(&self, name: &str) {
1364 self.state.write().deregister_udf(name).ok();
1365 }
1366
1367 pub fn deregister_udaf(&self, name: &str) {
1369 self.state.write().deregister_udaf(name).ok();
1370 }
1371
1372 pub fn deregister_udwf(&self, name: &str) {
1374 self.state.write().deregister_udwf(name).ok();
1375 }
1376
1377 pub fn deregister_udtf(&self, name: &str) {
1379 self.state.write().deregister_udtf(name).ok();
1380 }
1381
1382 async fn _read_type<'a, P: DataFilePaths>(
1387 &self,
1388 table_paths: P,
1389 options: impl ReadOptions<'a>,
1390 ) -> Result<DataFrame> {
1391 let table_paths = table_paths.to_urls()?;
1392 let session_config = self.copied_config();
1393 let listing_options =
1394 options.to_listing_options(&session_config, self.copied_table_options());
1395
1396 let option_extension = listing_options.file_extension.clone();
1397
1398 if table_paths.is_empty() {
1399 return exec_err!("No table paths were provided");
1400 }
1401
1402 for path in &table_paths {
1404 let file_path = path.as_str();
1405 if !file_path.ends_with(option_extension.clone().as_str())
1406 && !path.is_collection()
1407 {
1408 return exec_err!(
1409 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1410 );
1411 }
1412 }
1413
1414 let resolved_schema = options
1415 .get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
1416 .await?;
1417 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1418 .with_listing_options(listing_options)
1419 .with_schema(resolved_schema);
1420 let provider = ListingTable::try_new(config)?;
1421 self.read_table(Arc::new(provider))
1422 }
1423
1424 pub async fn read_arrow<P: DataFilePaths>(
1431 &self,
1432 table_paths: P,
1433 options: ArrowReadOptions<'_>,
1434 ) -> Result<DataFrame> {
1435 self._read_type(table_paths, options).await
1436 }
1437
1438 pub fn read_empty(&self) -> Result<DataFrame> {
1440 Ok(DataFrame::new(
1441 self.state(),
1442 LogicalPlanBuilder::empty(true).build()?,
1443 ))
1444 }
1445
1446 pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
1449 Ok(DataFrame::new(
1450 self.state(),
1451 LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
1452 .build()?,
1453 ))
1454 }
1455
1456 pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
1458 let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
1459 Ok(DataFrame::new(
1460 self.state(),
1461 LogicalPlanBuilder::scan(
1462 UNNAMED_TABLE,
1463 provider_as_source(Arc::new(provider)),
1464 None,
1465 )?
1466 .build()?,
1467 ))
1468 }
1469 pub fn read_batches(
1471 &self,
1472 batches: impl IntoIterator<Item = RecordBatch>,
1473 ) -> Result<DataFrame> {
1474 let mut batches = batches.into_iter().peekable();
1476 let schema = if let Some(batch) = batches.peek() {
1477 batch.schema()
1478 } else {
1479 Arc::new(Schema::empty())
1480 };
1481 let provider = MemTable::try_new(schema, vec![batches.collect()])?;
1482 Ok(DataFrame::new(
1483 self.state(),
1484 LogicalPlanBuilder::scan(
1485 UNNAMED_TABLE,
1486 provider_as_source(Arc::new(provider)),
1487 None,
1488 )?
1489 .build()?,
1490 ))
1491 }
1492 pub async fn register_listing_table(
1500 &self,
1501 table_ref: impl Into<TableReference>,
1502 table_path: impl AsRef<str>,
1503 options: ListingOptions,
1504 provided_schema: Option<SchemaRef>,
1505 sql_definition: Option<String>,
1506 ) -> Result<()> {
1507 let table_path = ListingTableUrl::parse(table_path)?;
1508 let resolved_schema = match provided_schema {
1509 Some(s) => s,
1510 None => options.infer_schema(&self.state(), &table_path).await?,
1511 };
1512 let config = ListingTableConfig::new(table_path)
1513 .with_listing_options(options)
1514 .with_schema(resolved_schema);
1515 let table = ListingTable::try_new(config)?.with_definition(sql_definition);
1516 self.register_table(table_ref, Arc::new(table))?;
1517 Ok(())
1518 }
1519
1520 fn register_type_check<P: DataFilePaths>(
1521 &self,
1522 table_paths: P,
1523 extension: impl AsRef<str>,
1524 ) -> Result<()> {
1525 let table_paths = table_paths.to_urls()?;
1526 if table_paths.is_empty() {
1527 return exec_err!("No table paths were provided");
1528 }
1529
1530 let extension = extension.as_ref();
1532 for path in &table_paths {
1533 let file_path = path.as_str();
1534 if !file_path.ends_with(extension) && !path.is_collection() {
1535 return exec_err!(
1536 "File path '{file_path}' does not match the expected extension '{extension}'"
1537 );
1538 }
1539 }
1540 Ok(())
1541 }
1542
1543 pub async fn register_arrow(
1546 &self,
1547 name: &str,
1548 table_path: &str,
1549 options: ArrowReadOptions<'_>,
1550 ) -> Result<()> {
1551 let listing_options = options
1552 .to_listing_options(&self.copied_config(), self.copied_table_options());
1553
1554 self.register_listing_table(
1555 name,
1556 table_path,
1557 listing_options,
1558 options.schema.map(|s| Arc::new(s.to_owned())),
1559 None,
1560 )
1561 .await?;
1562 Ok(())
1563 }
1564
1565 pub fn register_catalog(
1572 &self,
1573 name: impl Into<String>,
1574 catalog: Arc<dyn CatalogProvider>,
1575 ) -> Option<Arc<dyn CatalogProvider>> {
1576 let name = name.into();
1577 self.state
1578 .read()
1579 .catalog_list()
1580 .register_catalog(name, catalog)
1581 }
1582
1583 pub fn catalog_names(&self) -> Vec<String> {
1585 self.state.read().catalog_list().catalog_names()
1586 }
1587
1588 pub fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
1590 self.state.read().catalog_list().catalog(name)
1591 }
1592
1593 pub fn register_table(
1599 &self,
1600 table_ref: impl Into<TableReference>,
1601 provider: Arc<dyn TableProvider>,
1602 ) -> Result<Option<Arc<dyn TableProvider>>> {
1603 let table_ref: TableReference = table_ref.into();
1604 let table = table_ref.table().to_owned();
1605 self.state
1606 .read()
1607 .schema_for_ref(table_ref)?
1608 .register_table(table, provider)
1609 }
1610
1611 pub fn deregister_table(
1615 &self,
1616 table_ref: impl Into<TableReference>,
1617 ) -> Result<Option<Arc<dyn TableProvider>>> {
1618 let table_ref = table_ref.into();
1619 let table = table_ref.table().to_owned();
1620 self.state
1621 .read()
1622 .schema_for_ref(table_ref)?
1623 .deregister_table(&table)
1624 }
1625
1626 pub fn table_exist(&self, table_ref: impl Into<TableReference>) -> Result<bool> {
1628 let table_ref: TableReference = table_ref.into();
1629 let table = table_ref.table();
1630 let table_ref = table_ref.clone();
1631 Ok(self
1632 .state
1633 .read()
1634 .schema_for_ref(table_ref)?
1635 .table_exist(table))
1636 }
1637
1638 pub async fn table(&self, table_ref: impl Into<TableReference>) -> Result<DataFrame> {
1646 let table_ref: TableReference = table_ref.into();
1647 let provider = self.table_provider(table_ref.clone()).await?;
1648 let plan = LogicalPlanBuilder::scan(
1649 table_ref,
1650 provider_as_source(Arc::clone(&provider)),
1651 None,
1652 )?
1653 .build()?;
1654 Ok(DataFrame::new(self.state(), plan))
1655 }
1656
1657 pub fn table_function(&self, name: &str) -> Result<Arc<TableFunction>> {
1663 self.state
1664 .read()
1665 .table_functions()
1666 .get(name)
1667 .cloned()
1668 .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found"))
1669 }
1670
1671 pub async fn table_provider(
1673 &self,
1674 table_ref: impl Into<TableReference>,
1675 ) -> Result<Arc<dyn TableProvider>> {
1676 let table_ref = table_ref.into();
1677 let table = table_ref.table().to_string();
1678 let schema = self.state.read().schema_for_ref(table_ref)?;
1679 match schema.table(&table).await? {
1680 Some(ref provider) => Ok(Arc::clone(provider)),
1681 _ => plan_err!("No table named '{table}'"),
1682 }
1683 }
1684
1685 pub fn task_ctx(&self) -> Arc<TaskContext> {
1687 Arc::new(TaskContext::from(self))
1688 }
1689
1690 pub fn state(&self) -> SessionState {
1703 let mut state = self.state.read().clone();
1704 state.mark_start_execution();
1705 state
1706 }
1707
1708 pub fn state_ref(&self) -> Arc<RwLock<SessionState>> {
1710 Arc::clone(&self.state)
1711 }
1712
1713 pub fn state_weak_ref(&self) -> Weak<RwLock<SessionState>> {
1715 Arc::downgrade(&self.state)
1716 }
1717
1718 pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
1720 self.state.write().register_catalog_list(catalog_list)
1721 }
1722
1723 pub fn register_table_options_extension<T: ConfigExtension>(&self, extension: T) {
1726 self.state
1727 .write()
1728 .register_table_options_extension(extension)
1729 }
1730}
1731
1732impl FunctionRegistry for SessionContext {
1733 fn udfs(&self) -> HashSet<String> {
1734 self.state.read().udfs()
1735 }
1736
1737 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
1738 self.state.read().udf(name)
1739 }
1740
1741 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
1742 self.state.read().udaf(name)
1743 }
1744
1745 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1746 self.state.read().udwf(name)
1747 }
1748
1749 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1750 self.state.write().register_udf(udf)
1751 }
1752
1753 fn register_udaf(
1754 &mut self,
1755 udaf: Arc<AggregateUDF>,
1756 ) -> Result<Option<Arc<AggregateUDF>>> {
1757 self.state.write().register_udaf(udaf)
1758 }
1759
1760 fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
1761 self.state.write().register_udwf(udwf)
1762 }
1763
1764 fn register_function_rewrite(
1765 &mut self,
1766 rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
1767 ) -> Result<()> {
1768 self.state.write().register_function_rewrite(rewrite)
1769 }
1770
1771 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
1772 self.state.read().expr_planners().to_vec()
1773 }
1774
1775 fn register_expr_planner(
1776 &mut self,
1777 expr_planner: Arc<dyn ExprPlanner>,
1778 ) -> Result<()> {
1779 self.state.write().register_expr_planner(expr_planner)
1780 }
1781
1782 fn udafs(&self) -> HashSet<String> {
1783 self.state.read().udafs()
1784 }
1785
1786 fn udwfs(&self) -> HashSet<String> {
1787 self.state.read().udwfs()
1788 }
1789}
1790
1791impl From<&SessionContext> for TaskContext {
1793 fn from(session: &SessionContext) -> Self {
1794 TaskContext::from(&*session.state.read())
1795 }
1796}
1797
1798impl From<SessionState> for SessionContext {
1799 fn from(state: SessionState) -> Self {
1800 Self::new_with_state(state)
1801 }
1802}
1803
1804impl From<SessionContext> for SessionStateBuilder {
1805 fn from(session: SessionContext) -> Self {
1806 session.into_state_builder()
1807 }
1808}
1809
1810#[async_trait]
1812pub trait QueryPlanner: Debug {
1813 async fn create_physical_plan(
1815 &self,
1816 logical_plan: &LogicalPlan,
1817 session_state: &SessionState,
1818 ) -> Result<Arc<dyn ExecutionPlan>>;
1819}
1820
1821#[async_trait]
1859pub trait FunctionFactory: Debug + Sync + Send {
1860 async fn create(
1862 &self,
1863 state: &SessionState,
1864 statement: CreateFunction,
1865 ) -> Result<RegisterFunction>;
1866}
1867
1868#[derive(Debug, Clone)]
1870pub enum RegisterFunction {
1871 Scalar(Arc<ScalarUDF>),
1873 Aggregate(Arc<AggregateUDF>),
1875 Window(Arc<WindowUDF>),
1877 Table(String, Arc<dyn TableFunctionImpl>),
1879}
1880
1881#[derive(Debug)]
1884pub struct EmptySerializerRegistry;
1885
1886impl SerializerRegistry for EmptySerializerRegistry {
1887 fn serialize_logical_plan(
1888 &self,
1889 node: &dyn UserDefinedLogicalNode,
1890 ) -> Result<Vec<u8>> {
1891 not_impl_err!(
1892 "Serializing user defined logical plan node `{}` is not supported",
1893 node.name()
1894 )
1895 }
1896
1897 fn deserialize_logical_plan(
1898 &self,
1899 name: &str,
1900 _bytes: &[u8],
1901 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
1902 not_impl_err!(
1903 "Deserializing user defined logical plan node `{name}` is not supported"
1904 )
1905 }
1906}
1907
1908#[derive(Clone, Debug, Copy)]
1912pub struct SQLOptions {
1913 allow_ddl: bool,
1915 allow_dml: bool,
1917 allow_statements: bool,
1919}
1920
1921impl Default for SQLOptions {
1922 fn default() -> Self {
1923 Self {
1924 allow_ddl: true,
1925 allow_dml: true,
1926 allow_statements: true,
1927 }
1928 }
1929}
1930
1931impl SQLOptions {
1932 pub fn new() -> Self {
1934 Default::default()
1935 }
1936
1937 pub fn with_allow_ddl(mut self, allow: bool) -> Self {
1939 self.allow_ddl = allow;
1940 self
1941 }
1942
1943 pub fn with_allow_dml(mut self, allow: bool) -> Self {
1945 self.allow_dml = allow;
1946 self
1947 }
1948
1949 pub fn with_allow_statements(mut self, allow: bool) -> Self {
1951 self.allow_statements = allow;
1952 self
1953 }
1954
1955 pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
1958 plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
1959 Ok(())
1960 }
1961}
1962
1963struct BadPlanVisitor<'a> {
1964 options: &'a SQLOptions,
1965}
1966impl<'a> BadPlanVisitor<'a> {
1967 fn new(options: &'a SQLOptions) -> Self {
1968 Self { options }
1969 }
1970}
1971
1972impl<'n> TreeNodeVisitor<'n> for BadPlanVisitor<'_> {
1973 type Node = LogicalPlan;
1974
1975 fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1976 match node {
1977 LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
1978 plan_err!("DDL not supported: {}", ddl.name())
1979 }
1980 LogicalPlan::Dml(dml) if !self.options.allow_dml => {
1981 plan_err!("DML not supported: {}", dml.op)
1982 }
1983 LogicalPlan::Copy(_) if !self.options.allow_dml => {
1984 plan_err!("DML not supported: COPY")
1985 }
1986 LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
1987 plan_err!("Statement not supported: {}", stmt.name())
1988 }
1989 _ => Ok(TreeNodeRecursion::Continue),
1990 }
1991 }
1992}
1993
1994#[cfg(test)]
1995mod tests {
1996 use super::{super::options::CsvReadOptions, *};
1997 use crate::execution::memory_pool::MemoryConsumer;
1998 use crate::test;
1999 use crate::test_util::{plan_and_collect, populate_csv_partitions};
2000 use arrow::datatypes::{DataType, TimeUnit};
2001 use datafusion_common::DataFusionError;
2002 use std::error::Error;
2003 use std::path::PathBuf;
2004
2005 use datafusion_common::test_util::batches_to_string;
2006 use datafusion_common_runtime::SpawnedTask;
2007 use insta::{allow_duplicates, assert_snapshot};
2008
2009 use crate::catalog::SchemaProvider;
2010 use crate::execution::session_state::SessionStateBuilder;
2011 use crate::physical_planner::PhysicalPlanner;
2012 use async_trait::async_trait;
2013 use datafusion_expr::planner::TypePlanner;
2014 use sqlparser::ast;
2015 use tempfile::TempDir;
2016
2017 #[tokio::test]
2018 async fn shared_memory_and_disk_manager() {
2019 let ctx1 = SessionContext::new();
2022
2023 let memory_pool = ctx1.runtime_env().memory_pool.clone();
2025
2026 let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
2027 reservation.grow(100);
2028
2029 let disk_manager = ctx1.runtime_env().disk_manager.clone();
2030
2031 let ctx2 =
2032 SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env());
2033
2034 assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100);
2035 assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100);
2036
2037 drop(reservation);
2038
2039 assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 0);
2040 assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 0);
2041
2042 assert!(std::ptr::eq(
2043 Arc::as_ptr(&disk_manager),
2044 Arc::as_ptr(&ctx1.runtime_env().disk_manager)
2045 ));
2046 assert!(std::ptr::eq(
2047 Arc::as_ptr(&disk_manager),
2048 Arc::as_ptr(&ctx2.runtime_env().disk_manager)
2049 ));
2050 }
2051
2052 #[tokio::test]
2053 async fn create_variable_expr() -> Result<()> {
2054 let tmp_dir = TempDir::new()?;
2055 let partition_count = 4;
2056 let ctx = create_ctx(&tmp_dir, partition_count).await?;
2057
2058 let variable_provider = test::variable::SystemVar::new();
2059 ctx.register_variable(VarType::System, Arc::new(variable_provider));
2060 let variable_provider = test::variable::UserDefinedVar::new();
2061 ctx.register_variable(VarType::UserDefined, Arc::new(variable_provider));
2062
2063 let provider = test::create_table_dual();
2064 ctx.register_table("dual", provider)?;
2065
2066 let results =
2067 plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
2068 .await?;
2069
2070 assert_snapshot!(batches_to_string(&results), @r"
2071 +----------------------+------------------------+---------------------+
2072 | @@version | @name | @integer + Int64(1) |
2073 +----------------------+------------------------+---------------------+
2074 | system-var-@@version | user-defined-var-@name | 42 |
2075 +----------------------+------------------------+---------------------+
2076 ");
2077
2078 Ok(())
2079 }
2080
2081 #[tokio::test]
2082 async fn create_variable_err() -> Result<()> {
2083 let ctx = SessionContext::new();
2084
2085 let err = plan_and_collect(&ctx, "SElECT @= X3").await.unwrap_err();
2086 assert_eq!(
2087 err.strip_backtrace(),
2088 "Error during planning: variable [\"@=\"] has no type information"
2089 );
2090 Ok(())
2091 }
2092
2093 #[tokio::test]
2094 async fn register_deregister() -> Result<()> {
2095 let tmp_dir = TempDir::new()?;
2096 let partition_count = 4;
2097 let ctx = create_ctx(&tmp_dir, partition_count).await?;
2098
2099 let provider = test::create_table_dual();
2100 ctx.register_table("dual", provider)?;
2101
2102 assert!(ctx.deregister_table("dual")?.is_some());
2103 assert!(ctx.deregister_table("dual")?.is_none());
2104
2105 Ok(())
2106 }
2107
2108 #[tokio::test]
2109 async fn send_context_to_threads() -> Result<()> {
2110 let tmp_dir = TempDir::new()?;
2113 let partition_count = 4;
2114 let ctx = Arc::new(create_ctx(&tmp_dir, partition_count).await?);
2115
2116 let threads: Vec<_> = (0..2)
2117 .map(|_| ctx.clone())
2118 .map(|ctx| {
2119 SpawnedTask::spawn(async move {
2120 ctx.sql("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")
2122 .await
2123 })
2124 })
2125 .collect();
2126
2127 for handle in threads {
2128 handle.join().await.unwrap().unwrap();
2129 }
2130 Ok(())
2131 }
2132
2133 #[tokio::test]
2134 async fn with_listing_schema_provider() -> Result<()> {
2135 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2136 let path = path.join("tests/tpch-csv");
2137 let url = format!("file://{}", path.display());
2138
2139 let cfg = SessionConfig::new()
2140 .set_str("datafusion.catalog.location", url.as_str())
2141 .set_str("datafusion.catalog.format", "CSV")
2142 .set_str("datafusion.catalog.has_header", "true");
2143 let session_state = SessionStateBuilder::new()
2144 .with_config(cfg)
2145 .with_default_features()
2146 .build();
2147 let ctx = SessionContext::new_with_state(session_state);
2148 ctx.refresh_catalogs().await?;
2149
2150 let result =
2151 plan_and_collect(&ctx, "select c_name from default.customer limit 3;")
2152 .await?;
2153
2154 let actual = arrow::util::pretty::pretty_format_batches(&result)
2155 .unwrap()
2156 .to_string();
2157 assert_snapshot!(actual, @r"
2158 +--------------------+
2159 | c_name |
2160 +--------------------+
2161 | Customer#000000002 |
2162 | Customer#000000003 |
2163 | Customer#000000004 |
2164 +--------------------+
2165 ");
2166
2167 Ok(())
2168 }
2169
2170 #[tokio::test]
2171 async fn test_dynamic_file_query() -> Result<()> {
2172 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
2173 let path = path.join("tests/tpch-csv/customer.csv");
2174 let url = format!("file://{}", path.display());
2175 let cfg = SessionConfig::new();
2176 let session_state = SessionStateBuilder::new()
2177 .with_default_features()
2178 .with_config(cfg)
2179 .build();
2180 let ctx = SessionContext::new_with_state(session_state).enable_url_table();
2181 let result = plan_and_collect(
2182 &ctx,
2183 format!("select c_name from '{}' limit 3;", &url).as_str(),
2184 )
2185 .await?;
2186
2187 let actual = arrow::util::pretty::pretty_format_batches(&result)
2188 .unwrap()
2189 .to_string();
2190 assert_snapshot!(actual, @r"
2191 +--------------------+
2192 | c_name |
2193 +--------------------+
2194 | Customer#000000002 |
2195 | Customer#000000003 |
2196 | Customer#000000004 |
2197 +--------------------+
2198 ");
2199
2200 Ok(())
2201 }
2202
2203 #[tokio::test]
2204 async fn custom_query_planner() -> Result<()> {
2205 let runtime = Arc::new(RuntimeEnv::default());
2206 let session_state = SessionStateBuilder::new()
2207 .with_config(SessionConfig::new())
2208 .with_runtime_env(runtime)
2209 .with_default_features()
2210 .with_query_planner(Arc::new(MyQueryPlanner {}))
2211 .build();
2212 let ctx = SessionContext::new_with_state(session_state);
2213
2214 let df = ctx.sql("SELECT 1").await?;
2215 df.collect().await.expect_err("query not supported");
2216 Ok(())
2217 }
2218
2219 #[tokio::test]
2220 async fn disabled_default_catalog_and_schema() -> Result<()> {
2221 let ctx = SessionContext::new_with_config(
2222 SessionConfig::new().with_create_default_catalog_and_schema(false),
2223 );
2224
2225 assert!(matches!(
2226 ctx.register_table("test", test::table_with_sequence(1, 1)?),
2227 Err(DataFusionError::Plan(_))
2228 ));
2229
2230 let err = ctx
2231 .sql("select * from datafusion.public.test")
2232 .await
2233 .unwrap_err();
2234 let err = err
2235 .source()
2236 .and_then(|err| err.downcast_ref::<DataFusionError>())
2237 .unwrap();
2238
2239 assert!(matches!(err, &DataFusionError::Plan(_)));
2240
2241 Ok(())
2242 }
2243
2244 #[tokio::test]
2245 async fn custom_catalog_and_schema() {
2246 let config = SessionConfig::new()
2247 .with_create_default_catalog_and_schema(true)
2248 .with_default_catalog_and_schema("my_catalog", "my_schema");
2249 catalog_and_schema_test(config).await;
2250 }
2251
2252 #[tokio::test]
2253 async fn custom_catalog_and_schema_no_default() {
2254 let config = SessionConfig::new()
2255 .with_create_default_catalog_and_schema(false)
2256 .with_default_catalog_and_schema("my_catalog", "my_schema");
2257 catalog_and_schema_test(config).await;
2258 }
2259
2260 #[tokio::test]
2261 async fn custom_catalog_and_schema_and_information_schema() {
2262 let config = SessionConfig::new()
2263 .with_create_default_catalog_and_schema(true)
2264 .with_information_schema(true)
2265 .with_default_catalog_and_schema("my_catalog", "my_schema");
2266 catalog_and_schema_test(config).await;
2267 }
2268
2269 async fn catalog_and_schema_test(config: SessionConfig) {
2270 let ctx = SessionContext::new_with_config(config);
2271 let catalog = MemoryCatalogProvider::new();
2272 let schema = MemorySchemaProvider::new();
2273 schema
2274 .register_table("test".to_owned(), test::table_with_sequence(1, 1).unwrap())
2275 .unwrap();
2276 catalog
2277 .register_schema("my_schema", Arc::new(schema))
2278 .unwrap();
2279 ctx.register_catalog("my_catalog", Arc::new(catalog));
2280
2281 let mut results = Vec::new();
2282
2283 for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] {
2284 let result = plan_and_collect(
2285 &ctx,
2286 &format!("SELECT COUNT(*) AS count FROM {table_ref}"),
2287 )
2288 .await
2289 .unwrap();
2290
2291 results.push(result);
2292 }
2293 allow_duplicates! {
2294 for result in &results {
2295 assert_snapshot!(batches_to_string(result), @r"
2296 +-------+
2297 | count |
2298 +-------+
2299 | 1 |
2300 +-------+
2301 ");
2302 }
2303 }
2304 }
2305
2306 #[tokio::test]
2307 async fn cross_catalog_access() -> Result<()> {
2308 let ctx = SessionContext::new();
2309
2310 let catalog_a = MemoryCatalogProvider::new();
2311 let schema_a = MemorySchemaProvider::new();
2312 schema_a
2313 .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
2314 catalog_a.register_schema("schema_a", Arc::new(schema_a))?;
2315 ctx.register_catalog("catalog_a", Arc::new(catalog_a));
2316
2317 let catalog_b = MemoryCatalogProvider::new();
2318 let schema_b = MemorySchemaProvider::new();
2319 schema_b
2320 .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
2321 catalog_b.register_schema("schema_b", Arc::new(schema_b))?;
2322 ctx.register_catalog("catalog_b", Arc::new(catalog_b));
2323
2324 let result = plan_and_collect(
2325 &ctx,
2326 "SELECT cat, SUM(i) AS total FROM (
2327 SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a
2328 UNION ALL
2329 SELECT i, 'b' AS cat FROM catalog_b.schema_b.table_b
2330 ) AS all
2331 GROUP BY cat
2332 ORDER BY cat
2333 ",
2334 )
2335 .await?;
2336
2337 assert_snapshot!(batches_to_string(&result), @r"
2338 +-----+-------+
2339 | cat | total |
2340 +-----+-------+
2341 | a | 1 |
2342 | b | 3 |
2343 +-----+-------+
2344 ");
2345
2346 Ok(())
2347 }
2348
2349 #[tokio::test]
2350 async fn catalogs_not_leaked() {
2351 let ctx = SessionContext::new_with_config(
2353 SessionConfig::new().with_information_schema(true),
2354 );
2355
2356 let catalog = Arc::new(MemoryCatalogProvider::new());
2358 let catalog_weak = Arc::downgrade(&catalog);
2359 ctx.register_catalog("my_catalog", catalog);
2360
2361 let catalog_list_weak = {
2362 let state = ctx.state.read();
2363 Arc::downgrade(state.catalog_list())
2364 };
2365
2366 drop(ctx);
2367
2368 assert_eq!(Weak::strong_count(&catalog_list_weak), 0);
2369 assert_eq!(Weak::strong_count(&catalog_weak), 0);
2370 }
2371
2372 #[tokio::test]
2373 async fn sql_create_schema() -> Result<()> {
2374 let ctx = SessionContext::new_with_config(
2376 SessionConfig::new().with_information_schema(true),
2377 );
2378
2379 ctx.sql("CREATE SCHEMA abc").await?.collect().await?;
2381
2382 ctx.sql("CREATE TABLE abc.y AS VALUES (1,2,3)")
2384 .await?
2385 .collect()
2386 .await?;
2387
2388 let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2390
2391 assert_eq!(results[0].num_rows(), 1);
2392 Ok(())
2393 }
2394
2395 #[tokio::test]
2396 async fn sql_create_catalog() -> Result<()> {
2397 let ctx = SessionContext::new_with_config(
2399 SessionConfig::new().with_information_schema(true),
2400 );
2401
2402 ctx.sql("CREATE DATABASE test").await?.collect().await?;
2404
2405 ctx.sql("CREATE SCHEMA test.abc").await?.collect().await?;
2407
2408 ctx.sql("CREATE TABLE test.abc.y AS VALUES (1,2,3)")
2410 .await?
2411 .collect()
2412 .await?;
2413
2414 let results = ctx.sql("SELECT * FROM information_schema.tables WHERE table_catalog='test' AND table_schema='abc' AND table_name = 'y'").await.unwrap().collect().await.unwrap();
2416
2417 assert_eq!(results[0].num_rows(), 1);
2418 Ok(())
2419 }
2420
2421 #[tokio::test]
2422 async fn custom_type_planner() -> Result<()> {
2423 let state = SessionStateBuilder::new()
2424 .with_default_features()
2425 .with_type_planner(Arc::new(MyTypePlanner {}))
2426 .build();
2427 let ctx = SessionContext::new_with_state(state);
2428 let result = ctx
2429 .sql("SELECT DATETIME '2021-01-01 00:00:00'")
2430 .await?
2431 .collect()
2432 .await?;
2433 assert_snapshot!(batches_to_string(&result), @r#"
2434 +-----------------------------+
2435 | Utf8("2021-01-01 00:00:00") |
2436 +-----------------------------+
2437 | 2021-01-01T00:00:00 |
2438 +-----------------------------+
2439 "#);
2440 Ok(())
2441 }
2442 #[test]
2443 fn preserve_session_context_id() -> Result<()> {
2444 let ctx = SessionContext::new();
2445 assert_eq!(ctx.session_id(), ctx.enable_url_table().session_id());
2450 Ok(())
2451 }
2452
2453 struct MyPhysicalPlanner {}
2454
2455 #[async_trait]
2456 impl PhysicalPlanner for MyPhysicalPlanner {
2457 async fn create_physical_plan(
2458 &self,
2459 _logical_plan: &LogicalPlan,
2460 _session_state: &SessionState,
2461 ) -> Result<Arc<dyn ExecutionPlan>> {
2462 not_impl_err!("query not supported")
2463 }
2464
2465 fn create_physical_expr(
2466 &self,
2467 _expr: &Expr,
2468 _input_dfschema: &DFSchema,
2469 _session_state: &SessionState,
2470 ) -> Result<Arc<dyn PhysicalExpr>> {
2471 unimplemented!()
2472 }
2473 }
2474
2475 #[derive(Debug)]
2476 struct MyQueryPlanner {}
2477
2478 #[async_trait]
2479 impl QueryPlanner for MyQueryPlanner {
2480 async fn create_physical_plan(
2481 &self,
2482 logical_plan: &LogicalPlan,
2483 session_state: &SessionState,
2484 ) -> Result<Arc<dyn ExecutionPlan>> {
2485 let physical_planner = MyPhysicalPlanner {};
2486 physical_planner
2487 .create_physical_plan(logical_plan, session_state)
2488 .await
2489 }
2490 }
2491
2492 async fn create_ctx(
2494 tmp_dir: &TempDir,
2495 partition_count: usize,
2496 ) -> Result<SessionContext> {
2497 let ctx = SessionContext::new_with_config(
2498 SessionConfig::new().with_target_partitions(8),
2499 );
2500
2501 let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?;
2502
2503 ctx.register_csv(
2505 "test",
2506 tmp_dir.path().to_str().unwrap(),
2507 CsvReadOptions::new().schema(&schema),
2508 )
2509 .await?;
2510
2511 Ok(ctx)
2512 }
2513
2514 #[derive(Debug)]
2515 struct MyTypePlanner {}
2516
2517 impl TypePlanner for MyTypePlanner {
2518 fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
2519 match sql_type {
2520 ast::DataType::Datetime(precision) => {
2521 let precision = match precision {
2522 Some(0) => TimeUnit::Second,
2523 Some(3) => TimeUnit::Millisecond,
2524 Some(6) => TimeUnit::Microsecond,
2525 None | Some(9) => TimeUnit::Nanosecond,
2526 _ => unreachable!(),
2527 };
2528 Ok(Some(DataType::Timestamp(precision, None)))
2529 }
2530 _ => Ok(None),
2531 }
2532 }
2533 }
2534}