calyx_opt/passes/
group_to_invoke.rs

1use crate::analysis::AssignmentAnalysis;
2use crate::traversal::{Action, ConstructVisitor, Named, VisResult, Visitor};
3use calyx_ir::{self as ir};
4use calyx_ir::{GetAttributes, RRC};
5use calyx_utils::CalyxResult;
6use ir::Nothing;
7use itertools::Itertools;
8use std::collections::{HashMap, HashSet};
9use std::rc::Rc;
10
11/// Transform groups that are structurally invoking components into equivalent
12/// [ir::Invoke] statements.
13///
14/// For a group to meet the requirements of this pass, it must
15/// 1. Only write to one non-combinational component (all other writes must be
16///    to combinational primitives)
17/// 2. That component is *not* a ref cell, nor does it have the external attribute,
18///    nor is it This Component
19/// 3. Assign component.go = 1'd1
20/// 4. Assign group[done] = component.done
21pub struct GroupToInvoke {
22    /// Primitives that have multiple @go-@done signals
23    blacklist: HashSet<ir::Id>,
24    /// Maps names of group to the invokes that will replace them
25    group_invoke_map: HashMap<ir::Id, ir::Control>,
26}
27
28impl ConstructVisitor for GroupToInvoke {
29    fn from(ctx: &ir::Context) -> CalyxResult<Self>
30    where
31        Self: Sized,
32    {
33        // Construct list of primitives that have multiple go-done signals
34        let blacklist = ctx
35            .lib
36            .signatures()
37            .filter(|p| p.find_all_with_attr(ir::NumAttr::Go).count() > 1)
38            .map(|p| p.name)
39            .collect();
40
41        Ok(Self {
42            blacklist,
43            group_invoke_map: HashMap::new(),
44        })
45    }
46
47    fn clear_data(&mut self) {
48        self.group_invoke_map = HashMap::new();
49    }
50}
51
52impl Named for GroupToInvoke {
53    fn name() -> &'static str {
54        "group2invoke"
55    }
56
57    fn description() -> &'static str {
58        "covert groups that structurally invoke one component into invoke statements"
59    }
60}
61
62/// Construct an [ir::Invoke] from an [ir::Group] that has been validated by this pass.
63fn construct_invoke(
64    assigns: &[ir::Assignment<Nothing>],
65    comp: RRC<ir::Cell>,
66    builder: &mut ir::Builder,
67) -> ir::Control {
68    // Check if port's parent is a combinational primitive
69    let parent_is_comb = |port: &ir::Port| -> bool {
70        if !port.is_hole()
71            && let ir::CellType::Primitive { is_comb, .. } =
72                port.cell_parent().borrow().prototype
73        {
74            return is_comb;
75        }
76        false
77    };
78
79    // Check if port's parent is equal to comp
80    let parent_is_cell = |port: &ir::Port| -> bool {
81        match &port.parent {
82            ir::PortParent::Cell(cell_wref) => {
83                Rc::ptr_eq(&cell_wref.upgrade(), &comp)
84            }
85            _ => false,
86        }
87    };
88
89    let mut inputs = Vec::new();
90    let mut comb_assigns = Vec::new();
91    let mut wire_map: HashMap<ir::Id, ir::RRC<ir::Port>> = HashMap::new();
92
93    for assign in assigns {
94        // We know that all assignments in this group should write to either a)
95        // a combinational component or b) comp or c) the group's done port-- we
96        // should have checked for this condition before calling this function
97
98        // If a combinational component's port is being used as a dest, add
99        // it to comb_assigns
100        if parent_is_comb(&assign.dst.borrow()) {
101            comb_assigns.push(assign.clone());
102        }
103        // If the cell's port is being used as a dest, add the source to
104        // inputs. we can ignore the cell.go assignment, since that is not
105        // going to be part of the `invoke`.
106        else if parent_is_cell(&assign.dst.borrow())
107            && assign.dst
108                != comp.borrow().get_unique_with_attr(ir::NumAttr::Go).unwrap()
109        {
110            let name = assign.dst.borrow().name;
111            if assign.guard.is_true() {
112                inputs.push((name, Rc::clone(&assign.src)));
113            } else {
114                // assign has a guard condition,so need a wire
115                // We first check whether we have already built a wire
116                // for this port or not.
117                let wire_in = match wire_map.get(&assign.dst.borrow().name) {
118                    Some(w) => {
119                        // Already built a wire, so just need to return the
120                        // wire's input port (which wire_map stores)
121                        Rc::clone(w)
122                    }
123                    None => {
124                        // Need to create a new wire
125                        let width = assign.dst.borrow().width;
126                        let wire = builder.add_primitive(
127                            format!("{name}_guarded_wire"),
128                            "std_wire",
129                            &[width],
130                        );
131                        // Insert the wire's input port into wire_map
132                        let wire_in = wire.borrow().get("in");
133                        wire_map.insert(
134                            assign.dst.borrow().name,
135                            Rc::clone(&wire_in),
136                        );
137                        // add the wire's output port to the inputs of the
138                        // invoke statement we are building
139                        inputs.push((name, wire.borrow().get("out")));
140                        // return wire_in
141                        wire_in
142                    }
143                };
144                // Use wire_in to add another assignment to combinational group
145                let asmt = builder.build_assignment(
146                    wire_in,
147                    Rc::clone(&assign.src),
148                    *assign.guard.clone(),
149                );
150                comb_assigns.push(asmt);
151            }
152        }
153    }
154
155    let comb_group = if comb_assigns.is_empty() {
156        None
157    } else {
158        let comb_group_ref = builder.add_comb_group("comb_invoke");
159        comb_group_ref
160            .borrow_mut()
161            .assignments
162            .append(&mut comb_assigns);
163        Some(comb_group_ref)
164    };
165
166    ir::Control::Invoke(ir::Invoke {
167        comp,
168        inputs,
169        outputs: Vec::new(),
170        comb_group,
171        attributes: ir::Attributes::default(),
172        ref_cells: Vec::new(),
173    })
174}
175
176impl Visitor for GroupToInvoke {
177    fn start(
178        &mut self,
179        comp: &mut ir::Component,
180        sigs: &ir::LibrarySignatures,
181        _comps: &[ir::Component],
182    ) -> VisResult {
183        let groups = comp.get_groups_mut().drain().collect_vec();
184        let static_groups = comp.get_static_groups_mut().drain().collect_vec();
185        let mut builder = ir::Builder::new(comp, sigs);
186        for g in &groups {
187            self.analyze_group(
188                &mut builder,
189                g.borrow().name(),
190                &g.borrow().assignments,
191                &g.borrow().get("done"),
192            )
193        }
194        // Not transforming static groups rn
195        /*for g in &static_groups {
196            self.analyze_group(
197                &mut builder,
198                g.borrow().name(),
199                &g.borrow().assignments,
200                &g.borrow().get(ir::NumAttr::Done),
201            )
202        }*/
203
204        comp.get_groups_mut().append(groups.into_iter());
205        comp.get_static_groups_mut()
206            .append(static_groups.into_iter());
207
208        Ok(Action::Continue)
209    }
210
211    fn enable(
212        &mut self,
213        s: &mut ir::Enable,
214        _comp: &mut ir::Component,
215        _sigs: &ir::LibrarySignatures,
216        _comps: &[ir::Component],
217    ) -> VisResult {
218        match self.group_invoke_map.get(&s.group.borrow().name()) {
219            None => Ok(Action::Continue),
220            Some(invoke) => {
221                let mut inv = ir::Cloner::control(invoke);
222                let attrs = std::mem::take(&mut s.attributes);
223                *inv.get_mut_attributes() = attrs;
224                Ok(Action::Change(Box::new(inv)))
225            }
226        }
227    }
228}
229
230impl GroupToInvoke {
231    // if g is able to be turned into invoke, then add to self.group_invoke_map
232    fn analyze_group(
233        &mut self,
234        builder: &mut ir::Builder,
235        group_name: ir::Id,
236        assigns: &[ir::Assignment<Nothing>],
237        group_done_port: &ir::RRC<ir::Port>,
238    ) {
239        let mut writes = assigns
240            .iter()
241            .analysis()
242            .cell_writes()
243            .filter(|cell| match cell.borrow().prototype {
244                ir::CellType::Primitive { is_comb, .. } => !is_comb,
245                _ => true,
246            })
247            .collect_vec();
248        // Excluding writes to combinational components, should write to exactly
249        // one cell
250        if writes.len() != 1 {
251            return;
252        }
253
254        // If component is ThisComponent, Reference, or External, don't turn into invoke
255        let cr = writes.pop().unwrap();
256        let cell = cr.borrow();
257        match &cell.prototype {
258            ir::CellType::Primitive { name, .. }
259                if self.blacklist.contains(name) =>
260            {
261                return;
262            }
263            ir::CellType::ThisComponent => return,
264            _ => {}
265        }
266        if cell.is_reference() || cell.attributes.has(ir::BoolAttr::External) {
267            return;
268        }
269
270        // Component must define exactly one @go/@done interface
271        let Ok(Some(go_port)) = cell.find_unique_with_attr(ir::NumAttr::Go)
272        else {
273            return;
274        };
275        let Ok(Some(done_port)) = cell.find_unique_with_attr(ir::NumAttr::Done)
276        else {
277            return;
278        };
279
280        let mut go_wr_cnt = 0;
281        let mut done_wr_cnt = 0;
282
283        'assigns: for assign in assigns {
284            // @go port should have exactly one write and the src should be 1.
285            if assign.dst == go_port {
286                if go_wr_cnt > 0 {
287                    log::info!(
288                        "Cannot transform `{group_name}` due to multiple writes to @go port",
289                    );
290                    return;
291                } else if !assign.guard.is_true() {
292                    log::info!(
293                        "Cannot transform `{}` due to guarded write to @go port: {}",
294                        group_name,
295                        ir::Printer::assignment_to_str(assign)
296                    );
297                    return;
298                } else if assign.src.borrow().is_constant(1, 1) {
299                    go_wr_cnt += 1;
300                } else {
301                    // if go port's guard is not true, src is not (1,1), then
302                    // Continue
303                    continue 'assigns;
304                }
305            }
306            // @done port should have exactly one read and the dst should be
307            // group's done signal.
308            if assign.src == done_port {
309                if done_wr_cnt > 0 {
310                    log::info!(
311                        "Cannot transform `{group_name}` due to multiple writes to @done port",
312                    );
313                    return;
314                } else if !assign.guard.is_true() {
315                    log::info!(
316                        "Cannot transform `{}` due to guarded write to @done port: {}",
317                        group_name,
318                        ir::Printer::assignment_to_str(assign)
319                    );
320                    return;
321                } else if assign.dst == *group_done_port {
322                    done_wr_cnt += 1;
323                } else {
324                    // If done port's guard is not true and does not write to group's done
325                    // then Continue
326                    continue 'assigns;
327                }
328            }
329        }
330        drop(cell);
331
332        if go_wr_cnt != 1 {
333            log::info!(
334                "Cannot transform `{group_name}` because there are no writes to @go port"
335            );
336            return;
337        } else if done_wr_cnt != 1 {
338            log::info!(
339                "Cannot transform `{group_name}` because there are no writes to @done port"
340            );
341            return;
342        }
343
344        self.group_invoke_map
345            .insert(group_name, construct_invoke(assigns, cr, builder));
346    }
347}