1use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{Expr, ExprPath, GenericArgument, Token, Type};
12
13use self::ops::{OperatorConstraints, Persistence};
14use crate::diagnostic::{Diagnostic, Diagnostics, Level};
15use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
16use crate::pretty_span::PrettySpan;
17
18mod di_mul_graph;
19mod eliminate_extra_unions_tees;
20mod flat_graph_builder;
21mod flat_to_partitioned;
22mod graph_write;
23mod meta_graph;
24mod meta_graph_debugging;
25
26use std::fmt::Display;
27
28pub use di_mul_graph::DiMulGraph;
29pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
30pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
31pub use flat_to_partitioned::partition_graph;
32pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
33
34pub use crate::graph_ids::{GraphEdgeId, GraphLoopId, GraphNodeId, GraphSubgraphId};
35
36pub mod graph_algorithms;
37pub mod ops;
38
39impl GraphSubgraphId {
40 pub fn as_ident(self, span: Span) -> Ident {
42 use slotmap::Key;
43 Ident::new(&format!("sgid_{:?}", self.data()), span)
44 }
45}
46
47impl GraphLoopId {
48 pub fn as_ident(self, span: Span) -> Ident {
50 use slotmap::Key;
51 Ident::new(&format!("loop_{:?}", self.data()), span)
52 }
53}
54
55const CONTEXT: &str = "context";
57const GRAPH: &str = "df";
59
60const HANDOFF_NODE_STR: &str = "handoff";
61const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
62
63mod serde_syn {
64 use serde::{Deserialize, Deserializer, Serializer};
65
66 pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
67 where
68 S: Serializer,
69 T: quote::ToTokens,
70 {
71 serializer.serialize_str(&value.to_token_stream().to_string())
72 }
73
74 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
75 where
76 D: Deserializer<'de>,
77 T: syn::parse::Parse,
78 {
79 let s = String::deserialize(deserializer)?;
80 syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
81 }
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
88pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
89
90#[derive(Clone, Serialize, Deserialize)]
92pub enum GraphNode {
93 Operator(#[serde(with = "serde_syn")] Operator),
95 Handoff {
97 #[serde(skip, default = "Span::call_site")]
99 src_span: Span,
100 #[serde(skip, default = "Span::call_site")]
102 dst_span: Span,
103 },
104
105 ModuleBoundary {
107 input: bool,
109
110 #[serde(skip, default = "Span::call_site")]
114 import_expr: Span,
115 },
116}
117impl GraphNode {
118 pub fn to_pretty_string(&self) -> Cow<'static, str> {
120 match self {
121 GraphNode::Operator(op) => op.to_pretty_string().into(),
122 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
123 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
124 }
125 }
126
127 pub fn to_name_string(&self) -> Cow<'static, str> {
129 match self {
130 GraphNode::Operator(op) => op.name_string().into(),
131 GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
132 GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
133 }
134 }
135
136 pub fn span(&self) -> Span {
138 match self {
139 Self::Operator(op) => op.span(),
140 &Self::Handoff {
141 src_span, dst_span, ..
142 } => src_span.join(dst_span).unwrap_or(src_span),
143 Self::ModuleBoundary { import_expr, .. } => *import_expr,
144 }
145 }
146}
147impl std::fmt::Debug for GraphNode {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Operator(operator) => {
151 write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
152 }
153 Self::Handoff { .. } => write!(f, "Node::Handoff"),
154 Self::ModuleBoundary { input, .. } => {
155 write!(f, "Node::ModuleBoundary{{input: {}}}", input)
156 }
157 }
158 }
159}
160
161#[derive(Clone, Debug)]
170pub struct OperatorInstance {
171 pub op_constraints: &'static OperatorConstraints,
173 pub input_ports: Vec<PortIndexValue>,
175 pub output_ports: Vec<PortIndexValue>,
177 pub singletons_referenced: Vec<Ident>,
179
180 pub generics: OpInstGenerics,
182 pub arguments_pre: Punctuated<Expr, Token![,]>,
188 pub arguments_raw: TokenStream,
190}
191
192#[derive(Clone, Debug)]
194pub struct OpInstGenerics {
195 pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
197 pub persistence_args: Vec<Persistence>,
199 pub type_args: Vec<Type>,
201}
202
203impl OpInstGenerics {
204 fn join_spans<I>(mut spans: I) -> Option<Span>
209 where
210 I: Iterator<Item = Span>,
211 {
212 let mut span = spans.next()?;
213 for s in spans {
214 span = span.join(s)?;
215 }
216 Some(span)
217 }
218
219 pub fn persistence_args_span(&self) -> Option<Span> {
221 self.generic_args.as_ref().and_then(|args| {
222 Self::join_spans(
223 args.iter()
224 .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
225 .map(|a| a.span()),
226 )
227 })
228 }
229
230 pub fn type_args_span(&self) -> Option<Span> {
232 self.generic_args.as_ref().and_then(|args| {
233 Self::join_spans(
234 args.iter()
235 .filter(|a| matches!(a, GenericArgument::Type(_)))
236 .map(|a| a.span()),
237 )
238 })
239 }
240}
241
242pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
247 let generic_args = operator.type_arguments().cloned();
249 let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
250 GenericArgument::Lifetime(lifetime) => {
251 match &*lifetime.ident.to_string() {
252 "none" => Some(Persistence::None),
253 "loop" => Some(Persistence::Loop),
254 "tick" => Some(Persistence::Tick),
255 "static" => Some(Persistence::Static),
256 "mutable" => Some(Persistence::Mutable),
257 _ => {
258 diagnostics.push(Diagnostic::spanned(
259 generic_arg.span(),
260 Level::Error,
261 format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
262 ));
263 None
265 }
266 }
267 },
268 _ => None,
269 }).collect::<Vec<_>>();
270 let type_args = generic_args
271 .iter()
272 .flatten()
273 .skip(persistence_args.len())
274 .map_while(|generic_arg| match generic_arg {
275 GenericArgument::Type(typ) => Some(typ),
276 _ => None,
277 })
278 .cloned()
279 .collect::<Vec<_>>();
280
281 OpInstGenerics {
282 generic_args,
283 persistence_args,
284 type_args,
285 }
286}
287
288#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
290pub enum Color {
291 Pull,
293 Push,
295 Comp,
297 Hoff,
299}
300
301#[derive(Clone, Debug, Serialize, Deserialize)]
303pub enum PortIndexValue {
304 Int(#[serde(with = "serde_syn")] IndexInt),
306 Path(#[serde(with = "serde_syn")] ExprPath),
308 Elided(#[serde(skip)] Option<Span>),
311}
312impl PortIndexValue {
313 pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
316 where
317 Inner: Spanned,
318 {
319 let ported_span = Some(ported.inner.span());
320 let port_inn = ported
321 .inn
322 .map(|idx| idx.index.into())
323 .unwrap_or_else(|| Self::Elided(ported_span));
324 let inner = ported.inner;
325 let port_out = ported
326 .out
327 .map(|idx| idx.index.into())
328 .unwrap_or_else(|| Self::Elided(ported_span));
329 (port_inn, inner, port_out)
330 }
331
332 pub fn is_specified(&self) -> bool {
334 !matches!(self, Self::Elided(_))
335 }
336
337 #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
341 #[allow(
342 clippy::result_large_err,
343 reason = "variants are same size, error isn't to be propagated."
344 )]
345 pub fn combine(self, other: Self) -> Result<Self, Self> {
346 match (self.is_specified(), other.is_specified()) {
347 (false, _other) => Ok(other),
348 (true, false) => Ok(self),
349 (true, true) => Err(self),
350 }
351 }
352
353 pub fn as_error_message_string(&self) -> String {
355 match self {
356 PortIndexValue::Int(n) => format!("`{}`", n.value),
357 PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
358 PortIndexValue::Elided(_) => "<elided>".to_owned(),
359 }
360 }
361
362 pub fn span(&self) -> Span {
364 match self {
365 PortIndexValue::Int(x) => x.span(),
366 PortIndexValue::Path(x) => x.span(),
367 PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
368 }
369 }
370}
371impl From<PortIndex> for PortIndexValue {
372 fn from(value: PortIndex) -> Self {
373 match value {
374 PortIndex::Int(x) => Self::Int(x),
375 PortIndex::Path(x) => Self::Path(x),
376 }
377 }
378}
379impl PartialEq for PortIndexValue {
380 fn eq(&self, other: &Self) -> bool {
381 match (self, other) {
382 (Self::Int(l0), Self::Int(r0)) => l0 == r0,
383 (Self::Path(l0), Self::Path(r0)) => l0 == r0,
384 (Self::Elided(_), Self::Elided(_)) => true,
385 _else => false,
386 }
387 }
388}
389impl Eq for PortIndexValue {}
390impl PartialOrd for PortIndexValue {
391 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
392 Some(self.cmp(other))
393 }
394}
395impl Ord for PortIndexValue {
396 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
397 match (self, other) {
398 (Self::Int(s), Self::Int(o)) => s.cmp(o),
399 (Self::Path(s), Self::Path(o)) => s
400 .to_token_stream()
401 .to_string()
402 .cmp(&o.to_token_stream().to_string()),
403 (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
404 (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
405 (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
406 (_, Self::Elided(_)) => std::cmp::Ordering::Less,
407 (Self::Elided(_), _) => std::cmp::Ordering::Greater,
408 }
409 }
410}
411
412impl Display for PortIndexValue {
413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 match self {
415 PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
416 PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
417 PortIndexValue::Elided(_) => write!(f, "[]"),
418 }
419 }
420}
421
422pub struct BuildDfirCodeOutput {
424 pub partitioned_graph: DfirGraph,
426 pub code: TokenStream,
428 pub diagnostics: Diagnostics,
430}
431
432pub fn build_dfir_code(
434 dfir_code: DfirCode,
435 root: &TokenStream,
436) -> Result<BuildDfirCodeOutput, Diagnostics> {
437 let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
438
439 let FlatGraphBuilderOutput {
440 mut flat_graph,
441 uses,
442 mut diagnostics,
443 } = flat_graph_builder.build()?;
444
445 let () = match flat_graph.merge_modules() {
446 Ok(()) => (),
447 Err(d) => {
448 diagnostics.push(d);
449 return Err(diagnostics);
450 }
451 };
452
453 eliminate_extra_unions_tees(&mut flat_graph);
454
455 for (_loop_id, nodes) in flat_graph.loops() {
459 let span = nodes
460 .first()
461 .map_or_else(Span::call_site, |&n| flat_graph.node(n).span());
462 diagnostics.push(Diagnostic::spanned(
463 span,
464 Level::Error,
465 "`loop { }` blocks are not (yet) supported in `dfir_syntax!`.",
466 ));
467 }
468 if diagnostics.has_error() {
469 return Err(diagnostics);
470 }
471
472 let partitioned_graph = match partition_graph(flat_graph) {
473 Ok(partitioned_graph) => partitioned_graph,
474 Err(d) => {
475 diagnostics.push(d);
476 return Err(diagnostics);
477 }
478 };
479
480 let code =
481 partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
482
483 Ok(BuildDfirCodeOutput {
484 partitioned_graph,
485 code,
486 diagnostics,
487 })
488}
489
490fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
492 use proc_macro2::{Group, TokenTree};
493 tokens
494 .into_iter()
495 .map(|token| match token {
496 TokenTree::Group(mut group) => {
497 group.set_span(span);
498 TokenTree::Group(Group::new(
499 group.delimiter(),
500 change_spans(group.stream(), span),
501 ))
502 }
503 TokenTree::Ident(mut ident) => {
504 ident.set_span(span.resolved_at(ident.span()));
505 TokenTree::Ident(ident)
506 }
507 TokenTree::Punct(mut punct) => {
508 punct.set_span(span);
509 TokenTree::Punct(punct)
510 }
511 TokenTree::Literal(mut literal) => {
512 literal.set_span(span);
513 TokenTree::Literal(literal)
514 }
515 })
516 .collect()
517}