1use super::AssignmentAnalysis;
2use crate::analysis::{GraphAnalysis, compute_static::WithStatic};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug)]
12pub struct GoDone {
13 ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17 pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18 Self { ports }
19 }
20
21 pub fn is_go(&self, name: &ir::Id) -> bool {
23 self.ports.iter().any(|(go, _, _)| name == go)
24 }
25
26 pub fn is_done(&self, name: &ir::Id) -> bool {
28 self.ports.iter().any(|(_, done, _)| name == done)
29 }
30
31 pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33 self.ports.iter().find_map(|(go, _, lat)| {
34 if go == go_port { Some(*lat) } else { None }
35 })
36 }
37
38 pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
40 self.ports.iter()
41 }
42
43 pub fn is_empty(&self) -> bool {
45 self.ports.is_empty()
46 }
47
48 pub fn len(&self) -> usize {
50 self.ports.len()
51 }
52
53 pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
55 &self.ports
56 }
57}
58
59impl From<&ir::Primitive> for GoDone {
60 fn from(prim: &ir::Primitive) -> Self {
61 let done_ports: HashMap<_, _> = prim
62 .find_all_with_attr(ir::NumAttr::Done)
63 .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
64 .collect();
65
66 let go_ports = prim
67 .find_all_with_attr(ir::NumAttr::Go)
68 .filter_map(|pd| {
69 pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
71 done_ports
72 .get(&pd.attributes.get(ir::NumAttr::Go))
73 .map(|done_port| (pd.name(), *done_port, st))
74 })
75 })
76 .collect_vec();
77 GoDone::new(go_ports)
78 }
79}
80
81impl From<&ir::Cell> for GoDone {
82 fn from(cell: &ir::Cell) -> Self {
83 let done_ports: HashMap<_, _> = cell
84 .find_all_with_attr(ir::NumAttr::Done)
85 .map(|pr| {
86 let port = pr.borrow();
87 (port.attributes.get(ir::NumAttr::Done), port.name)
88 })
89 .collect();
90
91 let go_ports = cell
92 .find_all_with_attr(ir::NumAttr::Go)
93 .filter_map(|pr| {
94 let port = pr.borrow();
95 let st = match port.attributes.get(ir::NumAttr::Interval) {
97 Some(st) => Some(st),
98 None => port.attributes.get(ir::NumAttr::Promotable),
99 };
100 if let Some(static_latency) = st {
101 return done_ports
102 .get(&port.attributes.get(ir::NumAttr::Go))
103 .map(|done_port| {
104 (port.name, *done_port, static_latency)
105 });
106 }
107 None
108 })
109 .collect_vec();
110 GoDone::new(go_ports)
111 }
112}
113
114pub struct InferenceAnalysis {
117 pub latency_data: HashMap<ir::Id, GoDone>,
119 pub static_component_latencies: HashMap<ir::Id, u64>,
124
125 updated_components: HashSet<ir::Id>,
126}
127
128impl InferenceAnalysis {
129 pub fn from_ctx(ctx: &ir::Context) -> Self {
132 let mut latency_data = HashMap::new();
133 let mut static_component_latencies = HashMap::new();
134 for prim in ctx.lib.signatures() {
136 let prim_go_done = GoDone::from(prim);
137 if prim_go_done.len() == 1 {
138 static_component_latencies
139 .insert(prim.name, prim_go_done.get_ports()[0].2);
140 }
141 latency_data.insert(prim.name, GoDone::from(prim));
142 }
143 for comp in &ctx.components {
144 let comp_sig = comp.signature.borrow();
145
146 let done_ports: HashMap<_, _> = comp_sig
147 .find_all_with_attr(ir::NumAttr::Done)
148 .map(|pd| {
149 let pd_ref = pd.borrow();
150 (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
151 })
152 .collect();
153
154 let go_ports = comp_sig
155 .find_all_with_attr(ir::NumAttr::Go)
156 .filter_map(|pd| {
157 let pd_ref = pd.borrow();
158 let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
160 {
161 Some(st) => Some(st),
162 None => pd_ref.attributes.get(ir::NumAttr::Promotable),
163 };
164 if let Some(static_latency) = st {
165 return done_ports
166 .get(&pd_ref.attributes.get(ir::NumAttr::Go))
167 .map(|done_port| {
168 (pd_ref.name, *done_port, static_latency)
169 });
170 }
171 None
172 })
173 .collect_vec();
174
175 let go_done_comp = GoDone::new(go_ports);
176
177 if go_done_comp.len() == 1 {
178 static_component_latencies
179 .insert(comp.name, go_done_comp.get_ports()[0].2);
180 }
181 latency_data.insert(comp.name, go_done_comp);
182 }
183 InferenceAnalysis {
184 latency_data,
185 static_component_latencies,
186 updated_components: HashSet::new(),
187 }
188 }
189
190 pub fn add_component(
192 &mut self,
193 (comp_name, latency, go_done): (ir::Id, u64, GoDone),
194 ) {
195 self.latency_data.insert(comp_name, go_done);
196 self.static_component_latencies.insert(comp_name, latency);
197 }
198
199 pub fn remove_component(&mut self, comp_name: ir::Id) {
203 if self.latency_data.contains_key(&comp_name) {
204 self.updated_components.insert(comp_name);
207 }
208 self.latency_data.remove(&comp_name);
209 self.static_component_latencies.remove(&comp_name);
210 }
211
212 pub fn adjust_component(
216 &mut self,
217 (comp_name, adjusted_latency): (ir::Id, u64),
218 ) {
219 let mut updated = false;
221 self.latency_data.entry(comp_name).and_modify(|go_done| {
222 for (_, _, cur_latency) in &mut go_done.ports {
223 if *cur_latency != adjusted_latency {
225 *cur_latency = adjusted_latency;
226 updated = true;
227 }
228 }
229 });
230 self.static_component_latencies
231 .insert(comp_name, adjusted_latency);
232 if updated {
233 self.updated_components.insert(comp_name);
234 }
235 }
236
237 fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
242 match (&src.parent, &dst.parent) {
243 (
244 ir::PortParent::Cell(src_cell_wrf),
245 ir::PortParent::Cell(dst_cell_wrf),
246 ) => {
247 let src_rf = src_cell_wrf.upgrade();
248 let src_cell = src_rf.borrow();
249 let dst_rf = dst_cell_wrf.upgrade();
250 let dst_cell = dst_rf.borrow();
251 if let (Some(s_name), Some(d_name)) =
252 (src_cell.type_name(), dst_cell.type_name())
253 {
254 let data_src = self.latency_data.get(&s_name);
255 let data_dst = self.latency_data.get(&d_name);
256 if let (Some(dst_ports), Some(src_ports)) =
257 (data_dst, data_src)
258 {
259 return src_ports.is_done(&src.name)
260 && dst_ports.is_go(&dst.name);
261 }
262 }
263
264 if let (Some(d_name), ir::CellType::Constant { .. }) =
266 (dst_cell.type_name(), &src_cell.prototype)
267 && let Some(ports) = self.latency_data.get(&d_name)
268 {
269 return ports.is_go(&dst.name);
270 }
271
272 false
273 }
274
275 (_, ir::PortParent::Group(_)) => dst.name == "done",
277
278 _ => false,
280 }
281 }
282
283 fn find_go_done_edges(
286 &self,
287 group: &ir::Group,
288 ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
289 let rw_set = group.assignments.iter().analysis().cell_uses();
290 let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
291
292 for cell_ref in rw_set {
293 let cell = cell_ref.borrow();
294 if let Some(ports) =
295 cell.type_name().and_then(|c| self.latency_data.get(&c))
296 {
297 go_done_edges.extend(
298 ports
299 .iter()
300 .map(|(go, done, _)| (cell.get(go), cell.get(done))),
301 )
302 }
303 }
304 go_done_edges
305 }
306
307 fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
310 if let ir::PortParent::Cell(cwrf) = &port.parent {
311 let cr = cwrf.upgrade();
312 let cell = cr.borrow();
313 if let ir::CellType::Constant { val, .. } = &cell.prototype {
314 if *val > 0 {
315 return true;
316 }
317 } else if let Some(ports) =
318 cell.type_name().and_then(|c| self.latency_data.get(&c))
319 {
320 return ports.is_done(&port.name);
321 }
322 }
323 false
324 }
325
326 fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
329 for port in &graph.ports() {
330 match &port.borrow().parent {
331 ir::PortParent::Cell(cell_wrf) => {
332 let cr = cell_wrf.upgrade();
333 let cell = cr.borrow();
334 if let Some(ports) =
335 cell.type_name().and_then(|c| self.latency_data.get(&c))
336 {
337 let name = &port.borrow().name;
338 if ports.is_go(name) {
339 for write_port in graph.writes_to(&port.borrow()) {
340 if !self
341 .is_done_port_or_const(&write_port.borrow())
342 {
343 log::debug!(
344 "`{}` is not a done port",
345 write_port.borrow().canonical(),
346 );
347 return true;
348 }
349 }
350 }
351 }
352 }
353 ir::PortParent::Group(_) => {
354 if port.borrow().name == "done" {
355 for write_port in graph.writes_to(&port.borrow()) {
356 if !self.is_done_port_or_const(&write_port.borrow())
357 {
358 log::debug!(
359 "`{}` is not a done port",
360 write_port.borrow().canonical(),
361 );
362 return true;
363 }
364 }
365 }
366 }
367
368 ir::PortParent::FSM(_) => {
369 if port.borrow().name == "done" {
370 for write_port in graph.writes_to(&port.borrow()) {
371 if !self.is_done_port_or_const(&write_port.borrow())
372 {
373 log::debug!(
374 "`{}` is not a done port",
375 write_port.borrow().canonical(),
376 );
377 return true;
378 }
379 }
380 }
381 }
382
383 ir::PortParent::StaticGroup(_) =>
384 {
386 panic!(
387 "Have not decided how to handle static groups in infer-static-timing"
388 )
389 }
390 }
391 }
392 false
393 }
394
395 fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
397 for port in graph.ports() {
398 if graph.writes_to(&port.borrow()).count() > 1 {
399 return true;
400 }
401 }
402 false
403 }
404
405 fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
409 log::debug!("Checking group `{}`", group.name());
431 let graph_unprocessed = GraphAnalysis::from(group);
432 if self.contains_dyn_writes(&graph_unprocessed) {
433 log::debug!("FAIL: contains dynamic writes");
434 return None;
435 }
436
437 let go_done_edges = self.find_go_done_edges(group);
438 let graph = graph_unprocessed
439 .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
440 .add_edges(&go_done_edges)
441 .remove_isolated_vertices();
442
443 if Self::contains_node_deg_gt_one(&graph) {
445 log::debug!("FAIL: Group contains multiple writes");
446 return None;
447 }
448
449 let mut tsort = graph.toposort();
450 let start = tsort.next()?;
451 let finish = tsort.last()?;
452
453 let paths = graph.paths(&start.borrow(), &finish.borrow());
454 if paths.is_empty() {
456 log::debug!("FAIL: No path between @go and @done port");
457 return None;
458 }
459 let first_path = paths.first().unwrap();
460
461 let mut latency_sum = 0;
463 for port in first_path {
464 if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
465 let cr = cwrf.upgrade();
466 let cell = cr.borrow();
467 if let Some(ports) =
468 cell.type_name().and_then(|c| self.latency_data.get(&c))
469 && let Some(latency) =
470 ports.get_latency(&port.borrow().name)
471 {
472 latency_sum += latency;
473 }
474 }
475 }
476
477 log::debug!("SUCCESS: Latency = {latency_sum}");
478 Some(latency_sum)
479 }
480
481 pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
484 match c {
485 ir::Control::Static(sc) => Some(sc.get_latency()),
486 _ => c.get_attribute(ir::NumAttr::Promotable),
487 }
488 }
489
490 pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
491 for stmt in &mut seq.stmts {
492 Self::remove_promotable_attribute(stmt);
493 }
494 seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
495 }
496
497 pub fn remove_promotable_attribute(c: &mut ir::Control) {
500 c.get_mut_attributes().remove(ir::NumAttr::Promotable);
501 match c {
502 ir::Control::Empty(_)
503 | ir::Control::Invoke(_)
504 | ir::Control::Enable(_)
505 | ir::Control::Static(_)
506 | ir::Control::FSMEnable(_) => (),
507 ir::Control::While(ir::While { body, .. })
508 | ir::Control::Repeat(ir::Repeat { body, .. }) => {
509 Self::remove_promotable_attribute(body);
510 }
511 ir::Control::If(ir::If {
512 tbranch, fbranch, ..
513 }) => {
514 Self::remove_promotable_attribute(tbranch);
515 Self::remove_promotable_attribute(fbranch);
516 }
517 ir::Control::Seq(ir::Seq { stmts, .. })
518 | ir::Control::Par(ir::Par { stmts, .. }) => {
519 for stmt in stmts {
520 Self::remove_promotable_attribute(stmt);
521 }
522 }
523 }
524 }
525
526 pub fn fixup_seq(&self, seq: &mut ir::Seq) {
527 seq.update_static(&self.static_component_latencies);
528 }
529
530 pub fn fixup_par(&self, par: &mut ir::Par) {
531 par.update_static(&self.static_component_latencies);
532 }
533
534 pub fn fixup_if(&self, _if: &mut ir::If) {
535 _if.update_static(&self.static_component_latencies);
536 }
537
538 pub fn fixup_while(&self, _while: &mut ir::While) {
539 _while.update_static(&self.static_component_latencies);
540 }
541
542 pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
543 repeat.update_static(&self.static_component_latencies);
544 }
545
546 pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
547 ctrl.update_static(&self.static_component_latencies);
548 }
549
550 pub fn fixup_timing(&self, comp: &mut ir::Component) {
560 for group in comp.groups.iter() {
563 if group
567 .borrow_mut()
568 .assignments
569 .iter()
570 .analysis()
571 .cell_writes()
572 .any(|cell| match cell.borrow().prototype {
573 CellType::Component { name } => {
574 self.updated_components.contains(&name)
575 }
576 _ => false,
577 })
578 {
579 group
581 .borrow_mut()
582 .attributes
583 .remove(ir::NumAttr::Promotable);
584 }
585 }
586
587 for group in &mut comp.groups.iter() {
588 let latency_result = self.infer_latency(&group.borrow());
590 if let Some(latency) = latency_result {
591 group
592 .borrow_mut()
593 .attributes
594 .insert(ir::NumAttr::Promotable, latency);
595 }
596 }
597
598 Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
601 comp.control
602 .borrow_mut()
603 .update_static(&self.static_component_latencies);
604 }
605}