1use super::AssignmentAnalysis;
2use crate::analysis::{GraphAnalysis, compute_static::WithStatic};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug)]
12pub struct GoDone {
13 ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17 pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18 Self { ports }
19 }
20
21 pub fn is_go(&self, name: &ir::Id) -> bool {
23 self.ports.iter().any(|(go, _, _)| name == go)
24 }
25
26 pub fn is_done(&self, name: &ir::Id) -> bool {
28 self.ports.iter().any(|(_, done, _)| name == done)
29 }
30
31 pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33 self.ports.iter().find_map(|(go, _, lat)| {
34 if go == go_port { Some(*lat) } else { None }
35 })
36 }
37
38 pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
40 self.ports.iter()
41 }
42
43 pub fn is_empty(&self) -> bool {
45 self.ports.is_empty()
46 }
47
48 pub fn len(&self) -> usize {
50 self.ports.len()
51 }
52
53 pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
55 &self.ports
56 }
57}
58
59impl From<&ir::Primitive> for GoDone {
60 fn from(prim: &ir::Primitive) -> Self {
61 let done_ports: HashMap<_, _> = prim
62 .find_all_with_attr(ir::NumAttr::Done)
63 .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
64 .collect();
65
66 let go_ports = prim
67 .find_all_with_attr(ir::NumAttr::Go)
68 .filter_map(|pd| {
69 pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
71 done_ports
72 .get(&pd.attributes.get(ir::NumAttr::Go))
73 .map(|done_port| (pd.name(), *done_port, st))
74 })
75 })
76 .collect_vec();
77 GoDone::new(go_ports)
78 }
79}
80
81impl From<&ir::Cell> for GoDone {
82 fn from(cell: &ir::Cell) -> Self {
83 let done_ports: HashMap<_, _> = cell
84 .find_all_with_attr(ir::NumAttr::Done)
85 .map(|pr| {
86 let port = pr.borrow();
87 (port.attributes.get(ir::NumAttr::Done), port.name)
88 })
89 .collect();
90
91 let go_ports = cell
92 .find_all_with_attr(ir::NumAttr::Go)
93 .filter_map(|pr| {
94 let port = pr.borrow();
95 let st = match port.attributes.get(ir::NumAttr::Interval) {
97 Some(st) => Some(st),
98 None => port.attributes.get(ir::NumAttr::Promotable),
99 };
100 if let Some(static_latency) = st {
101 return done_ports
102 .get(&port.attributes.get(ir::NumAttr::Go))
103 .map(|done_port| {
104 (port.name, *done_port, static_latency)
105 });
106 }
107 None
108 })
109 .collect_vec();
110 GoDone::new(go_ports)
111 }
112}
113
114pub struct InferenceAnalysis {
117 pub latency_data: HashMap<ir::Id, GoDone>,
119 pub static_component_latencies: HashMap<ir::Id, u64>,
124
125 updated_components: HashSet<ir::Id>,
126}
127
128impl InferenceAnalysis {
129 pub fn from_ctx(ctx: &ir::Context) -> Self {
132 let mut latency_data = HashMap::new();
133 let mut static_component_latencies = HashMap::new();
134 for prim in ctx.lib.signatures() {
136 let prim_go_done = GoDone::from(prim);
137 if prim_go_done.len() == 1 {
138 static_component_latencies
139 .insert(prim.name, prim_go_done.get_ports()[0].2);
140 }
141 latency_data.insert(prim.name, GoDone::from(prim));
142 }
143 for comp in &ctx.components {
144 let comp_sig = comp.signature.borrow();
145
146 let done_ports: HashMap<_, _> = comp_sig
147 .find_all_with_attr(ir::NumAttr::Done)
148 .map(|pd| {
149 let pd_ref = pd.borrow();
150 (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
151 })
152 .collect();
153
154 let go_ports = comp_sig
155 .find_all_with_attr(ir::NumAttr::Go)
156 .filter_map(|pd| {
157 let pd_ref = pd.borrow();
158 let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
160 {
161 Some(st) => Some(st),
162 None => pd_ref.attributes.get(ir::NumAttr::Promotable),
163 };
164 if let Some(static_latency) = st {
165 return done_ports
166 .get(&pd_ref.attributes.get(ir::NumAttr::Go))
167 .map(|done_port| {
168 (pd_ref.name, *done_port, static_latency)
169 });
170 }
171 None
172 })
173 .collect_vec();
174
175 let go_done_comp = GoDone::new(go_ports);
176
177 if go_done_comp.len() == 1 {
178 static_component_latencies
179 .insert(comp.name, go_done_comp.get_ports()[0].2);
180 }
181 latency_data.insert(comp.name, go_done_comp);
182 }
183 InferenceAnalysis {
184 latency_data,
185 static_component_latencies,
186 updated_components: HashSet::new(),
187 }
188 }
189
190 pub fn add_component(
192 &mut self,
193 (comp_name, latency, go_done): (ir::Id, u64, GoDone),
194 ) {
195 self.latency_data.insert(comp_name, go_done);
196 self.static_component_latencies.insert(comp_name, latency);
197 }
198
199 pub fn remove_component(&mut self, comp_name: ir::Id) {
203 if self.latency_data.contains_key(&comp_name) {
204 self.updated_components.insert(comp_name);
207 }
208 self.latency_data.remove(&comp_name);
209 self.static_component_latencies.remove(&comp_name);
210 }
211
212 pub fn adjust_component(
216 &mut self,
217 (comp_name, adjusted_latency): (ir::Id, u64),
218 ) {
219 let mut updated = false;
221 self.latency_data.entry(comp_name).and_modify(|go_done| {
222 for (_, _, cur_latency) in &mut go_done.ports {
223 if *cur_latency != adjusted_latency {
225 *cur_latency = adjusted_latency;
226 updated = true;
227 }
228 }
229 });
230 self.static_component_latencies
231 .insert(comp_name, adjusted_latency);
232 if updated {
233 self.updated_components.insert(comp_name);
234 }
235 }
236
237 fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
242 match (&src.parent, &dst.parent) {
243 (
244 ir::PortParent::Cell(src_cell_wrf),
245 ir::PortParent::Cell(dst_cell_wrf),
246 ) => {
247 let src_rf = src_cell_wrf.upgrade();
248 let src_cell = src_rf.borrow();
249 let dst_rf = dst_cell_wrf.upgrade();
250 let dst_cell = dst_rf.borrow();
251 if let (Some(s_name), Some(d_name)) =
252 (src_cell.type_name(), dst_cell.type_name())
253 {
254 let data_src = self.latency_data.get(&s_name);
255 let data_dst = self.latency_data.get(&d_name);
256 if let (Some(dst_ports), Some(src_ports)) =
257 (data_dst, data_src)
258 {
259 return src_ports.is_done(&src.name)
260 && dst_ports.is_go(&dst.name);
261 }
262 }
263
264 if let (Some(d_name), ir::CellType::Constant { .. }) =
266 (dst_cell.type_name(), &src_cell.prototype)
267 {
268 if let Some(ports) = self.latency_data.get(&d_name) {
269 return ports.is_go(&dst.name);
270 }
271 }
272
273 false
274 }
275
276 (_, ir::PortParent::Group(_)) => dst.name == "done",
278
279 _ => false,
281 }
282 }
283
284 fn find_go_done_edges(
287 &self,
288 group: &ir::Group,
289 ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
290 let rw_set = group.assignments.iter().analysis().cell_uses();
291 let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
292
293 for cell_ref in rw_set {
294 let cell = cell_ref.borrow();
295 if let Some(ports) =
296 cell.type_name().and_then(|c| self.latency_data.get(&c))
297 {
298 go_done_edges.extend(
299 ports
300 .iter()
301 .map(|(go, done, _)| (cell.get(go), cell.get(done))),
302 )
303 }
304 }
305 go_done_edges
306 }
307
308 fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
311 if let ir::PortParent::Cell(cwrf) = &port.parent {
312 let cr = cwrf.upgrade();
313 let cell = cr.borrow();
314 if let ir::CellType::Constant { val, .. } = &cell.prototype {
315 if *val > 0 {
316 return true;
317 }
318 } else if let Some(ports) =
319 cell.type_name().and_then(|c| self.latency_data.get(&c))
320 {
321 return ports.is_done(&port.name);
322 }
323 }
324 false
325 }
326
327 fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
330 for port in &graph.ports() {
331 match &port.borrow().parent {
332 ir::PortParent::Cell(cell_wrf) => {
333 let cr = cell_wrf.upgrade();
334 let cell = cr.borrow();
335 if let Some(ports) =
336 cell.type_name().and_then(|c| self.latency_data.get(&c))
337 {
338 let name = &port.borrow().name;
339 if ports.is_go(name) {
340 for write_port in graph.writes_to(&port.borrow()) {
341 if !self
342 .is_done_port_or_const(&write_port.borrow())
343 {
344 log::debug!(
345 "`{}` is not a done port",
346 write_port.borrow().canonical(),
347 );
348 return true;
349 }
350 }
351 }
352 }
353 }
354 ir::PortParent::Group(_) => {
355 if port.borrow().name == "done" {
356 for write_port in graph.writes_to(&port.borrow()) {
357 if !self.is_done_port_or_const(&write_port.borrow())
358 {
359 log::debug!(
360 "`{}` is not a done port",
361 write_port.borrow().canonical(),
362 );
363 return true;
364 }
365 }
366 }
367 }
368
369 ir::PortParent::FSM(_) => {
370 if port.borrow().name == "done" {
371 for write_port in graph.writes_to(&port.borrow()) {
372 if !self.is_done_port_or_const(&write_port.borrow())
373 {
374 log::debug!(
375 "`{}` is not a done port",
376 write_port.borrow().canonical(),
377 );
378 return true;
379 }
380 }
381 }
382 }
383
384 ir::PortParent::StaticGroup(_) =>
385 {
387 panic!(
388 "Have not decided how to handle static groups in infer-static-timing"
389 )
390 }
391 }
392 }
393 false
394 }
395
396 fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
398 for port in graph.ports() {
399 if graph.writes_to(&port.borrow()).count() > 1 {
400 return true;
401 }
402 }
403 false
404 }
405
406 fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
410 log::debug!("Checking group `{}`", group.name());
432 let graph_unprocessed = GraphAnalysis::from(group);
433 if self.contains_dyn_writes(&graph_unprocessed) {
434 log::debug!("FAIL: contains dynamic writes");
435 return None;
436 }
437
438 let go_done_edges = self.find_go_done_edges(group);
439 let graph = graph_unprocessed
440 .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
441 .add_edges(&go_done_edges)
442 .remove_isolated_vertices();
443
444 if Self::contains_node_deg_gt_one(&graph) {
446 log::debug!("FAIL: Group contains multiple writes");
447 return None;
448 }
449
450 let mut tsort = graph.toposort();
451 let start = tsort.next()?;
452 let finish = tsort.last()?;
453
454 let paths = graph.paths(&start.borrow(), &finish.borrow());
455 if paths.is_empty() {
457 log::debug!("FAIL: No path between @go and @done port");
458 return None;
459 }
460 let first_path = paths.first().unwrap();
461
462 let mut latency_sum = 0;
464 for port in first_path {
465 if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
466 let cr = cwrf.upgrade();
467 let cell = cr.borrow();
468 if let Some(ports) =
469 cell.type_name().and_then(|c| self.latency_data.get(&c))
470 {
471 if let Some(latency) =
472 ports.get_latency(&port.borrow().name)
473 {
474 latency_sum += latency;
475 }
476 }
477 }
478 }
479
480 log::debug!("SUCCESS: Latency = {latency_sum}");
481 Some(latency_sum)
482 }
483
484 pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
487 match c {
488 ir::Control::Static(sc) => Some(sc.get_latency()),
489 _ => c.get_attribute(ir::NumAttr::Promotable),
490 }
491 }
492
493 pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
494 for stmt in &mut seq.stmts {
495 Self::remove_promotable_attribute(stmt);
496 }
497 seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
498 }
499
500 pub fn remove_promotable_attribute(c: &mut ir::Control) {
503 c.get_mut_attributes().remove(ir::NumAttr::Promotable);
504 match c {
505 ir::Control::Empty(_)
506 | ir::Control::Invoke(_)
507 | ir::Control::Enable(_)
508 | ir::Control::Static(_)
509 | ir::Control::FSMEnable(_) => (),
510 ir::Control::While(ir::While { body, .. })
511 | ir::Control::Repeat(ir::Repeat { body, .. }) => {
512 Self::remove_promotable_attribute(body);
513 }
514 ir::Control::If(ir::If {
515 tbranch, fbranch, ..
516 }) => {
517 Self::remove_promotable_attribute(tbranch);
518 Self::remove_promotable_attribute(fbranch);
519 }
520 ir::Control::Seq(ir::Seq { stmts, .. })
521 | ir::Control::Par(ir::Par { stmts, .. }) => {
522 for stmt in stmts {
523 Self::remove_promotable_attribute(stmt);
524 }
525 }
526 }
527 }
528
529 pub fn fixup_seq(&self, seq: &mut ir::Seq) {
530 seq.update_static(&self.static_component_latencies);
531 }
532
533 pub fn fixup_par(&self, par: &mut ir::Par) {
534 par.update_static(&self.static_component_latencies);
535 }
536
537 pub fn fixup_if(&self, _if: &mut ir::If) {
538 _if.update_static(&self.static_component_latencies);
539 }
540
541 pub fn fixup_while(&self, _while: &mut ir::While) {
542 _while.update_static(&self.static_component_latencies);
543 }
544
545 pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
546 repeat.update_static(&self.static_component_latencies);
547 }
548
549 pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
550 ctrl.update_static(&self.static_component_latencies);
551 }
552
553 pub fn fixup_timing(&self, comp: &mut ir::Component) {
563 for group in comp.groups.iter() {
566 if group
570 .borrow_mut()
571 .assignments
572 .iter()
573 .analysis()
574 .cell_writes()
575 .any(|cell| match cell.borrow().prototype {
576 CellType::Component { name } => {
577 self.updated_components.contains(&name)
578 }
579 _ => false,
580 })
581 {
582 group
584 .borrow_mut()
585 .attributes
586 .remove(ir::NumAttr::Promotable);
587 }
588 }
589
590 for group in &mut comp.groups.iter() {
591 let latency_result = self.infer_latency(&group.borrow());
593 if let Some(latency) = latency_result {
594 group
595 .borrow_mut()
596 .attributes
597 .insert(ir::NumAttr::Promotable, latency);
598 }
599 }
600
601 Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
604 comp.control
605 .borrow_mut()
606 .update_static(&self.static_component_latencies);
607 }
608}