datafusion_physical_expr/equivalence/properties/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::iter::Peekable;
19use std::sync::Arc;
20
21use super::EquivalenceProperties;
22use crate::equivalence::class::AcrossPartitions;
23use crate::{ConstExpr, PhysicalSortExpr};
24
25use arrow::datatypes::SchemaRef;
26use datafusion_common::{internal_err, Result};
27use datafusion_physical_expr_common::sort_expr::LexOrdering;
28
29/// Computes the union (in the sense of `UnionExec`) `EquivalenceProperties`
30/// of `lhs` and `rhs` according to the schema of `lhs`.
31///
32/// Rules: The `UnionExec` does not interleave its inputs, instead it passes
33/// each input partition from the children as its own output.
34///
35/// Since the output equivalence properties are properties that are true for
36/// *all* output partitions, that is the same as being true for all *input*
37/// partitions.
38fn calculate_union_binary(
39    lhs: EquivalenceProperties,
40    mut rhs: EquivalenceProperties,
41) -> Result<EquivalenceProperties> {
42    // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema):
43    if !rhs.schema.eq(&lhs.schema) {
44        rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?;
45    }
46
47    // First, calculate valid constants for the union. An expression is constant
48    // at the output of the union if it is constant in both sides with matching values.
49    let rhs_constants = rhs.constants();
50    let constants = lhs
51        .constants()
52        .into_iter()
53        .filter_map(|lhs_const| {
54            // Find matching constant expression in RHS
55            rhs_constants
56                .iter()
57                .find(|rhs_const| rhs_const.expr.eq(&lhs_const.expr))
58                .map(|rhs_const| {
59                    let mut const_expr = lhs_const.clone();
60                    // If both sides have matching constant values, preserve it.
61                    // Otherwise, set fall back to heterogeneous values.
62                    if lhs_const.across_partitions != rhs_const.across_partitions {
63                        const_expr.across_partitions = AcrossPartitions::Heterogeneous;
64                    }
65                    const_expr
66                })
67        })
68        .collect::<Vec<_>>();
69
70    // Next, calculate valid orderings for the union by searching for prefixes
71    // in both sides.
72    let mut orderings = UnionEquivalentOrderingBuilder::new();
73    orderings.add_satisfied_orderings(&lhs, &rhs)?;
74    orderings.add_satisfied_orderings(&rhs, &lhs)?;
75    let orderings = orderings.build();
76
77    let mut eq_properties = EquivalenceProperties::new(lhs.schema);
78    eq_properties.add_constants(constants)?;
79    eq_properties.add_orderings(orderings);
80    Ok(eq_properties)
81}
82
83/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties`
84/// of the given `EquivalenceProperties` in `eqps` according to the given
85/// output `schema` (which need not be the same with those of `lhs` and `rhs`
86/// as details such as nullability may be different).
87pub fn calculate_union(
88    eqps: Vec<EquivalenceProperties>,
89    schema: SchemaRef,
90) -> Result<EquivalenceProperties> {
91    // TODO: In some cases, we should be able to preserve some equivalence
92    //       classes. Add support for such cases.
93    let mut iter = eqps.into_iter();
94    let Some(mut acc) = iter.next() else {
95        return internal_err!(
96            "Cannot calculate EquivalenceProperties for a union with no inputs"
97        );
98    };
99
100    // Harmonize the schema of the init with the schema of the union:
101    if !acc.schema.eq(&schema) {
102        acc = acc.with_new_schema(schema)?;
103    }
104    // Fold in the rest of the EquivalenceProperties:
105    for props in iter {
106        acc = calculate_union_binary(acc, props)?;
107    }
108    Ok(acc)
109}
110
111#[derive(Debug)]
112enum AddedOrdering {
113    /// The ordering was added to the in progress result
114    Yes,
115    /// The ordering was not added
116    No(LexOrdering),
117}
118
119/// Builds valid output orderings of a `UnionExec`
120#[derive(Debug)]
121struct UnionEquivalentOrderingBuilder {
122    orderings: Vec<LexOrdering>,
123}
124
125impl UnionEquivalentOrderingBuilder {
126    fn new() -> Self {
127        Self { orderings: vec![] }
128    }
129
130    /// Add all orderings from `source` that satisfy `properties`,
131    /// potentially augmented with the constants in `source`.
132    ///
133    /// Note: Any column that is known to be constant can be inserted into the
134    /// ordering without changing its meaning.
135    ///
136    /// For example:
137    /// * Orderings in `source` contains `[a ASC, c ASC]` and constants contains
138    ///   `b`,
139    /// * `properties` has the ordering `[a ASC, b ASC]`.
140    ///
141    /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was
142    /// in the sort order and `b` was a constant).
143    fn add_satisfied_orderings(
144        &mut self,
145        source: &EquivalenceProperties,
146        properties: &EquivalenceProperties,
147    ) -> Result<()> {
148        let constants = source.constants();
149        let properties_constants = properties.constants();
150        for mut ordering in source.oeq_cache.normal_cls.clone() {
151            // Progressively shorten the ordering to search for a satisfied prefix:
152            loop {
153                ordering = match self.try_add_ordering(
154                    ordering,
155                    &constants,
156                    properties,
157                    &properties_constants,
158                )? {
159                    AddedOrdering::Yes => break,
160                    AddedOrdering::No(ordering) => {
161                        let mut sort_exprs: Vec<_> = ordering.into();
162                        sort_exprs.pop();
163                        if let Some(ordering) = LexOrdering::new(sort_exprs) {
164                            ordering
165                        } else {
166                            break;
167                        }
168                    }
169                }
170            }
171        }
172        Ok(())
173    }
174
175    /// Adds `ordering`, potentially augmented with `constants`, if it satisfies
176    /// the given `properties`.
177    ///
178    /// # Returns
179    ///
180    /// An [`AddedOrdering::Yes`] instance if the ordering was added (either
181    /// directly or augmented), or was empty. An [`AddedOrdering::No`] instance
182    /// otherwise.
183    fn try_add_ordering(
184        &mut self,
185        ordering: LexOrdering,
186        constants: &[ConstExpr],
187        properties: &EquivalenceProperties,
188        properties_constants: &[ConstExpr],
189    ) -> Result<AddedOrdering> {
190        if properties.ordering_satisfy(ordering.clone())? {
191            // If the ordering satisfies the target properties, no need to
192            // augment it with constants.
193            self.orderings.push(ordering);
194            Ok(AddedOrdering::Yes)
195        } else if self.try_find_augmented_ordering(
196            &ordering,
197            constants,
198            properties,
199            properties_constants,
200        ) {
201            // Augmented with constants to match the properties.
202            Ok(AddedOrdering::Yes)
203        } else {
204            Ok(AddedOrdering::No(ordering))
205        }
206    }
207
208    /// Attempts to add `constants` to `ordering` to satisfy the properties.
209    /// Returns `true` if augmentation took place, `false` otherwise.
210    fn try_find_augmented_ordering(
211        &mut self,
212        ordering: &LexOrdering,
213        constants: &[ConstExpr],
214        properties: &EquivalenceProperties,
215        properties_constants: &[ConstExpr],
216    ) -> bool {
217        let mut result = false;
218        // Can only augment if there are constants.
219        if !constants.is_empty() {
220            // For each equivalent ordering in properties, try and augment
221            // `ordering` with the constants to match `existing_ordering`:
222            for existing_ordering in properties.oeq_class.iter() {
223                if let Some(augmented_ordering) = Self::augment_ordering(
224                    ordering,
225                    constants,
226                    existing_ordering,
227                    properties_constants,
228                ) {
229                    self.orderings.push(augmented_ordering);
230                    result = true;
231                }
232            }
233        }
234        result
235    }
236
237    /// Attempts to augment the ordering with constants to match `existing_ordering`.
238    /// Returns `Some(ordering)` if an augmented ordering was found, `None` otherwise.
239    fn augment_ordering(
240        ordering: &LexOrdering,
241        constants: &[ConstExpr],
242        existing_ordering: &LexOrdering,
243        existing_constants: &[ConstExpr],
244    ) -> Option<LexOrdering> {
245        let mut augmented_ordering = vec![];
246        let mut sort_exprs = ordering.iter().peekable();
247        let mut existing_sort_exprs = existing_ordering.iter().peekable();
248
249        // Walk in parallel down the two orderings, trying to match them up:
250        while sort_exprs.peek().is_some() || existing_sort_exprs.peek().is_some() {
251            // If the next expressions are equal, add the next match. Otherwise,
252            // try and match with a constant.
253            if let Some(expr) =
254                advance_if_match(&mut sort_exprs, &mut existing_sort_exprs)
255            {
256                augmented_ordering.push(expr);
257            } else if let Some(expr) =
258                advance_if_matches_constant(&mut sort_exprs, existing_constants)
259            {
260                augmented_ordering.push(expr);
261            } else if let Some(expr) =
262                advance_if_matches_constant(&mut existing_sort_exprs, constants)
263            {
264                augmented_ordering.push(expr);
265            } else {
266                // no match, can't continue the ordering, return what we have
267                break;
268            }
269        }
270
271        LexOrdering::new(augmented_ordering)
272    }
273
274    fn build(self) -> Vec<LexOrdering> {
275        self.orderings
276    }
277}
278
279/// Advances two iterators in parallel if the next expressions are equal.
280/// Otherwise, the iterators are left unchanged and returns `None`.
281fn advance_if_match<'a>(
282    iter1: &mut Peekable<impl Iterator<Item = &'a PhysicalSortExpr>>,
283    iter2: &mut Peekable<impl Iterator<Item = &'a PhysicalSortExpr>>,
284) -> Option<PhysicalSortExpr> {
285    let (expr1, expr2) = (iter1.peek()?, iter2.peek()?);
286    if expr1.eq(expr2) {
287        iter1.next();
288        iter2.next().cloned()
289    } else {
290        None
291    }
292}
293
294/// Advances the iterator with a constant if the next expression matches one of
295/// the constants. Otherwise, the iterator is left unchanged and returns `None`.
296fn advance_if_matches_constant<'a>(
297    iter: &mut Peekable<impl Iterator<Item = &'a PhysicalSortExpr>>,
298    constants: &[ConstExpr],
299) -> Option<PhysicalSortExpr> {
300    let expr = iter.peek()?;
301    let const_expr = constants.iter().find(|c| expr.expr.eq(&c.expr))?;
302    let found_expr = PhysicalSortExpr::new(Arc::clone(&const_expr.expr), expr.options);
303    iter.next();
304    Some(found_expr)
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::equivalence::tests::{create_test_schema, parse_sort_expr};
311    use crate::expressions::col;
312    use crate::PhysicalExpr;
313
314    use arrow::datatypes::{DataType, Field, Schema};
315    use datafusion_common::ScalarValue;
316
317    use itertools::Itertools;
318
319    /// Checks whether `expr` is among in the `const_exprs`.
320    fn const_exprs_contains(
321        const_exprs: &[ConstExpr],
322        expr: &Arc<dyn PhysicalExpr>,
323    ) -> bool {
324        const_exprs
325            .iter()
326            .any(|const_expr| const_expr.expr.eq(expr))
327    }
328
329    #[test]
330    fn test_union_equivalence_properties_multi_children_1() -> Result<()> {
331        let schema = create_test_schema().unwrap();
332        let schema2 = append_fields(&schema, "1");
333        let schema3 = append_fields(&schema, "2");
334        UnionEquivalenceTest::new(&schema)
335            // Children 1
336            .with_child_sort(vec![vec!["a", "b", "c"]], &schema)?
337            // Children 2
338            .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)?
339            // Children 3
340            .with_child_sort(vec![vec!["a2", "b2"]], &schema3)?
341            .with_expected_sort(vec![vec!["a", "b"]])?
342            .run()
343    }
344
345    #[test]
346    fn test_union_equivalence_properties_multi_children_2() -> Result<()> {
347        let schema = create_test_schema().unwrap();
348        let schema2 = append_fields(&schema, "1");
349        let schema3 = append_fields(&schema, "2");
350        UnionEquivalenceTest::new(&schema)
351            // Children 1
352            .with_child_sort(vec![vec!["a", "b", "c"]], &schema)?
353            // Children 2
354            .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)?
355            // Children 3
356            .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)?
357            .with_expected_sort(vec![vec!["a", "b", "c"]])?
358            .run()
359    }
360
361    #[test]
362    fn test_union_equivalence_properties_multi_children_3() -> Result<()> {
363        let schema = create_test_schema().unwrap();
364        let schema2 = append_fields(&schema, "1");
365        let schema3 = append_fields(&schema, "2");
366        UnionEquivalenceTest::new(&schema)
367            // Children 1
368            .with_child_sort(vec![vec!["a", "b"]], &schema)?
369            // Children 2
370            .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)?
371            // Children 3
372            .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)?
373            .with_expected_sort(vec![vec!["a", "b"]])?
374            .run()
375    }
376
377    #[test]
378    fn test_union_equivalence_properties_multi_children_4() -> Result<()> {
379        let schema = create_test_schema().unwrap();
380        let schema2 = append_fields(&schema, "1");
381        let schema3 = append_fields(&schema, "2");
382        UnionEquivalenceTest::new(&schema)
383            // Children 1
384            .with_child_sort(vec![vec!["a", "b"]], &schema)?
385            // Children 2
386            .with_child_sort(vec![vec!["a1", "b1"]], &schema2)?
387            // Children 3
388            .with_child_sort(vec![vec!["b2", "c2"]], &schema3)?
389            .with_expected_sort(vec![])?
390            .run()
391    }
392
393    #[test]
394    fn test_union_equivalence_properties_multi_children_5() -> Result<()> {
395        let schema = create_test_schema().unwrap();
396        let schema2 = append_fields(&schema, "1");
397        UnionEquivalenceTest::new(&schema)
398            // Children 1
399            .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)?
400            // Children 2
401            .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)?
402            .with_expected_sort(vec![vec!["a", "b"], vec!["c"]])?
403            .run()
404    }
405
406    #[test]
407    fn test_union_equivalence_properties_constants_common_constants() -> Result<()> {
408        let schema = create_test_schema().unwrap();
409        UnionEquivalenceTest::new(&schema)
410            .with_child_sort_and_const_exprs(
411                // First child: [a ASC], const [b, c]
412                vec![vec!["a"]],
413                vec!["b", "c"],
414                &schema,
415            )?
416            .with_child_sort_and_const_exprs(
417                // Second child: [b ASC], const [a, c]
418                vec![vec!["b"]],
419                vec!["a", "c"],
420                &schema,
421            )?
422            .with_expected_sort_and_const_exprs(
423                // Union expected orderings: [[a ASC], [b ASC]], const [c]
424                vec![vec!["a"], vec!["b"]],
425                vec!["c"],
426            )?
427            .run()
428    }
429
430    #[test]
431    fn test_union_equivalence_properties_constants_prefix() -> Result<()> {
432        let schema = create_test_schema().unwrap();
433        UnionEquivalenceTest::new(&schema)
434            .with_child_sort_and_const_exprs(
435                // First child: [a ASC], const []
436                vec![vec!["a"]],
437                vec![],
438                &schema,
439            )?
440            .with_child_sort_and_const_exprs(
441                // Second child: [a ASC, b ASC], const []
442                vec![vec!["a", "b"]],
443                vec![],
444                &schema,
445            )?
446            .with_expected_sort_and_const_exprs(
447                // Union orderings: [a ASC], const []
448                vec![vec!["a"]],
449                vec![],
450            )?
451            .run()
452    }
453
454    #[test]
455    fn test_union_equivalence_properties_constants_asc_desc_mismatch() -> Result<()> {
456        let schema = create_test_schema().unwrap();
457        UnionEquivalenceTest::new(&schema)
458            .with_child_sort_and_const_exprs(
459                // First child: [a ASC], const []
460                vec![vec!["a"]],
461                vec![],
462                &schema,
463            )?
464            .with_child_sort_and_const_exprs(
465                // Second child orderings: [a DESC], const []
466                vec![vec!["a DESC"]],
467                vec![],
468                &schema,
469            )?
470            .with_expected_sort_and_const_exprs(
471                // Union doesn't have any ordering or constant
472                vec![],
473                vec![],
474            )?
475            .run()
476    }
477
478    #[test]
479    fn test_union_equivalence_properties_constants_different_schemas() -> Result<()> {
480        let schema = create_test_schema().unwrap();
481        let schema2 = append_fields(&schema, "1");
482        UnionEquivalenceTest::new(&schema)
483            .with_child_sort_and_const_exprs(
484                // First child orderings: [a ASC], const []
485                vec![vec!["a"]],
486                vec![],
487                &schema,
488            )?
489            .with_child_sort_and_const_exprs(
490                // Second child orderings: [a1 ASC, b1 ASC], const []
491                vec![vec!["a1", "b1"]],
492                vec![],
493                &schema2,
494            )?
495            .with_expected_sort_and_const_exprs(
496                // Union orderings: [a ASC]
497                //
498                // Note that a, and a1 are at the same index for their
499                // corresponding schemas.
500                vec![vec!["a"]],
501                vec![],
502            )?
503            .run()
504    }
505
506    #[test]
507    fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> {
508        let schema = create_test_schema().unwrap();
509        UnionEquivalenceTest::new(&schema)
510            .with_child_sort_and_const_exprs(
511                // First child orderings: [a ASC, c ASC], const [b]
512                vec![vec!["a", "c"]],
513                vec!["b"],
514                &schema,
515            )?
516            .with_child_sort_and_const_exprs(
517                // Second child orderings: [b ASC, c ASC], const [a]
518                vec![vec!["b", "c"]],
519                vec!["a"],
520                &schema,
521            )?
522            .with_expected_sort_and_const_exprs(
523                // Union orderings: [
524                //   [a ASC, b ASC, c ASC],
525                //   [b ASC, a ASC, c ASC]
526                // ], const []
527                vec![vec!["a", "b", "c"], vec!["b", "a", "c"]],
528                vec![],
529            )?
530            .run()
531    }
532
533    #[test]
534    fn test_union_equivalence_properties_constants_no_fill_gaps() -> Result<()> {
535        let schema = create_test_schema().unwrap();
536        UnionEquivalenceTest::new(&schema)
537            .with_child_sort_and_const_exprs(
538                // First child orderings: [a ASC, c ASC], const [d] // some other constant
539                vec![vec!["a", "c"]],
540                vec!["d"],
541                &schema,
542            )?
543            .with_child_sort_and_const_exprs(
544                // Second child orderings: [b ASC, c ASC], const [a]
545                vec![vec!["b", "c"]],
546                vec!["a"],
547                &schema,
548            )?
549            .with_expected_sort_and_const_exprs(
550                // Union orderings: [[a]] (only a is constant)
551                vec![vec!["a"]],
552                vec![],
553            )?
554            .run()
555    }
556
557    #[test]
558    fn test_union_equivalence_properties_constants_fill_some_gaps() -> Result<()> {
559        let schema = create_test_schema().unwrap();
560        UnionEquivalenceTest::new(&schema)
561            .with_child_sort_and_const_exprs(
562                // First child orderings: [c ASC], const [a, b] // some other constant
563                vec![vec!["c"]],
564                vec!["a", "b"],
565                &schema,
566            )?
567            .with_child_sort_and_const_exprs(
568                // Second child orderings: [a DESC, b], const []
569                vec![vec!["a DESC", "b"]],
570                vec![],
571                &schema,
572            )?
573            .with_expected_sort_and_const_exprs(
574                // Union orderings: [[a, b]] (can fill in the a/b with constants)
575                vec![vec!["a DESC", "b"]],
576                vec![],
577            )?
578            .run()
579    }
580
581    #[test]
582    fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()>
583    {
584        let schema = create_test_schema().unwrap();
585        UnionEquivalenceTest::new(&schema)
586            .with_child_sort_and_const_exprs(
587                // First child orderings: [a ASC, c ASC], const [b]
588                vec![vec!["a", "c"]],
589                vec!["b"],
590                &schema,
591            )?
592            .with_child_sort_and_const_exprs(
593                // Second child orderings: [b ASC, c ASC], const [a]
594                vec![vec!["b DESC", "c"]],
595                vec!["a"],
596                &schema,
597            )?
598            .with_expected_sort_and_const_exprs(
599                // Union orderings: [
600                //   [a ASC, b ASC, c ASC],
601                //   [b ASC, a ASC, c ASC]
602                // ], const []
603                vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]],
604                vec![],
605            )?
606            .run()
607    }
608
609    #[test]
610    fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> {
611        let schema = create_test_schema().unwrap();
612        UnionEquivalenceTest::new(&schema)
613            .with_child_sort_and_const_exprs(
614                // First child: [a ASC, b ASC, d ASC], const [c]
615                vec![vec!["a", "b", "d"]],
616                vec!["c"],
617                &schema,
618            )?
619            .with_child_sort_and_const_exprs(
620                // Second child: [a ASC, c ASC, d ASC], const [b]
621                vec![vec!["a", "c", "d"]],
622                vec!["b"],
623                &schema,
624            )?
625            .with_expected_sort_and_const_exprs(
626                // Union orderings:
627                // [a, b, c, d]
628                // [a, c, b, d]
629                vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]],
630                vec![],
631            )?
632            .run()
633    }
634
635    #[test]
636    fn test_union_equivalence_properties_constants_gap_fill_and_common() -> Result<()> {
637        let schema = create_test_schema().unwrap();
638        UnionEquivalenceTest::new(&schema)
639            .with_child_sort_and_const_exprs(
640                // First child: [a DESC, d ASC], const [b, c]
641                vec![vec!["a DESC", "d"]],
642                vec!["b", "c"],
643                &schema,
644            )?
645            .with_child_sort_and_const_exprs(
646                // Second child: [a DESC, c ASC, d ASC], const [b]
647                vec![vec!["a DESC", "c", "d"]],
648                vec!["b"],
649                &schema,
650            )?
651            .with_expected_sort_and_const_exprs(
652                // Union orderings:
653                // [a DESC, c, d]  [b]
654                vec![vec!["a DESC", "c", "d"]],
655                vec!["b"],
656            )?
657            .run()
658    }
659
660    #[test]
661    fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> {
662        let schema = create_test_schema().unwrap();
663        UnionEquivalenceTest::new(&schema)
664            .with_child_sort_and_const_exprs(
665                // NB `b DESC` in the first child
666                //
667                // First child: [a ASC, b DESC, d ASC], const [c]
668                vec![vec!["a", "b DESC", "d"]],
669                vec!["c"],
670                &schema,
671            )?
672            .with_child_sort_and_const_exprs(
673                // Second child: [a ASC, c ASC, d ASC], const [b]
674                vec![vec!["a", "c", "d"]],
675                vec!["b"],
676                &schema,
677            )?
678            .with_expected_sort_and_const_exprs(
679                // Union orderings:
680                // [a, b, d] (c constant)
681                // [a, c, d] (b constant)
682                vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]],
683                vec![],
684            )?
685            .run()
686    }
687
688    // TODO tests with multiple constants
689
690    #[derive(Debug)]
691    struct UnionEquivalenceTest {
692        /// The schema of the output of the Union
693        output_schema: SchemaRef,
694        /// The equivalence properties of each child to the union
695        child_properties: Vec<EquivalenceProperties>,
696        /// The expected output properties of the union. Must be set before
697        /// running `build`
698        expected_properties: Option<EquivalenceProperties>,
699    }
700
701    impl UnionEquivalenceTest {
702        fn new(output_schema: &SchemaRef) -> Self {
703            Self {
704                output_schema: Arc::clone(output_schema),
705                child_properties: vec![],
706                expected_properties: None,
707            }
708        }
709
710        /// Add a union input with the specified orderings
711        ///
712        /// See [`Self::make_props`] for the format of the strings in `orderings`
713        fn with_child_sort(
714            mut self,
715            orderings: Vec<Vec<&str>>,
716            schema: &SchemaRef,
717        ) -> Result<Self> {
718            let properties = self.make_props(orderings, vec![], schema)?;
719            self.child_properties.push(properties);
720            Ok(self)
721        }
722
723        /// Add a union input with the specified orderings and constant
724        /// equivalences
725        ///
726        /// See [`Self::make_props`] for the format of the strings in
727        /// `orderings` and `constants`
728        fn with_child_sort_and_const_exprs(
729            mut self,
730            orderings: Vec<Vec<&str>>,
731            constants: Vec<&str>,
732            schema: &SchemaRef,
733        ) -> Result<Self> {
734            let properties = self.make_props(orderings, constants, schema)?;
735            self.child_properties.push(properties);
736            Ok(self)
737        }
738
739        /// Set the expected output sort order for the union of the children
740        ///
741        /// See [`Self::make_props`] for the format of the strings in `orderings`
742        fn with_expected_sort(mut self, orderings: Vec<Vec<&str>>) -> Result<Self> {
743            let properties = self.make_props(orderings, vec![], &self.output_schema)?;
744            self.expected_properties = Some(properties);
745            Ok(self)
746        }
747
748        /// Set the expected output sort order and constant expressions for the
749        /// union of the children
750        ///
751        /// See [`Self::make_props`] for the format of the strings in
752        /// `orderings` and `constants`.
753        fn with_expected_sort_and_const_exprs(
754            mut self,
755            orderings: Vec<Vec<&str>>,
756            constants: Vec<&str>,
757        ) -> Result<Self> {
758            let properties =
759                self.make_props(orderings, constants, &self.output_schema)?;
760            self.expected_properties = Some(properties);
761            Ok(self)
762        }
763
764        /// compute the union's output equivalence properties from the child
765        /// properties, and compare them to the expected properties
766        fn run(self) -> Result<()> {
767            let Self {
768                output_schema,
769                child_properties,
770                expected_properties,
771            } = self;
772
773            let expected_properties =
774                expected_properties.expect("expected_properties not set");
775
776            // try all permutations of the children
777            // as the code treats lhs and rhs differently
778            for child_properties in child_properties
779                .iter()
780                .cloned()
781                .permutations(child_properties.len())
782            {
783                println!("--- permutation ---");
784                for c in &child_properties {
785                    println!("{c}");
786                }
787                let actual_properties =
788                    calculate_union(child_properties, Arc::clone(&output_schema))
789                        .expect("failed to calculate union equivalence properties");
790                Self::assert_eq_properties_same(
791                    &actual_properties,
792                    &expected_properties,
793                    format!(
794                        "expected: {expected_properties:?}\nactual:  {actual_properties:?}"
795                    ),
796                );
797            }
798            Ok(())
799        }
800
801        fn assert_eq_properties_same(
802            lhs: &EquivalenceProperties,
803            rhs: &EquivalenceProperties,
804            err_msg: String,
805        ) {
806            // Check whether constants are same
807            let lhs_constants = lhs.constants();
808            let rhs_constants = rhs.constants();
809            for rhs_constant in &rhs_constants {
810                assert!(
811                    const_exprs_contains(&lhs_constants, &rhs_constant.expr),
812                    "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
813                );
814            }
815            assert_eq!(
816                lhs_constants.len(),
817                rhs_constants.len(),
818                "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
819            );
820
821            // Check whether orderings are same.
822            let lhs_orderings = lhs.oeq_class();
823            let rhs_orderings = rhs.oeq_class();
824            for rhs_ordering in rhs_orderings.iter() {
825                assert!(
826                    lhs_orderings.contains(rhs_ordering),
827                    "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
828                );
829            }
830            assert_eq!(
831                lhs_orderings.len(),
832                rhs_orderings.len(),
833                "{err_msg}\nlhs: {lhs}\nrhs: {rhs}"
834            );
835        }
836
837        /// Make equivalence properties for the specified columns named in orderings and constants
838        ///
839        /// orderings: strings formatted like `"a"` or `"a DESC"`. See [`parse_sort_expr`]
840        /// constants: strings formatted like `"a"`.
841        fn make_props(
842            &self,
843            orderings: Vec<Vec<&str>>,
844            constants: Vec<&str>,
845            schema: &SchemaRef,
846        ) -> Result<EquivalenceProperties> {
847            let orderings = orderings.iter().map(|ordering| {
848                ordering.iter().map(|name| parse_sort_expr(name, schema))
849            });
850
851            let constants = constants
852                .iter()
853                .map(|col_name| ConstExpr::from(col(col_name, schema).unwrap()));
854
855            let mut props =
856                EquivalenceProperties::new_with_orderings(Arc::clone(schema), orderings);
857            props.add_constants(constants)?;
858            Ok(props)
859        }
860    }
861
862    #[test]
863    fn test_union_constant_value_preservation() -> Result<()> {
864        let schema = Arc::new(Schema::new(vec![
865            Field::new("a", DataType::Int32, true),
866            Field::new("b", DataType::Int32, true),
867        ]));
868
869        let col_a = col("a", &schema)?;
870        let literal_10 = ScalarValue::Int32(Some(10));
871
872        // Create first input with a=10
873        let const_expr1 = ConstExpr::new(
874            Arc::clone(&col_a),
875            AcrossPartitions::Uniform(Some(literal_10.clone())),
876        );
877        let mut input1 = EquivalenceProperties::new(Arc::clone(&schema));
878        input1.add_constants(vec![const_expr1])?;
879
880        // Create second input with a=10
881        let const_expr2 = ConstExpr::new(
882            Arc::clone(&col_a),
883            AcrossPartitions::Uniform(Some(literal_10.clone())),
884        );
885        let mut input2 = EquivalenceProperties::new(Arc::clone(&schema));
886        input2.add_constants(vec![const_expr2])?;
887
888        // Calculate union properties
889        let union_props = calculate_union(vec![input1, input2], schema)?;
890
891        // Verify column 'a' remains constant with value 10
892        let const_a = &union_props.constants()[0];
893        assert!(const_a.expr.eq(&col_a));
894        assert_eq!(
895            const_a.across_partitions,
896            AcrossPartitions::Uniform(Some(literal_10))
897        );
898
899        Ok(())
900    }
901
902    /// Return a new schema with the same types, but new field names
903    ///
904    /// The new field names are the old field names with `text` appended.
905    ///
906    /// For example, the schema "a", "b", "c" becomes "a1", "b1", "c1"
907    /// if `text` is "1".
908    fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef {
909        Arc::new(Schema::new(
910            schema
911                .fields()
912                .iter()
913                .map(|field| {
914                    Field::new(
915                        // Annotate name with `text`:
916                        format!("{}{}", field.name(), text),
917                        field.data_type().clone(),
918                        field.is_nullable(),
919                    )
920                })
921                .collect::<Vec<_>>(),
922        ))
923    }
924
925    #[test]
926    fn test_constants_share_values() -> Result<()> {
927        let schema = Arc::new(Schema::new(vec![
928            Field::new("const_1", DataType::Utf8, false),
929            Field::new("const_2", DataType::Utf8, false),
930        ]));
931
932        let col_const_1 = col("const_1", &schema)?;
933        let col_const_2 = col("const_2", &schema)?;
934
935        let literal_foo = ScalarValue::Utf8(Some("foo".to_owned()));
936        let literal_bar = ScalarValue::Utf8(Some("bar".to_owned()));
937
938        let const_expr_1_foo = ConstExpr::new(
939            Arc::clone(&col_const_1),
940            AcrossPartitions::Uniform(Some(literal_foo.clone())),
941        );
942        let const_expr_2_foo = ConstExpr::new(
943            Arc::clone(&col_const_2),
944            AcrossPartitions::Uniform(Some(literal_foo.clone())),
945        );
946        let const_expr_2_bar = ConstExpr::new(
947            Arc::clone(&col_const_2),
948            AcrossPartitions::Uniform(Some(literal_bar.clone())),
949        );
950
951        let mut input1 = EquivalenceProperties::new(Arc::clone(&schema));
952        let mut input2 = EquivalenceProperties::new(Arc::clone(&schema));
953
954        // | Input | Const_1 | Const_2 |
955        // | ----- | ------- | ------- |
956        // |     1 | foo     | foo     |
957        // |     2 | foo     | bar     |
958        input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?;
959        input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?;
960
961        // Calculate union properties
962        let union_props = calculate_union(vec![input1, input2], schema)?;
963
964        // This should result in:
965        //   const_1 = Uniform("foo")
966        //   const_2 = Heterogeneous
967        assert_eq!(union_props.constants().len(), 2);
968        let union_const_1 = &union_props.constants()[0];
969        assert!(union_const_1.expr.eq(&col_const_1));
970        assert_eq!(
971            union_const_1.across_partitions,
972            AcrossPartitions::Uniform(Some(literal_foo)),
973        );
974        let union_const_2 = &union_props.constants()[1];
975        assert!(union_const_2.expr.eq(&col_const_2));
976        assert_eq!(
977            union_const_2.across_partitions,
978            AcrossPartitions::Heterogeneous,
979        );
980
981        Ok(())
982    }
983}