calyx_opt/analysis/
inference_analysis.rs

1use super::AssignmentAnalysis;
2use crate::analysis::{GraphAnalysis, compute_static::WithStatic};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8/// Struct to store information about the go-done interfaces defined by a primitive.
9/// There is no default implementation because it will almost certainly be very
10/// unhelpful: you will want to use `from_ctx`.
11#[derive(Debug)]
12pub struct GoDone {
13    ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17    pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18        Self { ports }
19    }
20
21    /// Returns true if this is @go port
22    pub fn is_go(&self, name: &ir::Id) -> bool {
23        self.ports.iter().any(|(go, _, _)| name == go)
24    }
25
26    /// Returns true if this is a @done port
27    pub fn is_done(&self, name: &ir::Id) -> bool {
28        self.ports.iter().any(|(_, done, _)| name == done)
29    }
30
31    /// Returns the latency associated with the provided @go port if present
32    pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33        self.ports.iter().find_map(|(go, _, lat)| {
34            if go == go_port { Some(*lat) } else { None }
35        })
36    }
37
38    /// Iterate over the defined ports
39    pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
40        self.ports.iter()
41    }
42
43    /// Iterate over the defined ports
44    pub fn is_empty(&self) -> bool {
45        self.ports.is_empty()
46    }
47
48    /// Iterate over the defined ports
49    pub fn len(&self) -> usize {
50        self.ports.len()
51    }
52
53    /// Iterate over the defined ports
54    pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
55        &self.ports
56    }
57}
58
59impl From<&ir::Primitive> for GoDone {
60    fn from(prim: &ir::Primitive) -> Self {
61        let done_ports: HashMap<_, _> = prim
62            .find_all_with_attr(ir::NumAttr::Done)
63            .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
64            .collect();
65
66        let go_ports = prim
67            .find_all_with_attr(ir::NumAttr::Go)
68            .filter_map(|pd| {
69                // Primitives only have @interval.
70                pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
71                    done_ports
72                        .get(&pd.attributes.get(ir::NumAttr::Go))
73                        .map(|done_port| (pd.name(), *done_port, st))
74                })
75            })
76            .collect_vec();
77        GoDone::new(go_ports)
78    }
79}
80
81impl From<&ir::Cell> for GoDone {
82    fn from(cell: &ir::Cell) -> Self {
83        let done_ports: HashMap<_, _> = cell
84            .find_all_with_attr(ir::NumAttr::Done)
85            .map(|pr| {
86                let port = pr.borrow();
87                (port.attributes.get(ir::NumAttr::Done), port.name)
88            })
89            .collect();
90
91        let go_ports = cell
92            .find_all_with_attr(ir::NumAttr::Go)
93            .filter_map(|pr| {
94                let port = pr.borrow();
95                // Get static interval thru either @interval or @promotable.
96                let st = match port.attributes.get(ir::NumAttr::Interval) {
97                    Some(st) => Some(st),
98                    None => port.attributes.get(ir::NumAttr::Promotable),
99                };
100                if let Some(static_latency) = st {
101                    return done_ports
102                        .get(&port.attributes.get(ir::NumAttr::Go))
103                        .map(|done_port| {
104                            (port.name, *done_port, static_latency)
105                        });
106                }
107                None
108            })
109            .collect_vec();
110        GoDone::new(go_ports)
111    }
112}
113
114/// Default implemnetation is almost certainly not helpful.
115/// You should probably use `from_ctx` instead.
116pub struct InferenceAnalysis {
117    /// component name -> vec<(go signal, done signal, latency)>
118    pub latency_data: HashMap<ir::Id, GoDone>,
119    /// Maps static component names to their latencies, but there can only
120    /// be one go port on the component. (This is a subset of the information
121    /// given by latency_data), and is helpful for inferring invokes.
122    /// Perhaps someday we should get rid of it and only make it one field.
123    pub static_component_latencies: HashMap<ir::Id, u64>,
124
125    updated_components: HashSet<ir::Id>,
126}
127
128impl InferenceAnalysis {
129    /// Builds FixUp struct from a ctx. Looks at all primitives and component
130    /// signatures to get latency information.
131    pub fn from_ctx(ctx: &ir::Context) -> Self {
132        let mut latency_data = HashMap::new();
133        let mut static_component_latencies = HashMap::new();
134        // Construct latency_data for each primitive
135        for prim in ctx.lib.signatures() {
136            let prim_go_done = GoDone::from(prim);
137            if prim_go_done.len() == 1 {
138                static_component_latencies
139                    .insert(prim.name, prim_go_done.get_ports()[0].2);
140            }
141            latency_data.insert(prim.name, GoDone::from(prim));
142        }
143        for comp in &ctx.components {
144            let comp_sig = comp.signature.borrow();
145
146            let done_ports: HashMap<_, _> = comp_sig
147                .find_all_with_attr(ir::NumAttr::Done)
148                .map(|pd| {
149                    let pd_ref = pd.borrow();
150                    (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
151                })
152                .collect();
153
154            let go_ports = comp_sig
155                .find_all_with_attr(ir::NumAttr::Go)
156                .filter_map(|pd| {
157                    let pd_ref = pd.borrow();
158                    // Get static interval thru either @interval or @promotable.
159                    let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
160                    {
161                        Some(st) => Some(st),
162                        None => pd_ref.attributes.get(ir::NumAttr::Promotable),
163                    };
164                    if let Some(static_latency) = st {
165                        return done_ports
166                            .get(&pd_ref.attributes.get(ir::NumAttr::Go))
167                            .map(|done_port| {
168                                (pd_ref.name, *done_port, static_latency)
169                            });
170                    }
171                    None
172                })
173                .collect_vec();
174
175            let go_done_comp = GoDone::new(go_ports);
176
177            if go_done_comp.len() == 1 {
178                static_component_latencies
179                    .insert(comp.name, go_done_comp.get_ports()[0].2);
180            }
181            latency_data.insert(comp.name, go_done_comp);
182        }
183        InferenceAnalysis {
184            latency_data,
185            static_component_latencies,
186            updated_components: HashSet::new(),
187        }
188    }
189
190    /// Updates the component, given a component name and a new latency and GoDone object.
191    pub fn add_component(
192        &mut self,
193        (comp_name, latency, go_done): (ir::Id, u64, GoDone),
194    ) {
195        self.latency_data.insert(comp_name, go_done);
196        self.static_component_latencies.insert(comp_name, latency);
197    }
198
199    /// Updates the component, given a component name and a new latency.
200    /// Note that this expects that the component already is accounted for
201    /// in self.latency_data and self.static_component_latencies.
202    pub fn remove_component(&mut self, comp_name: ir::Id) {
203        if self.latency_data.contains_key(&comp_name) {
204            // To make inference as strong as possible, only update updated_components
205            // if we actually updated it.
206            self.updated_components.insert(comp_name);
207        }
208        self.latency_data.remove(&comp_name);
209        self.static_component_latencies.remove(&comp_name);
210    }
211
212    /// Updates the component, given a component name and a new latency.
213    /// Note that this expects that the component already is accounted for
214    /// in self.latency_data and self.static_component_latencies.
215    pub fn adjust_component(
216        &mut self,
217        (comp_name, adjusted_latency): (ir::Id, u64),
218    ) {
219        // Check whether we actually updated the component's latency.
220        let mut updated = false;
221        self.latency_data.entry(comp_name).and_modify(|go_done| {
222            for (_, _, cur_latency) in &mut go_done.ports {
223                // Updating components with latency data.
224                if *cur_latency != adjusted_latency {
225                    *cur_latency = adjusted_latency;
226                    updated = true;
227                }
228            }
229        });
230        self.static_component_latencies
231            .insert(comp_name, adjusted_latency);
232        if updated {
233            self.updated_components.insert(comp_name);
234        }
235    }
236
237    /// Return true if the edge (`src`, `dst`) meet one these criteria, and false otherwise:
238    ///   - `src` is an "out" port of a constant, and `dst` is a "go" port
239    ///   - `src` is a "done" port, and `dst` is a "go" port
240    ///   - `src` is a "done" port, and `dst` is the "done" port of a group
241    fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
242        match (&src.parent, &dst.parent) {
243            (
244                ir::PortParent::Cell(src_cell_wrf),
245                ir::PortParent::Cell(dst_cell_wrf),
246            ) => {
247                let src_rf = src_cell_wrf.upgrade();
248                let src_cell = src_rf.borrow();
249                let dst_rf = dst_cell_wrf.upgrade();
250                let dst_cell = dst_rf.borrow();
251                if let (Some(s_name), Some(d_name)) =
252                    (src_cell.type_name(), dst_cell.type_name())
253                {
254                    let data_src = self.latency_data.get(&s_name);
255                    let data_dst = self.latency_data.get(&d_name);
256                    if let (Some(dst_ports), Some(src_ports)) =
257                        (data_dst, data_src)
258                    {
259                        return src_ports.is_done(&src.name)
260                            && dst_ports.is_go(&dst.name);
261                    }
262                }
263
264                // A constant writes to a cell: to be added to the graph, the cell needs to be a "done" port.
265                if let (Some(d_name), ir::CellType::Constant { .. }) =
266                    (dst_cell.type_name(), &src_cell.prototype)
267                    && let Some(ports) = self.latency_data.get(&d_name)
268                {
269                    return ports.is_go(&dst.name);
270                }
271
272                false
273            }
274
275            // Something is written to a group: to be added to the graph, this needs to be a "done" port.
276            (_, ir::PortParent::Group(_)) => dst.name == "done",
277
278            // If we encounter anything else, no need to add it to the graph.
279            _ => false,
280        }
281    }
282
283    /// Return a Vec of edges (`a`, `b`), where `a` is a "go" port and `b`
284    /// is a "done" port, and `a` and `b` have the same parent cell.
285    fn find_go_done_edges(
286        &self,
287        group: &ir::Group,
288    ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
289        let rw_set = group.assignments.iter().analysis().cell_uses();
290        let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
291
292        for cell_ref in rw_set {
293            let cell = cell_ref.borrow();
294            if let Some(ports) =
295                cell.type_name().and_then(|c| self.latency_data.get(&c))
296            {
297                go_done_edges.extend(
298                    ports
299                        .iter()
300                        .map(|(go, done, _)| (cell.get(go), cell.get(done))),
301                )
302            }
303        }
304        go_done_edges
305    }
306
307    /// Returns true if `port` is a "done" port, and we know the latency data
308    /// about `port`, or is a constant.
309    fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
310        if let ir::PortParent::Cell(cwrf) = &port.parent {
311            let cr = cwrf.upgrade();
312            let cell = cr.borrow();
313            if let ir::CellType::Constant { val, .. } = &cell.prototype {
314                if *val > 0 {
315                    return true;
316                }
317            } else if let Some(ports) =
318                cell.type_name().and_then(|c| self.latency_data.get(&c))
319            {
320                return ports.is_done(&port.name);
321            }
322        }
323        false
324    }
325
326    /// Returns true if `graph` contains writes to "done" ports
327    /// that could have dynamic latencies, false otherwise.
328    fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
329        for port in &graph.ports() {
330            match &port.borrow().parent {
331                ir::PortParent::Cell(cell_wrf) => {
332                    let cr = cell_wrf.upgrade();
333                    let cell = cr.borrow();
334                    if let Some(ports) =
335                        cell.type_name().and_then(|c| self.latency_data.get(&c))
336                    {
337                        let name = &port.borrow().name;
338                        if ports.is_go(name) {
339                            for write_port in graph.writes_to(&port.borrow()) {
340                                if !self
341                                    .is_done_port_or_const(&write_port.borrow())
342                                {
343                                    log::debug!(
344                                        "`{}` is not a done port",
345                                        write_port.borrow().canonical(),
346                                    );
347                                    return true;
348                                }
349                            }
350                        }
351                    }
352                }
353                ir::PortParent::Group(_) => {
354                    if port.borrow().name == "done" {
355                        for write_port in graph.writes_to(&port.borrow()) {
356                            if !self.is_done_port_or_const(&write_port.borrow())
357                            {
358                                log::debug!(
359                                    "`{}` is not a done port",
360                                    write_port.borrow().canonical(),
361                                );
362                                return true;
363                            }
364                        }
365                    }
366                }
367
368                ir::PortParent::FSM(_) => {
369                    if port.borrow().name == "done" {
370                        for write_port in graph.writes_to(&port.borrow()) {
371                            if !self.is_done_port_or_const(&write_port.borrow())
372                            {
373                                log::debug!(
374                                    "`{}` is not a done port",
375                                    write_port.borrow().canonical(),
376                                );
377                                return true;
378                            }
379                        }
380                    }
381                }
382
383                ir::PortParent::StaticGroup(_) =>
384                // done ports of static groups should clearly NOT have static latencies
385                {
386                    panic!(
387                        "Have not decided how to handle static groups in infer-static-timing"
388                    )
389                }
390            }
391        }
392        false
393    }
394
395    /// Returns true if `graph` contains any nodes with degree > 1.
396    fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
397        for port in graph.ports() {
398            if graph.writes_to(&port.borrow()).count() > 1 {
399                return true;
400            }
401        }
402        false
403    }
404
405    /// Attempts to infer the number of cycles starting when
406    /// `group[go]` is high, and port is high. If inference is
407    /// not possible, returns None.
408    fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
409        // Creates a write dependency graph, which contains an edge (`a`, `b`) if:
410        //   - `a` is a "done" port, and writes to `b`, which is a "go" port
411        //   - `a` is a "done" port, and writes to `b`, which is the "done" port of this group
412        //   - `a` is an "out" port, and is a constant, and writes to `b`, a "go" port
413        //   - `a` is a "go" port, and `b` is a "done" port, and `a` and `b` share a parent cell
414        // Nodes that are not part of any edges that meet these criteria are excluded.
415        //
416        // For example, this group:
417        // ```
418        // group g1 {
419        //   a.in = 32'd1;
420        //   a.write_en = 1'd1;
421        //   g1[done] = a.done;
422        // }
423        // ```
424        // corresponds to this graph:
425        // ```
426        // constant(1) -> a.write_en
427        // a.write_en -> a.done
428        // a.done -> g1[done]
429        // ```
430        log::debug!("Checking group `{}`", group.name());
431        let graph_unprocessed = GraphAnalysis::from(group);
432        if self.contains_dyn_writes(&graph_unprocessed) {
433            log::debug!("FAIL: contains dynamic writes");
434            return None;
435        }
436
437        let go_done_edges = self.find_go_done_edges(group);
438        let graph = graph_unprocessed
439            .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
440            .add_edges(&go_done_edges)
441            .remove_isolated_vertices();
442
443        // Give up if a port has multiple writes to it.
444        if Self::contains_node_deg_gt_one(&graph) {
445            log::debug!("FAIL: Group contains multiple writes");
446            return None;
447        }
448
449        let mut tsort = graph.toposort();
450        let start = tsort.next()?;
451        let finish = tsort.last()?;
452
453        let paths = graph.paths(&start.borrow(), &finish.borrow());
454        // If there are no paths, give up.
455        if paths.is_empty() {
456            log::debug!("FAIL: No path between @go and @done port");
457            return None;
458        }
459        let first_path = paths.first().unwrap();
460
461        // Sum the latencies of each primitive along the path.
462        let mut latency_sum = 0;
463        for port in first_path {
464            if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
465                let cr = cwrf.upgrade();
466                let cell = cr.borrow();
467                if let Some(ports) =
468                    cell.type_name().and_then(|c| self.latency_data.get(&c))
469                    && let Some(latency) =
470                        ports.get_latency(&port.borrow().name)
471                {
472                    latency_sum += latency;
473                }
474            }
475        }
476
477        log::debug!("SUCCESS: Latency = {latency_sum}");
478        Some(latency_sum)
479    }
480
481    /// Returns Some(latency) if a control statement has a latency, because
482    /// it is static or is has the @promotable attribute
483    pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
484        match c {
485            ir::Control::Static(sc) => Some(sc.get_latency()),
486            _ => c.get_attribute(ir::NumAttr::Promotable),
487        }
488    }
489
490    pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
491        for stmt in &mut seq.stmts {
492            Self::remove_promotable_attribute(stmt);
493        }
494        seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
495    }
496
497    /// Removes the @promotable attribute from the control program.
498    /// Recursively visits the children of the control.
499    pub fn remove_promotable_attribute(c: &mut ir::Control) {
500        c.get_mut_attributes().remove(ir::NumAttr::Promotable);
501        match c {
502            ir::Control::Empty(_)
503            | ir::Control::Invoke(_)
504            | ir::Control::Enable(_)
505            | ir::Control::Static(_)
506            | ir::Control::FSMEnable(_) => (),
507            ir::Control::While(ir::While { body, .. })
508            | ir::Control::Repeat(ir::Repeat { body, .. }) => {
509                Self::remove_promotable_attribute(body);
510            }
511            ir::Control::If(ir::If {
512                tbranch, fbranch, ..
513            }) => {
514                Self::remove_promotable_attribute(tbranch);
515                Self::remove_promotable_attribute(fbranch);
516            }
517            ir::Control::Seq(ir::Seq { stmts, .. })
518            | ir::Control::Par(ir::Par { stmts, .. }) => {
519                for stmt in stmts {
520                    Self::remove_promotable_attribute(stmt);
521                }
522            }
523        }
524    }
525
526    pub fn fixup_seq(&self, seq: &mut ir::Seq) {
527        seq.update_static(&self.static_component_latencies);
528    }
529
530    pub fn fixup_par(&self, par: &mut ir::Par) {
531        par.update_static(&self.static_component_latencies);
532    }
533
534    pub fn fixup_if(&self, _if: &mut ir::If) {
535        _if.update_static(&self.static_component_latencies);
536    }
537
538    pub fn fixup_while(&self, _while: &mut ir::While) {
539        _while.update_static(&self.static_component_latencies);
540    }
541
542    pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
543        repeat.update_static(&self.static_component_latencies);
544    }
545
546    pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
547        ctrl.update_static(&self.static_component_latencies);
548    }
549
550    /// "Fixes Up" the component. In particular:
551    /// 1. Removes @promotable annotations for any groups that write to any
552    ///    `updated_components`.
553    /// 2. Try to re-infer groups' latencies.
554    /// 3. Removes all @promotable annotation from the control program.
555    /// 4. Re-infers the @promotable annotations for any groups or control.
556    ///
557    /// Note that this only fixes up the component's ``internals''.
558    /// It does *not* fix the component's signature.
559    pub fn fixup_timing(&self, comp: &mut ir::Component) {
560        // Removing @promotable annotations for any groups that write to an updated_component,
561        // then try to re-infer the latency.
562        for group in comp.groups.iter() {
563            // This checks any group that writes to the component:
564            // We can probably switch this to any group that writes to the component's
565            // `go` port to be more precise analysis.
566            if group
567                .borrow_mut()
568                .assignments
569                .iter()
570                .analysis()
571                .cell_writes()
572                .any(|cell| match cell.borrow().prototype {
573                    CellType::Component { name } => {
574                        self.updated_components.contains(&name)
575                    }
576                    _ => false,
577                })
578            {
579                // Remove attribute from group.
580                group
581                    .borrow_mut()
582                    .attributes
583                    .remove(ir::NumAttr::Promotable);
584            }
585        }
586
587        for group in &mut comp.groups.iter() {
588            // Immediately try to re-infer the latency of the group.
589            let latency_result = self.infer_latency(&group.borrow());
590            if let Some(latency) = latency_result {
591                group
592                    .borrow_mut()
593                    .attributes
594                    .insert(ir::NumAttr::Promotable, latency);
595            }
596        }
597
598        // Removing @promotable annotations for the control flow, then trying
599        // to re-infer them.
600        Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
601        comp.control
602            .borrow_mut()
603            .update_static(&self.static_component_latencies);
604    }
605}