datafusion_catalog/memory/
table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`MemTable`] for querying `Vec<RecordBatch>` by DataFusion.
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::TableProvider;
26
27use arrow::datatypes::SchemaRef;
28use arrow::record_batch::RecordBatch;
29use datafusion_common::error::Result;
30use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt};
31use datafusion_common_runtime::JoinSet;
32use datafusion_datasource::memory::{MemSink, MemorySourceConfig};
33use datafusion_datasource::sink::DataSinkExec;
34use datafusion_datasource::source::DataSourceExec;
35use datafusion_expr::dml::InsertOp;
36use datafusion_expr::{Expr, SortExpr, TableType};
37use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering};
38use datafusion_physical_plan::repartition::RepartitionExec;
39use datafusion_physical_plan::{
40    common, ExecutionPlan, ExecutionPlanProperties, Partitioning,
41};
42use datafusion_session::Session;
43
44use async_trait::async_trait;
45use futures::StreamExt;
46use log::debug;
47use parking_lot::Mutex;
48use tokio::sync::RwLock;
49
50// backward compatibility
51pub use datafusion_datasource::memory::PartitionData;
52
53/// In-memory data source for presenting a `Vec<RecordBatch>` as a
54/// data source that can be queried by DataFusion. This allows data to
55/// be pre-loaded into memory and then repeatedly queried without
56/// incurring additional file I/O overhead.
57#[derive(Debug)]
58pub struct MemTable {
59    schema: SchemaRef,
60    // batches used to be pub(crate), but it's needed to be public for the tests
61    pub batches: Vec<PartitionData>,
62    constraints: Constraints,
63    column_defaults: HashMap<String, Expr>,
64    /// Optional pre-known sort order(s). Must be `SortExpr`s.
65    /// inserting data into this table removes the order
66    pub sort_order: Arc<Mutex<Vec<Vec<SortExpr>>>>,
67}
68
69impl MemTable {
70    /// Create a new in-memory table from the provided schema and record batches.
71    ///
72    /// Requires at least one partition. To construct an empty `MemTable`, pass
73    /// `vec![vec![]]` as the `partitions` argument, this represents one partition with
74    /// no batches.
75    pub fn try_new(schema: SchemaRef, partitions: Vec<Vec<RecordBatch>>) -> Result<Self> {
76        if partitions.is_empty() {
77            return plan_err!("No partitions provided, expected at least one partition");
78        }
79
80        for batches in partitions.iter().flatten() {
81            let batches_schema = batches.schema();
82            if !schema.contains(&batches_schema) {
83                debug!(
84                    "mem table schema does not contain batches schema. \
85                        Target_schema: {schema:?}. Batches Schema: {batches_schema:?}"
86                );
87                return plan_err!("Mismatch between schema and batches");
88            }
89        }
90
91        Ok(Self {
92            schema,
93            batches: partitions
94                .into_iter()
95                .map(|e| Arc::new(RwLock::new(e)))
96                .collect::<Vec<_>>(),
97            constraints: Constraints::default(),
98            column_defaults: HashMap::new(),
99            sort_order: Arc::new(Mutex::new(vec![])),
100        })
101    }
102
103    /// Assign constraints
104    pub fn with_constraints(mut self, constraints: Constraints) -> Self {
105        self.constraints = constraints;
106        self
107    }
108
109    /// Assign column defaults
110    pub fn with_column_defaults(
111        mut self,
112        column_defaults: HashMap<String, Expr>,
113    ) -> Self {
114        self.column_defaults = column_defaults;
115        self
116    }
117
118    /// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
119    ///
120    /// If the data is not sorted by this order, DataFusion may produce
121    /// incorrect results.
122    ///
123    /// DataFusion may take advantage of this ordering to omit sorts
124    /// or use more efficient algorithms.
125    ///
126    /// Note that multiple sort orders are supported, if some are known to be
127    /// equivalent,
128    pub fn with_sort_order(self, mut sort_order: Vec<Vec<SortExpr>>) -> Self {
129        std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
130        self
131    }
132
133    /// Create a mem table by reading from another data source
134    pub async fn load(
135        t: Arc<dyn TableProvider>,
136        output_partitions: Option<usize>,
137        state: &dyn Session,
138    ) -> Result<Self> {
139        let schema = t.schema();
140        let constraints = t.constraints();
141        let exec = t.scan(state, None, &[], None).await?;
142        let partition_count = exec.output_partitioning().partition_count();
143
144        let mut join_set = JoinSet::new();
145
146        for part_idx in 0..partition_count {
147            let task = state.task_ctx();
148            let exec = Arc::clone(&exec);
149            join_set.spawn(async move {
150                let stream = exec.execute(part_idx, task)?;
151                common::collect(stream).await
152            });
153        }
154
155        let mut data: Vec<Vec<RecordBatch>> =
156            Vec::with_capacity(exec.output_partitioning().partition_count());
157
158        while let Some(result) = join_set.join_next().await {
159            match result {
160                Ok(res) => data.push(res?),
161                Err(e) => {
162                    if e.is_panic() {
163                        std::panic::resume_unwind(e.into_panic());
164                    } else {
165                        unreachable!();
166                    }
167                }
168            }
169        }
170
171        let mut exec = DataSourceExec::new(Arc::new(MemorySourceConfig::try_new(
172            &data,
173            Arc::clone(&schema),
174            None,
175        )?));
176        if let Some(cons) = constraints {
177            exec = exec.with_constraints(cons.clone());
178        }
179
180        if let Some(num_partitions) = output_partitions {
181            let exec = RepartitionExec::try_new(
182                Arc::new(exec),
183                Partitioning::RoundRobinBatch(num_partitions),
184            )?;
185
186            // execute and collect results
187            let mut output_partitions = vec![];
188            for i in 0..exec.properties().output_partitioning().partition_count() {
189                // execute this *output* partition and collect all batches
190                let task_ctx = state.task_ctx();
191                let mut stream = exec.execute(i, task_ctx)?;
192                let mut batches = vec![];
193                while let Some(result) = stream.next().await {
194                    batches.push(result?);
195                }
196                output_partitions.push(batches);
197            }
198
199            return MemTable::try_new(Arc::clone(&schema), output_partitions);
200        }
201        MemTable::try_new(Arc::clone(&schema), data)
202    }
203}
204
205#[async_trait]
206impl TableProvider for MemTable {
207    fn as_any(&self) -> &dyn Any {
208        self
209    }
210
211    fn schema(&self) -> SchemaRef {
212        Arc::clone(&self.schema)
213    }
214
215    fn constraints(&self) -> Option<&Constraints> {
216        Some(&self.constraints)
217    }
218
219    fn table_type(&self) -> TableType {
220        TableType::Base
221    }
222
223    async fn scan(
224        &self,
225        state: &dyn Session,
226        projection: Option<&Vec<usize>>,
227        _filters: &[Expr],
228        _limit: Option<usize>,
229    ) -> Result<Arc<dyn ExecutionPlan>> {
230        let mut partitions = vec![];
231        for arc_inner_vec in self.batches.iter() {
232            let inner_vec = arc_inner_vec.read().await;
233            partitions.push(inner_vec.clone())
234        }
235
236        let mut source =
237            MemorySourceConfig::try_new(&partitions, self.schema(), projection.cloned())?;
238
239        let show_sizes = state.config_options().explain.show_sizes;
240        source = source.with_show_sizes(show_sizes);
241
242        // add sort information if present
243        let sort_order = self.sort_order.lock();
244        if !sort_order.is_empty() {
245            let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?;
246
247            let eqp = state.execution_props();
248            let mut file_sort_order = vec![];
249            for sort_exprs in sort_order.iter() {
250                let physical_exprs =
251                    create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?;
252                file_sort_order.extend(LexOrdering::new(physical_exprs));
253            }
254            source = source.try_with_sort_information(file_sort_order)?;
255        }
256
257        Ok(DataSourceExec::from_data_source(source))
258    }
259
260    /// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
261    ///
262    /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
263    ///
264    /// # Arguments
265    ///
266    /// * `state` - The [`SessionState`] containing the context for executing the plan.
267    /// * `input` - The [`ExecutionPlan`] to execute and insert.
268    ///
269    /// # Returns
270    ///
271    /// * A plan that returns the number of rows written.
272    ///
273    /// [`SessionState`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html
274    async fn insert_into(
275        &self,
276        _state: &dyn Session,
277        input: Arc<dyn ExecutionPlan>,
278        insert_op: InsertOp,
279    ) -> Result<Arc<dyn ExecutionPlan>> {
280        // If we are inserting into the table, any sort order may be messed up so reset it here
281        *self.sort_order.lock() = vec![];
282
283        // Create a physical plan from the logical plan.
284        // Check that the schema of the plan matches the schema of this table.
285        self.schema()
286            .logically_equivalent_names_and_types(&input.schema())?;
287
288        if insert_op != InsertOp::Append {
289            return not_impl_err!("{insert_op} not implemented for MemoryTable yet");
290        }
291        let sink = MemSink::try_new(self.batches.clone(), Arc::clone(&self.schema))?;
292        Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
293    }
294
295    fn get_column_default(&self, column: &str) -> Option<&Expr> {
296        self.column_defaults.get(column)
297    }
298}