calyx_opt/passes_experimental/sync/
compile_sync.rs1use crate::traversal::{Action, Named, VisResult, Visitor};
2use calyx_ir::RRC;
3use calyx_ir::{self as ir, GetAttributes};
4use calyx_ir::{build_assignments, guard, structure};
5use calyx_utils::{CalyxResult, Error};
6use linked_hash_map::LinkedHashMap;
7use std::collections::{HashMap, HashSet};
8use std::rc::Rc;
9
10#[derive(Default)]
11pub struct CompileSync {
34 barriers: BarrierMap,
35}
36
37type BarrierMap = LinkedHashMap<u64, ([RRC<ir::Cell>; 2], [RRC<ir::Group>; 3])>;
39
40impl Named for CompileSync {
41 fn name() -> &'static str {
42 "compile-sync"
43 }
44
45 fn description() -> &'static str {
46 "Implement barriers for statements marked with @sync attribute"
47 }
48}
49
50fn count_barriers(
52 s: &ir::Control,
53 count: &mut HashSet<u64>,
54) -> CalyxResult<()> {
55 match s {
56 ir::Control::Empty(_) => {
57 if let Some(n) = s.get_attributes().get(ir::NumAttr::Sync) {
58 count.insert(n);
59 }
60 Ok(())
61 }
62 ir::Control::Seq(seq) => {
63 for stmt in seq.stmts.iter() {
64 count_barriers(stmt, count)?;
65 }
66 Ok(())
67 }
68 ir::Control::While(ir::While { body, .. })
69 | ir::Control::Repeat(ir::Repeat { body, .. }) => {
70 count_barriers(body, count)?;
71 Ok(())
72 }
73 ir::Control::If(i) => {
74 count_barriers(&i.tbranch, count)?;
75 count_barriers(&i.fbranch, count)?;
76 Ok(())
77 }
78 ir::Control::Enable(e) => {
79 if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
80 return Err(Error::malformed_control(
81 "Enable or Invoke controls cannot be marked with @sync"
82 .to_string(),
83 )
84 .with_pos(&e.attributes));
85 }
86 Ok(())
87 }
88 ir::Control::Invoke(i) => {
89 if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
90 return Err(Error::malformed_control(
91 "Enable or Invoke controls cannot be marked with @sync"
92 .to_string(),
93 )
94 .with_pos(&i.attributes));
95 }
96 Ok(())
97 }
98 ir::Control::Par(_) => Ok(()),
99 ir::Control::Static(_) => Ok(()),
100 ir::Control::FSMEnable(_) => {
101 todo!("should not encounter fsm nodes")
102 }
103 }
104}
105
106impl CompileSync {
107 fn build_barriers(
108 &mut self,
109 builder: &mut ir::Builder,
110 s: &mut ir::Control,
111 count: &mut HashMap<u64, u64>,
112 ) {
113 match s {
114 ir::Control::Empty(_) => {
115 if let Some(ref n) = s.get_attributes().get(ir::NumAttr::Sync) {
116 if self.barriers.get(n).is_none() {
117 self.add_shared_structure(builder, n);
118 }
119 let (cells, groups) = &self.barriers[n];
120 let member_idx = count[n];
121
122 let mut new_s =
123 build_member(builder, cells, groups, &member_idx);
124 std::mem::swap(s, &mut new_s);
125 }
126 }
127 ir::Control::Seq(seq) => {
128 for stmt in seq.stmts.iter_mut() {
138 self.build_barriers(builder, stmt, count);
139 }
140 }
141 ir::Control::While(w) => {
142 self.build_barriers(builder, &mut w.body, count);
143 }
144 ir::Control::If(i) => {
145 self.build_barriers(builder, &mut i.tbranch, count);
146 self.build_barriers(builder, &mut i.fbranch, count);
147 }
148 _ => {}
149 }
150 }
151
152 fn add_shared_structure(
153 &mut self,
154 builder: &mut ir::Builder,
155 barrier_idx: &u64,
156 ) {
157 structure!(builder;
158 let barrier = prim std_sync_reg(32);
159 let eq = prim std_eq(32);
160 );
161 let restore = build_restore(builder, &barrier);
162 let wait_restore = build_wait_restore(builder, &eq);
163 let clear_barrier = build_clear_barrier(builder, &barrier);
164 let shared_cells: [RRC<ir::Cell>; 2] = [barrier, eq];
165 let shared_groups: [RRC<ir::Group>; 3] =
166 [wait_restore, restore, clear_barrier];
167 let info = (shared_cells, shared_groups);
168 self.barriers.insert(*barrier_idx, info);
169 }
170}
171
172fn build_incr_barrier(
174 builder: &mut ir::Builder,
175 barrier: &RRC<ir::Cell>,
176 save: &RRC<ir::Cell>,
177 member_idx: &u64,
178) -> RRC<ir::Group> {
179 let group = builder.add_group("incr_barrier");
180 structure!(builder;
181 let incr = prim std_add(32);
182 let cst_1 = constant(1, 1);
183 let cst_2 = constant(1, 32););
184 let read_done_guard = guard!(barrier[format!("read_done_{member_idx}")]);
185 let assigns = build_assignments!(builder;
186 barrier[format!("read_en_{member_idx}")] = ?cst_1["out"];
188 incr["left"] = ? barrier[format!("out_{member_idx}")];
190 incr["right"] = ? cst_2["out"];
192 save["in"] = read_done_guard? incr["out"];
194 save["write_en"] = ? barrier[format!("read_done_{member_idx}")];
196 group["done"] = ?save["done"];
198 );
199
200 group.borrow_mut().assignments.extend(assigns);
201 group
202}
203
204fn build_write_barrier(
206 builder: &mut ir::Builder,
207 barrier: &RRC<ir::Cell>,
208 save: &RRC<ir::Cell>,
209 member_idx: &u64,
210) -> RRC<ir::Group> {
211 let group = builder.add_group("write_barrier");
212 structure!(builder;
213 let cst_1 = constant(1, 1););
214 let assigns = build_assignments!(builder;
215 barrier[format!("write_en_{member_idx}")] = ?cst_1["out"];
217 barrier[format!("in_{member_idx}")] = ?save["out"];
219 group["done"] = ?barrier[format!("write_done_{member_idx}")];
221 );
222 group.borrow_mut().assignments.extend(assigns);
223 group
224}
225
226fn build_wait(builder: &mut ir::Builder, eq: &RRC<ir::Cell>) -> RRC<ir::Group> {
231 let group = builder.add_group("wt");
232 structure!(builder;
233 let wait_reg = prim std_reg(1);
234 let cst_1 = constant(1, 1););
235 let eq_guard = guard!(eq["out"]);
236 let assigns = build_assignments!(builder;
237 wait_reg["in"] = ?eq["out"];
240 wait_reg["write_en"] = eq_guard? cst_1["out"];
242 group["done"] = ?wait_reg["done"];);
244 group.borrow_mut().assignments.extend(assigns);
245 group
246}
247
248fn build_clear_barrier(
250 builder: &mut ir::Builder,
251 barrier: &RRC<ir::Cell>,
252) -> RRC<ir::Group> {
253 let group = builder.add_group("clear_barrier");
254 structure!(builder;
255 let cst_1 = constant(1, 1););
256 let assigns = build_assignments!(builder;
257 barrier["read_en_0"] = ?cst_1["out"];
259 group["done"] = ?barrier["read_done_0"];
261 );
262 group.borrow_mut().assignments.extend(assigns);
263 group
264}
265
266fn build_restore(
268 builder: &mut ir::Builder,
269 barrier: &RRC<ir::Cell>,
270) -> RRC<ir::Group> {
271 let group = builder.add_group("restore");
272 structure!(builder;
273 let cst_1 = constant(1,1);
274 let cst_2 = constant(0, 32););
275 let assigns = build_assignments!(builder;
276 barrier["write_en_0"] = ?cst_1["out"];
278 barrier["in_0"] = ?cst_2["out"];
280 group["done"] = ?barrier["write_done_0"];
282 );
283 group.borrow_mut().assignments.extend(assigns);
284 group
285}
286
287fn build_wait_restore(
291 builder: &mut ir::Builder,
292 eq: &RRC<ir::Cell>,
293) -> RRC<ir::Group> {
294 let group = builder.add_group("wait_restore");
295 structure!(builder;
296 let wait_restore_reg = prim std_reg(1);
297 let cst_1 = constant(1, 1););
298 let eq_guard = !guard!(eq["out"]);
299 let assigns = build_assignments!(builder;
300 wait_restore_reg["in"] = eq_guard? cst_1["out"];
302 wait_restore_reg["write_en"] = eq_guard? cst_1["out"];
304 group["done"] = ?wait_restore_reg["done"];
306 );
307 group.borrow_mut().assignments.extend(assigns);
308 group
309}
310
311fn build_member(
313 builder: &mut ir::Builder,
314 cells: &[RRC<ir::Cell>; 2],
315 groups: &[RRC<ir::Group>; 3],
316 member_idx: &u64,
317) -> ir::Control {
318 let mut stmts: Vec<ir::Control> = Vec::new();
319
320 let barrier = Rc::clone(&cells[0]);
321 let eq = Rc::clone(&cells[1]);
322 let wait_restore = Rc::clone(&groups[0]);
323 let restore = Rc::clone(&groups[1]);
324 let clear_barrier = Rc::clone(&groups[2]);
325
326 structure!(builder;
327 let save = prim std_reg(32););
328 let incr_barrier =
329 build_incr_barrier(builder, &barrier, &save, &(member_idx - 1));
330 let write_barrier =
331 build_write_barrier(builder, &barrier, &save, &(member_idx - 1));
332 let wait = build_wait(builder, &eq);
333
334 stmts.push(ir::Control::enable(incr_barrier));
335 stmts.push(ir::Control::enable(write_barrier));
336 stmts.push(ir::Control::enable(wait));
337 if member_idx == &1 {
338 stmts.push(ir::Control::enable(clear_barrier));
339 stmts.push(ir::Control::enable(restore));
340 } else {
341 stmts.push(ir::Control::enable(wait_restore));
342 }
343 ir::Control::seq(stmts)
344}
345
346impl Visitor for CompileSync {
347 fn finish_par(
348 &mut self,
349 s: &mut ir::Par,
350 comp: &mut ir::Component,
351 sigs: &ir::LibrarySignatures,
352 _comps: &[ir::Component],
353 ) -> VisResult {
354 let mut builder = ir::Builder::new(comp, sigs);
355 let mut barrier_count: HashMap<u64, u64> = HashMap::new();
356 for stmt in s.stmts.iter_mut() {
357 let mut cnt: HashSet<u64> = HashSet::new();
358 count_barriers(stmt, &mut cnt)?;
359 for barrier in cnt {
360 barrier_count
361 .entry(barrier)
362 .and_modify(|count| *count += 1)
363 .or_insert(1);
364 }
365 self.build_barriers(&mut builder, stmt, &mut barrier_count);
366 }
367
368 if self.barriers.is_empty() {
369 return Ok(Action::Continue);
370 }
371
372 let mut init_barriers: Vec<ir::Control> = Vec::new();
373 for (n, (cells, groups)) in self.barriers.iter() {
374 let barrier = Rc::clone(&cells[0]);
375 let eq = Rc::clone(&cells[1]);
376 let restore = Rc::clone(&groups[1]);
377 let n_members = barrier_count.get(n).unwrap();
378 structure!(builder;
379 let num_members = constant(*n_members, 32);
380 );
381 let assigns = build_assignments!(builder;
383 eq["left"] = ?barrier["peek"];
385 eq["right"] = ?num_members["out"];
387 );
388 builder.component.continuous_assignments.extend(assigns);
389
390 init_barriers.push(ir::Control::enable(restore));
391 }
392
393 let mut changed_sequence: Vec<ir::Control> =
395 vec![ir::Control::par(init_barriers)];
396 let mut copied_par_stmts: Vec<ir::Control> = Vec::new();
397 for con in s.stmts.drain(..) {
398 copied_par_stmts.push(con);
399 }
400 changed_sequence.push(ir::Control::par(copied_par_stmts));
401
402 Ok(Action::change(ir::Control::seq(changed_sequence)))
403 }
404}