calyx_utils/
weight_graph.rs1use itertools::Itertools;
2use petgraph::matrix_graph::{MatrixGraph, NodeIndex, UnMatrix, Zero};
3use petgraph::visit::IntoEdgeReferences;
4use std::{collections::HashMap, fmt::Display, hash::Hash};
5
6pub type Idx = NodeIndex;
8
9pub struct BoolIdx(bool);
11
12impl From<bool> for BoolIdx {
13 fn from(b: bool) -> Self {
14 BoolIdx(b)
15 }
16}
17
18impl Zero for BoolIdx {
19 fn zero() -> Self {
20 BoolIdx(false)
21 }
22
23 fn is_zero(&self) -> bool {
24 !self.0
25 }
26}
27
28pub struct WeightGraph<T> {
37 pub index_map: HashMap<T, NodeIndex>,
39 pub graph: UnMatrix<(), BoolIdx>,
41}
42
43impl<T: Eq + Hash + Clone + Ord> Default for WeightGraph<T> {
44 fn default() -> Self {
45 WeightGraph {
46 index_map: HashMap::new(),
47 graph: MatrixGraph::new_undirected(),
48 }
49 }
50}
51
52impl<T, C> From<C> for WeightGraph<T>
53where
54 T: Eq + Hash + Ord,
55 C: Iterator<Item = T>,
56{
57 fn from(nodes: C) -> Self {
58 let mut graph = MatrixGraph::new_undirected();
59 let index_map: HashMap<_, _> =
60 nodes.map(|node| (node, graph.add_node(()))).collect();
61 WeightGraph { index_map, graph }
62 }
63}
64
65impl<'a, T> WeightGraph<T>
66where
67 T: 'a + Eq + Hash + Clone + Ord,
68{
69 #[inline(always)]
71 pub fn add_edge(&mut self, a: &T, b: &T) {
72 self.graph.update_edge(
73 self.index_map[a],
74 self.index_map[b],
75 true.into(),
76 );
77 }
78
79 pub fn add_all_edges<C>(&mut self, items: C)
81 where
82 C: Iterator<Item = &'a T> + Clone,
83 {
84 items.tuple_combinations().for_each(|(src, dst)| {
85 self.add_edge(src, dst);
86 });
87 }
88
89 #[inline(always)]
91 pub fn contains_node(&self, node: &T) -> bool {
92 self.index_map.contains_key(node)
93 }
94
95 pub fn add_node(&mut self, node: T) {
103 debug_assert!(
104 !self.index_map.contains_key(&node),
105 "Attempted to add pre-existing node to WeightGraph. Client code should ensure that this never happens."
106 );
107 let idx = self.graph.add_node(());
108 self.index_map.insert(node, idx);
109 }
110
111 pub fn reverse_index(&self) -> HashMap<NodeIndex, T> {
113 self.index_map
114 .iter()
115 .map(|(k, v)| (*v, k.clone()))
116 .collect()
117 }
118
119 pub fn nodes(&self) -> impl Iterator<Item = &T> {
121 self.index_map.keys()
122 }
123
124 pub fn degree(&self, node: &T) -> usize {
126 self.graph.neighbors(self.index_map[node]).count()
127 }
128}
129
130impl<T: Eq + Hash + ToString + Clone + Ord> Display for WeightGraph<T> {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 let rev_map = self.reverse_index();
133 let keys: Vec<_> = self.index_map.keys().collect();
134 let nodes = keys
135 .iter()
136 .map(|key| {
137 format!(
138 " {} [label=\"{}\"];",
139 key.to_string(),
140 key.to_string()
141 )
142 })
143 .collect::<Vec<_>>()
144 .join("\n");
145 let edges = self
146 .graph
147 .edge_references()
148 .map(|(a_idx, b_idx, _)| {
149 format!(
150 " {} -- {};",
151 rev_map[&a_idx].to_string(),
152 rev_map[&b_idx].to_string()
153 )
154 })
155 .collect::<Vec<_>>()
156 .join("\n");
157 write!(f, "graph {{ \n{nodes}\n{edges}\n }}")
158 }
159}