calyx_opt/analysis/
compaction_analysis.rs

1use crate::analysis::{ControlOrder, PromotionAnalysis};
2use calyx_ir::{self as ir};
3use ir::GetAttributes;
4use itertools::Itertools;
5use petgraph::{algo, graph::NodeIndex};
6use std::collections::HashMap;
7
8use super::read_write_set::AssignmentAnalysis;
9
10/// Struct to perform compaction on `seqs`.
11/// It will only work if you update_cont_read_writes for each component that
12/// you run it on.
13#[derive(Debug, Default)]
14pub struct CompactionAnalysis {
15    cont_reads: Vec<ir::RRC<ir::Cell>>,
16    cont_writes: Vec<ir::RRC<ir::Cell>>,
17}
18
19impl CompactionAnalysis {
20    /// Updates self so that compaction will take continuous assignments into account
21    pub fn update_cont_read_writes(&mut self, comp: &mut ir::Component) {
22        let (cont_reads, cont_writes) = (
23            comp.continuous_assignments
24                .iter()
25                .analysis()
26                .cell_reads()
27                .collect(),
28            comp.continuous_assignments
29                .iter()
30                .analysis()
31                .cell_writes()
32                .collect(),
33        );
34        self.cont_reads = cont_reads;
35        self.cont_writes = cont_writes;
36    }
37
38    // Given a total_order and sorted schedule, builds a vec of the original seq.
39    // Note that this function assumes the `total_order`` and `sorted_schedule`
40    // represent a completely sequential schedule.
41    fn recover_seq(
42        mut total_order: petgraph::graph::DiGraph<Option<ir::Control>, ()>,
43        sorted_schedule: Vec<(NodeIndex, u64)>,
44    ) -> Vec<ir::Control> {
45        sorted_schedule
46            .into_iter()
47            .map(|(i, _)| total_order[i].take().unwrap())
48            .collect_vec()
49    }
50
51    /// Takes a vec of ctrl stmts and turns it into a compacted schedule.
52    /// If compaction doesn't lead to any latency decreases, it just returns
53    /// a vec of stmts in the original order.
54    /// If it can compact, then it returns a vec with one
55    /// element: a compacted static par.
56    pub fn compact_control_vec(
57        &mut self,
58        stmts: Vec<ir::Control>,
59        promotion_analysis: &mut PromotionAnalysis,
60        builder: &mut ir::Builder,
61    ) -> Vec<ir::Control> {
62        // Records the corresponding node indices that each control program
63        // has data dependency on.
64        let mut dependency: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
65        // Records the latency of corresponding control operator for each
66        // node index.
67        let mut latency_map: HashMap<NodeIndex, u64> = HashMap::new();
68        // Records the scheduled start time of corresponding control operator
69        // for each node index.
70        let mut schedule: HashMap<NodeIndex, u64> = HashMap::new();
71
72        let og_latency: u64 = stmts
73            .iter()
74            .map(PromotionAnalysis::get_inferred_latency)
75            .sum();
76
77        let mut total_order = ControlOrder::<false>::get_dependency_graph_seq(
78            stmts.into_iter(),
79            (&self.cont_reads, &self.cont_writes),
80            &mut dependency,
81            &mut latency_map,
82        );
83
84        if let Ok(order) = algo::toposort(&total_order, None) {
85            let mut total_time: u64 = 0;
86
87            // First we build the schedule.
88            for i in order {
89                // Start time is when the latest dependency finishes
90                let start = dependency
91                    .get(&i)
92                    .unwrap()
93                    .iter()
94                    .map(|node| schedule[node] + latency_map[node])
95                    .max()
96                    .unwrap_or(0);
97                schedule.insert(i, start);
98                total_time = std::cmp::max(start + latency_map[&i], total_time);
99            }
100
101            // We sort the schedule by start time.
102            let mut sorted_schedule: Vec<(NodeIndex, u64)> =
103                schedule.into_iter().collect();
104            sorted_schedule
105                .sort_by(|(k1, v1), (k2, v2)| (v1, k1).cmp(&(v2, k2)));
106
107            if total_time == og_latency {
108                // If we can't comapct at all, then just recover the and return
109                // the original seq.
110                return Self::recover_seq(total_order, sorted_schedule);
111            }
112
113            // Threads for the static par, where each entry is (thread, thread_latency)
114            let mut par_threads: Vec<(Vec<ir::Control>, u64)> = Vec::new();
115
116            // We encode the schedule while trying to minimize the number of
117            // par threads.
118            'outer: for (i, start) in sorted_schedule {
119                let control = total_order[i].take().unwrap();
120                for (thread, thread_latency) in par_threads.iter_mut() {
121                    if *thread_latency <= start {
122                        if *thread_latency < start {
123                            // Need a no-op group so the schedule starts correctly
124                            let no_op = builder.add_static_group(
125                                "no-op",
126                                start - *thread_latency,
127                            );
128                            thread.push(ir::Control::Static(
129                                ir::StaticControl::Enable(ir::StaticEnable {
130                                    group: no_op,
131                                    attributes: ir::Attributes::default(),
132                                }),
133                            ));
134                            *thread_latency = start;
135                        }
136                        thread.push(control);
137                        *thread_latency += latency_map[&i];
138                        continue 'outer;
139                    }
140                }
141                // We must create a new par thread.
142                if start > 0 {
143                    // If start > 0, then we must add a delay to the start of the
144                    // group.
145                    let no_op = builder.add_static_group("no-op", start);
146                    let no_op_enable = ir::Control::Static(
147                        ir::StaticControl::Enable(ir::StaticEnable {
148                            group: no_op,
149                            attributes: ir::Attributes::default(),
150                        }),
151                    );
152                    par_threads.push((
153                        vec![no_op_enable, control],
154                        start + latency_map[&i],
155                    ));
156                } else {
157                    par_threads.push((vec![control], latency_map[&i]));
158                }
159            }
160            // Turn Vec<ir::StaticControl> -> StaticSeq
161            let mut par_control_threads: Vec<ir::StaticControl> = Vec::new();
162            for (thread, thread_latency) in par_threads {
163                let mut promoted_stmts = thread
164                    .into_iter()
165                    .map(|mut stmt| {
166                        promotion_analysis.convert_to_static(&mut stmt, builder)
167                    })
168                    .collect_vec();
169                if promoted_stmts.len() == 1 {
170                    // Don't wrap in static seq if we don't need to.
171                    par_control_threads.push(promoted_stmts.pop().unwrap());
172                } else {
173                    par_control_threads.push(ir::StaticControl::Seq(
174                        ir::StaticSeq {
175                            stmts: promoted_stmts,
176                            attributes: ir::Attributes::default(),
177                            latency: thread_latency,
178                        },
179                    ));
180                }
181            }
182            // Double checking that we have built the static par correctly.
183            let max: Option<u64> =
184                par_control_threads.iter().map(|c| c.get_latency()).max();
185            assert!(
186                max.unwrap() == total_time,
187                "The schedule expects latency {}. The static par that was built has latency {}",
188                total_time,
189                max.unwrap()
190            );
191
192            let mut s_par = ir::StaticControl::Par(ir::StaticPar {
193                stmts: par_control_threads,
194                attributes: ir::Attributes::default(),
195                latency: total_time,
196            });
197            s_par.get_mut_attributes().insert(ir::BoolAttr::Promoted, 1);
198            vec![ir::Control::Static(s_par)]
199        } else {
200            panic!(
201                "Error when producing topo sort. Dependency graph has a cycle."
202            );
203        }
204    }
205}