calyx_opt/passes/
static_inference.rs

1use crate::analysis::{GoDone, InferenceAnalysis};
2use crate::traversal::{
3    Action, ConstructVisitor, Named, Order, VisResult, Visitor,
4};
5use calyx_ir::{self as ir, LibrarySignatures};
6use calyx_utils::CalyxResult;
7use itertools::Itertools;
8
9/// Infer @promotable annotation
10/// for groups and control.
11/// Inference occurs whenever possible.
12pub struct StaticInference {
13    /// Takes static information.
14    inference_analysis: InferenceAnalysis,
15}
16
17// Override constructor to build latency_data information from the primitives
18// library.
19impl ConstructVisitor for StaticInference {
20    fn from(ctx: &ir::Context) -> CalyxResult<Self> {
21        Ok(StaticInference {
22            inference_analysis: InferenceAnalysis::from_ctx(ctx),
23        })
24    }
25
26    // This pass shared information between components
27    fn clear_data(&mut self) {}
28}
29
30impl Named for StaticInference {
31    fn name() -> &'static str {
32        "static-inference"
33    }
34
35    fn description() -> &'static str {
36        "infer when dynamic control programs are promotable"
37    }
38}
39
40impl Visitor for StaticInference {
41    // Require post order traversal of components to ensure `invoke` nodes
42    // get timing information for components.
43    fn iteration_order() -> Order {
44        Order::Post
45    }
46
47    fn finish(
48        &mut self,
49        comp: &mut ir::Component,
50        _lib: &LibrarySignatures,
51        _comps: &[ir::Component],
52    ) -> VisResult {
53        if comp.name != "main" {
54            // If the entire component's control is promotable.
55            if let Some(val) =
56                InferenceAnalysis::get_possible_latency(&comp.control.borrow())
57            {
58                let comp_sig = comp.signature.borrow();
59                let mut go_ports: Vec<_> =
60                    comp_sig.find_all_with_attr(ir::NumAttr::Go).collect();
61                // Insert @promotable attribute on the go ports.
62                for go_port in &mut go_ports {
63                    go_port
64                        .borrow_mut()
65                        .attributes
66                        .insert(ir::NumAttr::Promotable, val);
67                }
68                let mut done_ports: Vec<_> =
69                    comp_sig.find_all_with_attr(ir::NumAttr::Done).collect();
70                // Update `latency_data`.
71                go_ports.sort_by_key(|port| {
72                    port.borrow().attributes.get(ir::NumAttr::Go).unwrap()
73                });
74                done_ports.sort_by_key(|port| {
75                    port.borrow().attributes.get(ir::NumAttr::Done).unwrap()
76                });
77                let zipped: Vec<_> =
78                    go_ports.iter().zip(done_ports.iter()).collect();
79                let go_done_ports = zipped
80                    .into_iter()
81                    .map(|(go_port, done_port)| {
82                        (go_port.borrow().name, done_port.borrow().name, val)
83                    })
84                    .collect_vec();
85                self.inference_analysis.add_component((
86                    comp.name,
87                    val,
88                    GoDone::new(go_done_ports),
89                ));
90            }
91        }
92        Ok(Action::Continue)
93    }
94
95    fn start(
96        &mut self,
97        comp: &mut ir::Component,
98        _sigs: &LibrarySignatures,
99        _comps: &[ir::Component],
100    ) -> VisResult {
101        // ``Fix up the timing'', but with the updated_components argument as
102        // and empty HashMap. This just performs inference.
103        self.inference_analysis.fixup_timing(comp);
104        Ok(Action::Continue)
105    }
106}