calyx_opt/analysis/
static_par_timing.rs

1use super::LiveRangeAnalysis;
2use crate::analysis::ControlId;
3use calyx_ir as ir;
4use std::{collections::HashMap, fmt::Debug};
5
6/// maps cell names to a vector of tuples (i,j), which is the clock
7/// cycles (relative to the start of the par) that enable is live
8/// the tuples/intervals should always be sorted within the vec
9type CellTimingMap = HashMap<ir::Id, Vec<(u64, u64)>>;
10/// maps threads (i.e., direct children of pars) to cell
11/// timing maps
12type ThreadTimingMap = HashMap<u64, CellTimingMap>;
13
14#[derive(Default)]
15/// Calculate live ranges across static par blocks.
16/// Assumes control ids have already been given; it does not add its own
17pub struct StaticParTiming {
18    /// Map from par block ids to cell_timing_maps
19    cell_map: HashMap<u64, ThreadTimingMap>,
20    /// name of component
21    component_name: ir::Id,
22}
23
24impl Debug for StaticParTiming {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        //must sort the hashmap and hashsets in order to get consistent ordering
27        writeln!(
28            f,
29            "This maps ids of par blocks to \" cell timing maps \", which map cells to intervals (i,j), that signify the clock cycles the group is active for, \n relative to the start of the given par block"
30        )?;
31        write!(
32            f,
33            "============ Map for Component \"{}\"",
34            self.component_name
35        )?;
36        writeln!(f, " ============")?;
37        let map = self.cell_map.clone();
38        // Sorting map to get deterministic ordering
39        let mut vec: Vec<(u64, ThreadTimingMap)> = map.into_iter().collect();
40        vec.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
41        for (par_id, thread_timing_map) in vec.into_iter() {
42            write!(f, "========")?;
43            write!(f, "Par Node ID: {par_id:?}")?;
44            writeln!(f, " ========")?;
45            let mut vec1: Vec<(u64, CellTimingMap)> =
46                thread_timing_map.into_iter().collect::<Vec<_>>();
47            vec1.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
48            for (thread_id, cell_timing_map) in vec1 {
49                write!(f, "====")?;
50                write!(f, "Child/Thread ID: {thread_id:?}")?;
51                writeln!(f, " ====")?;
52                let mut vec2: Vec<(ir::Id, Vec<(u64, u64)>)> =
53                    cell_timing_map.into_iter().collect::<Vec<_>>();
54                vec2.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
55                for (cell_name, clock_intervals) in vec2 {
56                    write!(f, "{cell_name:?} -- ")?;
57                    writeln!(f, "{clock_intervals:?}")?;
58                }
59            }
60            writeln!(f)?
61        }
62        write!(f, "")
63    }
64}
65
66impl StaticParTiming {
67    /// Construct a live range analysis.
68    pub fn new(
69        control: &mut ir::Control,
70        comp_name: ir::Id,
71        live: &LiveRangeAnalysis,
72    ) -> Self {
73        let mut time_map = StaticParTiming {
74            component_name: comp_name,
75            ..Default::default()
76        };
77
78        time_map.build_time_map(control, live);
79
80        time_map
81    }
82
83    /// `par_id` is the id of a par thread.
84    /// `thread_a` and `thread_b` are ids of direct children of par_id (if `thread_a` and
85    /// `thread_b` are *not* direct children of par_id, then the function will error)
86    /// `a` and `b` are cell names
87    /// liveness_overlaps checks if the liveness of `a` in `thread_a` ever overlaps
88    /// with the liveness of `b` in `thread_b`
89    /// if `par_id` is not static, then will automtically return true
90    pub fn liveness_overlaps(
91        &self,
92        par_id: &u64,
93        thread_a: &u64,
94        thread_b: &u64,
95        a: &ir::Id,
96        b: &ir::Id,
97    ) -> bool {
98        // unwrapping cell_map data structure, eventually getting to the two vecs
99        // a_liveness, b_liveness, that we actually care about
100        let thread_timing_map = match self.cell_map.get(par_id) {
101            Some(m) => m,
102            // not a static par block, so must assume overlap
103            None => return true,
104        };
105        let a_liveness = thread_timing_map
106            .get(thread_a)
107            .unwrap_or_else(|| {
108                unreachable!("{} not a thread in {}", thread_a, par_id)
109            })
110            .get(a);
111        let b_liveness = thread_timing_map
112            .get(thread_b)
113            .unwrap_or_else(|| {
114                unreachable!("{} not a thread in {}", thread_a, par_id)
115            })
116            .get(b);
117        match (a_liveness, b_liveness) {
118            (Some(a_intervals), Some(b_intervals)) => {
119                let mut a_iter = a_intervals.iter();
120                let mut b_iter = b_intervals.iter();
121                let mut cur_a = a_iter.next();
122                let mut cur_b = b_iter.next();
123                // this relies on the fact that a_iter and b_iter are sorted
124                // in ascending order
125                while cur_a.is_some() && cur_b.is_some() {
126                    let ((a1, a2), (b1, b2)) = (cur_a.unwrap(), cur_b.unwrap());
127                    // if a1 is smaller, checks if it overlaps with
128                    // b1. If it does, return true, otherwise, advance
129                    // a in the iteration
130                    match a1.cmp(b1) {
131                        std::cmp::Ordering::Less => {
132                            if a2 > b1 {
133                                return true;
134                            } else {
135                                cur_a = a_iter.next();
136                            }
137                        }
138                        std::cmp::Ordering::Greater => {
139                            if b2 > a1 {
140                                return true;
141                            } else {
142                                cur_b = b_iter.next();
143                            }
144                        }
145                        std::cmp::Ordering::Equal => return true,
146                    }
147                }
148                false
149            }
150            _ => false,
151        }
152    }
153
154    // updates self.cell_map, returns the state after the invoke/enable has occured
155    // assumes that there is a cur_state = (par_id, thread_id, cur_clock)
156    // also, id is the id of the invoke/enable, and latency is the latency of the
157    // invoke/enable
158    fn update_invoke_enable(
159        &mut self,
160        id: u64,
161        latency: u64,
162        live: &LiveRangeAnalysis,
163        cur_state: (u64, u64, u64),
164    ) -> (u64, u64, u64) {
165        let (par_id, thread_id, cur_clock) = cur_state;
166        // live set is all cells live at this invoke/enable, organized by cell type
167        let live_set = live.get(&id).clone();
168        // go thru all live cells in this enable add them to appropriate entry in
169        // self.cell_map
170        for (_, live_cells) in live_set {
171            for cell in live_cells {
172                let interval_vec = self
173                    .cell_map
174                    .entry(par_id)
175                    .or_default()
176                    .entry(thread_id)
177                    .or_default()
178                    .entry(cell)
179                    .or_default();
180                // we need to check whether we've already added this
181                // to vec before or not. If we haven't,
182                // then we can push
183                // This can sometimes occur if there is a par block,
184                // that contains a while loop, and that while loop
185                // contains another par block.
186                match interval_vec.last() {
187                    None => interval_vec.push((cur_clock, cur_clock + latency)),
188                    Some(interval) => {
189                        if *interval != (cur_clock, cur_clock + latency) {
190                            interval_vec.push((cur_clock, cur_clock + latency))
191                        }
192                    }
193                }
194            }
195        }
196        (par_id, thread_id, cur_clock + latency)
197    }
198
199    // Recursively updates self.time_map
200    // This is a helper function for fn `build_time_map`.
201    // Read comment for that function to see what this function is doing
202    fn build_time_map_static(
203        &mut self,
204        sc: &ir::StaticControl,
205        // cur_state = Some(parent_par_id, thread_id, cur_clock) if we're inside a static par, None otherwise.
206        // parent_par_id = Node ID of the static par that we're analyzing
207        // thread_id = Node ID of the thread that we're analyzing within the par
208        // note that this thread_id only corresponds to "direct" children
209        // cur_clock = current clock cycles we're at relative to the start of parent_par
210        cur_state: Option<(u64, u64, u64)>,
211        // LiveRangeAnalysis instance
212        live: &LiveRangeAnalysis,
213    ) -> Option<(u64, u64, u64)> {
214        match sc {
215            ir::StaticControl::Empty(_) => cur_state,
216            ir::StaticControl::If(ir::StaticIf {
217                tbranch, fbranch, ..
218            }) => match cur_state {
219                Some((parent_par, thread_id, cur_clock)) => {
220                    // we already know parent par + latency of the if stmt, so don't
221                    // care about return type: we just want to add enables to the timing map
222                    self.build_time_map_static(tbranch, cur_state, live);
223                    self.build_time_map_static(fbranch, cur_state, live);
224                    Some((parent_par, thread_id, cur_clock + sc.get_latency()))
225                }
226                None => {
227                    // should still look thru the branches in case there are static pars
228                    // inside the branches
229                    self.build_time_map_static(tbranch, cur_state, live);
230                    self.build_time_map_static(fbranch, cur_state, live);
231                    None
232                }
233            },
234            ir::StaticControl::Enable(ir::StaticEnable { group, .. }) => {
235                match cur_state {
236                    Some(cur_state_unwrapped) => {
237                        let enable_id = ControlId::get_guaranteed_id_static(sc);
238                        let latency = group.borrow().get_latency();
239                        Some(self.update_invoke_enable(
240                            enable_id,
241                            latency,
242                            live,
243                            cur_state_unwrapped,
244                        ))
245                    }
246                    None => cur_state,
247                }
248            }
249            ir::StaticControl::Invoke(inv) => match cur_state {
250                Some(cur_state_unwrapped) => {
251                    let invoke_id = ControlId::get_guaranteed_id_static(sc);
252                    let latency = inv.latency;
253                    Some(self.update_invoke_enable(
254                        invoke_id,
255                        latency,
256                        live,
257                        cur_state_unwrapped,
258                    ))
259                }
260                None => cur_state,
261            },
262            ir::StaticControl::Repeat(ir::StaticRepeat {
263                body,
264                num_repeats,
265                ..
266            }) => {
267                if cur_state.is_some() {
268                    // essentially just unrolling the loop
269                    let mut new_state = cur_state;
270                    for _ in 0..*num_repeats {
271                        new_state =
272                            self.build_time_map_static(body, new_state, live)
273                    }
274                    new_state
275                } else {
276                    // look thru while body for static pars
277                    self.build_time_map_static(body, cur_state, live);
278                    None
279                }
280            }
281            ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => {
282                // this works whether or not cur_state is None or Some
283                let mut new_state = cur_state;
284                for stmt in stmts {
285                    new_state =
286                        self.build_time_map_static(stmt, new_state, live);
287                }
288                new_state
289            }
290            ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
291                // We know that all children must be static
292                // Analyze the Current Par
293                for stmt in stmts {
294                    self.build_time_map_static(
295                        stmt,
296                        Some((
297                            ControlId::get_guaranteed_id_static(sc),
298                            ControlId::get_guaranteed_id_static(stmt),
299                            0,
300                        )),
301                        live,
302                    );
303                }
304                // If we have nested pars, want to get the clock cycles relative
305                // to the start of both the current par and the nested par.
306                // So we have the following code to possibly get the clock cycles
307                // relative to the parent par.
308                // Might be overkill, but trying to keep the analysis general.
309                match cur_state {
310                    Some((cur_parent_par, cur_thread, cur_clock)) => {
311                        for stmt in stmts {
312                            self.build_time_map_static(stmt, cur_state, live);
313                        }
314                        Some((
315                            cur_parent_par,
316                            cur_thread,
317                            cur_clock + sc.get_latency(),
318                        ))
319                    }
320                    None => None,
321                }
322            }
323        }
324    }
325
326    // Recursively updates self.time_map
327    // Takes in Control block `c`, Live Range Analyss `live`
328    // self.time_map maps par ids -> (maps of thread ids -> (maps of cells -> intervals for which
329    // cells are live))
330    fn build_time_map(
331        &mut self,
332        c: &ir::Control,
333        // LiveRangeAnalysis instance
334        live: &LiveRangeAnalysis,
335    ) {
336        match c {
337            ir::Control::Invoke(_)
338            | ir::Control::Empty(_)
339            | ir::Control::Enable(_) => (),
340            ir::Control::Par(ir::Par { stmts, .. })
341            | ir::Control::Seq(ir::Seq { stmts, .. }) => {
342                for stmt in stmts {
343                    self.build_time_map(stmt, live)
344                }
345            }
346            ir::Control::If(ir::If {
347                tbranch, fbranch, ..
348            }) => {
349                self.build_time_map(tbranch, live);
350                self.build_time_map(fbranch, live);
351            }
352            ir::Control::While(ir::While { body, .. })
353            | ir::Control::Repeat(ir::Repeat { body, .. }) => {
354                self.build_time_map(body, live);
355            }
356            ir::Control::Static(sc) => {
357                self.build_time_map_static(sc, None, live);
358            }
359            ir::Control::FSMEnable(_) => {
360                todo!("should not encounter fsm nodes")
361            }
362        }
363    }
364}