datafusion/datasource/
memory_test.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#[cfg(test)]
19mod tests {
20
21    use crate::datasource::MemTable;
22    use crate::datasource::{provider_as_source, DefaultTableSource};
23    use crate::physical_plan::collect;
24    use crate::prelude::SessionContext;
25    use arrow::array::{AsArray, Int32Array};
26    use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
27    use arrow::error::ArrowError;
28    use arrow::record_batch::RecordBatch;
29    use arrow_schema::SchemaRef;
30    use datafusion_catalog::TableProvider;
31    use datafusion_common::{DataFusionError, Result};
32    use datafusion_expr::dml::InsertOp;
33    use datafusion_expr::LogicalPlanBuilder;
34    use futures::StreamExt;
35    use std::collections::HashMap;
36    use std::sync::Arc;
37
38    #[tokio::test]
39    async fn test_with_projection() -> Result<()> {
40        let session_ctx = SessionContext::new();
41        let task_ctx = session_ctx.task_ctx();
42        let schema = Arc::new(Schema::new(vec![
43            Field::new("a", DataType::Int32, false),
44            Field::new("b", DataType::Int32, false),
45            Field::new("c", DataType::Int32, false),
46            Field::new("d", DataType::Int32, true),
47        ]));
48
49        let batch = RecordBatch::try_new(
50            schema.clone(),
51            vec![
52                Arc::new(Int32Array::from(vec![1, 2, 3])),
53                Arc::new(Int32Array::from(vec![4, 5, 6])),
54                Arc::new(Int32Array::from(vec![7, 8, 9])),
55                Arc::new(Int32Array::from(vec![None, None, Some(9)])),
56            ],
57        )?;
58
59        let provider = MemTable::try_new(schema, vec![vec![batch]])?;
60
61        // scan with projection
62        let exec = provider
63            .scan(&session_ctx.state(), Some(&vec![2, 1]), &[], None)
64            .await?;
65
66        let mut it = exec.execute(0, task_ctx)?;
67        let batch2 = it.next().await.unwrap()?;
68        assert_eq!(2, batch2.schema().fields().len());
69        assert_eq!("c", batch2.schema().field(0).name());
70        assert_eq!("b", batch2.schema().field(1).name());
71        assert_eq!(2, batch2.num_columns());
72
73        Ok(())
74    }
75
76    #[tokio::test]
77    async fn test_without_projection() -> Result<()> {
78        let session_ctx = SessionContext::new();
79        let task_ctx = session_ctx.task_ctx();
80        let schema = Arc::new(Schema::new(vec![
81            Field::new("a", DataType::Int32, false),
82            Field::new("b", DataType::Int32, false),
83            Field::new("c", DataType::Int32, false),
84        ]));
85
86        let batch = RecordBatch::try_new(
87            schema.clone(),
88            vec![
89                Arc::new(Int32Array::from(vec![1, 2, 3])),
90                Arc::new(Int32Array::from(vec![4, 5, 6])),
91                Arc::new(Int32Array::from(vec![7, 8, 9])),
92            ],
93        )?;
94
95        let provider = MemTable::try_new(schema, vec![vec![batch]])?;
96
97        let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
98        let mut it = exec.execute(0, task_ctx)?;
99        let batch1 = it.next().await.unwrap()?;
100        assert_eq!(3, batch1.schema().fields().len());
101        assert_eq!(3, batch1.num_columns());
102
103        Ok(())
104    }
105
106    #[tokio::test]
107    async fn test_invalid_projection() -> Result<()> {
108        let session_ctx = SessionContext::new();
109
110        let schema = Arc::new(Schema::new(vec![
111            Field::new("a", DataType::Int32, false),
112            Field::new("b", DataType::Int32, false),
113            Field::new("c", DataType::Int32, false),
114        ]));
115
116        let batch = RecordBatch::try_new(
117            schema.clone(),
118            vec![
119                Arc::new(Int32Array::from(vec![1, 2, 3])),
120                Arc::new(Int32Array::from(vec![4, 5, 6])),
121                Arc::new(Int32Array::from(vec![7, 8, 9])),
122            ],
123        )?;
124
125        let provider = MemTable::try_new(schema, vec![vec![batch]])?;
126
127        let projection: Vec<usize> = vec![0, 4];
128
129        match provider
130            .scan(&session_ctx.state(), Some(&projection), &[], None)
131            .await
132        {
133            Err(DataFusionError::ArrowError(err, _)) => match err.as_ref() {
134                ArrowError::SchemaError(e) => {
135                    assert_eq!(
136                        "\"project index 4 out of bounds, max field 3\"",
137                        format!("{e:?}")
138                    )
139                }
140                _ => panic!("unexpected error"),
141            },
142            res => panic!("Scan should failed on invalid projection, got {res:?}"),
143        };
144
145        Ok(())
146    }
147
148    #[test]
149    fn test_schema_validation_incompatible_column() -> Result<()> {
150        let schema1 = Arc::new(Schema::new(vec![
151            Field::new("a", DataType::Int32, false),
152            Field::new("b", DataType::Int32, false),
153            Field::new("c", DataType::Int32, false),
154        ]));
155
156        let schema2 = Arc::new(Schema::new(vec![
157            Field::new("a", DataType::Int32, false),
158            Field::new("b", DataType::Float64, false),
159            Field::new("c", DataType::Int32, false),
160        ]));
161
162        let batch = RecordBatch::try_new(
163            schema1,
164            vec![
165                Arc::new(Int32Array::from(vec![1, 2, 3])),
166                Arc::new(Int32Array::from(vec![4, 5, 6])),
167                Arc::new(Int32Array::from(vec![7, 8, 9])),
168            ],
169        )?;
170
171        let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
172        assert_eq!(
173            "Error during planning: Mismatch between schema and batches",
174            e.strip_backtrace()
175        );
176
177        Ok(())
178    }
179
180    #[test]
181    fn test_schema_validation_different_column_count() -> Result<()> {
182        let schema1 = Arc::new(Schema::new(vec![
183            Field::new("a", DataType::Int32, false),
184            Field::new("c", DataType::Int32, false),
185        ]));
186
187        let schema2 = Arc::new(Schema::new(vec![
188            Field::new("a", DataType::Int32, false),
189            Field::new("b", DataType::Int32, false),
190            Field::new("c", DataType::Int32, false),
191        ]));
192
193        let batch = RecordBatch::try_new(
194            schema1,
195            vec![
196                Arc::new(Int32Array::from(vec![1, 2, 3])),
197                Arc::new(Int32Array::from(vec![7, 5, 9])),
198            ],
199        )?;
200
201        let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
202        assert_eq!(
203            "Error during planning: Mismatch between schema and batches",
204            e.strip_backtrace()
205        );
206
207        Ok(())
208    }
209
210    #[tokio::test]
211    async fn test_merged_schema() -> Result<()> {
212        let session_ctx = SessionContext::new();
213        let task_ctx = session_ctx.task_ctx();
214        let mut metadata = HashMap::new();
215        metadata.insert("foo".to_string(), "bar".to_string());
216
217        let schema1 = Schema::new_with_metadata(
218            vec![
219                Field::new("a", DataType::Int32, false),
220                Field::new("b", DataType::Int32, false),
221                Field::new("c", DataType::Int32, false),
222            ],
223            // test for comparing metadata
224            metadata,
225        );
226
227        let schema2 = Schema::new(vec![
228            // test for comparing nullability
229            Field::new("a", DataType::Int32, true),
230            Field::new("b", DataType::Int32, false),
231            Field::new("c", DataType::Int32, false),
232        ]);
233
234        let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
235
236        let batch1 = RecordBatch::try_new(
237            Arc::new(schema1),
238            vec![
239                Arc::new(Int32Array::from(vec![1, 2, 3])),
240                Arc::new(Int32Array::from(vec![4, 5, 6])),
241                Arc::new(Int32Array::from(vec![7, 8, 9])),
242            ],
243        )?;
244
245        let batch2 = RecordBatch::try_new(
246            Arc::new(schema2),
247            vec![
248                Arc::new(Int32Array::from(vec![1, 2, 3])),
249                Arc::new(Int32Array::from(vec![4, 5, 6])),
250                Arc::new(Int32Array::from(vec![7, 8, 9])),
251            ],
252        )?;
253
254        let provider =
255            MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
256
257        let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
258        let mut it = exec.execute(0, task_ctx)?;
259        let batch1 = it.next().await.unwrap()?;
260        assert_eq!(3, batch1.schema().fields().len());
261        assert_eq!(3, batch1.num_columns());
262
263        Ok(())
264    }
265
266    async fn experiment(
267        schema: SchemaRef,
268        initial_data: Vec<Vec<RecordBatch>>,
269        inserted_data: Vec<Vec<RecordBatch>>,
270    ) -> Result<Vec<Vec<RecordBatch>>> {
271        let expected_count: u64 = inserted_data
272            .iter()
273            .flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
274            .sum();
275
276        // Create a new session context
277        let session_ctx = SessionContext::new();
278        // Create and register the initial table with the provided schema and data
279        let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
280        session_ctx.register_table("t", initial_table.clone())?;
281        let target = Arc::new(DefaultTableSource::new(initial_table.clone()));
282        // Create and register the source table with the provided schema and inserted data
283        let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
284        session_ctx.register_table("source", source_table.clone())?;
285        // Convert the source table into a provider so that it can be used in a query
286        let source = provider_as_source(source_table);
287        // Create a table scan logical plan to read from the source table
288        let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
289        // Create an insert plan to insert the source data into the initial table
290        let insert_into_table =
291            LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)?
292                .build()?;
293        // Create a physical plan from the insert plan
294        let plan = session_ctx
295            .state()
296            .create_physical_plan(&insert_into_table)
297            .await?;
298
299        // Execute the physical plan and collect the results
300        let res = collect(plan, session_ctx.task_ctx()).await?;
301        assert_eq!(extract_count(res), expected_count);
302
303        // Read the data from the initial table and store it in a vector of partitions
304        let mut partitions = vec![];
305        for partition in initial_table.batches.iter() {
306            let part = partition.read().await.clone();
307            partitions.push(part);
308        }
309        Ok(partitions)
310    }
311
312    /// Returns the value of results. For example, returns 6 given the following
313    ///
314    /// ```text
315    /// +-------+,
316    /// | count |,
317    /// +-------+,
318    /// | 6     |,
319    /// +-------+,
320    /// ```
321    fn extract_count(res: Vec<RecordBatch>) -> u64 {
322        assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
323        let batch = &res[0];
324        assert_eq!(
325            batch.num_columns(),
326            1,
327            "expected 1 column, got {}",
328            batch.num_columns()
329        );
330        let col = batch.column(0).as_primitive::<UInt64Type>();
331        assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
332        let val = col
333            .iter()
334            .next()
335            .expect("had value")
336            .expect("expected non null");
337        val
338    }
339
340    // Test inserting a single batch of data into a single partition
341    #[tokio::test]
342    async fn test_insert_into_single_partition() -> Result<()> {
343        // Create a new schema with one field called "a" of type Int32
344        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
345
346        // Create a new batch of data to insert into the table
347        let batch = RecordBatch::try_new(
348            schema.clone(),
349            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
350        )?;
351        // Run the experiment and obtain the resulting data in the table
352        let resulting_data_in_table =
353            experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
354                .await?;
355        // Ensure that the table now contains two batches of data in the same partition
356        assert_eq!(resulting_data_in_table[0].len(), 2);
357        Ok(())
358    }
359
360    // Test inserting multiple batches of data into a single partition
361    #[tokio::test]
362    async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
363        // Create a new schema with one field called "a" of type Int32
364        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
365
366        // Create a new batch of data to insert into the table
367        let batch = RecordBatch::try_new(
368            schema.clone(),
369            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
370        )?;
371        // Run the experiment and obtain the resulting data in the table
372        let resulting_data_in_table = experiment(
373            schema,
374            vec![vec![batch.clone()]],
375            vec![vec![batch.clone()], vec![batch]],
376        )
377        .await?;
378        // Ensure that the table now contains three batches of data in the same partition
379        assert_eq!(resulting_data_in_table[0].len(), 3);
380        Ok(())
381    }
382
383    // Test inserting multiple batches of data into multiple partitions
384    #[tokio::test]
385    async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
386        // Create a new schema with one field called "a" of type Int32
387        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
388
389        // Create a new batch of data to insert into the table
390        let batch = RecordBatch::try_new(
391            schema.clone(),
392            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
393        )?;
394        // Run the experiment and obtain the resulting data in the table
395        let resulting_data_in_table = experiment(
396            schema,
397            vec![vec![batch.clone()], vec![batch.clone()]],
398            vec![
399                vec![batch.clone(), batch.clone()],
400                vec![batch.clone(), batch],
401            ],
402        )
403        .await?;
404        // Ensure that each partition in the table now contains three batches of data
405        assert_eq!(resulting_data_in_table[0].len(), 3);
406        assert_eq!(resulting_data_in_table[1].len(), 3);
407        Ok(())
408    }
409
410    #[tokio::test]
411    async fn test_insert_from_empty_table() -> Result<()> {
412        // Create a new schema with one field called "a" of type Int32
413        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
414
415        // Create a new batch of data to insert into the table
416        let batch = RecordBatch::try_new(
417            schema.clone(),
418            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
419        )?;
420        // Run the experiment and obtain the resulting data in the table
421        let resulting_data_in_table = experiment(
422            schema,
423            vec![vec![batch.clone(), batch.clone()]],
424            vec![vec![]],
425        )
426        .await?;
427        // Ensure that the table now contains two batches of data in the same partition
428        assert_eq!(resulting_data_in_table[0].len(), 2);
429        Ok(())
430    }
431
432    // Test inserting a batch into a MemTable without any partitions
433    #[tokio::test]
434    async fn test_insert_into_zero_partition() -> Result<()> {
435        // Create a new schema with one field called "a" of type Int32
436        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
437
438        // Create a new batch of data to insert into the table
439        let batch = RecordBatch::try_new(
440            schema.clone(),
441            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
442        )?;
443        // Run the experiment and expect an error
444        let experiment_result = experiment(schema, vec![], vec![vec![batch.clone()]])
445            .await
446            .unwrap_err();
447        // Ensure that there is a descriptive error message
448        assert_eq!(
449            "Error during planning: No partitions provided, expected at least one partition",
450            experiment_result.strip_backtrace()
451        );
452        Ok(())
453    }
454}