calyx_opt/passes/
comb_prop.rs

1use crate::traversal::{
2    Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor,
3};
4use calyx_ir::{self as ir, RRC};
5use itertools::Itertools;
6use std::rc::Rc;
7
8/// A data structure to track rewrites of ports with added functionality to declare
9/// two wires to be "equal" when they are connected together.
10#[derive(Default, Clone)]
11struct WireRewriter {
12    rewrites: ir::rewriter::PortRewriteMap,
13}
14
15impl WireRewriter {
16    // If the destination is a wire, then we have something like:
17    // ```
18    // wire.in = c.out;
19    // ```
20    // Which means all instances of `wire.out` can be replaced with `c.out` because the
21    // wire is being used to forward values from `c.out`.
22    pub fn insert_src_rewrite(
23        &mut self,
24        wire: RRC<ir::Cell>,
25        src: RRC<ir::Port>,
26    ) {
27        let wire_out = wire.borrow().get("out");
28        log::debug!(
29            "src rewrite: {} -> {}",
30            wire_out.borrow().canonical(),
31            src.borrow().canonical(),
32        );
33        let old = self.insert(wire_out, Rc::clone(&src));
34        assert!(
35            old.is_none(),
36            "Attempting to add multiple sources to a wire"
37        );
38    }
39
40    // If the source is a wire, we have something like:
41    // ```
42    // c.in = wire.out;
43    // ```
44    // Which means all instances of `wire.in` can be replaced with `c.in` because the wire
45    // is being used to unconditionally forward values.
46    pub fn insert_dst_rewrite(
47        &mut self,
48        wire: RRC<ir::Cell>,
49        dst: RRC<ir::Port>,
50    ) {
51        let wire_in = wire.borrow().get("in");
52        log::debug!(
53            "dst rewrite: {} -> {}",
54            wire_in.borrow().canonical(),
55            dst.borrow().canonical(),
56        );
57        let old_v = self.insert(Rc::clone(&wire_in), dst);
58
59        // If the insertion process found an old key, we have something like:
60        // ```
61        // x.in = wire.out;
62        // y.in = wire.out;
63        // ```
64        // This means that `wire` is being used to forward values to many components and a
65        // simple inlining will not work.
66        if old_v.is_some() {
67            self.remove(wire_in);
68        }
69
70        // No forwading generated because the wire is used in dst position
71    }
72
73    /// Insert into rewrite map. If `v` is in current `rewrites`, then insert `k` -> `rewrites[v]`
74    /// and returns the previous rewrite if any.
75    fn insert(
76        &mut self,
77        from: RRC<ir::Port>,
78        to: RRC<ir::Port>,
79    ) -> Option<RRC<ir::Port>> {
80        let from_idx = from.borrow().canonical();
81        let old = self.rewrites.insert(from_idx, to);
82        if log::log_enabled!(log::Level::Debug) {
83            if let Some(ref old) = old {
84                log::debug!(
85                    "Previous rewrite: {} -> {}",
86                    from.borrow().canonical(),
87                    old.borrow().canonical()
88                );
89            }
90        }
91        old
92    }
93
94    // Removes the mapping associated with the key.
95    pub fn remove(&mut self, from: RRC<ir::Port>) {
96        log::debug!("Removing rewrite for `{}'", from.borrow().canonical());
97        let from_idx = from.borrow().canonical();
98        self.rewrites.remove(&from_idx);
99    }
100
101    /// Apply all the defined equalities to the current set of rewrites.
102    fn make_consistent(self) -> Self {
103        // Perform rewrites on the defined rewrites
104        let rewrites = self
105            .rewrites
106            .iter()
107            .map(|(from, to)| {
108                let to_idx = to.borrow().canonical();
109                let mut final_to = self.rewrites.get(&to_idx);
110                while let Some(new_to) = final_to {
111                    if let Some(new_new_to) =
112                        self.rewrites.get(&new_to.borrow().canonical())
113                    {
114                        final_to = Some(new_new_to);
115                    } else {
116                        break;
117                    }
118                }
119                (from.clone(), Rc::clone(final_to.unwrap_or(to)))
120            })
121            .collect();
122        Self { rewrites }
123    }
124}
125
126impl From<WireRewriter> for ir::rewriter::PortRewriteMap {
127    fn from(v: WireRewriter) -> Self {
128        v.make_consistent().rewrites
129    }
130}
131
132impl std::fmt::Debug for WireRewriter {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        for (ir::Canonical { cell, port }, port_ref) in &self.rewrites {
135            writeln!(
136                f,
137                "{}.{} -> {}",
138                cell.id,
139                port.id,
140                ir::Printer::port_to_str(&port_ref.borrow())
141            )?
142        }
143        Ok(())
144    }
145}
146
147/// Propagate unconditional reads and writes from wires.
148///
149/// If the source is a wire, we have something like:
150/// ```
151/// c.in = wire.out;
152/// ```
153/// Which means all instances of `wire.in` can be replaced with `c.in` because the wire
154/// is being used to unconditionally forward values.
155///
156/// If the destination is a wire, then we have something like:
157/// ```
158/// wire.in = c.out;
159/// ```
160/// Which means all instances of `wire.out` can be replaced with `c.out` because the
161/// wire is being used to forward values from `c.out`.
162///
163/// For example, we can safely inline the value `c` wherever `w.out` is read.
164/// ```
165/// w.in = c;
166/// group g {
167///   r.in = w.out
168/// }
169/// ```
170///
171/// Gets rewritten to:
172/// ```
173/// w.in = c;
174/// group g {
175///   r.in = c;
176/// }
177/// ```
178///
179/// Correctly propagates writes through mutliple wires:
180/// ```
181/// w1.in = c;
182/// w2.in = w1.out;
183/// r.in = w2.out;
184/// ```
185/// into:
186/// ```
187/// w1.in = c;
188/// w2.in = c;
189/// r.in = c;
190/// ```
191pub struct CombProp {
192    /// Disable automatic removal of some dead assignments needed for correctness and instead mark
193    /// them with @dead.
194    /// NOTE: if this is enabled, the pass will not remove obviously conflicting assignments.
195    no_eliminate: bool,
196}
197
198impl ConstructVisitor for CombProp {
199    fn from(ctx: &ir::Context) -> calyx_utils::CalyxResult<Self>
200    where
201        Self: Sized,
202    {
203        let opts = Self::get_opts(ctx);
204        Ok(CombProp {
205            no_eliminate: opts[&"no-eliminate"].bool(),
206        })
207    }
208
209    fn clear_data(&mut self) {
210        /* do nothing */
211    }
212}
213
214impl Named for CombProp {
215    fn name() -> &'static str {
216        "comb-prop"
217    }
218
219    fn description() -> &'static str {
220        "propagate unconditional continuous assignments"
221    }
222
223    fn opts() -> Vec<PassOpt> {
224        vec![PassOpt::new(
225            "no-eliminate",
226            "mark dead assignments with @dead instead of removing them",
227            ParseVal::Bool(false),
228            PassOpt::parse_bool,
229        )]
230    }
231}
232
233impl CombProp {
234    /// Predicate for removing an assignment
235    #[inline]
236    fn remove_predicate<T>(
237        rewritten: &[RRC<ir::Port>],
238        assign: &ir::Assignment<T>,
239    ) -> bool
240    where
241        T: Clone + Eq + ToString,
242    {
243        let out = rewritten.iter().any(|v| Rc::ptr_eq(v, &assign.dst));
244        if log::log_enabled!(log::Level::Debug) && out {
245            log::debug!("Removing: {}", ir::Printer::assignment_to_str(assign));
246        }
247        out
248    }
249
250    /// Mark assignments for removal
251    fn remove_rewritten(
252        &self,
253        rewritten: &[RRC<ir::Port>],
254        comp: &mut ir::Component,
255    ) {
256        log::debug!(
257            "Rewritten: {}",
258            rewritten
259                .iter()
260                .map(|p| format!("{}", p.borrow().canonical()))
261                .collect::<Vec<_>>()
262                .join(", ")
263        );
264        // Remove writes to all the ports that show up in write position
265        if self.no_eliminate {
266            // If elimination is disabled, mark the assignments with the @dead attribute.
267            for assign in &mut comp.continuous_assignments {
268                if Self::remove_predicate(rewritten, assign) {
269                    assign.attributes.insert(ir::InternalAttr::DEAD, 1)
270                }
271            }
272        } else {
273            comp.continuous_assignments.retain_mut(|assign| {
274                !Self::remove_predicate(rewritten, assign)
275            });
276        }
277    }
278
279    fn parent_is_wire(parent: &ir::PortParent) -> bool {
280        match parent {
281            ir::PortParent::Cell(cell_wref) => {
282                let cr = cell_wref.upgrade();
283                let cell = cr.borrow();
284                cell.is_primitive(Some("std_wire"))
285            }
286            ir::PortParent::Group(_) => false,
287            ir::PortParent::StaticGroup(_) => false,
288            ir::PortParent::FSM(_) => false,
289        }
290    }
291
292    fn disable_rewrite<T>(
293        assign: &mut ir::Assignment<T>,
294        rewrites: &mut WireRewriter,
295    ) {
296        if assign.guard.is_true() {
297            return;
298        }
299        assign.for_each_port(|pr| {
300            let p = pr.borrow();
301            if p.direction == ir::Direction::Output
302                && Self::parent_is_wire(&p.parent)
303            {
304                let cell = p.cell_parent();
305                rewrites.remove(cell.borrow().get("in"));
306            }
307            // Never change the port
308            None
309        });
310    }
311}
312
313impl Visitor for CombProp {
314    fn start(
315        &mut self,
316        comp: &mut ir::Component,
317        _sigs: &ir::LibrarySignatures,
318        _comps: &[ir::Component],
319    ) -> VisResult {
320        let mut rewrites = WireRewriter::default();
321
322        for assign in &mut comp.continuous_assignments {
323            // Cannot add rewrites for conditional statements
324            if !assign.guard.is_true() {
325                continue;
326            }
327
328            let dst = assign.dst.borrow();
329            if Self::parent_is_wire(&dst.parent) {
330                rewrites.insert_src_rewrite(
331                    dst.cell_parent(),
332                    Rc::clone(&assign.src),
333                );
334            }
335
336            let src = assign.src.borrow();
337            if Self::parent_is_wire(&src.parent) {
338                rewrites.insert_dst_rewrite(
339                    src.cell_parent(),
340                    Rc::clone(&assign.dst),
341                );
342            }
343        }
344
345        // Disable all rewrites:
346        // If the statement uses a wire output (w.out) as a source, we
347        // cannot rewrite the wire's input (w.in) uses
348        comp.for_each_assignment(|assign| {
349            Self::disable_rewrite(assign, &mut rewrites)
350        });
351        comp.for_each_static_assignment(|assign| {
352            Self::disable_rewrite(assign, &mut rewrites)
353        });
354
355        // Rewrite assignments
356        // Make the set of rewrites consistent and transform into map
357        let rewrites: ir::rewriter::PortRewriteMap = rewrites.into();
358        let rewritten = rewrites.values().cloned().collect_vec();
359        self.remove_rewritten(&rewritten, comp);
360
361        comp.for_each_assignment(|assign| {
362            if !assign.attributes.has(ir::InternalAttr::DEAD) {
363                assign.for_each_port(|port| {
364                    rewrites.get(&port.borrow().canonical()).cloned()
365                })
366            }
367        });
368        comp.for_each_static_assignment(|assign| {
369            if !assign.attributes.has(ir::InternalAttr::DEAD) {
370                assign.for_each_port(|port| {
371                    rewrites.get(&port.borrow().canonical()).cloned()
372                })
373            }
374        });
375
376        let rewriter = ir::Rewriter {
377            port_map: rewrites,
378            ..Default::default()
379        };
380        rewriter.rewrite_control(&mut comp.control.borrow_mut());
381
382        Ok(Action::Stop)
383    }
384}