calyx_opt/analysis/
schedule_conflicts.rs

1use calyx_ir as ir;
2use calyx_utils::{Idx, WeightGraph};
3use petgraph::visit::IntoEdgeReferences;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Default)]
7/// A conflict graph that describes which nodes (i.e. groups/invokes) are being run in parallel
8/// to each other.
9pub struct ScheduleConflicts {
10    graph: WeightGraph<ir::Id>,
11    /// Reverse mapping from node indices to node (i.e. group/invoke) names.
12    /// We can store this because we don't expect nodes or edges to be added
13    /// once a conflict graph is constructed.
14    rev_map: HashMap<Idx, ir::Id>,
15}
16
17/// A conflict between two nodes is specified using the name of the nodes
18/// involved
19type Conflict = (ir::Id, ir::Id);
20
21impl ScheduleConflicts {
22    /// Return a vector of all nodes that conflict with this nodes.
23    pub fn conflicts_with(&self, node: &ir::Id) -> HashSet<ir::Id> {
24        self.graph
25            .graph
26            .neighbors(self.graph.index_map[node])
27            .map(|idx| self.rev_map[&idx])
28            .collect()
29    }
30
31    /// Returns an iterator containing all conflict edges,
32    /// `(src node: ir::Id, dst node: ir::Id)`, in this graph.
33    pub fn all_conflicts(&self) -> impl Iterator<Item = Conflict> + '_ {
34        self.graph
35            .graph
36            .edge_references()
37            .map(move |(src, dst, _)| (self.rev_map[&src], self.rev_map[&dst]))
38    }
39
40    /////////////// Internal Methods //////////////////
41    /// Adds a node to the CurrentConflict set.
42    fn add_node(&mut self, node: ir::Id) {
43        if !self.graph.contains_node(&node) {
44            self.graph.add_node(node)
45        }
46    }
47
48    fn add_edge(&mut self, g1: &ir::Id, g2: &ir::Id) {
49        self.graph.add_edge(g1, g2)
50    }
51}
52
53/// Given a set of vectors of nodes, adds edges between all nodes in one
54/// vector to all nodes in every other vector.
55///
56/// For example:
57/// ```
58/// vec![
59///     vec!["a", "b"],
60///     vec!["c", "d"]
61/// ]
62/// ```
63/// will create the edges:
64/// ```
65/// a --- c
66/// b --- c
67/// a --- d
68/// b --- d
69/// ```
70fn all_conflicting(
71    groups: &[Vec<ir::Id>],
72    current_conflicts: &mut ScheduleConflicts,
73) {
74    for group1 in 0..groups.len() {
75        for group2 in group1 + 1..groups.len() {
76            for node1 in &groups[group1] {
77                for node2 in &groups[group2] {
78                    current_conflicts.add_edge(node1, node2);
79                }
80            }
81        }
82    }
83}
84
85fn build_conflict_graph_static(
86    sc: &ir::StaticControl,
87    confs: &mut ScheduleConflicts,
88    all_nodes: &mut Vec<ir::Id>,
89) {
90    match sc {
91        ir::StaticControl::Enable(ir::StaticEnable { group, .. }) => {
92            confs.add_node(group.borrow().name());
93            all_nodes.push(group.borrow().name());
94        }
95        ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
96            build_conflict_graph_static(body, confs, all_nodes);
97        }
98        ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => stmts
99            .iter()
100            .for_each(|c| build_conflict_graph_static(c, confs, all_nodes)),
101        ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
102            let par_nodes = stmts
103                .iter()
104                .map(|c| {
105                    // Visit this child and add conflict edges.
106                    // Collect the enables in this into a new vector.
107                    let mut nodes = Vec::new();
108                    build_conflict_graph_static(c, confs, &mut nodes);
109                    nodes
110                })
111                .collect::<Vec<_>>();
112
113            // Add conflict edges between all children.
114            all_conflicting(&par_nodes, confs);
115
116            // Add the enables from visiting the children to the current
117            // set of enables.
118            all_nodes.append(&mut par_nodes.into_iter().flatten().collect());
119        }
120        ir::StaticControl::If(ir::StaticIf {
121            tbranch, fbranch, ..
122        }) => {
123            build_conflict_graph_static(tbranch, confs, all_nodes);
124            build_conflict_graph_static(fbranch, confs, all_nodes);
125        }
126        ir::StaticControl::Invoke(ir::StaticInvoke { comp, .. }) => {
127            confs.add_node(comp.borrow().name());
128            all_nodes.push(comp.borrow().name());
129        }
130        ir::StaticControl::Empty(_) => (),
131    }
132}
133/// Construct a conflict graph by traversing the Control program.
134fn build_conflict_graph(
135    c: &ir::Control,
136    confs: &mut ScheduleConflicts,
137    all_nodes: &mut Vec<ir::Id>,
138) {
139    match c {
140        ir::Control::Empty(_) => (),
141        ir::Control::Invoke(ir::Invoke { comp, .. }) => {
142            confs.add_node(comp.borrow().name());
143            all_nodes.push(comp.borrow().name());
144        }
145        ir::Control::Enable(ir::Enable { group, .. }) => {
146            confs.add_node(group.borrow().name());
147            all_nodes.push(group.borrow().name());
148        }
149        ir::Control::Seq(ir::Seq { stmts, .. }) => stmts
150            .iter()
151            .for_each(|c| build_conflict_graph(c, confs, all_nodes)),
152        ir::Control::If(ir::If {
153            cond,
154            tbranch,
155            fbranch,
156            ..
157        }) => {
158            // XXX (rachit): This might be incorrect since cond is a combinational
159            // group
160            if let Some(c) = cond {
161                all_nodes.push(c.borrow().name());
162                confs.add_node(c.borrow().name());
163            }
164            build_conflict_graph(tbranch, confs, all_nodes);
165            build_conflict_graph(fbranch, confs, all_nodes);
166        }
167        ir::Control::While(ir::While { cond, body, .. }) => {
168            // XXX (rachit): This might be incorrect since cond is a combinational
169            // group
170            if let Some(c) = cond {
171                all_nodes.push(c.borrow().name());
172                confs.add_node(c.borrow().name());
173            }
174            build_conflict_graph(body, confs, all_nodes);
175        }
176        ir::Control::Repeat(ir::Repeat { body, .. }) => {
177            build_conflict_graph(body, confs, all_nodes);
178        }
179        ir::Control::Par(ir::Par { stmts, .. }) => {
180            let par_nodes = stmts
181                .iter()
182                .map(|c| {
183                    // Visit this child and add conflict edges.
184                    // Collect the enables in this into a new vector.
185                    let mut nodes = Vec::new();
186                    build_conflict_graph(c, confs, &mut nodes);
187                    nodes
188                })
189                .collect::<Vec<_>>();
190
191            // Add conflict edges between all children.
192            all_conflicting(&par_nodes, confs);
193
194            // Add the enables from visiting the children to the current
195            // set of enables.
196            all_nodes.append(&mut par_nodes.into_iter().flatten().collect());
197        }
198        ir::Control::Static(sc) => {
199            build_conflict_graph_static(sc, confs, all_nodes)
200        }
201        ir::Control::FSMEnable(_) => {
202            todo!("should not encounter fsm nodes")
203        }
204    }
205}
206
207/// Construct ScheduleConflicts from a ir::Control.
208impl From<&ir::Control> for ScheduleConflicts {
209    fn from(control: &ir::Control) -> Self {
210        let mut confs = ScheduleConflicts::default();
211        build_conflict_graph(control, &mut confs, &mut vec![]);
212        // Build the reverse index
213        confs.rev_map = confs.graph.reverse_index();
214        confs
215    }
216}