1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Tick,
28 TickLazy,
30}
31
32pub enum PortListSpec {
34 Variadic,
36 Fixed(Punctuated<PortIndex, Token![,]>),
38}
39
40pub struct OperatorConstraints {
42 pub name: &'static str,
44 pub categories: &'static [OperatorCategory],
46
47 pub hard_range_inn: &'static dyn RangeTrait<usize>,
50 pub soft_range_inn: &'static dyn RangeTrait<usize>,
52 pub hard_range_out: &'static dyn RangeTrait<usize>,
54 pub soft_range_out: &'static dyn RangeTrait<usize>,
56 pub num_args: usize,
58 pub persistence_args: &'static dyn RangeTrait<usize>,
60 pub type_args: &'static dyn RangeTrait<usize>,
64 pub is_external_input: bool,
67 pub has_singleton_output: bool,
71 pub preserves_singleton: bool,
74 pub flo_type: Option<FloType>,
76
77 pub ports_inn: Option<fn() -> PortListSpec>,
79 pub ports_out: Option<fn() -> PortListSpec>,
81
82 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
84 pub write_fn: WriteFn,
86}
87
88pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
90
91impl Debug for OperatorConstraints {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("OperatorConstraints")
94 .field("name", &self.name)
95 .field("hard_range_inn", &self.hard_range_inn)
96 .field("soft_range_inn", &self.soft_range_inn)
97 .field("hard_range_out", &self.hard_range_out)
98 .field("soft_range_out", &self.soft_range_out)
99 .field("num_args", &self.num_args)
100 .field("persistence_args", &self.persistence_args)
101 .field("type_args", &self.type_args)
102 .field("is_external_input", &self.is_external_input)
103 .field("ports_inn", &self.ports_inn)
104 .field("ports_out", &self.ports_out)
105 .finish()
109 }
110}
111
112#[derive(Default)]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
120 pub write_iterator: TokenStream,
127 pub write_iterator_after: TokenStream,
129 pub write_tick_end: TokenStream,
132}
133
134pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
136pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
138pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
140
141pub fn identity_write_iterator_fn(
144 &WriteContextArgs {
145 root,
146 op_span,
147 ident,
148 inputs,
149 outputs,
150 is_pull,
151 op_inst:
152 OperatorInstance {
153 generics: OpInstGenerics { type_args, .. },
154 ..
155 },
156 ..
157 }: &WriteContextArgs,
158) -> TokenStream {
159 let generic_type = type_args
160 .first()
161 .map(quote::ToTokens::to_token_stream)
162 .unwrap_or(quote_spanned!(op_span=> _));
163
164 if is_pull {
165 let input = &inputs[0];
166 quote_spanned! {op_span=>
167 let #ident = {
168 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
169 where
170 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
171 {
172 pull
173 }
174 check_input::<_, #generic_type>(#input)
175 };
176 }
177 } else {
178 let output = &outputs[0];
179 quote_spanned! {op_span=>
180 let #ident = {
181 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
182 where
183 Psh: #root::dfir_pipes::push::Push<Item, ()>,
184 {
185 push
186 }
187 check_output::<_, #generic_type>(#output)
188 };
189 }
190 }
191}
192
193pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
195 let write_iterator = identity_write_iterator_fn(write_context_args);
196 Ok(OperatorWriteOutput {
197 write_iterator,
198 ..Default::default()
199 })
200};
201
202pub fn null_write_iterator_fn(
205 &WriteContextArgs {
206 root,
207 op_span,
208 ident,
209 inputs,
210 outputs,
211 is_pull,
212 op_inst:
213 OperatorInstance {
214 generics: OpInstGenerics { type_args, .. },
215 ..
216 },
217 ..
218 }: &WriteContextArgs,
219) -> TokenStream {
220 let default_type = parse_quote_spanned! {op_span=> _};
221 let iter_type = type_args.first().unwrap_or(&default_type);
222
223 if is_pull {
224 quote_spanned! {op_span=>
225 let #ident = #root::dfir_pipes::pull::poll_fn({
226 #(
227 let mut #inputs = ::std::boxed::Box::pin(#inputs);
228 )*
229 move |_cx| {
230 #(
234 let #inputs = #root::dfir_pipes::pull::Pull::pull(
235 ::std::pin::Pin::as_mut(&mut #inputs),
236 <_ as #root::dfir_pipes::Context>::from_task(_cx),
237 );
238 )*
239 #(
240 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
241 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
242 }
243 )*
244 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
245 }
246 });
247 }
248 } else {
249 quote_spanned! {op_span=>
250 #[allow(clippy::let_unit_value)]
251 let _ = (#(#outputs),*);
252 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
253 }
254 }
255}
256
257pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
260 let write_iterator = null_write_iterator_fn(write_context_args);
261 Ok(OperatorWriteOutput {
262 write_iterator,
263 ..Default::default()
264 })
265};
266
267macro_rules! declare_ops {
268 ( $( $mod:ident :: $op:ident, )* ) => {
269 $( pub(crate) mod $mod; )*
270 pub const OPERATORS: &[OperatorConstraints] = &[
272 $( $mod :: $op, )*
273 ];
274 };
275}
276declare_ops![
277 all_iterations::ALL_ITERATIONS,
278 all_once::ALL_ONCE,
279 anti_join::ANTI_JOIN,
280 assert::ASSERT,
281 assert_eq::ASSERT_EQ,
282 batch::BATCH,
283 chain::CHAIN,
284 chain_first_n::CHAIN_FIRST_N,
285 _counter::_COUNTER,
286 cross_join::CROSS_JOIN,
287 cross_join_multiset::CROSS_JOIN_MULTISET,
288 cross_singleton::CROSS_SINGLETON,
289 demux_enum::DEMUX_ENUM,
290 dest_file::DEST_FILE,
291 dest_sink::DEST_SINK,
292 dest_sink_serde::DEST_SINK_SERDE,
293 difference::DIFFERENCE,
294 enumerate::ENUMERATE,
295 filter::FILTER,
296 filter_map::FILTER_MAP,
297 flat_map::FLAT_MAP,
298 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
299 flatten::FLATTEN,
300 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
301 fold::FOLD,
302 fold_no_replay::FOLD_NO_REPLAY,
303 for_each::FOR_EACH,
304 identity::IDENTITY,
305 initialize::INITIALIZE,
306 inspect::INSPECT,
307 join::JOIN,
308 join_fused::JOIN_FUSED,
309 join_fused_lhs::JOIN_FUSED_LHS,
310 join_fused_rhs::JOIN_FUSED_RHS,
311 join_multiset::JOIN_MULTISET,
312 join_multiset_half::JOIN_MULTISET_HALF,
313 fold_keyed::FOLD_KEYED,
314 reduce_keyed::REDUCE_KEYED,
315 repeat_n::REPEAT_N,
316 lattice_bimorphism::LATTICE_BIMORPHISM,
318 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
319 lattice_fold::LATTICE_FOLD,
320 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
321 lattice_reduce::LATTICE_REDUCE,
322 map::MAP,
323 union::UNION,
324 multiset_delta::MULTISET_DELTA,
325 next_iteration::NEXT_ITERATION,
326 defer_signal::DEFER_SIGNAL,
327 defer_tick::DEFER_TICK,
328 defer_tick_lazy::DEFER_TICK_LAZY,
329 null::NULL,
330 partition::PARTITION,
331 persist::PERSIST,
332 persist_mut::PERSIST_MUT,
333 persist_mut_keyed::PERSIST_MUT_KEYED,
334 prefix::PREFIX,
335 resolve_futures::RESOLVE_FUTURES,
336 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
337 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
338 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
339 reduce::REDUCE,
340 reduce_no_replay::REDUCE_NO_REPLAY,
341 scan::SCAN,
342 scan_async_blocking::SCAN_ASYNC_BLOCKING,
343 spin::SPIN,
344 sort::SORT,
345 sort_by_key::SORT_BY_KEY,
346 source_file::SOURCE_FILE,
347 source_interval::SOURCE_INTERVAL,
348 source_iter::SOURCE_ITER,
349 source_json::SOURCE_JSON,
350 source_stdin::SOURCE_STDIN,
351 source_stream::SOURCE_STREAM,
352 source_stream_serde::SOURCE_STREAM_SERDE,
353 state::STATE,
354 state_by::STATE_BY,
355 tee::TEE,
356 unique::UNIQUE,
357 unzip::UNZIP,
358 zip::ZIP,
359 zip_longest::ZIP_LONGEST,
360];
361
362pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
364 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
365 OnceLock::new();
366 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
367}
368pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
370 if let GraphNode::Operator(operator) = node {
371 find_op_op_constraints(operator)
372 } else {
373 None
374 }
375}
376pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
378 let name = &*operator.name_string();
379 operator_lookup().get(name).copied()
380}
381
382#[derive(Clone)]
384pub struct WriteContextArgs<'a> {
385 pub root: &'a TokenStream,
387 pub context: &'a Ident,
390 pub df_ident: &'a Ident,
394 pub subgraph_id: GraphSubgraphId,
396 pub node_id: GraphNodeId,
398 pub loop_id: Option<GraphLoopId>,
400 pub op_span: Span,
402 pub op_tag: Option<String>,
404 pub work_fn: &'a Ident,
406 pub work_fn_async: &'a Ident,
408
409 pub ident: &'a Ident,
411 pub is_pull: bool,
413 pub inputs: &'a [Ident],
415 pub outputs: &'a [Ident],
417 pub singleton_output_ident: &'a Ident,
419
420 pub op_name: &'static str,
422 pub op_inst: &'a OperatorInstance,
424 pub arguments: &'a Punctuated<Expr, Token![,]>,
430 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
432}
433impl WriteContextArgs<'_> {
434 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
440 Ident::new(
441 &format!(
442 "sg_{:?}_node_{:?}_{}",
443 self.subgraph_id.data(),
444 self.node_id.data(),
445 suffix.as_ref(),
446 ),
447 self.op_span,
448 )
449 }
450
451 pub fn persistence_args_disallow_mutable<const N: usize>(
453 &self,
454 diagnostics: &mut Diagnostics,
455 ) -> [Persistence; N] {
456 let len = self.op_inst.generics.persistence_args.len();
457 if 0 != len && 1 != len && N != len {
458 diagnostics.push(Diagnostic::spanned(
459 self.op_span,
460 Level::Error,
461 format!(
462 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
463 self.op_name, N
464 ),
465 ));
466 }
467
468 let default_persistence = if self.loop_id.is_some() {
469 Persistence::None
470 } else {
471 Persistence::Tick
472 };
473 let mut out = [default_persistence; N];
474 self.op_inst
475 .generics
476 .persistence_args
477 .iter()
478 .copied()
479 .cycle() .take(N)
481 .enumerate()
482 .filter(|&(_i, p)| {
483 if p == Persistence::Mutable {
484 diagnostics.push(Diagnostic::spanned(
485 self.op_span,
486 Level::Error,
487 format!(
488 "An implementation of `'{}` does not exist",
489 p.to_str_lowercase()
490 ),
491 ));
492 false
493 } else {
494 true
495 }
496 })
497 .for_each(|(i, p)| {
498 out[i] = p;
499 });
500 out
501 }
502}
503
504pub trait RangeTrait<T>: Send + Sync + Debug
506where
507 T: ?Sized,
508{
509 fn start_bound(&self) -> Bound<&T>;
511 fn end_bound(&self) -> Bound<&T>;
513 fn contains(&self, item: &T) -> bool
515 where
516 T: PartialOrd<T>;
517
518 fn human_string(&self) -> String
520 where
521 T: Display + PartialEq,
522 {
523 match (self.start_bound(), self.end_bound()) {
524 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
525
526 (Bound::Included(n), Bound::Included(x)) if n == x => {
527 format!("exactly {}", n)
528 }
529 (Bound::Included(n), Bound::Included(x)) => {
530 format!("at least {} and at most {}", n, x)
531 }
532 (Bound::Included(n), Bound::Excluded(x)) => {
533 format!("at least {} and less than {}", n, x)
534 }
535 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
536 (Bound::Excluded(n), Bound::Included(x)) => {
537 format!("more than {} and at most {}", n, x)
538 }
539 (Bound::Excluded(n), Bound::Excluded(x)) => {
540 format!("more than {} and less than {}", n, x)
541 }
542 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
543 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
544 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
545 }
546 }
547}
548
549impl<R, T> RangeTrait<T> for R
550where
551 R: RangeBounds<T> + Send + Sync + Debug,
552{
553 fn start_bound(&self) -> Bound<&T> {
554 self.start_bound()
555 }
556
557 fn end_bound(&self) -> Bound<&T> {
558 self.end_bound()
559 }
560
561 fn contains(&self, item: &T) -> bool
562 where
563 T: PartialOrd<T>,
564 {
565 self.contains(item)
566 }
567}
568
569#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
571pub enum Persistence {
572 None,
574 Loop,
576 Tick,
578 Static,
580 Mutable,
582}
583impl Persistence {
584 pub fn to_str_lowercase(self) -> &'static str {
586 match self {
587 Persistence::None => "none",
588 Persistence::Tick => "tick",
589 Persistence::Loop => "loop",
590 Persistence::Static => "static",
591 Persistence::Mutable => "mutable",
592 }
593 }
594}
595
596fn make_missing_runtime_msg(op_name: &str) -> Literal {
598 Literal::string(&format!(
599 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
600 op_name
601 ))
602}
603
604#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
606pub enum OperatorCategory {
607 Map,
609 Filter,
611 Flatten,
613 Fold,
615 KeyedFold,
617 LatticeFold,
619 Persistence,
621 MultiIn,
623 MultiOut,
625 Source,
627 Sink,
629 Control,
631 CompilerFusionOperator,
633 Windowing,
635 Unwindowing,
637}
638impl OperatorCategory {
639 pub fn name(self) -> &'static str {
641 self.get_variant_docs().split_once(":").unwrap().0
642 }
643 pub fn description(self) -> &'static str {
645 self.get_variant_docs().split_once(":").unwrap().1
646 }
647}
648
649#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
651pub enum FloType {
652 Source,
654 Windowing,
656 Unwindowing,
658 NextIteration,
660}