1#[cfg(test)]
19mod tests {
20
21 use crate::datasource::MemTable;
22 use crate::datasource::{provider_as_source, DefaultTableSource};
23 use crate::physical_plan::collect;
24 use crate::prelude::SessionContext;
25 use arrow::array::{AsArray, Int32Array};
26 use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
27 use arrow::error::ArrowError;
28 use arrow::record_batch::RecordBatch;
29 use arrow_schema::SchemaRef;
30 use datafusion_catalog::TableProvider;
31 use datafusion_common::{DataFusionError, Result};
32 use datafusion_expr::dml::InsertOp;
33 use datafusion_expr::LogicalPlanBuilder;
34 use futures::StreamExt;
35 use std::collections::HashMap;
36 use std::sync::Arc;
37
38 #[tokio::test]
39 async fn test_with_projection() -> Result<()> {
40 let session_ctx = SessionContext::new();
41 let task_ctx = session_ctx.task_ctx();
42 let schema = Arc::new(Schema::new(vec![
43 Field::new("a", DataType::Int32, false),
44 Field::new("b", DataType::Int32, false),
45 Field::new("c", DataType::Int32, false),
46 Field::new("d", DataType::Int32, true),
47 ]));
48
49 let batch = RecordBatch::try_new(
50 schema.clone(),
51 vec![
52 Arc::new(Int32Array::from(vec![1, 2, 3])),
53 Arc::new(Int32Array::from(vec![4, 5, 6])),
54 Arc::new(Int32Array::from(vec![7, 8, 9])),
55 Arc::new(Int32Array::from(vec![None, None, Some(9)])),
56 ],
57 )?;
58
59 let provider = MemTable::try_new(schema, vec![vec![batch]])?;
60
61 let exec = provider
63 .scan(&session_ctx.state(), Some(&vec![2, 1]), &[], None)
64 .await?;
65
66 let mut it = exec.execute(0, task_ctx)?;
67 let batch2 = it.next().await.unwrap()?;
68 assert_eq!(2, batch2.schema().fields().len());
69 assert_eq!("c", batch2.schema().field(0).name());
70 assert_eq!("b", batch2.schema().field(1).name());
71 assert_eq!(2, batch2.num_columns());
72
73 Ok(())
74 }
75
76 #[tokio::test]
77 async fn test_without_projection() -> Result<()> {
78 let session_ctx = SessionContext::new();
79 let task_ctx = session_ctx.task_ctx();
80 let schema = Arc::new(Schema::new(vec![
81 Field::new("a", DataType::Int32, false),
82 Field::new("b", DataType::Int32, false),
83 Field::new("c", DataType::Int32, false),
84 ]));
85
86 let batch = RecordBatch::try_new(
87 schema.clone(),
88 vec![
89 Arc::new(Int32Array::from(vec![1, 2, 3])),
90 Arc::new(Int32Array::from(vec![4, 5, 6])),
91 Arc::new(Int32Array::from(vec![7, 8, 9])),
92 ],
93 )?;
94
95 let provider = MemTable::try_new(schema, vec![vec![batch]])?;
96
97 let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
98 let mut it = exec.execute(0, task_ctx)?;
99 let batch1 = it.next().await.unwrap()?;
100 assert_eq!(3, batch1.schema().fields().len());
101 assert_eq!(3, batch1.num_columns());
102
103 Ok(())
104 }
105
106 #[tokio::test]
107 async fn test_invalid_projection() -> Result<()> {
108 let session_ctx = SessionContext::new();
109
110 let schema = Arc::new(Schema::new(vec![
111 Field::new("a", DataType::Int32, false),
112 Field::new("b", DataType::Int32, false),
113 Field::new("c", DataType::Int32, false),
114 ]));
115
116 let batch = RecordBatch::try_new(
117 schema.clone(),
118 vec![
119 Arc::new(Int32Array::from(vec![1, 2, 3])),
120 Arc::new(Int32Array::from(vec![4, 5, 6])),
121 Arc::new(Int32Array::from(vec![7, 8, 9])),
122 ],
123 )?;
124
125 let provider = MemTable::try_new(schema, vec![vec![batch]])?;
126
127 let projection: Vec<usize> = vec![0, 4];
128
129 match provider
130 .scan(&session_ctx.state(), Some(&projection), &[], None)
131 .await
132 {
133 Err(DataFusionError::ArrowError(err, _)) => match err.as_ref() {
134 ArrowError::SchemaError(e) => {
135 assert_eq!(
136 "\"project index 4 out of bounds, max field 3\"",
137 format!("{e:?}")
138 )
139 }
140 _ => panic!("unexpected error"),
141 },
142 res => panic!("Scan should failed on invalid projection, got {res:?}"),
143 };
144
145 Ok(())
146 }
147
148 #[test]
149 fn test_schema_validation_incompatible_column() -> Result<()> {
150 let schema1 = Arc::new(Schema::new(vec![
151 Field::new("a", DataType::Int32, false),
152 Field::new("b", DataType::Int32, false),
153 Field::new("c", DataType::Int32, false),
154 ]));
155
156 let schema2 = Arc::new(Schema::new(vec![
157 Field::new("a", DataType::Int32, false),
158 Field::new("b", DataType::Float64, false),
159 Field::new("c", DataType::Int32, false),
160 ]));
161
162 let batch = RecordBatch::try_new(
163 schema1,
164 vec![
165 Arc::new(Int32Array::from(vec![1, 2, 3])),
166 Arc::new(Int32Array::from(vec![4, 5, 6])),
167 Arc::new(Int32Array::from(vec![7, 8, 9])),
168 ],
169 )?;
170
171 let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
172 assert_eq!(
173 "Error during planning: Mismatch between schema and batches",
174 e.strip_backtrace()
175 );
176
177 Ok(())
178 }
179
180 #[test]
181 fn test_schema_validation_different_column_count() -> Result<()> {
182 let schema1 = Arc::new(Schema::new(vec![
183 Field::new("a", DataType::Int32, false),
184 Field::new("c", DataType::Int32, false),
185 ]));
186
187 let schema2 = Arc::new(Schema::new(vec![
188 Field::new("a", DataType::Int32, false),
189 Field::new("b", DataType::Int32, false),
190 Field::new("c", DataType::Int32, false),
191 ]));
192
193 let batch = RecordBatch::try_new(
194 schema1,
195 vec![
196 Arc::new(Int32Array::from(vec![1, 2, 3])),
197 Arc::new(Int32Array::from(vec![7, 5, 9])),
198 ],
199 )?;
200
201 let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err();
202 assert_eq!(
203 "Error during planning: Mismatch between schema and batches",
204 e.strip_backtrace()
205 );
206
207 Ok(())
208 }
209
210 #[tokio::test]
211 async fn test_merged_schema() -> Result<()> {
212 let session_ctx = SessionContext::new();
213 let task_ctx = session_ctx.task_ctx();
214 let mut metadata = HashMap::new();
215 metadata.insert("foo".to_string(), "bar".to_string());
216
217 let schema1 = Schema::new_with_metadata(
218 vec![
219 Field::new("a", DataType::Int32, false),
220 Field::new("b", DataType::Int32, false),
221 Field::new("c", DataType::Int32, false),
222 ],
223 metadata,
225 );
226
227 let schema2 = Schema::new(vec![
228 Field::new("a", DataType::Int32, true),
230 Field::new("b", DataType::Int32, false),
231 Field::new("c", DataType::Int32, false),
232 ]);
233
234 let merged_schema = Schema::try_merge(vec![schema1.clone(), schema2.clone()])?;
235
236 let batch1 = RecordBatch::try_new(
237 Arc::new(schema1),
238 vec![
239 Arc::new(Int32Array::from(vec![1, 2, 3])),
240 Arc::new(Int32Array::from(vec![4, 5, 6])),
241 Arc::new(Int32Array::from(vec![7, 8, 9])),
242 ],
243 )?;
244
245 let batch2 = RecordBatch::try_new(
246 Arc::new(schema2),
247 vec![
248 Arc::new(Int32Array::from(vec![1, 2, 3])),
249 Arc::new(Int32Array::from(vec![4, 5, 6])),
250 Arc::new(Int32Array::from(vec![7, 8, 9])),
251 ],
252 )?;
253
254 let provider =
255 MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
256
257 let exec = provider.scan(&session_ctx.state(), None, &[], None).await?;
258 let mut it = exec.execute(0, task_ctx)?;
259 let batch1 = it.next().await.unwrap()?;
260 assert_eq!(3, batch1.schema().fields().len());
261 assert_eq!(3, batch1.num_columns());
262
263 Ok(())
264 }
265
266 async fn experiment(
267 schema: SchemaRef,
268 initial_data: Vec<Vec<RecordBatch>>,
269 inserted_data: Vec<Vec<RecordBatch>>,
270 ) -> Result<Vec<Vec<RecordBatch>>> {
271 let expected_count: u64 = inserted_data
272 .iter()
273 .flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
274 .sum();
275
276 let session_ctx = SessionContext::new();
278 let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
280 session_ctx.register_table("t", initial_table.clone())?;
281 let target = Arc::new(DefaultTableSource::new(initial_table.clone()));
282 let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
284 session_ctx.register_table("source", source_table.clone())?;
285 let source = provider_as_source(source_table);
287 let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
289 let insert_into_table =
291 LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)?
292 .build()?;
293 let plan = session_ctx
295 .state()
296 .create_physical_plan(&insert_into_table)
297 .await?;
298
299 let res = collect(plan, session_ctx.task_ctx()).await?;
301 assert_eq!(extract_count(res), expected_count);
302
303 let mut partitions = vec![];
305 for partition in initial_table.batches.iter() {
306 let part = partition.read().await.clone();
307 partitions.push(part);
308 }
309 Ok(partitions)
310 }
311
312 fn extract_count(res: Vec<RecordBatch>) -> u64 {
322 assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
323 let batch = &res[0];
324 assert_eq!(
325 batch.num_columns(),
326 1,
327 "expected 1 column, got {}",
328 batch.num_columns()
329 );
330 let col = batch.column(0).as_primitive::<UInt64Type>();
331 assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
332 let val = col
333 .iter()
334 .next()
335 .expect("had value")
336 .expect("expected non null");
337 val
338 }
339
340 #[tokio::test]
342 async fn test_insert_into_single_partition() -> Result<()> {
343 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
345
346 let batch = RecordBatch::try_new(
348 schema.clone(),
349 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
350 )?;
351 let resulting_data_in_table =
353 experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]])
354 .await?;
355 assert_eq!(resulting_data_in_table[0].len(), 2);
357 Ok(())
358 }
359
360 #[tokio::test]
362 async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> {
363 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
365
366 let batch = RecordBatch::try_new(
368 schema.clone(),
369 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
370 )?;
371 let resulting_data_in_table = experiment(
373 schema,
374 vec![vec![batch.clone()]],
375 vec![vec![batch.clone()], vec![batch]],
376 )
377 .await?;
378 assert_eq!(resulting_data_in_table[0].len(), 3);
380 Ok(())
381 }
382
383 #[tokio::test]
385 async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> {
386 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
388
389 let batch = RecordBatch::try_new(
391 schema.clone(),
392 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
393 )?;
394 let resulting_data_in_table = experiment(
396 schema,
397 vec![vec![batch.clone()], vec![batch.clone()]],
398 vec![
399 vec![batch.clone(), batch.clone()],
400 vec![batch.clone(), batch],
401 ],
402 )
403 .await?;
404 assert_eq!(resulting_data_in_table[0].len(), 3);
406 assert_eq!(resulting_data_in_table[1].len(), 3);
407 Ok(())
408 }
409
410 #[tokio::test]
411 async fn test_insert_from_empty_table() -> Result<()> {
412 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
414
415 let batch = RecordBatch::try_new(
417 schema.clone(),
418 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
419 )?;
420 let resulting_data_in_table = experiment(
422 schema,
423 vec![vec![batch.clone(), batch.clone()]],
424 vec![vec![]],
425 )
426 .await?;
427 assert_eq!(resulting_data_in_table[0].len(), 2);
429 Ok(())
430 }
431
432 #[tokio::test]
434 async fn test_insert_into_zero_partition() -> Result<()> {
435 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
437
438 let batch = RecordBatch::try_new(
440 schema.clone(),
441 vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
442 )?;
443 let experiment_result = experiment(schema, vec![], vec![vec![batch.clone()]])
445 .await
446 .unwrap_err();
447 assert_eq!(
449 "Error during planning: No partitions provided, expected at least one partition",
450 experiment_result.strip_backtrace()
451 );
452 Ok(())
453 }
454}