calyx_opt/analysis/
dataflow_order.rs

1use super::read_write_set::ReadWriteSet;
2use crate::analysis;
3use calyx_ir::{self as ir};
4use calyx_utils::{CalyxResult, Error};
5use ir::RRC;
6use itertools::Itertools;
7use petgraph::{
8    algo,
9    graph::{DiGraph, NodeIndex},
10};
11use std::collections::{HashMap, HashSet};
12
13/// Mapping from the name output port to all the input ports that must be driven before it.
14type WriteMap = HashMap<ir::Id, HashSet<ir::Id>>;
15
16/// Given a set of assignment, generates an ordering that respects combinatinal
17/// dataflow.
18pub struct DataflowOrder {
19    // Mapping from name of a primitive to its [WriteMap].
20    write_map: HashMap<ir::Id, WriteMap>,
21}
22
23/// Generate a write map using a primitive definition.
24fn prim_to_write_map(prim: &ir::Primitive) -> CalyxResult<WriteMap> {
25    let read_together_spec = analysis::PortInterface::comb_path_spec(prim)?;
26    let mut inputs = HashSet::new();
27    let mut outputs: Vec<(ir::Id, bool)> = Vec::new();
28
29    // Handle ports not mentioned in read_together specs.
30    // Each remaining output ports are dependent on all remaining inputs unless it is marked as
31    // @stable or is an interface port in which case it does not depend on any inputs.
32    for port in &prim.signature {
33        let attrs = &port.attributes;
34        if attrs.get(ir::NumAttr::ReadTogether).is_some() {
35            continue;
36        }
37        match port.direction {
38            ir::Direction::Input => {
39                inputs.insert(port.name());
40            }
41            ir::Direction::Output => outputs.push((
42                port.name(),
43                attrs
44                    .get(ir::BoolAttr::Stable)
45                    .or_else(|| attrs.get(ir::NumAttr::Done))
46                    .is_some(),
47            )),
48            ir::Direction::Inout => {
49                unreachable!("Primitive ports should not be inout")
50            }
51        }
52    }
53    let all_ports: WriteMap = outputs
54        .into_iter()
55        .map(|(out, stable)| {
56            // Stable ports don't depend on anything
57            if stable {
58                (out, HashSet::new())
59            } else {
60                (out, inputs.clone())
61            }
62        })
63        .chain(read_together_spec)
64        .collect();
65    Ok(all_ports)
66}
67
68/// Get the name of the port's cell's prototype if it is a component.
69fn primitive_parent(pr: &RRC<ir::Port>) -> Option<ir::Id> {
70    let port = pr.borrow();
71    match &port.cell_parent().borrow().prototype {
72        ir::CellType::Primitive { name, .. } => Some(*name),
73        ir::CellType::Component { .. }
74        | ir::CellType::ThisComponent
75        | ir::CellType::Constant { .. } => None,
76    }
77}
78
79impl DataflowOrder {
80    pub fn new<'a>(
81        primitives: impl Iterator<Item = &'a ir::Primitive>,
82    ) -> CalyxResult<Self> {
83        let write_map = primitives
84            .map(|p| prim_to_write_map(p).map(|wm| (p.name, wm)))
85            .collect::<CalyxResult<_>>()?;
86        Ok(DataflowOrder { write_map })
87    }
88
89    pub fn dataflow_sort<T>(
90        &self,
91        assigns: Vec<ir::Assignment<T>>,
92    ) -> CalyxResult<Vec<ir::Assignment<T>>>
93    where
94        T: ToString + Clone + Eq,
95    {
96        // Construct a graph where a node is an assignment and there is edge between
97        // nodes if one should occur before another.
98        let mut gr: DiGraph<Option<ir::Assignment<T>>, ()> = DiGraph::new();
99
100        // Mapping from the index corresponding to an assignment to its read/write sets.
101        let mut writes: HashMap<ir::Canonical, Vec<NodeIndex>> = HashMap::new();
102        let mut reads: Vec<(NodeIndex, (ir::Id, ir::Canonical))> =
103            Vec::with_capacity(assigns.len());
104
105        // Assignments to the hole are not considered in the sorting.
106        let mut hole_writes: Vec<ir::Assignment<T>> = Vec::new();
107
108        // Construct the nodes that contain the assignments
109        for assign in assigns {
110            if assign.dst.borrow().is_hole() {
111                hole_writes.push(assign)
112            } else {
113                let rs = ReadWriteSet::port_reads(&assign)
114                    .filter_map(|p| {
115                        primitive_parent(&p)
116                            .map(|comp| (comp, p.borrow().canonical()))
117                    })
118                    .collect_vec();
119                let ws = {
120                    let dst = assign.dst.borrow();
121                    if dst.cell_parent().borrow().is_primitive::<&str>(None) {
122                        Some(dst.canonical())
123                    } else {
124                        None
125                    }
126                };
127                let idx = gr.add_node(Some(assign));
128                reads.extend(rs.into_iter().map(|r| (idx, r)));
129                if let Some(w_can) = ws {
130                    writes.entry(w_can).or_default().push(idx);
131                }
132            }
133        }
134
135        // Walk over the writes and add edges between all required reads
136        // XXX(rachit): This probably adds a bunch of duplicate edges and in the
137        // worst case makes this pass much slower than it needs to be.
138        for (r_idx, (comp, canonical_port)) in reads {
139            let ir::Canonical { cell: inst, port } = canonical_port;
140            let dep_ports = self
141                .write_map
142                .get(&comp)
143                .unwrap_or_else(|| {
144                    panic!("Component `{comp}` write map is not defined")
145                })
146                .get(&port)
147                .unwrap_or_else(|| {
148                    panic!(
149                        "Port `{}.{}` write map is not defined",
150                        comp,
151                        port.clone()
152                    )
153                });
154
155            dep_ports
156                .iter()
157                .cloned()
158                .flat_map(|port| writes.get(&ir::Canonical::new(inst, port)))
159                .flatten()
160                .try_for_each(|w_idx| {
161                    if *w_idx == r_idx {
162                        Err(Error::misc(format!(
163                            "Assignment depends on itself: {}",
164                            ir::Printer::assignment_to_str(
165                                gr[*w_idx].as_ref().unwrap()
166                            )
167                        )))
168                    } else {
169                        gr.add_edge(*w_idx, r_idx, ());
170                        Ok(())
171                    }
172                })?;
173        }
174
175        // Generate a topological ordering
176        if let Ok(order) = algo::toposort(&gr, None) {
177            let mut assigns = order
178                .into_iter()
179                .map(|idx| gr[idx].take().unwrap())
180                .collect_vec();
181            assigns.append(&mut hole_writes);
182            Ok(assigns)
183        } else {
184            // Compute strongly connected component of the graph
185            let sccs = algo::kosaraju_scc(&gr);
186            let scc = sccs
187                .iter()
188                .find(|cc| cc.len() > 1)
189                .expect("All combinational cycles are self loops");
190            let msg = scc
191                .iter()
192                .map(|idx| {
193                    ir::Printer::assignment_to_str(gr[*idx].as_ref().unwrap())
194                })
195                .join("\n");
196            Err(Error::misc(format!("Found combinational cycle:\n{msg}")))
197        }
198    }
199}