calyx_opt/passes/
top_down_compile_control.rs

1use crate::passes;
2use crate::traversal::{
3    Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor,
4};
5use calyx_frontend::SetAttr;
6use calyx_ir::{
7    self as ir, BoolAttr, Cell, GetAttributes, Id, LibrarySignatures, Printer,
8    RRC, build_assignments, guard, structure,
9};
10use calyx_utils::{CalyxResult, Error, OutputFile, math::bits_needed_for};
11use ir::Nothing;
12use itertools::Itertools;
13use petgraph::graph::DiGraph;
14use serde::Serialize;
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::io::Write;
17use std::rc::Rc;
18
19const NODE_ID: ir::Attribute =
20    ir::Attribute::Internal(ir::InternalAttr::NODE_ID);
21const DUPLICATE_NUM_REG: u64 = 2;
22
23/// Computes the exit edges of a given [ir::Control] program.
24///
25/// ## Example
26/// In the following Calyx program:
27/// ```
28/// while comb_reg.out {
29///   seq {
30///     @NODE_ID(4) incr;
31///     @NODE_ID(5) cond0;
32///   }
33/// }
34/// ```
35/// The exit edge is is `[(5, cond0[done])]` indicating that the state 5 exits when the guard
36/// `cond0[done]` is true.
37///
38/// Multiple exit points are created when conditions are used:
39/// ```
40/// while comb_reg.out {
41///   @NODE_ID(7) incr;
42///   if comb_reg2.out {
43///     @NODE_ID(8) tru;
44///   } else {
45///     @NODE_ID(9) fal;
46///   }
47/// }
48/// ```
49/// The exit set is `[(8, tru[done] & !comb_reg.out), (9, fal & !comb_reg.out)]`.
50fn control_exits(con: &ir::Control, exits: &mut Vec<PredEdge>) {
51    match con {
52        ir::Control::Empty(_) => {}
53        ir::Control::Enable(ir::Enable { group, attributes }) => {
54            let cur_state = attributes.get(NODE_ID).unwrap();
55            exits.push((cur_state, guard!(group["done"])))
56        }
57        ir::Control::FSMEnable(ir::FSMEnable { attributes, fsm }) => {
58            let cur_state = attributes.get(NODE_ID).unwrap();
59            exits.push((cur_state, guard!(fsm["done"])))
60        }
61        ir::Control::Seq(ir::Seq { stmts, .. }) => {
62            if let Some(stmt) = stmts.last() {
63                control_exits(stmt, exits)
64            }
65        }
66        ir::Control::If(ir::If {
67            tbranch, fbranch, ..
68        }) => {
69            control_exits(tbranch, exits);
70            control_exits(fbranch, exits)
71        }
72        ir::Control::While(ir::While { body, port, .. }) => {
73            let mut loop_exits = vec![];
74            control_exits(body, &mut loop_exits);
75            // Loop exits only happen when the loop guard is false
76            exits.extend(
77                loop_exits
78                    .into_iter()
79                    .map(|(s, g)| (s, g & !ir::Guard::from(port.clone()))),
80            );
81        }
82        ir::Control::Repeat(_) => unreachable!(
83            "`repeat` statements should have been compiled away. Run `{}` before this pass.",
84            passes::CompileRepeat::name()
85        ),
86        ir::Control::Invoke(_) => unreachable!(
87            "`invoke` statements should have been compiled away. Run `{}` before this pass.",
88            passes::CompileInvoke::name()
89        ),
90        ir::Control::Par(_) => unreachable!(),
91        ir::Control::Static(_) => unreachable!(
92            " static control should have been compiled away. Run the static compilation passes before this pass"
93        ),
94    }
95}
96
97/// Adds the @NODE_ID attribute to [ir::Enable] and [ir::Par].
98/// Each [ir::Enable] gets a unique label within the context of a child of
99/// a [ir::Par] node.
100/// Furthermore, if an if/while/seq statement is labeled with a `new_fsm` attribute,
101/// then it will get its own unique label. Within that if/while/seq, each enable
102/// will get its own unique label within the context of that if/while/seq (see
103/// example for clarification).
104///
105/// ## Example:
106/// ```
107/// seq { A; B; par { C; D; }; E; @new_fsm seq {F; G; H}}
108/// ```
109/// gets the labels:
110/// ```
111/// seq {
112///   @NODE_ID(1) A; @NODE_ID(2) B;
113///   @NODE_ID(3) par {
114///     @NODE_ID(0) C;
115///     @NODE_ID(0) D;
116///   }
117///   @NODE_ID(4) E;
118///   @NODE_ID(5) seq{
119///     @NODE_ID(0) F;
120///     @NODE_ID(1) G;
121///     @NODE_ID(2) H;
122///   }
123/// }
124/// ```
125///
126/// These identifiers are used by the compilation methods [calculate_states_recur]
127/// and [control_exits].
128fn compute_unique_ids(con: &mut ir::Control, cur_state: u64) -> u64 {
129    match con {
130        ir::Control::Enable(ir::Enable { attributes, .. }) => {
131            attributes.insert(NODE_ID, cur_state);
132            cur_state + 1
133        }
134        ir::Control::FSMEnable(ir::FSMEnable { attributes, .. }) => {
135            attributes.insert(NODE_ID, cur_state);
136            cur_state + 1
137        }
138        ir::Control::Par(ir::Par { stmts, attributes }) => {
139            attributes.insert(NODE_ID, cur_state);
140            stmts.iter_mut().for_each(|stmt| {
141                compute_unique_ids(stmt, 0);
142            });
143            cur_state + 1
144        }
145        ir::Control::Seq(ir::Seq { stmts, attributes }) => {
146            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
147            // if new_fsm is true, then insert attribute at the seq, and then
148            // start over counting states from 0
149            let mut cur = if new_fsm {
150                attributes.insert(NODE_ID, cur_state);
151                0
152            } else {
153                cur_state
154            };
155            stmts.iter_mut().for_each(|stmt| {
156                cur = compute_unique_ids(stmt, cur);
157            });
158            // If new_fsm is true then we want to return cur_state + 1, since this
159            // seq should really only take up 1 "state" on the "outer" fsm
160            if new_fsm { cur_state + 1 } else { cur }
161        }
162        ir::Control::If(ir::If {
163            tbranch,
164            fbranch,
165            attributes,
166            ..
167        }) => {
168            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
169            // if new_fsm is true, then we want to add an attribute to this
170            // control statement
171            if new_fsm {
172                attributes.insert(NODE_ID, cur_state);
173            }
174            // If the program starts with a branch then branches can't get
175            // the initial state.
176            // Also, if new_fsm is true, we want to start with state 1 as well:
177            // we can't start at 0 for the reason mentioned above
178            let cur = if new_fsm || cur_state == 0 {
179                1
180            } else {
181                cur_state
182            };
183            let tru_nxt = compute_unique_ids(tbranch, cur);
184            let false_nxt = compute_unique_ids(fbranch, tru_nxt);
185            // If new_fsm is true then we want to return cur_state + 1, since this
186            // if stmt should really only take up 1 "state" on the "outer" fsm
187            if new_fsm { cur_state + 1 } else { false_nxt }
188        }
189        ir::Control::While(ir::While {
190            body, attributes, ..
191        }) => {
192            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
193            // if new_fsm is true, then we want to add an attribute to this
194            // control statement
195            if new_fsm {
196                attributes.insert(NODE_ID, cur_state);
197            }
198            // If the program starts with a branch then branches can't get
199            // the initial state.
200            // Also, if new_fsm is true, we want to start with state 1 as well:
201            // we can't start at 0 for the reason mentioned above
202            let cur = if new_fsm || cur_state == 0 {
203                1
204            } else {
205                cur_state
206            };
207            let body_nxt = compute_unique_ids(body, cur);
208            // If new_fsm is true then we want to return cur_state + 1, since this
209            // while loop should really only take up 1 "state" on the "outer" fsm
210            if new_fsm { cur_state + 1 } else { body_nxt }
211        }
212        ir::Control::Empty(_) => cur_state,
213        ir::Control::Repeat(_) => unreachable!(
214            "`repeat` statements should have been compiled away. Run `{}` before this pass.",
215            passes::CompileRepeat::name()
216        ),
217        ir::Control::Invoke(_) => unreachable!(
218            "`invoke` statements should have been compiled away. Run `{}` before this pass.",
219            passes::CompileInvoke::name()
220        ),
221        ir::Control::Static(_) => unreachable!(
222            "static control should have been compiled away. Run the static compilation passes before this pass"
223        ),
224    }
225}
226
227/// Given the state of the FSM, returns the index for the register in `fsms``
228/// that should be queried.
229/// A query for each state must read from one of the `num_registers` registers.
230/// For `r` registers and `n` states, we split into "buckets" as follows:
231/// `{0, ... , n/r - 1} -> reg. @ index 0`,
232/// `{n/r, ... , 2n/r - 1} -> reg. @ index 1`,
233/// ...,
234/// `{(r-1)n/r, ... , n - 1} -> reg. @ index n - 1`.
235/// Note that dividing each state by the value `n/r`normalizes the state w.r.t.
236/// the FSM register from which it should read. We can then take the floor of this value
237/// (or, equivalently, use unsigned integer division) to get this register index.
238fn register_to_query(
239    state: u64,
240    num_states: u64,
241    num_registers: u64,
242    distribute: bool,
243) -> usize {
244    match distribute {
245        true => {
246            // num_states+1 is needed to prevent error (the done condition needs
247            // to check past the number of states, i.e., will check fsm == 3 when
248            // num_states == 3).
249            (state * num_registers / (num_states + 1))
250                .try_into()
251                .unwrap()
252        }
253        false => 0,
254    }
255}
256
257#[derive(Clone, Copy)]
258enum RegisterEncoding {
259    Binary,
260    OneHot,
261}
262#[derive(Clone, Copy)]
263enum RegisterSpread {
264    // Default option: just a single register
265    Single,
266    // Duplicate the register to reduce fanout when querying
267    // (all FSMs in this vec still have all of the states)
268    Duplicate,
269}
270
271#[derive(Clone, Copy)]
272/// A type that represents how the FSM should be implemented in hardware.
273struct FSMRepresentation {
274    // the representation of a state within a register (one-hot, binary, etc.)
275    encoding: RegisterEncoding,
276    // the number of registers representing the dynamic finite state machine
277    spread: RegisterSpread,
278    // the index of the last state in the fsm (total # states = last_state + 1)
279    last_state: u64,
280}
281
282/// Represents the dyanmic execution schedule of a control program.
283struct Schedule<'b, 'a: 'b> {
284    /// A mapping from groups to corresponding FSM state ids
285    pub groups_to_states: HashSet<FSMStateInfo>,
286    /// Assigments that should be enabled in a given state.
287    pub enables: HashMap<u64, Vec<ir::Assignment<Nothing>>>,
288    /// FSMs that should be triggered in a given state.
289    pub fsm_enables: HashMap<u64, Vec<ir::Assignment<Nothing>>>,
290    /// Transition from one state to another when the guard is true.
291    pub transitions: Vec<(u64, u64, ir::Guard<Nothing>)>,
292    /// The component builder. The reference has a shorter lifetime than the builder itself
293    /// to allow multiple schedules to use the same builder.
294    pub builder: &'b mut ir::Builder<'a>,
295    /// The topmost control node's position ids. This is to identify the control node that
296    /// the schedule (and the control group generated by the schedule) corresponds to.
297    pub topmost_node_pos: Vec<u32>,
298}
299
300/// Information to serialize for profiling purposes
301#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
302enum ProfilingInfo {
303    Fsm(FSMInfo),
304    Par(ParInfo),
305    SingleEnable(SingleEnableInfo),
306}
307
308/// Information to be serialized for a group that isn't managed by a FSM
309/// This can happen if the group is the only group in a control block or a par arm
310#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
311struct SingleEnableInfo {
312    #[serde(serialize_with = "id_serialize_passthrough")]
313    pub component: Id,
314    #[serde(serialize_with = "id_serialize_passthrough")]
315    pub group: Id,
316}
317
318#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
319struct ParInfo {
320    #[serde(serialize_with = "id_serialize_passthrough")]
321    pub component: Id,
322    #[serde(serialize_with = "id_serialize_passthrough")]
323    pub par_group: Id,
324    pub child_groups: Vec<ParChildInfo>,
325    /// Values in the position set attribute that was associated with the control par node
326    /// that generated this par group.
327    pub pos: Vec<u32>,
328}
329
330#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
331struct ParChildInfo {
332    #[serde(serialize_with = "id_serialize_passthrough")]
333    pub group: Id,
334    #[serde(serialize_with = "id_serialize_passthrough")]
335    pub register: Id,
336}
337
338/// Information to be serialized for a single FSM
339#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
340struct FSMInfo {
341    #[serde(serialize_with = "id_serialize_passthrough")]
342    pub component: Id,
343    #[serde(serialize_with = "id_serialize_passthrough")]
344    pub group: Id,
345    #[serde(serialize_with = "id_serialize_passthrough")]
346    pub fsm: Id,
347    pub states: Vec<FSMStateInfo>,
348    /// Values in the position set attribute that was associated with the control node
349    /// that generated the TDCC group.
350    pub pos: Vec<u32>,
351}
352
353/// Mapping of FSM state ids to corresponding group names
354#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize)]
355struct FSMStateInfo {
356    id: u64,
357    #[serde(serialize_with = "id_serialize_passthrough")]
358    group: Id,
359}
360
361fn id_serialize_passthrough<S>(id: &Id, ser: S) -> Result<S::Ok, S::Error>
362where
363    S: serde::Serializer,
364{
365    id.to_string().serialize(ser)
366}
367
368impl<'b, 'a> From<&'b mut ir::Builder<'a>> for Schedule<'b, 'a> {
369    fn from(builder: &'b mut ir::Builder<'a>) -> Self {
370        Schedule {
371            groups_to_states: HashSet::new(),
372            enables: HashMap::new(),
373            fsm_enables: HashMap::new(),
374            transitions: Vec::new(),
375            builder,
376            topmost_node_pos: Vec::new(),
377        }
378    }
379}
380
381impl Schedule<'_, '_> {
382    /// Validate that all states are reachable in the transition graph.
383    fn validate(&self) {
384        let graph = DiGraph::<(), u32>::from_edges(
385            self.transitions
386                .iter()
387                .map(|(s, e, _)| (*s as u32, *e as u32)),
388        );
389
390        debug_assert!(
391            petgraph::algo::connected_components(&graph) == 1,
392            "State transition graph has unreachable states (graph has more than one connected component)."
393        );
394    }
395
396    /// Return the max state in the transition graph
397    fn last_state(&self) -> u64 {
398        self.transitions
399            .iter()
400            .max_by_key(|(_, s, _)| s)
401            .expect("Schedule::transition is empty!")
402            .1
403    }
404
405    /// Print out the current schedule
406    fn display(&self, group: String) {
407        let out = &mut std::io::stdout();
408        writeln!(out, "======== {group} =========").unwrap();
409        self.enables
410            .iter()
411            .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))
412            .for_each(|(state, assigns)| {
413                writeln!(out, "{state}:").unwrap();
414                assigns.iter().for_each(|assign| {
415                    Printer::write_assignment(assign, 2, out).unwrap();
416                    writeln!(out).unwrap();
417                })
418            });
419        writeln!(out, "{}:\n  <end>", self.last_state()).unwrap();
420        writeln!(out, "transitions:").unwrap();
421        self.transitions
422            .iter()
423            .sorted_by(|(k1, _, _), (k2, _, _)| k1.cmp(k2))
424            .for_each(|(i, f, g)| {
425                writeln!(out, "  ({}, {}): {}", i, f, Printer::guard_str(g))
426                    .unwrap();
427            });
428    }
429
430    /// First chooses which register to query from (only relevant in the duplication case.)
431    /// Then queries the FSM by building a new slicer and corresponding assignments if
432    /// the query hasn't yet been made. If this query has been made before with one-hot
433    /// encoding, it reuses the old query, but always returns a new guard representing the query.
434    fn query_state(
435        builder: &mut ir::Builder,
436        used_slicers_vec: &mut [HashMap<u64, RRC<Cell>>],
437        fsm_rep: &FSMRepresentation,
438        hardware: (&[RRC<ir::Cell>], &RRC<Cell>),
439        state: &u64,
440        fsm_size: &u64,
441        distribute: bool,
442    ) -> ir::Guard<Nothing> {
443        let (fsms, signal_on) = hardware;
444        let (fsm, used_slicers) = {
445            let reg_to_query = register_to_query(
446                *state,
447                fsm_rep.last_state,
448                fsms.len().try_into().unwrap(),
449                distribute,
450            );
451            (
452                fsms.get(reg_to_query)
453                    .expect("the register at this index does not exist"),
454                used_slicers_vec
455                    .get_mut(reg_to_query)
456                    .expect("the used slicer map at this index does not exist"),
457            )
458        };
459        match fsm_rep.encoding {
460            RegisterEncoding::Binary => {
461                let state_const = builder.add_constant(*state, *fsm_size);
462                let state_guard = guard!(fsm["out"] == state_const["out"]);
463                state_guard
464            }
465            RegisterEncoding::OneHot => {
466                match used_slicers.get(state) {
467                    None => {
468                        // construct slicer for this bit query
469                        structure!(
470                            builder;
471                            let slicer = prim std_bit_slice(*fsm_size, *state, *state, 1);
472                        );
473                        // build wire from fsm to slicer
474                        let fsm_to_slicer = builder.build_assignment(
475                            slicer.borrow().get("in"),
476                            fsm.borrow().get("out"),
477                            ir::Guard::True,
478                        );
479                        // add continuous assignments to slicer
480                        builder
481                            .component
482                            .continuous_assignments
483                            .push(fsm_to_slicer);
484                        // create a guard representing when to allow next-state transition
485                        let state_guard =
486                            guard!(slicer["out"] == signal_on["out"]);
487                        used_slicers.insert(*state, slicer);
488                        state_guard
489                    }
490                    Some(slicer) => {
491                        let state_guard =
492                            guard!(slicer["out"] == signal_on["out"]);
493                        state_guard
494                    }
495                }
496            }
497        }
498    }
499
500    /// Builds the register(s) and constants needed for a given encoding and spread type.
501    fn build_fsm_infrastructure(
502        builder: &mut ir::Builder,
503        fsm_rep: &FSMRepresentation,
504    ) -> (Vec<RRC<Cell>>, RRC<Cell>, u64) {
505        // get fsm bit width and build constant emitting fsm first state
506        let (fsm_size, first_state) = match fsm_rep.encoding {
507            RegisterEncoding::Binary => {
508                let fsm_size = bits_needed_for(fsm_rep.last_state + 1);
509                (fsm_size, builder.add_constant(0, fsm_size))
510            }
511            RegisterEncoding::OneHot => {
512                let fsm_size = fsm_rep.last_state + 1;
513                (fsm_size, builder.add_constant(1, fsm_size))
514            }
515        };
516
517        // for the given number of fsm registers to read from, add a primitive register to the design for each
518        let mut add_fsm_regs = |prim_name: &str, num_regs: u64| {
519            (0..num_regs)
520                .map(|n| {
521                    let fsm_name = if num_regs == 1 {
522                        "fsm".to_string()
523                    } else {
524                        format!("fsm{n}")
525                    };
526                    builder.add_primitive(
527                        fsm_name.as_str(),
528                        prim_name,
529                        &[fsm_size],
530                    )
531                })
532                .collect_vec()
533        };
534
535        let fsms = match (fsm_rep.encoding, fsm_rep.spread) {
536            (RegisterEncoding::Binary, RegisterSpread::Single) => {
537                add_fsm_regs("std_reg", 1)
538            }
539            (RegisterEncoding::OneHot, RegisterSpread::Single) => {
540                add_fsm_regs("init_one_reg", 1)
541            }
542            (RegisterEncoding::Binary, RegisterSpread::Duplicate) => {
543                add_fsm_regs("std_reg", DUPLICATE_NUM_REG)
544            }
545            (RegisterEncoding::OneHot, RegisterSpread::Duplicate) => {
546                add_fsm_regs("init_one_reg", DUPLICATE_NUM_REG)
547            }
548        };
549
550        (fsms, first_state, fsm_size)
551    }
552
553    /// Implement a given [Schedule] and return the name of the [ir::Group] that
554    /// implements it.
555    fn realize_schedule(
556        self,
557        dump_fsm: bool,
558        fsm_groups: &mut HashSet<ProfilingInfo>,
559        fsm_rep: FSMRepresentation,
560    ) -> RRC<ir::Group> {
561        // confirm all states are reachable
562        self.validate();
563
564        // build tdcc group
565        let group = self.builder.add_group("tdcc");
566        if dump_fsm {
567            self.display(format!(
568                "{}:{}",
569                self.builder.component.name,
570                group.borrow().name()
571            ));
572        }
573        // add position attributes to generated tdcc group
574        for pos in &self.topmost_node_pos {
575            group.borrow_mut().attributes.insert_set(SetAttr::Pos, *pos);
576        }
577
578        // build necessary primitives dependent on encoding and spread
579        let signal_on = self.builder.add_constant(1, 1);
580        let (fsms, first_state, fsm_size) =
581            Self::build_fsm_infrastructure(self.builder, &fsm_rep);
582
583        // get first fsm register
584        let fsm1 = fsms.first().expect("first fsm register does not exist");
585
586        // Add last state to JSON info
587        let mut states = self.groups_to_states.iter().cloned().collect_vec();
588        states.push(FSMStateInfo {
589            id: fsm_rep.last_state, // check that this register (fsm.0) is the correct one to use
590            group: Id::new(format!("{}_END", fsm1.borrow().name())),
591        });
592
593        // Keep track of groups to FSM state id information for dumping to json
594        fsm_groups.insert(ProfilingInfo::Fsm(FSMInfo {
595            component: self.builder.component.name,
596            fsm: fsm1.borrow().name(),
597            group: group.borrow().name(),
598            states: states.into_iter().sorted().collect(),
599            pos: self.topmost_node_pos.into_iter().sorted().collect(),
600        }));
601
602        // keep track of used slicers if using one hot encoding. one for each register
603        let mut used_slicers_vec =
604            fsms.iter().map(|_| HashMap::new()).collect_vec();
605
606        // enable assignments
607        // the following enable queries; we can decide which register to query for state-dependent assignments
608        // because we know all registers precisely agree at each cycle
609        group.borrow_mut().assignments.extend(
610            self.enables
611                .into_iter()
612                .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))
613                .flat_map(|(state, mut assigns)| {
614                    // for every assignment dependent on current fsm state, `&` new guard with existing guard
615                    let state_guard = Self::query_state(
616                        self.builder,
617                        &mut used_slicers_vec,
618                        &fsm_rep,
619                        (&fsms, &signal_on),
620                        &state,
621                        &fsm_size,
622                        true, // by default attempt to distribute across regs if >=2 exist
623                    );
624                    assigns.iter_mut().for_each(|asgn| {
625                        asgn.guard.update(|g| g.and(state_guard.clone()))
626                    });
627                    assigns
628                }),
629        );
630
631        // transition assignments
632        // the following updates are meant to ensure agreement between the two
633        // fsm registers; hence, all registers must be updated if `duplicate` is chosen
634        group.borrow_mut().assignments.extend(
635            self.transitions.into_iter().flat_map(|(s, e, guard)| {
636                // get a transition guard for the first fsm register, and apply it to every fsm register
637                let state_guard = Self::query_state(
638                    self.builder,
639                    &mut used_slicers_vec,
640                    &fsm_rep,
641                    (&fsms, &signal_on),
642                    &s,
643                    &fsm_size,
644                    false, // by default do not distribute transition queries across regs; choose first
645                );
646
647                // add transitions for every fsm register to ensure consistency between each
648                fsms.iter()
649                    .flat_map(|fsm| {
650                        let trans_guard =
651                            state_guard.clone().and(guard.clone());
652                        let end_const = match fsm_rep.encoding {
653                            RegisterEncoding::Binary => {
654                                self.builder.add_constant(e, fsm_size)
655                            }
656                            RegisterEncoding::OneHot => {
657                                self.builder.add_constant(
658                                    u64::pow(
659                                        2,
660                                        e.try_into()
661                                            .expect("failed to convert to u32"),
662                                    ),
663                                    fsm_size,
664                                )
665                            }
666                        };
667                        let ec_borrow = end_const.borrow();
668                        vec![
669                            self.builder.build_assignment(
670                                fsm.borrow().get("in"),
671                                ec_borrow.get("out"),
672                                trans_guard.clone(),
673                            ),
674                            self.builder.build_assignment(
675                                fsm.borrow().get("write_en"),
676                                signal_on.borrow().get("out"),
677                                trans_guard,
678                            ),
679                        ]
680                    })
681                    .collect_vec()
682            }),
683        );
684
685        // done condition for group
686        // arbitrarily look at first fsm register, since all are identical
687        let first_fsm_last_guard = Self::query_state(
688            self.builder,
689            &mut used_slicers_vec,
690            &fsm_rep,
691            (&fsms, &signal_on),
692            &fsm_rep.last_state,
693            &fsm_size,
694            false,
695        );
696
697        let done_assign = self.builder.build_assignment(
698            group.borrow().get("done"),
699            signal_on.borrow().get("out"),
700            first_fsm_last_guard.clone(),
701        );
702
703        group.borrow_mut().assignments.push(done_assign);
704
705        // Cleanup: Add a transition from last state to the first state for each register
706        let reset_fsms = fsms
707            .iter()
708            .flat_map(|fsm| {
709                // by default, query first register
710                let fsm_last_guard = Self::query_state(
711                    self.builder,
712                    &mut used_slicers_vec,
713                    &fsm_rep,
714                    (&fsms, &signal_on),
715                    &fsm_rep.last_state,
716                    &fsm_size,
717                    false,
718                );
719                let reset_fsm = build_assignments!(self.builder;
720                    fsm["in"] = fsm_last_guard ? first_state["out"];
721                    fsm["write_en"] = fsm_last_guard ? signal_on["out"];
722                );
723                reset_fsm.to_vec()
724            })
725            .collect_vec();
726
727        // extend with conditions to set all fsms to initial state
728        self.builder
729            .component
730            .continuous_assignments
731            .extend(reset_fsms);
732
733        group
734    }
735
736    fn realize_fsm(self, dump_fsm: bool) -> RRC<ir::FSM> {
737        // ensure schedule is valid
738        self.validate();
739
740        // compute final state and fsm_size, and register initial fsm
741        let fsm = self.builder.add_fsm("fsm");
742
743        if dump_fsm {
744            self.display(format!(
745                "{}:{}",
746                self.builder.component.name,
747                fsm.borrow().name()
748            ));
749        }
750
751        // map each source state to a list of conditional transitions
752        let mut transitions_map: HashMap<u64, Vec<(ir::Guard<Nothing>, u64)>> =
753            HashMap::new();
754        self.transitions.into_iter().for_each(
755            |(s, e, g)| match transitions_map.get_mut(&(s + 1)) {
756                Some(next_states) => next_states.push((g, e + 1)),
757                None => {
758                    transitions_map.insert(s + 1, vec![(g, e + 1)]);
759                }
760            },
761        );
762
763        // push the cases of the fsm to the fsm instantiation
764        let (mut transitions, mut assignments): (
765            VecDeque<ir::Transition>,
766            VecDeque<Vec<ir::Assignment<Nothing>>>,
767        ) = transitions_map
768            .drain()
769            .sorted_by(|(s1, _), (s2, _)| s1.cmp(s2))
770            .map(|(state, mut cond_dsts)| {
771                let assigns = match self.fsm_enables.get(&(state - 1)) {
772                    None => vec![],
773                    Some(assigns) => assigns.clone(),
774                };
775                // self-loop if all other guards are not met;
776                // should be at the end of the conditional destinations vec!
777                cond_dsts.push((ir::Guard::True, state));
778
779                (ir::Transition::Conditional(cond_dsts), assigns)
780            })
781            .unzip();
782
783        // insert transition condition from 0 to 1
784        let true_guard = ir::Guard::True;
785        assignments.push_front(vec![]);
786        transitions.push_front(ir::Transition::Conditional(vec![
787            (guard!(fsm["start"]), 1),
788            (true_guard.clone(), 0),
789        ]));
790
791        // insert transition from final calc state to `done` state
792        let signal_on = self.builder.add_constant(1, 1);
793        let assign = build_assignments!(self.builder;
794            fsm["done"] = true_guard ? signal_on["out"];
795        );
796        assignments.push_back(assign.to_vec());
797        transitions.push_back(ir::Transition::Unconditional(0));
798
799        fsm.borrow_mut().assignments.extend(assignments);
800        fsm.borrow_mut().transitions.extend(transitions);
801
802        // register group enables dependent on fsm state as assignments in the
803        // relevant state's assignment section
804        self.enables.into_iter().for_each(|(state, state_enables)| {
805            fsm.borrow_mut()
806                .extend_state_assignments(state + 1, state_enables);
807        });
808        fsm
809    }
810}
811
812/// Represents an edge from a predeccesor to the current control node.
813/// The `u64` represents the FSM state of the predeccesor and the guard needs
814/// to be true for the predeccesor to transition to the current state.
815type PredEdge = (u64, ir::Guard<Nothing>);
816
817impl Schedule<'_, '_> {
818    /// Recursively build an dynamic finite state machine represented by a [Schedule].
819    /// Does the following, given an [ir::Control]:
820    ///     1. If needed, add transitions from predeccesors to the current state.
821    ///     2. Enable the groups in the current state
822    ///     3. Calculate [PredEdge] implied by this state
823    ///     4. Return [PredEdge] and the next state.
824    /// Another note: the functions calc_seq_recur, calc_while_recur, and calc_if_recur
825    /// are functions that `calculate_states_recur` uses for when con is a seq, while,
826    /// and if respectively. The reason why they are defined as separate functions is because we
827    /// need to call `calculate_seq_recur` (for example) directly when we are in `finish_seq`
828    /// since `finish_seq` only gives us access to a `& mut seq` type, not a `& Control`
829    /// type.
830    fn calculate_states_recur(
831        // Current schedule.
832        &mut self,
833        con: &ir::Control,
834        // The set of previous states that want to transition into cur_state
835        preds: Vec<PredEdge>,
836        // True if early_transitions are allowed
837        early_transitions: bool,
838        // True if the `@fast` attribute has successfully been applied to the parent of this control
839        has_fast_guarantee: bool,
840    ) -> CalyxResult<Vec<PredEdge>> {
841        match con {
842            ir::Control::FSMEnable(ir::FSMEnable { fsm, attributes }) => {
843                let cur_state = attributes.get(NODE_ID).unwrap_or_else(|| {
844                    panic!(
845                        "Group `{}` does not have state_id information",
846                        fsm.borrow().name()
847                    )
848                });
849                let (cur_state, prev_states) =
850                    if preds.len() == 1 && preds[0].1.is_true() {
851                        (preds[0].0, vec![])
852                    } else {
853                        (cur_state, preds)
854                    };
855                // Add group to mapping for emitting group JSON info
856                self.groups_to_states.insert(FSMStateInfo {
857                    id: cur_state,
858                    group: fsm.borrow().name(),
859                });
860
861                let not_done = !guard!(fsm["done"]);
862                let signal_on = self.builder.add_constant(1, 1);
863
864                // Activate this fsm in the current state
865                let en_go: [ir::Assignment<Nothing>; 1] = build_assignments!(self.builder;
866                    fsm["start"] = not_done ? signal_on["out"];
867                );
868
869                // store enable conditions for this FSM
870                self.fsm_enables.entry(cur_state).or_default().extend(en_go);
871
872                // Enable FSM to be triggered by states besides the most recent
873                if early_transitions || has_fast_guarantee {
874                    for (st, g) in &prev_states {
875                        let early_go = build_assignments!(self.builder;
876                            fsm["start"] = g ? signal_on["out"];
877                        );
878                        self.fsm_enables
879                            .entry(*st)
880                            .or_default()
881                            .extend(early_go);
882                    }
883                }
884
885                let transitions = prev_states
886                    .into_iter()
887                    .map(|(st, guard)| (st, cur_state, guard));
888                self.transitions.extend(transitions);
889
890                let done_cond = guard!(fsm["done"]);
891                Ok(vec![(cur_state, done_cond)])
892            }
893            // See explanation of FSM states generated in [ir::TopDownCompileControl].
894            ir::Control::Enable(ir::Enable { group, attributes }) => {
895                let cur_state = attributes.get(NODE_ID).unwrap_or_else(|| {
896                    panic!(
897                        "Group `{}` does not have node_id information",
898                        group.borrow().name()
899                    )
900                });
901                // If there is exactly one previous transition state with a `true`
902                // guard, then merge this state into previous state.
903                // This happens when the first control statement is an enable not
904                // inside a branch.
905                let (cur_state, prev_states) =
906                    if preds.len() == 1 && preds[0].1.is_true() {
907                        (preds[0].0, vec![])
908                    } else {
909                        (cur_state, preds)
910                    };
911
912                // Add group to mapping for emitting group JSON info
913                self.groups_to_states.insert(FSMStateInfo {
914                    id: cur_state,
915                    group: group.borrow().name(),
916                });
917
918                let not_done = !guard!(group["done"]);
919                let signal_on = self.builder.add_constant(1, 1);
920
921                // Activate this group in the current state
922                let en_go = build_assignments!(self.builder;
923                    group["go"] = not_done ? signal_on["out"];
924                );
925                self.enables.entry(cur_state).or_default().extend(en_go);
926
927                // Activate group in the cycle when previous state signals done.
928                // NOTE: We explicilty do not add `not_done` to the guard.
929                // See explanation in [ir::TopDownCompileControl] to understand
930                // why.
931                if early_transitions || has_fast_guarantee {
932                    for (st, g) in &prev_states {
933                        let early_go = build_assignments!(self.builder;
934                            group["go"] = g ? signal_on["out"];
935                        );
936                        self.enables.entry(*st).or_default().extend(early_go);
937                    }
938                }
939
940                let transitions = prev_states
941                    .into_iter()
942                    .map(|(st, guard)| (st, cur_state, guard));
943                self.transitions.extend(transitions);
944
945                let done_cond = guard!(group["done"]);
946                Ok(vec![(cur_state, done_cond)])
947            }
948            ir::Control::Seq(seq) => {
949                self.calc_seq_recur(seq, preds, early_transitions)
950            }
951            ir::Control::If(if_stmt) => {
952                self.calc_if_recur(if_stmt, preds, early_transitions)
953            }
954            ir::Control::While(while_stmt) => {
955                self.calc_while_recur(while_stmt, preds, early_transitions)
956            }
957            ir::Control::Par(_) => unreachable!(),
958            ir::Control::Repeat(_) => unreachable!(
959                "`repeat` statements should have been compiled away. Run `{}` before this pass.",
960                passes::CompileRepeat::name()
961            ),
962            ir::Control::Invoke(_) => unreachable!(
963                "`invoke` statements should have been compiled away. Run `{}` before this pass.",
964                passes::CompileInvoke::name()
965            ),
966            ir::Control::Empty(_) => unreachable!(
967                "`calculate_states_recur` should not see an `empty` control."
968            ),
969            ir::Control::Static(_) => unreachable!(
970                "static control should have been compiled away. Run the static compilation passes before this pass"
971            ),
972        }
973    }
974
975    /// Builds a finite state machine for `seq` represented by a [Schedule].
976    /// At a high level, it iterates through each stmt in the seq's control, using the
977    /// previous stmt's [PredEdge] as the `preds` for the current stmt, and returns
978    /// the [PredEdge] implied by the last stmt in `seq`'s control.
979    fn calc_seq_recur(
980        &mut self,
981        seq: &ir::Seq,
982        // The set of previous states that want to transition into cur_state
983        preds: Vec<PredEdge>,
984        // True if early_transitions are allowed
985        early_transitions: bool,
986    ) -> CalyxResult<Vec<PredEdge>> {
987        let mut prev = preds;
988        for (i, stmt) in seq.stmts.iter().enumerate() {
989            prev = self.calculate_states_recur(
990                stmt,
991                prev,
992                early_transitions,
993                i > 0 && seq.get_attributes().has(BoolAttr::Fast),
994            )?;
995        }
996        Ok(prev)
997    }
998
999    /// Builds a finite state machine for `if_stmt` represented by a [Schedule].
1000    /// First generates the transitions into the true branch + the transitions that exist
1001    /// inside the true branch. Then generates the transitions into the false branch + the transitions
1002    /// that exist inside the false branch. Then calculates the transitions needed to
1003    /// exit the if statmement (which include edges from both the true and false branches).
1004    fn calc_if_recur(
1005        &mut self,
1006        if_stmt: &ir::If,
1007        // The set of previous states that want to transition into cur_state
1008        preds: Vec<PredEdge>,
1009        // True if early_transitions are allowed
1010        early_transitions: bool,
1011    ) -> CalyxResult<Vec<PredEdge>> {
1012        if if_stmt.cond.is_some() {
1013            return Err(Error::malformed_structure(format!(
1014                "{}: Found group `{}` in with position of if. This should have compiled away.",
1015                TopDownCompileControl::name(),
1016                if_stmt.cond.as_ref().unwrap().borrow().name()
1017            )));
1018        }
1019        let port_guard: ir::Guard<Nothing> = Rc::clone(&if_stmt.port).into();
1020        // Previous states transitioning into true branch need the conditional
1021        // to be true.
1022        let tru_transitions = preds
1023            .clone()
1024            .into_iter()
1025            .map(|(s, g)| (s, g & port_guard.clone()))
1026            .collect();
1027        let tru_prev = self.calculate_states_recur(
1028            &if_stmt.tbranch,
1029            tru_transitions,
1030            early_transitions,
1031            false,
1032        )?;
1033        // Previous states transitioning into false branch need the conditional
1034        // to be false.
1035        let fal_transitions = preds
1036            .into_iter()
1037            .map(|(s, g)| (s, g & !port_guard.clone()))
1038            .collect();
1039
1040        let fal_prev = if let ir::Control::Empty(..) = *if_stmt.fbranch {
1041            // If the false branch is empty, then all the prevs to this node will become prevs
1042            // to the next node.
1043            fal_transitions
1044        } else {
1045            self.calculate_states_recur(
1046                &if_stmt.fbranch,
1047                fal_transitions,
1048                early_transitions,
1049                false,
1050            )?
1051        };
1052
1053        let prevs = tru_prev.into_iter().chain(fal_prev).collect();
1054        Ok(prevs)
1055    }
1056
1057    /// Builds a finite state machine for `while_stmt` represented by a [Schedule].
1058    /// It first generates the backwards edges (i.e., edges from the end of the while
1059    /// body back to the beginning of the while body), then generates the forwards
1060    /// edges in the body, then generates the edges that exit the while loop.
1061    fn calc_while_recur(
1062        &mut self,
1063        while_stmt: &ir::While,
1064        // The set of previous states that want to transition into cur_state
1065        preds: Vec<PredEdge>,
1066        // True if early_transitions are allowed
1067        early_transitions: bool,
1068    ) -> CalyxResult<Vec<PredEdge>> {
1069        if while_stmt.cond.is_some() {
1070            return Err(Error::malformed_structure(format!(
1071                "{}: Found group `{}` in with position of if. This should have compiled away.",
1072                TopDownCompileControl::name(),
1073                while_stmt.cond.as_ref().unwrap().borrow().name()
1074            )));
1075        }
1076
1077        let port_guard: ir::Guard<Nothing> = Rc::clone(&while_stmt.port).into();
1078
1079        // Step 1: Generate the backward edges by computing the exit nodes.
1080        let mut exits = vec![];
1081        control_exits(&while_stmt.body, &mut exits);
1082
1083        // Step 2: Generate the forward edges normally.
1084        // Previous transitions into the body require the condition to be
1085        // true.
1086        let transitions: Vec<PredEdge> = preds
1087            .clone()
1088            .into_iter()
1089            .chain(exits)
1090            .map(|(s, g)| (s, g & port_guard.clone()))
1091            .collect();
1092
1093        let prevs = self.calculate_states_recur(
1094            &while_stmt.body,
1095            transitions,
1096            early_transitions,
1097            false,
1098        )?;
1099
1100        // Step 3: The final out edges from the while come from:
1101        //   - Before the body when the condition is false
1102        //   - Inside the body when the condition is false
1103        let not_port_guard = !port_guard;
1104        let all_prevs = preds
1105            .into_iter()
1106            .chain(prevs)
1107            .map(|(st, guard)| (st, guard & not_port_guard.clone()))
1108            .collect();
1109
1110        Ok(all_prevs)
1111    }
1112
1113    /// Creates a Schedule that represents `seq`, mainly relying on `calc_seq_recur()`.
1114    fn calculate_states_seq(
1115        &mut self,
1116        seq: &ir::Seq,
1117        early_transitions: bool,
1118    ) -> CalyxResult<()> {
1119        let first_state = (0, ir::Guard::True);
1120        // We create an empty first state in case the control program starts with
1121        // a branch (if, while).
1122        // If the program doesn't branch, then the initial state is merged into
1123        // the first group.
1124        let prev =
1125            self.calc_seq_recur(seq, vec![first_state], early_transitions)?;
1126        self.add_nxt_transition(prev);
1127        Ok(())
1128    }
1129
1130    /// Creates a Schedule that represents `if`, mainly relying on `calc_if_recur()`.
1131    fn calculate_states_if(
1132        &mut self,
1133        if_stmt: &ir::If,
1134        early_transitions: bool,
1135    ) -> CalyxResult<()> {
1136        let first_state = (0, ir::Guard::True);
1137        // We create an empty first state in case the control program starts with
1138        // a branch (if, while).
1139        // If the program doesn't branch, then the initial state is merged into
1140        // the first group.
1141        let prev =
1142            self.calc_if_recur(if_stmt, vec![first_state], early_transitions)?;
1143        self.add_nxt_transition(prev);
1144        Ok(())
1145    }
1146
1147    /// Creates a Schedule that represents `while`, mainly relying on `calc_while_recur()`.
1148    fn calculate_states_while(
1149        &mut self,
1150        while_stmt: &ir::While,
1151        early_transitions: bool,
1152    ) -> CalyxResult<()> {
1153        let first_state = (0, ir::Guard::True);
1154        // We create an empty first state in case the control program starts with
1155        // a branch (if, while).
1156        // If the program doesn't branch, then the initial state is merged into
1157        // the first group.
1158        let prev = self.calc_while_recur(
1159            while_stmt,
1160            vec![first_state],
1161            early_transitions,
1162        )?;
1163        self.add_nxt_transition(prev);
1164        Ok(())
1165    }
1166
1167    /// Given predecessors prev, creates a new "next" state and transitions from
1168    /// each state in prev to the next state.
1169    /// In other words, it just adds an "end" state to [Schedule] and the
1170    /// appropriate transitions to that "end" state.
1171    fn add_nxt_transition(&mut self, prev: Vec<PredEdge>) {
1172        let nxt = prev
1173            .iter()
1174            .max_by(|(st1, _), (st2, _)| st1.cmp(st2))
1175            .unwrap()
1176            .0
1177            + 1;
1178        let transitions = prev.into_iter().map(|(st, guard)| (st, nxt, guard));
1179        self.transitions.extend(transitions);
1180    }
1181
1182    /// Note: the functions calculate_states_seq, calculate_states_while, and calculate_states_if
1183    /// are functions that basically do what `calculate_states` would do if `calculate_states` knew (for certain)
1184    /// that its input parameter would be a seq/while/if.
1185    /// The reason why we need to define these as separate functions is because `finish_seq`
1186    /// (for example) we only gives us access to a `& mut seq` type, not a `& Control`
1187    /// type.
1188    fn calculate_states(
1189        &mut self,
1190        con: &ir::Control,
1191        early_transitions: bool,
1192    ) -> CalyxResult<()> {
1193        // Collect position ids from the topmost node in the schedule
1194        // (assumes that either all or no control nodes contain position ids,
1195        //  and that the topmost node is either a seq, while, or if.)
1196        if self.topmost_node_pos.is_empty() {
1197            let attrs_opt = match con {
1198                ir::Control::Seq(_)
1199                | ir::Control::While(_)
1200                | ir::Control::If(_) => Some(con.get_attributes()),
1201                _ => None,
1202            };
1203            if let Some(attrs) = attrs_opt {
1204                if let Some(pos_set) = attrs.get_set(SetAttr::Pos) {
1205                    self.topmost_node_pos.extend(pos_set.iter());
1206                }
1207            }
1208        }
1209        let first_state = (0, ir::Guard::True);
1210        // We create an empty first state in case the control program starts with
1211        // a branch (if, while).
1212        // If the program doesn't branch, then the initial state is merged into
1213        // the first group.
1214        let prev = self.calculate_states_recur(
1215            con,
1216            vec![first_state],
1217            early_transitions,
1218            false,
1219        )?;
1220        self.add_nxt_transition(prev);
1221        Ok(())
1222    }
1223}
1224
1225/// **Core lowering pass.**
1226/// Compiles away the control programs in components into purely structural code using an
1227/// finite-state machine (FSM).
1228///
1229/// Lowering operates in two steps:
1230/// 1. Compile all [ir::Par] control sub-programs into a single [ir::Enable] of a group that runs
1231///    all children to completion.
1232/// 2. Compile the top-level control program into a single [ir::Enable].
1233///
1234/// ## Compiling non-`par` programs
1235/// At very high-level, the pass assigns an FSM state to each [ir::Enable] in the program and
1236/// generates transitions to the state to activate the groups contained within the [ir::Enable].
1237///
1238/// The compilation process calculates all predeccesors of the [ir::Enable] while walking over the
1239/// control program. A predeccesor is any enable statement that can directly "jump" to the current
1240/// [ir::Enable]. The compilation process computes all such predeccesors and the guards that need
1241/// to be true for the predeccesor to jump into this enable statement.
1242///
1243/// ```
1244/// cond0;
1245/// while lt.out {
1246///   if gt.out { true } else { false }
1247/// }
1248/// next;
1249/// ```
1250/// The predeccesor sets are:
1251/// ```
1252/// cond0 -> []
1253/// true -> [(cond0, lt.out & gt.out); (true; lt.out & gt.out); (false, lt.out & !gt.out)]
1254/// false -> [(cond0, lt.out & !gt.out); (true; lt.out & gt.out); (false, lt.out & !gt.out)]
1255/// next -> [(cond0, !lt.out); (true, !lt.out); (false, !lt.out)]
1256/// ```
1257///
1258/// ### Compiling [ir::Enable]
1259/// The process first takes all edges from predeccesors and transitions to the state for this
1260/// enable and enables the group in this state:
1261/// ```text
1262/// let cur_state; // state of this enable
1263/// for (state, guard) in predeccesors:
1264///   transitions.insert(state, cur_state, guard)
1265/// enables.insert(cur_state, group)
1266/// ```
1267///
1268/// While this process will generate a functioning FSM, the FSM takes unnecessary cycles for FSM
1269/// transitions.
1270///
1271/// For example:
1272/// ```
1273/// seq { one; two; }
1274/// ```
1275/// The FSM generated will look like this (where `f` is the FSM register):
1276/// ```
1277/// f.in = one[done] ? 1;
1278/// f.in = two[done] ? 2;
1279/// one[go] = !one[done] & f.out == 0;
1280/// two[go] = !two[done] & f.out == 1;
1281/// ```
1282///
1283/// The cycle-level timing for this FSM will look like:
1284///     - cycle 0: (`f.out` == 0), enable one
1285///     - cycle t: (`f.out` == 0), (`one[done]` == 1), disable one
1286///     - cycle t+1: (`f.out` == 1), enable two
1287///     - cycle t+l: (`f.out` == 1), (`two[done]` == 1), disable two
1288///     - cycle t+l+1: finish
1289///
1290/// The transition t -> t+1 represents one where group one is done but group two hasn't started
1291/// executing.
1292///
1293/// To address this specific problem, there is an additional enable added to run all groups within
1294/// an enable *while the FSM is transitioning*.
1295/// The final transition will look like this:
1296/// ```
1297/// f.in = one[done] ? 1;
1298/// f.in = two[done] ? 2;
1299/// one[go] = !one[done] & f.out == 0;
1300/// two[go] = (!two[done] & f.out == 1) || (one[done] & f.out == 0);
1301/// ```
1302///
1303/// Note that `!two[done]` isn't present in the second disjunct because all groups are guaranteed
1304/// to run for at least one cycle and the second disjunct will only be true for one cycle before
1305/// the first disjunct becomes true.
1306///
1307/// ## Compiling `par` programs
1308/// We have to generate new FSM-based controller for each child of a `par` node so that each child
1309/// can indepdendently make progress.
1310/// If we tie the children to one top-level FSM, their transitions would become interdependent and
1311/// reduce available concurrency.
1312///
1313/// ## Compilation guarantee
1314/// At the end of this pass, the control program will have no more than one
1315/// group enable in it.
1316pub struct TopDownCompileControl {
1317    /// Print out the FSM representation to STDOUT
1318    dump_fsm: bool,
1319    /// Output a JSON FSM representation to file if specified
1320    dump_fsm_json: Option<OutputFile>,
1321    /// Enable early transitions
1322    early_transitions: bool,
1323    /// Profiling: Bookkeeping for TDCC-generated register/group information (FSMs, par groups)
1324    profiling_info: HashSet<ProfilingInfo>,
1325    /// Decides whether FSMs are emitted as normal (inlined) or such that synthesis tool infers + optimizes FSM
1326    infer_fsms: bool,
1327    /// How many states the dynamic FSM must have before picking binary over one-hot
1328    one_hot_cutoff: u64,
1329    /// Number of states the dynamic FSM must have before picking duplicate over single register
1330    duplicate_cutoff: u64,
1331}
1332
1333impl TopDownCompileControl {
1334    /// Given a dynamic schedule and attributes, selects a representation for
1335    /// the finite state machine in hardware.
1336    fn get_representation(
1337        &self,
1338        sch: &Schedule,
1339        attrs: &ir::Attributes,
1340    ) -> FSMRepresentation {
1341        let last_state = sch.last_state();
1342        FSMRepresentation {
1343            encoding: {
1344                match (
1345                    attrs.has(BoolAttr::OneHot),
1346                    last_state <= self.one_hot_cutoff,
1347                ) {
1348                    (true, _) | (false, true) => RegisterEncoding::OneHot,
1349                    (false, false) => RegisterEncoding::Binary,
1350                }
1351            },
1352            spread: {
1353                match (last_state + 1) <= self.duplicate_cutoff {
1354                    true => RegisterSpread::Single,
1355                    false => RegisterSpread::Duplicate,
1356                }
1357            },
1358            last_state,
1359        }
1360    }
1361}
1362
1363impl ConstructVisitor for TopDownCompileControl {
1364    fn from(ctx: &ir::Context) -> CalyxResult<Self>
1365    where
1366        Self: Sized + Named,
1367    {
1368        let opts = Self::get_opts(ctx);
1369
1370        Ok(TopDownCompileControl {
1371            dump_fsm: opts[&"dump-fsm"].bool(),
1372            dump_fsm_json: opts[&"dump-fsm-json"].not_null_outstream(),
1373            early_transitions: opts[&"early-transitions"].bool(),
1374            profiling_info: HashSet::new(),
1375            infer_fsms: opts[&"infer-fsms"].bool(),
1376            one_hot_cutoff: opts[&"one-hot-cutoff"]
1377                .pos_num()
1378                .expect("requires non-negative OHE cutoff parameter"),
1379            duplicate_cutoff: opts[&"duplicate-cutoff"]
1380                .pos_num()
1381                .expect("requires non-negative duplicate cutoff parameter"),
1382        })
1383    }
1384
1385    fn clear_data(&mut self) {
1386        /* All data can be transferred between components */
1387    }
1388}
1389
1390impl Named for TopDownCompileControl {
1391    fn name() -> &'static str {
1392        "tdcc"
1393    }
1394
1395    fn description() -> &'static str {
1396        "Top-down compilation for removing control constructs"
1397    }
1398
1399    fn opts() -> Vec<PassOpt> {
1400        vec![
1401            PassOpt::new(
1402                "dump-fsm",
1403                "Print out the state machine implementing the schedule",
1404                ParseVal::Bool(false),
1405                PassOpt::parse_bool,
1406            ),
1407            PassOpt::new(
1408                "dump-fsm-json",
1409                "Write the state machine implementing the schedule to a JSON file",
1410                ParseVal::OutStream(OutputFile::Null),
1411                PassOpt::parse_outstream,
1412            ),
1413            PassOpt::new(
1414                "early-transitions",
1415                "Experimental: Enable early transitions for group enables",
1416                ParseVal::Bool(false),
1417                PassOpt::parse_bool,
1418            ),
1419            PassOpt::new(
1420                "infer-fsms",
1421                "Emits FSMs that are inferred and optimized by Vivado toolchain",
1422                ParseVal::Bool(false),
1423                PassOpt::parse_bool,
1424            ),
1425            PassOpt::new(
1426                "one-hot-cutoff",
1427                "Threshold at and below which a one-hot encoding is used for dynamic group scheduling",
1428                ParseVal::Num(0),
1429                PassOpt::parse_num,
1430            ),
1431            PassOpt::new(
1432                "duplicate-cutoff",
1433                "Threshold above which the dynamic fsm register is replicated into a second, identical register",
1434                ParseVal::Num(i64::MAX),
1435                PassOpt::parse_num,
1436            ),
1437        ]
1438    }
1439}
1440
1441/// Helper function to emit profiling information when the control consists of a single group.
1442fn extract_single_enable(
1443    con: &mut ir::Control,
1444    component: Id,
1445) -> Option<SingleEnableInfo> {
1446    if let ir::Control::Enable(enable) = con {
1447        return Some(SingleEnableInfo {
1448            component,
1449            group: enable.group.borrow().name(),
1450        });
1451    } else {
1452        None
1453    }
1454}
1455
1456impl Visitor for TopDownCompileControl {
1457    fn start(
1458        &mut self,
1459        comp: &mut ir::Component,
1460        _sigs: &LibrarySignatures,
1461        _comps: &[ir::Component],
1462    ) -> VisResult {
1463        let mut con = comp.control.borrow_mut();
1464        match *con {
1465            // If there's one top-level FSM at the beginning of the traversal,
1466            // the control tree is likely entirely static and has already been compiled.
1467            // In that case, just move on without wrapping that FSM in another dynamic FSM.
1468            ir::Control::FSMEnable(_) => Ok(Action::Stop),
1469            ir::Control::Empty(_) | ir::Control::Enable(_) => {
1470                if let Some(enable_info) =
1471                    extract_single_enable(&mut con, comp.name)
1472                {
1473                    self.profiling_info
1474                        .insert(ProfilingInfo::SingleEnable(enable_info));
1475                }
1476                Ok(Action::Stop)
1477            }
1478            _ => {
1479                compute_unique_ids(&mut con, 0);
1480                Ok(Action::Continue)
1481            }
1482        }
1483    }
1484
1485    fn finish_seq(
1486        &mut self,
1487        s: &mut ir::Seq,
1488        comp: &mut ir::Component,
1489        sigs: &LibrarySignatures,
1490        _comps: &[ir::Component],
1491    ) -> VisResult {
1492        // only compile using new fsm if has new_fsm attribute
1493        if !s.attributes.has(ir::BoolAttr::NewFSM) {
1494            return Ok(Action::Continue);
1495        }
1496        let mut builder = ir::Builder::new(comp, sigs);
1497        let mut sch = Schedule::from(&mut builder);
1498
1499        sch.calculate_states_seq(s, self.early_transitions)?;
1500
1501        // compile schedule and return the enable node
1502        let mut seq_enable = if self.infer_fsms {
1503            ir::Control::fsm_enable(sch.realize_fsm(self.dump_fsm))
1504        } else {
1505            let fsm_impl = self.get_representation(&sch, &s.attributes);
1506            let seq_group = sch.realize_schedule(
1507                self.dump_fsm,
1508                &mut self.profiling_info,
1509                fsm_impl,
1510            );
1511            seq_group
1512                .borrow_mut()
1513                .attributes
1514                .copy_from_set(s.get_attributes(), vec![SetAttr::Pos]);
1515            seq_group.borrow_mut().attributes = s.attributes.clone();
1516            ir::Control::enable(seq_group)
1517        };
1518
1519        // Add NODE_ID to compiled enable.
1520        let state_id = s.attributes.get(NODE_ID).unwrap();
1521        seq_enable.get_mut_attributes().insert(NODE_ID, state_id);
1522
1523        Ok(Action::change(seq_enable))
1524    }
1525
1526    fn finish_if(
1527        &mut self,
1528        i: &mut ir::If,
1529        comp: &mut ir::Component,
1530        sigs: &LibrarySignatures,
1531        _comps: &[ir::Component],
1532    ) -> VisResult {
1533        // only compile using new fsm if has new_fsm attribute
1534        if !i.attributes.has(ir::BoolAttr::NewFSM) {
1535            return Ok(Action::Continue);
1536        }
1537        let mut builder = ir::Builder::new(comp, sigs);
1538        let mut sch = Schedule::from(&mut builder);
1539
1540        sch.calculate_states_if(i, self.early_transitions)?;
1541
1542        // Compile schedule and return the group.
1543        let mut if_enable = if self.infer_fsms {
1544            ir::Control::fsm_enable(sch.realize_fsm(self.dump_fsm))
1545        } else {
1546            let fsm_impl = self.get_representation(&sch, &i.attributes);
1547            ir::Control::enable(sch.realize_schedule(
1548                self.dump_fsm,
1549                &mut self.profiling_info,
1550                fsm_impl,
1551            ))
1552        };
1553
1554        // Add NODE_ID to compiled group.
1555        let node_id = i.attributes.get(NODE_ID).unwrap();
1556        if_enable.get_mut_attributes().insert(NODE_ID, node_id);
1557
1558        Ok(Action::change(if_enable))
1559    }
1560
1561    fn finish_while(
1562        &mut self,
1563        w: &mut ir::While,
1564        comp: &mut ir::Component,
1565        sigs: &LibrarySignatures,
1566        _comps: &[ir::Component],
1567    ) -> VisResult {
1568        // only compile using new fsm if has attribute
1569        if !w.attributes.has(ir::BoolAttr::NewFSM) {
1570            return Ok(Action::Continue);
1571        }
1572        let mut builder = ir::Builder::new(comp, sigs);
1573        let mut sch = Schedule::from(&mut builder);
1574        sch.calculate_states_while(w, self.early_transitions)?;
1575
1576        // compile schedule and return the enable node
1577        let mut while_enable = if self.infer_fsms {
1578            ir::Control::fsm_enable(sch.realize_fsm(self.dump_fsm))
1579        } else {
1580            let fsm_impl = self.get_representation(&sch, &w.attributes);
1581            ir::Control::enable(sch.realize_schedule(
1582                self.dump_fsm,
1583                &mut self.profiling_info,
1584                fsm_impl,
1585            ))
1586        };
1587
1588        // Add NODE_ID to compiled enable.
1589        let node_id = w.attributes.get(NODE_ID).unwrap();
1590        while_enable.get_mut_attributes().insert(NODE_ID, node_id);
1591
1592        Ok(Action::change(while_enable))
1593    }
1594
1595    /// Compile each child in `par` block separately so each child can make
1596    /// progress independently.
1597    fn finish_par(
1598        &mut self,
1599        s: &mut ir::Par,
1600        comp: &mut ir::Component,
1601        sigs: &LibrarySignatures,
1602        _comps: &[ir::Component],
1603    ) -> VisResult {
1604        let mut builder = ir::Builder::new(comp, sigs);
1605
1606        // Compilation group
1607        let par_group = builder.add_group("par");
1608        structure!(builder;
1609            let signal_on = constant(1, 1);
1610            let signal_off = constant(0, 1);
1611        );
1612
1613        // Registers to save the done signal from each child.
1614        let mut done_regs = Vec::with_capacity(s.stmts.len());
1615
1616        // Profiling: record each par child (arm)'s group and done register
1617        let mut child_infos = Vec::with_capacity(s.stmts.len());
1618
1619        // For each child, build the enabling logic.
1620        for con in &s.stmts {
1621            // Build circuitry to enable and disable this group.
1622
1623            if self.infer_fsms {
1624                structure!(builder;
1625                    let pd = prim std_reg(1);
1626                );
1627
1628                let fsm = {
1629                    let mut sch = Schedule::from(&mut builder);
1630                    sch.calculate_states(con, self.early_transitions)?;
1631                    sch.realize_fsm(self.dump_fsm)
1632                };
1633                let fsm_start = !(guard!(pd["out"] | fsm["done"]));
1634                let fsm_done = guard!(fsm["done"]);
1635
1636                // Save the done condition of fsm in `pd`.
1637                let assigns = build_assignments!(builder;
1638                    fsm["start"] = fsm_start ? signal_on["out"];
1639                    pd["in"] = fsm_done ? signal_on["out"];
1640                    pd["write_en"] = fsm_done ? signal_on["out"];
1641                );
1642
1643                par_group.borrow_mut().assignments.extend(assigns);
1644                done_regs.push(pd)
1645            } else {
1646                let group = match con {
1647                    // Do not compile enables
1648                    ir::Control::Enable(ir::Enable { group, .. }) => {
1649                        self.profiling_info.insert(
1650                            ProfilingInfo::SingleEnable(SingleEnableInfo {
1651                                group: group.borrow().name(),
1652                                component: builder.component.name,
1653                            }),
1654                        );
1655                        Rc::clone(group)
1656                    }
1657                    // Compile complex schedule and return the group.
1658                    _ => {
1659                        let mut sch = Schedule::from(&mut builder);
1660                        sch.calculate_states(con, self.early_transitions)?;
1661                        let fsm_impl =
1662                            self.get_representation(&sch, &s.attributes);
1663                        sch.realize_schedule(
1664                            self.dump_fsm,
1665                            &mut self.profiling_info,
1666                            fsm_impl,
1667                        )
1668                    }
1669                };
1670
1671                structure!(builder;
1672                    let pd = prim std_reg(1);
1673                );
1674
1675                let group_go = !(guard!(pd["out"] | group["done"]));
1676                let group_done = guard!(group["done"]);
1677
1678                // Save the done condition of group in `pd`.
1679                let assigns = build_assignments!(builder;
1680                    group["go"] = group_go ? signal_on["out"];
1681                    pd["in"] = group_done ? signal_on["out"];
1682                    pd["write_en"] = group_done ? signal_on["out"];
1683                );
1684
1685                child_infos.push(ParChildInfo {
1686                    group: group.borrow().name(),
1687                    register: pd.borrow().name(),
1688                });
1689                par_group.borrow_mut().assignments.extend(assigns);
1690                done_regs.push(pd)
1691            };
1692        }
1693        let pos = if let Some(pos_set) =
1694            s.get_mut_attributes().get_set(calyx_frontend::SetAttr::Pos)
1695        {
1696            pos_set.iter().copied().collect()
1697        } else {
1698            Vec::new()
1699        };
1700        // Profiling: save collected information about this par
1701        self.profiling_info.insert(ProfilingInfo::Par(ParInfo {
1702            component: builder.component.name,
1703            par_group: par_group.borrow().name(),
1704            child_groups: child_infos,
1705            pos,
1706        }));
1707
1708        // Done condition for this group
1709        let done_guard = done_regs
1710            .clone()
1711            .into_iter()
1712            .map(|r| guard!(r["out"]))
1713            .fold(ir::Guard::True, ir::Guard::and);
1714
1715        // CLEANUP: Reset the registers once the group is finished.
1716        let mut cleanup = done_regs
1717            .into_iter()
1718            .flat_map(|r| {
1719                build_assignments!(builder;
1720                    r["in"] = done_guard ? signal_off["out"];
1721                    r["write_en"] = done_guard ? signal_on["out"];
1722                )
1723            })
1724            .collect::<Vec<_>>();
1725        builder
1726            .component
1727            .continuous_assignments
1728            .append(&mut cleanup);
1729
1730        // Done conditional for this group.
1731        let done = builder.build_assignment(
1732            par_group.borrow().get("done"),
1733            signal_on.borrow().get("out"),
1734            done_guard,
1735        );
1736        par_group.borrow_mut().assignments.push(done);
1737
1738        // Add NODE_ID to compiled group.
1739        let mut en = ir::Control::enable(par_group);
1740        let node_id = s.attributes.get(NODE_ID).unwrap();
1741        en.get_mut_attributes().insert(NODE_ID, node_id);
1742
1743        Ok(Action::change(en))
1744    }
1745
1746    fn finish(
1747        &mut self,
1748        comp: &mut ir::Component,
1749        sigs: &LibrarySignatures,
1750        _comps: &[ir::Component],
1751    ) -> VisResult {
1752        let control = Rc::clone(&comp.control);
1753        let attrs = comp.attributes.clone();
1754        let mut builder = ir::Builder::new(comp, sigs);
1755        let mut sch = Schedule::from(&mut builder);
1756
1757        // Add assignments for the final states
1758        sch.calculate_states(&control.borrow(), self.early_transitions)?;
1759
1760        let control_node = if self.infer_fsms {
1761            ir::Control::fsm_enable(sch.realize_fsm(self.dump_fsm))
1762        } else {
1763            let fsm_rep = self.get_representation(&sch, &attrs);
1764            ir::Control::enable(sch.realize_schedule(
1765                self.dump_fsm,
1766                &mut self.profiling_info,
1767                fsm_rep,
1768            ))
1769            // Retaining set attributes from original control node in the generated Par group
1770        };
1771
1772        Ok(Action::change(control_node))
1773    }
1774
1775    /// If requested, emit FSM json after all components are processed
1776    fn finish_context(&mut self, _ctx: &mut calyx_ir::Context) -> VisResult {
1777        if let Some(json_out_file) = &mut self.dump_fsm_json {
1778            let _ = serde_json::to_writer_pretty(
1779                json_out_file.get_write(),
1780                &self.profiling_info,
1781            );
1782        }
1783        Ok(Action::Continue)
1784    }
1785}