calyx_opt/analysis/
static_tree.rs

1use super::{FSMEncoding, StaticFSM};
2use calyx_ir::{self as ir};
3use calyx_ir::{Nothing, build_assignments};
4use calyx_ir::{guard, structure};
5use itertools::Itertools;
6use std::collections::{BTreeMap, HashMap};
7use std::ops::Not;
8use std::rc::Rc;
9
10use super::GraphColoring;
11
12/// Optional Rc of a RefCell of a StaticFSM object.
13type OptionalStaticFSM = Option<ir::RRC<StaticFSM>>;
14/// Query (i ,(j,k)) that corresponds to:
15/// Am I in iteration i, and between cylces j and k within
16/// that query?
17type SingleIterQuery = (u64, (u64, u64));
18/// Query (i,j) that corresponds to:
19/// Am I between iterations i and j, inclusive?
20type ItersQuery = (u64, u64);
21
22/// Helpful for translating queries for the FSMTree structure.
23/// Because of the tree structure, %[i:j] is no longer is always equal to i <= fsm < j.
24/// Offload(i) means the FSM is offloading when fsm == i: so if the fsm == i,
25/// we need to look at the children to know what cycle we are in exactly.
26/// Normal(i,j) means the FSM is outputing (i..j), incrementing each cycle (i.e.,
27/// like normal) and not offloading. Note that even though the FSM is outputting
28/// i..j each cycle, that does not necesarily mean we are in cycles i..j (due
29/// to offloading performed in the past.)
30#[derive(Debug)]
31pub enum StateType {
32    Normal((u64, u64)),
33    Offload(u64),
34}
35
36/// Node can either be a SingleNode (i.e., a single node) or ParNodes (i.e., a group of
37/// nodes that are executing in parallel).
38/// Most methods in `Node` simply call the equivalent methods for each
39/// of the two possible variants.
40/// Perhaps could be more compactly implemented as a Trait.
41pub enum Node {
42    Single(SingleNode),
43    Par(ParNodes),
44}
45
46// The following methods are used to actually instantiate the FSMTree structure
47// and compile static groups/control to dynamic groups/control.
48impl Node {
49    /// Instantiate the necessary registers.
50    /// The equivalent methods for the two variants contain more implementation
51    /// details.
52    /// `coloring`, `colors_to_max_values`, and `colors_to_fsm` are necessary
53    /// to know whether we actually need to instantiate a new FSM, or we can
54    /// juse use another node's FSM.
55    pub fn instantiate_fsms(
56        &mut self,
57        builder: &mut ir::Builder,
58        coloring: &HashMap<ir::Id, ir::Id>,
59        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
60        colors_to_fsm: &mut HashMap<
61            ir::Id,
62            (OptionalStaticFSM, OptionalStaticFSM),
63        >,
64        one_hot_cutoff: u64,
65    ) {
66        match self {
67            Node::Single(single_node) => single_node.instantiate_fsms(
68                builder,
69                coloring,
70                colors_to_max_values,
71                colors_to_fsm,
72                one_hot_cutoff,
73            ),
74            Node::Par(par_nodes) => par_nodes.instantiate_fsms(
75                builder,
76                coloring,
77                colors_to_max_values,
78                colors_to_fsm,
79                one_hot_cutoff,
80            ),
81        }
82    }
83
84    /// Count to n. Need to call `instantiate_fsms` before calling `count_to_n`.
85    /// The equivalent methods for the two variants contain more implementation
86    /// details.
87    /// `incr_start_cond` can optionally guard the 0->1 transition.
88    pub fn count_to_n(
89        &mut self,
90        builder: &mut ir::Builder,
91        incr_start_cond: Option<ir::Guard<Nothing>>,
92    ) {
93        match self {
94            Node::Single(single_node) => {
95                single_node.count_to_n(builder, incr_start_cond)
96            }
97            Node::Par(par_nodes) => {
98                par_nodes.count_to_n(builder, incr_start_cond)
99            }
100        }
101    }
102
103    /// "Realize" the static groups into dynamic groups.
104    /// The main challenge is converting %[i:j] into fsm guards.
105    /// Need to call `instantiate_fsms` and
106    /// `count_to_n` before calling `realize`.
107    /// The equivalent methods for the two variants contain more implementation
108    /// details.
109    /// `reset_early_map`, `fsm_info_map`, and `group_rewrites` are just metadata
110    /// to make it easier to rewrite control, add wrappers, etc.
111    pub fn realize(
112        &mut self,
113        ignore_timing_guards: bool,
114        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
115        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
116        fsm_info_map: &mut HashMap<
117            ir::Id,
118            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
119        >,
120        group_rewrites: &mut ir::rewriter::PortRewriteMap,
121        builder: &mut ir::Builder,
122    ) {
123        match self {
124            Node::Single(single_node) => single_node.realize(
125                ignore_timing_guards,
126                static_groups,
127                reset_early_map,
128                fsm_info_map,
129                group_rewrites,
130                builder,
131            ),
132            Node::Par(par_nodes) => par_nodes.realize(
133                ignore_timing_guards,
134                static_groups,
135                reset_early_map,
136                fsm_info_map,
137                group_rewrites,
138                builder,
139            ),
140        }
141    }
142
143    /// Get the equivalent fsm guard when the tree is between cycles i and j, i.e.,
144    /// when i <= cycle_count < j.
145    /// The equivalent methods for the two variants contain more implementation
146    /// details.
147    pub fn query_between(
148        &mut self,
149        query: (u64, u64),
150        builder: &mut ir::Builder,
151    ) -> ir::Guard<Nothing> {
152        match self {
153            Node::Single(single_node) => {
154                single_node.query_between(query, builder)
155            }
156            Node::Par(par_nodes) => par_nodes.query_between(query, builder),
157        }
158    }
159}
160
161/// The following methods are used to help build the conflict graph for coloring
162/// to share FSMs
163impl Node {
164    /// Get the names of all nodes (i.e., the names of the groups for each node
165    /// in the tree).
166    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
167        match self {
168            Node::Single(single_node) => single_node.get_all_nodes(),
169            Node::Par(par_nodes) => par_nodes.get_all_nodes(),
170        }
171    }
172
173    /// Adds conflicts between nodes in the tree that execute at the same time.
174    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
175        match self {
176            Node::Single(single_node) => {
177                single_node.add_conflicts(conflict_graph)
178            }
179            Node::Par(par_nodes) => par_nodes.add_conflicts(conflict_graph),
180        }
181    }
182
183    /// Get max value of all nodes in the tree, according to some function f.
184    /// `f` takes in a Tree (i.e., a node type) and returns a `u64`.
185    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
186    where
187        F: Fn(&SingleNode) -> u64,
188    {
189        match self {
190            Node::Single(single_node) => single_node.get_max_value(name, f),
191            Node::Par(par_nodes) => par_nodes.get_max_value(name, f),
192        }
193    }
194}
195
196// Used to compile static component interface. This is really annoying to do, since
197// for static components, they only need to be guarded for %0, while for static
198// groups, they need to be guarded for %[0:n]. This creates some annoying `if`
199// statements.
200impl Node {
201    // Helper to `preprocess_static_interface_assigns`
202    // Looks recursively thru guard to transform %[0:n] into %0 | %[1:n].
203    fn preprocess_static_interface_guard(
204        guard: ir::Guard<ir::StaticTiming>,
205        comp_sig: ir::RRC<ir::Cell>,
206    ) -> ir::Guard<ir::StaticTiming> {
207        match guard {
208            ir::Guard::Info(st) => {
209                let (beg, end) = st.get_interval();
210                if beg == 0 {
211                    // Replace %[0:n] -> (%0 & comp.go) | %[1:n]
212                    // Cannot just do comp.go | %[1:n] because we want
213                    // clients to be able to assert `comp.go` even after the first
214                    // cycle w/o affecting correctness.
215                    let first_cycle =
216                        ir::Guard::Info(ir::StaticTiming::new((0, 1)));
217                    let comp_go = guard!(comp_sig["go"]);
218                    let first_and_go = ir::Guard::and(comp_go, first_cycle);
219                    if end == 1 {
220                        return first_and_go;
221                    } else {
222                        let after =
223                            ir::Guard::Info(ir::StaticTiming::new((1, end)));
224                        let cong = ir::Guard::or(first_and_go, after);
225                        return cong;
226                    }
227                }
228                guard
229            }
230            ir::Guard::And(l, r) => {
231                let left = Self::preprocess_static_interface_guard(
232                    *l,
233                    Rc::clone(&comp_sig),
234                );
235                let right =
236                    Self::preprocess_static_interface_guard(*r, comp_sig);
237                ir::Guard::and(left, right)
238            }
239            ir::Guard::Or(l, r) => {
240                let left = Self::preprocess_static_interface_guard(
241                    *l,
242                    Rc::clone(&comp_sig),
243                );
244                let right =
245                    Self::preprocess_static_interface_guard(*r, comp_sig);
246                ir::Guard::or(left, right)
247            }
248            ir::Guard::Not(g) => {
249                let a = Self::preprocess_static_interface_guard(*g, comp_sig);
250                ir::Guard::Not(Box::new(a))
251            }
252            _ => guard,
253        }
254    }
255
256    // Looks recursively thru assignment's guard to %[0:n] into %0 | %[1:n].
257    pub fn preprocess_static_interface_assigns(
258        assign: &mut ir::Assignment<ir::StaticTiming>,
259        comp_sig: ir::RRC<ir::Cell>,
260    ) {
261        assign
262            .guard
263            .update(|g| Self::preprocess_static_interface_guard(g, comp_sig));
264    }
265}
266
267// The following are just standard `getter` methods.
268impl Node {
269    /// Take the assignments of the root of the tree and return them.
270    /// This only works on a single node (i.e., the `Tree`` variant).
271    pub fn take_root_assigns(&mut self) -> Vec<ir::Assignment<Nothing>> {
272        match self {
273            Node::Single(single_node) => {
274                std::mem::take(&mut single_node.root.1)
275            }
276            Node::Par(_) => {
277                unreachable!(
278                    "Cannot take root assignments of Node::Par variant"
279                )
280            }
281        }
282    }
283
284    /// Get the name of the root of the tree and return them.
285    /// This only works on a single node (i.e., the `Tree`` variant).
286    pub fn get_root_name(&mut self) -> ir::Id {
287        match self {
288            Node::Single(single_node) => single_node.root.0,
289            Node::Par(_) => {
290                unreachable!("Cannot take root name of Node::Par variant")
291            }
292        }
293    }
294
295    /// Get the name of the group at the root of the tree (if a `Tree` variant) or
296    /// of the equivalent `par` group (i.e., the name of the group that triggers
297    /// execution of all the trees) if a `Par` variant.
298    pub fn get_group_name(&self) -> ir::Id {
299        match self {
300            Node::Single(single_node) => single_node.root.0,
301            Node::Par(par_nodes) => par_nodes.group_name,
302        }
303    }
304
305    /// Gets latency of the overall tree.
306    pub fn get_latency(&self) -> u64 {
307        match self {
308            Node::Single(single_node) => single_node.latency,
309            Node::Par(par_nodes) => par_nodes.latency,
310        }
311    }
312
313    /// Gets the children of root of the tree (if a `Tree` variant) or
314    /// of the threads (i.e., trees) that are scheduled to execute (if a `Par`
315    /// variant.)
316    pub fn get_children(&mut self) -> &mut Vec<(Node, (u64, u64))> {
317        match self {
318            Node::Single(single_node) => &mut single_node.children,
319            Node::Par(par_nodes) => &mut par_nodes.threads,
320        }
321    }
322
323    /// Get number of repeats.
324    fn get_num_repeats(&self) -> u64 {
325        match self {
326            Node::Single(single_node) => single_node.num_repeats,
327            Node::Par(par_nodes) => par_nodes.num_repeats,
328        }
329    }
330}
331
332/// `SingleNode` struct.
333pub struct SingleNode {
334    /// latency of one iteration.
335    pub latency: u64,
336    /// number of repeats. (So "total" latency = `latency` x `num_repeats`)
337    pub num_repeats: u64,
338    /// number of states in this node
339    pub num_states: u64,
340    /// (name of static group, assignments to build a corresponding dynamic group)
341    pub root: (ir::Id, Vec<ir::Assignment<Nothing>>),
342    ///  maps cycles (i,j) -> fsm state type.
343    ///  Here is an example FSM schedule:
344    ///                           Cycles     FSM State (i.e., `fsm.out`)
345    ///                           (0..10) ->  Normal[0,10)
346    ///                           (10..30) -> Offload(10) // Offloading to child
347    ///                           (30..40) -> Normal[11, 21)
348    ///                           (40,80) ->  Offload(21)
349    ///                           (80,100)->  Normal[22, 42)
350    pub fsm_schedule: BTreeMap<(u64, u64), StateType>,
351    /// vec of (Node Object, cycles for which that child is executing).
352    /// Note that you can build `fsm_schedule` from just this information,
353    /// but it's convenient to have `fsm_schedule` avaialable.
354    pub children: Vec<(Node, (u64, u64))>,
355    /// Keep track of where we are within a single iteration.
356    /// If `latency` == 1, then we don't need an `fsm_cell`.
357    pub fsm_cell: Option<ir::RRC<StaticFSM>>,
358    /// Keep track of which iteration we are on. If iteration count == 1, then
359    /// we don't need an `iter_count_cell`.
360    pub iter_count_cell: Option<ir::RRC<StaticFSM>>,
361}
362
363impl SingleNode {
364    /// Instantiates the necessary registers.
365    /// Because we share FSM registers, it's possible that this register has already
366    /// been instantiated.
367    /// Therefore we take in a bunch of data structures to keep track of coloring:
368    ///   - `coloring` that maps group names -> colors,
369    ///   - `colors_to_max_values` which maps colors -> (max latency, max_num_repeats)
370    ///     (we need to make sure that when we instantiate a color,
371    ///     we give enough bits to support the maximum latency/num_repeats that will be
372    ///     used for that color)
373    ///   - `colors_to_fsm`
374    ///     which maps colors to (fsm_register, iter_count_register): fsm_register counts
375    ///     up for a single iteration, iter_count_register counts the number of iterations
376    ///     that have passed.
377    ///
378    /// Note that it is not always necessary to instantiate one or both registers (e.g.,
379    /// if num_repeats == 1 then you don't need an iter_count_register).
380    ///
381    /// `one_hot_cutoff` is the cutoff to choose between binary and one hot encoding.
382    /// Any number of states greater than the cutoff will get binary encoding.
383    fn instantiate_fsms(
384        &mut self,
385        builder: &mut ir::Builder,
386        coloring: &HashMap<ir::Id, ir::Id>,
387        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
388        colors_to_fsm: &mut HashMap<
389            ir::Id,
390            (OptionalStaticFSM, OptionalStaticFSM),
391        >,
392        one_hot_cutoff: u64,
393    ) {
394        // Get color assigned to this node.
395        let color = coloring.get(&self.root.0).expect("couldn't find group");
396        // Check if we've already instantiated the registers for this color.
397        match colors_to_fsm.get(color) {
398            // We need to create the registers for the colors.
399            None => {
400                // First we get the maximum num_states and num_repeats
401                // for this color so we know how many bits each register needs.
402                let (num_states, num_repeats) = colors_to_max_values
403                    .get(color)
404                    .expect("Couldn't find color");
405                // Only need a `self.fsm_cell` if num_states > 1.
406                if *num_states != 1 {
407                    // Choose encoding based on one_hot_cutoff.
408                    let encoding = if *num_states > one_hot_cutoff {
409                        FSMEncoding::Binary
410                    } else {
411                        FSMEncoding::OneHot
412                    };
413                    let fsm_cell = ir::rrc(StaticFSM::from_basic_info(
414                        *num_states,
415                        encoding,
416                        builder,
417                    ));
418                    self.fsm_cell = Some(fsm_cell);
419                }
420                // Only need a `self.iter_count_cell` if num_states > 1.
421                if *num_repeats != 1 {
422                    let encoding = if *num_repeats > one_hot_cutoff {
423                        FSMEncoding::Binary
424                    } else {
425                        FSMEncoding::OneHot
426                    };
427                    let repeat_counter = ir::rrc(StaticFSM::from_basic_info(
428                        *num_repeats,
429                        encoding,
430                        builder,
431                    ));
432                    self.iter_count_cell = Some(repeat_counter);
433                }
434
435                // Insert into `colors_to_fsms` so the next time we call this method
436                // we see we've already instantiated the registers.
437                colors_to_fsm.insert(
438                    *color,
439                    (
440                        self.fsm_cell.as_ref().map(Rc::clone),
441                        self.iter_count_cell.as_ref().map(Rc::clone),
442                    ),
443                );
444            }
445            Some((fsm_option, repeat_option)) => {
446                // Trivially assign to `self.fsm_cell` and `self.iter_count_cell` since
447                // we've already created it.
448                self.fsm_cell = fsm_option.as_ref().map(Rc::clone);
449                self.iter_count_cell = repeat_option.as_ref().map(Rc::clone);
450            }
451        }
452
453        // Recursively instantiate fsms for all the children.
454        for (child, _) in &mut self.children {
455            child.instantiate_fsms(
456                builder,
457                coloring,
458                colors_to_max_values,
459                colors_to_fsm,
460                one_hot_cutoff,
461            );
462        }
463    }
464
465    /// Counts to n.
466    /// If `incr_start_cond.is_some()`, then we will add it as an extra
467    /// guard guarding the 0->1 transition.
468    fn count_to_n(
469        &mut self,
470        builder: &mut ir::Builder,
471        incr_start_cond: Option<ir::Guard<Nothing>>,
472    ) {
473        // res_vec will contain the assignments that count to n.
474        let mut res_vec: Vec<ir::Assignment<Nothing>> = Vec::new();
475
476        // Only need to count up to n if self.num_states > 1.
477        // If self.num_states == 1, then either a) latency is 1 cycle or b)
478        // we're just offloading the entire time (so the child will count).
479        // Either way, there's no need to instantiate a self.fsm_cell to count.
480        if self.num_states > 1 {
481            // `offload_states` are the fsm_states that last >1 cycles (i.e., states
482            // where children are executing, unless the child only lasts one cycle---
483            // then we can discount it as an "offload" state).
484            let offload_states: Vec<u64> = self
485                .fsm_schedule
486                .iter()
487                .filter_map(|((beg, end), state_type)| match state_type {
488                    StateType::Normal(_) => None,
489                    StateType::Offload(offload_state) => {
490                        // Only need to include the children that last more than one cycle.
491                        if beg + 1 == *end {
492                            None
493                        } else {
494                            Some(*offload_state)
495                        }
496                    }
497                })
498                .collect();
499
500            // There are two conditions under which we increment the FSM.
501            // 1) Increment when we are NOT in an offload state
502            // 2) Increment when we ARE in an offload state, but the child being offloaded
503            // is in its final state. (intuitively, we need to increment because
504            // we know the control is being passed back to parent in the next cycle).
505            // (when we are in the final state, we obviously should not increment:
506            // we should reset back to 0.)
507
508            let parent_fsm = Rc::clone(
509                self.fsm_cell
510                    .as_mut()
511                    .expect("should have set self.fsm_cell"),
512            );
513
514            // Build an adder to increment the parent fsm.
515            let (adder_asssigns, adder) =
516                parent_fsm.borrow_mut().build_incrementer(builder);
517            res_vec.extend(adder_asssigns);
518
519            // Handle situation 1). Increment when we are NOT in an offload state
520            res_vec.extend(self.increment_if_not_offloading(
521                incr_start_cond.clone(),
522                &offload_states,
523                Rc::clone(&adder),
524                Rc::clone(&parent_fsm),
525                builder,
526            ));
527
528            // Handle situation 2): Increment when we ARE in an offload state
529            // but the child being offloaded is in its final state.
530            res_vec.extend(self.increment_if_child_final_state(
531                &offload_states,
532                adder,
533                Rc::clone(&parent_fsm),
534                builder,
535            ));
536
537            // Reset the FSM when it is at its final fsm_state.
538            let final_fsm_state =
539                self.get_fsm_query((self.latency - 1, self.latency), builder);
540            res_vec.extend(
541                parent_fsm
542                    .borrow_mut()
543                    .conditional_reset(final_fsm_state, builder),
544            );
545        }
546
547        // If self.num_states > 1, then it's guaranteed that self.latency > 1.
548        // However, even if self.num_states == 1, self.latency might still be
549        // greater than 1 if we're just offloading the computation for the entire time.
550        // In this case, we still need the children to count to n.
551        if self.latency > 1 {
552            for (child, (beg, end)) in self.children.iter_mut() {
553                // If beg == 0 and end > 1 then we need to "transfer" the incr_start_condition
554                // to the child so it guards the 0->1 transition.
555                let cond = if *beg == 0 && *end > 1 {
556                    incr_start_cond.clone()
557                } else {
558                    None
559                };
560                // Recursively call `count_to_n`
561                child.count_to_n(builder, cond);
562            }
563        }
564
565        // Handle repeats (i.e., make sure we actually interate `self.num_repeats` times).
566        if self.num_repeats != 1 {
567            // If self.latency == 10, then we should increment the self.iter_count_cell
568            // each time fsm == 9, i.e., `final_fsm_state`.
569            let final_fsm_state =
570                self.get_fsm_query((self.latency - 1, self.latency), builder);
571
572            // `repeat_fsm` store number of iterations.
573            let repeat_fsm = Rc::clone(
574                self.iter_count_cell
575                    .as_mut()
576                    .expect("should have set self.iter_count_cell"),
577            );
578            // Build an incrementer to increment `self.iter_count_cell`.
579            let (repeat_adder_assigns, repeat_adder) =
580                repeat_fsm.borrow_mut().build_incrementer(builder);
581            // We shouldn't increment `self.iter_count_cell` if we are in the final iteration:
582            // we should reset it instead.
583            let final_repeat_state = *repeat_fsm.borrow_mut().query_between(
584                builder,
585                (self.num_repeats - 1, self.num_repeats),
586            );
587            let not_final_repeat_state = final_repeat_state.clone().not();
588            res_vec.extend(repeat_adder_assigns);
589            // Incrementing self.iter_count_cell when appropriate.
590            res_vec.extend(repeat_fsm.borrow_mut().conditional_increment(
591                final_fsm_state.clone().and(not_final_repeat_state),
592                repeat_adder,
593                builder,
594            ));
595            // Resetting self.iter_count_cell when appropriate.
596            res_vec.extend(repeat_fsm.borrow_mut().conditional_reset(
597                final_fsm_state.clone().and(final_repeat_state),
598                builder,
599            ));
600        }
601
602        // Extend root assigns to include `res_vec` (which counts to n).
603        self.root.1.extend(res_vec);
604    }
605
606    /// Helper to `count_to_n`
607    /// Increment when we are NOT in an offload state
608    /// e.g., if `offload_states` == [2,4,6] then
609    /// We should increment when !(fsm == 2 | fsm == 4 | fsm == 6).
610    /// There are a couple corner cases we need to think about (in particular,
611    /// we should guard the 0->1 transition differently if `incr_start_cond` is
612    /// some(), and we should reset rather than increment when we are in the final
613    /// fsm state).
614    fn increment_if_not_offloading(
615        &mut self,
616        incr_start_cond: Option<ir::Guard<Nothing>>,
617        offload_states: &[u64],
618        adder: ir::RRC<ir::Cell>,
619        parent_fsm: ir::RRC<StaticFSM>,
620        builder: &mut ir::Builder,
621    ) -> Vec<ir::Assignment<Nothing>> {
622        let mut res_vec = vec![];
623        let mut offload_state_guard: ir::Guard<Nothing> =
624            ir::Guard::Not(Box::new(ir::Guard::True));
625        for offload_state in offload_states {
626            // Creating a guard that checks whether the parent fsm is
627            // in an offload state.
628            offload_state_guard.update(|g| {
629                g.or(*parent_fsm.borrow_mut().query_between(
630                    builder,
631                    (*offload_state, offload_state + 1),
632                ))
633            });
634        }
635        let not_offload_state = offload_state_guard.not();
636
637        let mut incr_guard = not_offload_state;
638
639        // If incr_start_cond.is_some(), then we have to specially take into
640        // account this scenario when incrementing the FSM.
641        if let Some(g) = incr_start_cond.clone() {
642            // If we offload during the transition from cycle 0->1 transition
643            // then we don't need a special first transition guard.
644            // (we will make sure the child will add this guard when
645            // it is counting to n.)
646            if let Some(((beg, end), state_type)) =
647                self.fsm_schedule.iter().next()
648            {
649                if !(matches!(state_type, StateType::Offload(_))
650                    && *beg == 0
651                    && *end > 1)
652                {
653                    let first_state = self.get_fsm_query((0, 1), builder);
654                    // We must handle the 0->1 transition separately.
655                    // fsm.in = fsm == 0 & incr_start_cond ? fsm + 1;
656                    // fsm.write_en = fsm == 0 & incr_start_cond ? 1'd1;
657                    res_vec.extend(
658                        parent_fsm.borrow_mut().conditional_increment(
659                            first_state.clone().and(g),
660                            Rc::clone(&adder),
661                            builder,
662                        ),
663                    );
664                    // We also have to add fsm != 0 to incr_guard since
665                    // we have just added assignments to handle this situation
666                    // separately
667                    incr_guard = incr_guard.and(first_state.not())
668                }
669            }
670        };
671
672        // We shouldn't increment when we are in the final state
673        // (we should be resetting instead).
674        // So we need to `& !in_final_state` to the guard.
675        let final_fsm_state =
676            self.get_fsm_query((self.latency - 1, self.latency), builder);
677        let not_final_state = final_fsm_state.not();
678
679        // However, if the final state is an offload state, then there's no need
680        // to make this extra check of not being in the last state.
681        if let Some((_, (_, end_final_child))) = self.children.last() {
682            // If the final state is not an offload state, then
683            // we need to add this check.
684            if *end_final_child != self.latency {
685                incr_guard = incr_guard.and(not_final_state);
686            }
687        } else {
688            // Also, if there is just no offloading, then we need to add this check.
689            incr_guard = incr_guard.and(not_final_state);
690        };
691
692        // Conditionally increment based on `incr_guard`
693        res_vec.extend(parent_fsm.borrow_mut().conditional_increment(
694            incr_guard,
695            Rc::clone(&adder),
696            builder,
697        ));
698
699        res_vec
700    }
701
702    /// Helper to `count_to_n`
703    /// Increment when we ARE in an offload state, but the child being
704    /// offloaded is in its final state.
705    fn increment_if_child_final_state(
706        &mut self,
707        offload_states: &[u64],
708        adder: ir::RRC<ir::Cell>,
709        parent_fsm: ir::RRC<StaticFSM>,
710        builder: &mut ir::Builder,
711    ) -> Vec<ir::Assignment<Nothing>> {
712        let mut res_vec = vec![];
713        for (i, (child, (_, end))) in self
714            .children
715            .iter_mut()
716            // If child only lasts a single cycle, then we can just unconditionally increment.
717            // We handle that case above (since `offload_states` only includes children that
718            // last more than one cycle).
719            .filter(|(_, (beg, end))| beg + 1 != *end)
720            .enumerate()
721        {
722            // We need to increment parent when child is in final state.
723            // For example, if the parent is offloading to `child_x` when it
724            // is in state 5, the guard would look like
725            // fsm.in = fsm == 5 && child_x_fsm_in_final_state ? fsm + 1;
726            // fsm.write_en == 5 && child_x_fsm_in_final_state ? 1'd1;
727
728            // The one exception:
729            // If the offload state is the last state (end == self.latency) then we don't
730            // increment, we need to reset to 0 (which we handle separately).
731            if *end != self.latency {
732                // We know each offload state corresponds to exactly one child.
733                let child_state = offload_states[i];
734                // Checking that we are in child state, e.g., `(fsm == 5)`
735                // in the above example.
736                let in_child_state = parent_fsm
737                    .borrow_mut()
738                    .query_between(builder, (child_state, child_state + 1));
739                // now we need to check `child_fsm_in_final_state`
740                let total_child_latency =
741                    child.get_latency() * child.get_num_repeats();
742                let child_final_state = child.query_between(
743                    (total_child_latency - 1, total_child_latency),
744                    builder,
745                );
746                // Conditionally increment when `fsm==5 & child_final_state`
747                let parent_fsm_incr =
748                    parent_fsm.borrow_mut().conditional_increment(
749                        in_child_state.and(child_final_state),
750                        Rc::clone(&adder),
751                        builder,
752                    );
753                res_vec.extend(parent_fsm_incr);
754            }
755        }
756        res_vec
757    }
758
759    /// `Realize` each static group in the tree into a dynamic group.
760    /// In particular, this involves converting %[i:j] guards into actual
761    /// fsm register queries (which can get complicated with out tree structure:
762    /// it's not just i <= fsm < j anymore).
763    ///
764    /// `reset_early_map`, `fsm_info_map`, and `group_rewrites` are all
765    /// metadata to make it more easier later on when we are rewriting control,
766    ///  adding wrapper groups when necessary, etc.
767    fn realize(
768        &mut self,
769        ignore_timing_guards: bool,
770        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
771        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
772        fsm_info_map: &mut HashMap<
773            ir::Id,
774            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
775        >,
776        group_rewrites: &mut ir::rewriter::PortRewriteMap,
777        builder: &mut ir::Builder,
778    ) {
779        // Get static group we are "realizing".
780        let static_group = Rc::clone(
781            static_groups
782                .iter()
783                .find(|sgroup| sgroup.borrow().name() == self.root.0)
784                .expect("couldn't find static group"),
785        );
786        // Create the dynamic "early reset group" that will replace the static group.
787        let static_group_name = static_group.borrow().name();
788        let mut early_reset_name = static_group_name.to_string();
789        early_reset_name.insert_str(0, "early_reset_");
790        let early_reset_group = builder.add_group(early_reset_name);
791
792        // Realize the static %[i:j] guards to fsm queries.
793        // *This is the most of the difficult thing the function does*.
794        // This is significantly more complicated with a tree structure.
795        let mut assigns = static_group
796            .borrow()
797            .assignments
798            .clone()
799            .into_iter()
800            .map(|assign| {
801                self.make_assign_dyn(
802                    assign,
803                    false,
804                    ignore_timing_guards,
805                    builder,
806                )
807            })
808            .collect_vec();
809
810        // Add assignment `group[done] = ud.out`` to the new group.
811        structure!( builder; let ud = prim undef(1););
812        let early_reset_done_assign = build_assignments!(
813          builder;
814          early_reset_group["done"] = ? ud["out"];
815        );
816        assigns.extend(early_reset_done_assign);
817
818        // Adding the assignments of `self.root` (mainly the `count_to_n`
819        // assignments).
820        assigns.extend(std::mem::take(&mut self.root.1));
821        self.root.1 = assigns.clone();
822
823        early_reset_group.borrow_mut().assignments = assigns;
824        early_reset_group.borrow_mut().attributes =
825            static_group.borrow().attributes.clone();
826
827        // Now we have to update the fields with a bunch of information.
828        // This makes it easier when we have to build wrappers, rewrite ports, etc.
829
830        // Map the static group name -> early reset group name.
831        reset_early_map
832            .insert(static_group_name, early_reset_group.borrow().name());
833        // self.group_rewrite_map helps write static_group[go] to early_reset_group[go]
834        // Technically we could do this w/ early_reset_map but is easier w/
835        // group_rewrite, which is explicitly of type `PortRewriterMap`
836        group_rewrites.insert(
837            ir::Canonical::new(static_group_name, ir::Id::from("go")),
838            early_reset_group.borrow().find("go").unwrap_or_else(|| {
839                unreachable!(
840                    "group {} has no go port",
841                    early_reset_group.borrow().name()
842                )
843            }),
844        );
845
846        let fsm_identifier = match self.fsm_cell.as_ref() {
847            // If the tree does not have an fsm cell, then we can err on the
848            // side of giving it its own unique identifier.
849            None => self.root.0,
850            Some(fsm_rc) => fsm_rc.borrow().get_unique_id(),
851        };
852        let total_latency = self.latency * self.num_repeats;
853        fsm_info_map.insert(
854            early_reset_group.borrow().name(),
855            (
856                fsm_identifier,
857                self.query_between((0, 1), builder),
858                self.query_between((total_latency - 1, total_latency), builder),
859            ),
860        );
861
862        // Recursively realize each child.
863        self.children.iter_mut().for_each(|(child, _)| {
864            child.realize(
865                ignore_timing_guards,
866                static_groups,
867                reset_early_map,
868                fsm_info_map,
869                group_rewrites,
870                builder,
871            )
872        })
873    }
874
875    // Rephrasing an (i,j) query: this breaks up the guard and makes it easier
876    // to figure out what logic we need to instantiate to perform the query.
877    // Restructure an (i,j) query into:
878    // (beg, middle, end) query.
879    // This is best explained by example.
880    // Suppose latency = 5, num repeats = 10.
881    // Suppose we query %[3:32].
882    // beg = Some(0, 3-5). 0 bc we are on the 0th iteration,
883    // and only cycles 3-5 of that iteration.
884    // middle = Some([1,6)). These are the iterations for which the query is true
885    // throughout the entire iteration.
886    // end = Some(6,0-2). 6 bc 6th iteration, 0-2 because only cycles 0-2 of that
887    // iteration.
888    fn restructure_query(
889        &self,
890        query: (u64, u64),
891    ) -> (
892        Option<SingleIterQuery>,
893        Option<ItersQuery>,
894        Option<SingleIterQuery>,
895    ) {
896        // Splitting the query into an fsm query and and iteration query.
897        // (beg_iter_query, end_iter_query) is an inclusive (both sides) query
898        // on the iterations we are active for.
899        // (beg_fsm_query, end_fsm_query) is the fsm query we should be supporting.
900        let (beg_query, end_query) = query;
901        let (beg_iter_query, beg_fsm_query) =
902            (beg_query / self.latency, beg_query % self.latency);
903        let (end_iter_query, mut end_fsm_query) =
904            ((end_query - 1) / self.latency, (end_query) % self.latency);
905        if end_fsm_query == 0 {
906            end_fsm_query = self.latency;
907        }
908
909        // Scenario 1: the query spans only a single iteration.
910        // In this case, we set beg_query to
911        // `Some(<that single iteration>, (beg_fsm_query->end_fsm_query))``
912        // and set middle and end to None.
913        if beg_iter_query == end_iter_query {
914            let repeat_query = beg_iter_query;
915            let fsm_query = (beg_fsm_query, end_fsm_query);
916            let res = Some((repeat_query, fsm_query));
917            (res, None, None)
918        }
919        // Scenario 2: the query spans only 2 iterations.
920        // In this case, we only need a beg_query and an end_query, but no
921        // middle query.
922        else if beg_iter_query + 1 == end_iter_query {
923            let middle_res = None;
924
925            let repeat_query0 = beg_iter_query;
926            // We know the beg_query stretches into the next iteration,
927            // so we can end it at self.latency.
928            let fsm_query0 = (beg_fsm_query, self.latency);
929            let beg_res = Some((repeat_query0, fsm_query0));
930
931            let repeat_query1 = end_iter_query;
932            // We know the end_query stretches backwards into the previous iteration,
933            // so we can start it at 0.
934            let fsm_query1 = (0, end_fsm_query);
935            let end_res = Some((repeat_query1, fsm_query1));
936
937            (beg_res, middle_res, end_res)
938        }
939        // Scenario 3: the query spans 3 or more iterations.
940        // In this case, we need the middle_query for the middle iterations,
941        // and the beg and end queries for (parts of) the
942        // first and last iterations for this query.
943        else {
944            let mut unconditional_repeat_query =
945                (beg_iter_query + 1, end_iter_query);
946
947            let repeat_query0 = beg_iter_query;
948            // We know the beg_query stretches into the next iteration,
949            // so we can end it at self.latency.
950            let fsm_query0 = (beg_fsm_query, self.latency);
951            let mut beg_res = Some((repeat_query0, fsm_query0));
952            // if beg_fsm_query == 0, then beg_query spans the entire iterations,
953            // so we can just add it the unconditional_repeat_query (i.e., the middle_query).
954            if beg_fsm_query == 0 {
955                beg_res = None;
956                unconditional_repeat_query.0 -= 1;
957            }
958
959            let repeat_query1 = end_iter_query;
960            // We know the end_query stretches backwards into the previous iteration,
961            // so we can start it at 0.
962            let fsm_query1 = (0, end_fsm_query);
963            let mut end_res = Some((repeat_query1, fsm_query1));
964            // If end_fsm_query == self.latency, then end_res spans the entire iterations,
965            // so we can just add it the unconditional_repeat_query (i.e., the middle_query).
966            if end_fsm_query == self.latency {
967                end_res = None;
968                unconditional_repeat_query.1 += 1;
969            }
970
971            (beg_res, Some(unconditional_repeat_query), end_res)
972        }
973    }
974
975    // Given query (i,j), get the fsm query for cycles (i,j).
976    // Does NOT check the iteration number.
977    // This is greatly complicated by the offloading to children.
978    // We use a resturcturing that organizes the query into (beg, middle, end),
979    // similar to (but not the same as) self.restructure_query().
980    fn get_fsm_query(
981        &mut self,
982        query: (u64, u64),
983        builder: &mut ir::Builder,
984    ) -> ir::Guard<Nothing> {
985        // If guard is true the entire execution, then return `true`.
986        if 0 == query.0 && self.latency == query.1 {
987            return ir::Guard::True;
988        }
989
990        let fsm_cell_opt = self.fsm_cell.as_ref();
991        if fsm_cell_opt.is_none() {
992            // If there is no fsm cell even though latency > 1, then we must
993            // have offloaded the entire latency. Therefore we just need
994            // to query the child.
995            assert!(self.offload_entire_latency());
996            let (only_child, _) = self.children.iter_mut().next().unwrap();
997            return only_child.query_between(query, builder);
998        }
999
1000        let fsm_cell: Rc<std::cell::RefCell<StaticFSM>> =
1001            Rc::clone(fsm_cell_opt.expect("just checked if None"));
1002
1003        let (query_beg, query_end) = query;
1004        let mut beg_interval = ir::Guard::True.not();
1005        let mut end_interval = ir::Guard::True.not();
1006        let mut middle_interval = None;
1007        let mut child_index = 0;
1008        // Suppose fsm_schedule =    Cycles     FSM State (i.e., `fsm.out`)
1009        //                           (0..10) ->  Normal[0,10)
1010        //                           (10..30) -> Offload(10) // Offloading to child
1011        //                           (30..40) -> Normal[11, 21)
1012        //                           (40,80) ->  Offload(21)
1013        //                           (80,100)->  Normal[22, 42)
1014        // And query = (15,95).
1015        // Then at the end of the following `for` loop we want:
1016        // `beg_interval` should be fsm == 10 && <child.query_between(5,20)>
1017        // `middle_interval` should be (11, 22)
1018        // `end_interval` should be 22 <= fsm < 37
1019        for ((beg, end), state_type) in self.fsm_schedule.iter() {
1020            // Check if the query encompasses the entire interval.
1021            // If so, we add it to the "middle" interval.
1022            if query_beg <= *beg && *end <= query_end {
1023                // Get the interval we have to add, based on `state_type`.
1024                let interval_to_add = match state_type {
1025                    StateType::Normal(fsm_interval) => *fsm_interval,
1026                    StateType::Offload(offload_state) => {
1027                        (*offload_state, offload_state + 1)
1028                    }
1029                };
1030                // Add `interval_to_add` to `middle_interval`.
1031                match middle_interval {
1032                    None => middle_interval = Some(interval_to_add),
1033                    Some((cur_start, cur_end)) => {
1034                        assert!(cur_end == interval_to_add.0);
1035                        middle_interval = Some((cur_start, interval_to_add.1));
1036                    }
1037                }
1038            }
1039            // Otherwise check if the beginning of the query lies within the
1040            // interval (This should only happen once). Add to `beg_interval`.
1041            else if *beg <= query_beg && query_beg < *end {
1042                assert!(beg_interval.is_false());
1043                // This is the query, but relativized to the start of the current interval.
1044                let relative_query = (query_beg - beg, query_end - beg);
1045                match state_type {
1046                    // If we are not offloading, then we can just produce a normal
1047                    // query.
1048                    StateType::Normal((beg_fsm_interval, end_fsm_interval)) => {
1049                        let translated_query = (
1050                            beg_fsm_interval + relative_query.0,
1051                            // This query either stretches into the next interval, or
1052                            // ends within the interval: we want to capture both of these choices.
1053                            std::cmp::min(
1054                                beg_fsm_interval + relative_query.1,
1055                                *end_fsm_interval,
1056                            ),
1057                        );
1058                        beg_interval = *fsm_cell
1059                            .borrow_mut()
1060                            .query_between(builder, translated_query);
1061                    }
1062                    // If we are not offloading, then we first check the state,
1063                    // then we must query the corresponding child.
1064                    StateType::Offload(offload_state) => {
1065                        let in_offload_state =
1066                            *fsm_cell.borrow_mut().query_between(
1067                                builder,
1068                                (*offload_state, offload_state + 1),
1069                            );
1070                        let (child, _) =
1071                            self.children.get_mut(child_index).unwrap();
1072                        let child_query = child.query_between(
1073                            (
1074                                relative_query.0,
1075                                // This query either stretches into another interval, or
1076                                // ends within the interval: we want to capture both of these choices.
1077                                std::cmp::min(
1078                                    relative_query.1,
1079                                    child.get_latency()
1080                                        * child.get_num_repeats(),
1081                                ),
1082                            ),
1083                            builder,
1084                        );
1085                        beg_interval = in_offload_state.and(child_query);
1086                    }
1087                };
1088            }
1089            // Check if the end of the query lies within the
1090            // interval (This should only happen once) Add to `end_interval`.
1091            else if *beg < query_end && query_end <= *end {
1092                // We only need the end of the relative query.
1093                // If we try to get the beginning then we could get overflow error.
1094                let relative_query_end = query_end - beg;
1095                assert!(end_interval.is_false());
1096                match state_type {
1097                    StateType::Normal((beg_fsm_interval, _)) => {
1098                        end_interval = *fsm_cell.borrow_mut().query_between(
1099                            builder,
1100                            // This query must stretch backwards into a preiouvs interval
1101                            // Otherwise it would have been caught by the
1102                            // So beg_fsm_interval is a safe start.
1103                            (
1104                                *beg_fsm_interval,
1105                                beg_fsm_interval + relative_query_end,
1106                            ),
1107                        );
1108                    }
1109                    StateType::Offload(offload_state) => {
1110                        let in_offload_state =
1111                            *fsm_cell.borrow_mut().query_between(
1112                                builder,
1113                                (*offload_state, offload_state + 1),
1114                            );
1115                        let (child, _) =
1116                            self.children.get_mut(child_index).unwrap();
1117                        // We know this must stretch backwards
1118                        // into a previous interval: otherwise, it
1119                        // would have been caught by the previous elif condition.
1120                        // therefore, we can start the child query at 0.
1121                        let child_query = child
1122                            .query_between((0, relative_query_end), builder);
1123                        end_interval = in_offload_state.and(child_query);
1124                    }
1125                };
1126            }
1127            if matches!(state_type, StateType::Offload(_)) {
1128                child_index += 1;
1129            }
1130        }
1131
1132        // Turn `middle_interval` into an actual `ir::Guard`.
1133        let middle_query = match middle_interval {
1134            None => Box::new(ir::Guard::True.not()),
1135            Some((i, j)) => self
1136                .fsm_cell
1137                .as_mut()
1138                .unwrap()
1139                .borrow_mut()
1140                .query_between(builder, (i, j)),
1141        };
1142
1143        beg_interval.or(end_interval.or(*middle_query))
1144    }
1145
1146    // Produces a guard that checks whether query.0 <= self.iter_count_cell < query.1
1147    fn get_repeat_query(
1148        &mut self,
1149        query: (u64, u64),
1150        builder: &mut ir::Builder,
1151    ) -> Box<ir::Guard<Nothing>> {
1152        // If self.num_repeats == 1, then no need for a complicated query.
1153        match self.num_repeats {
1154            1 => {
1155                assert!(query.0 == 0 && query.1 == 1);
1156                Box::new(ir::Guard::True)
1157            }
1158            _ => self
1159                .iter_count_cell
1160                .as_mut()
1161                .expect("querying repeat implies cell exists")
1162                .borrow_mut()
1163                .query_between(builder, (query.0, query.1)),
1164        }
1165    }
1166
1167    // Produce a guard that checks:
1168    //   - whether iteration == repeat_query AND
1169    //   - whether %[fsm_query.0:fsm_query.1]
1170    fn check_iteration_and_fsm_state(
1171        &mut self,
1172        (repeat_query, fsm_query): (u64, (u64, u64)),
1173        builder: &mut ir::Builder,
1174    ) -> ir::Guard<Nothing> {
1175        let fsm_guard = self.get_fsm_query(fsm_query, builder);
1176
1177        // Checks `self.iter_count_cell`.
1178        let counter_guard =
1179            self.get_repeat_query((repeat_query, repeat_query + 1), builder);
1180        ir::Guard::And(Box::new(fsm_guard), counter_guard)
1181    }
1182
1183    // Converts a %[i:j] query into a query of `self`'s and its childrens
1184    // iteration registers.
1185    fn query_between(
1186        &mut self,
1187        query: (u64, u64),
1188        builder: &mut ir::Builder,
1189    ) -> ir::Guard<Nothing> {
1190        // See `restructure_query` to see what we're doing.
1191        // But basically:
1192        // beg_iter_query = Option(iteration number, cycles during that iteration the query is true).
1193        // middle_iter_query = Option(iterations during which the query is true the entire iteration).
1194        // end_iter_query = Option(iteration number, cycles during that iteration the query is true).
1195        let (beg_iter_query, middle_iter_query, end_iter_query) =
1196            self.restructure_query(query);
1197
1198        // Call `check_iteration_and_fsm_state` for beg and end queries.
1199        let g0 = match beg_iter_query {
1200            None => ir::Guard::True.not(),
1201            Some(q0) => self.check_iteration_and_fsm_state(q0, builder),
1202        };
1203        let g1 = match end_iter_query {
1204            None => ir::Guard::True.not(),
1205            Some(q1) => self.check_iteration_and_fsm_state(q1, builder),
1206        };
1207
1208        // Call `get_repeat_query` for middle_iter_queries.
1209        let rep_query = match middle_iter_query {
1210            None => Box::new(ir::Guard::True.not()),
1211            Some(rq) => self.get_repeat_query(rq, builder),
1212        };
1213        g0.or(g1.or(*rep_query))
1214    }
1215
1216    // Takes in a static guard `guard`, and returns equivalent dynamic guard
1217    // The only thing that actually changes is the Guard::Info case
1218    // We need to turn static_timing to dynamic guards using `fsm`.
1219    // See `make_assign_dyn` for explanations of `global_view` and `ignore_timing`
1220    // variable.
1221    fn make_guard_dyn(
1222        &mut self,
1223        guard: ir::Guard<ir::StaticTiming>,
1224        global_view: bool,
1225        ignore_timing: bool,
1226        builder: &mut ir::Builder,
1227    ) -> Box<ir::Guard<Nothing>> {
1228        match guard {
1229            ir::Guard::Or(l, r) => Box::new(ir::Guard::Or(
1230                self.make_guard_dyn(*l, global_view, ignore_timing, builder),
1231                self.make_guard_dyn(*r, global_view, ignore_timing, builder),
1232            )),
1233            ir::Guard::And(l, r) => Box::new(ir::Guard::And(
1234                self.make_guard_dyn(*l, global_view, ignore_timing, builder),
1235                self.make_guard_dyn(*r, global_view, ignore_timing, builder),
1236            )),
1237            ir::Guard::Not(g) => Box::new(ir::Guard::Not(self.make_guard_dyn(
1238                *g,
1239                global_view,
1240                ignore_timing,
1241                builder,
1242            ))),
1243            ir::Guard::CompOp(op, l, r) => {
1244                Box::new(ir::Guard::CompOp(op, l, r))
1245            }
1246            ir::Guard::Port(p) => Box::new(ir::Guard::Port(p)),
1247            ir::Guard::True => Box::new(ir::Guard::True),
1248            ir::Guard::Info(static_timing) => {
1249                // If `ignore_timing` is true, then just return a true guard.
1250                if ignore_timing {
1251                    assert!(static_timing.get_interval() == (0, 1));
1252                    return Box::new(ir::Guard::True);
1253                }
1254                if global_view {
1255                    // For global_view we call `query_between`
1256                    Box::new(
1257                        self.query_between(
1258                            static_timing.get_interval(),
1259                            builder,
1260                        ),
1261                    )
1262                } else {
1263                    // For local_view we call `get_fsm_query`
1264                    Box::new(
1265                        self.get_fsm_query(
1266                            static_timing.get_interval(),
1267                            builder,
1268                        ),
1269                    )
1270                }
1271            }
1272        }
1273    }
1274
1275    /// Takes in static assignment `assign` and returns a dynamic assignments
1276    /// For example, it could transform the guard %[2:3] -> fsm.out >= 2 & fsm.out <= 3
1277    /// `global_view`: are you just querying for a given iteration, or are
1278    /// you querying for the entire tree's execution?
1279    ///   - if `global_view` is true, then you have to include the iteration
1280    ///     count register in the assignment's guard.
1281    ///   - if `global_view` is false, then you dont' have to include it
1282    ///
1283    /// `ignore_timing`: remove static timing guards instead of transforming them
1284    /// into an FSM query. Note that in order to do this, the timing guard must
1285    /// equal %[0:1], otherwise we will throw an error. This option is here
1286    /// mainly to save resource usage.
1287    pub fn make_assign_dyn(
1288        &mut self,
1289        assign: ir::Assignment<ir::StaticTiming>,
1290        global_view: bool,
1291        ignore_timing: bool,
1292        builder: &mut ir::Builder,
1293    ) -> ir::Assignment<Nothing> {
1294        ir::Assignment {
1295            src: assign.src,
1296            dst: assign.dst,
1297            attributes: assign.attributes,
1298            guard: self.make_guard_dyn(
1299                *assign.guard,
1300                global_view,
1301                ignore_timing,
1302                builder,
1303            ),
1304        }
1305    }
1306
1307    // Helper function: checks
1308    // whether the tree offloads its entire latency, and returns the
1309    // appropriate `bool`.
1310    fn offload_entire_latency(&self) -> bool {
1311        self.children.len() == 1
1312            && self
1313                .children
1314                .iter()
1315                .any(|(_, (beg, end))| *beg == 0 && *end == self.latency)
1316                // This last check is prob unnecessary since it follows from the first two.
1317            && self.num_states == 1
1318    }
1319}
1320
1321/// These methods handle adding conflicts to the tree (to help coloring for
1322/// sharing FSMs)
1323impl SingleNode {
1324    // Get names of groups corresponding to all nodes
1325    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
1326        let mut res = vec![self.root.0];
1327        for (child, _) in &self.children {
1328            res.extend(child.get_all_nodes())
1329        }
1330        res
1331    }
1332
1333    // Adds conflicts between children and any descendents.
1334    // Also add conflicts between any overlapping children. XXX(Caleb): normally
1335    // there shouldn't be overlapping children, but when we are doing the traditional
1336    // method in we don't offload (and therefore don't need this tree structure)
1337    // I have created dummy trees for the sole purpose of drawing conflicts
1338    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
1339        let root_name = self.root.0;
1340        for (child, _) in &self.children {
1341            for sgroup in &child.get_all_nodes() {
1342                conflict_graph.insert_conflict(&root_name, sgroup);
1343            }
1344            child.add_conflicts(conflict_graph);
1345        }
1346        // Adding conflicts between overlapping children.
1347        for ((child_a, (beg_a, end_a)), (child_b, (beg_b, end_b))) in
1348            self.children.iter().tuple_combinations()
1349        {
1350            // Checking if children overlap: either b begins within a, it
1351            // ends within a, or it encompasses a's entire interval.
1352            if ((beg_a <= beg_b) & (beg_b < end_a))
1353                | ((beg_a < end_b) & (end_b <= end_a))
1354                | (beg_b <= beg_a && end_a <= end_b)
1355            {
1356                // Adding conflicts between all nodes of the children if
1357                // the children overlap.
1358                for a_node in child_a.get_all_nodes() {
1359                    for b_node in child_b.get_all_nodes() {
1360                        conflict_graph.insert_conflict(&a_node, &b_node);
1361                    }
1362                }
1363            }
1364        }
1365    }
1366
1367    // Gets max value according to some function f.
1368    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
1369    where
1370        F: Fn(&SingleNode) -> u64,
1371    {
1372        let mut cur_max = 0;
1373        if self.root.0 == name {
1374            cur_max = std::cmp::max(cur_max, f(self));
1375        }
1376        for (child, _) in &self.children {
1377            cur_max = std::cmp::max(cur_max, child.get_max_value(name, f));
1378        }
1379        cur_max
1380    }
1381}
1382
1383/// Represents a group of `Nodes` that execute in parallel.
1384pub struct ParNodes {
1385    /// Name of the `par_group` that fires off the threads
1386    pub group_name: ir::Id,
1387    /// Latency
1388    pub latency: u64,
1389    /// Num Repeats
1390    pub num_repeats: u64,
1391    /// (Thread, interval thread is active).
1392    /// Interval thread is active should always start at 0.
1393    pub threads: Vec<(Node, (u64, u64))>,
1394}
1395
1396impl ParNodes {
1397    /// Instantiates FSMs by recursively instantiating FSM for each thread.
1398    pub fn instantiate_fsms(
1399        &mut self,
1400        builder: &mut ir::Builder,
1401        coloring: &HashMap<ir::Id, ir::Id>,
1402        colors_to_max_values: &HashMap<ir::Id, (u64, u64)>,
1403        colors_to_fsm: &mut HashMap<
1404            ir::Id,
1405            (OptionalStaticFSM, OptionalStaticFSM),
1406        >,
1407        one_hot_cutoff: u64,
1408    ) {
1409        for (thread, _) in &mut self.threads {
1410            thread.instantiate_fsms(
1411                builder,
1412                coloring,
1413                colors_to_max_values,
1414                colors_to_fsm,
1415                one_hot_cutoff,
1416            );
1417        }
1418    }
1419
1420    /// Counts to N by recursively calling `count_to_n` on each thread.
1421    pub fn count_to_n(
1422        &mut self,
1423        builder: &mut ir::Builder,
1424        incr_start_cond: Option<ir::Guard<Nothing>>,
1425    ) {
1426        for (thread, _) in &mut self.threads {
1427            thread.count_to_n(builder, incr_start_cond.clone());
1428        }
1429    }
1430
1431    /// Realizes static groups into dynamic group.
1432    pub fn realize(
1433        &mut self,
1434        ignore_timing_guards: bool,
1435        static_groups: &Vec<ir::RRC<ir::StaticGroup>>,
1436        reset_early_map: &mut HashMap<ir::Id, ir::Id>,
1437        fsm_info_map: &mut HashMap<
1438            ir::Id,
1439            (ir::Id, ir::Guard<Nothing>, ir::Guard<Nothing>),
1440        >,
1441        group_rewrites: &mut ir::rewriter::PortRewriteMap,
1442        builder: &mut ir::Builder,
1443    ) {
1444        // Get static grouo we are "realizing".
1445        let static_group = Rc::clone(
1446            static_groups
1447                .iter()
1448                .find(|sgroup| sgroup.borrow().name() == self.group_name)
1449                .expect("couldn't find static group"),
1450        );
1451        // Create the dynamic "early reset group" that will replace the static group.
1452        let static_group_name = static_group.borrow().name();
1453        let mut early_reset_name = static_group_name.to_string();
1454        early_reset_name.insert_str(0, "early_reset_");
1455        let early_reset_group = builder.add_group(early_reset_name);
1456
1457        // Get the longest node.
1458        let longest_node = self.get_longest_node();
1459
1460        // If one thread lasts 10 cycles, and another lasts 5 cycles, then the par_group
1461        // will look like this:
1462        // static<10> group par_group {
1463        //   thread1[go] = 1'd1;
1464        //   thread2[go] = %[0:5] ? 1'd1;
1465        // }
1466        // Therefore the %[0:5] needs to be realized using the FSMs from thread1 (the
1467        // longest FSM).
1468        let mut assigns = static_group
1469            .borrow()
1470            .assignments
1471            .clone()
1472            .into_iter()
1473            .map(|assign| {
1474                longest_node.make_assign_dyn(
1475                    assign,
1476                    true,
1477                    ignore_timing_guards,
1478                    builder,
1479                )
1480            })
1481            .collect_vec();
1482
1483        // Add assignment `group[done] = ud.out`` to the new group.
1484        structure!( builder; let ud = prim undef(1););
1485        let early_reset_done_assign = build_assignments!(
1486          builder;
1487          early_reset_group["done"] = ? ud["out"];
1488        );
1489        assigns.extend(early_reset_done_assign);
1490
1491        early_reset_group.borrow_mut().assignments = assigns;
1492        early_reset_group.borrow_mut().attributes =
1493            static_group.borrow().attributes.clone();
1494
1495        // Now we have to update the fields with a bunch of information.
1496        // This makes it easier when we have to build wrappers, rewrite ports, etc.
1497
1498        // Map the static group name -> early reset group name.
1499        // This is helpful for rewriting control
1500        reset_early_map
1501            .insert(static_group_name, early_reset_group.borrow().name());
1502        // self.group_rewrite_map helps write static_group[go] to early_reset_group[go]
1503        // Technically we could do this w/ early_reset_map but is easier w/
1504        // group_rewrite, which is explicitly of type `PortRewriterMap`
1505        group_rewrites.insert(
1506            ir::Canonical::new(static_group_name, ir::Id::from("go")),
1507            early_reset_group.borrow().find("go").unwrap_or_else(|| {
1508                unreachable!(
1509                    "group {} has no go port",
1510                    early_reset_group.borrow().name()
1511                )
1512            }),
1513        );
1514
1515        let fsm_identifier = match longest_node.fsm_cell.as_ref() {
1516            // If the tree does not have an fsm cell, then we can err on the
1517            // side of giving it its own unique identifier.
1518            None => longest_node.root.0,
1519            Some(fsm_rc) => fsm_rc.borrow().get_unique_id(),
1520        };
1521
1522        let total_latency = self.latency * self.num_repeats;
1523        fsm_info_map.insert(
1524            early_reset_group.borrow().name(),
1525            (
1526                fsm_identifier,
1527                self.query_between((0, 1), builder),
1528                self.query_between((total_latency - 1, total_latency), builder),
1529            ),
1530        );
1531
1532        // Recursively realize each child.
1533        self.threads.iter_mut().for_each(|(child, _)| {
1534            child.realize(
1535                ignore_timing_guards,
1536                static_groups,
1537                reset_early_map,
1538                fsm_info_map,
1539                group_rewrites,
1540                builder,
1541            )
1542        })
1543    }
1544
1545    /// Recursively searches each thread to get the longest (in terms of
1546    /// cycle counts) SingleNode.
1547    pub fn get_longest_node(&mut self) -> &mut SingleNode {
1548        let max = self.threads.iter_mut().max_by_key(|(child, _)| {
1549            (child.get_latency() * child.get_num_repeats()) as i64
1550        });
1551        if let Some((max_child, _)) = max {
1552            match max_child {
1553                Node::Par(par_nodes) => par_nodes.get_longest_node(),
1554                Node::Single(single_node) => single_node,
1555            }
1556        } else {
1557            unreachable!("self.children is empty/no maximum value found");
1558        }
1559    }
1560
1561    /// Use the longest node to query between.
1562    pub fn query_between(
1563        &mut self,
1564        query: (u64, u64),
1565        builder: &mut ir::Builder,
1566    ) -> ir::Guard<Nothing> {
1567        let longest_node = self.get_longest_node();
1568        longest_node.query_between(query, builder)
1569    }
1570}
1571
1572/// Used to add conflicts for graph coloring for sharing FSMs.
1573/// See the equivalent SingleNode implementation for more details.
1574impl ParNodes {
1575    pub fn get_all_nodes(&self) -> Vec<ir::Id> {
1576        let mut res = vec![];
1577        for (thread, _) in &self.threads {
1578            res.extend(thread.get_all_nodes())
1579        }
1580        res
1581    }
1582
1583    pub fn add_conflicts(&self, conflict_graph: &mut GraphColoring<ir::Id>) {
1584        for ((thread1, _), (thread2, _)) in
1585            self.threads.iter().tuple_combinations()
1586        {
1587            for sgroup1 in thread1.get_all_nodes() {
1588                for sgroup2 in thread2.get_all_nodes() {
1589                    conflict_graph.insert_conflict(&sgroup1, &sgroup2);
1590                }
1591            }
1592            thread1.add_conflicts(conflict_graph);
1593            thread2.add_conflicts(conflict_graph);
1594        }
1595    }
1596
1597    pub fn get_max_value<F>(&self, name: &ir::Id, f: &F) -> u64
1598    where
1599        F: Fn(&SingleNode) -> u64,
1600    {
1601        let mut cur_max = 0;
1602        for (thread, _) in &self.threads {
1603            cur_max = std::cmp::max(cur_max, thread.get_max_value(name, f));
1604        }
1605        cur_max
1606    }
1607}