1use calyx_ir as ir;
3use itertools::Itertools;
4use std::cmp::Ordering;
5use std::cmp::{Ord, PartialOrd};
6use std::{
7 collections::{BTreeMap, BTreeSet, HashMap},
8 ops::BitOr,
9};
10
11use super::read_write_set::AssignmentAnalysis;
12
13const INVOKE_PREFIX: &str = "__invoke_";
14
15type GroupName = ir::Id;
16type InvokeName = ir::Id;
17
18#[derive(Clone, Debug, Hash, Eq, PartialEq)]
23pub enum GroupOrInvoke {
24 Group(GroupName),
25 Invoke(InvokeName),
26}
27
28impl PartialOrd for GroupOrInvoke {
29 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
30 Some(self.cmp(other))
31 }
32}
33
34impl Ord for GroupOrInvoke {
35 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
36 match (self, other) {
37 (GroupOrInvoke::Group(a), GroupOrInvoke::Group(b))
38 | (GroupOrInvoke::Invoke(a), GroupOrInvoke::Invoke(b)) => {
39 ir::Id::cmp(a, b)
40 }
41 (GroupOrInvoke::Group(_), GroupOrInvoke::Invoke(_)) => {
42 Ordering::Greater
43 }
44 (GroupOrInvoke::Invoke(_), GroupOrInvoke::Group(_)) => {
45 Ordering::Less
46 }
47 }
48 }
49}
50
51#[allow(clippy::from_over_into)]
52impl Into<ir::Id> for GroupOrInvoke {
53 fn into(self) -> ir::Id {
54 match self {
55 GroupOrInvoke::Group(id) | GroupOrInvoke::Invoke(id) => id,
56 }
57 }
58}
59
60#[derive(Debug, Default)]
61pub struct MetadataMap {
62 map: HashMap<*const ir::Invoke, ir::Id>,
63 static_map: HashMap<*const ir::StaticInvoke, ir::Id>,
64}
65
66impl MetadataMap {
67 fn attach_label(&mut self, invoke: &ir::Invoke, label: ir::Id) {
68 self.map.insert(invoke as *const ir::Invoke, label);
69 }
70
71 fn attach_label_static(
72 &mut self,
73 invoke: &ir::StaticInvoke,
74 label: ir::Id,
75 ) {
76 self.static_map
77 .insert(invoke as *const ir::StaticInvoke, label);
78 }
79
80 pub fn fetch_label(&self, invoke: &ir::Invoke) -> Option<&ir::Id> {
81 self.map.get(&(invoke as *const ir::Invoke))
82 }
83
84 pub fn fetch_label_static(
85 &self,
86 invoke: &ir::StaticInvoke,
87 ) -> Option<&ir::Id> {
88 self.static_map.get(&(invoke as *const ir::StaticInvoke))
89 }
90}
91#[derive(Clone, Debug, Default)]
108pub struct DefSet {
109 set: BTreeSet<(ir::Id, GroupOrInvoke)>,
110}
111
112impl DefSet {
113 fn extend(&mut self, writes: BTreeSet<ir::Id>, grp: GroupName) {
114 for var in writes {
115 self.set.insert((var, GroupOrInvoke::Group(grp)));
116 }
117 }
118
119 fn kill_from_writeread(
120 &self,
121 writes: &BTreeSet<ir::Id>,
122 reads: &BTreeSet<ir::Id>,
123 ) -> (Self, KilledSet) {
124 let mut killed = KilledSet::new();
125 let def = DefSet {
126 set: self
127 .set
128 .iter()
129 .cloned()
130 .filter_map(|(name, grp)| {
131 if !writes.contains(&name) || reads.contains(&name) {
132 Some((name, grp))
133 } else {
134 killed.insert(name);
135 None
136 }
137 })
138 .collect(),
139 };
140 (def, killed)
141 }
142
143 fn kill_from_hashset(&self, killset: &BTreeSet<ir::Id>) -> Self {
144 DefSet {
145 set: self
146 .set
147 .iter()
148 .filter(|&(name, _)| !killset.contains(name))
149 .cloned()
150 .collect(),
151 }
152 }
153}
154
155impl BitOr<&DefSet> for &DefSet {
156 type Output = DefSet;
157
158 fn bitor(self, rhs: &DefSet) -> Self::Output {
159 DefSet {
160 set: &self.set | &rhs.set,
161 }
162 }
163}
164
165type OverlapMap = BTreeMap<ir::Id, Vec<BTreeSet<(ir::Id, GroupOrInvoke)>>>;
166
167#[derive(Debug, Default)]
184pub struct ReachingDefinitionAnalysis {
185 pub reach: BTreeMap<GroupOrInvoke, DefSet>,
186 pub meta: MetadataMap,
187}
188
189impl ReachingDefinitionAnalysis {
190 pub fn new(control: &ir::Control) -> Self {
196 let initial_set = DefSet::default();
197 let mut analysis = ReachingDefinitionAnalysis::default();
198 let mut counter: u64 = 0;
199
200 build_reaching_def(
201 control,
202 initial_set,
203 KilledSet::new(),
204 &mut analysis,
205 &mut counter,
206 );
207 analysis
208 }
209
210 pub fn calculate_overlap<'a, I, T: 'a>(
218 &'a self,
219 continuous_assignments: I,
220 ) -> OverlapMap
221 where
222 I: Iterator<Item = &'a ir::Assignment<T>> + Clone + 'a,
223 {
224 let continuous_regs: Vec<ir::Id> = continuous_assignments
225 .analysis()
226 .cell_uses()
227 .filter_map(|cell| {
228 let cell_ref = cell.borrow();
229 if let Some(name) = cell_ref.type_name() {
230 if name == "std_reg" {
231 return Some(cell_ref.name());
232 }
233 }
234 None
235 })
236 .collect();
237
238 let mut overlap_map: BTreeMap<
239 ir::Id,
240 Vec<BTreeSet<(ir::Id, GroupOrInvoke)>>,
241 > = BTreeMap::new();
242 for (grp, defset) in &self.reach {
243 let mut group_overlaps: BTreeMap<
244 &ir::Id,
245 BTreeSet<(ir::Id, GroupOrInvoke)>,
246 > = BTreeMap::new();
247
248 for (defname, group_name) in &defset.set {
249 let set = group_overlaps.entry(defname).or_default();
250 set.insert((*defname, group_name.clone()));
251 set.insert((*defname, grp.clone()));
252 }
253
254 for name in &continuous_regs {
255 let set = group_overlaps.entry(name).or_default();
256 set.insert((
257 *name,
258 GroupOrInvoke::Group("__continuous".into()),
259 ));
260 }
261
262 for (defname, set) in group_overlaps {
263 let overlap_vec = overlap_map.entry(*defname).or_default();
264
265 if overlap_vec.is_empty() {
266 overlap_vec.push(set)
267 } else {
268 let mut no_overlap = vec![];
269 let mut overlap = vec![];
270
271 for entry in overlap_vec.drain(..) {
272 if set.is_disjoint(&entry) {
273 no_overlap.push(entry)
274 } else {
275 overlap.push(entry)
276 }
277 }
278
279 *overlap_vec = no_overlap;
280
281 if overlap.is_empty() {
282 overlap_vec.push(set);
283 } else {
284 overlap_vec.push(
285 overlap
286 .into_iter()
287 .fold(set, |acc, entry| &acc | &entry),
288 )
289 }
290 }
291 }
292 }
293 overlap_map
294 }
295}
296
297type KilledSet = BTreeSet<ir::Id>;
298
299fn remove_entries_defined_by(set: &mut KilledSet, defs: &DefSet) {
300 let tmp_set: BTreeSet<_> = defs.set.iter().map(|(id, _)| id).collect();
301 *set = std::mem::take(set)
302 .into_iter()
303 .filter(|x| !tmp_set.contains(x))
304 .collect();
305}
306
307fn register_reads<T>(assigns: &[ir::Assignment<T>]) -> BTreeSet<ir::Id> {
310 assigns
311 .iter()
312 .analysis()
313 .reads()
314 .filter_map(|p| {
315 let port = p.borrow();
316 let ir::PortParent::Cell(cell_wref) = &port.parent else {
317 unreachable!("Port not part of a cell");
318 };
319 if &port.name != "out" {
321 return None;
322 };
323 let cr = cell_wref.upgrade();
324 let cell = cr.borrow();
325 if cell.is_primitive(Some("std_reg")) {
326 Some(cr.borrow().name())
327 } else {
328 None
329 }
330 })
331 .unique()
332 .collect()
333}
334
335fn handle_reaching_def_enables<T>(
338 asgns: &[ir::Assignment<T>],
339 reach: DefSet,
340 rd: &mut ReachingDefinitionAnalysis,
341 group_name: ir::Id,
342) -> (DefSet, KilledSet) {
343 let writes = asgns.iter().analysis().must_writes().cells();
344 let write_set = writes
348 .filter(|x| match &x.borrow().prototype {
349 ir::CellType::Primitive { name, .. } => name == "std_reg",
350 _ => false,
351 })
352 .map(|x| x.borrow().name())
353 .collect::<BTreeSet<_>>();
354
355 let read_set = register_reads(asgns);
356
357 let (mut cur_reach, killed) =
359 reach.kill_from_writeread(&write_set, &read_set);
360 cur_reach.extend(write_set, group_name);
361
362 rd.reach
363 .insert(GroupOrInvoke::Group(group_name), cur_reach.clone());
364
365 (cur_reach, killed)
366}
367
368fn build_reaching_def_static(
369 sc: &ir::StaticControl,
370 reach: DefSet,
371 killed: KilledSet,
372 rd: &mut ReachingDefinitionAnalysis,
373 counter: &mut u64,
374) -> (DefSet, KilledSet) {
375 match sc {
376 ir::StaticControl::Empty(_) => (reach, killed),
377 ir::StaticControl::Enable(sen) => handle_reaching_def_enables(
378 &sen.group.borrow().assignments,
379 reach,
380 rd,
381 sen.group.borrow().name(),
382 ),
383 ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
384 let (post_cond_def, post_cond_killed) = build_reaching_def_static(
385 &ir::StaticControl::empty(),
386 reach.clone(),
387 killed,
388 rd,
389 counter,
390 );
391
392 let (round_1_def, mut round_1_killed) = build_reaching_def_static(
393 body,
394 post_cond_def,
395 post_cond_killed,
396 rd,
397 counter,
398 );
399
400 remove_entries_defined_by(&mut round_1_killed, &reach);
401
402 let (post_cond2_def, post_cond2_killed) = build_reaching_def(
403 &ir::Control::empty(),
404 &round_1_def | &reach,
405 round_1_killed,
406 rd,
407 counter,
408 );
409 let (final_def, mut final_kill) = build_reaching_def_static(
412 body,
413 post_cond2_def.clone(),
414 post_cond2_killed,
415 rd,
416 counter,
417 );
418
419 remove_entries_defined_by(&mut final_kill, &post_cond2_def);
420
421 (&final_def | &post_cond2_def, final_kill)
422 }
423
424 ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => stmts
425 .iter()
426 .fold((reach, killed), |(acc, killed), inner_c| {
427 build_reaching_def_static(inner_c, acc, killed, rd, counter)
428 }),
429 ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
430 let (defs, par_killed): (Vec<DefSet>, Vec<KilledSet>) = stmts
431 .iter()
432 .map(|ctrl| {
433 build_reaching_def_static(
434 ctrl,
435 reach.clone(),
436 KilledSet::new(),
437 rd,
438 counter,
439 )
440 })
441 .unzip();
442
443 let global_killed = par_killed
444 .iter()
445 .fold(KilledSet::new(), |acc, set| &acc | set);
446
447 let par_exit_defs = defs
448 .iter()
449 .zip(par_killed.iter())
450 .map(|(defs, kills)| {
451 defs.kill_from_hashset(&(&global_killed - kills))
452 })
453 .fold(DefSet::default(), |acc, element| &acc | &element);
454 (par_exit_defs, &global_killed | &killed)
455 }
456 ir::StaticControl::If(ir::StaticIf {
457 tbranch, fbranch, ..
458 }) => {
459 let (post_cond_def, post_cond_killed) = build_reaching_def_static(
460 &ir::StaticControl::empty(),
461 reach,
462 killed,
463 rd,
464 counter,
465 );
466 let (t_case_def, t_case_killed) = build_reaching_def_static(
467 tbranch,
468 post_cond_def.clone(),
469 post_cond_killed.clone(),
470 rd,
471 counter,
472 );
473 let (f_case_def, f_case_killed) = build_reaching_def_static(
474 fbranch,
475 post_cond_def,
476 post_cond_killed,
477 rd,
478 counter,
479 );
480 (&t_case_def | &f_case_def, &t_case_killed | &f_case_killed)
481 }
482 ir::StaticControl::Invoke(invoke) => {
483 *counter += 1;
484
485 let iterator = invoke
486 .inputs
487 .iter()
488 .chain(invoke.outputs.iter())
489 .filter_map(|(_, port)| {
490 if let ir::PortParent::Cell(wc) = &port.borrow().parent {
491 let rc = wc.upgrade();
492 let parent = rc.borrow();
493 if parent
494 .type_name()
495 .unwrap_or_else(|| ir::Id::from(""))
496 == "std_reg"
497 {
498 let name = format!("{INVOKE_PREFIX}{counter}");
499 rd.meta.attach_label_static(
500 invoke,
501 ir::Id::from(name.clone()),
502 );
503 return Some((
504 parent.name(),
505 GroupOrInvoke::Invoke(ir::Id::from(name)),
506 ));
507 }
508 }
509 None
510 });
511
512 let mut new_reach = reach;
513 new_reach.set.extend(iterator);
514
515 (new_reach, killed)
516 }
517 }
518}
519
520fn handle_repeat_while_body(
522 body: &ir::Control,
523 reach: DefSet,
524 killed: KilledSet,
525 rd: &mut ReachingDefinitionAnalysis,
526 counter: &mut u64,
527) -> (DefSet, KilledSet) {
528 let (post_cond_def, post_cond_killed) = build_reaching_def(
529 &ir::Control::empty(),
530 reach.clone(),
531 killed,
532 rd,
533 counter,
534 );
535
536 let (round_1_def, mut round_1_killed) =
537 build_reaching_def(body, post_cond_def, post_cond_killed, rd, counter);
538
539 remove_entries_defined_by(&mut round_1_killed, &reach);
540
541 let (post_cond2_def, post_cond2_killed) = build_reaching_def(
542 &ir::Control::empty(),
543 &round_1_def | &reach,
544 round_1_killed,
545 rd,
546 counter,
547 );
548 let (final_def, mut final_kill) = build_reaching_def(
551 body,
552 post_cond2_def.clone(),
553 post_cond2_killed,
554 rd,
555 counter,
556 );
557
558 remove_entries_defined_by(&mut final_kill, &post_cond2_def);
559
560 (&final_def | &post_cond2_def, final_kill)
561}
562
563fn build_reaching_def(
564 c: &ir::Control,
565 reach: DefSet,
566 killed: KilledSet,
567 rd: &mut ReachingDefinitionAnalysis,
568 counter: &mut u64,
569) -> (DefSet, KilledSet) {
570 match c {
571 ir::Control::Seq(ir::Seq { stmts, .. }) => {
572 stmts
573 .iter()
574 .fold((reach, killed), |(acc, killed), inner_c| {
575 build_reaching_def(inner_c, acc, killed, rd, counter)
576 })
577 }
578 ir::Control::Par(ir::Par { stmts, .. }) => {
579 let (defs, par_killed): (Vec<DefSet>, Vec<KilledSet>) = stmts
580 .iter()
581 .map(|ctrl| {
582 build_reaching_def(
583 ctrl,
584 reach.clone(),
585 KilledSet::new(),
586 rd,
587 counter,
588 )
589 })
590 .unzip();
591
592 let global_killed = par_killed
593 .iter()
594 .fold(KilledSet::new(), |acc, set| &acc | set);
595
596 let par_exit_defs = defs
597 .iter()
598 .zip(par_killed.iter())
599 .map(|(defs, kills)| {
600 defs.kill_from_hashset(&(&global_killed - kills))
601 })
602 .fold(DefSet::default(), |acc, element| &acc | &element);
603 (par_exit_defs, &global_killed | &killed)
604 }
605 ir::Control::If(ir::If {
606 tbranch, fbranch, ..
607 }) => {
608 let (post_cond_def, post_cond_killed) = build_reaching_def(
609 &ir::Control::empty(),
610 reach,
611 killed,
612 rd,
613 counter,
614 );
615 let (t_case_def, t_case_killed) = build_reaching_def(
616 tbranch,
617 post_cond_def.clone(),
618 post_cond_killed.clone(),
619 rd,
620 counter,
621 );
622 let (f_case_def, f_case_killed) = build_reaching_def(
623 fbranch,
624 post_cond_def,
625 post_cond_killed,
626 rd,
627 counter,
628 );
629 (&t_case_def | &f_case_def, &t_case_killed | &f_case_killed)
630 }
631 ir::Control::While(ir::While { body, .. }) => {
632 handle_repeat_while_body(body, reach, killed, rd, counter)
633 }
634 ir::Control::Invoke(invoke) => {
635 *counter += 1;
636
637 let iterator = invoke
638 .inputs
639 .iter()
640 .chain(invoke.outputs.iter())
641 .filter_map(|(_, port)| {
642 if let ir::PortParent::Cell(wc) = &port.borrow().parent {
643 let rc = wc.upgrade();
644 let parent = rc.borrow();
645 if parent
646 .type_name()
647 .unwrap_or_else(|| ir::Id::from(""))
648 == "std_reg"
649 {
650 let name = format!("{INVOKE_PREFIX}{counter}");
651 rd.meta.attach_label(
652 invoke,
653 ir::Id::from(name.clone()),
654 );
655 return Some((
656 parent.name(),
657 GroupOrInvoke::Invoke(ir::Id::from(name)),
658 ));
659 }
660 }
661 None
662 });
663
664 let mut new_reach = reach;
665 new_reach.set.extend(iterator);
666
667 (new_reach, killed)
668 }
669 ir::Control::Enable(en) => handle_reaching_def_enables(
670 &en.group.borrow().assignments,
671 reach,
672 rd,
673 en.group.borrow().name(),
674 ),
675 ir::Control::Empty(_) => (reach, killed),
676 ir::Control::Repeat(ir::Repeat { body, .. }) => {
677 handle_repeat_while_body(body, reach, killed, rd, counter)
678 }
679 ir::Control::Static(sc) => {
680 build_reaching_def_static(sc, reach, killed, rd, counter)
681 }
682 ir::Control::FSMEnable(_) => {
683 todo!("should not encounter fsm nodes")
684 }
685 }
686}