datafusion_physical_plan/repartition/
distributor_channels.rs1use std::{
41 collections::VecDeque,
42 future::Future,
43 ops::DerefMut,
44 pin::Pin,
45 sync::{
46 atomic::{AtomicUsize, Ordering},
47 Arc,
48 },
49 task::{Context, Poll, Waker},
50};
51
52use parking_lot::Mutex;
53
54pub fn channels<T>(
56 n: usize,
57) -> (Vec<DistributionSender<T>>, Vec<DistributionReceiver<T>>) {
58 let channels = (0..n)
59 .map(|id| Arc::new(Channel::new_with_one_sender(id)))
60 .collect::<Vec<_>>();
61 let gate = Arc::new(Gate {
62 empty_channels: AtomicUsize::new(n),
63 send_wakers: Mutex::new(None),
64 });
65 let senders = channels
66 .iter()
67 .map(|channel| DistributionSender {
68 channel: Arc::clone(channel),
69 gate: Arc::clone(&gate),
70 })
71 .collect();
72 let receivers = channels
73 .into_iter()
74 .map(|channel| DistributionReceiver {
75 channel,
76 gate: Arc::clone(&gate),
77 })
78 .collect();
79 (senders, receivers)
80}
81
82type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
83type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;
84
85pub fn partition_aware_channels<T>(
89 n_in: usize,
90 n_out: usize,
91) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) {
92 (0..n_in).map(|_| channels(n_out)).unzip()
93}
94
95#[derive(PartialEq, Eq)]
99pub struct SendError<T>(pub T);
100
101impl<T> std::fmt::Debug for SendError<T> {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_tuple("SendError").finish()
104 }
105}
106
107impl<T> std::fmt::Display for SendError<T> {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "cannot send data, receiver is gone")
110 }
111}
112
113impl<T> std::error::Error for SendError<T> {}
114
115#[derive(Debug)]
121pub struct DistributionSender<T> {
122 channel: SharedChannel<T>,
124 gate: SharedGate,
125}
126
127impl<T> DistributionSender<T> {
128 pub fn send(&self, element: T) -> SendFuture<'_, T> {
132 SendFuture {
133 channel: &self.channel,
134 gate: &self.gate,
135 element: Box::new(Some(element)),
136 }
137 }
138}
139
140impl<T> Clone for DistributionSender<T> {
141 fn clone(&self) -> Self {
142 self.channel.n_senders.fetch_add(1, Ordering::SeqCst);
143
144 Self {
145 channel: Arc::clone(&self.channel),
146 gate: Arc::clone(&self.gate),
147 }
148 }
149}
150
151impl<T> Drop for DistributionSender<T> {
152 fn drop(&mut self) {
153 let n_senders_pre = self.channel.n_senders.fetch_sub(1, Ordering::SeqCst);
154 if n_senders_pre > 1 {
156 return;
157 }
158
159 let receivers = {
160 let mut state = self.channel.state.lock();
161
162 if state
177 .data
178 .as_ref()
179 .map(|data| data.is_empty())
180 .unwrap_or_default()
181 {
182 self.gate.decr_empty_channels();
184 }
185
186 state.recv_wakers.take().expect("not closed yet")
188 };
189
190 for recv in receivers {
192 recv.wake();
193 }
194 }
195}
196
197#[derive(Debug)]
199pub struct SendFuture<'a, T> {
200 channel: &'a SharedChannel<T>,
201 gate: &'a SharedGate,
202 element: Box<Option<T>>,
204}
205
206impl<T> Future for SendFuture<'_, T> {
207 type Output = Result<(), SendError<T>>;
208
209 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210 let this = &mut *self;
211 assert!(this.element.is_some(), "polled ready future");
212
213 let to_wake = {
215 let mut guard_channel_state = this.channel.state.lock();
216
217 let Some(data) = guard_channel_state.data.as_mut() else {
218 return Poll::Ready(Err(SendError(
220 this.element.take().expect("just checked"),
221 )));
222 };
223
224 if this.gate.empty_channels.load(Ordering::SeqCst) == 0 {
227 let mut guard = this.gate.send_wakers.lock();
228 if let Some(send_wakers) = guard.deref_mut() {
229 send_wakers.push((cx.waker().clone(), this.channel.id));
230 return Poll::Pending;
231 }
232 }
233
234 let was_empty = data.is_empty();
235 data.push_back(this.element.take().expect("just checked"));
236
237 if was_empty {
238 this.gate.decr_empty_channels();
239 guard_channel_state.take_recv_wakers()
240 } else {
241 Vec::with_capacity(0)
242 }
243 };
244
245 for receiver in to_wake {
247 receiver.wake();
248 }
249
250 Poll::Ready(Ok(()))
251 }
252}
253
254#[derive(Debug)]
256pub struct DistributionReceiver<T> {
257 channel: SharedChannel<T>,
258 gate: SharedGate,
259}
260
261impl<T> DistributionReceiver<T> {
262 pub fn recv(&mut self) -> RecvFuture<'_, T> {
266 RecvFuture {
267 channel: &mut self.channel,
268 gate: &mut self.gate,
269 rdy: false,
270 }
271 }
272}
273
274impl<T> Drop for DistributionReceiver<T> {
275 fn drop(&mut self) {
276 let mut guard_channel_state = self.channel.state.lock();
277 let data = guard_channel_state.data.take().expect("not dropped yet");
278
279 if data.is_empty() && (self.channel.n_senders.load(Ordering::SeqCst) > 0) {
282 self.gate.decr_empty_channels();
284 }
285
286 self.gate.wake_channel_senders(self.channel.id);
288 }
289}
290
291pub struct RecvFuture<'a, T> {
293 channel: &'a mut SharedChannel<T>,
294 gate: &'a mut SharedGate,
295 rdy: bool,
296}
297
298impl<T> Future for RecvFuture<'_, T> {
299 type Output = Option<T>;
300
301 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302 let this = &mut *self;
303 assert!(!this.rdy, "polled ready future");
304
305 let mut guard_channel_state = this.channel.state.lock();
306 let channel_state = guard_channel_state.deref_mut();
307 let data = channel_state.data.as_mut().expect("not dropped yet");
308
309 match data.pop_front() {
310 Some(element) => {
311 if data.is_empty() && channel_state.recv_wakers.is_some() {
313 let old_counter =
315 this.gate.empty_channels.fetch_add(1, Ordering::SeqCst);
316
317 let to_wake = if old_counter == 0 {
319 let mut guard = this.gate.send_wakers.lock();
320
321 if this.gate.empty_channels.load(Ordering::SeqCst) > 0 {
323 guard.take().unwrap_or_default()
324 } else {
325 Vec::with_capacity(0)
326 }
327 } else {
328 Vec::with_capacity(0)
329 };
330
331 drop(guard_channel_state);
332
333 for (waker, _channel_id) in to_wake {
335 waker.wake();
336 }
337 }
338
339 this.rdy = true;
340 Poll::Ready(Some(element))
341 }
342 None => {
343 if let Some(recv_wakers) = channel_state.recv_wakers.as_mut() {
344 recv_wakers.push(cx.waker().clone());
345 Poll::Pending
346 } else {
347 this.rdy = true;
348 Poll::Ready(None)
349 }
350 }
351 }
352 }
353}
354
355#[derive(Debug)]
357struct Channel<T> {
358 n_senders: AtomicUsize,
360
361 id: usize,
365
366 state: Mutex<ChannelState<T>>,
368}
369
370impl<T> Channel<T> {
371 fn new_with_one_sender(id: usize) -> Self {
373 Channel {
374 n_senders: AtomicUsize::new(1),
375 id,
376 state: Mutex::new(ChannelState {
377 data: Some(VecDeque::default()),
378 recv_wakers: Some(Vec::default()),
379 }),
380 }
381 }
382}
383
384#[derive(Debug)]
385struct ChannelState<T> {
386 data: Option<VecDeque<T>>,
390
391 recv_wakers: Option<Vec<Waker>>,
396}
397
398impl<T> ChannelState<T> {
399 fn take_recv_wakers(&mut self) -> Vec<Waker> {
406 let to_wake = self.recv_wakers.as_mut().expect("not closed");
407 let mut tmp = Vec::with_capacity(to_wake.capacity());
408 std::mem::swap(to_wake, &mut tmp);
409 tmp
410 }
411}
412
413type SharedChannel<T> = Arc<Channel<T>>;
417
418#[derive(Debug)]
420struct Gate {
421 empty_channels: AtomicUsize,
423
424 send_wakers: Mutex<Option<Vec<(Waker, usize)>>>,
428}
429
430impl Gate {
431 fn wake_channel_senders(&self, id: usize) {
435 let to_wake = {
437 let mut guard = self.send_wakers.lock();
438
439 if let Some(send_wakers) = guard.deref_mut() {
440 let (wake, keep) =
442 send_wakers.drain(..).partition(|(_waker, id2)| id == *id2);
443
444 *send_wakers = keep;
445
446 wake
447 } else {
448 Vec::with_capacity(0)
449 }
450 };
451
452 for (waker, _id) in to_wake {
454 waker.wake();
455 }
456 }
457
458 fn decr_empty_channels(&self) {
459 let old_count = self.empty_channels.fetch_sub(1, Ordering::SeqCst);
460
461 if old_count == 1 {
462 let mut guard = self.send_wakers.lock();
463
464 if self.empty_channels.load(Ordering::SeqCst) == 0 && guard.is_none() {
466 *guard = Some(Vec::new());
467 }
468 }
469 }
470}
471
472type SharedGate = Arc<Gate>;
474
475#[cfg(test)]
476mod tests {
477 use std::sync::atomic::AtomicBool;
478
479 use futures::{task::ArcWake, FutureExt};
480
481 use super::*;
482
483 #[test]
484 fn test_single_channel_no_gate() {
485 let (mut txs, mut rxs) = channels(2);
487
488 let mut recv_fut = rxs[0].recv();
489 let waker = poll_pending(&mut recv_fut);
490
491 poll_ready(&mut txs[0].send("foo")).unwrap();
492 assert!(waker.woken());
493 assert_eq!(poll_ready(&mut recv_fut), Some("foo"),);
494
495 poll_ready(&mut txs[0].send("bar")).unwrap();
496 poll_ready(&mut txs[0].send("baz")).unwrap();
497 poll_ready(&mut txs[0].send("end")).unwrap();
498 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
499 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("baz"),);
500
501 txs.remove(0);
503 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("end"),);
504 assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
505 assert_eq!(poll_ready(&mut rxs[0].recv()), None,);
506 }
507
508 #[test]
509 fn test_multi_sender() {
510 let (txs, mut rxs) = channels(2);
512
513 let tx_clone = txs[0].clone();
514
515 poll_ready(&mut txs[0].send("foo")).unwrap();
516 poll_ready(&mut tx_clone.send("bar")).unwrap();
517
518 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("foo"),);
519 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),);
520 }
521
522 #[test]
523 fn test_gate() {
524 let (txs, mut rxs) = channels(2);
525
526 poll_ready(&mut txs[0].send("0_a")).unwrap();
528
529 poll_ready(&mut txs[0].send("0_b")).unwrap();
531
532 poll_ready(&mut txs[1].send("1_a")).unwrap();
534
535 let mut send_fut = txs[1].send("1_b");
538 let waker = poll_pending(&mut send_fut);
539
540 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_a"),);
542 poll_pending(&mut send_fut);
543 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_b"),);
544
545 assert!(waker.woken());
547 poll_ready(&mut send_fut).unwrap();
548 }
549
550 #[test]
551 fn test_close_channel_by_dropping_tx() {
552 let (mut txs, mut rxs) = channels(2);
553
554 let tx0 = txs.remove(0);
555 let tx1 = txs.remove(0);
556 let tx0_clone = tx0.clone();
557
558 let mut recv_fut = rxs[0].recv();
559
560 poll_ready(&mut tx1.send("a")).unwrap();
561 let recv_waker = poll_pending(&mut recv_fut);
562
563 drop(tx0);
565
566 assert!(!recv_waker.woken());
568 poll_ready(&mut tx1.send("b")).unwrap();
569 let recv_waker = poll_pending(&mut recv_fut);
570
571 let tx0_clone2 = tx0_clone.clone();
573 assert!(!recv_waker.woken());
574 poll_ready(&mut tx1.send("c")).unwrap();
575 let recv_waker = poll_pending(&mut recv_fut);
576
577 drop(tx0_clone);
579 assert!(!recv_waker.woken());
580 poll_ready(&mut tx1.send("d")).unwrap();
581 let recv_waker = poll_pending(&mut recv_fut);
582
583 drop(tx0_clone2);
585
586 poll_pending(&mut tx1.send("e"));
588 assert!(recv_waker.woken());
589 assert_eq!(poll_ready(&mut recv_fut), None,);
590 }
591
592 #[test]
593 fn test_close_channel_by_dropping_rx_on_open_gate() {
594 let (txs, mut rxs) = channels(2);
595
596 let rx0 = rxs.remove(0);
597 let _rx1 = rxs.remove(0);
598
599 poll_ready(&mut txs[1].send("a")).unwrap();
600
601 drop(rx0);
603
604 poll_pending(&mut txs[1].send("b"));
605 assert_eq!(poll_ready(&mut txs[0].send("foo")), Err(SendError("foo")),);
606 }
607
608 #[test]
609 fn test_close_channel_by_dropping_rx_on_closed_gate() {
610 let (txs, mut rxs) = channels(2);
611
612 let rx0 = rxs.remove(0);
613 let mut rx1 = rxs.remove(0);
614
615 poll_ready(&mut txs[0].send("0_a")).unwrap();
617 poll_ready(&mut txs[1].send("1_a")).unwrap();
618
619 let mut send_fut0 = txs[0].send("0_b");
620 let mut send_fut1 = txs[1].send("1_b");
621 let waker0 = poll_pending(&mut send_fut0);
622 let waker1 = poll_pending(&mut send_fut1);
623
624 drop(rx0);
626
627 assert!(waker0.woken());
628 assert!(!waker1.woken());
629 assert_eq!(poll_ready(&mut send_fut0), Err(SendError("0_b")),);
630
631 poll_pending(&mut send_fut1);
633
634 assert_eq!(poll_ready(&mut rx1.recv()), Some("1_a"),);
636 }
637
638 #[test]
639 fn test_drop_rx_three_channels() {
640 let (mut txs, mut rxs) = channels(3);
641
642 let tx0 = txs.remove(0);
643 let tx1 = txs.remove(0);
644 let tx2 = txs.remove(0);
645 let mut rx0 = rxs.remove(0);
646 let rx1 = rxs.remove(0);
647 let _rx2 = rxs.remove(0);
648
649 poll_ready(&mut tx0.send("0_a")).unwrap();
651 poll_ready(&mut tx1.send("1_a")).unwrap();
652 poll_ready(&mut tx2.send("2_a")).unwrap();
653
654 drop(rx1);
656
657 assert_eq!(poll_ready(&mut rx0.recv()), Some("0_a"),);
659
660 poll_ready(&mut tx0.send("0_b")).unwrap();
662 assert_eq!(poll_ready(&mut tx1.send("1_b")), Err(SendError("1_b")),);
663 poll_pending(&mut tx2.send("2_b"));
664 }
665
666 #[test]
667 fn test_close_channel_by_dropping_rx_clears_data() {
668 let (txs, rxs) = channels(1);
669
670 let obj = Arc::new(());
671 let counter = Arc::downgrade(&obj);
672 assert_eq!(counter.strong_count(), 1);
673
674 poll_ready(&mut txs[0].send(obj)).unwrap();
676 assert_eq!(counter.strong_count(), 1);
677
678 drop(rxs);
680
681 assert_eq!(counter.strong_count(), 0);
682 }
683
684 #[test]
686 fn test_poll_empty_channel_twice() {
687 let (txs, mut rxs) = channels(1);
688
689 let mut recv_fut = rxs[0].recv();
690 let waker_1a = poll_pending(&mut recv_fut);
691 let waker_1b = poll_pending(&mut recv_fut);
692
693 let mut recv_fut = rxs[0].recv();
694 let waker_2 = poll_pending(&mut recv_fut);
695
696 poll_ready(&mut txs[0].send("a")).unwrap();
697 assert!(waker_1a.woken());
698 assert!(waker_1b.woken());
699 assert!(waker_2.woken());
700 assert_eq!(poll_ready(&mut recv_fut), Some("a"),);
701
702 poll_ready(&mut txs[0].send("b")).unwrap();
703 let mut send_fut = txs[0].send("c");
704 let waker_3 = poll_pending(&mut send_fut);
705 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("b"),);
706 assert!(waker_3.woken());
707 poll_ready(&mut send_fut).unwrap();
708 assert_eq!(poll_ready(&mut rxs[0].recv()), Some("c"));
709
710 let mut recv_fut = rxs[0].recv();
711 let waker_4 = poll_pending(&mut recv_fut);
712
713 let mut recv_fut = rxs[0].recv();
714 let waker_5 = poll_pending(&mut recv_fut);
715
716 poll_ready(&mut txs[0].send("d")).unwrap();
717 let mut send_fut = txs[0].send("e");
718 let waker_6a = poll_pending(&mut send_fut);
719 let waker_6b = poll_pending(&mut send_fut);
720
721 assert!(waker_4.woken());
722 assert!(waker_5.woken());
723 assert_eq!(poll_ready(&mut recv_fut), Some("d"),);
724
725 assert!(waker_6a.woken());
726 assert!(waker_6b.woken());
727 poll_ready(&mut send_fut).unwrap();
728 }
729
730 #[test]
731 #[should_panic(expected = "polled ready future")]
732 fn test_panic_poll_send_future_after_ready_ok() {
733 let (txs, _rxs) = channels(1);
734 let mut fut = txs[0].send("foo");
735 poll_ready(&mut fut).unwrap();
736 poll_ready(&mut fut).ok();
737 }
738
739 #[test]
740 #[should_panic(expected = "polled ready future")]
741 fn test_panic_poll_send_future_after_ready_err() {
742 let (txs, rxs) = channels(1);
743
744 drop(rxs);
745
746 let mut fut = txs[0].send("foo");
747 poll_ready(&mut fut).unwrap_err();
748 poll_ready(&mut fut).ok();
749 }
750
751 #[test]
752 #[should_panic(expected = "polled ready future")]
753 fn test_panic_poll_recv_future_after_ready_some() {
754 let (txs, mut rxs) = channels(1);
755
756 poll_ready(&mut txs[0].send("foo")).unwrap();
757
758 let mut fut = rxs[0].recv();
759 poll_ready(&mut fut).unwrap();
760 poll_ready(&mut fut);
761 }
762
763 #[test]
764 #[should_panic(expected = "polled ready future")]
765 fn test_panic_poll_recv_future_after_ready_none() {
766 let (txs, mut rxs) = channels::<u8>(1);
767
768 drop(txs);
769
770 let mut fut = rxs[0].recv();
771 assert!(poll_ready(&mut fut).is_none());
772 poll_ready(&mut fut);
773 }
774
775 #[test]
776 #[should_panic(expected = "future is pending")]
777 fn test_meta_poll_ready_wrong_state() {
778 let mut fut = futures::future::pending::<u8>();
779 poll_ready(&mut fut);
780 }
781
782 #[test]
783 #[should_panic(expected = "future is ready")]
784 fn test_meta_poll_pending_wrong_state() {
785 let mut fut = futures::future::ready(1);
786 poll_pending(&mut fut);
787 }
788
789 #[test]
791 fn test_meta_poll_pending_waker() {
792 let (tx, mut rx) = futures::channel::oneshot::channel();
793 let waker = poll_pending(&mut rx);
794 assert!(!waker.woken());
795 tx.send(1).unwrap();
796 assert!(waker.woken());
797 }
798
799 #[track_caller]
801 fn poll_ready<F>(fut: &mut F) -> F::Output
802 where
803 F: Future + Unpin,
804 {
805 match poll(fut).0 {
806 Poll::Ready(x) => x,
807 Poll::Pending => panic!("future is pending"),
808 }
809 }
810
811 #[track_caller]
815 fn poll_pending<F>(fut: &mut F) -> Arc<TestWaker>
816 where
817 F: Future + Unpin,
818 {
819 let (res, waker) = poll(fut);
820 match res {
821 Poll::Ready(_) => panic!("future is ready"),
822 Poll::Pending => waker,
823 }
824 }
825
826 fn poll<F>(fut: &mut F) -> (Poll<F::Output>, Arc<TestWaker>)
827 where
828 F: Future + Unpin,
829 {
830 let test_waker = Arc::new(TestWaker::default());
831 let waker = futures::task::waker(Arc::clone(&test_waker));
832 let mut cx = Context::from_waker(&waker);
833 let res = fut.poll_unpin(&mut cx);
834 (res, test_waker)
835 }
836
837 #[derive(Debug, Default)]
839 struct TestWaker {
840 woken: AtomicBool,
841 }
842
843 impl TestWaker {
844 fn woken(&self) -> bool {
846 self.woken.load(Ordering::SeqCst)
847 }
848 }
849
850 impl ArcWake for TestWaker {
851 fn wake_by_ref(arc_self: &Arc<Self>) {
852 arc_self.woken.store(true, Ordering::SeqCst);
853 }
854 }
855}