calyx_opt/passes_experimental/
discover_external.rs

1use crate::traversal::{Action, ConstructVisitor, Named, Visitor};
2use calyx_ir as ir;
3use calyx_utils::CalyxResult;
4use ir::RRC;
5use itertools::Itertools;
6use linked_hash_map::LinkedHashMap;
7use std::collections::HashSet;
8
9/// A pass to detect cells that have been inlined into the top-level component
10/// and turn them into real cells marked with [ir::BoolAttr::External].
11pub struct DiscoverExternal {
12    /// The default value used for parameters that cannot be inferred.
13    default: u64,
14    /// The suffix to be remove from the inferred names
15    suffix: Option<String>,
16}
17
18impl Named for DiscoverExternal {
19    fn name() -> &'static str {
20        "discover-external"
21    }
22
23    fn description() -> &'static str {
24        "Detect cells that have been inlined into a component's interface and turn them into @external cells"
25    }
26}
27
28impl ConstructVisitor for DiscoverExternal {
29    fn from(ctx: &ir::Context) -> CalyxResult<Self>
30    where
31        Self: Sized,
32    {
33        // Manual parsing because our options are not flags
34        let n = Self::name();
35        let given_opts: HashSet<_> = ctx
36            .extra_opts
37            .iter()
38            .filter_map(|opt| {
39                let mut splits = opt.split(':');
40                if splits.next() == Some(n) {
41                    splits.next()
42                } else {
43                    None
44                }
45            })
46            .collect();
47
48        let mut default = None;
49        let mut suffix = None;
50        for opt in given_opts {
51            let mut splits = opt.split('=');
52            let spl = splits.next();
53            // Search for the "default=<n>" option
54            if spl == Some("default") {
55                let Some(val) = splits.next().and_then(|v| v.parse().ok())
56                else {
57                    log::warn!(
58                        "Failed to parse default value. Please specify using -x {n}:default=<n>"
59                    );
60                    continue;
61                };
62                log::info!("Setting default value to {val}");
63
64                default = Some(val);
65            }
66            // Search for "strip-suffix=<str>" option
67            else if spl == Some("strip-suffix") {
68                let Some(suff) = splits.next() else {
69                    log::warn!(
70                        "Failed to parse suffix. Please specify using -x {n}:strip-suffix=<str>"
71                    );
72                    continue;
73                };
74                log::info!("Setting suffix to {suff}");
75
76                suffix = Some(suff.to_string());
77            }
78        }
79
80        Ok(Self {
81            default: default.unwrap_or(32),
82            suffix,
83        })
84    }
85
86    fn clear_data(&mut self) {
87        /* All data is shared */
88    }
89}
90
91impl Visitor for DiscoverExternal {
92    fn start(
93        &mut self,
94        comp: &mut ir::Component,
95        sigs: &ir::LibrarySignatures,
96        _comps: &[ir::Component],
97    ) -> crate::traversal::VisResult {
98        // Ignore non-toplevel components
99        if !comp.attributes.has(ir::BoolAttr::TopLevel) {
100            return Ok(Action::Stop);
101        }
102
103        // Group ports by longest common prefix
104        // NOTE(rachit): This is an awfully inefficient representation. We really
105        // want a TrieMap here.
106        let mut prefix_map: LinkedHashMap<String, HashSet<ir::Id>> =
107            LinkedHashMap::new();
108        for port in comp.signature.borrow().ports() {
109            let name = port.borrow().name;
110            let mut prefix = String::new();
111            // Walk over the port name and add it to the prefix map
112            for c in name.as_ref().chars() {
113                prefix.push(c);
114                if prefix == name.as_ref() {
115                    // We have reached the end of the name
116                    break;
117                }
118                // Remove prefix from name
119                let name = name.as_ref().strip_prefix(&prefix).unwrap();
120                prefix_map
121                    .entry(prefix.clone())
122                    .or_default()
123                    .insert(name.into());
124            }
125        }
126
127        // For all cells in the library, build a set of port names.
128        let mut prim_ports: LinkedHashMap<ir::Id, HashSet<ir::Id>> =
129            LinkedHashMap::new();
130        for prim in sigs.signatures() {
131            let hs = prim
132                .signature
133                .iter()
134                .filter(|p| {
135                    // Ignore clk and reset cells
136                    !p.attributes.has(ir::BoolAttr::Clk)
137                        && !p.attributes.has(ir::BoolAttr::Reset)
138                })
139                .map(|p| p.name())
140                .collect::<HashSet<_>>();
141            prim_ports.insert(prim.name, hs);
142        }
143
144        // For all prefixes, check if there is a primitive that matches the
145        // prefix. If there is, then we have an external cell.
146        let mut pre_to_prim: LinkedHashMap<String, ir::Id> =
147            LinkedHashMap::new();
148        for (prefix, ports) in prefix_map.iter() {
149            for (&prim, prim_ports) in prim_ports.iter() {
150                if prim_ports == ports {
151                    pre_to_prim.insert(prefix.clone(), prim);
152                }
153            }
154        }
155
156        // Collect all ports associated with a specific prefix
157        let mut port_map: LinkedHashMap<String, Vec<RRC<ir::Port>>> =
158            LinkedHashMap::new();
159        'outer: for port in &comp.signature.borrow().ports {
160            // If this matches a prefix, add it to the corresponding port map
161            for pre in pre_to_prim.keys() {
162                if port.borrow().name.as_ref().starts_with(pre) {
163                    port_map.entry(pre.clone()).or_default().push(port.clone());
164                    continue 'outer;
165                }
166            }
167        }
168
169        // Add external cells for all matching prefixes
170        let mut pre_to_cells = LinkedHashMap::new();
171        for (pre, &prim) in &pre_to_prim {
172            log::info!("Prefix {pre} matches primitive {prim}");
173            // Attempt to infer the parameters for the external cell
174            let prim_sig = sigs.get_primitive(prim);
175            let ports = &port_map[pre];
176            let mut params: LinkedHashMap<_, Option<u64>> = prim_sig
177                .params
178                .clone()
179                .into_iter()
180                .map(|p| (p, None))
181                .collect();
182
183            // Walk over the abstract port definition and attempt to match the bitwidths
184            for abs in &prim_sig.signature {
185                if let ir::Width::Param { value } = abs.width {
186                    // Find the corresponding port
187                    let port = ports
188                        .iter()
189                        .find(|p| {
190                            p.borrow()
191                                .name
192                                .as_ref()
193                                .ends_with(abs.name().as_ref())
194                        })
195                        .unwrap_or_else(|| {
196                            panic!("No port found for {}", abs.name())
197                        });
198                    // Update the value of the parameter
199                    let v = params.get_mut(&value).unwrap();
200                    if let Some(v) = v {
201                        if *v != port.borrow().width {
202                            log::warn!(
203                                "Mismatched bitwidths for {} in {}, defaulting to {}",
204                                pre,
205                                prim,
206                                self.default
207                            );
208                            *v = self.default;
209                        }
210                    } else {
211                        *v = Some(port.borrow().width);
212                    }
213                }
214            }
215
216            let param_values = params
217                .into_iter()
218                .map(|(_, v)| {
219                    if let Some(v) = v {
220                        v
221                    } else {
222                        log::warn!(
223                            "Unable to infer parameter value for {} in {}, defaulting to {}",
224                            pre,
225                            prim,
226                            self.default
227                        );
228                        self.default
229                    }
230                })
231                .collect_vec();
232
233            let mut builder = ir::Builder::new(comp, sigs);
234            // Remove the suffix from the cell name
235            let name = if let Some(suf) = &self.suffix {
236                pre.strip_suffix(suf).unwrap_or(pre)
237            } else {
238                pre
239            };
240            let cell = builder.add_primitive(name, prim, &param_values);
241            cell.borrow_mut()
242                .attributes
243                .insert(ir::BoolAttr::External, 1);
244            pre_to_cells.insert(pre.clone(), cell);
245        }
246
247        // Rewrite the ports mentioned in the component signature and remove them
248        let mut rewrites: ir::rewriter::PortRewriteMap = LinkedHashMap::new();
249        for (pre, ports) in port_map {
250            // let prim = sigs.get_primitive(pre_to_prim[&pre]);
251            let cr = pre_to_cells[&pre].clone();
252            let cell = cr.borrow();
253            let cell_ports = cell.ports();
254            // Iterate over ports with the same names.
255            for pr in ports {
256                let port = pr.borrow();
257                let cp = cell_ports
258                    .iter()
259                    .find(|p| {
260                        port.name.as_ref().ends_with(p.borrow().name.as_ref())
261                    })
262                    .unwrap_or_else(|| {
263                        panic!("No port found for {}", port.name)
264                    });
265                rewrites.insert(port.canonical(), cp.clone());
266            }
267        }
268
269        comp.for_each_assignment(|assign| {
270            assign.for_each_port(|port| {
271                rewrites.get(&port.borrow().canonical()).cloned()
272            })
273        });
274        comp.for_each_static_assignment(|assign| {
275            assign.for_each_port(|port| {
276                rewrites.get(&port.borrow().canonical()).cloned()
277            })
278        });
279
280        // Remove all ports from the signature that match a prefix
281        comp.signature.borrow_mut().ports.retain(|p| {
282            !pre_to_prim
283                .keys()
284                .any(|pre| p.borrow().name.as_ref().starts_with(pre))
285        });
286
287        // Purely structural pass
288        Ok(Action::Stop)
289    }
290}