datafusion_sql/
cte.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
21
22use datafusion_common::{
23    not_impl_err, plan_err,
24    tree_node::{TreeNode, TreeNodeRecursion},
25    Result,
26};
27use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
28use sqlparser::ast::{Query, SetExpr, SetOperator, With};
29
30impl<S: ContextProvider> SqlToRel<'_, S> {
31    pub(super) fn plan_with_clause(
32        &self,
33        with: With,
34        planner_context: &mut PlannerContext,
35    ) -> Result<()> {
36        let is_recursive = with.recursive;
37        // Process CTEs from top to bottom
38        for cte in with.cte_tables {
39            // A `WITH` block can't use the same name more than once
40            let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone());
41            if planner_context.contains_cte(&cte_name) {
42                return plan_err!(
43                    "WITH query name {cte_name:?} specified more than once"
44                );
45            }
46
47            // Create a logical plan for the CTE
48            let cte_plan = if is_recursive {
49                self.recursive_cte(cte_name.clone(), *cte.query, planner_context)?
50            } else {
51                self.non_recursive_cte(*cte.query, planner_context)?
52            };
53
54            // Each `WITH` block can change the column names in the last
55            // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
56            let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
57            // Export the CTE to the outer query
58            planner_context.insert_cte(cte_name, final_plan);
59        }
60        Ok(())
61    }
62
63    fn non_recursive_cte(
64        &self,
65        cte_query: Query,
66        planner_context: &mut PlannerContext,
67    ) -> Result<LogicalPlan> {
68        self.query_to_plan(cte_query, planner_context)
69    }
70
71    fn recursive_cte(
72        &self,
73        cte_name: String,
74        mut cte_query: Query,
75        planner_context: &mut PlannerContext,
76    ) -> Result<LogicalPlan> {
77        if !self
78            .context_provider
79            .options()
80            .execution
81            .enable_recursive_ctes
82        {
83            return not_impl_err!("Recursive CTEs are not enabled");
84        }
85
86        let (left_expr, right_expr, set_quantifier) = match *cte_query.body {
87            SetExpr::SetOperation {
88                op: SetOperator::Union,
89                left,
90                right,
91                set_quantifier,
92            } => (left, right, set_quantifier),
93            other => {
94                // If the query is not a UNION, then it is not a recursive CTE
95                cte_query.body = Box::new(other);
96                return self.non_recursive_cte(cte_query, planner_context);
97            }
98        };
99
100        // Each recursive CTE consists of two parts in the logical plan:
101        //   1. A static term   (the left-hand side on the SQL, where the
102        //                       referencing to the same CTE is not allowed)
103        //
104        //   2. A recursive term (the right hand side, and the recursive
105        //                       part)
106
107        // Since static term does not have any specific properties, it can
108        // be compiled as if it was a regular expression. This will
109        // allow us to infer the schema to be used in the recursive term.
110
111        // ---------- Step 1: Compile the static term ------------------
112        let static_plan = self.set_expr_to_plan(*left_expr, planner_context)?;
113
114        // Since the recursive CTEs include a component that references a
115        // table with its name, like the example below:
116        //
117        // WITH RECURSIVE values(n) AS (
118        //      SELECT 1 as n -- static term
119        //    UNION ALL
120        //      SELECT n + 1
121        //      FROM values -- self reference
122        //      WHERE n < 100
123        // )
124        //
125        // We need a temporary 'relation' to be referenced and used. PostgreSQL
126        // calls this a 'working table', but it is entirely an implementation
127        // detail and a 'real' table with that name might not even exist (as
128        // in the case of DataFusion).
129        //
130        // Since we can't simply register a table during planning stage (it is
131        // an execution problem), we'll use a relation object that preserves the
132        // schema of the input perfectly and also knows which recursive CTE it is
133        // bound to.
134
135        // ---------- Step 2: Create a temporary relation ------------------
136        // Step 2.1: Create a table source for the temporary relation
137        let work_table_source = self
138            .context_provider
139            .create_cte_work_table(&cte_name, Arc::clone(static_plan.schema().inner()))?;
140
141        // Step 2.2: Create a temporary relation logical plan that will be used
142        // as the input to the recursive term
143        let work_table_plan = LogicalPlanBuilder::scan(
144            cte_name.to_string(),
145            Arc::clone(&work_table_source),
146            None,
147        )?
148        .build()?;
149
150        let name = cte_name.clone();
151
152        // Step 2.3: Register the temporary relation in the planning context
153        // For all the self references in the variadic term, we'll replace it
154        // with the temporary relation we created above by temporarily registering
155        // it as a CTE. This temporary relation in the planning context will be
156        // replaced by the actual CTE plan once we're done with the planning.
157        planner_context.insert_cte(cte_name.clone(), work_table_plan);
158
159        // ---------- Step 3: Compile the recursive term ------------------
160        // this uses the named_relation we inserted above to resolve the
161        // relation. This ensures that the recursive term uses the named relation logical plan
162        // and thus the 'continuance' physical plan as its input and source
163        let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?;
164
165        // Check if the recursive term references the CTE itself,
166        // if not, it is a non-recursive CTE
167        if !has_work_table_reference(&recursive_plan, &work_table_source) {
168            // Remove the work table plan from the context
169            planner_context.remove_cte(&cte_name);
170            // Compile it as a non-recursive CTE
171            return self.set_operation_to_plan(
172                SetOperator::Union,
173                static_plan,
174                recursive_plan,
175                set_quantifier,
176            );
177        }
178
179        // ---------- Step 4: Create the final plan ------------------
180        // Step 4.1: Compile the final plan
181        let distinct = !Self::is_union_all(set_quantifier)?;
182        LogicalPlanBuilder::from(static_plan)
183            .to_recursive_query(name, recursive_plan, distinct)?
184            .build()
185    }
186}
187
188fn has_work_table_reference(
189    plan: &LogicalPlan,
190    work_table_source: &Arc<dyn TableSource>,
191) -> bool {
192    let mut has_reference = false;
193    plan.apply(|node| {
194        if let LogicalPlan::TableScan(scan) = node {
195            if Arc::ptr_eq(&scan.source, work_table_source) {
196                has_reference = true;
197                return Ok(TreeNodeRecursion::Stop);
198            }
199        }
200        Ok(TreeNodeRecursion::Continue)
201    })
202    // Closure always return Ok
203    .unwrap();
204    has_reference
205}