calyx_opt/analysis/
inference_analysis.rs

1use super::AssignmentAnalysis;
2use crate::analysis::{GraphAnalysis, compute_static::WithStatic};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8/// Struct to store information about the go-done interfaces defined by a primitive.
9/// There is no default implementation because it will almost certainly be very
10/// unhelpful: you will want to use `from_ctx`.
11#[derive(Debug)]
12pub struct GoDone {
13    ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17    pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18        Self { ports }
19    }
20
21    /// Returns true if this is @go port
22    pub fn is_go(&self, name: &ir::Id) -> bool {
23        self.ports.iter().any(|(go, _, _)| name == go)
24    }
25
26    /// Returns true if this is a @done port
27    pub fn is_done(&self, name: &ir::Id) -> bool {
28        self.ports.iter().any(|(_, done, _)| name == done)
29    }
30
31    /// Returns the latency associated with the provided @go port if present
32    pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33        self.ports.iter().find_map(|(go, _, lat)| {
34            if go == go_port { Some(*lat) } else { None }
35        })
36    }
37
38    /// Iterate over the defined ports
39    pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
40        self.ports.iter()
41    }
42
43    /// Iterate over the defined ports
44    pub fn is_empty(&self) -> bool {
45        self.ports.is_empty()
46    }
47
48    /// Iterate over the defined ports
49    pub fn len(&self) -> usize {
50        self.ports.len()
51    }
52
53    /// Iterate over the defined ports
54    pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
55        &self.ports
56    }
57}
58
59impl From<&ir::Primitive> for GoDone {
60    fn from(prim: &ir::Primitive) -> Self {
61        let done_ports: HashMap<_, _> = prim
62            .find_all_with_attr(ir::NumAttr::Done)
63            .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
64            .collect();
65
66        let go_ports = prim
67            .find_all_with_attr(ir::NumAttr::Go)
68            .filter_map(|pd| {
69                // Primitives only have @interval.
70                pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
71                    done_ports
72                        .get(&pd.attributes.get(ir::NumAttr::Go))
73                        .map(|done_port| (pd.name(), *done_port, st))
74                })
75            })
76            .collect_vec();
77        GoDone::new(go_ports)
78    }
79}
80
81impl From<&ir::Cell> for GoDone {
82    fn from(cell: &ir::Cell) -> Self {
83        let done_ports: HashMap<_, _> = cell
84            .find_all_with_attr(ir::NumAttr::Done)
85            .map(|pr| {
86                let port = pr.borrow();
87                (port.attributes.get(ir::NumAttr::Done), port.name)
88            })
89            .collect();
90
91        let go_ports = cell
92            .find_all_with_attr(ir::NumAttr::Go)
93            .filter_map(|pr| {
94                let port = pr.borrow();
95                // Get static interval thru either @interval or @promotable.
96                let st = match port.attributes.get(ir::NumAttr::Interval) {
97                    Some(st) => Some(st),
98                    None => port.attributes.get(ir::NumAttr::Promotable),
99                };
100                if let Some(static_latency) = st {
101                    return done_ports
102                        .get(&port.attributes.get(ir::NumAttr::Go))
103                        .map(|done_port| {
104                            (port.name, *done_port, static_latency)
105                        });
106                }
107                None
108            })
109            .collect_vec();
110        GoDone::new(go_ports)
111    }
112}
113
114/// Default implemnetation is almost certainly not helpful.
115/// You should probably use `from_ctx` instead.
116pub struct InferenceAnalysis {
117    /// component name -> vec<(go signal, done signal, latency)>
118    pub latency_data: HashMap<ir::Id, GoDone>,
119    /// Maps static component names to their latencies, but there can only
120    /// be one go port on the component. (This is a subset of the information
121    /// given by latency_data), and is helpful for inferring invokes.
122    /// Perhaps someday we should get rid of it and only make it one field.
123    pub static_component_latencies: HashMap<ir::Id, u64>,
124
125    updated_components: HashSet<ir::Id>,
126}
127
128impl InferenceAnalysis {
129    /// Builds FixUp struct from a ctx. Looks at all primitives and component
130    /// signatures to get latency information.
131    pub fn from_ctx(ctx: &ir::Context) -> Self {
132        let mut latency_data = HashMap::new();
133        let mut static_component_latencies = HashMap::new();
134        // Construct latency_data for each primitive
135        for prim in ctx.lib.signatures() {
136            let prim_go_done = GoDone::from(prim);
137            if prim_go_done.len() == 1 {
138                static_component_latencies
139                    .insert(prim.name, prim_go_done.get_ports()[0].2);
140            }
141            latency_data.insert(prim.name, GoDone::from(prim));
142        }
143        for comp in &ctx.components {
144            let comp_sig = comp.signature.borrow();
145
146            let done_ports: HashMap<_, _> = comp_sig
147                .find_all_with_attr(ir::NumAttr::Done)
148                .map(|pd| {
149                    let pd_ref = pd.borrow();
150                    (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
151                })
152                .collect();
153
154            let go_ports = comp_sig
155                .find_all_with_attr(ir::NumAttr::Go)
156                .filter_map(|pd| {
157                    let pd_ref = pd.borrow();
158                    // Get static interval thru either @interval or @promotable.
159                    let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
160                    {
161                        Some(st) => Some(st),
162                        None => pd_ref.attributes.get(ir::NumAttr::Promotable),
163                    };
164                    if let Some(static_latency) = st {
165                        return done_ports
166                            .get(&pd_ref.attributes.get(ir::NumAttr::Go))
167                            .map(|done_port| {
168                                (pd_ref.name, *done_port, static_latency)
169                            });
170                    }
171                    None
172                })
173                .collect_vec();
174
175            let go_done_comp = GoDone::new(go_ports);
176
177            if go_done_comp.len() == 1 {
178                static_component_latencies
179                    .insert(comp.name, go_done_comp.get_ports()[0].2);
180            }
181            latency_data.insert(comp.name, go_done_comp);
182        }
183        InferenceAnalysis {
184            latency_data,
185            static_component_latencies,
186            updated_components: HashSet::new(),
187        }
188    }
189
190    /// Updates the component, given a component name and a new latency and GoDone object.
191    pub fn add_component(
192        &mut self,
193        (comp_name, latency, go_done): (ir::Id, u64, GoDone),
194    ) {
195        self.latency_data.insert(comp_name, go_done);
196        self.static_component_latencies.insert(comp_name, latency);
197    }
198
199    /// Updates the component, given a component name and a new latency.
200    /// Note that this expects that the component already is accounted for
201    /// in self.latency_data and self.static_component_latencies.
202    pub fn remove_component(&mut self, comp_name: ir::Id) {
203        if self.latency_data.contains_key(&comp_name) {
204            // To make inference as strong as possible, only update updated_components
205            // if we actually updated it.
206            self.updated_components.insert(comp_name);
207        }
208        self.latency_data.remove(&comp_name);
209        self.static_component_latencies.remove(&comp_name);
210    }
211
212    /// Updates the component, given a component name and a new latency.
213    /// Note that this expects that the component already is accounted for
214    /// in self.latency_data and self.static_component_latencies.
215    pub fn adjust_component(
216        &mut self,
217        (comp_name, adjusted_latency): (ir::Id, u64),
218    ) {
219        // Check whether we actually updated the component's latency.
220        let mut updated = false;
221        self.latency_data.entry(comp_name).and_modify(|go_done| {
222            for (_, _, cur_latency) in &mut go_done.ports {
223                // Updating components with latency data.
224                if *cur_latency != adjusted_latency {
225                    *cur_latency = adjusted_latency;
226                    updated = true;
227                }
228            }
229        });
230        self.static_component_latencies
231            .insert(comp_name, adjusted_latency);
232        if updated {
233            self.updated_components.insert(comp_name);
234        }
235    }
236
237    /// Return true if the edge (`src`, `dst`) meet one these criteria, and false otherwise:
238    ///   - `src` is an "out" port of a constant, and `dst` is a "go" port
239    ///   - `src` is a "done" port, and `dst` is a "go" port
240    ///   - `src` is a "done" port, and `dst` is the "done" port of a group
241    fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
242        match (&src.parent, &dst.parent) {
243            (
244                ir::PortParent::Cell(src_cell_wrf),
245                ir::PortParent::Cell(dst_cell_wrf),
246            ) => {
247                let src_rf = src_cell_wrf.upgrade();
248                let src_cell = src_rf.borrow();
249                let dst_rf = dst_cell_wrf.upgrade();
250                let dst_cell = dst_rf.borrow();
251                if let (Some(s_name), Some(d_name)) =
252                    (src_cell.type_name(), dst_cell.type_name())
253                {
254                    let data_src = self.latency_data.get(&s_name);
255                    let data_dst = self.latency_data.get(&d_name);
256                    if let (Some(dst_ports), Some(src_ports)) =
257                        (data_dst, data_src)
258                    {
259                        return src_ports.is_done(&src.name)
260                            && dst_ports.is_go(&dst.name);
261                    }
262                }
263
264                // A constant writes to a cell: to be added to the graph, the cell needs to be a "done" port.
265                if let (Some(d_name), ir::CellType::Constant { .. }) =
266                    (dst_cell.type_name(), &src_cell.prototype)
267                {
268                    if let Some(ports) = self.latency_data.get(&d_name) {
269                        return ports.is_go(&dst.name);
270                    }
271                }
272
273                false
274            }
275
276            // Something is written to a group: to be added to the graph, this needs to be a "done" port.
277            (_, ir::PortParent::Group(_)) => dst.name == "done",
278
279            // If we encounter anything else, no need to add it to the graph.
280            _ => false,
281        }
282    }
283
284    /// Return a Vec of edges (`a`, `b`), where `a` is a "go" port and `b`
285    /// is a "done" port, and `a` and `b` have the same parent cell.
286    fn find_go_done_edges(
287        &self,
288        group: &ir::Group,
289    ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
290        let rw_set = group.assignments.iter().analysis().cell_uses();
291        let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
292
293        for cell_ref in rw_set {
294            let cell = cell_ref.borrow();
295            if let Some(ports) =
296                cell.type_name().and_then(|c| self.latency_data.get(&c))
297            {
298                go_done_edges.extend(
299                    ports
300                        .iter()
301                        .map(|(go, done, _)| (cell.get(go), cell.get(done))),
302                )
303            }
304        }
305        go_done_edges
306    }
307
308    /// Returns true if `port` is a "done" port, and we know the latency data
309    /// about `port`, or is a constant.
310    fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
311        if let ir::PortParent::Cell(cwrf) = &port.parent {
312            let cr = cwrf.upgrade();
313            let cell = cr.borrow();
314            if let ir::CellType::Constant { val, .. } = &cell.prototype {
315                if *val > 0 {
316                    return true;
317                }
318            } else if let Some(ports) =
319                cell.type_name().and_then(|c| self.latency_data.get(&c))
320            {
321                return ports.is_done(&port.name);
322            }
323        }
324        false
325    }
326
327    /// Returns true if `graph` contains writes to "done" ports
328    /// that could have dynamic latencies, false otherwise.
329    fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
330        for port in &graph.ports() {
331            match &port.borrow().parent {
332                ir::PortParent::Cell(cell_wrf) => {
333                    let cr = cell_wrf.upgrade();
334                    let cell = cr.borrow();
335                    if let Some(ports) =
336                        cell.type_name().and_then(|c| self.latency_data.get(&c))
337                    {
338                        let name = &port.borrow().name;
339                        if ports.is_go(name) {
340                            for write_port in graph.writes_to(&port.borrow()) {
341                                if !self
342                                    .is_done_port_or_const(&write_port.borrow())
343                                {
344                                    log::debug!(
345                                        "`{}` is not a done port",
346                                        write_port.borrow().canonical(),
347                                    );
348                                    return true;
349                                }
350                            }
351                        }
352                    }
353                }
354                ir::PortParent::Group(_) => {
355                    if port.borrow().name == "done" {
356                        for write_port in graph.writes_to(&port.borrow()) {
357                            if !self.is_done_port_or_const(&write_port.borrow())
358                            {
359                                log::debug!(
360                                    "`{}` is not a done port",
361                                    write_port.borrow().canonical(),
362                                );
363                                return true;
364                            }
365                        }
366                    }
367                }
368
369                ir::PortParent::FSM(_) => {
370                    if port.borrow().name == "done" {
371                        for write_port in graph.writes_to(&port.borrow()) {
372                            if !self.is_done_port_or_const(&write_port.borrow())
373                            {
374                                log::debug!(
375                                    "`{}` is not a done port",
376                                    write_port.borrow().canonical(),
377                                );
378                                return true;
379                            }
380                        }
381                    }
382                }
383
384                ir::PortParent::StaticGroup(_) =>
385                // done ports of static groups should clearly NOT have static latencies
386                {
387                    panic!(
388                        "Have not decided how to handle static groups in infer-static-timing"
389                    )
390                }
391            }
392        }
393        false
394    }
395
396    /// Returns true if `graph` contains any nodes with degree > 1.
397    fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
398        for port in graph.ports() {
399            if graph.writes_to(&port.borrow()).count() > 1 {
400                return true;
401            }
402        }
403        false
404    }
405
406    /// Attempts to infer the number of cycles starting when
407    /// `group[go]` is high, and port is high. If inference is
408    /// not possible, returns None.
409    fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
410        // Creates a write dependency graph, which contains an edge (`a`, `b`) if:
411        //   - `a` is a "done" port, and writes to `b`, which is a "go" port
412        //   - `a` is a "done" port, and writes to `b`, which is the "done" port of this group
413        //   - `a` is an "out" port, and is a constant, and writes to `b`, a "go" port
414        //   - `a` is a "go" port, and `b` is a "done" port, and `a` and `b` share a parent cell
415        // Nodes that are not part of any edges that meet these criteria are excluded.
416        //
417        // For example, this group:
418        // ```
419        // group g1 {
420        //   a.in = 32'd1;
421        //   a.write_en = 1'd1;
422        //   g1[done] = a.done;
423        // }
424        // ```
425        // corresponds to this graph:
426        // ```
427        // constant(1) -> a.write_en
428        // a.write_en -> a.done
429        // a.done -> g1[done]
430        // ```
431        log::debug!("Checking group `{}`", group.name());
432        let graph_unprocessed = GraphAnalysis::from(group);
433        if self.contains_dyn_writes(&graph_unprocessed) {
434            log::debug!("FAIL: contains dynamic writes");
435            return None;
436        }
437
438        let go_done_edges = self.find_go_done_edges(group);
439        let graph = graph_unprocessed
440            .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
441            .add_edges(&go_done_edges)
442            .remove_isolated_vertices();
443
444        // Give up if a port has multiple writes to it.
445        if Self::contains_node_deg_gt_one(&graph) {
446            log::debug!("FAIL: Group contains multiple writes");
447            return None;
448        }
449
450        let mut tsort = graph.toposort();
451        let start = tsort.next()?;
452        let finish = tsort.last()?;
453
454        let paths = graph.paths(&start.borrow(), &finish.borrow());
455        // If there are no paths, give up.
456        if paths.is_empty() {
457            log::debug!("FAIL: No path between @go and @done port");
458            return None;
459        }
460        let first_path = paths.first().unwrap();
461
462        // Sum the latencies of each primitive along the path.
463        let mut latency_sum = 0;
464        for port in first_path {
465            if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
466                let cr = cwrf.upgrade();
467                let cell = cr.borrow();
468                if let Some(ports) =
469                    cell.type_name().and_then(|c| self.latency_data.get(&c))
470                {
471                    if let Some(latency) =
472                        ports.get_latency(&port.borrow().name)
473                    {
474                        latency_sum += latency;
475                    }
476                }
477            }
478        }
479
480        log::debug!("SUCCESS: Latency = {latency_sum}");
481        Some(latency_sum)
482    }
483
484    /// Returns Some(latency) if a control statement has a latency, because
485    /// it is static or is has the @promotable attribute
486    pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
487        match c {
488            ir::Control::Static(sc) => Some(sc.get_latency()),
489            _ => c.get_attribute(ir::NumAttr::Promotable),
490        }
491    }
492
493    pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
494        for stmt in &mut seq.stmts {
495            Self::remove_promotable_attribute(stmt);
496        }
497        seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
498    }
499
500    /// Removes the @promotable attribute from the control program.
501    /// Recursively visits the children of the control.
502    pub fn remove_promotable_attribute(c: &mut ir::Control) {
503        c.get_mut_attributes().remove(ir::NumAttr::Promotable);
504        match c {
505            ir::Control::Empty(_)
506            | ir::Control::Invoke(_)
507            | ir::Control::Enable(_)
508            | ir::Control::Static(_)
509            | ir::Control::FSMEnable(_) => (),
510            ir::Control::While(ir::While { body, .. })
511            | ir::Control::Repeat(ir::Repeat { body, .. }) => {
512                Self::remove_promotable_attribute(body);
513            }
514            ir::Control::If(ir::If {
515                tbranch, fbranch, ..
516            }) => {
517                Self::remove_promotable_attribute(tbranch);
518                Self::remove_promotable_attribute(fbranch);
519            }
520            ir::Control::Seq(ir::Seq { stmts, .. })
521            | ir::Control::Par(ir::Par { stmts, .. }) => {
522                for stmt in stmts {
523                    Self::remove_promotable_attribute(stmt);
524                }
525            }
526        }
527    }
528
529    pub fn fixup_seq(&self, seq: &mut ir::Seq) {
530        seq.update_static(&self.static_component_latencies);
531    }
532
533    pub fn fixup_par(&self, par: &mut ir::Par) {
534        par.update_static(&self.static_component_latencies);
535    }
536
537    pub fn fixup_if(&self, _if: &mut ir::If) {
538        _if.update_static(&self.static_component_latencies);
539    }
540
541    pub fn fixup_while(&self, _while: &mut ir::While) {
542        _while.update_static(&self.static_component_latencies);
543    }
544
545    pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
546        repeat.update_static(&self.static_component_latencies);
547    }
548
549    pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
550        ctrl.update_static(&self.static_component_latencies);
551    }
552
553    /// "Fixes Up" the component. In particular:
554    /// 1. Removes @promotable annotations for any groups that write to any
555    ///    `updated_components`.
556    /// 2. Try to re-infer groups' latencies.
557    /// 3. Removes all @promotable annotation from the control program.
558    /// 4. Re-infers the @promotable annotations for any groups or control.
559    ///
560    /// Note that this only fixes up the component's ``internals''.
561    /// It does *not* fix the component's signature.
562    pub fn fixup_timing(&self, comp: &mut ir::Component) {
563        // Removing @promotable annotations for any groups that write to an updated_component,
564        // then try to re-infer the latency.
565        for group in comp.groups.iter() {
566            // This checks any group that writes to the component:
567            // We can probably switch this to any group that writes to the component's
568            // `go` port to be more precise analysis.
569            if group
570                .borrow_mut()
571                .assignments
572                .iter()
573                .analysis()
574                .cell_writes()
575                .any(|cell| match cell.borrow().prototype {
576                    CellType::Component { name } => {
577                        self.updated_components.contains(&name)
578                    }
579                    _ => false,
580                })
581            {
582                // Remove attribute from group.
583                group
584                    .borrow_mut()
585                    .attributes
586                    .remove(ir::NumAttr::Promotable);
587            }
588        }
589
590        for group in &mut comp.groups.iter() {
591            // Immediately try to re-infer the latency of the group.
592            let latency_result = self.infer_latency(&group.borrow());
593            if let Some(latency) = latency_result {
594                group
595                    .borrow_mut()
596                    .attributes
597                    .insert(ir::NumAttr::Promotable, latency);
598            }
599        }
600
601        // Removing @promotable annotations for the control flow, then trying
602        // to re-infer them.
603        Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
604        comp.control
605            .borrow_mut()
606            .update_static(&self.static_component_latencies);
607    }
608}