datafusion_common/
cse.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with
19//! a [`CSEController`], that defines how to eliminate common subtrees from a particular
20//! [`TreeNode`] tree.
21
22use crate::hash_utils::combine_hashes;
23use crate::tree_node::{
24    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
25    TreeNodeVisitor,
26};
27use crate::Result;
28use indexmap::IndexMap;
29use std::collections::HashMap;
30use std::hash::{BuildHasher, Hash, Hasher, RandomState};
31use std::marker::PhantomData;
32use std::sync::Arc;
33
34/// Hashes the direct content of an [`TreeNode`] without recursing into its children.
35///
36/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds
37/// a deep hash of a node and its descendants during the bottom-up phase of the first
38/// traversal and so avoid computing the hash of the node and then the hash of its
39/// descendants separately.
40///
41/// If a node doesn't have any children then the value returned by `hash_node()` is
42/// similar to '.hash()`, but not necessarily returns the same value.
43pub trait HashNode {
44    fn hash_node<H: Hasher>(&self, state: &mut H);
45}
46
47impl<T: HashNode + ?Sized> HashNode for Arc<T> {
48    fn hash_node<H: Hasher>(&self, state: &mut H) {
49        (**self).hash_node(state);
50    }
51}
52
53/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
54///
55/// Normalization is the process of converting a node into a canonical form that can be used
56/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
57/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
58pub trait Normalizeable {
59    fn can_normalize(&self) -> bool;
60}
61
62/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
63/// normalized nodes in optimizations like Common Subexpression Elimination (CSE).
64///
65/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
66/// are considered equal in CSE optimization, even if their original forms differ.
67///
68/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
69/// internal representations.
70pub trait NormalizeEq: Eq + Normalizeable {
71    fn normalize_eq(&self, other: &Self) -> bool;
72}
73
74/// Identifier that represents a [`TreeNode`] tree.
75///
76/// This identifier is designed to be efficient and  "hash", "accumulate", "equal" and
77/// "have no collision (as low as possible)"
78#[derive(Debug, Eq)]
79struct Identifier<'n, N: NormalizeEq> {
80    // Hash of `node` built up incrementally during the first, visiting traversal.
81    // Its value is not necessarily equal to default hash of the node. E.g. it is not
82    // equal to `expr.hash()` if the node is `Expr`.
83    hash: u64,
84    node: &'n N,
85}
86
87impl<N: NormalizeEq> Clone for Identifier<'_, N> {
88    fn clone(&self) -> Self {
89        *self
90    }
91}
92impl<N: NormalizeEq> Copy for Identifier<'_, N> {}
93
94impl<N: NormalizeEq> Hash for Identifier<'_, N> {
95    fn hash<H: Hasher>(&self, state: &mut H) {
96        state.write_u64(self.hash);
97    }
98}
99
100impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
101    fn eq(&self, other: &Self) -> bool {
102        self.hash == other.hash && self.node.normalize_eq(other.node)
103    }
104}
105
106impl<'n, N> Identifier<'n, N>
107where
108    N: HashNode + NormalizeEq,
109{
110    fn new(node: &'n N, random_state: &RandomState) -> Self {
111        let mut hasher = random_state.build_hasher();
112        node.hash_node(&mut hasher);
113        let hash = hasher.finish();
114        Self { hash, node }
115    }
116
117    fn combine(mut self, other: Option<Self>) -> Self {
118        other.map_or(self, |other_id| {
119            self.hash = combine_hashes(self.hash, other_id.hash);
120            self
121        })
122    }
123}
124
125/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the
126/// preorder index of the nodes.
127///
128/// This cache is filled by [`CSEVisitor`] during the first traversal and is
129/// used by [`CSERewriter`] during the second traversal.
130///
131/// The purpose of this cache is to quickly find the identifier of a node during the
132/// second traversal.
133///
134/// Elements in this array are added during `f_down` so the indexes represent the preorder
135/// index of nodes and thus element 0 belongs to the root of the tree.
136///
137/// The elements of the array are tuples that contain:
138/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start
139///   from 0.
140/// - The optional [`Identifier`] of the node. If none the node should not be considered
141///   for CSE.
142///
143/// # Example
144/// An expression tree like `(a + b)` would have the following `IdArray`:
145/// ```text
146/// [
147///   (2, Some(Identifier(hash_of("a + b"), &"a + b"))),
148///   (1, Some(Identifier(hash_of("a"), &"a"))),
149///   (0, Some(Identifier(hash_of("b"), &"b")))
150/// ]
151/// ```
152type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
153
154#[derive(PartialEq, Eq)]
155/// How many times a node is evaluated. A node can be considered common if evaluated
156/// surely at least 2 times or surely only once but also conditionally.
157enum NodeEvaluation {
158    SurelyOnce,
159    ConditionallyAtLeastOnce,
160    Common,
161}
162
163/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
164type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
165
166/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
167/// extracted during the second, rewriting traversal.
168type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
169
170type ChildrenList<N> = (Vec<N>, Vec<N>);
171
172/// The [`TreeNode`] specific definition of elimination.
173pub trait CSEController {
174    /// The type of the tree nodes.
175    type Node;
176
177    /// Splits the children to normal and conditionally evaluated ones or returns `None`
178    /// if all are always evaluated.
179    fn conditional_children(node: &Self::Node) -> Option<ChildrenList<&Self::Node>>;
180
181    // A helper method called on each node before is_ignored, during top-down traversal during the first,
182    // visiting traversal of CSE.
183    fn visit_f_down(&mut self, _node: &Self::Node) {}
184    
185    // A helper method called on each node after is_ignored, during bottom-up traversal during the first,
186    // visiting traversal of CSE.
187    fn visit_f_up(&mut self, _node: &Self::Node) {}
188
189    // Returns true if a node is valid. If a node is invalid then it can't be eliminated.
190    // Validity is propagated up which means no subtree can be eliminated that contains
191    // an invalid node.
192    // (E.g. volatile expressions are not valid and subtrees containing such a node can't
193    // be extracted.)
194    fn is_valid(node: &Self::Node) -> bool;
195
196    // Returns true if a node should be ignored during CSE. Contrary to validity of a node,
197    // it is not propagated up.
198    fn is_ignored(&self, node: &Self::Node) -> bool;
199
200    // Generates a new name for the extracted subtree.
201    fn generate_alias(&self) -> String;
202
203    // Replaces a node to the generated alias.
204    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
205
206    // A helper method called on each node during top-down traversal during the second,
207    // rewriting traversal of CSE.
208    fn rewrite_f_down(&mut self, _node: &Self::Node) {}
209
210    // A helper method called on each node during bottom-up traversal during the second,
211    // rewriting traversal of CSE.
212    fn rewrite_f_up(&mut self, _node: &Self::Node) {}
213}
214
215/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common
216/// subtrees.
217#[derive(Debug)]
218pub enum FoundCommonNodes<N> {
219    /// No common [`TreeNode`]s were found
220    No { original_nodes_list: Vec<Vec<N>> },
221
222    /// Common [`TreeNode`]s were found
223    Yes {
224        /// extracted common [`TreeNode`]
225        common_nodes: Vec<(N, String)>,
226
227        /// new [`TreeNode`]s with common subtrees replaced
228        new_nodes_list: Vec<Vec<N>>,
229
230        /// original [`TreeNode`]s
231        original_nodes_list: Vec<Vec<N>>,
232    },
233}
234
235/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees.
236///
237/// An identifier contains information of the [`TreeNode`] itself and its subtrees.
238/// This visitor implementation use a stack `visit_stack` to track traversal, which
239/// lets us know when a subtree's visiting is finished. When `pre_visit` is called
240/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack.
241/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem`
242/// before the first `EnterMark` is considered to be sub-tree of the leaving node.
243///
244/// This visitor also records identifier in `id_array`. Makes the following traverse
245/// pass can get the identifier of a node without recalculate it. We assign each node
246/// in the tree a series number, start from 1, maintained by `series_number`.
247/// Series number represents the order we left (`f_up()`) a node. Has the property
248/// that child node's series number always smaller than parent's. While `id_array` is
249/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to
250/// get the index of `id_array` for each node.
251///
252/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
253/// because they should not be recognized as common subtree.
254struct CSEVisitor<'a, 'n, N, C>
255where
256    N: NormalizeEq,
257    C: CSEController<Node = N>,
258{
259    /// statistics of [`TreeNode`]s
260    node_stats: &'a mut NodeStats<'n, N>,
261
262    /// cache to speed up second traversal
263    id_array: &'a mut IdArray<'n, N>,
264
265    /// inner states
266    visit_stack: Vec<VisitRecord<'n, N>>,
267
268    /// preorder index, start from 0.
269    down_index: usize,
270
271    /// postorder index, start from 0.
272    up_index: usize,
273
274    /// a [`RandomState`] to generate hashes during the first traversal
275    random_state: &'a RandomState,
276
277    /// a flag to indicate that common [`TreeNode`]s found
278    found_common: bool,
279
280    /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`]
281    /// might not be executed depending on the runtime values of other [`TreeNode`]s, and
282    /// thus can not be extracted as a common [`TreeNode`].
283    conditional: bool,
284
285    controller: &'a mut C,
286}
287
288/// Record item that used when traversing a [`TreeNode`] tree.
289enum VisitRecord<'n, N>
290where
291    N: NormalizeEq,
292{
293    /// Marks the beginning of [`TreeNode`]. It contains:
294    /// - The post-order index assigned during the first, visiting traversal.
295    EnterMark(usize),
296
297    /// Marks an accumulated subtree. It contains:
298    /// - The accumulated identifier of a subtree.
299    /// - A accumulated boolean flag if the subtree is valid for CSE.
300    ///   The flag is propagated up from children to parent. (E.g. volatile expressions
301    ///   are not valid and can't be extracted, but non-volatile children of volatile
302    ///   expressions can be extracted.)
303    NodeItem(Identifier<'n, N>, bool),
304}
305
306impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
307where
308    N: TreeNode + HashNode + NormalizeEq,
309    C: CSEController<Node = N>,
310{
311    /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
312    /// it. Returns a tuple that contains:
313    /// - The pre-order index of the [`TreeNode`] we marked.
314    /// - The accumulated identifier of the children of the marked [`TreeNode`].
315    /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all
316    ///   children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a
317    ///   common [`TreeNode`] from its children POV).
318    ///   (E.g. if any of the children of the marked expression is not valid (e.g. is
319    ///   volatile) then the expression is also not valid, so we can propagate this
320    ///   information up from children to parents via `visit_stack` during the first,
321    ///   visiting traversal and no need to test the expression's validity beforehand with
322    ///   an extra traversal).
323    fn pop_enter_mark(
324        &mut self,
325        can_normalize: bool,
326    ) -> (usize, Option<Identifier<'n, N>>, bool) {
327        let mut node_ids: Vec<Identifier<'n, N>> = vec![];
328        let mut is_valid = true;
329
330        while let Some(item) = self.visit_stack.pop() {
331            match item {
332                VisitRecord::EnterMark(down_index) => {
333                    if can_normalize {
334                        node_ids.sort_by_key(|i| i.hash);
335                    }
336                    let node_id = node_ids
337                        .into_iter()
338                        .fold(None, |accum, item| Some(item.combine(accum)));
339                    return (down_index, node_id, is_valid);
340                }
341                VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
342                    node_ids.push(sub_node_id);
343                    is_valid &= sub_node_is_valid;
344                }
345            }
346        }
347        unreachable!("EnterMark should paired with NodeItem");
348    }
349}
350
351impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
352where
353    N: TreeNode + HashNode + NormalizeEq,
354    C: CSEController<Node = N>,
355{
356    type Node = N;
357
358    fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
359        self.id_array.push((0, None));
360        self.visit_stack
361            .push(VisitRecord::EnterMark(self.down_index));
362        self.down_index += 1;
363        self.controller.visit_f_down(node);
364
365        // If a node can short-circuit then some of its children might not be executed so
366        // count the occurrence either normal or conditional.
367        Ok(if self.conditional {
368            // If we are already in a conditionally evaluated subtree then continue
369            // traversal.
370            TreeNodeRecursion::Continue
371        } else {
372            // If we are already in a node that can short-circuit then start new
373            // traversals on its normal conditional children.
374            match C::conditional_children(node) {
375                Some((normal, conditional)) => {
376                    normal
377                        .into_iter()
378                        .try_for_each(|n| n.visit(self).map(|_| ()))?;
379                    self.conditional = true;
380                    conditional
381                        .into_iter()
382                        .try_for_each(|n| n.visit(self).map(|_| ()))?;
383                    self.conditional = false;
384
385                    TreeNodeRecursion::Jump
386                }
387
388                // In case of non-short-circuit node continue the traversal.
389                _ => TreeNodeRecursion::Continue,
390            }
391        })
392    }
393
394    fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
395        let (down_index, sub_node_id, sub_node_is_valid) =
396            self.pop_enter_mark(node.can_normalize());
397
398        let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
399        let is_valid = C::is_valid(node) && sub_node_is_valid;
400
401        self.id_array[down_index].0 = self.up_index;
402        if is_valid && !self.controller.is_ignored(node) {
403            self.id_array[down_index].1 = Some(node_id);
404            self.node_stats
405                .entry(node_id)
406                .and_modify(|evaluation| {
407                    if *evaluation == NodeEvaluation::SurelyOnce
408                        || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
409                            && !self.conditional
410                    {
411                        *evaluation = NodeEvaluation::Common;
412                        self.found_common = true;
413                    }
414                })
415                .or_insert_with(|| {
416                    if self.conditional {
417                        NodeEvaluation::ConditionallyAtLeastOnce
418                    } else {
419                        NodeEvaluation::SurelyOnce
420                    }
421                });
422        }
423        self.visit_stack
424            .push(VisitRecord::NodeItem(node_id, is_valid));
425        self.up_index += 1;
426        self.controller.visit_f_up(node);
427
428        Ok(TreeNodeRecursion::Continue)
429    }
430}
431
432/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
433/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
434/// replaced [`TreeNode`] tree.
435struct CSERewriter<'a, 'n, N, C>
436where
437    N: NormalizeEq,
438    C: CSEController<Node = N>,
439{
440    /// statistics of [`TreeNode`]s
441    node_stats: &'a NodeStats<'n, N>,
442
443    /// cache to speed up second traversal
444    id_array: &'a IdArray<'n, N>,
445
446    /// common [`TreeNode`]s, that are replaced during the second traversal, are collected
447    /// to this map
448    common_nodes: &'a mut CommonNodes<'n, N>,
449
450    // preorder index, starts from 0.
451    down_index: usize,
452
453    controller: &'a mut C,
454}
455
456impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
457where
458    N: TreeNode + NormalizeEq,
459    C: CSEController<Node = N>,
460{
461    type Node = N;
462
463    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
464        self.controller.rewrite_f_down(&node);
465
466        let (up_index, node_id) = self.id_array[self.down_index];
467        self.down_index += 1;
468
469        // Handle nodes with identifiers only
470        if let Some(node_id) = node_id {
471            let evaluation = self.node_stats.get(&node_id).unwrap();
472            if *evaluation == NodeEvaluation::Common {
473                // step index to skip all sub-node (which has smaller series number).
474                while self.down_index < self.id_array.len()
475                    && self.id_array[self.down_index].0 < up_index
476                {
477                    self.down_index += 1;
478                }
479
480                // We *must* replace all original nodes with same `node_id`, not just the first
481                // node which is inserted into the common_nodes. This is because nodes with the same
482                // `node_id` are semantically equivalent, but not exactly the same.
483                //
484                // For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
485                // In this case, we should replace the common expression `1 + a` with a new variable
486                // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
487                // `__common_cse_1`.
488                //
489                // The final result would be:
490                // - `__common_cse_1 as a + 1`
491                // - `__common_cse_1 as 1 + a`
492                //
493                // This way, we can efficiently handle semantically equivalent expressions without
494                // incorrectly treating them as identical.
495                let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
496                {
497                    self.controller.rewrite(&node, alias)
498                } else {
499                    let node_alias = self.controller.generate_alias();
500                    let rewritten = self.controller.rewrite(&node, &node_alias);
501                    self.common_nodes.insert(node_id, (node, node_alias));
502                    rewritten
503                };
504
505                return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
506            }
507        }
508
509        Ok(Transformed::no(node))
510    }
511
512    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
513        self.controller.rewrite_f_up(&node);
514
515        Ok(Transformed::no(node))
516    }
517}
518
519/// The main entry point of Common Subexpression Elimination.
520///
521/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular
522/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the
523/// [`CSE::extract_common_nodes()`] method.
524pub struct CSE<N, C: CSEController<Node = N>> {
525    random_state: RandomState,
526    phantom_data: PhantomData<N>,
527    controller: C,
528}
529
530impl<N, C> CSE<N, C>
531where
532    N: TreeNode + HashNode + Clone + NormalizeEq,
533    C: CSEController<Node = N>,
534{
535    pub fn new(controller: C) -> Self {
536        Self {
537            random_state: RandomState::new(),
538            phantom_data: PhantomData,
539            controller,
540        }
541    }
542
543    /// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
544    fn node_to_id_array<'n>(
545        &mut self,
546        node: &'n N,
547        node_stats: &mut NodeStats<'n, N>,
548        id_array: &mut IdArray<'n, N>,
549    ) -> Result<bool> {
550        let mut visitor = CSEVisitor {
551            node_stats,
552            id_array,
553            visit_stack: vec![],
554            down_index: 0,
555            up_index: 0,
556            random_state: &self.random_state,
557            found_common: false,
558            conditional: false,
559            controller: &mut self.controller,
560        };
561        node.visit(&mut visitor)?;
562
563        Ok(visitor.found_common)
564    }
565
566    /// Returns the identifier list for each element in `nodes` and a flag to indicate if
567    /// rewrite phase of CSE make sense.
568    ///
569    /// Returns and array with 1 element for each input node in `nodes`
570    ///
571    /// Each element is itself the result of [`CSE::node_to_id_array`] for that node
572    /// (e.g. the identifiers for each node in the tree)
573    fn to_arrays<'n>(
574        &mut self,
575        nodes: &'n [N],
576        node_stats: &mut NodeStats<'n, N>,
577    ) -> Result<(bool, Vec<IdArray<'n, N>>)> {
578        let mut found_common = false;
579        nodes
580            .iter()
581            .map(|n| {
582                let mut id_array = vec![];
583                self.node_to_id_array(n, node_stats, &mut id_array)
584                    .map(|fc| {
585                        found_common |= fc;
586
587                        id_array
588                    })
589            })
590            .collect::<Result<Vec<_>>>()
591            .map(|id_arrays| (found_common, id_arrays))
592    }
593
594    /// Replace common subtrees in `node` with the corresponding temporary
595    /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]
596    fn replace_common_node<'n>(
597        &mut self,
598        node: N,
599        id_array: &IdArray<'n, N>,
600        node_stats: &NodeStats<'n, N>,
601        common_nodes: &mut CommonNodes<'n, N>,
602    ) -> Result<N> {
603        if id_array.is_empty() {
604            Ok(Transformed::no(node))
605        } else {
606            node.rewrite(&mut CSERewriter {
607                node_stats,
608                id_array,
609                common_nodes,
610                down_index: 0,
611                controller: &mut self.controller,
612            })
613        }
614        .data()
615    }
616
617    /// Replace common subtrees in `nodes_list` with the corresponding temporary
618    /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`].
619    fn rewrite_nodes_list<'n>(
620        &mut self,
621        nodes_list: Vec<Vec<N>>,
622        arrays_list: &[Vec<IdArray<'n, N>>],
623        node_stats: &NodeStats<'n, N>,
624        common_nodes: &mut CommonNodes<'n, N>,
625    ) -> Result<Vec<Vec<N>>> {
626        nodes_list
627            .into_iter()
628            .zip(arrays_list.iter())
629            .map(|(nodes, arrays)| {
630                nodes
631                    .into_iter()
632                    .zip(arrays.iter())
633                    .map(|(node, id_array)| {
634                        self.replace_common_node(node, id_array, node_stats, common_nodes)
635                    })
636                    .collect::<Result<Vec<_>>>()
637            })
638            .collect::<Result<Vec<_>>>()
639    }
640
641    /// Extracts common [`TreeNode`]s and rewrites `nodes_list`.
642    ///
643    /// Returns [`FoundCommonNodes`] recording the result of the extraction.
644    pub fn extract_common_nodes(
645        &mut self,
646        nodes_list: Vec<Vec<N>>,
647    ) -> Result<FoundCommonNodes<N>> {
648        let mut found_common = false;
649        let mut node_stats = NodeStats::new();
650
651        let id_arrays_list = nodes_list
652            .iter()
653            .map(|nodes| {
654                self.to_arrays(nodes, &mut node_stats)
655                    .map(|(fc, id_arrays)| {
656                        found_common |= fc;
657
658                        id_arrays
659                    })
660            })
661            .collect::<Result<Vec<_>>>()?;
662        if found_common {
663            let mut common_nodes = CommonNodes::new();
664            let new_nodes_list = self.rewrite_nodes_list(
665                // Must clone the list of nodes as Identifiers use references to original
666                // nodes so we have to keep them intact.
667                nodes_list.clone(),
668                &id_arrays_list,
669                &node_stats,
670                &mut common_nodes,
671            )?;
672            assert!(!common_nodes.is_empty());
673
674            Ok(FoundCommonNodes::Yes {
675                common_nodes: common_nodes.into_values().collect(),
676                new_nodes_list,
677                original_nodes_list: nodes_list,
678            })
679        } else {
680            Ok(FoundCommonNodes::No {
681                original_nodes_list: nodes_list,
682            })
683        }
684    }
685}
686
687#[cfg(test)]
688mod test {
689    use crate::alias::AliasGenerator;
690    use crate::cse::{
691        CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
692        Normalizeable, CSE,
693    };
694    use crate::tree_node::tests::TestTreeNode;
695    use crate::Result;
696    use std::collections::HashSet;
697    use std::hash::{Hash, Hasher};
698
699    const CSE_PREFIX: &str = "__common_node";
700
701    #[derive(Clone, Copy)]
702    pub enum TestTreeNodeMask {
703        Normal,
704        NormalAndAggregates,
705    }
706
707    pub struct TestTreeNodeCSEController<'a> {
708        alias_generator: &'a AliasGenerator,
709        mask: TestTreeNodeMask,
710    }
711
712    impl<'a> TestTreeNodeCSEController<'a> {
713        fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self {
714            Self {
715                alias_generator,
716                mask,
717            }
718        }
719    }
720
721    impl CSEController for TestTreeNodeCSEController<'_> {
722        type Node = TestTreeNode<String>;
723
724        fn conditional_children(
725            _: &Self::Node,
726        ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
727            None
728        }
729
730        fn is_valid(_node: &Self::Node) -> bool {
731            true
732        }
733
734        fn is_ignored(&self, node: &Self::Node) -> bool {
735            let is_leaf = node.is_leaf();
736            let is_aggr = node.data == "avg" || node.data == "sum";
737
738            match self.mask {
739                TestTreeNodeMask::Normal => is_leaf || is_aggr,
740                TestTreeNodeMask::NormalAndAggregates => is_leaf,
741            }
742        }
743
744        fn generate_alias(&self) -> String {
745            self.alias_generator.next(CSE_PREFIX)
746        }
747
748        fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
749            TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
750        }
751    }
752
753    impl HashNode for TestTreeNode<String> {
754        fn hash_node<H: Hasher>(&self, state: &mut H) {
755            self.data.hash(state);
756        }
757    }
758
759    impl Normalizeable for TestTreeNode<String> {
760        fn can_normalize(&self) -> bool {
761            false
762        }
763    }
764
765    impl NormalizeEq for TestTreeNode<String> {
766        fn normalize_eq(&self, other: &Self) -> bool {
767            self == other
768        }
769    }
770
771    #[test]
772    fn id_array_visitor() -> Result<()> {
773        let alias_generator = AliasGenerator::new();
774        let mut eliminator = CSE::new(TestTreeNodeCSEController::new(
775            &alias_generator,
776            TestTreeNodeMask::Normal,
777        ));
778
779        let a_plus_1 = TestTreeNode::new(
780            vec![
781                TestTreeNode::new_leaf("a".to_string()),
782                TestTreeNode::new_leaf("1".to_string()),
783            ],
784            "+".to_string(),
785        );
786        let avg_c = TestTreeNode::new(
787            vec![TestTreeNode::new_leaf("c".to_string())],
788            "avg".to_string(),
789        );
790        let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string());
791        let sum_a_plus_1_minus_avg_c =
792            TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
793        let root = TestTreeNode::new(
794            vec![
795                sum_a_plus_1_minus_avg_c,
796                TestTreeNode::new_leaf("2".to_string()),
797            ],
798            "*".to_string(),
799        );
800
801        let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
802            panic!("Cannot extract subtree references")
803        };
804        let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else {
805            panic!("Cannot extract subtree references")
806        };
807        let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
808            panic!("Cannot extract subtree references")
809        };
810
811        // skip aggregates
812        let mut id_array = vec![];
813        eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
814
815        // Collect distinct hashes and set them to 0 in `id_array`
816        fn collect_hashes(
817            id_array: &mut IdArray<'_, TestTreeNode<String>>,
818        ) -> HashSet<u64> {
819            id_array
820                .iter_mut()
821                .flat_map(|(_, id_option)| {
822                    id_option.as_mut().map(|node_id| {
823                        let hash = node_id.hash;
824                        node_id.hash = 0;
825                        hash
826                    })
827                })
828                .collect::<HashSet<_>>()
829        }
830
831        let hashes = collect_hashes(&mut id_array);
832        assert_eq!(hashes.len(), 3);
833
834        let expected = vec![
835            (
836                8,
837                Some(Identifier {
838                    hash: 0,
839                    node: &root,
840                }),
841            ),
842            (
843                6,
844                Some(Identifier {
845                    hash: 0,
846                    node: sum_a_plus_1_minus_avg_c,
847                }),
848            ),
849            (3, None),
850            (
851                2,
852                Some(Identifier {
853                    hash: 0,
854                    node: a_plus_1,
855                }),
856            ),
857            (0, None),
858            (1, None),
859            (5, None),
860            (4, None),
861            (7, None),
862        ];
863        assert_eq!(expected, id_array);
864
865        // include aggregates
866        let mut eliminator = CSE::new(TestTreeNodeCSEController::new(
867            &alias_generator,
868            TestTreeNodeMask::NormalAndAggregates,
869        ));
870
871        let mut id_array = vec![];
872        eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?;
873
874        let hashes = collect_hashes(&mut id_array);
875        assert_eq!(hashes.len(), 5);
876
877        let expected = vec![
878            (
879                8,
880                Some(Identifier {
881                    hash: 0,
882                    node: &root,
883                }),
884            ),
885            (
886                6,
887                Some(Identifier {
888                    hash: 0,
889                    node: sum_a_plus_1_minus_avg_c,
890                }),
891            ),
892            (
893                3,
894                Some(Identifier {
895                    hash: 0,
896                    node: sum_a_plus_1,
897                }),
898            ),
899            (
900                2,
901                Some(Identifier {
902                    hash: 0,
903                    node: a_plus_1,
904                }),
905            ),
906            (0, None),
907            (1, None),
908            (
909                5,
910                Some(Identifier {
911                    hash: 0,
912                    node: avg_c,
913                }),
914            ),
915            (4, None),
916            (7, None),
917        ];
918        assert_eq!(expected, id_array);
919
920        Ok(())
921    }
922}