Skip to main content

dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::union_find::UnionFind;
13
14/// Helper struct for tracking barrier crossers, see [`find_barrier_crossers`].
15struct BarrierCrossers {
16    /// Edge barrier crossers, including what type.
17    pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
18    /// Singleton reference barrier crossers — force subgraph boundaries.
19    pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
20}
21impl BarrierCrossers {
22    /// Iterate pairs of nodes that are across a barrier (subgraph boundary or tick boundary).
23    fn iter_node_pairs<'a>(
24        &'a self,
25        partitioned_graph: &'a DfirGraph,
26    ) -> impl 'a + Iterator<Item = (GraphNodeId, GraphNodeId)> {
27        let edge_pairs_iter = self
28            .edge_barrier_crossers
29            .iter()
30            .map(|(edge_id, &_delay_type)| partitioned_graph.edge(edge_id));
31        let singleton_pairs_iter = self.singleton_barrier_crossers.iter().copied();
32        edge_pairs_iter.chain(singleton_pairs_iter)
33    }
34
35    /// Insert/replace edge.
36    fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
37        if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
38            self.edge_barrier_crossers.insert(new_edge_id, delay_type);
39        }
40    }
41}
42
43/// Find all the barrier crossers.
44fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
45    let edge_barrier_crossers = partitioned_graph
46        .edges()
47        .filter(|&(_, (_src, dst))| {
48            // Ignore barriers within `loop {` blocks.
49            partitioned_graph.node_loop(dst).is_none()
50        })
51        .filter_map(|(edge_id, (_src, dst))| {
52            let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
53            let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
54            let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
55            Some((edge_id, input_barrier))
56        })
57        .collect();
58    let singleton_barrier_crossers = partitioned_graph
59        .node_ids()
60        .flat_map(|dst| {
61            partitioned_graph
62                .node_singleton_references(dst)
63                .iter()
64                .flatten()
65                .map(move |&src_ref| (src_ref, dst))
66        })
67        .collect();
68    BarrierCrossers {
69        edge_barrier_crossers,
70        singleton_barrier_crossers,
71    }
72}
73
74fn find_subgraph_unionfind(
75    partitioned_graph: &DfirGraph,
76    barrier_crossers: &BarrierCrossers,
77) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
78    // Modality (color) of nodes, push or pull.
79    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
80    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
81    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
82    let mut node_color = partitioned_graph
83        .node_ids()
84        .filter_map(|node_id| {
85            let op_color = partitioned_graph.node_color(node_id)?;
86            Some((node_id, op_color))
87        })
88        .collect::<SparseSecondaryMap<_, _>>();
89
90    let mut subgraph_unionfind: UnionFind<GraphNodeId> =
91        UnionFind::with_capacity(partitioned_graph.nodes().len());
92
93    // Will contain all edges which are handoffs. Starts out with all edges and
94    // we remove from this set as we combine nodes into subgraphs.
95    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
96    // Would sort edges here for priority (for now, no sort/priority).
97
98    // Each edge gets looked at in order. However we may not know if a linear
99    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
100    // algorithm would know to handle linear chains from the outside inward.
101    // But instead we just run through the edges in a loop until no more
102    // progress is made. Could have some sort of O(N^2) pathological worst
103    // case.
104    let mut progress = true;
105    while progress {
106        progress = false;
107        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
108        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
109            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
110            if subgraph_unionfind.same_set(src, dst) {
111                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
112                // Handoffs will be inserted later for this self-loop.
113                continue;
114            }
115
116            // Do not connect stratum crossers (next edges).
117            if barrier_crossers
118                .iter_node_pairs(partitioned_graph)
119                .any(|(x_src, x_dst)| {
120                    (subgraph_unionfind.same_set(x_src, src)
121                        && subgraph_unionfind.same_set(x_dst, dst))
122                        || (subgraph_unionfind.same_set(x_src, dst)
123                            && subgraph_unionfind.same_set(x_dst, src))
124                })
125            {
126                continue;
127            }
128
129            // Do not connect across loop contexts.
130            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
131                continue;
132            }
133            // Do not connect `next_iteration()`.
134            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
135                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
136            }) {
137                continue;
138            }
139
140            if can_connect_colorize(&mut node_color, src, dst) {
141                // At this point we have selected this edge and its src & dst to be
142                // within a single subgraph.
143                subgraph_unionfind.union(src, dst);
144                assert!(handoff_edges.remove(&edge_id));
145                progress = true;
146            }
147        }
148    }
149
150    (subgraph_unionfind, handoff_edges)
151}
152
153/// Builds the datastructures for checking which subgraph each node belongs to
154/// after handoffs have already been inserted to partition subgraphs.
155/// This list of nodes in each subgraph are returned in topological sort order.
156fn make_subgraph_collect(
157    partitioned_graph: &DfirGraph,
158    mut subgraph_unionfind: UnionFind<GraphNodeId>,
159) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
160    // We want the nodes of each subgraph to be listed in topo-sort order.
161    // We could do this on each subgraph, or we could do it all at once on the
162    // whole node graph by ignoring handoffs, which is what we do here:
163    let topo_sort = graph_algorithms::topo_sort(
164        partitioned_graph
165            .nodes()
166            .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
167            .map(|(node_id, _)| node_id),
168        |v| {
169            partitioned_graph
170                .node_predecessor_nodes(v)
171                .filter(|&pred_id| {
172                    let pred = partitioned_graph.node(pred_id);
173                    !matches!(pred, GraphNode::Handoff { .. })
174                })
175        },
176    )
177    .expect("Subgraphs are in-out trees.");
178
179    let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
180    for node_id in topo_sort {
181        let repr_node = subgraph_unionfind.find(node_id);
182        if !grouped_nodes.contains_key(repr_node) {
183            grouped_nodes.insert(repr_node, Default::default());
184        }
185        grouped_nodes[repr_node].push(node_id);
186    }
187    grouped_nodes
188}
189
190/// Find subgraph and insert handoffs.
191/// Modifies barrier_crossers so that the edge OUT of an inserted handoff has
192/// the DelayType data.
193fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
194    // Algorithm:
195    // 1. Each node begins as its own subgraph.
196    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
197    // 3. For each edge, try to join `(to, from)` into the same subgraph.
198
199    // TODO(mingwei):
200    // self.partitioned_graph.assert_valid();
201
202    let (subgraph_unionfind, handoff_edges) =
203        find_subgraph_unionfind(partitioned_graph, barrier_crossers);
204
205    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
206    for edge_id in handoff_edges {
207        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
208
209        // Already has a handoff, no need to insert one.
210        let src_node = partitioned_graph.node(src_id);
211        let dst_node = partitioned_graph.node(dst_id);
212        if matches!(src_node, GraphNode::Handoff { .. })
213            || matches!(dst_node, GraphNode::Handoff { .. })
214        {
215            continue;
216        }
217
218        let hoff = GraphNode::Handoff {
219            src_span: src_node.span(),
220            dst_span: dst_node.span(),
221        };
222        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
223
224        // Update barrier_crossers for inserted node.
225        barrier_crossers.replace_edge(edge_id, out_edge_id);
226    }
227
228    // Determine node's subgraph and subgraph's nodes.
229    // This list of nodes in each subgraph are to be in topological sort order.
230    // Eventually returned directly in the [`DfirGraph`].
231    let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
232    for (_repr_node, member_nodes) in grouped_nodes {
233        partitioned_graph.insert_subgraph(member_nodes).unwrap();
234    }
235}
236
237/// Set `src` or `dst` color if `None` based on the other (if possible):
238/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
239/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
240///
241/// Returns if `src` and `dst` can be in the same subgraph.
242fn can_connect_colorize(
243    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
244    src: GraphNodeId,
245    dst: GraphNodeId,
246) -> bool {
247    // Pull -> Pull
248    // Push -> Push
249    // Pull -> [Computation] -> Push
250    // Push -> [Handoff] -> Pull
251    let can_connect = match (node_color.get(src), node_color.get(dst)) {
252        // Linear chain, can't connect because it may cause future conflicts.
253        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
254        (None, None) => false,
255
256        // Infer left side.
257        (None, Some(Color::Pull | Color::Comp)) => {
258            node_color.insert(src, Color::Pull);
259            true
260        }
261        (None, Some(Color::Push | Color::Hoff)) => {
262            node_color.insert(src, Color::Push);
263            true
264        }
265
266        // Infer right side.
267        (Some(Color::Pull | Color::Hoff), None) => {
268            node_color.insert(dst, Color::Pull);
269            true
270        }
271        (Some(Color::Comp | Color::Push), None) => {
272            node_color.insert(dst, Color::Push);
273            true
274        }
275
276        // Both sides already specified.
277        (Some(Color::Pull), Some(Color::Pull)) => true,
278        (Some(Color::Pull), Some(Color::Comp)) => true,
279        (Some(Color::Pull), Some(Color::Push)) => true,
280
281        (Some(Color::Comp), Some(Color::Pull)) => false,
282        (Some(Color::Comp), Some(Color::Comp)) => false,
283        (Some(Color::Comp), Some(Color::Push)) => true,
284
285        (Some(Color::Push), Some(Color::Pull)) => false,
286        (Some(Color::Push), Some(Color::Comp)) => false,
287        (Some(Color::Push), Some(Color::Push)) => true,
288
289        // Handoffs are not part of subgraphs.
290        (Some(Color::Hoff), Some(_)) => false,
291        (Some(_), Some(Color::Hoff)) => false,
292    };
293    can_connect
294}
295
296/// Topologically sorts subgraphs and marks tick-boundary (`defer_tick` / `defer_tick_lazy`)
297/// handoffs with their delay type for double-buffered codegen in `as_code`.
298///
299/// Returns an error if there is an intra-tick cycle (i.e. the subgraph DAG has a cycle when
300/// tick-boundary edges are excluded).
301fn order_subgraphs(
302    partitioned_graph: &mut DfirGraph,
303    barrier_crossers: &BarrierCrossers,
304) -> Result<(), Diagnostic> {
305    // Build a subgraph-level directed graph, excluding tick-boundary edges.
306    let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
307
308    // Track which handoff edges are tick-boundary, keyed by (src_sg, dst_sg).
309    let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
310
311    // Iterate handoffs between subgraphs.
312    for (node_id, node) in partitioned_graph.nodes() {
313        if !matches!(node, GraphNode::Handoff { .. }) {
314            continue;
315        }
316        assert_eq!(1, partitioned_graph.node_successors(node_id).len());
317        let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
318
319        let succ_edge_delaytype = barrier_crossers
320            .edge_barrier_crossers
321            .get(succ_edge)
322            .copied();
323        // Tick edges are excluded from the topo sort — they are cross-tick by design.
324        if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
325            tick_edges.push((succ_edge, delay_type));
326            continue;
327        }
328
329        assert_eq!(1, partitioned_graph.node_predecessors(node_id).len());
330        let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
331
332        let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
333        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
334
335        sg_preds.entry(succ_sg).or_default().push(pred_sg);
336    }
337    // Include singleton reference edges.
338    for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
339        assert_ne!(pred, succ, "TODO(mingwei)");
340        let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
341        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
342        assert_ne!(pred_sg, succ_sg);
343        sg_preds.entry(succ_sg).or_default().push(pred_sg);
344    }
345
346    // Topological sort — rejects intra-tick cycles.
347    if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
348        sg_preds.get(&v).into_iter().flatten().copied()
349    }) {
350        let span = cycle
351            .first()
352            .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
353            .map(|n| partitioned_graph.node(n).span())
354            .unwrap_or_else(Span::call_site);
355        return Err(Diagnostic::spanned(
356            span,
357            Level::Error,
358            "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
359        ));
360    }
361
362    // Mark tick-boundary handoffs with their delay type.
363    // These handoffs are excluded from the intra-tick topo ordering in
364    // `as_code`; instead, their double-buffered handoff semantics defer data
365    // across the tick boundary to the next tick.
366    for (edge_id, delay_type) in tick_edges {
367        let (hoff, _dst) = partitioned_graph.edge(edge_id);
368        assert!(matches!(
369            partitioned_graph.node(hoff),
370            GraphNode::Handoff { .. }
371        ));
372        partitioned_graph.set_handoff_delay_type(hoff, delay_type);
373    }
374    Ok(())
375}
376
377/// Main method for this module. Partitions a flat [`DfirGraph`] into one with subgraphs.
378///
379/// Returns an error if an intra-tick cycle exists in the graph.
380pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
381    // Pre-find barrier crossers (input edges with a `DelayType`).
382    let mut barrier_crossers = find_barrier_crossers(&flat_graph);
383    let mut partitioned_graph = flat_graph;
384
385    // Partition into subgraphs.
386    make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
387
388    // Topologically order subgraphs and mark tick-boundary handoffs for double-buffering.
389    order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
390
391    Ok(partitioned_graph)
392}