use std::cell::{RefCell, RefMut};
use std::fs::File;
use std::num::NonZeroU16;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{ensure, Context, Result};
#[cfg(feature = "arrow")]
use arrow::array::StructArray;
use rayon::prelude::*;
use thread_local::ThreadLocal;
#[cfg(feature = "arrow-ipc")]
mod ipc;
#[cfg(feature = "arrow-ipc")]
pub use ipc::*;
#[cfg(feature = "parquet")]
mod parquet;
#[cfg(feature = "parquet")]
pub use parquet::*;
#[cfg(feature = "arrow")]
#[allow(clippy::len_without_is_empty)]
pub trait StructArrayBuilder {
fn len(&self) -> usize;
fn finish(self) -> Result<StructArray>;
}
pub struct ParallelDatasetWriter<W: TableWriter + Send> {
num_files: AtomicU64,
schema: W::Schema,
path: PathBuf,
writers: ThreadLocal<RefCell<W>>,
pub flush_threshold: Option<usize>,
}
impl<W: TableWriter<Schema = ()> + Send> ParallelDatasetWriter<W> {
pub fn new(path: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&path)
.with_context(|| format!("Could not create {}", path.display()))?;
Ok(ParallelDatasetWriter {
num_files: AtomicU64::new(0),
schema: (),
path,
writers: ThreadLocal::new(),
flush_threshold: None,
})
}
}
impl<W: TableWriter + Send> ParallelDatasetWriter<W> {
pub fn new_with_schema(path: PathBuf, schema: W::Schema) -> Result<Self> {
std::fs::create_dir_all(&path)
.with_context(|| format!("Could not create {}", path.display()))?;
Ok(ParallelDatasetWriter {
num_files: AtomicU64::new(0),
schema,
path,
writers: ThreadLocal::new(),
flush_threshold: None,
})
}
fn get_new_seq_writer(&self) -> Result<RefCell<W>> {
let path = self
.path
.join(self.num_files.fetch_add(1, Ordering::Relaxed).to_string());
Ok(RefCell::new(W::new(
path,
self.schema.clone(),
self.flush_threshold,
)?))
}
pub fn get_thread_writer(&self) -> Result<RefMut<W>> {
self.writers
.get_or_try(|| self.get_new_seq_writer())
.map(|writer| writer.borrow_mut())
}
pub fn flush(&mut self) -> Result<()> {
self.writers
.iter_mut()
.collect::<Vec<_>>()
.into_par_iter()
.map(|writer| writer.get_mut().flush())
.collect::<Result<Vec<()>>>()
.map(|_: Vec<()>| ())
}
pub fn close(mut self) -> Result<Vec<W::CloseResult>> {
let mut tmp = ThreadLocal::new();
std::mem::swap(&mut tmp, &mut self.writers);
tmp.into_iter()
.collect::<Vec<_>>()
.into_par_iter()
.map(|writer| writer.into_inner().close())
.collect()
}
}
impl<W: TableWriter + Send> Drop for ParallelDatasetWriter<W> {
fn drop(&mut self) {
let mut tmp = ThreadLocal::new();
std::mem::swap(&mut tmp, &mut self.writers);
tmp.into_iter()
.collect::<Vec<_>>()
.into_par_iter()
.try_for_each(|writer| writer.into_inner().close().map(|_| ()))
.expect("Could not close ParallelDatasetWriter");
}
}
pub trait TableWriter {
type Schema: Clone;
type CloseResult: Send;
fn new(path: PathBuf, schema: Self::Schema, flush_threshold: Option<usize>) -> Result<Self>
where
Self: Sized;
fn flush(&mut self) -> Result<()>;
fn close(self) -> Result<Self::CloseResult>;
}
pub struct PartitionedTableWriter<PartitionWriter: TableWriter + Send> {
partition_writers: Vec<PartitionWriter>,
}
impl<PartitionWriter: TableWriter + Send> TableWriter for PartitionedTableWriter<PartitionWriter> {
type Schema = (String, Option<NonZeroU16>, PartitionWriter::Schema);
type CloseResult = Vec<PartitionWriter::CloseResult>;
fn new(
mut path: PathBuf,
(partition_column, num_partitions, schema): Self::Schema,
flush_threshold: Option<usize>,
) -> Result<Self> {
let thread_id = path.file_name().map(|p| p.to_owned());
ensure!(
path.pop(),
"Unexpected root path for partitioned writer: {}",
path.display()
);
let thread_id = thread_id.unwrap();
Ok(PartitionedTableWriter {
partition_writers: (0..num_partitions.map(NonZeroU16::get).unwrap_or(1))
.map(|partition_id| {
let partition_path = if num_partitions.is_some() {
path.join(format!("{}={}", partition_column, partition_id))
} else {
path.to_owned()
};
std::fs::create_dir_all(&partition_path).with_context(|| {
format!("Could not create {}", partition_path.display())
})?;
PartitionWriter::new(
partition_path.join(&thread_id),
schema.clone(),
flush_threshold,
)
})
.collect::<Result<_>>()?,
})
}
fn flush(&mut self) -> Result<()> {
self.partition_writers
.par_iter_mut()
.try_for_each(|writer| writer.flush())
}
fn close(self) -> Result<Self::CloseResult> {
self.partition_writers
.into_par_iter()
.map(|writer| writer.close())
.collect()
}
}
impl<PartitionWriter: TableWriter + Send> PartitionedTableWriter<PartitionWriter> {
pub fn partitions(&mut self) -> &mut [PartitionWriter] {
&mut self.partition_writers
}
}
pub type CsvZstTableWriter<'a> = csv::Writer<zstd::stream::AutoFinishEncoder<'a, File>>;
impl<'a> TableWriter for CsvZstTableWriter<'a> {
type Schema = ();
type CloseResult = ();
fn new(
mut path: PathBuf,
_schema: Self::Schema,
_flush_threshold: Option<usize>,
) -> Result<Self> {
path.set_extension("csv.zst");
let file =
File::create(&path).with_context(|| format!("Could not create {}", path.display()))?;
let compression_level = 3;
let zstd_encoder = zstd::stream::write::Encoder::new(file, compression_level)
.with_context(|| format!("Could not create ZSTD encoder for {}", path.display()))?
.auto_finish();
Ok(csv::WriterBuilder::new()
.has_headers(true)
.terminator(csv::Terminator::CRLF)
.from_writer(zstd_encoder))
}
fn flush(&mut self) -> Result<()> {
self.flush().context("Could not flush CsvZst writer")
}
fn close(mut self) -> Result<()> {
self.flush().context("Could not close CsvZst writer")
}
}