calyx_utils/
weight_graph.rs

1use itertools::Itertools;
2use petgraph::matrix_graph::{MatrixGraph, NodeIndex, UnMatrix, Zero};
3use petgraph::visit::IntoEdgeReferences;
4use std::{collections::HashMap, fmt::Display, hash::Hash};
5
6/// Index into a [WeightGraph]
7pub type Idx = NodeIndex;
8
9/// Edge weight used for the graph nodes
10pub struct BoolIdx(bool);
11
12impl From<bool> for BoolIdx {
13    fn from(b: bool) -> Self {
14        BoolIdx(b)
15    }
16}
17
18impl Zero for BoolIdx {
19    fn zero() -> Self {
20        BoolIdx(false)
21    }
22
23    fn is_zero(&self) -> bool {
24        !self.0
25    }
26}
27
28/// Weight graph provides a wrapper over a Graph that allows adding edges using
29/// the NodeWeight type `T` (petgraph only allows adding edges using `NodeIndex`).
30/// Additionally, the edges are not allowed to have any weights.
31///
32/// The internal representation stores a mapping from NodeWeight `T` to a
33/// `NodeIndex` in the graph.
34/// The underlying `petgraph::MatrixGraph` stores `()` for node weights and
35/// a boolean to represent the edges.
36pub struct WeightGraph<T> {
37    /// Mapping from T to a unique identifier.
38    pub index_map: HashMap<T, NodeIndex>,
39    /// Graph representating using identifier.
40    pub graph: UnMatrix<(), BoolIdx>,
41}
42
43impl<T: Eq + Hash + Clone + Ord> Default for WeightGraph<T> {
44    fn default() -> Self {
45        WeightGraph {
46            index_map: HashMap::new(),
47            graph: MatrixGraph::new_undirected(),
48        }
49    }
50}
51
52impl<T, C> From<C> for WeightGraph<T>
53where
54    T: Eq + Hash + Ord,
55    C: Iterator<Item = T>,
56{
57    fn from(nodes: C) -> Self {
58        let mut graph = MatrixGraph::new_undirected();
59        let index_map: HashMap<_, _> =
60            nodes.map(|node| (node, graph.add_node(()))).collect();
61        WeightGraph { index_map, graph }
62    }
63}
64
65impl<'a, T> WeightGraph<T>
66where
67    T: 'a + Eq + Hash + Clone + Ord,
68{
69    /// Add an edge between `a` and `b`.
70    #[inline(always)]
71    pub fn add_edge(&mut self, a: &T, b: &T) {
72        self.graph.update_edge(
73            self.index_map[a],
74            self.index_map[b],
75            true.into(),
76        );
77    }
78
79    /// Add edges between all given items.
80    pub fn add_all_edges<C>(&mut self, items: C)
81    where
82        C: Iterator<Item = &'a T> + Clone,
83    {
84        items.tuple_combinations().for_each(|(src, dst)| {
85            self.add_edge(src, dst);
86        });
87    }
88
89    /// Checks if the node has already been added to the graph.
90    #[inline(always)]
91    pub fn contains_node(&self, node: &T) -> bool {
92        self.index_map.contains_key(node)
93    }
94
95    /// Add a new node to the graph. Client code should ensure that duplicate
96    /// edges are never added to graph.
97    /// Instead of using this method, consider constructing the graph using
98    /// `From<Iterator<T>>`.
99    ///
100    /// # Panics
101    /// (Debug build only) Panics if node is already present in the graph
102    pub fn add_node(&mut self, node: T) {
103        debug_assert!(
104            !self.index_map.contains_key(&node),
105            "Attempted to add pre-existing node to WeightGraph. Client code should ensure that this never happens."
106        );
107        let idx = self.graph.add_node(());
108        self.index_map.insert(node, idx);
109    }
110
111    /// Returns a Map from `NodeIndex` to `T` (the reverse of the index)
112    pub fn reverse_index(&self) -> HashMap<NodeIndex, T> {
113        self.index_map
114            .iter()
115            .map(|(k, v)| (*v, k.clone()))
116            .collect()
117    }
118
119    /// Returns an iterator over references to nodes in the Graph.
120    pub fn nodes(&self) -> impl Iterator<Item = &T> {
121        self.index_map.keys()
122    }
123
124    /// Return the degree of a given node (number of edges connected).
125    pub fn degree(&self, node: &T) -> usize {
126        self.graph.neighbors(self.index_map[node]).count()
127    }
128}
129
130impl<T: Eq + Hash + ToString + Clone + Ord> Display for WeightGraph<T> {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        let rev_map = self.reverse_index();
133        let keys: Vec<_> = self.index_map.keys().collect();
134        let nodes = keys
135            .iter()
136            .map(|key| {
137                format!(
138                    "  {} [label=\"{}\"];",
139                    key.to_string(),
140                    key.to_string()
141                )
142            })
143            .collect::<Vec<_>>()
144            .join("\n");
145        let edges = self
146            .graph
147            .edge_references()
148            .map(|(a_idx, b_idx, _)| {
149                format!(
150                    "  {} -- {};",
151                    rev_map[&a_idx].to_string(),
152                    rev_map[&b_idx].to_string()
153                )
154            })
155            .collect::<Vec<_>>()
156            .join("\n");
157        write!(f, "graph {{ \n{nodes}\n{edges}\n }}")
158    }
159}