1use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::union_find::UnionFind;
13
14struct BarrierCrossers {
16 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
18 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
20}
21impl BarrierCrossers {
22 fn iter_node_pairs<'a>(
24 &'a self,
25 partitioned_graph: &'a DfirGraph,
26 ) -> impl 'a + Iterator<Item = (GraphNodeId, GraphNodeId)> {
27 let edge_pairs_iter = self
28 .edge_barrier_crossers
29 .iter()
30 .map(|(edge_id, &_delay_type)| partitioned_graph.edge(edge_id));
31 let singleton_pairs_iter = self.singleton_barrier_crossers.iter().copied();
32 edge_pairs_iter.chain(singleton_pairs_iter)
33 }
34
35 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
37 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
38 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
39 }
40 }
41}
42
43fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
45 let edge_barrier_crossers = partitioned_graph
46 .edges()
47 .filter(|&(_, (_src, dst))| {
48 partitioned_graph.node_loop(dst).is_none()
50 })
51 .filter_map(|(edge_id, (_src, dst))| {
52 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
53 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
54 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
55 Some((edge_id, input_barrier))
56 })
57 .collect();
58 let singleton_barrier_crossers = partitioned_graph
59 .node_ids()
60 .flat_map(|dst| {
61 partitioned_graph
62 .node_singleton_references(dst)
63 .iter()
64 .flatten()
65 .map(move |&src_ref| (src_ref, dst))
66 })
67 .collect();
68 BarrierCrossers {
69 edge_barrier_crossers,
70 singleton_barrier_crossers,
71 }
72}
73
74fn find_subgraph_unionfind(
75 partitioned_graph: &DfirGraph,
76 barrier_crossers: &BarrierCrossers,
77) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
78 let mut node_color = partitioned_graph
83 .node_ids()
84 .filter_map(|node_id| {
85 let op_color = partitioned_graph.node_color(node_id)?;
86 Some((node_id, op_color))
87 })
88 .collect::<SparseSecondaryMap<_, _>>();
89
90 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
91 UnionFind::with_capacity(partitioned_graph.nodes().len());
92
93 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
96 let mut progress = true;
105 while progress {
106 progress = false;
107 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
109 if subgraph_unionfind.same_set(src, dst) {
111 continue;
114 }
115
116 if barrier_crossers
118 .iter_node_pairs(partitioned_graph)
119 .any(|(x_src, x_dst)| {
120 (subgraph_unionfind.same_set(x_src, src)
121 && subgraph_unionfind.same_set(x_dst, dst))
122 || (subgraph_unionfind.same_set(x_src, dst)
123 && subgraph_unionfind.same_set(x_dst, src))
124 })
125 {
126 continue;
127 }
128
129 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
131 continue;
132 }
133 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
135 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
136 }) {
137 continue;
138 }
139
140 if can_connect_colorize(&mut node_color, src, dst) {
141 subgraph_unionfind.union(src, dst);
144 assert!(handoff_edges.remove(&edge_id));
145 progress = true;
146 }
147 }
148 }
149
150 (subgraph_unionfind, handoff_edges)
151}
152
153fn make_subgraph_collect(
157 partitioned_graph: &DfirGraph,
158 mut subgraph_unionfind: UnionFind<GraphNodeId>,
159) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
160 let topo_sort = graph_algorithms::topo_sort(
164 partitioned_graph
165 .nodes()
166 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
167 .map(|(node_id, _)| node_id),
168 |v| {
169 partitioned_graph
170 .node_predecessor_nodes(v)
171 .filter(|&pred_id| {
172 let pred = partitioned_graph.node(pred_id);
173 !matches!(pred, GraphNode::Handoff { .. })
174 })
175 },
176 )
177 .expect("Subgraphs are in-out trees.");
178
179 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
180 for node_id in topo_sort {
181 let repr_node = subgraph_unionfind.find(node_id);
182 if !grouped_nodes.contains_key(repr_node) {
183 grouped_nodes.insert(repr_node, Default::default());
184 }
185 grouped_nodes[repr_node].push(node_id);
186 }
187 grouped_nodes
188}
189
190fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
194 let (subgraph_unionfind, handoff_edges) =
203 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
204
205 for edge_id in handoff_edges {
207 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
208
209 let src_node = partitioned_graph.node(src_id);
211 let dst_node = partitioned_graph.node(dst_id);
212 if matches!(src_node, GraphNode::Handoff { .. })
213 || matches!(dst_node, GraphNode::Handoff { .. })
214 {
215 continue;
216 }
217
218 let hoff = GraphNode::Handoff {
219 src_span: src_node.span(),
220 dst_span: dst_node.span(),
221 };
222 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
223
224 barrier_crossers.replace_edge(edge_id, out_edge_id);
226 }
227
228 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
232 for (_repr_node, member_nodes) in grouped_nodes {
233 partitioned_graph.insert_subgraph(member_nodes).unwrap();
234 }
235}
236
237fn can_connect_colorize(
243 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
244 src: GraphNodeId,
245 dst: GraphNodeId,
246) -> bool {
247 let can_connect = match (node_color.get(src), node_color.get(dst)) {
252 (None, None) => false,
255
256 (None, Some(Color::Pull | Color::Comp)) => {
258 node_color.insert(src, Color::Pull);
259 true
260 }
261 (None, Some(Color::Push | Color::Hoff)) => {
262 node_color.insert(src, Color::Push);
263 true
264 }
265
266 (Some(Color::Pull | Color::Hoff), None) => {
268 node_color.insert(dst, Color::Pull);
269 true
270 }
271 (Some(Color::Comp | Color::Push), None) => {
272 node_color.insert(dst, Color::Push);
273 true
274 }
275
276 (Some(Color::Pull), Some(Color::Pull)) => true,
278 (Some(Color::Pull), Some(Color::Comp)) => true,
279 (Some(Color::Pull), Some(Color::Push)) => true,
280
281 (Some(Color::Comp), Some(Color::Pull)) => false,
282 (Some(Color::Comp), Some(Color::Comp)) => false,
283 (Some(Color::Comp), Some(Color::Push)) => true,
284
285 (Some(Color::Push), Some(Color::Pull)) => false,
286 (Some(Color::Push), Some(Color::Comp)) => false,
287 (Some(Color::Push), Some(Color::Push)) => true,
288
289 (Some(Color::Hoff), Some(_)) => false,
291 (Some(_), Some(Color::Hoff)) => false,
292 };
293 can_connect
294}
295
296fn order_subgraphs(
302 partitioned_graph: &mut DfirGraph,
303 barrier_crossers: &BarrierCrossers,
304) -> Result<(), Diagnostic> {
305 let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
307
308 let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
310
311 for (node_id, node) in partitioned_graph.nodes() {
313 if !matches!(node, GraphNode::Handoff { .. }) {
314 continue;
315 }
316 assert_eq!(1, partitioned_graph.node_successors(node_id).len());
317 let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
318
319 let succ_edge_delaytype = barrier_crossers
320 .edge_barrier_crossers
321 .get(succ_edge)
322 .copied();
323 if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
325 tick_edges.push((succ_edge, delay_type));
326 continue;
327 }
328
329 assert_eq!(1, partitioned_graph.node_predecessors(node_id).len());
330 let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
331
332 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
333 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
334
335 sg_preds.entry(succ_sg).or_default().push(pred_sg);
336 }
337 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
339 assert_ne!(pred, succ, "TODO(mingwei)");
340 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
341 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
342 assert_ne!(pred_sg, succ_sg);
343 sg_preds.entry(succ_sg).or_default().push(pred_sg);
344 }
345
346 if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
348 sg_preds.get(&v).into_iter().flatten().copied()
349 }) {
350 let span = cycle
351 .first()
352 .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
353 .map(|n| partitioned_graph.node(n).span())
354 .unwrap_or_else(Span::call_site);
355 return Err(Diagnostic::spanned(
356 span,
357 Level::Error,
358 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
359 ));
360 }
361
362 for (edge_id, delay_type) in tick_edges {
367 let (hoff, _dst) = partitioned_graph.edge(edge_id);
368 assert!(matches!(
369 partitioned_graph.node(hoff),
370 GraphNode::Handoff { .. }
371 ));
372 partitioned_graph.set_handoff_delay_type(hoff, delay_type);
373 }
374 Ok(())
375}
376
377pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
381 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
383 let mut partitioned_graph = flat_graph;
384
385 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
387
388 order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
390
391 Ok(partitioned_graph)
392}