calyx_opt/passes/
uniquefy_enables.rs

1use std::{cmp, collections::BTreeMap, collections::BTreeSet};
2
3use crate::traversal::{
4    Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor,
5};
6use calyx_frontend::SetAttr;
7use calyx_ir::{self as ir, Nothing};
8use calyx_utils::{CalyxResult, OutputFile};
9use serde::Serialize;
10
11// Converts each dynamic and static enable to an enable of a unique group.
12// Also (1) computes path descriptors for each unique enable group and par (outputted to `path_descriptor_json` if provided); and
13// (2) statically assigns par thread ids to each unique enable group (outputted to `par_thread_json` if provided).
14// Used by the profiler.
15
16pub struct UniquefyEnables {
17    path_descriptor_json: Option<OutputFile>,
18    path_descriptor_infos: BTreeMap<String, PathDescriptorInfo>,
19    par_thread_json: Option<OutputFile>,
20    par_thread_info: BTreeMap<String, BTreeMap<String, u32>>,
21}
22
23impl Named for UniquefyEnables {
24    fn name() -> &'static str {
25        "uniquefy-enables"
26    }
27
28    fn description() -> &'static str {
29        "Make all control (dynamic and static) enables unique."
30    }
31
32    fn opts() -> Vec<crate::traversal::PassOpt> {
33        vec![
34            PassOpt::new(
35                "path-descriptor-json",
36                "Write the path descriptor of each enable and par to a JSON file",
37                ParseVal::OutStream(OutputFile::Null),
38                PassOpt::parse_outstream,
39            ),
40            PassOpt::new(
41                "par-thread-json",
42                "Write an assigned thread ID of each enable to a JSON file",
43                ParseVal::OutStream(OutputFile::Null),
44                PassOpt::parse_outstream,
45            ),
46        ]
47    }
48}
49
50/// Information to serialize for locating path descriptors
51#[derive(Serialize)]
52struct PathDescriptorInfo {
53    /// enable id --> descriptor
54    pub enables: BTreeMap<String, String>,
55    /// descriptor --> position set
56    /// (Ideally I'd do a position set --> descriptor mapping but
57    /// a set shouldn't be a key.)
58    pub control_pos: BTreeMap<String, BTreeSet<u32>>,
59}
60
61impl ConstructVisitor for UniquefyEnables {
62    fn from(ctx: &ir::Context) -> CalyxResult<Self>
63    where
64        Self: Sized + Named,
65    {
66        let opts = Self::get_opts(ctx);
67        Ok(UniquefyEnables {
68            path_descriptor_json: opts[&"path-descriptor-json"]
69                .not_null_outstream(),
70            path_descriptor_infos: BTreeMap::new(),
71            par_thread_json: opts[&"par-thread-json"].not_null_outstream(),
72            par_thread_info: BTreeMap::new(),
73        })
74    }
75
76    fn clear_data(&mut self) {}
77}
78
79fn assign_par_threads_static(
80    control: &ir::StaticControl,
81    start_idx: u32,
82    next_idx: u32,
83    enable_to_track: &mut BTreeMap<String, u32>,
84) -> u32 {
85    match control {
86        ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
87            assign_par_threads_static(
88                body,
89                start_idx,
90                next_idx,
91                enable_to_track,
92            )
93        }
94        ir::StaticControl::Enable(ir::StaticEnable { group, .. }) => {
95            let group_name = group.borrow().name().to_string();
96            enable_to_track.insert(group_name, start_idx);
97            start_idx + 1
98        }
99        ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
100            let mut idx = next_idx;
101            for stmt in stmts {
102                idx = assign_par_threads_static(
103                    stmt,
104                    idx,
105                    idx + 1,
106                    enable_to_track,
107                );
108            }
109            idx
110        }
111        ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => {
112            let mut new_next_idx = next_idx;
113            for stmt in stmts {
114                let potential_new_idx = assign_par_threads_static(
115                    stmt,
116                    start_idx,
117                    new_next_idx,
118                    enable_to_track,
119                );
120                new_next_idx = cmp::max(new_next_idx, potential_new_idx)
121            }
122            new_next_idx
123        }
124        ir::StaticControl::If(ir::StaticIf {
125            tbranch, fbranch, ..
126        }) => {
127            let false_next_idx = assign_par_threads_static(
128                tbranch,
129                start_idx,
130                next_idx,
131                enable_to_track,
132            );
133            assign_par_threads_static(
134                fbranch,
135                start_idx,
136                false_next_idx,
137                enable_to_track,
138            )
139        }
140        _ => next_idx,
141    }
142}
143
144fn assign_par_threads(
145    control: &ir::Control,
146    start_idx: u32,
147    next_idx: u32,
148    enable_to_track: &mut BTreeMap<String, u32>,
149) -> u32 {
150    match control {
151        ir::Control::Seq(ir::Seq { stmts, .. }) => {
152            let mut new_next_idx = next_idx;
153            for stmt in stmts {
154                let potential_new_idx = assign_par_threads(
155                    stmt,
156                    start_idx,
157                    new_next_idx,
158                    enable_to_track,
159                );
160                new_next_idx = cmp::max(new_next_idx, potential_new_idx)
161            }
162            new_next_idx
163        }
164        ir::Control::Enable(enable) => {
165            let group_name = enable.group.borrow().name().to_string();
166            enable_to_track.insert(group_name, start_idx);
167            start_idx + 1
168        }
169        ir::Control::Par(ir::Par { stmts, .. }) => {
170            let mut idx = next_idx;
171            for stmt in stmts {
172                idx = assign_par_threads(stmt, idx, idx + 1, enable_to_track);
173            }
174            idx
175        }
176        ir::Control::If(ir::If {
177            tbranch,
178            fbranch,
179            cond,
180            ..
181        }) => {
182            let true_next_idx = if let Some(comb_group) = cond {
183                enable_to_track
184                    .insert(comb_group.borrow().name().to_string(), start_idx);
185                start_idx + 1
186            } else {
187                start_idx
188            };
189            let false_next_idx = assign_par_threads(
190                tbranch,
191                true_next_idx,
192                next_idx,
193                enable_to_track,
194            );
195            assign_par_threads(
196                fbranch,
197                start_idx,
198                false_next_idx,
199                enable_to_track,
200            )
201        }
202        ir::Control::While(ir::While { body, cond, .. }) => {
203            let body_start_idx = if let Some(comb_group) = cond {
204                enable_to_track
205                    .insert(comb_group.borrow().name().to_string(), start_idx);
206                start_idx + 1
207            } else {
208                start_idx
209            };
210            assign_par_threads(body, body_start_idx, next_idx, enable_to_track)
211        }
212        ir::Control::Repeat(ir::Repeat { body, .. }) => {
213            assign_par_threads(body, start_idx, next_idx, enable_to_track)
214        }
215        ir::Control::Static(static_control) => assign_par_threads_static(
216            static_control,
217            start_idx,
218            next_idx,
219            enable_to_track,
220        ),
221        ir::Control::Invoke(_) => {
222            panic!("compile-invoke should be run before uniquefy-enables!")
223        }
224        _ => next_idx,
225    }
226}
227
228fn compute_path_descriptors_static(
229    control: &ir::StaticControl,
230    current_id: String,
231    path_descriptor_info: &mut PathDescriptorInfo,
232    parent_is_component: bool,
233) {
234    match control {
235        ir::StaticControl::Repeat(ir::StaticRepeat {
236            body,
237            attributes,
238            ..
239        }) => {
240            let repeat_id = format!("{current_id}-");
241            let body_id = format!("{repeat_id}b");
242            compute_path_descriptors_static(
243                body,
244                body_id,
245                path_descriptor_info,
246                false,
247            );
248            let new_pos_set = retrieve_pos_set(attributes);
249            path_descriptor_info
250                .control_pos
251                .insert(repeat_id, new_pos_set);
252        }
253        ir::StaticControl::Enable(ir::StaticEnable { group, .. }) => {
254            let group_id = if parent_is_component {
255                // edge case: the entire control is just one static enable
256                format!("{current_id}0")
257            } else {
258                current_id
259            };
260            let group_name = group.borrow().name();
261            path_descriptor_info
262                .enables
263                .insert(group_name.to_string(), group_id);
264        }
265        ir::StaticControl::Par(ir::StaticPar {
266            stmts, attributes, ..
267        }) => {
268            let par_id = format!("{current_id}-");
269            for (acc, stmt) in stmts.iter().enumerate() {
270                let stmt_id = format!("{par_id}{acc}");
271                compute_path_descriptors_static(
272                    stmt,
273                    stmt_id,
274                    path_descriptor_info,
275                    false,
276                );
277            }
278            let new_pos_set: BTreeSet<u32> = retrieve_pos_set(attributes);
279            path_descriptor_info.control_pos.insert(par_id, new_pos_set);
280        }
281        ir::StaticControl::Seq(ir::StaticSeq {
282            stmts, attributes, ..
283        }) => {
284            let seq_id = format!("{current_id}-");
285            for (acc, stmt) in stmts.iter().enumerate() {
286                let stmt_id = format!("{seq_id}{acc}");
287                compute_path_descriptors_static(
288                    stmt,
289                    stmt_id,
290                    path_descriptor_info,
291                    false,
292                );
293            }
294            let new_pos_set: BTreeSet<u32> = retrieve_pos_set(attributes);
295            path_descriptor_info.control_pos.insert(seq_id, new_pos_set);
296        }
297        ir::StaticControl::If(ir::StaticIf {
298            tbranch,
299            fbranch,
300            attributes,
301            ..
302        }) => {
303            let if_id = format!("{current_id}-");
304            // process true branch
305            let true_id = format!("{if_id}t");
306            compute_path_descriptors_static(
307                tbranch,
308                true_id,
309                path_descriptor_info,
310                false,
311            );
312            // process false branch
313            let false_id = format!("{if_id}f");
314            compute_path_descriptors_static(
315                fbranch,
316                false_id,
317                path_descriptor_info,
318                false,
319            );
320            path_descriptor_info
321                .control_pos
322                .insert(if_id, retrieve_pos_set(attributes));
323        }
324        ir::StaticControl::Empty(_empty) => (),
325        ir::StaticControl::Invoke(_static_invoke) => {
326            panic!("compile-invoke should be run before unique-control!")
327        }
328    }
329}
330
331fn compute_path_descriptors(
332    control: &ir::Control,
333    current_id: String,
334    path_descriptor_info: &mut PathDescriptorInfo,
335    parent_is_component: bool,
336) {
337    match control {
338        ir::Control::Seq(ir::Seq {
339            stmts, attributes, ..
340        }) => {
341            let seq_id = format!("{current_id}-");
342            for (acc, stmt) in stmts.iter().enumerate() {
343                let stmt_id = format!("{current_id}-{acc}");
344                compute_path_descriptors(
345                    stmt,
346                    stmt_id,
347                    path_descriptor_info,
348                    false,
349                );
350            }
351            let new_pos_set = retrieve_pos_set(attributes);
352            path_descriptor_info.control_pos.insert(seq_id, new_pos_set);
353        }
354        ir::Control::Par(ir::Par {
355            stmts, attributes, ..
356        }) => {
357            let par_id = format!("{current_id}-");
358            for (acc, stmt) in stmts.iter().enumerate() {
359                let stmt_id = format!("{par_id}{acc}");
360                compute_path_descriptors(
361                    stmt,
362                    stmt_id,
363                    path_descriptor_info,
364                    false,
365                );
366            }
367            // add this node to path_descriptor_info
368            let new_pos_set = retrieve_pos_set(attributes);
369            path_descriptor_info.control_pos.insert(par_id, new_pos_set);
370        }
371        ir::Control::If(ir::If {
372            tbranch,
373            fbranch,
374            attributes,
375            cond,
376            ..
377        }) => {
378            let if_id = format!("{current_id}-");
379            // process condition if it exists
380            if let Some(comb_group) = cond {
381                let comb_id = format!("{if_id}c");
382                path_descriptor_info
383                    .enables
384                    .insert(comb_group.borrow().name().to_string(), comb_id);
385            }
386
387            // process true branch
388            let true_id = format!("{if_id}t");
389            compute_path_descriptors(
390                tbranch,
391                true_id,
392                path_descriptor_info,
393                false,
394            );
395            // process false branch
396            let false_id = format!("{if_id}f");
397            compute_path_descriptors(
398                fbranch,
399                false_id,
400                path_descriptor_info,
401                false,
402            );
403            // add this node to path_descriptor_info
404            let new_pos_set = retrieve_pos_set(attributes);
405            path_descriptor_info.control_pos.insert(if_id, new_pos_set);
406        }
407        ir::Control::While(ir::While {
408            body,
409            attributes,
410            cond,
411            ..
412        }) => {
413            let while_id = format!("{current_id}-");
414            let body_id = format!("{while_id}b");
415            // FIXME: we need to create unique enables for comb groups associated with `while`s and `if`s`
416
417            // add path descriptor for comb group associated with while if exists
418            if let Some(comb_group) = cond {
419                let comb_id = format!("{while_id}c");
420                path_descriptor_info
421                    .enables
422                    .insert(comb_group.borrow().name().to_string(), comb_id);
423            }
424            compute_path_descriptors(
425                body,
426                body_id,
427                path_descriptor_info,
428                false,
429            );
430            // add this node to path_descriptor_info
431            let new_pos_set = retrieve_pos_set(attributes);
432            path_descriptor_info
433                .control_pos
434                .insert(while_id, new_pos_set);
435        }
436        ir::Control::Enable(ir::Enable { group, .. }) => {
437            let group_id = if parent_is_component {
438                // edge case: the entire control is just one enable
439                format!("{current_id}0")
440            } else {
441                current_id
442            };
443            let group_name = group.borrow().name();
444            path_descriptor_info
445                .enables
446                .insert(group_name.to_string(), group_id);
447        }
448        ir::Control::Repeat(ir::Repeat {
449            body, attributes, ..
450        }) => {
451            let repeat_id = format!("{current_id}-");
452            let body_id = format!("{repeat_id}b");
453            compute_path_descriptors(
454                body,
455                body_id,
456                path_descriptor_info,
457                false,
458            );
459            // add this node to path_descriptor_info
460            let new_pos_set = retrieve_pos_set(attributes);
461            path_descriptor_info
462                .control_pos
463                .insert(repeat_id, new_pos_set);
464        }
465        ir::Control::Static(static_control) => {
466            compute_path_descriptors_static(
467                static_control,
468                current_id,
469                path_descriptor_info,
470                parent_is_component,
471            );
472        }
473        ir::Control::Empty(_) => (),
474        ir::Control::FSMEnable(_) => todo!(),
475        ir::Control::Invoke(_) => {
476            panic!("compile-invoke should be run before unique-control!")
477        }
478    }
479}
480
481/// Returns a BTreeSet with the elements contained in the @pos set attribute.
482fn retrieve_pos_set(attributes: &calyx_ir::Attributes) -> BTreeSet<u32> {
483    let new_pos_set: BTreeSet<u32> =
484        if let Some(pos_set) = attributes.get_set(SetAttr::Pos) {
485            pos_set.iter().copied().collect()
486        } else {
487            BTreeSet::new()
488        };
489    new_pos_set
490}
491
492/// Helper function to construct a unique version of a combinational group used as the
493/// condition in an if or a while, if one exists. Otherwise returns None.
494fn create_unique_comb_group(
495    cond: &Option<std::rc::Rc<std::cell::RefCell<calyx_ir::CombGroup>>>,
496    comp: &mut calyx_ir::Component,
497    sigs: &calyx_ir::LibrarySignatures,
498) -> Option<std::rc::Rc<std::cell::RefCell<calyx_ir::CombGroup>>> {
499    if let Some(comb_group) = cond {
500        // UG stands for "unique group". This is to separate these names from the original group names
501        let unique_comb_group_name: String =
502            format!("{}UG", comb_group.borrow().name());
503        let mut builder = ir::Builder::new(comp, sigs);
504        let unique_comb_group = builder.add_comb_group(unique_comb_group_name);
505        unique_comb_group.borrow_mut().assignments =
506            comb_group.borrow().assignments.clone();
507        unique_comb_group.borrow_mut().attributes =
508            comb_group.borrow().attributes.clone();
509        Some(unique_comb_group)
510    } else {
511        None
512    }
513}
514
515impl Visitor for UniquefyEnables {
516    fn finish_while(
517        &mut self,
518        s: &mut calyx_ir::While,
519        comp: &mut calyx_ir::Component,
520        sigs: &calyx_ir::LibrarySignatures,
521        _comps: &[calyx_ir::Component],
522    ) -> VisResult {
523        // create a freshly named version of the condition comb group if one exists.
524        s.cond = create_unique_comb_group(&s.cond, comp, sigs);
525        Ok(Action::Continue)
526    }
527
528    fn finish_if(
529        &mut self,
530        s: &mut calyx_ir::If,
531        comp: &mut calyx_ir::Component,
532        sigs: &calyx_ir::LibrarySignatures,
533        _comps: &[calyx_ir::Component],
534    ) -> VisResult {
535        // create a freshly named version of the condition comb group if one exists.
536        s.cond = create_unique_comb_group(&s.cond, comp, sigs);
537        Ok(Action::Continue)
538    }
539
540    fn enable(
541        &mut self,
542        s: &mut calyx_ir::Enable,
543        comp: &mut calyx_ir::Component,
544        sigs: &calyx_ir::LibrarySignatures,
545        _comps: &[calyx_ir::Component],
546    ) -> VisResult {
547        // create a unique group for this particular enable.
548        let group_name = s.group.borrow().name();
549        // UG stands for "unique group". This is to separate these names from the original group names
550        let unique_group_name: String = format!("{group_name}UG");
551        // create an unique-ified version of the group
552        let mut builder = ir::Builder::new(comp, sigs);
553        let unique_group = builder.add_group(unique_group_name);
554        let mut unique_group_assignments: Vec<calyx_ir::Assignment<Nothing>> =
555            Vec::new();
556        for asgn in s.group.borrow().assignments.iter() {
557            if asgn.dst.borrow().get_parent_name() == group_name
558                && asgn.dst.borrow().name == "done"
559            {
560                // done needs to be reassigned
561                let new_done_asgn = builder.build_assignment(
562                    unique_group.borrow().get("done"),
563                    asgn.src.clone(),
564                    *asgn.guard.clone(),
565                );
566                unique_group_assignments.push(new_done_asgn);
567            } else {
568                unique_group_assignments.push(asgn.clone());
569            }
570        }
571        unique_group
572            .borrow_mut()
573            .assignments
574            .append(&mut unique_group_assignments);
575        // copy over all attributes that were in the original group.
576        unique_group.borrow_mut().attributes =
577            s.group.borrow().attributes.clone();
578        Ok(Action::Change(Box::new(ir::Control::enable(unique_group))))
579    }
580
581    fn static_enable(
582        &mut self,
583        s: &mut calyx_ir::StaticEnable,
584        comp: &mut calyx_ir::Component,
585        sigs: &calyx_ir::LibrarySignatures,
586        _comps: &[calyx_ir::Component],
587    ) -> VisResult {
588        // create a unique group for this particular static enable.
589        let group_name = s.group.borrow().name();
590        // UG stands for "unique group". This is to separate these names from the original group names
591        let unique_group_name = format!("{group_name}UG");
592        // create an unique-ified version of the group
593        let mut builder = ir::Builder::new(comp, sigs);
594        let unique_group = builder.add_static_group(
595            unique_group_name,
596            s.group.borrow().get_latency(),
597        );
598        // Since we don't need to worry about setting the `done` signal, the assignments of unique_group are
599        // a straight copy of the original group's assignments
600        unique_group.borrow_mut().assignments =
601            s.group.borrow().assignments.clone();
602        // copy over all attributes that were in the original group.
603        unique_group.borrow_mut().attributes =
604            s.group.borrow().attributes.clone();
605        Ok(Action::Change(Box::new(ir::Control::static_enable(
606            unique_group,
607        ))))
608    }
609
610    fn finish(
611        &mut self,
612        comp: &mut calyx_ir::Component,
613        _sigs: &calyx_ir::LibrarySignatures,
614        _comps: &[calyx_ir::Component],
615    ) -> VisResult {
616        // Compute path descriptors for each enable and par block in the component.
617        let control = comp.control.borrow();
618        let mut path_descriptor_info = PathDescriptorInfo {
619            enables: BTreeMap::new(),
620            control_pos: BTreeMap::new(),
621        };
622        compute_path_descriptors(
623            &control,
624            format!("{}.", comp.name),
625            &mut path_descriptor_info,
626            true,
627        );
628        self.path_descriptor_infos
629            .insert(comp.name.to_string(), path_descriptor_info);
630        // Compute par thread ids for each enable in the component.
631        let mut enable_to_track: BTreeMap<String, u32> = BTreeMap::new();
632        assign_par_threads(&control, 0, 1, &mut enable_to_track);
633        self.par_thread_info
634            .insert(comp.name.to_string(), enable_to_track);
635        Ok(Action::Continue)
636    }
637
638    fn finish_context(&mut self, _ctx: &mut calyx_ir::Context) -> VisResult {
639        // Write path descriptors to file if prompted.
640        if let Some(json_out_file) = &mut self.path_descriptor_json {
641            let _ = serde_json::to_writer_pretty(
642                json_out_file.get_write(),
643                &self.path_descriptor_infos,
644            );
645        }
646        // Write par thread assignments to file if prompted.
647        if let Some(json_out_file) = &mut self.par_thread_json {
648            let _ = serde_json::to_writer_pretty(
649                json_out_file.get_write(),
650                &self.par_thread_info,
651            );
652        }
653        Ok(Action::Continue)
654    }
655}