calyx_opt/analysis/
port_interface.rs

1use calyx_ir as ir;
2use calyx_utils::{CalyxResult, Error};
3use itertools::Itertools;
4use std::collections::{HashMap, HashSet};
5
6/// Tuple containing (port, set of ports).
7/// When the first port is read from, all of the ports in the set must be written to.
8type ReadTogether = (ir::Id, HashSet<ir::Id>);
9/// Read together specs map the name of a primitive to its [ReadTogether] specs
10type ReadTogetherSpecs = HashMap<ir::Id, Vec<ReadTogether>>;
11
12/// Set of ports that need to be driven together.
13type WriteTogether = HashSet<ir::Id>;
14// Write together specs map the name of a primitive to the set of ports that need
15// to be driven together.
16type WriteTogetherSpecs = HashMap<ir::Id, Vec<WriteTogether>>;
17
18/// Helper methods to parse `@read_together` and `@write_together` specifications
19pub struct PortInterface;
20
21impl PortInterface {
22    /// Construct @write_together specs from the primitive definitions.
23    pub fn write_together_specs<'a>(
24        primitives: impl Iterator<Item = &'a ir::Primitive>,
25    ) -> WriteTogetherSpecs {
26        let mut write_together = HashMap::new();
27        for prim in primitives {
28            let writes: Vec<HashSet<ir::Id>> = prim
29                .find_all_with_attr(ir::NumAttr::WriteTogether)
30                .map(|pd| {
31                    (
32                        pd.attributes.get(ir::NumAttr::WriteTogether).unwrap(),
33                        pd.name(),
34                    )
35                })
36                .into_group_map()
37                .into_values()
38                .map(|writes| writes.into_iter().collect::<HashSet<_>>())
39                .collect();
40            if !writes.is_empty() {
41                write_together.insert(prim.name, writes);
42            }
43        }
44        write_together
45    }
46
47    /// Construct `@read_together` spec from the definition of a primitive.
48    /// Each spec is allowed to have exactly one output port along with one
49    /// or more input ports.
50    /// The specification dictates that before reading the output port, the
51    /// input ports must be driven, i.e., the output port is combinationally
52    /// related to the input ports and only those ports.
53    pub fn comb_path_spec(
54        prim: &ir::Primitive,
55    ) -> CalyxResult<Vec<ReadTogether>> {
56        prim
57                .find_all_with_attr(ir::NumAttr::ReadTogether)
58                .map(|pd| (pd.attributes.get(ir::NumAttr::ReadTogether).unwrap(), pd))
59                .into_group_map()
60                .into_values()
61                .map(|ports| {
62                    let (outputs, inputs): (Vec<_>, Vec<_>) =
63                        ports.into_iter().partition(|&port| {
64                            matches!(port.direction, ir::Direction::Output)
65                        });
66                    // There should only be one port in the read_together specification.
67                    if outputs.len() != 1 {
68                        return Err(Error::papercut(format!("Invalid @read_together specification for primitive `{}`. Each specification group is only allowed to have one output port specified.", prim.name)))
69                    }
70                    assert!(outputs.len() == 1);
71                    Ok((
72                        outputs[0].name(),
73                        inputs
74                            .into_iter()
75                            .map(|port| port.name())
76                            .collect::<HashSet<_>>(),
77                    ))
78                })
79                .collect::<CalyxResult<_>>()
80    }
81
82    /// Construct @read_together specs from the primitive definitions.
83    pub fn comb_path_specs<'a>(
84        primitives: impl Iterator<Item = &'a ir::Primitive>,
85    ) -> CalyxResult<ReadTogetherSpecs> {
86        let mut read_together = HashMap::new();
87        for prim in primitives {
88            let reads = Self::comb_path_spec(prim)?;
89            if !reads.is_empty() {
90                read_together.insert(prim.name, reads);
91            }
92        }
93        Ok(read_together)
94    }
95}