1use calyx_ir::{self as ir};
2use std::collections::HashMap;
3use std::rc::Rc;
4
5#[derive(Debug, Default)]
6pub struct PromotionAnalysis {
7 static_group_name: HashMap<ir::Id, ir::Id>,
9}
10
11impl PromotionAnalysis {
12 fn check_latencies_match(actual: u64, inferred: u64) {
13 assert_eq!(
14 actual, inferred,
15 "Inferred and Annotated Latencies do not match. Latency: {actual}. Inferred: {inferred}"
16 );
17 }
18
19 pub fn get_inferred_latency(c: &ir::Control) -> u64 {
20 let ir::Control::Static(sc) = c else {
21 let Some(latency) = c.get_attribute(ir::NumAttr::Promotable) else {
22 unreachable!(
23 "Called get_latency on control that is neither static nor promotable"
24 )
25 };
26 return latency;
27 };
28 sc.get_latency()
29 }
30
31 pub fn can_be_promoted(c: &ir::Control) -> bool {
34 c.is_static() || c.has_attribute(ir::NumAttr::Promotable)
35 }
36
37 fn construct_static_group(
40 &mut self,
41 builder: &mut ir::Builder,
42 group: ir::RRC<ir::Group>,
43 latency: u64,
44 ) -> ir::RRC<ir::StaticGroup> {
45 if let Some(s_name) = self.static_group_name.get(&group.borrow().name())
46 {
47 builder.component.find_static_group(*s_name).unwrap()
48 } else {
49 let sg = builder.add_static_group(group.borrow().name(), latency);
50 self.static_group_name
51 .insert(group.borrow().name(), sg.borrow().name());
52 for assignment in group.borrow().assignments.iter() {
53 if !(assignment.dst.borrow().is_hole()
55 && assignment.dst.borrow().name == "done")
56 {
57 sg.borrow_mut()
58 .assignments
59 .push(ir::Assignment::from(assignment.clone()));
60 }
61 }
62 Rc::clone(&sg)
63 }
64 }
65
66 pub fn convert_enable_to_static(
68 &mut self,
69 s: &mut ir::Enable,
70 builder: &mut ir::Builder,
71 ) -> ir::StaticControl {
72 s.attributes.remove(ir::NumAttr::Promotable);
73 ir::StaticControl::Enable(ir::StaticEnable {
74 group: self.construct_static_group(
76 builder,
77 Rc::clone(&s.group),
78 s.group
79 .borrow()
80 .get_attributes()
81 .unwrap()
82 .get(ir::NumAttr::Promotable)
83 .unwrap(),
84 ),
85 attributes: std::mem::take(&mut s.attributes),
86 })
87 }
88
89 pub fn convert_invoke_to_static(
91 &mut self,
92 s: &mut ir::Invoke,
93 ) -> ir::StaticControl {
94 assert!(
95 s.comb_group.is_none(),
96 "Shouldn't Promote to Static if there is a Comb Group",
97 );
98 let latency = s.attributes.get(ir::NumAttr::Promotable).unwrap();
99 s.attributes.remove(ir::NumAttr::Promotable);
100 let s_inv = ir::StaticInvoke {
101 comp: Rc::clone(&s.comp),
102 inputs: std::mem::take(&mut s.inputs),
103 outputs: std::mem::take(&mut s.outputs),
104 latency,
105 attributes: std::mem::take(&mut s.attributes),
106 ref_cells: std::mem::take(&mut s.ref_cells),
107 comb_group: std::mem::take(&mut s.comb_group),
108 };
109 ir::StaticControl::Invoke(s_inv)
110 }
111
112 pub fn convert_to_static(
115 &mut self,
116 c: &mut ir::Control,
117 builder: &mut ir::Builder,
118 ) -> ir::StaticControl {
119 assert!(
120 c.has_attribute(ir::NumAttr::Promotable) || c.is_static(),
121 "Called convert_to_static control that is neither static nor promotable"
122 );
123 let bound_attribute = c.get_attribute(ir::NumAttr::Bound);
126 let inferred_latency = Self::get_inferred_latency(c);
129 match c {
130 ir::Control::Empty(_) => ir::StaticControl::empty(),
131 ir::Control::Enable(s) => self.convert_enable_to_static(s, builder),
132 ir::Control::Seq(ir::Seq { stmts, attributes }) => {
133 attributes.remove(ir::NumAttr::Promotable);
135 attributes.insert(ir::NumAttr::Compactable, 1);
137 let static_stmts =
138 self.convert_vec_to_static(builder, std::mem::take(stmts));
139 let latency =
140 static_stmts.iter().map(|s| s.get_latency()).sum();
141 Self::check_latencies_match(latency, inferred_latency);
142 ir::StaticControl::Seq(ir::StaticSeq {
143 stmts: static_stmts,
144 attributes: std::mem::take(attributes),
145 latency,
146 })
147 }
148 ir::Control::Par(ir::Par { stmts, attributes }) => {
149 attributes.remove(ir::NumAttr::Promotable);
151 let static_stmts =
153 self.convert_vec_to_static(builder, std::mem::take(stmts));
154 let latency = static_stmts
156 .iter()
157 .map(|s| s.get_latency())
158 .max()
159 .unwrap_or_else(|| unreachable!("Empty Par Block"));
160 Self::check_latencies_match(latency, inferred_latency);
161 ir::StaticControl::Par(ir::StaticPar {
162 stmts: static_stmts,
163 attributes: ir::Attributes::default(),
164 latency,
165 })
166 }
167 ir::Control::Repeat(ir::Repeat {
168 body,
169 num_repeats,
170 attributes,
171 }) => {
172 attributes.remove(ir::NumAttr::Promotable);
174 let sc = self.convert_to_static(body, builder);
175 let latency = (*num_repeats) * sc.get_latency();
176 Self::check_latencies_match(latency, inferred_latency);
177 ir::StaticControl::Repeat(ir::StaticRepeat {
178 attributes: std::mem::take(attributes),
179 body: Box::new(sc),
180 num_repeats: *num_repeats,
181 latency,
182 })
183 }
184 ir::Control::While(ir::While {
185 body, attributes, ..
186 }) => {
187 attributes.remove(ir::NumAttr::Promotable);
189 attributes.remove(ir::NumAttr::Bound);
191 let sc = self.convert_to_static(body, builder);
192 let num_repeats = bound_attribute.unwrap_or_else(|| unreachable!("Called convert_to_static on a while loop without a bound"));
193 let latency = num_repeats * sc.get_latency();
194 Self::check_latencies_match(latency, inferred_latency);
195 ir::StaticControl::Repeat(ir::StaticRepeat {
196 attributes: std::mem::take(attributes),
197 body: Box::new(sc),
198 num_repeats,
199 latency,
200 })
201 }
202 ir::Control::If(ir::If {
203 port,
204 tbranch,
205 fbranch,
206 attributes,
207 ..
208 }) => {
209 attributes.remove(ir::NumAttr::Promotable);
211 let static_tbranch = self.convert_to_static(tbranch, builder);
212 let static_fbranch = self.convert_to_static(fbranch, builder);
213 let latency = std::cmp::max(
214 static_tbranch.get_latency(),
215 static_fbranch.get_latency(),
216 );
217 Self::check_latencies_match(latency, inferred_latency);
218 ir::StaticControl::static_if(
219 Rc::clone(port),
220 Box::new(static_tbranch),
221 Box::new(static_fbranch),
222 latency,
223 )
224 }
225 ir::Control::Static(_) => c.take_static_control(),
226 ir::Control::Invoke(s) => self.convert_invoke_to_static(s),
227 ir::Control::FSMEnable(_) => {
228 todo!("should not encounter fsm nodes")
229 }
230 }
231 }
232
233 pub fn convert_vec_to_static(
236 &mut self,
237 builder: &mut ir::Builder,
238 control_vec: Vec<ir::Control>,
239 ) -> Vec<ir::StaticControl> {
240 control_vec
241 .into_iter()
242 .map(|mut c| self.convert_to_static(&mut c, builder))
243 .collect()
244 }
245}