calyx_opt/analysis/
reaching_defns.rs

1//! Calculate the reaching definitions in a control program.
2use calyx_ir as ir;
3use itertools::Itertools;
4use std::cmp::Ordering;
5use std::cmp::{Ord, PartialOrd};
6use std::{
7    collections::{BTreeMap, BTreeSet, HashMap},
8    ops::BitOr,
9};
10
11use super::read_write_set::AssignmentAnalysis;
12
13const INVOKE_PREFIX: &str = "__invoke_";
14
15type GroupName = ir::Id;
16type InvokeName = ir::Id;
17
18/// A wrapper enum to distinguish between Ids that come from groups and ids that
19/// were fabricated during the analysis for individual invoke statements. This
20/// prevents attempting to look up the ids used for the invoke statements as
21/// there will be no corresponding group.
22#[derive(Clone, Debug, Hash, Eq, PartialEq)]
23pub enum GroupOrInvoke {
24    Group(GroupName),
25    Invoke(InvokeName),
26}
27
28impl PartialOrd for GroupOrInvoke {
29    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
30        Some(self.cmp(other))
31    }
32}
33
34impl Ord for GroupOrInvoke {
35    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
36        match (self, other) {
37            (GroupOrInvoke::Group(a), GroupOrInvoke::Group(b))
38            | (GroupOrInvoke::Invoke(a), GroupOrInvoke::Invoke(b)) => {
39                ir::Id::cmp(a, b)
40            }
41            (GroupOrInvoke::Group(_), GroupOrInvoke::Invoke(_)) => {
42                Ordering::Greater
43            }
44            (GroupOrInvoke::Invoke(_), GroupOrInvoke::Group(_)) => {
45                Ordering::Less
46            }
47        }
48    }
49}
50
51#[allow(clippy::from_over_into)]
52impl Into<ir::Id> for GroupOrInvoke {
53    fn into(self) -> ir::Id {
54        match self {
55            GroupOrInvoke::Group(id) | GroupOrInvoke::Invoke(id) => id,
56        }
57    }
58}
59
60#[derive(Debug, Default)]
61pub struct MetadataMap {
62    map: HashMap<*const ir::Invoke, ir::Id>,
63    static_map: HashMap<*const ir::StaticInvoke, ir::Id>,
64}
65
66impl MetadataMap {
67    fn attach_label(&mut self, invoke: &ir::Invoke, label: ir::Id) {
68        self.map.insert(invoke as *const ir::Invoke, label);
69    }
70
71    fn attach_label_static(
72        &mut self,
73        invoke: &ir::StaticInvoke,
74        label: ir::Id,
75    ) {
76        self.static_map
77            .insert(invoke as *const ir::StaticInvoke, label);
78    }
79
80    pub fn fetch_label(&self, invoke: &ir::Invoke) -> Option<&ir::Id> {
81        self.map.get(&(invoke as *const ir::Invoke))
82    }
83
84    pub fn fetch_label_static(
85        &self,
86        invoke: &ir::StaticInvoke,
87    ) -> Option<&ir::Id> {
88        self.static_map.get(&(invoke as *const ir::StaticInvoke))
89    }
90}
91/// A datastructure used to represent a set of definitions/uses. These are
92/// represented as pairs of (Id, GroupOrInvoke) where the Id is the identifier
93/// being defined, and the second term represents the defining location (or use
94/// location). In the case of a group, this location is just the group Id. In
95/// the case of an invoke the "location" is a unique label assigned to each
96/// invoke statement that beings with the INVOKE_PREFIX.
97///
98/// Defsets are constructed based on the assignments in a group and the ports in
99/// an invoke. If a group writes to a register then it corresponds to a
100/// definition (REGID, GROUPNAME). Similarly, this can be used to represent a
101/// use of the register REGID in the very same group.
102///
103/// These structs are used both to determine what definitions reach a given
104/// location and are also used to ensure that uses of a given definition (or
105/// group of definitions are appropriately tied to any renaming that the
106/// particular definition undergoes.
107#[derive(Clone, Debug, Default)]
108pub struct DefSet {
109    set: BTreeSet<(ir::Id, GroupOrInvoke)>,
110}
111
112impl DefSet {
113    fn extend(&mut self, writes: BTreeSet<ir::Id>, grp: GroupName) {
114        for var in writes {
115            self.set.insert((var, GroupOrInvoke::Group(grp)));
116        }
117    }
118
119    fn kill_from_writeread(
120        &self,
121        writes: &BTreeSet<ir::Id>,
122        reads: &BTreeSet<ir::Id>,
123    ) -> (Self, KilledSet) {
124        let mut killed = KilledSet::new();
125        let def = DefSet {
126            set: self
127                .set
128                .iter()
129                .cloned()
130                .filter_map(|(name, grp)| {
131                    if !writes.contains(&name) || reads.contains(&name) {
132                        Some((name, grp))
133                    } else {
134                        killed.insert(name);
135                        None
136                    }
137                })
138                .collect(),
139        };
140        (def, killed)
141    }
142
143    fn kill_from_hashset(&self, killset: &BTreeSet<ir::Id>) -> Self {
144        DefSet {
145            set: self
146                .set
147                .iter()
148                .filter(|&(name, _)| !killset.contains(name))
149                .cloned()
150                .collect(),
151        }
152    }
153}
154
155impl BitOr<&DefSet> for &DefSet {
156    type Output = DefSet;
157
158    fn bitor(self, rhs: &DefSet) -> Self::Output {
159        DefSet {
160            set: &self.set | &rhs.set,
161        }
162    }
163}
164
165type OverlapMap = BTreeMap<ir::Id, Vec<BTreeSet<(ir::Id, GroupOrInvoke)>>>;
166
167/// A struct used to compute a reaching definition analysis. The only field is a
168/// map between [GroupOrInvoke] labels and the definitions that exit the given
169/// group or the given Invoke node. This analysis is conservative and will only
170/// kill a definition if the group MUST write the given register and does not
171/// read it. If this is not the case old definitions will remain in the reaching
172/// set as we cannot be certain that they have been killed.
173///
174/// Note that this analysis assumes that groups do not appear more than once
175/// within the control structure and will provide inaccurate results if this
176/// expectation is violated.
177///
178/// Like [super::LiveRangeAnalysis] par blocks are treated via a parallel CFG approach.
179/// Concretely this means that after a par block executes any id that is killed
180/// by one arm is killed and all defs introduced (but not killed) by any arm are
181/// defined. Note that this assumes separate arms are not writing the same
182/// register or reading a registe that is written by another arm.
183#[derive(Debug, Default)]
184pub struct ReachingDefinitionAnalysis {
185    pub reach: BTreeMap<GroupOrInvoke, DefSet>,
186    pub meta: MetadataMap,
187}
188
189impl ReachingDefinitionAnalysis {
190    /// Constructs a reaching definition analysis for registers over the given
191    /// control structure. Will include dummy "definitions" for invoke statements
192    /// which can be ignored if one is not rewriting values
193    /// **NOTE**: Assumes that each group appears at only one place in the control
194    /// structure.
195    pub fn new(control: &ir::Control) -> Self {
196        let initial_set = DefSet::default();
197        let mut analysis = ReachingDefinitionAnalysis::default();
198        let mut counter: u64 = 0;
199
200        build_reaching_def(
201            control,
202            initial_set,
203            KilledSet::new(),
204            &mut analysis,
205            &mut counter,
206        );
207        analysis
208    }
209
210    /// Provides a map containing a vector of sets for each register. The sets
211    /// within contain separate groupings of definitions for the given register.
212    /// If the vector contains one set, then all the definitions for the given
213    /// register name must have the same name.
214    /// **NOTE:** Includes dummy "definitions" for continuous assignments and
215    /// uses within groups and invoke statements. This is to ensure that all
216    /// uses of a given register are rewriten with the appropriate name.
217    pub fn calculate_overlap<'a, I, T: 'a>(
218        &'a self,
219        continuous_assignments: I,
220    ) -> OverlapMap
221    where
222        I: Iterator<Item = &'a ir::Assignment<T>> + Clone + 'a,
223    {
224        let continuous_regs: Vec<ir::Id> = continuous_assignments
225            .analysis()
226            .cell_uses()
227            .filter_map(|cell| {
228                let cell_ref = cell.borrow();
229                if let Some(name) = cell_ref.type_name() {
230                    if name == "std_reg" {
231                        return Some(cell_ref.name());
232                    }
233                }
234                None
235            })
236            .collect();
237
238        let mut overlap_map: BTreeMap<
239            ir::Id,
240            Vec<BTreeSet<(ir::Id, GroupOrInvoke)>>,
241        > = BTreeMap::new();
242        for (grp, defset) in &self.reach {
243            let mut group_overlaps: BTreeMap<
244                &ir::Id,
245                BTreeSet<(ir::Id, GroupOrInvoke)>,
246            > = BTreeMap::new();
247
248            for (defname, group_name) in &defset.set {
249                let set = group_overlaps.entry(defname).or_default();
250                set.insert((*defname, group_name.clone()));
251                set.insert((*defname, grp.clone()));
252            }
253
254            for name in &continuous_regs {
255                let set = group_overlaps.entry(name).or_default();
256                set.insert((
257                    *name,
258                    GroupOrInvoke::Group("__continuous".into()),
259                ));
260            }
261
262            for (defname, set) in group_overlaps {
263                let overlap_vec = overlap_map.entry(*defname).or_default();
264
265                if overlap_vec.is_empty() {
266                    overlap_vec.push(set)
267                } else {
268                    let mut no_overlap = vec![];
269                    let mut overlap = vec![];
270
271                    for entry in overlap_vec.drain(..) {
272                        if set.is_disjoint(&entry) {
273                            no_overlap.push(entry)
274                        } else {
275                            overlap.push(entry)
276                        }
277                    }
278
279                    *overlap_vec = no_overlap;
280
281                    if overlap.is_empty() {
282                        overlap_vec.push(set);
283                    } else {
284                        overlap_vec.push(
285                            overlap
286                                .into_iter()
287                                .fold(set, |acc, entry| &acc | &entry),
288                        )
289                    }
290                }
291            }
292        }
293        overlap_map
294    }
295}
296
297type KilledSet = BTreeSet<ir::Id>;
298
299fn remove_entries_defined_by(set: &mut KilledSet, defs: &DefSet) {
300    let tmp_set: BTreeSet<_> = defs.set.iter().map(|(id, _)| id).collect();
301    *set = std::mem::take(set)
302        .into_iter()
303        .filter(|x| !tmp_set.contains(x))
304        .collect();
305}
306
307/// Returns the register cells whose out port is read anywhere in the given
308/// assignments
309fn register_reads<T>(assigns: &[ir::Assignment<T>]) -> BTreeSet<ir::Id> {
310    assigns
311        .iter()
312        .analysis()
313        .reads()
314        .filter_map(|p| {
315            let port = p.borrow();
316            let ir::PortParent::Cell(cell_wref) = &port.parent else {
317                unreachable!("Port not part of a cell");
318            };
319            // Skip this if the port is not an output
320            if &port.name != "out" {
321                return None;
322            };
323            let cr = cell_wref.upgrade();
324            let cell = cr.borrow();
325            if cell.is_primitive(Some("std_reg")) {
326                Some(cr.borrow().name())
327            } else {
328                None
329            }
330        })
331        .unique()
332        .collect()
333}
334
335// handles `build_reaching_defns` for the enable/static_enables case.
336// asgns are the assignments in the group (either static or dynamic)
337fn handle_reaching_def_enables<T>(
338    asgns: &[ir::Assignment<T>],
339    reach: DefSet,
340    rd: &mut ReachingDefinitionAnalysis,
341    group_name: ir::Id,
342) -> (DefSet, KilledSet) {
343    let writes = asgns.iter().analysis().must_writes().cells();
344    // for each write:
345    // Killing all other reaching defns for that var
346    // generating a new defn (Id, GROUP)
347    let write_set = writes
348        .filter(|x| match &x.borrow().prototype {
349            ir::CellType::Primitive { name, .. } => name == "std_reg",
350            _ => false,
351        })
352        .map(|x| x.borrow().name())
353        .collect::<BTreeSet<_>>();
354
355    let read_set = register_reads(asgns);
356
357    // only kill a def if the value is not read.
358    let (mut cur_reach, killed) =
359        reach.kill_from_writeread(&write_set, &read_set);
360    cur_reach.extend(write_set, group_name);
361
362    rd.reach
363        .insert(GroupOrInvoke::Group(group_name), cur_reach.clone());
364
365    (cur_reach, killed)
366}
367
368fn build_reaching_def_static(
369    sc: &ir::StaticControl,
370    reach: DefSet,
371    killed: KilledSet,
372    rd: &mut ReachingDefinitionAnalysis,
373    counter: &mut u64,
374) -> (DefSet, KilledSet) {
375    match sc {
376        ir::StaticControl::Empty(_) => (reach, killed),
377        ir::StaticControl::Enable(sen) => handle_reaching_def_enables(
378            &sen.group.borrow().assignments,
379            reach,
380            rd,
381            sen.group.borrow().name(),
382        ),
383        ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
384            let (post_cond_def, post_cond_killed) = build_reaching_def_static(
385                &ir::StaticControl::empty(),
386                reach.clone(),
387                killed,
388                rd,
389                counter,
390            );
391
392            let (round_1_def, mut round_1_killed) = build_reaching_def_static(
393                body,
394                post_cond_def,
395                post_cond_killed,
396                rd,
397                counter,
398            );
399
400            remove_entries_defined_by(&mut round_1_killed, &reach);
401
402            let (post_cond2_def, post_cond2_killed) = build_reaching_def(
403                &ir::Control::empty(),
404                &round_1_def | &reach,
405                round_1_killed,
406                rd,
407                counter,
408            );
409            // Run the analysis a second time to get the fixed point of the
410            // while loop using the defsets calculated during the first iteration
411            let (final_def, mut final_kill) = build_reaching_def_static(
412                body,
413                post_cond2_def.clone(),
414                post_cond2_killed,
415                rd,
416                counter,
417            );
418
419            remove_entries_defined_by(&mut final_kill, &post_cond2_def);
420
421            (&final_def | &post_cond2_def, final_kill)
422        }
423
424        ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => stmts
425            .iter()
426            .fold((reach, killed), |(acc, killed), inner_c| {
427                build_reaching_def_static(inner_c, acc, killed, rd, counter)
428            }),
429        ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
430            let (defs, par_killed): (Vec<DefSet>, Vec<KilledSet>) = stmts
431                .iter()
432                .map(|ctrl| {
433                    build_reaching_def_static(
434                        ctrl,
435                        reach.clone(),
436                        KilledSet::new(),
437                        rd,
438                        counter,
439                    )
440                })
441                .unzip();
442
443            let global_killed = par_killed
444                .iter()
445                .fold(KilledSet::new(), |acc, set| &acc | set);
446
447            let par_exit_defs = defs
448                .iter()
449                .zip(par_killed.iter())
450                .map(|(defs, kills)| {
451                    defs.kill_from_hashset(&(&global_killed - kills))
452                })
453                .fold(DefSet::default(), |acc, element| &acc | &element);
454            (par_exit_defs, &global_killed | &killed)
455        }
456        ir::StaticControl::If(ir::StaticIf {
457            tbranch, fbranch, ..
458        }) => {
459            let (post_cond_def, post_cond_killed) = build_reaching_def_static(
460                &ir::StaticControl::empty(),
461                reach,
462                killed,
463                rd,
464                counter,
465            );
466            let (t_case_def, t_case_killed) = build_reaching_def_static(
467                tbranch,
468                post_cond_def.clone(),
469                post_cond_killed.clone(),
470                rd,
471                counter,
472            );
473            let (f_case_def, f_case_killed) = build_reaching_def_static(
474                fbranch,
475                post_cond_def,
476                post_cond_killed,
477                rd,
478                counter,
479            );
480            (&t_case_def | &f_case_def, &t_case_killed | &f_case_killed)
481        }
482        ir::StaticControl::Invoke(invoke) => {
483            *counter += 1;
484
485            let iterator = invoke
486                .inputs
487                .iter()
488                .chain(invoke.outputs.iter())
489                .filter_map(|(_, port)| {
490                    if let ir::PortParent::Cell(wc) = &port.borrow().parent {
491                        let rc = wc.upgrade();
492                        let parent = rc.borrow();
493                        if parent
494                            .type_name()
495                            .unwrap_or_else(|| ir::Id::from(""))
496                            == "std_reg"
497                        {
498                            let name = format!("{INVOKE_PREFIX}{counter}");
499                            rd.meta.attach_label_static(
500                                invoke,
501                                ir::Id::from(name.clone()),
502                            );
503                            return Some((
504                                parent.name(),
505                                GroupOrInvoke::Invoke(ir::Id::from(name)),
506                            ));
507                        }
508                    }
509                    None
510                });
511
512            let mut new_reach = reach;
513            new_reach.set.extend(iterator);
514
515            (new_reach, killed)
516        }
517    }
518}
519
520// Handles both `repeat` and `while` bodies when building reaching defs.
521fn handle_repeat_while_body(
522    body: &ir::Control,
523    reach: DefSet,
524    killed: KilledSet,
525    rd: &mut ReachingDefinitionAnalysis,
526    counter: &mut u64,
527) -> (DefSet, KilledSet) {
528    let (post_cond_def, post_cond_killed) = build_reaching_def(
529        &ir::Control::empty(),
530        reach.clone(),
531        killed,
532        rd,
533        counter,
534    );
535
536    let (round_1_def, mut round_1_killed) =
537        build_reaching_def(body, post_cond_def, post_cond_killed, rd, counter);
538
539    remove_entries_defined_by(&mut round_1_killed, &reach);
540
541    let (post_cond2_def, post_cond2_killed) = build_reaching_def(
542        &ir::Control::empty(),
543        &round_1_def | &reach,
544        round_1_killed,
545        rd,
546        counter,
547    );
548    // Run the analysis a second time to get the fixed point of the
549    // while loop using the defsets calculated during the first iteration
550    let (final_def, mut final_kill) = build_reaching_def(
551        body,
552        post_cond2_def.clone(),
553        post_cond2_killed,
554        rd,
555        counter,
556    );
557
558    remove_entries_defined_by(&mut final_kill, &post_cond2_def);
559
560    (&final_def | &post_cond2_def, final_kill)
561}
562
563fn build_reaching_def(
564    c: &ir::Control,
565    reach: DefSet,
566    killed: KilledSet,
567    rd: &mut ReachingDefinitionAnalysis,
568    counter: &mut u64,
569) -> (DefSet, KilledSet) {
570    match c {
571        ir::Control::Seq(ir::Seq { stmts, .. }) => {
572            stmts
573                .iter()
574                .fold((reach, killed), |(acc, killed), inner_c| {
575                    build_reaching_def(inner_c, acc, killed, rd, counter)
576                })
577        }
578        ir::Control::Par(ir::Par { stmts, .. }) => {
579            let (defs, par_killed): (Vec<DefSet>, Vec<KilledSet>) = stmts
580                .iter()
581                .map(|ctrl| {
582                    build_reaching_def(
583                        ctrl,
584                        reach.clone(),
585                        KilledSet::new(),
586                        rd,
587                        counter,
588                    )
589                })
590                .unzip();
591
592            let global_killed = par_killed
593                .iter()
594                .fold(KilledSet::new(), |acc, set| &acc | set);
595
596            let par_exit_defs = defs
597                .iter()
598                .zip(par_killed.iter())
599                .map(|(defs, kills)| {
600                    defs.kill_from_hashset(&(&global_killed - kills))
601                })
602                .fold(DefSet::default(), |acc, element| &acc | &element);
603            (par_exit_defs, &global_killed | &killed)
604        }
605        ir::Control::If(ir::If {
606            tbranch, fbranch, ..
607        }) => {
608            let (post_cond_def, post_cond_killed) = build_reaching_def(
609                &ir::Control::empty(),
610                reach,
611                killed,
612                rd,
613                counter,
614            );
615            let (t_case_def, t_case_killed) = build_reaching_def(
616                tbranch,
617                post_cond_def.clone(),
618                post_cond_killed.clone(),
619                rd,
620                counter,
621            );
622            let (f_case_def, f_case_killed) = build_reaching_def(
623                fbranch,
624                post_cond_def,
625                post_cond_killed,
626                rd,
627                counter,
628            );
629            (&t_case_def | &f_case_def, &t_case_killed | &f_case_killed)
630        }
631        ir::Control::While(ir::While { body, .. }) => {
632            handle_repeat_while_body(body, reach, killed, rd, counter)
633        }
634        ir::Control::Invoke(invoke) => {
635            *counter += 1;
636
637            let iterator = invoke
638                .inputs
639                .iter()
640                .chain(invoke.outputs.iter())
641                .filter_map(|(_, port)| {
642                    if let ir::PortParent::Cell(wc) = &port.borrow().parent {
643                        let rc = wc.upgrade();
644                        let parent = rc.borrow();
645                        if parent
646                            .type_name()
647                            .unwrap_or_else(|| ir::Id::from(""))
648                            == "std_reg"
649                        {
650                            let name = format!("{INVOKE_PREFIX}{counter}");
651                            rd.meta.attach_label(
652                                invoke,
653                                ir::Id::from(name.clone()),
654                            );
655                            return Some((
656                                parent.name(),
657                                GroupOrInvoke::Invoke(ir::Id::from(name)),
658                            ));
659                        }
660                    }
661                    None
662                });
663
664            let mut new_reach = reach;
665            new_reach.set.extend(iterator);
666
667            (new_reach, killed)
668        }
669        ir::Control::Enable(en) => handle_reaching_def_enables(
670            &en.group.borrow().assignments,
671            reach,
672            rd,
673            en.group.borrow().name(),
674        ),
675        ir::Control::Empty(_) => (reach, killed),
676        ir::Control::Repeat(ir::Repeat { body, .. }) => {
677            handle_repeat_while_body(body, reach, killed, rd, counter)
678        }
679        ir::Control::Static(sc) => {
680            build_reaching_def_static(sc, reach, killed, rd, counter)
681        }
682        ir::Control::FSMEnable(_) => {
683            todo!("should not encounter fsm nodes")
684        }
685    }
686}