calyx_opt/analysis/
compute_static.rs

1use calyx_ir::{self as ir, GetAttributes};
2use std::collections::HashMap;
3use std::rc::Rc;
4
5/// Trait to propagate and extra "static" attributes through [ir::Control].
6/// Calling the update function ensures that the current program, as well as all
7/// sub-programs have a "static" attribute on them.
8/// Usage:
9/// ```
10/// use calyx::analysis::compute_static::WithStatic;
11/// let con: ir::Control = todo!(); // A complex control program
12/// con.update(&HashMap::new());    // Compute the static information for the program
13/// ```
14pub trait WithStatic
15where
16    Self: GetAttributes,
17{
18    /// Extra information needed to compute static information for this type.
19    type Info;
20
21    /// Compute the static information for the type if possible and add it to its attribute.
22    /// Implementors should instead implement [WithStatic::compute_static] and call this function
23    /// on sub-programs.
24    /// **Ensures**: All sub-programs of the type will also be updated.
25    fn update_static(&mut self, extra: &Self::Info) -> Option<u64> {
26        if let Some(time) = self.compute_static(extra) {
27            self.get_mut_attributes()
28                .insert(ir::NumAttr::Promotable, time);
29            Some(time)
30        } else {
31            None
32        }
33    }
34
35    /// Compute the static information for the type if possible and update all sub-programs.
36    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64>;
37}
38
39type CompTime = HashMap<ir::Id, u64>;
40
41impl WithStatic for ir::Control {
42    // Mapping from name of components to their latency information
43    type Info = CompTime;
44
45    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
46        match self {
47            ir::Control::Seq(seq) => seq.update_static(extra),
48            ir::Control::Par(par) => par.update_static(extra),
49            ir::Control::If(if_) => if_.update_static(extra),
50            ir::Control::While(wh) => wh.update_static(extra),
51            ir::Control::Repeat(rep) => rep.update_static(extra),
52            ir::Control::Invoke(inv) => inv.update_static(extra),
53            ir::Control::Enable(en) => en.update_static(&()),
54            ir::Control::Empty(_) => Some(0),
55            ir::Control::Static(sc) => Some(sc.get_latency()),
56            ir::Control::FSMEnable(_) => None,
57        }
58    }
59}
60
61impl WithStatic for ir::Enable {
62    type Info = ();
63    fn compute_static(&mut self, _: &Self::Info) -> Option<u64> {
64        // Attempt to get the latency from the attribute on the enable first, or
65        // failing that, from the group.
66        self.attributes.get(ir::NumAttr::Promotable).or_else(|| {
67            self.group.borrow().attributes.get(ir::NumAttr::Promotable)
68        })
69    }
70}
71
72impl WithStatic for ir::Invoke {
73    type Info = CompTime;
74    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
75        self.attributes.get(ir::NumAttr::Promotable).or_else(|| {
76            let comp = self.comp.borrow().type_name()?;
77            extra.get(&comp).cloned()
78        })
79    }
80}
81
82/// Walk over a set of control statements and call `update_static` on each of them.
83/// Use a merge function to merge the results of the `update_static` calls.
84fn walk_static<T, F>(stmts: &mut [T], extra: &T::Info, merge: F) -> Option<u64>
85where
86    T: WithStatic,
87    F: Fn(u64, u64) -> u64,
88{
89    let mut latency = Some(0);
90    // This is implemented as a loop because we want to call `update_static` on
91    // each statement even if we cannot compute a total latency anymore.
92    for stmt in stmts.iter_mut() {
93        let stmt_latency = stmt.update_static(extra);
94        latency = match (latency, stmt_latency) {
95            (Some(l), Some(s)) => Some(merge(l, s)),
96            (_, _) => None,
97        }
98    }
99    latency
100}
101
102impl WithStatic for ir::Seq {
103    type Info = CompTime;
104    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
105        walk_static(&mut self.stmts, extra, |x, y| x + y)
106    }
107}
108
109impl WithStatic for ir::Par {
110    type Info = CompTime;
111    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
112        walk_static(&mut self.stmts, extra, std::cmp::max)
113    }
114}
115
116impl WithStatic for ir::If {
117    type Info = CompTime;
118    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
119        // Cannot compute latency information for `if`-`with`
120        let t_latency = self.tbranch.update_static(extra);
121        let f_latency = self.fbranch.update_static(extra);
122        if self.cond.is_some() {
123            log::debug!("Cannot compute latency for while-with");
124            return None;
125        }
126        match (t_latency, f_latency) {
127            (Some(t), Some(f)) => Some(std::cmp::max(t, f)),
128            (_, _) => None,
129        }
130    }
131}
132
133impl WithStatic for ir::While {
134    type Info = CompTime;
135    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
136        let b_time = self.body.update_static(extra)?;
137        // Cannot compute latency information for `while`-`with`
138        if self.cond.is_some() {
139            log::debug!("Cannot compute latency for while-with");
140            return None;
141        }
142        let bound = self.attributes.get(ir::NumAttr::Bound)?;
143        Some(bound * b_time)
144    }
145}
146
147impl WithStatic for ir::Repeat {
148    type Info = CompTime;
149    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
150        let b_time = self.body.update_static(extra)?;
151        let num_repeats = self.num_repeats;
152        Some(num_repeats * b_time)
153    }
154}
155
156pub trait IntoStatic {
157    type StaticCon;
158    fn make_static(&mut self) -> Option<Self::StaticCon>;
159}
160
161impl IntoStatic for ir::Seq {
162    type StaticCon = ir::StaticSeq;
163    fn make_static(&mut self) -> Option<Self::StaticCon> {
164        let mut static_stmts: Vec<ir::StaticControl> = Vec::new();
165        let mut latency = 0;
166        for stmt in self.stmts.iter() {
167            if !matches!(stmt, ir::Control::Static(_)) {
168                log::debug!(
169                    "Cannot build `static seq`. Control statement inside `seq` is not static"
170                );
171                return None;
172            }
173        }
174
175        for stmt in self.stmts.drain(..) {
176            let ir::Control::Static(sc) = stmt else {
177                unreachable!(
178                    "We have already checked that all control statements are static"
179                )
180            };
181            latency += sc.get_latency();
182            static_stmts.push(sc);
183        }
184        Some(ir::StaticSeq {
185            stmts: static_stmts,
186            attributes: self.attributes.clone(),
187            latency,
188        })
189    }
190}
191
192impl IntoStatic for ir::Par {
193    type StaticCon = ir::StaticPar;
194    fn make_static(&mut self) -> Option<Self::StaticCon> {
195        let mut static_stmts: Vec<ir::StaticControl> = Vec::new();
196        let mut latency = 0;
197        for stmt in self.stmts.iter() {
198            if !matches!(stmt, ir::Control::Static(_)) {
199                log::debug!(
200                    "Cannot build `static seq`. Control statement inside `seq` is not static"
201                );
202                return None;
203            }
204        }
205
206        for stmt in self.stmts.drain(..) {
207            let ir::Control::Static(sc) = stmt else {
208                unreachable!(
209                    "We have already checked that all control statements are static"
210                )
211            };
212            latency = std::cmp::max(latency, sc.get_latency());
213            static_stmts.push(sc);
214        }
215        Some(ir::StaticPar {
216            stmts: static_stmts,
217            attributes: self.attributes.clone(),
218            latency,
219        })
220    }
221}
222
223impl IntoStatic for ir::If {
224    type StaticCon = ir::StaticIf;
225    fn make_static(&mut self) -> Option<Self::StaticCon> {
226        if !(self.tbranch.is_static() && self.fbranch.is_static()) {
227            return None;
228        };
229        let tb = std::mem::replace(&mut *self.tbranch, ir::Control::empty());
230        let fb = std::mem::replace(&mut *self.fbranch, ir::Control::empty());
231        let ir::Control::Static(sc_t) = tb else {
232            unreachable!("we have already checked tbranch to be static")
233        };
234        let ir::Control::Static(sc_f) = fb else {
235            unreachable!("we have already checker fbranch to be static")
236        };
237        let latency = std::cmp::max(sc_t.get_latency(), sc_f.get_latency());
238        Some(ir::StaticIf {
239            tbranch: Box::new(sc_t),
240            fbranch: Box::new(sc_f),
241            attributes: ir::Attributes::default(),
242            port: Rc::clone(&self.port),
243            latency,
244        })
245    }
246}