use super::Attribute;
use crate::{InlineAttributes, attribute::SetAttribute};
use calyx_utils::{CalyxResult, GPosIdx, WithPos};
use itertools::Itertools;
use linked_hash_map::LinkedHashMap;
use smallvec::SmallVec;
use std::{collections::HashMap, convert::TryFrom};
#[derive(Debug, Clone, Default)]
struct HeapAttrInfo {
attrs: LinkedHashMap<Attribute, u64>,
set_attrs: HashMap<SetAttribute, VecSet<u32, 4>>,
span: GPosIdx,
}
#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
pub struct Attributes {
inl: InlineAttributes,
hinfo: Box<HeapAttrInfo>,
}
pub enum ParseAttributeWrapper {
Attribute(Attribute, u64),
Set(SetAttribute, Vec<u32>),
}
impl From<(Attribute, u64)> for ParseAttributeWrapper {
fn from(value: (Attribute, u64)) -> Self {
Self::Attribute(value.0, value.1)
}
}
impl From<(SetAttribute, Vec<u32>)> for ParseAttributeWrapper {
fn from(value: (SetAttribute, Vec<u32>)) -> Self {
Self::Set(value.0, value.1)
}
}
impl TryFrom<Vec<ParseAttributeWrapper>> for Attributes {
type Error = calyx_utils::Error;
fn try_from(v: Vec<ParseAttributeWrapper>) -> CalyxResult<Self> {
let mut attrs = Attributes::default();
for item in v {
match item {
ParseAttributeWrapper::Attribute(k, v) => {
if attrs.has(k) {
return Err(Self::Error::malformed_structure(format!(
"Multiple entries for attribute: {}",
k
)));
}
attrs.insert(k, v);
}
ParseAttributeWrapper::Set(set_attr, vec) => {
if attrs.hinfo.set_attrs.contains_key(&set_attr) {
return Err(Self::Error::malformed_structure(format!(
"Multiple entries for attribute: {}",
set_attr
)));
}
attrs
.hinfo
.set_attrs
.insert(set_attr, vec.into_iter().collect());
}
}
}
Ok(attrs)
}
}
impl WithPos for Attributes {
fn copy_span(&self) -> GPosIdx {
self.hinfo.span
}
}
pub trait GetAttributes {
fn get_attributes(&self) -> &Attributes;
fn get_mut_attributes(&mut self) -> &mut Attributes;
}
impl Attributes {
pub fn insert<A>(&mut self, key: A, val: u64)
where
A: Into<Attribute>,
{
match key.into() {
Attribute::Bool(b) => {
assert!(
val == 1,
"{} is a boolean attribute and can only have a value of 1",
b.as_ref(),
);
self.inl.insert(b);
}
attr => {
self.hinfo.attrs.insert(attr, val);
}
}
}
pub fn get<A>(&self, key: A) -> Option<u64>
where
A: Into<Attribute>,
{
match key.into() {
Attribute::Bool(b) => {
if self.inl.has(b) {
Some(1)
} else {
None
}
}
attr => self.hinfo.attrs.get(&attr).cloned(),
}
}
pub fn get_set<S>(&self, key: S) -> Option<&VecSet<u32>>
where
S: Into<SetAttribute>,
{
self.hinfo.set_attrs.get(&key.into())
}
pub fn has<A>(&self, key: A) -> bool
where
A: Into<Attribute>,
{
match key.into() {
Attribute::Bool(b) => self.inl.has(b),
attr => self.hinfo.attrs.contains_key(&attr),
}
}
pub fn is_empty(&self) -> bool {
self.inl.is_empty()
&& self.hinfo.attrs.is_empty()
&& self.hinfo.set_attrs.is_empty()
}
pub fn remove<A>(&mut self, key: A)
where
A: Into<Attribute>,
{
match key.into() {
Attribute::Bool(b) => {
self.inl.remove(b);
}
attr => {
self.hinfo.attrs.remove(&attr);
}
}
}
pub fn copy_from<A>(&mut self, other: Self, keys: Vec<A>)
where
A: Into<Attribute> + Clone,
{
for key in keys {
match other.get(key.clone()) {
None => (),
Some(val) => self.insert(key, val),
}
}
}
pub fn add_span(mut self, span: GPosIdx) -> Self {
self.hinfo.span = span;
self
}
pub fn to_string_with<F, S>(
&self,
sep: &'static str,
fmt: F,
set_fmt: S,
) -> String
where
F: Fn(String, u64) -> String,
S: Fn(String, &[u32]) -> String,
{
if self.is_empty() {
return String::default();
}
self.hinfo
.attrs
.iter()
.map(|(k, v)| fmt(k.to_string(), *v))
.chain(self.inl.iter().map(|k| fmt(k.as_ref().to_string(), 1)))
.chain(
self.hinfo
.set_attrs
.iter()
.sorted_by_key(|(k, _)| *k)
.filter_map(|(k, v)| {
if v.is_empty() {
None
} else {
let formatted =
set_fmt(k.to_string(), v.as_slice());
if formatted.is_empty() {
None
} else {
Some(formatted)
}
}
}),
)
.collect::<Vec<_>>()
.join(sep)
}
}
impl PartialEq for Attributes {
fn eq(&self, other: &Self) -> bool {
self.inl == other.inl
&& self.hinfo.attrs.len() == other.hinfo.attrs.len()
&& self
.hinfo
.attrs
.iter()
.all(|(k, v)| other.hinfo.attrs.get(k) == Some(v))
&& self
.hinfo
.set_attrs
.iter()
.all(|(k, v)| other.hinfo.set_attrs.get(k) == Some(v))
}
}
impl Eq for Attributes {}
#[cfg(feature = "serialize")]
impl serde::Serialize for HeapAttrInfo {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
ser.collect_map(self.to_owned().attrs.iter())
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct VecSet<D, const ALLOC: usize = 4>
where
D: Eq + std::hash::Hash + Clone,
{
inner: SmallVec<[D; ALLOC]>,
}
impl<D, const ALLOC: usize> VecSet<D, ALLOC>
where
D: Eq + std::hash::Hash + Clone,
{
pub fn new() -> Self {
Self {
inner: SmallVec::new(),
}
}
pub fn insert(&mut self, d: D) {
if !self.inner.contains(&d) {
self.inner.push(d);
}
}
pub fn contains(&self, d: &D) -> bool {
self.inner.contains(d)
}
pub fn iter(&self) -> impl Iterator<Item = &D> {
self.inner.iter()
}
pub fn as_slice(&self) -> &[D] {
self.inner.as_slice()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<D, const ALLOC: usize> FromIterator<D> for VecSet<D, ALLOC>
where
D: Eq + std::hash::Hash + Clone,
{
fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
Self {
inner: iter.into_iter().unique().collect(),
}
}
}