calyx_opt/passes/
compile_invoke.rs

1use crate::traversal::{
2    self, Action, ConstructVisitor, Named, VisResult, Visitor,
3};
4use calyx_ir::structure;
5use calyx_ir::{self as ir, Attributes, LibrarySignatures};
6use calyx_utils::{CalyxResult, Error};
7use ir::{Assignment, RRC, WRC};
8use itertools::Itertools;
9use linked_hash_map::LinkedHashMap;
10use std::collections::{HashMap, HashSet};
11use std::rc::Rc;
12
13use super::dump_ports;
14
15// given `cell_ref` returns the `go` port of the cell (if it only has one `go` port),
16// or an error otherwise
17fn get_go_port(cell_ref: ir::RRC<ir::Cell>) -> CalyxResult<ir::RRC<ir::Port>> {
18    let cell = cell_ref.borrow();
19    let name = cell.name();
20
21    // Get the go port
22    match cell.find_unique_with_attr(ir::NumAttr::Go) {
23        Ok(Some(p)) => Ok(p),
24        Ok(None) => Err(Error::malformed_control(format!(
25            "Invoked component `{name}` does not define a @go signal. Cannot compile the invoke",
26        ))),
27        Err(_) => Err(Error::malformed_control(format!(
28            "Invoked component `{name}` defines multiple @go signals. Cannot compile the invoke",
29        ))),
30    }
31}
32
33// given inputs and outputs (of the invoke), and the `enable_assignments` (e.g., invoked_component.go = 1'd1)
34// and a cell, builds the assignments for the corresponding group
35fn build_assignments<T>(
36    inputs: &mut Vec<(ir::Id, ir::RRC<ir::Port>)>,
37    outputs: &mut Vec<(ir::Id, ir::RRC<ir::Port>)>,
38    builder: &mut ir::Builder,
39    cell: &ir::Cell,
40) -> Vec<ir::Assignment<T>> {
41    inputs
42        .drain(..)
43        .map(|(inp, p)| {
44            builder.build_assignment(cell.get(inp), p, ir::Guard::True)
45        })
46        .chain(outputs.drain(..).map(|(out, p)| {
47            builder.build_assignment(p, cell.get(out), ir::Guard::True)
48        }))
49        .collect()
50}
51
52#[derive(Default)]
53/// Map for storing added ports for each ref cell
54/// level of Hashmap represents:
55/// HashMap<-component name-, Hashmap<(-ref cell name-,-port name-), port>>;
56struct RefPortMap(HashMap<ir::Id, LinkedHashMap<ir::Canonical, RRC<ir::Port>>>);
57
58impl RefPortMap {
59    fn insert(
60        &mut self,
61        comp_name: ir::Id,
62        ports: LinkedHashMap<ir::Canonical, RRC<ir::Port>>,
63    ) {
64        self.0.insert(comp_name, ports);
65    }
66
67    fn get(
68        &self,
69        comp_name: &ir::Id,
70    ) -> Option<&LinkedHashMap<ir::Canonical, RRC<ir::Port>>> {
71        self.0.get(comp_name)
72    }
73
74    /// Get all of the newly added ports associated with a component that had
75    /// ref cells
76    fn get_ports(&self, comp_name: &ir::Id) -> Option<Vec<RRC<ir::Port>>> {
77        self.0.get(comp_name).map(|map| {
78            map.values()
79                .cloned()
80                .sorted_by(|a, b| a.borrow().name.cmp(&b.borrow().name))
81                .collect()
82        })
83    }
84}
85
86/// Compiles [`ir::Invoke`] statements into an [`ir::Enable`] that runs the
87/// invoked component.
88pub struct CompileInvoke {
89    /// Mapping from component to the canonical port name of ref cell o
90    port_names: RefPortMap,
91    /// Mapping from the ports of cells that were removed to the new port on the
92    /// component signature.
93    removed: LinkedHashMap<ir::Canonical, ir::RRC<ir::Port>>,
94    /// Ref cells in the component. We hold onto these so that our references don't get invalidated
95    ref_cells: Vec<ir::RRC<ir::Cell>>,
96}
97
98impl ConstructVisitor for CompileInvoke {
99    fn from(_ctx: &ir::Context) -> CalyxResult<Self>
100    where
101        Self: Sized,
102    {
103        Ok(CompileInvoke {
104            port_names: RefPortMap::default(),
105            removed: LinkedHashMap::new(),
106            ref_cells: Vec::new(),
107        })
108    }
109
110    fn clear_data(&mut self) {
111        self.removed.clear();
112        self.ref_cells.clear()
113    }
114}
115
116impl Named for CompileInvoke {
117    fn name() -> &'static str {
118        "compile-invoke"
119    }
120
121    fn description() -> &'static str {
122        "Rewrites invoke statements to group enables"
123    }
124}
125
126impl CompileInvoke {
127    /// Given `ref_cells` of an invoke, returns `(inputs, outputs)` where
128    /// inputs are the corresponding inputs to the `invoke` and
129    /// outputs are the corresponding outputs to the `invoke` that are used
130    /// in the component the ref_cell is.
131    /// (i.e. If a component only reads from a register,
132    /// only assignments for `reg.out` will be returned.)
133    ///
134    /// Since this pass eliminates all ref cells in post order, we expect that
135    /// invoked component already had all of its ref cells removed.
136    fn ref_cells_to_ports_assignments<T>(
137        &mut self,
138        inv_cell: RRC<ir::Cell>,
139        ref_cells: impl Iterator<Item = (ir::Id, ir::RRC<ir::Cell>)>,
140        invoked_comp: Option<&ir::Component>, //i.e. in invoke reader[]()(); this is `reader`
141    ) -> Vec<ir::Assignment<T>> {
142        let inv_comp_id = inv_cell.borrow().type_name().unwrap();
143        let mut assigns = Vec::new();
144        for (ref_cell_name, concrete_cell) in ref_cells {
145            log::debug!(
146                "Removing ref cell `{}` with {} ports",
147                ref_cell_name,
148                concrete_cell.borrow().ports.len()
149            );
150
151            // comp_ports is mapping from canonical names of the ports of the ref cell to the
152            // new port defined on the signature of the higher level component.
153            // i.e. ref_reg.in -> ref_reg_in
154            // These have name of ref cell, not the cell passed in as an arugment
155            let Some(comp_ports) = self.port_names.get(&inv_comp_id) else {
156                unreachable!(
157                    "component `{}` invoked but not already visited by the pass",
158                    inv_comp_id
159                )
160            };
161
162            // tracks ports used in assigments of the invoked component
163            let mut used_ports: HashSet<ir::Id> = HashSet::new();
164            if let Some(invoked_comp) = invoked_comp {
165                invoked_comp.iter_assignments(|a| {
166                    for port in a.iter_ports() {
167                        used_ports.insert(port.borrow().name);
168                    }
169                });
170                invoked_comp.iter_static_assignments(|a| {
171                    for port in a.iter_ports() {
172                        used_ports.insert(port.borrow().name);
173                    }
174                });
175            // If the `invoked_comp` passed to the function is `None`,
176            // then the component being invoked is a primitive.
177            } else {
178                unreachable!(
179                    "Primitives should not have ref cells passed into them at invocation. However ref cells were found at the invocation of {}.",
180                    inv_comp_id
181                );
182            }
183
184            //contains the newly added ports that result from ref cells removal/dump_ports
185            let new_comp_ports = comp_ports
186                .values()
187                .map(|p| p.borrow().name)
188                .collect::<HashSet<_>>();
189
190            let to_assign: HashSet<&ir::Id> =
191                new_comp_ports.intersection(&used_ports).collect();
192
193            // We expect each canonical port in `comp_ports` to exactly match with a port in
194            //`concrete_cell` based on well-formedness subtype checks.
195            // `canon` is `ref_reg.in`, for example.
196            for (ref_cell_canon, new_sig_port) in comp_ports.iter() {
197                //only interested in ports attached to the ref cell
198                if ref_cell_canon.cell != ref_cell_name {
199                    continue;
200                }
201
202                // For example, if we have a reader component that only reads from a ref_reg,
203                // we will not have `ref_reg.in = ...` in the invoke* group because the
204                // reader component does not access `ref_reg.in`.
205                if !to_assign.contains(&new_sig_port.borrow().name) {
206                    continue;
207                }
208
209                // The given port of the actual, concrete cell passed in
210                let concrete_port = Self::get_concrete_port(
211                    concrete_cell.clone(),
212                    &ref_cell_canon.port,
213                );
214
215                if concrete_port.borrow().has_attribute(ir::BoolAttr::Clk)
216                    || concrete_port.borrow().has_attribute(ir::BoolAttr::Reset)
217                {
218                    continue;
219                }
220
221                let Some(comp_port) = comp_ports.get(ref_cell_canon) else {
222                    unreachable!(
223                        "port `{}` not found in the signature of {}. Known ports are: {}",
224                        ref_cell_canon,
225                        inv_comp_id,
226                        comp_ports
227                            .keys()
228                            .map(|c| c.port.as_ref())
229                            .collect_vec()
230                            .join(", ")
231                    )
232                };
233                // Get the port on the new cell with the same name as ref_port
234                let ref_port = inv_cell.borrow().get(comp_port.borrow().name);
235                log::debug!(
236                    "Port `{}` -> `{}`",
237                    ref_cell_canon,
238                    ref_port.borrow().name
239                );
240
241                let old_port = concrete_port.borrow().canonical();
242                // If the port has been removed already, get the new port from the component's signature
243                let arg_port = if let Some(sig_pr) = self.removed.get(&old_port)
244                {
245                    log::debug!(
246                        "Port `{}` has been removed. Using `{}`",
247                        old_port,
248                        sig_pr.borrow().name
249                    );
250                    Rc::clone(sig_pr)
251                } else {
252                    Rc::clone(&concrete_port)
253                };
254
255                //Create assignments from dst to src
256                let dst: RRC<ir::Port>;
257                let src: RRC<ir::Port>;
258                match concrete_port.borrow().direction {
259                    ir::Direction::Output => {
260                        dst = ref_port.clone();
261                        src = arg_port;
262                    }
263                    ir::Direction::Input => {
264                        dst = arg_port;
265                        src = ref_port.clone();
266                    }
267                    _ => {
268                        unreachable!("Cell should have inout ports");
269                    }
270                };
271                log::debug!(
272                    "constructing: {} = {}",
273                    dst.borrow().canonical(),
274                    src.borrow().canonical(),
275                );
276                assigns.push(ir::Assignment::new(dst, src));
277            }
278        }
279        assigns
280    }
281
282    /// Takes in a concrete cell (aka an in_cell/what is passed in to a ref cell at invocation)
283    /// and returns the concrete port based on just the port of a canonical id.
284    fn get_concrete_port(
285        concrete_cell: RRC<ir::Cell>,
286        canonical_port: &ir::Id,
287    ) -> RRC<ir::Port> {
288        let concrete_cell = concrete_cell.borrow();
289        concrete_cell
290            .ports
291            .iter()
292            .find(|&concrete_cell_port| {
293                concrete_cell_port.borrow().name == canonical_port
294            })
295            .unwrap_or_else(|| {
296                unreachable!(
297                    "port `{}` not found in the cell `{}`",
298                    canonical_port,
299                    concrete_cell.name()
300                )
301            })
302            .clone()
303    }
304}
305
306impl Visitor for CompileInvoke {
307    fn iteration_order() -> crate::traversal::Order
308    where
309        Self: Sized,
310    {
311        traversal::Order::Post
312    }
313
314    fn start(
315        &mut self,
316        comp: &mut ir::Component,
317        _sigs: &LibrarySignatures,
318        _comps: &[ir::Component],
319    ) -> VisResult {
320        log::debug!("Visiting `{}`", comp.name);
321        // For all subcomponents that had a `ref` cell in them, we need to
322        // update their cell to have the new ports added from inlining the
323        // signatures of all the ref cells.
324        for cell in comp.cells.iter() {
325            let mut new_ports: Vec<RRC<ir::Port>> = Vec::new();
326            if let Some(name) = cell.borrow().type_name() {
327                if let Some(ports) = self.port_names.get_ports(&name) {
328                    log::debug!(
329                        "Updating ports of cell `{}' (type `{name}')",
330                        cell.borrow().name()
331                    );
332                    for p in ports.iter() {
333                        let new_port = ir::rrc(ir::Port {
334                            name: p.borrow().name,
335                            width: p.borrow().width,
336                            direction: p.borrow().direction.reverse(),
337                            parent: ir::PortParent::Cell(WRC::from(cell)),
338                            attributes: Attributes::default(),
339                        });
340                        new_ports.push(new_port);
341                    }
342                }
343            }
344            cell.borrow_mut().ports.extend(new_ports);
345        }
346
347        let dump_ports::DumpResults { cells, rewrites } =
348            dump_ports::dump_ports_to_signature(
349                comp,
350                |cell| cell.borrow().is_reference(),
351                true,
352            );
353
354        // Hold onto the cells so they don't get dropped.
355        self.ref_cells = cells;
356        self.removed = rewrites;
357
358        Ok(Action::Continue)
359    }
360
361    fn invoke(
362        &mut self,
363        s: &mut ir::Invoke,
364        comp: &mut ir::Component,
365        ctx: &LibrarySignatures,
366        comps: &[ir::Component],
367    ) -> VisResult {
368        let mut builder = ir::Builder::new(comp, ctx);
369        let invoke_group = builder.add_group("invoke");
370
371        //get iterator of comps of ref_cells used in the invoke
372        let invoked_comp: Option<&ir::Component> = comps
373            .iter()
374            .find(|&c| s.comp.borrow().prototype.get_name().unwrap() == c.name);
375
376        // Assigns representing the ref cell connections
377        invoke_group.borrow_mut().assignments.extend(
378            self.ref_cells_to_ports_assignments(
379                Rc::clone(&s.comp),
380                s.ref_cells.drain(..),
381                invoked_comp,
382            ),
383            //the clone here is questionable? but lets things type check? Maybe change ref_cells_to_ports to expect a reference?
384        );
385
386        // comp.go = 1'd1;
387        // invoke[done] = comp.done;
388        structure!(builder;
389            let one = constant(1, 1);
390        );
391        let cell = s.comp.borrow();
392        let go_port = get_go_port(Rc::clone(&s.comp))?;
393        let done_port = cell.find_unique_with_attr(ir::NumAttr::Done)?.unwrap();
394
395        // Build assignemnts
396        let go_assign = builder.build_assignment(
397            go_port,
398            one.borrow().get("out"),
399            ir::Guard::True,
400        );
401        let done_assign = builder.build_assignment(
402            invoke_group.borrow().get("done"),
403            done_port,
404            ir::Guard::True,
405        );
406
407        invoke_group
408            .borrow_mut()
409            .assignments
410            .extend(vec![go_assign, done_assign]);
411
412        // Generate argument assignments
413        let cell = &*s.comp.borrow();
414        let assigns = build_assignments(
415            &mut s.inputs,
416            &mut s.outputs,
417            &mut builder,
418            cell,
419        );
420        invoke_group.borrow_mut().assignments.extend(assigns);
421        // Add assignments from the attached combinational group
422        if let Some(cgr) = &s.comb_group {
423            let cg = &*cgr.borrow();
424            invoke_group
425                .borrow_mut()
426                .assignments
427                .extend(cg.assignments.iter().cloned())
428        }
429
430        // Copy "promotable" annotation from the `invoke` statement if present
431        if let Some(time) = s.attributes.get(ir::NumAttr::Promotable) {
432            invoke_group
433                .borrow_mut()
434                .attributes
435                .insert(ir::NumAttr::Promotable, time);
436        }
437
438        let mut en = ir::Enable {
439            group: invoke_group,
440            attributes: std::mem::take(&mut s.attributes),
441        };
442        if let Some(time) = s.attributes.get(ir::NumAttr::Promotable) {
443            en.attributes.insert(ir::NumAttr::Promotable, time);
444        }
445
446        Ok(Action::change(ir::Control::Enable(en)))
447    }
448
449    fn static_invoke(
450        &mut self,
451        s: &mut ir::StaticInvoke,
452        comp: &mut ir::Component,
453        ctx: &LibrarySignatures,
454        comps: &[ir::Component],
455    ) -> VisResult {
456        let mut builder = ir::Builder::new(comp, ctx);
457        let invoke_group = builder.add_static_group("static_invoke", s.latency);
458
459        //If the component is not a primitive, pass along the component to `ref_cells_to_ports``
460        let invoked_comp: Option<&ir::Component> = comps
461            .iter()
462            .find(|&c| s.comp.borrow().prototype.get_name().unwrap() == c.name);
463
464        invoke_group.borrow_mut().assignments.extend(
465            self.ref_cells_to_ports_assignments(
466                Rc::clone(&s.comp),
467                s.ref_cells.drain(..),
468                invoked_comp,
469            ),
470        );
471
472        // comp.go = 1'd1;
473        structure!(builder;
474            let one = constant(1, 1);
475        );
476
477        // Get the go port
478        let go_port = get_go_port(Rc::clone(&s.comp))?;
479
480        // Checks whether compe is a static<n> component or an @interval(n) component.
481        let go_guard = if s
482            .comp
483            .borrow()
484            .ports
485            .iter()
486            .any(|port| port.borrow().attributes.has(ir::NumAttr::Interval))
487        {
488            // For @interval(n) components, we do not guard the comp.go
489            // We trigger the go signal for the entire interval.
490            ir::Guard::True
491        } else {
492            // For static<n> components, we guard the comp.go with %[0:1]
493            ir::Guard::Info(ir::StaticTiming::new((0, 1)))
494        };
495
496        // Build assignemnts
497        let go_assign: ir::Assignment<ir::StaticTiming> = builder
498            .build_assignment(go_port, one.borrow().get("out"), go_guard);
499        invoke_group.borrow_mut().assignments.push(go_assign);
500
501        // Generate argument assignments
502        let cell = &*s.comp.borrow();
503        let assigns = build_assignments(
504            &mut s.inputs,
505            &mut s.outputs,
506            &mut builder,
507            cell,
508        );
509        invoke_group.borrow_mut().assignments.extend(assigns);
510
511        if let Some(cgr) = &s.comb_group {
512            let cg = &*cgr.borrow();
513            invoke_group.borrow_mut().assignments.extend(
514                cg.assignments
515                    .iter()
516                    .cloned()
517                    .map(Assignment::from)
518                    .collect_vec(),
519            );
520        }
521
522        let en = ir::StaticEnable {
523            group: invoke_group,
524            attributes: std::mem::take(&mut s.attributes),
525        };
526
527        Ok(Action::StaticChange(Box::new(ir::StaticControl::Enable(
528            en,
529        ))))
530    }
531
532    fn finish(
533        &mut self,
534        comp: &mut ir::Component,
535        _sigs: &LibrarySignatures,
536        _comps: &[ir::Component],
537    ) -> VisResult {
538        let port_map = std::mem::take(&mut self.removed);
539        // Add the newly added port to the global port map
540        // Rewrite all of the ref cell ports
541        let rw = ir::Rewriter {
542            port_map,
543            ..Default::default()
544        };
545        rw.rewrite(comp);
546        self.port_names.insert(comp.name, rw.port_map);
547        Ok(Action::Continue)
548    }
549}