calyx_opt/analysis/
static_tree.rs

1use super::{FSMEncoding, StaticFSM};
2use calyx_ir::{self as ir};
3use calyx_ir::{Nothing, build_assignments};
4use calyx_ir::{guard, structure};
5use itertools::Itertools;
6use std::collections::{BTreeMap, HashMap};
7use std::ops::Not;
8use std::rc::Rc;
9
10use super::GraphColoring;
11
12/// Optional Rc of a RefCell of a StaticFSM object.
13type OptionalStaticFSM = Option<ir::RRC<StaticFSM>>;
14/// Query (i ,(j,k)) that corresponds to:
15/// Am I in iteration i, and between cylces j and k within
16/// that query?
17type SingleIterQuery = (u64, (u64, u64));
18/// Query (i,j) that corresponds to:
19/// Am I between iterations i and j, inclusive?
20type ItersQuery = (u64, u64);
21
22/// Helpful for translating queries for the FSMTree structure.
23/// Because of the tree structure, %[i:j] is no longer is always equal to i <= fsm < j.
24/// Offload(i) means the FSM is offloading when fsm == i: so if the fsm == i,
25/// we need to look at the children to know what cycle we are in exactly.
26/// Normal(i,j) means the FSM is outputing (i..j), incrementing each cycle (i.e.,
27/// like normal) and not offloading. Note that even though the FSM is outputting
28/// i..j each cycle, that does not necesarily mean we are in cycles i..j (due
29/// to offloading performed in the past.)
30#[derive(Debug)]
31pub enum StateType {
32    Normal((u64, u64)),
33    Offload(u64),
34}
35
36/// Node can either be a SingleNode (i.e., a single node) or ParNodes (i.e., a group of
37/// nodes that are executing in parallel).
38/// Most methods in `Node` simply call the equivalent methods for each
39/// of the two possible variants.
40/// Perhaps could be more compactly implemented as a Trait.
41pub enum Node {
42    Single(SingleNode),
43    Par(ParNodes),
44}
45
46// The following methods are used to actually instantiate the FSMTree structure
47// and compile static groups/control to dynamic groups/control.
48impl Node {
49    /// Instantiate the necessary registers.
50    /// The equivalent methods for the two variants contain more implementation
51    /// details.
52    /// `coloring`, `colors_to_max_values`, and `colors_to_fsm` are necessary
53    /// to know whether we actually need to instantiate a new FSM, or we can
54    /// juse use another node's FSM.
55    pub fn instantiate_fsms(
56        &mut self,
57        builder: &mut ir::Builder,
58        coloring: &HashMap<ir::Id, ir::Id>,
59        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
60        colors_to_fsm: &mut HashMap<
61            ir::Id,
62            (OptionalStaticFSM, OptionalStaticFSM),
63        >,
64        one_hot_cutoff: u64,
65    ) {
66        match self {
67            Node::Single(single_node) => single_node.instantiate_fsms(
68                builder,
69                coloring,
70                colors_to_max_values,
71                colors_to_fsm,
72                one_hot_cutoff,
73            ),
74            Node::Par(par_nodes) => par_nodes.instantiate_fsms(
75                builder,
76                coloring,
77                colors_to_max_values,
78                colors_to_fsm,
79                one_hot_cutoff,
80            ),
81        }
82    }
83
84    /// Count to n. Need to call `instantiate_fsms` before calling `count_to_n`.
85    /// The equivalent methods for the two variants contain more implementation
86    /// details.
87    /// `incr_start_cond` can optionally guard the 0->1 transition.
88    pub fn count_to_n(
89        &mut self,
90        builder: &mut ir::Builder,
91        incr_start_cond: Option<ir::Guard<Nothing>>,
92    ) {
93        match self {
94            Node::Single(single_node) => {
95                single_node.count_to_n(builder, incr_start_cond)
96            }
97            Node::Par(par_nodes) => {
98                par_nodes.count_to_n(builder, incr_start_cond)
99            }
100        }
101    }
102
103    /// "Realize" the static groups into dynamic groups.
104    /// The main challenge is converting %[i:j] into fsm guards.
105    /// Need to call `instantiate_fsms` and
106    /// `count_to_n` before calling `realize`.
107    /// The equivalent methods for the two variants contain more implementation
108    /// details.
109    /// `reset_early_map`, `fsm_info_map`, and `group_rewrites` are just metadata
110    /// to make it easier to rewrite control, add wrappers, etc.
111    pub fn realize(
112        &mut self,
113        ignore_timing_guards: bool,
114        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
115        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
116        fsm_info_map: &mut HashMap<
117            ir::Id,
118            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
119        >,
120        group_rewrites: &mut ir::rewriter::PortRewriteMap,
121        builder: &mut ir::Builder,
122    ) {
123        match self {
124            Node::Single(single_node) => single_node.realize(
125                ignore_timing_guards,
126                static_groups,
127                reset_early_map,
128                fsm_info_map,
129                group_rewrites,
130                builder,
131            ),
132            Node::Par(par_nodes) => par_nodes.realize(
133                ignore_timing_guards,
134                static_groups,
135                reset_early_map,
136                fsm_info_map,
137                group_rewrites,
138                builder,
139            ),
140        }
141    }
142
143    /// Get the equivalent fsm guard when the tree is between cycles i and j, i.e.,
144    /// when i <= cycle_count < j.
145    /// The equivalent methods for the two variants contain more implementation
146    /// details.
147    pub fn query_between(
148        &mut self,
149        query: (u64, u64),
150        builder: &mut ir::Builder,
151    ) -> ir::Guard<Nothing> {
152        match self {
153            Node::Single(single_node) => {
154                single_node.query_between(query, builder)
155            }
156            Node::Par(par_nodes) => par_nodes.query_between(query, builder),
157        }
158    }
159}
160
161/// The following methods are used to help build the conflict graph for coloring
162/// to share FSMs
163impl Node {
164    /// Get the names of all nodes (i.e., the names of the groups for each node
165    /// in the tree).
166    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
167        match self {
168            Node::Single(single_node) => single_node.get_all_nodes(),
169            Node::Par(par_nodes) => par_nodes.get_all_nodes(),
170        }
171    }
172
173    /// Adds conflicts between nodes in the tree that execute at the same time.
174    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
175        match self {
176            Node::Single(single_node) => {
177                single_node.add_conflicts(conflict_graph)
178            }
179            Node::Par(par_nodes) => par_nodes.add_conflicts(conflict_graph),
180        }
181    }
182
183    /// Get max value of all nodes in the tree, according to some function f.
184    /// `f` takes in a Tree (i.e., a node type) and returns a `u64`.
185    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
186    where
187        F: Fn(&SingleNode) -> u64,
188    {
189        match self {
190            Node::Single(single_node) => single_node.get_max_value(name, f),
191            Node::Par(par_nodes) => par_nodes.get_max_value(name, f),
192        }
193    }
194}
195
196// Used to compile static component interface. This is really annoying to do, since
197// for static components, they only need to be guarded for %0, while for static
198// groups, they need to be guarded for %[0:n]. This creates some annoying `if`
199// statements.
200impl Node {
201    // Helper to `preprocess_static_interface_assigns`
202    // Looks recursively thru guard to transform %[0:n] into %0 | %[1:n].
203    fn preprocess_static_interface_guard(
204        guard: ir::Guard<ir::StaticTiming>,
205        comp_sig: ir::RRC<ir::Cell>,
206    ) -> ir::Guard<ir::StaticTiming> {
207        match guard {
208            ir::Guard::Info(st) => {
209                let (beg, end) = st.get_interval();
210                if beg == 0 {
211                    // Replace %[0:n] -> (%0 & comp.go) | %[1:n]
212                    // Cannot just do comp.go | %[1:n] because we want
213                    // clients to be able to assert `comp.go` even after the first
214                    // cycle w/o affecting correctness.
215                    let first_cycle =
216                        ir::Guard::Info(ir::StaticTiming::new((0, 1)));
217                    let comp_go = guard!(comp_sig["go"]);
218                    let first_and_go = ir::Guard::and(comp_go, first_cycle);
219                    if end == 1 {
220                        return first_and_go;
221                    } else {
222                        let after =
223                            ir::Guard::Info(ir::StaticTiming::new((1, end)));
224                        let cong = ir::Guard::or(first_and_go, after);
225                        return cong;
226                    }
227                }
228                guard
229            }
230            ir::Guard::And(l, r) => {
231                let left = Self::preprocess_static_interface_guard(
232                    *l,
233                    Rc::clone(&comp_sig),
234                );
235                let right =
236                    Self::preprocess_static_interface_guard(*r, comp_sig);
237                ir::Guard::and(left, right)
238            }
239            ir::Guard::Or(l, r) => {
240                let left = Self::preprocess_static_interface_guard(
241                    *l,
242                    Rc::clone(&comp_sig),
243                );
244                let right =
245                    Self::preprocess_static_interface_guard(*r, comp_sig);
246                ir::Guard::or(left, right)
247            }
248            ir::Guard::Not(g) => {
249                let a = Self::preprocess_static_interface_guard(*g, comp_sig);
250                ir::Guard::Not(Box::new(a))
251            }
252            _ => guard,
253        }
254    }
255
256    // Looks recursively thru assignment's guard to %[0:n] into %0 | %[1:n].
257    pub fn preprocess_static_interface_assigns(
258        assign: &mut ir::Assignment<ir::StaticTiming>,
259        comp_sig: ir::RRC<ir::Cell>,
260    ) {
261        assign
262            .guard
263            .update(|g| Self::preprocess_static_interface_guard(g, comp_sig));
264    }
265}
266
267// The following are just standard `getter` methods.
268impl Node {
269    /// Take the assignments of the root of the tree and return them.
270    /// This only works on a single node (i.e., the `Tree`` variant).
271    pub fn take_root_assigns(&mut self) -> Vec<ir::Assignment<Nothing>> {
272        match self {
273            Node::Single(single_node) => {
274                std::mem::take(&mut single_node.root.1)
275            }
276            Node::Par(_) => {
277                unreachable!(
278                    "Cannot take root assignments of Node::Par variant"
279                )
280            }
281        }
282    }
283
284    /// Get the name of the root of the tree and return them.
285    /// This only works on a single node (i.e., the `Tree`` variant).
286    pub fn get_root_name(&mut self) -> ir::Id {
287        match self {
288            Node::Single(single_node) => single_node.root.0,
289            Node::Par(_) => {
290                unreachable!("Cannot take root name of Node::Par variant")
291            }
292        }
293    }
294
295    /// Get the name of the group at the root of the tree (if a `Tree` variant) or
296    /// of the equivalent `par` group (i.e., the name of the group that triggers
297    /// execution of all the trees) if a `Par` variant.
298    pub fn get_group_name(&self) -> ir::Id {
299        match self {
300            Node::Single(single_node) => single_node.root.0,
301            Node::Par(par_nodes) => par_nodes.group_name,
302        }
303    }
304
305    /// Gets latency of the overall tree.
306    pub fn get_latency(&self) -> u64 {
307        match self {
308            Node::Single(single_node) => single_node.latency,
309            Node::Par(par_nodes) => par_nodes.latency,
310        }
311    }
312
313    /// Gets the children of root of the tree (if a `Tree` variant) or
314    /// of the threads (i.e., trees) that are scheduled to execute (if a `Par`
315    /// variant.)
316    pub fn get_children(&mut self) -> &mut Vec<(Node, (u64, u64))> {
317        match self {
318            Node::Single(single_node) => &mut single_node.children,
319            Node::Par(par_nodes) => &mut par_nodes.threads,
320        }
321    }
322
323    /// Get number of repeats.
324    fn get_num_repeats(&self) -> u64 {
325        match self {
326            Node::Single(single_node) => single_node.num_repeats,
327            Node::Par(par_nodes) => par_nodes.num_repeats,
328        }
329    }
330}
331
332/// `SingleNode` struct.
333pub struct SingleNode {
334    /// latency of one iteration.
335    pub latency: u64,
336    /// number of repeats. (So "total" latency = `latency` x `num_repeats`)
337    pub num_repeats: u64,
338    /// number of states in this node
339    pub num_states: u64,
340    /// (name of static group, assignments to build a corresponding dynamic group)
341    pub root: (ir::Id, Vec<ir::Assignment<Nothing>>),
342    ///  maps cycles (i,j) -> fsm state type.
343    ///  Here is an example FSM schedule:
344    ///                           Cycles     FSM State (i.e., `fsm.out`)
345    ///                           (0..10) ->  Normal[0,10)
346    ///                           (10..30) -> Offload(10) // Offloading to child
347    ///                           (30..40) -> Normal[11, 21)
348    ///                           (40,80) ->  Offload(21)
349    ///                           (80,100)->  Normal[22, 42)
350    pub fsm_schedule: BTreeMap<(u64, u64), StateType>,
351    /// vec of (Node Object, cycles for which that child is executing).
352    /// Note that you can build `fsm_schedule` from just this information,
353    /// but it's convenient to have `fsm_schedule` avaialable.
354    pub children: Vec<(Node, (u64, u64))>,
355    /// Keep track of where we are within a single iteration.
356    /// If `latency` == 1, then we don't need an `fsm_cell`.
357    pub fsm_cell: Option<ir::RRC<StaticFSM>>,
358    /// Keep track of which iteration we are on. If iteration count == 1, then
359    /// we don't need an `iter_count_cell`.
360    pub iter_count_cell: Option<ir::RRC<StaticFSM>>,
361}
362
363impl SingleNode {
364    /// Instantiates the necessary registers.
365    /// Because we share FSM registers, it's possible that this register has already
366    /// been instantiated.
367    /// Therefore we take in a bunch of data structures to keep track of coloring:
368    ///   - `coloring` that maps group names -> colors,
369    ///   - `colors_to_max_values` which maps colors -> (max latency, max_num_repeats)
370    ///     (we need to make sure that when we instantiate a color,
371    ///     we give enough bits to support the maximum latency/num_repeats that will be
372    ///     used for that color)
373    ///   - `colors_to_fsm`
374    ///     which maps colors to (fsm_register, iter_count_register): fsm_register counts
375    ///     up for a single iteration, iter_count_register counts the number of iterations
376    ///     that have passed.
377    ///
378    /// Note that it is not always necessary to instantiate one or both registers (e.g.,
379    /// if num_repeats == 1 then you don't need an iter_count_register).
380    ///
381    /// `one_hot_cutoff` is the cutoff to choose between binary and one hot encoding.
382    /// Any number of states greater than the cutoff will get binary encoding.
383    fn instantiate_fsms(
384        &mut self,
385        builder: &mut ir::Builder,
386        coloring: &HashMap<ir::Id, ir::Id>,
387        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
388        colors_to_fsm: &mut HashMap<
389            ir::Id,
390            (OptionalStaticFSM, OptionalStaticFSM),
391        >,
392        one_hot_cutoff: u64,
393    ) {
394        // Get color assigned to this node.
395        let color = coloring.get(&self.root.0).expect("couldn't find group");
396        // Check if we've already instantiated the registers for this color.
397        match colors_to_fsm.get(color) {
398            // We need to create the registers for the colors.
399            None => {
400                // First we get the maximum num_states and num_repeats
401                // for this color so we know how many bits each register needs.
402                let (num_states, num_repeats) = colors_to_max_values
403                    .get(color)
404                    .expect("Couldn't find color");
405                // Only need a `self.fsm_cell` if num_states > 1.
406                if *num_states != 1 {
407                    // Choose encoding based on one_hot_cutoff.
408                    let encoding = if *num_states > one_hot_cutoff {
409                        FSMEncoding::Binary
410                    } else {
411                        FSMEncoding::OneHot
412                    };
413                    let fsm_cell = ir::rrc(StaticFSM::from_basic_info(
414                        *num_states,
415                        encoding,
416                        builder,
417                    ));
418                    self.fsm_cell = Some(fsm_cell);
419                }
420                // Only need a `self.iter_count_cell` if num_states > 1.
421                if *num_repeats != 1 {
422                    let encoding = if *num_repeats > one_hot_cutoff {
423                        FSMEncoding::Binary
424                    } else {
425                        FSMEncoding::OneHot
426                    };
427                    let repeat_counter = ir::rrc(StaticFSM::from_basic_info(
428                        *num_repeats,
429                        encoding,
430                        builder,
431                    ));
432                    self.iter_count_cell = Some(repeat_counter);
433                }
434
435                // Insert into `colors_to_fsms` so the next time we call this method
436                // we see we've already instantiated the registers.
437                colors_to_fsm.insert(
438                    *color,
439                    (
440                        self.fsm_cell.as_ref().map(Rc::clone),
441                        self.iter_count_cell.as_ref().map(Rc::clone),
442                    ),
443                );
444            }
445            Some((fsm_option, repeat_option)) => {
446                // Trivially assign to `self.fsm_cell` and `self.iter_count_cell` since
447                // we've already created it.
448                self.fsm_cell = fsm_option.as_ref().map(Rc::clone);
449                self.iter_count_cell = repeat_option.as_ref().map(Rc::clone);
450            }
451        }
452
453        // Recursively instantiate fsms for all the children.
454        for (child, _) in &mut self.children {
455            child.instantiate_fsms(
456                builder,
457                coloring,
458                colors_to_max_values,
459                colors_to_fsm,
460                one_hot_cutoff,
461            );
462        }
463    }
464
465    /// Counts to n.
466    /// If `incr_start_cond.is_some()`, then we will add it as an extra
467    /// guard guarding the 0->1 transition.
468    fn count_to_n(
469        &mut self,
470        builder: &mut ir::Builder,
471        incr_start_cond: Option<ir::Guard<Nothing>>,
472    ) {
473        // res_vec will contain the assignments that count to n.
474        let mut res_vec: Vec<ir::Assignment<Nothing>> = Vec::new();
475
476        // Only need to count up to n if self.num_states > 1.
477        // If self.num_states == 1, then either a) latency is 1 cycle or b)
478        // we're just offloading the entire time (so the child will count).
479        // Either way, there's no need to instantiate a self.fsm_cell to count.
480        if self.num_states > 1 {
481            // `offload_states` are the fsm_states that last >1 cycles (i.e., states
482            // where children are executing, unless the child only lasts one cycle---
483            // then we can discount it as an "offload" state).
484            let offload_states: Vec<u64> = self
485                .fsm_schedule
486                .iter()
487                .filter_map(|((beg, end), state_type)| match state_type {
488                    StateType::Normal(_) => None,
489                    StateType::Offload(offload_state) => {
490                        // Only need to include the children that last more than one cycle.
491                        if beg + 1 == *end {
492                            None
493                        } else {
494                            Some(*offload_state)
495                        }
496                    }
497                })
498                .collect();
499
500            // There are two conditions under which we increment the FSM.
501            // 1) Increment when we are NOT in an offload state
502            // 2) Increment when we ARE in an offload state, but the child being offloaded
503            // is in its final state. (intuitively, we need to increment because
504            // we know the control is being passed back to parent in the next cycle).
505            // (when we are in the final state, we obviously should not increment:
506            // we should reset back to 0.)
507
508            let parent_fsm = Rc::clone(
509                self.fsm_cell
510                    .as_mut()
511                    .expect("should have set self.fsm_cell"),
512            );
513
514            // Build an adder to increment the parent fsm.
515            let (adder_asssigns, adder) =
516                parent_fsm.borrow_mut().build_incrementer(builder);
517            res_vec.extend(adder_asssigns);
518
519            // Handle situation 1). Increment when we are NOT in an offload state
520            res_vec.extend(self.increment_if_not_offloading(
521                incr_start_cond.clone(),
522                &offload_states,
523                Rc::clone(&adder),
524                Rc::clone(&parent_fsm),
525                builder,
526            ));
527
528            // Handle situation 2): Increment when we ARE in an offload state
529            // but the child being offloaded is in its final state.
530            res_vec.extend(self.increment_if_child_final_state(
531                &offload_states,
532                adder,
533                Rc::clone(&parent_fsm),
534                builder,
535            ));
536
537            // Reset the FSM when it is at its final fsm_state.
538            let final_fsm_state =
539                self.get_fsm_query((self.latency - 1, self.latency), builder);
540            res_vec.extend(
541                parent_fsm
542                    .borrow_mut()
543                    .conditional_reset(final_fsm_state, builder),
544            );
545        }
546
547        // If self.num_states > 1, then it's guaranteed that self.latency > 1.
548        // However, even if self.num_states == 1, self.latency might still be
549        // greater than 1 if we're just offloading the computation for the entire time.
550        // In this case, we still need the children to count to n.
551        if self.latency > 1 {
552            for (child, (beg, end)) in self.children.iter_mut() {
553                // If beg == 0 and end > 1 then we need to "transfer" the incr_start_condition
554                // to the child so it guards the 0->1 transition.
555                let cond = if *beg == 0 && *end > 1 {
556                    incr_start_cond.clone()
557                } else {
558                    None
559                };
560                // Recursively call `count_to_n`
561                child.count_to_n(builder, cond);
562            }
563        }
564
565        // Handle repeats (i.e., make sure we actually interate `self.num_repeats` times).
566        if self.num_repeats != 1 {
567            // If self.latency == 10, then we should increment the self.iter_count_cell
568            // each time fsm == 9, i.e., `final_fsm_state`.
569            let final_fsm_state =
570                self.get_fsm_query((self.latency - 1, self.latency), builder);
571
572            // `repeat_fsm` store number of iterations.
573            let repeat_fsm = Rc::clone(
574                self.iter_count_cell
575                    .as_mut()
576                    .expect("should have set self.iter_count_cell"),
577            );
578            // Build an incrementer to increment `self.iter_count_cell`.
579            let (repeat_adder_assigns, repeat_adder) =
580                repeat_fsm.borrow_mut().build_incrementer(builder);
581            // We shouldn't increment `self.iter_count_cell` if we are in the final iteration:
582            // we should reset it instead.
583            let final_repeat_state = *repeat_fsm.borrow_mut().query_between(
584                builder,
585                (self.num_repeats - 1, self.num_repeats),
586            );
587            let not_final_repeat_state = final_repeat_state.clone().not();
588            res_vec.extend(repeat_adder_assigns);
589            // Incrementing self.iter_count_cell when appropriate.
590            res_vec.extend(repeat_fsm.borrow_mut().conditional_increment(
591                final_fsm_state.clone().and(not_final_repeat_state),
592                repeat_adder,
593                builder,
594            ));
595            // Resetting self.iter_count_cell when appropriate.
596            res_vec.extend(repeat_fsm.borrow_mut().conditional_reset(
597                final_fsm_state.clone().and(final_repeat_state),
598                builder,
599            ));
600        }
601
602        // Extend root assigns to include `res_vec` (which counts to n).
603        self.root.1.extend(res_vec);
604    }
605
606    /// Helper to `count_to_n`
607    /// Increment when we are NOT in an offload state
608    /// e.g., if `offload_states` == [2,4,6] then
609    /// We should increment when !(fsm == 2 | fsm == 4 | fsm == 6).
610    /// There are a couple corner cases we need to think about (in particular,
611    /// we should guard the 0->1 transition differently if `incr_start_cond` is
612    /// some(), and we should reset rather than increment when we are in the final
613    /// fsm state).
614    fn increment_if_not_offloading(
615        &mut self,
616        incr_start_cond: Option<ir::Guard<Nothing>>,
617        offload_states: &[u64],
618        adder: ir::RRC<ir::Cell>,
619        parent_fsm: ir::RRC<StaticFSM>,
620        builder: &mut ir::Builder,
621    ) -> Vec<ir::Assignment<Nothing>> {
622        let mut res_vec = vec![];
623        let mut offload_state_guard: ir::Guard<Nothing> =
624            ir::Guard::Not(Box::new(ir::Guard::True));
625        for offload_state in offload_states {
626            // Creating a guard that checks whether the parent fsm is
627            // in an offload state.
628            offload_state_guard.update(|g| {
629                g.or(*parent_fsm.borrow_mut().query_between(
630                    builder,
631                    (*offload_state, offload_state + 1),
632                ))
633            });
634        }
635        let not_offload_state = offload_state_guard.not();
636
637        let mut incr_guard = not_offload_state;
638
639        // If incr_start_cond.is_some(), then we have to specially take into
640        // account this scenario when incrementing the FSM.
641        if let Some(g) = incr_start_cond.clone() {
642            // If we offload during the transition from cycle 0->1 transition
643            // then we don't need a special first transition guard.
644            // (we will make sure the child will add this guard when
645            // it is counting to n.)
646            if let Some(((beg, end), state_type)) =
647                self.fsm_schedule.iter().next()
648                && !(matches!(state_type, StateType::Offload(_))
649                    && *beg == 0
650                    && *end > 1)
651            {
652                let first_state = self.get_fsm_query((0, 1), builder);
653                // We must handle the 0->1 transition separately.
654                // fsm.in = fsm == 0 & incr_start_cond ? fsm + 1;
655                // fsm.write_en = fsm == 0 & incr_start_cond ? 1'd1;
656                res_vec.extend(parent_fsm.borrow_mut().conditional_increment(
657                    first_state.clone().and(g),
658                    Rc::clone(&adder),
659                    builder,
660                ));
661                // We also have to add fsm != 0 to incr_guard since
662                // we have just added assignments to handle this situation
663                // separately
664                incr_guard = incr_guard.and(first_state.not())
665            }
666        };
667
668        // We shouldn't increment when we are in the final state
669        // (we should be resetting instead).
670        // So we need to `& !in_final_state` to the guard.
671        let final_fsm_state =
672            self.get_fsm_query((self.latency - 1, self.latency), builder);
673        let not_final_state = final_fsm_state.not();
674
675        // However, if the final state is an offload state, then there's no need
676        // to make this extra check of not being in the last state.
677        if let Some((_, (_, end_final_child))) = self.children.last() {
678            // If the final state is not an offload state, then
679            // we need to add this check.
680            if *end_final_child != self.latency {
681                incr_guard = incr_guard.and(not_final_state);
682            }
683        } else {
684            // Also, if there is just no offloading, then we need to add this check.
685            incr_guard = incr_guard.and(not_final_state);
686        };
687
688        // Conditionally increment based on `incr_guard`
689        res_vec.extend(parent_fsm.borrow_mut().conditional_increment(
690            incr_guard,
691            Rc::clone(&adder),
692            builder,
693        ));
694
695        res_vec
696    }
697
698    /// Helper to `count_to_n`
699    /// Increment when we ARE in an offload state, but the child being
700    /// offloaded is in its final state.
701    fn increment_if_child_final_state(
702        &mut self,
703        offload_states: &[u64],
704        adder: ir::RRC<ir::Cell>,
705        parent_fsm: ir::RRC<StaticFSM>,
706        builder: &mut ir::Builder,
707    ) -> Vec<ir::Assignment<Nothing>> {
708        let mut res_vec = vec![];
709        for (i, (child, (_, end))) in self
710            .children
711            .iter_mut()
712            // If child only lasts a single cycle, then we can just unconditionally increment.
713            // We handle that case above (since `offload_states` only includes children that
714            // last more than one cycle).
715            .filter(|(_, (beg, end))| beg + 1 != *end)
716            .enumerate()
717        {
718            // We need to increment parent when child is in final state.
719            // For example, if the parent is offloading to `child_x` when it
720            // is in state 5, the guard would look like
721            // fsm.in = fsm == 5 && child_x_fsm_in_final_state ? fsm + 1;
722            // fsm.write_en == 5 && child_x_fsm_in_final_state ? 1'd1;
723
724            // The one exception:
725            // If the offload state is the last state (end == self.latency) then we don't
726            // increment, we need to reset to 0 (which we handle separately).
727            if *end != self.latency {
728                // We know each offload state corresponds to exactly one child.
729                let child_state = offload_states[i];
730                // Checking that we are in child state, e.g., `(fsm == 5)`
731                // in the above example.
732                let in_child_state = parent_fsm
733                    .borrow_mut()
734                    .query_between(builder, (child_state, child_state + 1));
735                // now we need to check `child_fsm_in_final_state`
736                let total_child_latency =
737                    child.get_latency() * child.get_num_repeats();
738                let child_final_state = child.query_between(
739                    (total_child_latency - 1, total_child_latency),
740                    builder,
741                );
742                // Conditionally increment when `fsm==5 & child_final_state`
743                let parent_fsm_incr =
744                    parent_fsm.borrow_mut().conditional_increment(
745                        in_child_state.and(child_final_state),
746                        Rc::clone(&adder),
747                        builder,
748                    );
749                res_vec.extend(parent_fsm_incr);
750            }
751        }
752        res_vec
753    }
754
755    /// `Realize` each static group in the tree into a dynamic group.
756    /// In particular, this involves converting %[i:j] guards into actual
757    /// fsm register queries (which can get complicated with out tree structure:
758    /// it's not just i <= fsm < j anymore).
759    ///
760    /// `reset_early_map`, `fsm_info_map`, and `group_rewrites` are all
761    /// metadata to make it more easier later on when we are rewriting control,
762    ///  adding wrapper groups when necessary, etc.
763    fn realize(
764        &mut self,
765        ignore_timing_guards: bool,
766        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
767        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
768        fsm_info_map: &mut HashMap<
769            ir::Id,
770            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
771        >,
772        group_rewrites: &mut ir::rewriter::PortRewriteMap,
773        builder: &mut ir::Builder,
774    ) {
775        // Get static group we are "realizing".
776        let static_group = Rc::clone(
777            static_groups
778                .iter()
779                .find(|sgroup| sgroup.borrow().name() == self.root.0)
780                .expect("couldn't find static group"),
781        );
782        // Create the dynamic "early reset group" that will replace the static group.
783        let static_group_name = static_group.borrow().name();
784        let mut early_reset_name = static_group_name.to_string();
785        early_reset_name.insert_str(0, "early_reset_");
786        let early_reset_group = builder.add_group(early_reset_name);
787
788        // Realize the static %[i:j] guards to fsm queries.
789        // *This is the most of the difficult thing the function does*.
790        // This is significantly more complicated with a tree structure.
791        let mut assigns = static_group
792            .borrow()
793            .assignments
794            .clone()
795            .into_iter()
796            .map(|assign| {
797                self.make_assign_dyn(
798                    assign,
799                    false,
800                    ignore_timing_guards,
801                    builder,
802                )
803            })
804            .collect_vec();
805
806        // Add assignment `group[done] = ud.out`` to the new group.
807        structure!( builder; let ud = prim undef(1););
808        let early_reset_done_assign = build_assignments!(
809          builder;
810          early_reset_group["done"] = ? ud["out"];
811        );
812        assigns.extend(early_reset_done_assign);
813
814        // Adding the assignments of `self.root` (mainly the `count_to_n`
815        // assignments).
816        assigns.extend(std::mem::take(&mut self.root.1));
817        self.root.1 = assigns.clone();
818
819        early_reset_group.borrow_mut().assignments = assigns;
820        early_reset_group.borrow_mut().attributes =
821            static_group.borrow().attributes.clone();
822
823        // Now we have to update the fields with a bunch of information.
824        // This makes it easier when we have to build wrappers, rewrite ports, etc.
825
826        // Map the static group name -> early reset group name.
827        reset_early_map
828            .insert(static_group_name, early_reset_group.borrow().name());
829        // self.group_rewrite_map helps write static_group[go] to early_reset_group[go]
830        // Technically we could do this w/ early_reset_map but is easier w/
831        // group_rewrite, which is explicitly of type `PortRewriterMap`
832        group_rewrites.insert(
833            ir::Canonical::new(static_group_name, ir::Id::from("go")),
834            early_reset_group.borrow().find("go").unwrap_or_else(|| {
835                unreachable!(
836                    "group {} has no go port",
837                    early_reset_group.borrow().name()
838                )
839            }),
840        );
841
842        let fsm_identifier = match self.fsm_cell.as_ref() {
843            // If the tree does not have an fsm cell, then we can err on the
844            // side of giving it its own unique identifier.
845            None => self.root.0,
846            Some(fsm_rc) => fsm_rc.borrow().get_unique_id(),
847        };
848        let total_latency = self.latency * self.num_repeats;
849        fsm_info_map.insert(
850            early_reset_group.borrow().name(),
851            (
852                fsm_identifier,
853                self.query_between((0, 1), builder),
854                self.query_between((total_latency - 1, total_latency), builder),
855            ),
856        );
857
858        // Recursively realize each child.
859        self.children.iter_mut().for_each(|(child, _)| {
860            child.realize(
861                ignore_timing_guards,
862                static_groups,
863                reset_early_map,
864                fsm_info_map,
865                group_rewrites,
866                builder,
867            )
868        })
869    }
870
871    // Rephrasing an (i,j) query: this breaks up the guard and makes it easier
872    // to figure out what logic we need to instantiate to perform the query.
873    // Restructure an (i,j) query into:
874    // (beg, middle, end) query.
875    // This is best explained by example.
876    // Suppose latency = 5, num repeats = 10.
877    // Suppose we query %[3:32].
878    // beg = Some(0, 3-5). 0 bc we are on the 0th iteration,
879    // and only cycles 3-5 of that iteration.
880    // middle = Some([1,6)). These are the iterations for which the query is true
881    // throughout the entire iteration.
882    // end = Some(6,0-2). 6 bc 6th iteration, 0-2 because only cycles 0-2 of that
883    // iteration.
884    fn restructure_query(
885        &self,
886        query: (u64, u64),
887    ) -> (
888        Option<SingleIterQuery>,
889        Option<ItersQuery>,
890        Option<SingleIterQuery>,
891    ) {
892        // Splitting the query into an fsm query and and iteration query.
893        // (beg_iter_query, end_iter_query) is an inclusive (both sides) query
894        // on the iterations we are active for.
895        // (beg_fsm_query, end_fsm_query) is the fsm query we should be supporting.
896        let (beg_query, end_query) = query;
897        let (beg_iter_query, beg_fsm_query) =
898            (beg_query / self.latency, beg_query % self.latency);
899        let (end_iter_query, mut end_fsm_query) =
900            ((end_query - 1) / self.latency, (end_query) % self.latency);
901        if end_fsm_query == 0 {
902            end_fsm_query = self.latency;
903        }
904
905        // Scenario 1: the query spans only a single iteration.
906        // In this case, we set beg_query to
907        // `Some(<that single iteration>, (beg_fsm_query->end_fsm_query))``
908        // and set middle and end to None.
909        if beg_iter_query == end_iter_query {
910            let repeat_query = beg_iter_query;
911            let fsm_query = (beg_fsm_query, end_fsm_query);
912            let res = Some((repeat_query, fsm_query));
913            (res, None, None)
914        }
915        // Scenario 2: the query spans only 2 iterations.
916        // In this case, we only need a beg_query and an end_query, but no
917        // middle query.
918        else if beg_iter_query + 1 == end_iter_query {
919            let middle_res = None;
920
921            let repeat_query0 = beg_iter_query;
922            // We know the beg_query stretches into the next iteration,
923            // so we can end it at self.latency.
924            let fsm_query0 = (beg_fsm_query, self.latency);
925            let beg_res = Some((repeat_query0, fsm_query0));
926
927            let repeat_query1 = end_iter_query;
928            // We know the end_query stretches backwards into the previous iteration,
929            // so we can start it at 0.
930            let fsm_query1 = (0, end_fsm_query);
931            let end_res = Some((repeat_query1, fsm_query1));
932
933            (beg_res, middle_res, end_res)
934        }
935        // Scenario 3: the query spans 3 or more iterations.
936        // In this case, we need the middle_query for the middle iterations,
937        // and the beg and end queries for (parts of) the
938        // first and last iterations for this query.
939        else {
940            let mut unconditional_repeat_query =
941                (beg_iter_query + 1, end_iter_query);
942
943            let repeat_query0 = beg_iter_query;
944            // We know the beg_query stretches into the next iteration,
945            // so we can end it at self.latency.
946            let fsm_query0 = (beg_fsm_query, self.latency);
947            let mut beg_res = Some((repeat_query0, fsm_query0));
948            // if beg_fsm_query == 0, then beg_query spans the entire iterations,
949            // so we can just add it the unconditional_repeat_query (i.e., the middle_query).
950            if beg_fsm_query == 0 {
951                beg_res = None;
952                unconditional_repeat_query.0 -= 1;
953            }
954
955            let repeat_query1 = end_iter_query;
956            // We know the end_query stretches backwards into the previous iteration,
957            // so we can start it at 0.
958            let fsm_query1 = (0, end_fsm_query);
959            let mut end_res = Some((repeat_query1, fsm_query1));
960            // If end_fsm_query == self.latency, then end_res spans the entire iterations,
961            // so we can just add it the unconditional_repeat_query (i.e., the middle_query).
962            if end_fsm_query == self.latency {
963                end_res = None;
964                unconditional_repeat_query.1 += 1;
965            }
966
967            (beg_res, Some(unconditional_repeat_query), end_res)
968        }
969    }
970
971    // Given query (i,j), get the fsm query for cycles (i,j).
972    // Does NOT check the iteration number.
973    // This is greatly complicated by the offloading to children.
974    // We use a resturcturing that organizes the query into (beg, middle, end),
975    // similar to (but not the same as) self.restructure_query().
976    fn get_fsm_query(
977        &mut self,
978        query: (u64, u64),
979        builder: &mut ir::Builder,
980    ) -> ir::Guard<Nothing> {
981        // If guard is true the entire execution, then return `true`.
982        if 0 == query.0 && self.latency == query.1 {
983            return ir::Guard::True;
984        }
985
986        let fsm_cell_opt = self.fsm_cell.as_ref();
987        if fsm_cell_opt.is_none() {
988            // If there is no fsm cell even though latency > 1, then we must
989            // have offloaded the entire latency. Therefore we just need
990            // to query the child.
991            assert!(self.offload_entire_latency());
992            let (only_child, _) = self.children.iter_mut().next().unwrap();
993            return only_child.query_between(query, builder);
994        }
995
996        let fsm_cell: Rc<std::cell::RefCell<StaticFSM>> =
997            Rc::clone(fsm_cell_opt.expect("just checked if None"));
998
999        let (query_beg, query_end) = query;
1000        let mut beg_interval = ir::Guard::True.not();
1001        let mut end_interval = ir::Guard::True.not();
1002        let mut middle_interval = None;
1003        let mut child_index = 0;
1004        // Suppose fsm_schedule =    Cycles     FSM State (i.e., `fsm.out`)
1005        //                           (0..10) ->  Normal[0,10)
1006        //                           (10..30) -> Offload(10) // Offloading to child
1007        //                           (30..40) -> Normal[11, 21)
1008        //                           (40,80) ->  Offload(21)
1009        //                           (80,100)->  Normal[22, 42)
1010        // And query = (15,95).
1011        // Then at the end of the following `for` loop we want:
1012        // `beg_interval` should be fsm == 10 && <child.query_between(5,20)>
1013        // `middle_interval` should be (11, 22)
1014        // `end_interval` should be 22 <= fsm < 37
1015        for ((beg, end), state_type) in self.fsm_schedule.iter() {
1016            // Check if the query encompasses the entire interval.
1017            // If so, we add it to the "middle" interval.
1018            if query_beg <= *beg && *end <= query_end {
1019                // Get the interval we have to add, based on `state_type`.
1020                let interval_to_add = match state_type {
1021                    StateType::Normal(fsm_interval) => *fsm_interval,
1022                    StateType::Offload(offload_state) => {
1023                        (*offload_state, offload_state + 1)
1024                    }
1025                };
1026                // Add `interval_to_add` to `middle_interval`.
1027                match middle_interval {
1028                    None => middle_interval = Some(interval_to_add),
1029                    Some((cur_start, cur_end)) => {
1030                        assert!(cur_end == interval_to_add.0);
1031                        middle_interval = Some((cur_start, interval_to_add.1));
1032                    }
1033                }
1034            }
1035            // Otherwise check if the beginning of the query lies within the
1036            // interval (This should only happen once). Add to `beg_interval`.
1037            else if *beg <= query_beg && query_beg < *end {
1038                assert!(beg_interval.is_false());
1039                // This is the query, but relativized to the start of the current interval.
1040                let relative_query = (query_beg - beg, query_end - beg);
1041                match state_type {
1042                    // If we are not offloading, then we can just produce a normal
1043                    // query.
1044                    StateType::Normal((beg_fsm_interval, end_fsm_interval)) => {
1045                        let translated_query = (
1046                            beg_fsm_interval + relative_query.0,
1047                            // This query either stretches into the next interval, or
1048                            // ends within the interval: we want to capture both of these choices.
1049                            std::cmp::min(
1050                                beg_fsm_interval + relative_query.1,
1051                                *end_fsm_interval,
1052                            ),
1053                        );
1054                        beg_interval = *fsm_cell
1055                            .borrow_mut()
1056                            .query_between(builder, translated_query);
1057                    }
1058                    // If we are not offloading, then we first check the state,
1059                    // then we must query the corresponding child.
1060                    StateType::Offload(offload_state) => {
1061                        let in_offload_state =
1062                            *fsm_cell.borrow_mut().query_between(
1063                                builder,
1064                                (*offload_state, offload_state + 1),
1065                            );
1066                        let (child, _) =
1067                            self.children.get_mut(child_index).unwrap();
1068                        let child_query = child.query_between(
1069                            (
1070                                relative_query.0,
1071                                // This query either stretches into another interval, or
1072                                // ends within the interval: we want to capture both of these choices.
1073                                std::cmp::min(
1074                                    relative_query.1,
1075                                    child.get_latency()
1076                                        * child.get_num_repeats(),
1077                                ),
1078                            ),
1079                            builder,
1080                        );
1081                        beg_interval = in_offload_state.and(child_query);
1082                    }
1083                };
1084            }
1085            // Check if the end of the query lies within the
1086            // interval (This should only happen once) Add to `end_interval`.
1087            else if *beg < query_end && query_end <= *end {
1088                // We only need the end of the relative query.
1089                // If we try to get the beginning then we could get overflow error.
1090                let relative_query_end = query_end - beg;
1091                assert!(end_interval.is_false());
1092                match state_type {
1093                    StateType::Normal((beg_fsm_interval, _)) => {
1094                        end_interval = *fsm_cell.borrow_mut().query_between(
1095                            builder,
1096                            // This query must stretch backwards into a preiouvs interval
1097                            // Otherwise it would have been caught by the
1098                            // So beg_fsm_interval is a safe start.
1099                            (
1100                                *beg_fsm_interval,
1101                                beg_fsm_interval + relative_query_end,
1102                            ),
1103                        );
1104                    }
1105                    StateType::Offload(offload_state) => {
1106                        let in_offload_state =
1107                            *fsm_cell.borrow_mut().query_between(
1108                                builder,
1109                                (*offload_state, offload_state + 1),
1110                            );
1111                        let (child, _) =
1112                            self.children.get_mut(child_index).unwrap();
1113                        // We know this must stretch backwards
1114                        // into a previous interval: otherwise, it
1115                        // would have been caught by the previous elif condition.
1116                        // therefore, we can start the child query at 0.
1117                        let child_query = child
1118                            .query_between((0, relative_query_end), builder);
1119                        end_interval = in_offload_state.and(child_query);
1120                    }
1121                };
1122            }
1123            if matches!(state_type, StateType::Offload(_)) {
1124                child_index += 1;
1125            }
1126        }
1127
1128        // Turn `middle_interval` into an actual `ir::Guard`.
1129        let middle_query = match middle_interval {
1130            None => Box::new(ir::Guard::True.not()),
1131            Some((i, j)) => self
1132                .fsm_cell
1133                .as_mut()
1134                .unwrap()
1135                .borrow_mut()
1136                .query_between(builder, (i, j)),
1137        };
1138
1139        beg_interval.or(end_interval.or(*middle_query))
1140    }
1141
1142    // Produces a guard that checks whether query.0 <= self.iter_count_cell < query.1
1143    fn get_repeat_query(
1144        &mut self,
1145        query: (u64, u64),
1146        builder: &mut ir::Builder,
1147    ) -> Box<ir::Guard<Nothing>> {
1148        // If self.num_repeats == 1, then no need for a complicated query.
1149        match self.num_repeats {
1150            1 => {
1151                assert!(query.0 == 0 && query.1 == 1);
1152                Box::new(ir::Guard::True)
1153            }
1154            _ => self
1155                .iter_count_cell
1156                .as_mut()
1157                .expect("querying repeat implies cell exists")
1158                .borrow_mut()
1159                .query_between(builder, (query.0, query.1)),
1160        }
1161    }
1162
1163    // Produce a guard that checks:
1164    //   - whether iteration == repeat_query AND
1165    //   - whether %[fsm_query.0:fsm_query.1]
1166    fn check_iteration_and_fsm_state(
1167        &mut self,
1168        (repeat_query, fsm_query): (u64, (u64, u64)),
1169        builder: &mut ir::Builder,
1170    ) -> ir::Guard<Nothing> {
1171        let fsm_guard = self.get_fsm_query(fsm_query, builder);
1172
1173        // Checks `self.iter_count_cell`.
1174        let counter_guard =
1175            self.get_repeat_query((repeat_query, repeat_query + 1), builder);
1176        ir::Guard::And(Box::new(fsm_guard), counter_guard)
1177    }
1178
1179    // Converts a %[i:j] query into a query of `self`'s and its childrens
1180    // iteration registers.
1181    fn query_between(
1182        &mut self,
1183        query: (u64, u64),
1184        builder: &mut ir::Builder,
1185    ) -> ir::Guard<Nothing> {
1186        // See `restructure_query` to see what we're doing.
1187        // But basically:
1188        // beg_iter_query = Option(iteration number, cycles during that iteration the query is true).
1189        // middle_iter_query = Option(iterations during which the query is true the entire iteration).
1190        // end_iter_query = Option(iteration number, cycles during that iteration the query is true).
1191        let (beg_iter_query, middle_iter_query, end_iter_query) =
1192            self.restructure_query(query);
1193
1194        // Call `check_iteration_and_fsm_state` for beg and end queries.
1195        let g0 = match beg_iter_query {
1196            None => ir::Guard::True.not(),
1197            Some(q0) => self.check_iteration_and_fsm_state(q0, builder),
1198        };
1199        let g1 = match end_iter_query {
1200            None => ir::Guard::True.not(),
1201            Some(q1) => self.check_iteration_and_fsm_state(q1, builder),
1202        };
1203
1204        // Call `get_repeat_query` for middle_iter_queries.
1205        let rep_query = match middle_iter_query {
1206            None => Box::new(ir::Guard::True.not()),
1207            Some(rq) => self.get_repeat_query(rq, builder),
1208        };
1209        g0.or(g1.or(*rep_query))
1210    }
1211
1212    // Takes in a static guard `guard`, and returns equivalent dynamic guard
1213    // The only thing that actually changes is the Guard::Info case
1214    // We need to turn static_timing to dynamic guards using `fsm`.
1215    // See `make_assign_dyn` for explanations of `global_view` and `ignore_timing`
1216    // variable.
1217    fn make_guard_dyn(
1218        &mut self,
1219        guard: ir::Guard<ir::StaticTiming>,
1220        global_view: bool,
1221        ignore_timing: bool,
1222        builder: &mut ir::Builder,
1223    ) -> Box<ir::Guard<Nothing>> {
1224        match guard {
1225            ir::Guard::Or(l, r) => Box::new(ir::Guard::Or(
1226                self.make_guard_dyn(*l, global_view, ignore_timing, builder),
1227                self.make_guard_dyn(*r, global_view, ignore_timing, builder),
1228            )),
1229            ir::Guard::And(l, r) => Box::new(ir::Guard::And(
1230                self.make_guard_dyn(*l, global_view, ignore_timing, builder),
1231                self.make_guard_dyn(*r, global_view, ignore_timing, builder),
1232            )),
1233            ir::Guard::Not(g) => Box::new(ir::Guard::Not(self.make_guard_dyn(
1234                *g,
1235                global_view,
1236                ignore_timing,
1237                builder,
1238            ))),
1239            ir::Guard::CompOp(op, l, r) => {
1240                Box::new(ir::Guard::CompOp(op, l, r))
1241            }
1242            ir::Guard::Port(p) => Box::new(ir::Guard::Port(p)),
1243            ir::Guard::True => Box::new(ir::Guard::True),
1244            ir::Guard::Info(static_timing) => {
1245                // If `ignore_timing` is true, then just return a true guard.
1246                if ignore_timing {
1247                    assert!(static_timing.get_interval() == (0, 1));
1248                    return Box::new(ir::Guard::True);
1249                }
1250                if global_view {
1251                    // For global_view we call `query_between`
1252                    Box::new(
1253                        self.query_between(
1254                            static_timing.get_interval(),
1255                            builder,
1256                        ),
1257                    )
1258                } else {
1259                    // For local_view we call `get_fsm_query`
1260                    Box::new(
1261                        self.get_fsm_query(
1262                            static_timing.get_interval(),
1263                            builder,
1264                        ),
1265                    )
1266                }
1267            }
1268        }
1269    }
1270
1271    /// Takes in static assignment `assign` and returns a dynamic assignments
1272    /// For example, it could transform the guard %[2:3] -> fsm.out >= 2 & fsm.out <= 3
1273    /// `global_view`: are you just querying for a given iteration, or are
1274    /// you querying for the entire tree's execution?
1275    ///   - if `global_view` is true, then you have to include the iteration
1276    ///     count register in the assignment's guard.
1277    ///   - if `global_view` is false, then you dont' have to include it
1278    ///
1279    /// `ignore_timing`: remove static timing guards instead of transforming them
1280    /// into an FSM query. Note that in order to do this, the timing guard must
1281    /// equal %[0:1], otherwise we will throw an error. This option is here
1282    /// mainly to save resource usage.
1283    pub fn make_assign_dyn(
1284        &mut self,
1285        assign: ir::Assignment<ir::StaticTiming>,
1286        global_view: bool,
1287        ignore_timing: bool,
1288        builder: &mut ir::Builder,
1289    ) -> ir::Assignment<Nothing> {
1290        ir::Assignment {
1291            src: assign.src,
1292            dst: assign.dst,
1293            attributes: assign.attributes,
1294            guard: self.make_guard_dyn(
1295                *assign.guard,
1296                global_view,
1297                ignore_timing,
1298                builder,
1299            ),
1300        }
1301    }
1302
1303    // Helper function: checks
1304    // whether the tree offloads its entire latency, and returns the
1305    // appropriate `bool`.
1306    fn offload_entire_latency(&self) -> bool {
1307        self.children.len() == 1
1308            && self
1309                .children
1310                .iter()
1311                .any(|(_, (beg, end))| *beg == 0 && *end == self.latency)
1312                // This last check is prob unnecessary since it follows from the first two.
1313            && self.num_states == 1
1314    }
1315}
1316
1317/// These methods handle adding conflicts to the tree (to help coloring for
1318/// sharing FSMs)
1319impl SingleNode {
1320    // Get names of groups corresponding to all nodes
1321    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
1322        let mut res = vec![self.root.0];
1323        for (child, _) in &self.children {
1324            res.extend(child.get_all_nodes())
1325        }
1326        res
1327    }
1328
1329    // Adds conflicts between children and any descendents.
1330    // Also add conflicts between any overlapping children. XXX(Caleb): normally
1331    // there shouldn't be overlapping children, but when we are doing the traditional
1332    // method in we don't offload (and therefore don't need this tree structure)
1333    // I have created dummy trees for the sole purpose of drawing conflicts
1334    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
1335        let root_name = self.root.0;
1336        for (child, _) in &self.children {
1337            for sgroup in &child.get_all_nodes() {
1338                conflict_graph.insert_conflict(&root_name, sgroup);
1339            }
1340            child.add_conflicts(conflict_graph);
1341        }
1342        // Adding conflicts between overlapping children.
1343        for ((child_a, (beg_a, end_a)), (child_b, (beg_b, end_b))) in
1344            self.children.iter().tuple_combinations()
1345        {
1346            // Checking if children overlap: either b begins within a, it
1347            // ends within a, or it encompasses a's entire interval.
1348            if ((beg_a <= beg_b) & (beg_b < end_a))
1349                | ((beg_a < end_b) & (end_b <= end_a))
1350                | (beg_b <= beg_a && end_a <= end_b)
1351            {
1352                // Adding conflicts between all nodes of the children if
1353                // the children overlap.
1354                for a_node in child_a.get_all_nodes() {
1355                    for b_node in child_b.get_all_nodes() {
1356                        conflict_graph.insert_conflict(&a_node, &b_node);
1357                    }
1358                }
1359            }
1360        }
1361    }
1362
1363    // Gets max value according to some function f.
1364    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
1365    where
1366        F: Fn(&SingleNode) -> u64,
1367    {
1368        let mut cur_max = 0;
1369        if self.root.0 == name {
1370            cur_max = std::cmp::max(cur_max, f(self));
1371        }
1372        for (child, _) in &self.children {
1373            cur_max = std::cmp::max(cur_max, child.get_max_value(name, f));
1374        }
1375        cur_max
1376    }
1377}
1378
1379/// Represents a group of `Nodes` that execute in parallel.
1380pub struct ParNodes {
1381    /// Name of the `par_group` that fires off the threads
1382    pub group_name: ir::Id,
1383    /// Latency
1384    pub latency: u64,
1385    /// Num Repeats
1386    pub num_repeats: u64,
1387    /// (Thread, interval thread is active).
1388    /// Interval thread is active should always start at 0.
1389    pub threads: Vec<(Node, (u64, u64))>,
1390}
1391
1392impl ParNodes {
1393    /// Instantiates FSMs by recursively instantiating FSM for each thread.
1394    pub fn instantiate_fsms(
1395        &mut self,
1396        builder: &mut ir::Builder,
1397        coloring: &HashMap<ir::Id, ir::Id>,
1398        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
1399        colors_to_fsm: &mut HashMap<
1400            ir::Id,
1401            (OptionalStaticFSM, OptionalStaticFSM),
1402        >,
1403        one_hot_cutoff: u64,
1404    ) {
1405        for (thread, _) in &mut self.threads {
1406            thread.instantiate_fsms(
1407                builder,
1408                coloring,
1409                colors_to_max_values,
1410                colors_to_fsm,
1411                one_hot_cutoff,
1412            );
1413        }
1414    }
1415
1416    /// Counts to N by recursively calling `count_to_n` on each thread.
1417    pub fn count_to_n(
1418        &mut self,
1419        builder: &mut ir::Builder,
1420        incr_start_cond: Option<ir::Guard<Nothing>>,
1421    ) {
1422        for (thread, _) in &mut self.threads {
1423            thread.count_to_n(builder, incr_start_cond.clone());
1424        }
1425    }
1426
1427    /// Realizes static groups into dynamic group.
1428    pub fn realize(
1429        &mut self,
1430        ignore_timing_guards: bool,
1431        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
1432        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
1433        fsm_info_map: &mut HashMap<
1434            ir::Id,
1435            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
1436        >,
1437        group_rewrites: &mut ir::rewriter::PortRewriteMap,
1438        builder: &mut ir::Builder,
1439    ) {
1440        // Get static grouo we are "realizing".
1441        let static_group = Rc::clone(
1442            static_groups
1443                .iter()
1444                .find(|sgroup| sgroup.borrow().name() == self.group_name)
1445                .expect("couldn't find static group"),
1446        );
1447        // Create the dynamic "early reset group" that will replace the static group.
1448        let static_group_name = static_group.borrow().name();
1449        let mut early_reset_name = static_group_name.to_string();
1450        early_reset_name.insert_str(0, "early_reset_");
1451        let early_reset_group = builder.add_group(early_reset_name);
1452
1453        // Get the longest node.
1454        let longest_node = self.get_longest_node();
1455
1456        // If one thread lasts 10 cycles, and another lasts 5 cycles, then the par_group
1457        // will look like this:
1458        // static<10> group par_group {
1459        //   thread1[go] = 1'd1;
1460        //   thread2[go] = %[0:5] ? 1'd1;
1461        // }
1462        // Therefore the %[0:5] needs to be realized using the FSMs from thread1 (the
1463        // longest FSM).
1464        let mut assigns = static_group
1465            .borrow()
1466            .assignments
1467            .clone()
1468            .into_iter()
1469            .map(|assign| {
1470                longest_node.make_assign_dyn(
1471                    assign,
1472                    true,
1473                    ignore_timing_guards,
1474                    builder,
1475                )
1476            })
1477            .collect_vec();
1478
1479        // Add assignment `group[done] = ud.out`` to the new group.
1480        structure!( builder; let ud = prim undef(1););
1481        let early_reset_done_assign = build_assignments!(
1482          builder;
1483          early_reset_group["done"] = ? ud["out"];
1484        );
1485        assigns.extend(early_reset_done_assign);
1486
1487        early_reset_group.borrow_mut().assignments = assigns;
1488        early_reset_group.borrow_mut().attributes =
1489            static_group.borrow().attributes.clone();
1490
1491        // Now we have to update the fields with a bunch of information.
1492        // This makes it easier when we have to build wrappers, rewrite ports, etc.
1493
1494        // Map the static group name -> early reset group name.
1495        // This is helpful for rewriting control
1496        reset_early_map
1497            .insert(static_group_name, early_reset_group.borrow().name());
1498        // self.group_rewrite_map helps write static_group[go] to early_reset_group[go]
1499        // Technically we could do this w/ early_reset_map but is easier w/
1500        // group_rewrite, which is explicitly of type `PortRewriterMap`
1501        group_rewrites.insert(
1502            ir::Canonical::new(static_group_name, ir::Id::from("go")),
1503            early_reset_group.borrow().find("go").unwrap_or_else(|| {
1504                unreachable!(
1505                    "group {} has no go port",
1506                    early_reset_group.borrow().name()
1507                )
1508            }),
1509        );
1510
1511        let fsm_identifier = match longest_node.fsm_cell.as_ref() {
1512            // If the tree does not have an fsm cell, then we can err on the
1513            // side of giving it its own unique identifier.
1514            None => longest_node.root.0,
1515            Some(fsm_rc) => fsm_rc.borrow().get_unique_id(),
1516        };
1517
1518        let total_latency = self.latency * self.num_repeats;
1519        fsm_info_map.insert(
1520            early_reset_group.borrow().name(),
1521            (
1522                fsm_identifier,
1523                self.query_between((0, 1), builder),
1524                self.query_between((total_latency - 1, total_latency), builder),
1525            ),
1526        );
1527
1528        // Recursively realize each child.
1529        self.threads.iter_mut().for_each(|(child, _)| {
1530            child.realize(
1531                ignore_timing_guards,
1532                static_groups,
1533                reset_early_map,
1534                fsm_info_map,
1535                group_rewrites,
1536                builder,
1537            )
1538        })
1539    }
1540
1541    /// Recursively searches each thread to get the longest (in terms of
1542    /// cycle counts) SingleNode.
1543    pub fn get_longest_node(&mut self) -> &mut SingleNode {
1544        let max = self.threads.iter_mut().max_by_key(|(child, _)| {
1545            (child.get_latency() * child.get_num_repeats()) as i64
1546        });
1547        if let Some((max_child, _)) = max {
1548            match max_child {
1549                Node::Par(par_nodes) => par_nodes.get_longest_node(),
1550                Node::Single(single_node) => single_node,
1551            }
1552        } else {
1553            unreachable!("self.children is empty/no maximum value found");
1554        }
1555    }
1556
1557    /// Use the longest node to query between.
1558    pub fn query_between(
1559        &mut self,
1560        query: (u64, u64),
1561        builder: &mut ir::Builder,
1562    ) -> ir::Guard<Nothing> {
1563        let longest_node = self.get_longest_node();
1564        longest_node.query_between(query, builder)
1565    }
1566}
1567
1568/// Used to add conflicts for graph coloring for sharing FSMs.
1569/// See the equivalent SingleNode implementation for more details.
1570impl ParNodes {
1571    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
1572        let mut res = vec![];
1573        for (thread, _) in &self.threads {
1574            res.extend(thread.get_all_nodes())
1575        }
1576        res
1577    }
1578
1579    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
1580        for ((thread1, _), (thread2, _)) in
1581            self.threads.iter().tuple_combinations()
1582        {
1583            for sgroup1 in thread1.get_all_nodes() {
1584                for sgroup2 in thread2.get_all_nodes() {
1585                    conflict_graph.insert_conflict(&sgroup1, &sgroup2);
1586                }
1587            }
1588            thread1.add_conflicts(conflict_graph);
1589            thread2.add_conflicts(conflict_graph);
1590        }
1591    }
1592
1593    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
1594    where
1595        F: Fn(&SingleNode) -> u64,
1596    {
1597        let mut cur_max = 0;
1598        for (thread, _) in &self.threads {
1599            cur_max = std::cmp::max(cur_max, thread.get_max_value(name, f));
1600        }
1601        cur_max
1602    }
1603}