use crate::{OutOfBoundError, SWHType};
use anyhow::{Context, Result};
use log::info;
use mmap_rs::{Mmap, MmapFlags, MmapMut};
use std::path::Path;
use sux::prelude::{BitFieldSlice, BitFieldSliceCore, BitFieldSliceMut, BitFieldVec};
pub struct Node2Type<B> {
data: BitFieldVec<usize, B>,
}
impl<B: AsRef<[usize]>> Node2Type<B> {
#[inline]
pub unsafe fn get_unchecked(&self, node_id: usize) -> SWHType {
SWHType::try_from(self.data.get_unchecked(node_id) as u8).unwrap()
}
#[inline]
pub fn get(&self, node_id: usize) -> Result<SWHType, OutOfBoundError> {
SWHType::try_from(self.data.get(node_id) as u8).map_err(|_| OutOfBoundError {
index: node_id,
len: self.data.len(),
})
}
}
impl<B: AsRef<[usize]> + AsMut<[usize]>> Node2Type<B> {
#[inline]
pub unsafe fn set_unchecked(&mut self, node_id: usize, node_type: SWHType) {
self.data.set_unchecked(node_id, node_type as usize);
}
#[inline]
pub fn set(&mut self, node_id: usize, node_type: SWHType) {
self.data.set(node_id, node_type as usize);
}
}
pub struct UsizeMmap<B>(B);
impl<B: AsRef<[u8]>> AsRef<[usize]> for UsizeMmap<B> {
fn as_ref(&self) -> &[usize] {
bytemuck::cast_slice(self.0.as_ref())
}
}
impl<B: AsRef<[u8]> + AsMut<[u8]>> AsMut<[usize]> for UsizeMmap<B> {
fn as_mut(&mut self) -> &mut [usize] {
bytemuck::cast_slice_mut(self.0.as_mut())
}
}
impl Node2Type<UsizeMmap<MmapMut>> {
pub fn new<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = ((num_nodes * SWHType::BITWIDTH) as u64).div_ceil(64) * 8;
info!("The resulting file will be {} bytes long.", file_len);
let node2type_file = std::fs::File::options()
.read(true)
.write(true)
.create(true)
.open(path)
.with_context(|| {
format!(
"While creating the .node2type.bin file: {}",
path.to_string_lossy()
)
})?;
node2type_file
.set_len(file_len)
.with_context(|| "While fallocating the file with zeros")?;
let mmap = unsafe {
mmap_rs::MmapOptions::new(file_len as _)
.context("Could not initialize mmap")?
.with_file(node2type_file, 0)
.map_mut()
.with_context(|| "While mmapping the file")?
};
let mmap = UsizeMmap(mmap);
let node2type = unsafe { BitFieldVec::from_raw_parts(mmap, SWHType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
pub fn load_mut<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = path
.metadata()
.with_context(|| format!("Could not stat {}", path.display()))?
.len();
let file = std::fs::File::open(path)
.with_context(|| format!("Could not open {}", path.display()))?;
let data = unsafe {
mmap_rs::MmapOptions::new(file_len as _)
.context("Could not initialize mmap")?
.with_flags(MmapFlags::TRANSPARENT_HUGE_PAGES)
.with_file(file, 0)
.map_mut()?
};
#[cfg(target_os = "linux")]
unsafe {
libc::madvise(data.as_ptr() as *mut _, data.len(), libc::MADV_RANDOM)
};
let data = UsizeMmap(data);
let node2type = unsafe { BitFieldVec::from_raw_parts(data, SWHType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
}
impl Node2Type<UsizeMmap<Mmap>> {
pub fn load<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = path
.metadata()
.with_context(|| format!("Could not stat {}", path.display()))?
.len();
let expected_file_len = ((num_nodes * SWHType::BITWIDTH).div_ceil(64) * 8) as u64;
assert_eq!(
file_len,
expected_file_len,
"Expected {} to have size {} (because graph has {} nodes), but it has size {}",
path.display(),
expected_file_len,
num_nodes,
file_len,
);
let file = std::fs::File::open(path)
.with_context(|| format!("Could not open {}", path.display()))?;
let data = unsafe {
mmap_rs::MmapOptions::new(file_len as _)?
.with_flags(MmapFlags::TRANSPARENT_HUGE_PAGES)
.with_file(file, 0)
.map()?
};
#[cfg(target_os = "linux")]
unsafe {
libc::madvise(data.as_ptr() as *mut _, data.len(), libc::MADV_RANDOM)
};
let data = UsizeMmap(data);
let node2type = unsafe { BitFieldVec::from_raw_parts(data, SWHType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
}
impl Node2Type<UsizeMmap<Vec<u8>>> {
pub fn new_from_iter(types: impl ExactSizeIterator<Item = SWHType>) -> Self {
let num_nodes = types.len();
let file_len = ((num_nodes * SWHType::BITWIDTH) as u64).div_ceil(64) * 8;
let file_len = file_len.try_into().expect("num_nodes overflowed usize");
let data = UsizeMmap(vec![0; file_len]);
let data = unsafe { BitFieldVec::from_raw_parts(data, SWHType::BITWIDTH, num_nodes) };
let mut node2type = Node2Type { data };
for (i, type_) in types.enumerate() {
node2type.set(i, type_);
}
node2type
}
}