calyx_opt/passes/
group_to_invoke.rs

1use crate::analysis::AssignmentAnalysis;
2use crate::traversal::{Action, ConstructVisitor, Named, VisResult, Visitor};
3use calyx_ir::{self as ir};
4use calyx_ir::{GetAttributes, RRC};
5use calyx_utils::CalyxResult;
6use ir::Nothing;
7use itertools::Itertools;
8use std::collections::{HashMap, HashSet};
9use std::rc::Rc;
10
11/// Transform groups that are structurally invoking components into equivalent
12/// [ir::Invoke] statements.
13///
14/// For a group to meet the requirements of this pass, it must
15/// 1. Only write to one non-combinational component (all other writes must be
16///    to combinational primitives)
17/// 2. That component is *not* a ref cell, nor does it have the external attribute,
18///    nor is it This Component
19/// 3. Assign component.go = 1'd1
20/// 4. Assign group[done] = component.done
21pub struct GroupToInvoke {
22    /// Primitives that have multiple @go-@done signals
23    blacklist: HashSet<ir::Id>,
24    /// Maps names of group to the invokes that will replace them
25    group_invoke_map: HashMap<ir::Id, ir::Control>,
26}
27
28impl ConstructVisitor for GroupToInvoke {
29    fn from(ctx: &ir::Context) -> CalyxResult<Self>
30    where
31        Self: Sized,
32    {
33        // Construct list of primitives that have multiple go-done signals
34        let blacklist = ctx
35            .lib
36            .signatures()
37            .filter(|p| p.find_all_with_attr(ir::NumAttr::Go).count() > 1)
38            .map(|p| p.name)
39            .collect();
40
41        Ok(Self {
42            blacklist,
43            group_invoke_map: HashMap::new(),
44        })
45    }
46
47    fn clear_data(&mut self) {
48        self.group_invoke_map = HashMap::new();
49    }
50}
51
52impl Named for GroupToInvoke {
53    fn name() -> &'static str {
54        "group2invoke"
55    }
56
57    fn description() -> &'static str {
58        "covert groups that structurally invoke one component into invoke statements"
59    }
60}
61
62/// Construct an [ir::Invoke] from an [ir::Group] that has been validated by this pass.
63fn construct_invoke(
64    assigns: &[ir::Assignment<Nothing>],
65    comp: RRC<ir::Cell>,
66    builder: &mut ir::Builder,
67) -> ir::Control {
68    // Check if port's parent is a combinational primitive
69    let parent_is_comb = |port: &ir::Port| -> bool {
70        if !port.is_hole() {
71            if let ir::CellType::Primitive { is_comb, .. } =
72                port.cell_parent().borrow().prototype
73            {
74                return is_comb;
75            }
76        }
77        false
78    };
79
80    // Check if port's parent is equal to comp
81    let parent_is_cell = |port: &ir::Port| -> bool {
82        match &port.parent {
83            ir::PortParent::Cell(cell_wref) => {
84                Rc::ptr_eq(&cell_wref.upgrade(), &comp)
85            }
86            _ => false,
87        }
88    };
89
90    let mut inputs = Vec::new();
91    let mut comb_assigns = Vec::new();
92    let mut wire_map: HashMap<ir::Id, ir::RRC<ir::Port>> = HashMap::new();
93
94    for assign in assigns {
95        // We know that all assignments in this group should write to either a)
96        // a combinational component or b) comp or c) the group's done port-- we
97        // should have checked for this condition before calling this function
98
99        // If a combinational component's port is being used as a dest, add
100        // it to comb_assigns
101        if parent_is_comb(&assign.dst.borrow()) {
102            comb_assigns.push(assign.clone());
103        }
104        // If the cell's port is being used as a dest, add the source to
105        // inputs. we can ignore the cell.go assignment, since that is not
106        // going to be part of the `invoke`.
107        else if parent_is_cell(&assign.dst.borrow())
108            && assign.dst
109                != comp.borrow().get_unique_with_attr(ir::NumAttr::Go).unwrap()
110        {
111            let name = assign.dst.borrow().name;
112            if assign.guard.is_true() {
113                inputs.push((name, Rc::clone(&assign.src)));
114            } else {
115                // assign has a guard condition,so need a wire
116                // We first check whether we have already built a wire
117                // for this port or not.
118                let wire_in = match wire_map.get(&assign.dst.borrow().name) {
119                    Some(w) => {
120                        // Already built a wire, so just need to return the
121                        // wire's input port (which wire_map stores)
122                        Rc::clone(w)
123                    }
124                    None => {
125                        // Need to create a new wire
126                        let width = assign.dst.borrow().width;
127                        let wire = builder.add_primitive(
128                            format!("{name}_guarded_wire"),
129                            "std_wire",
130                            &[width],
131                        );
132                        // Insert the wire's input port into wire_map
133                        let wire_in = wire.borrow().get("in");
134                        wire_map.insert(
135                            assign.dst.borrow().name,
136                            Rc::clone(&wire_in),
137                        );
138                        // add the wire's output port to the inputs of the
139                        // invoke statement we are building
140                        inputs.push((name, wire.borrow().get("out")));
141                        // return wire_in
142                        wire_in
143                    }
144                };
145                // Use wire_in to add another assignment to combinational group
146                let asmt = builder.build_assignment(
147                    wire_in,
148                    Rc::clone(&assign.src),
149                    *assign.guard.clone(),
150                );
151                comb_assigns.push(asmt);
152            }
153        }
154    }
155
156    let comb_group = if comb_assigns.is_empty() {
157        None
158    } else {
159        let comb_group_ref = builder.add_comb_group("comb_invoke");
160        comb_group_ref
161            .borrow_mut()
162            .assignments
163            .append(&mut comb_assigns);
164        Some(comb_group_ref)
165    };
166
167    ir::Control::Invoke(ir::Invoke {
168        comp,
169        inputs,
170        outputs: Vec::new(),
171        comb_group,
172        attributes: ir::Attributes::default(),
173        ref_cells: Vec::new(),
174    })
175}
176
177impl Visitor for GroupToInvoke {
178    fn start(
179        &mut self,
180        comp: &mut ir::Component,
181        sigs: &ir::LibrarySignatures,
182        _comps: &[ir::Component],
183    ) -> VisResult {
184        let groups = comp.get_groups_mut().drain().collect_vec();
185        let static_groups = comp.get_static_groups_mut().drain().collect_vec();
186        let mut builder = ir::Builder::new(comp, sigs);
187        for g in &groups {
188            self.analyze_group(
189                &mut builder,
190                g.borrow().name(),
191                &g.borrow().assignments,
192                &g.borrow().get("done"),
193            )
194        }
195        // Not transforming static groups rn
196        /*for g in &static_groups {
197            self.analyze_group(
198                &mut builder,
199                g.borrow().name(),
200                &g.borrow().assignments,
201                &g.borrow().get(ir::NumAttr::Done),
202            )
203        }*/
204
205        comp.get_groups_mut().append(groups.into_iter());
206        comp.get_static_groups_mut()
207            .append(static_groups.into_iter());
208
209        Ok(Action::Continue)
210    }
211
212    fn enable(
213        &mut self,
214        s: &mut ir::Enable,
215        _comp: &mut ir::Component,
216        _sigs: &ir::LibrarySignatures,
217        _comps: &[ir::Component],
218    ) -> VisResult {
219        match self.group_invoke_map.get(&s.group.borrow().name()) {
220            None => Ok(Action::Continue),
221            Some(invoke) => {
222                let mut inv = ir::Cloner::control(invoke);
223                let attrs = std::mem::take(&mut s.attributes);
224                *inv.get_mut_attributes() = attrs;
225                Ok(Action::Change(Box::new(inv)))
226            }
227        }
228    }
229}
230
231impl GroupToInvoke {
232    // if g is able to be turned into invoke, then add to self.group_invoke_map
233    fn analyze_group(
234        &mut self,
235        builder: &mut ir::Builder,
236        group_name: ir::Id,
237        assigns: &[ir::Assignment<Nothing>],
238        group_done_port: &ir::RRC<ir::Port>,
239    ) {
240        let mut writes = assigns
241            .iter()
242            .analysis()
243            .cell_writes()
244            .filter(|cell| match cell.borrow().prototype {
245                ir::CellType::Primitive { is_comb, .. } => !is_comb,
246                _ => true,
247            })
248            .collect_vec();
249        // Excluding writes to combinational components, should write to exactly
250        // one cell
251        if writes.len() != 1 {
252            return;
253        }
254
255        // If component is ThisComponent, Reference, or External, don't turn into invoke
256        let cr = writes.pop().unwrap();
257        let cell = cr.borrow();
258        match &cell.prototype {
259            ir::CellType::Primitive { name, .. }
260                if self.blacklist.contains(name) =>
261            {
262                return;
263            }
264            ir::CellType::ThisComponent => return,
265            _ => {}
266        }
267        if cell.is_reference() || cell.attributes.has(ir::BoolAttr::External) {
268            return;
269        }
270
271        // Component must define exactly one @go/@done interface
272        let Ok(Some(go_port)) = cell.find_unique_with_attr(ir::NumAttr::Go)
273        else {
274            return;
275        };
276        let Ok(Some(done_port)) = cell.find_unique_with_attr(ir::NumAttr::Done)
277        else {
278            return;
279        };
280
281        let mut go_wr_cnt = 0;
282        let mut done_wr_cnt = 0;
283
284        'assigns: for assign in assigns {
285            // @go port should have exactly one write and the src should be 1.
286            if assign.dst == go_port {
287                if go_wr_cnt > 0 {
288                    log::info!(
289                        "Cannot transform `{group_name}` due to multiple writes to @go port",
290                    );
291                    return;
292                } else if !assign.guard.is_true() {
293                    log::info!(
294                        "Cannot transform `{}` due to guarded write to @go port: {}",
295                        group_name,
296                        ir::Printer::assignment_to_str(assign)
297                    );
298                    return;
299                } else if assign.src.borrow().is_constant(1, 1) {
300                    go_wr_cnt += 1;
301                } else {
302                    // if go port's guard is not true, src is not (1,1), then
303                    // Continue
304                    continue 'assigns;
305                }
306            }
307            // @done port should have exactly one read and the dst should be
308            // group's done signal.
309            if assign.src == done_port {
310                if done_wr_cnt > 0 {
311                    log::info!(
312                        "Cannot transform `{group_name}` due to multiple writes to @done port",
313                    );
314                    return;
315                } else if !assign.guard.is_true() {
316                    log::info!(
317                        "Cannot transform `{}` due to guarded write to @done port: {}",
318                        group_name,
319                        ir::Printer::assignment_to_str(assign)
320                    );
321                    return;
322                } else if assign.dst == *group_done_port {
323                    done_wr_cnt += 1;
324                } else {
325                    // If done port's guard is not true and does not write to group's done
326                    // then Continue
327                    continue 'assigns;
328                }
329            }
330        }
331        drop(cell);
332
333        if go_wr_cnt != 1 {
334            log::info!(
335                "Cannot transform `{group_name}` because there are no writes to @go port"
336            );
337            return;
338        } else if done_wr_cnt != 1 {
339            log::info!(
340                "Cannot transform `{group_name}` because there are no writes to @done port"
341            );
342            return;
343        }
344
345        self.group_invoke_map
346            .insert(group_name, construct_invoke(assigns, cr, builder));
347    }
348}