calyx_opt/passes/
simplify_with_control.rs

1use crate::analysis;
2use crate::traversal::{Action, Named, VisResult, Visitor};
3use calyx_frontend::SetAttr;
4use calyx_ir::{self as ir, GetAttributes, LibrarySignatures, RRC, structure};
5use calyx_utils::{CalyxResult, Error};
6use std::collections::HashMap;
7use std::rc::Rc;
8
9#[derive(Default)]
10/// Transforms combinational groups into normal groups by registering the values
11/// read from the ports of cells used within the combinational group.
12///
13/// It also transforms (if,while)-with into semantically equivalent control programs that
14/// first enable a group that calculates and registers the ports defined by the combinational group
15/// execute the respective cond group and then execute the control operator.
16///
17/// # Example
18/// ```
19/// group comb_cond<"static"=0> {
20///     lt.right = 32'd10;
21///     lt.left = 32'd1;
22///     eq.right = r.out;
23///     eq.left = x.out;
24///     comb_cond[done] = 1'd1;
25/// }
26/// control {
27///     if lt.out with comb_cond {
28///         ...
29///     }
30///     while eq.out with comb_cond {
31///         ...
32///     }
33/// }
34/// ```
35/// into:
36/// ```
37/// group comb_cond<"static"=1> {
38///     lt.right = 32'd10;
39///     lt.left = 32'd1;
40///     eq.right = r.out;
41///     eq.left = x.out;
42///     lt_reg.in = lt.out
43///     lt_reg.write_en = 1'd1;
44///     eq_reg.in = eq.out;
45///     eq_reg.write_en = 1'd1;
46///     comb_cond[done] = lt_reg.done & eq_reg.done ? 1'd1;
47/// }
48/// control {
49///     seq {
50///       comb_cond;
51///       if lt_reg.out {
52///           ...
53///       }
54///     }
55///     seq {
56///       comb_cond;
57///       while eq_reg.out {
58///           ...
59///           comb_cond;
60///       }
61///     }
62/// }
63/// ```
64pub struct SimplifyWithControl {
65    // Mapping from (group_name, (cell_name, port_name)) -> (port, static_group).
66    port_rewrite: HashMap<PortInGroup, (RRC<ir::Port>, RRC<ir::StaticGroup>)>,
67}
68
69/// Represents (group_name, (cell_name, port_name))
70type PortInGroup = (ir::Id, ir::Canonical);
71
72impl Named for SimplifyWithControl {
73    fn name() -> &'static str {
74        "simplify-with-control"
75    }
76
77    fn description() -> &'static str {
78        "Transforms if-with and while-with to if and while"
79    }
80}
81
82impl Visitor for SimplifyWithControl {
83    fn start(
84        &mut self,
85        comp: &mut ir::Component,
86        sigs: &LibrarySignatures,
87        _comps: &[ir::Component],
88    ) -> VisResult {
89        let mut used_ports =
90            analysis::ControlPorts::<false>::from(&*comp.control.borrow());
91
92        // Early return if there are no combinational groups
93        if comp.comb_groups.is_empty() {
94            return Ok(Action::Stop);
95        }
96
97        // Detach the combinational groups from the component
98        let comb_groups = std::mem::take(&mut comp.comb_groups);
99        let mut builder = ir::Builder::new(comp, sigs);
100
101        // Groups generated by transforming combinational groups
102        let groups = comb_groups
103            .iter()
104            .map(|cg_ref| {
105                let name = cg_ref.borrow().name();
106                // Register the ports read by the combinational group's usages.
107                let used_ports = used_ports.remove(&name).ok_or_else(|| {
108                    Error::malformed_structure(format!(
109                        "values from combinational group `{name}` never used"
110                    ))
111                })?;
112
113                // Group generated to replace this comb group.
114                let group_ref = builder.add_static_group(name, 1);
115                let mut group = group_ref.borrow_mut();
116                // Attach assignmens from comb group
117                group.assignments = cg_ref
118                    .borrow_mut()
119                    .assignments
120                    .clone()
121                    .into_iter()
122                    .map(|x| x.into())
123                    .collect();
124
125                // Registers to save value for the group
126                let mut save_regs = Vec::with_capacity(used_ports.len());
127                for port in used_ports {
128                    // Register to save port value
129                    structure!(builder;
130                        let comb_reg = prim std_reg(port.borrow().width);
131                        let signal_on = constant(1, 1);
132                    );
133                    let write = builder.build_assignment(
134                        comb_reg.borrow().get("in"),
135                        Rc::clone(&port),
136                        ir::Guard::True,
137                    );
138                    let en = builder.build_assignment(
139                        comb_reg.borrow().get("write_en"),
140                        signal_on.borrow().get("out"),
141                        ir::Guard::True,
142                    );
143                    group.assignments.push(write);
144                    group.assignments.push(en);
145
146                    // Define mapping from this port to the register's output
147                    // value.
148                    self.port_rewrite.insert(
149                        (name, port.borrow().canonical().clone()),
150                        (
151                            Rc::clone(&comb_reg.borrow().get("out")),
152                            Rc::clone(&group_ref),
153                        ),
154                    );
155
156                    save_regs.push(comb_reg);
157                }
158
159                // No need for a done condition
160                drop(group);
161
162                Ok(group_ref)
163            })
164            .collect::<CalyxResult<Vec<_>>>()?;
165
166        for group in groups {
167            comp.get_static_groups_mut().add(group)
168        }
169
170        // Restore the combinational groups
171        comp.comb_groups = comb_groups;
172
173        Ok(Action::Continue)
174    }
175
176    fn finish_while(
177        &mut self,
178        s: &mut ir::While,
179        _comp: &mut ir::Component,
180        _sigs: &LibrarySignatures,
181        _comps: &[ir::Component],
182    ) -> VisResult {
183        if s.cond.is_none() {
184            return Ok(Action::Continue);
185        }
186
187        // Construct a new `while` statement
188        let key = (
189            s.cond.as_ref().unwrap().borrow().name(),
190            s.port.borrow().canonical(),
191        );
192        let (port_ref, cond_ref) = self.port_rewrite.get(&key).unwrap();
193        let mut cond_in_body = ir::Control::static_enable(Rc::clone(cond_ref));
194        cond_in_body
195            .get_mut_attributes()
196            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
197        let body = std::mem::replace(s.body.as_mut(), ir::Control::empty());
198        let mut new_body = ir::Control::seq(vec![body, cond_in_body]);
199        new_body
200            .get_mut_attributes()
201            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
202        let mut while_ =
203            ir::Control::while_(Rc::clone(port_ref), None, Box::new(new_body));
204        let attrs = while_.get_mut_attributes();
205        *attrs = s.attributes.clone();
206        let mut cond_before_body =
207            ir::Control::static_enable(Rc::clone(cond_ref));
208        cond_before_body
209            .get_mut_attributes()
210            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
211
212        let mut new_seq = ir::Control::seq(vec![cond_before_body, while_]);
213        new_seq
214            .get_mut_attributes()
215            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
216
217        Ok(Action::change(new_seq))
218    }
219
220    /// Transforms a `if-with` into a `seq-if` which first runs the cond group
221    /// and then the branch.
222    fn finish_if(
223        &mut self,
224        s: &mut ir::If,
225        _comp: &mut ir::Component,
226        _sigs: &LibrarySignatures,
227        _comps: &[ir::Component],
228    ) -> VisResult {
229        if s.cond.is_none() {
230            return Ok(Action::Continue);
231        }
232        // Construct a new `if` statement
233        let key = (
234            s.cond.as_ref().unwrap().borrow().name(),
235            s.port.borrow().canonical(),
236        );
237        let (port_ref, cond_ref) =
238            self.port_rewrite.get(&key).unwrap_or_else(|| {
239                panic!(
240                    "{}: Port `{}` in group `{}` doesn't have a rewrite",
241                    Self::name(),
242                    key.1,
243                    key.0
244                )
245            });
246        let tbranch =
247            std::mem::replace(s.tbranch.as_mut(), ir::Control::empty());
248        let fbranch =
249            std::mem::replace(s.fbranch.as_mut(), ir::Control::empty());
250        let mut if_ = ir::Control::if_(
251            Rc::clone(port_ref),
252            None,
253            Box::new(tbranch),
254            Box::new(fbranch),
255        );
256        let attrs = if_.get_mut_attributes();
257        *attrs = s.attributes.clone();
258
259        let mut cond = ir::Control::static_enable(Rc::clone(cond_ref));
260        cond.get_mut_attributes()
261            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
262
263        let mut new_seq = ir::Control::seq(vec![cond, if_]);
264        new_seq
265            .get_mut_attributes()
266            .copy_from_set(&s.attributes, vec![SetAttr::Pos]);
267
268        Ok(Action::change(new_seq))
269    }
270
271    fn finish(
272        &mut self,
273        comp: &mut ir::Component,
274        _sigs: &LibrarySignatures,
275        _comps: &[ir::Component],
276    ) -> VisResult {
277        if comp.is_static() {
278            let msg = format!(
279                "Static Component {} has combinational groups which is not supported",
280                comp.name
281            );
282            return Err(Error::pass_assumption(Self::name(), msg)
283                .with_pos(&comp.attributes));
284        }
285        Ok(Action::Continue)
286    }
287}