datafusion_physical_plan/coalesce/
mod.rs1use arrow::array::RecordBatch;
19use arrow::compute::BatchCoalescer;
20use arrow::datatypes::SchemaRef;
21use datafusion_common::{internal_err, Result};
22
23#[derive(Debug)]
27pub struct LimitedBatchCoalescer {
28 inner: BatchCoalescer,
30 total_rows: usize,
32 fetch: Option<usize>,
34 finished: bool,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PushBatchStatus {
41 Continue,
43 LimitReached,
47}
48
49impl LimitedBatchCoalescer {
50 pub fn new(
58 schema: SchemaRef,
59 target_batch_size: usize,
60 fetch: Option<usize>,
61 ) -> Self {
62 Self {
63 inner: BatchCoalescer::new(schema, target_batch_size)
64 .with_biggest_coalesce_batch_size(Some(target_batch_size / 2)),
65 total_rows: 0,
66 fetch,
67 finished: false,
68 }
69 }
70
71 pub fn schema(&self) -> SchemaRef {
73 self.inner.schema()
74 }
75
76 pub fn push_batch(&mut self, batch: RecordBatch) -> Result<PushBatchStatus> {
91 if self.finished {
92 return internal_err!(
93 "LimitedBatchCoalescer: cannot push batch after finish"
94 );
95 }
96
97 if let Some(fetch) = self.fetch {
99 if self.total_rows >= fetch {
101 return Ok(PushBatchStatus::LimitReached);
102 }
103
104 if self.total_rows + batch.num_rows() >= fetch {
106 let remaining_rows = fetch - self.total_rows;
108 debug_assert!(remaining_rows > 0);
109
110 let batch_head = batch.slice(0, remaining_rows);
111 self.total_rows += batch_head.num_rows();
112 self.inner.push_batch(batch_head)?;
113 return Ok(PushBatchStatus::LimitReached);
114 }
115 }
116
117 self.total_rows += batch.num_rows();
119 self.inner.push_batch(batch)?;
120
121 Ok(PushBatchStatus::Continue)
122 }
123
124 pub fn is_empty(&self) -> bool {
126 self.inner.is_empty()
127 }
128
129 pub fn finish(&mut self) -> Result<()> {
133 self.inner.finish_buffered_batch()?;
134 self.finished = true;
135 Ok(())
136 }
137
138 pub fn next_completed_batch(&mut self) -> Option<RecordBatch> {
140 self.inner.next_completed_batch()
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::ops::Range;
148 use std::sync::Arc;
149
150 use arrow::array::UInt32Array;
151 use arrow::compute::concat_batches;
152 use arrow::datatypes::{DataType, Field, Schema};
153
154 #[test]
155 fn test_coalesce() {
156 let batch = uint32_batch(0..8);
157 Test::new()
158 .with_batches(std::iter::repeat_n(batch, 10))
159 .with_target_batch_size(21)
161 .with_expected_output_sizes(vec![21, 21, 21, 17])
162 .run()
163 }
164
165 #[test]
166 fn test_coalesce_with_fetch_larger_than_input_size() {
167 let batch = uint32_batch(0..8);
168 Test::new()
169 .with_batches(std::iter::repeat_n(batch, 10))
170 .with_target_batch_size(21)
173 .with_fetch(Some(100))
174 .with_expected_output_sizes(vec![21, 21, 21, 17])
175 .run();
176 }
177
178 #[test]
179 fn test_coalesce_with_fetch_less_than_input_size() {
180 let batch = uint32_batch(0..8);
181 Test::new()
182 .with_batches(std::iter::repeat_n(batch, 10))
183 .with_target_batch_size(21)
185 .with_fetch(Some(50))
186 .with_expected_output_sizes(vec![21, 21, 8])
187 .run();
188 }
189
190 #[test]
191 fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
192 let batch = uint32_batch(0..8);
193 Test::new()
194 .with_batches(std::iter::repeat_n(batch, 10))
195 .with_target_batch_size(24)
197 .with_fetch(Some(48))
198 .with_expected_output_sizes(vec![24, 24])
199 .run();
200 }
201
202 #[test]
203 fn test_coalesce_with_fetch_less_target_batch_size() {
204 let batch = uint32_batch(0..8);
205 Test::new()
206 .with_batches(std::iter::repeat_n(batch, 10))
207 .with_target_batch_size(21)
209 .with_fetch(Some(10))
210 .with_expected_output_sizes(vec![10])
211 .run();
212 }
213
214 #[test]
215 fn test_coalesce_single_large_batch_over_fetch() {
216 let large_batch = uint32_batch(0..100);
217 Test::new()
218 .with_batch(large_batch)
219 .with_target_batch_size(20)
220 .with_fetch(Some(7))
221 .with_expected_output_sizes(vec![7])
222 .run()
223 }
224
225 #[derive(Debug, Clone, Default)]
230 struct Test {
231 input_batches: Vec<RecordBatch>,
234 expected_output_sizes: Vec<usize>,
236 target_batch_size: usize,
238 fetch: Option<usize>,
240 }
241
242 impl Test {
243 fn new() -> Self {
244 Self::default()
245 }
246
247 fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
249 self.target_batch_size = target_batch_size;
250 self
251 }
252
253 fn with_fetch(mut self, fetch: Option<usize>) -> Self {
255 self.fetch = fetch;
256 self
257 }
258
259 fn with_batch(mut self, batch: RecordBatch) -> Self {
261 self.input_batches.push(batch);
262 self
263 }
264
265 fn with_batches(
267 mut self,
268 batches: impl IntoIterator<Item = RecordBatch>,
269 ) -> Self {
270 self.input_batches.extend(batches);
271 self
272 }
273
274 fn with_expected_output_sizes(
276 mut self,
277 sizes: impl IntoIterator<Item = usize>,
278 ) -> Self {
279 self.expected_output_sizes.extend(sizes);
280 self
281 }
282
283 fn run(self) {
285 let Self {
286 input_batches,
287 target_batch_size,
288 fetch,
289 expected_output_sizes,
290 } = self;
291
292 let schema = input_batches[0].schema();
293
294 let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
296
297 let mut coalescer =
298 LimitedBatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch);
299
300 let mut output_batches = vec![];
301 for batch in input_batches {
302 match coalescer.push_batch(batch).unwrap() {
303 PushBatchStatus::Continue => {
304 }
306 PushBatchStatus::LimitReached => {
307 break;
308 }
309 }
310 }
311 coalescer.finish().unwrap();
312 while let Some(batch) = coalescer.next_completed_batch() {
313 output_batches.push(batch);
314 }
315
316 let actual_output_sizes: Vec<usize> =
317 output_batches.iter().map(|b| b.num_rows()).collect();
318 assert_eq!(
319 expected_output_sizes, actual_output_sizes,
320 "Unexpected number of rows in output batches\n\
321 Expected\n{expected_output_sizes:#?}\nActual:{actual_output_sizes:#?}"
322 );
323
324 let mut starting_idx = 0;
326 assert_eq!(expected_output_sizes.len(), output_batches.len());
327 for (i, (expected_size, batch)) in
328 expected_output_sizes.iter().zip(output_batches).enumerate()
329 {
330 assert_eq!(
331 *expected_size,
332 batch.num_rows(),
333 "Unexpected number of rows in Batch {i}"
334 );
335
336 let expected_batch =
339 single_input_batch.slice(starting_idx, *expected_size);
340 let batch_strings = batch_to_pretty_strings(&batch);
341 let expected_batch_strings = batch_to_pretty_strings(&expected_batch);
342 let batch_strings = batch_strings.lines().collect::<Vec<_>>();
343 let expected_batch_strings =
344 expected_batch_strings.lines().collect::<Vec<_>>();
345 assert_eq!(
346 expected_batch_strings, batch_strings,
347 "Unexpected content in Batch {i}:\
348 \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
349 );
350 starting_idx += *expected_size;
351 }
352 }
353 }
354
355 fn uint32_batch(range: Range<u32>) -> RecordBatch {
357 let schema =
358 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
359
360 RecordBatch::try_new(
361 Arc::clone(&schema),
362 vec![Arc::new(UInt32Array::from_iter_values(range))],
363 )
364 .unwrap()
365 }
366
367 fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
368 arrow::util::pretty::pretty_format_batches(std::slice::from_ref(batch))
369 .unwrap()
370 .to_string()
371 }
372}