Skip to content

Instantly share code, notes, and snippets.

@segeljakt
Last active November 11, 2023 22:47
Show Gist options
  • Select an option

  • Save segeljakt/c490d15444eee55cc4096d821a1565c0 to your computer and use it in GitHub Desktop.

Select an option

Save segeljakt/c490d15444eee55cc4096d821a1565c0 to your computer and use it in GitHub Desktop.
deserialize-bids-file
[package]
name = "csv-bids-deserializer"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1.34.0", features = ["full"] }
serde = { version = "1.0.192", features = ["derive"] }
csv-core = "0.1.11"
lexical-parse-float = "0.8.5"
atoi = "2.0.0"
use serde::Deserialize;
use tokio::io::AsyncBufReadExt;
fn main() {
let time = std::time::Instant::now();
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let f = tokio::fs::File::open("../data/bids.csv").await.unwrap();
let mut input = tokio::io::BufReader::new(f);
let mut buf = Vec::with_capacity(1024 * 30);
let mut reader = Reader::<1024>::new(',');
loop {
match input.read_until(b'\n', &mut buf).await {
Ok(0) => break,
Ok(n) => {
let mut deserializer = Deserializer::new(&mut reader, &buf[0..n]);
if let Ok(_) = Bid::deserialize(&mut deserializer) {
// do nothing (just test the reading + deserializing)
}
}
Err(e) => panic!("{e}"),
}
}
});
println!("Job Execution Time {:?}", time.elapsed());
}
// Example bids.csv (10M lines):
//
// 1000,1001,73134520,channel-7568,https://www.nexmark.com/rswp/bsu/_gzj/item.htm?query=1&channel_id=163053568,1699628784800,tjegpemlelrhcglaovelrtxwcwcintpbwbhwemkngirkduwbqfbwmnrtegvmrittzvxgswwdln
// ...
#[allow(unused)]
#[derive(Debug, serde::Deserialize)]
pub struct Bid {
auction: u64,
bidder: u64,
price: u64,
channel: String,
url: String,
date_time: u64,
extra: String,
}
// Deserialization using csv-score
pub struct Reader<const N: usize> {
inner: csv_core::Reader,
buffer: [u8; N],
record_ends: [usize; 30], // To hold the ends of each field
}
struct Deserializer<'a, const N: usize> {
reader: &'a mut Reader<N>,
input: &'a [u8],
nread: usize,
current_field_index: usize, // To keep track of the current field index
total_fields: usize, // Total number of fields in the current record
}
impl<const N: usize> Reader<N> {
#[allow(clippy::new_without_default)]
pub fn new(sep: char) -> Self {
Self {
inner: csv_core::ReaderBuilder::new().delimiter(sep as u8).build(),
buffer: [0; N],
record_ends: [0; 30],
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Error {
/// Buffer overflow.
Overflow,
/// Expected an empty field.
ExpectedEmpty,
/// Invalid boolean value. Expected either `true` or `false`.
InvalidBool(String),
/// Invalid integer.
InvalidInt(String),
/// Invalid floating-point number.
InvalidFloat(lexical_parse_float::Error),
/// Invalid UTF-8 encoded character.
InvalidChar(String),
/// Invalid UTF-8 encoded string.
InvalidStr(std::str::Utf8Error),
/// Invalid UTF-8 encoded string.
InvalidString(std::string::FromUtf8Error),
/// Error with a custom message had to be discard.
Custom(String),
BufferTooSmall,
EndOfRecord,
}
pub type Result<T> = std::result::Result<T, Error>;
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Overflow => write!(f, "Buffer overflow."),
Error::ExpectedEmpty => write!(f, "Expected an empty field."),
Error::InvalidBool(s) => write!(f, "Invalid bool: {s}"),
Error::InvalidInt(s) => write!(f, "Invalid integer: {s}"),
Error::InvalidFloat(e) => write!(f, "Invalid float: {e}"),
Error::InvalidChar(s) => write!(f, "Invalid character: {s}"),
Error::InvalidStr(e) => write!(f, "Invalid string: {e}"),
Error::InvalidString(e) => write!(f, "Invalid string: {e}"),
Error::Custom(s) => write!(f, "CSV does not match deserializer's expected format: {s}"),
Error::BufferTooSmall => write!(f, "Buffer too small"),
Error::EndOfRecord => write!(f, "End of record"),
}
}
}
impl serde::de::StdError for Error {}
impl serde::de::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Self::Custom(msg.to_string())
}
}
impl<'a, const N: usize> Deserializer<'a, N> {
pub fn new(reader: &'a mut Reader<N>, input: &'a [u8]) -> Self {
Self {
reader,
input,
nread: 0,
current_field_index: 0,
total_fields: 0,
}
}
fn advance_record(&mut self) -> Result<()> {
self.current_field_index = 0;
let (result, r, _w, ends) = self.reader.inner.read_record(
&self.input[self.nread..],
&mut self.reader.buffer[0..],
&mut self.reader.record_ends[0..],
);
self.nread += r;
self.total_fields = ends;
match result {
csv_core::ReadRecordResult::InputEmpty => {}
csv_core::ReadRecordResult::End => {}
csv_core::ReadRecordResult::OutputFull => return Err(Error::Overflow),
csv_core::ReadRecordResult::OutputEndsFull => return Err(Error::BufferTooSmall),
csv_core::ReadRecordResult::Record => {}
}
Ok(())
}
/// Read a record from the CSV input.
// fn advance_record(&mut self) -> Result<()> {
// let (result, r) = self.reader.inner.read_record(&self.input[self.nread..]);
// self.nread += r;
// match result {
// csv_core::ReadRecordResult::InputEmpty => {}
// csv_core::ReadRecordResult::RecordEnd => self.record_end = true,
// csv_core::ReadRecordResult::End => {}
// }
// Ok(())
// }
fn read_bytes(&mut self) -> Result<&[u8]> {
if self.current_field_index >= self.total_fields {
return Err(Error::EndOfRecord); // New error variant for end of record
}
let start = if self.current_field_index == 0 {
0
} else {
self.reader.record_ends[self.current_field_index - 1]
};
let end = self.reader.record_ends[self.current_field_index];
self.current_field_index += 1;
Ok(&self.reader.buffer[start..end])
}
fn read_int<T: atoi::FromRadix10SignedChecked>(&mut self) -> Result<T> {
let bytes = self.read_bytes()?;
atoi::atoi(bytes)
.ok_or_else(|| Error::InvalidInt(std::str::from_utf8(bytes).unwrap().to_string()))
}
fn read_float<T: lexical_parse_float::FromLexical>(&mut self) -> Result<T> {
T::from_lexical(self.read_bytes()?)
.map_err(|e: lexical_parse_float::Error| Error::InvalidFloat(e))
}
fn read_bool(&mut self) -> Result<bool> {
let bytes = self.read_bytes()?;
match bytes {
b"true" => Ok(true),
b"false" => Ok(false),
_ => Err(Error::InvalidBool(
std::str::from_utf8(bytes).unwrap().to_string(),
)),
}
}
fn read_char(&mut self) -> Result<char> {
let str = self.read_str()?;
let mut iter = str.chars();
let c = iter
.next()
.ok_or_else(|| Error::InvalidChar(str.to_string()))?;
if iter.next().is_some() {
return Err(Error::InvalidChar(str.to_string()));
} else {
Ok(c)
}
}
fn read_str(&mut self) -> Result<&str> {
std::str::from_utf8(self.read_bytes()?)
.map_err(|e: std::str::Utf8Error| Error::InvalidStr(e))
}
fn read_string(&mut self) -> Result<String> {
std::string::String::from_utf8(self.read_bytes()?.to_vec())
.map_err(|e| Error::InvalidString(e))
}
}
impl<'de, 'a, 'b, const N: usize> serde::de::Deserializer<'de> for &'a mut Deserializer<'b, N> {
type Error = Error;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
unreachable!("`Deserializer::deserialize_any` is not supported")
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_bool(self.read_bool()?)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_i8(self.read_int()?)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_i16(self.read_int()?)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_i32(self.read_int()?)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_i64(self.read_int()?)
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_u8(self.read_int()?)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_u16(self.read_int()?)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_u32(self.read_int()?)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_u64(self.read_int()?)
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_f32(self.read_float()?)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_f64(self.read_float()?)
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_char(self.read_char()?)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_str(self.read_str()?)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_string(self.read_string()?)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_bytes(self.read_bytes()?)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_byte_buf(self.read_bytes()?.to_vec())
}
fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
todo!()
// if self.peek_bytes()?.is_empty() {
// visitor.visit_none()
// } else {
// visitor.visit_some(self)
// }
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
if !self.read_bytes()?.is_empty() {
return Err(Error::ExpectedEmpty);
}
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(self)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_enum(self)
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
unimplemented!("`Deserializer::deserialize_identifier` is not supported");
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let _ = self.read_bytes()?;
visitor.visit_unit()
}
}
impl<'de, 'a, 'b, const N: usize> serde::de::VariantAccess<'de> for &'a mut Deserializer<'b, N> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<U: serde::de::DeserializeSeed<'de>>(
self,
_seed: U,
) -> Result<U::Value> {
unimplemented!("`VariantAccess::newtype_variant_seed` is not implemented");
}
fn tuple_variant<V: serde::de::Visitor<'de>>(
self,
_len: usize,
_visitor: V,
) -> Result<V::Value> {
unimplemented!("`VariantAccess::tuple_variant` is not implemented");
}
fn struct_variant<V: serde::de::Visitor<'de>>(
self,
_fields: &'static [&'static str],
_visitor: V,
) -> Result<V::Value> {
unimplemented!("`VariantAccess::struct_variant` is not implemented");
}
}
impl<'de, 'a, 'b, const N: usize> serde::de::EnumAccess<'de> for &'a mut Deserializer<'b, N> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: serde::de::DeserializeSeed<'de>,
{
use serde::de::IntoDeserializer;
let variant_name = self.read_bytes()?;
seed.deserialize(variant_name.into_deserializer())
.map(|v| (v, self))
}
}
impl<'de, 'a, 'b, const N: usize> serde::de::SeqAccess<'de> for &'a mut Deserializer<'b, N> {
type Error = Error;
fn next_element_seed<V>(&mut self, seed: V) -> Result<Option<V::Value>>
where
V: serde::de::DeserializeSeed<'de>,
{
if self.current_field_index >= self.total_fields {
self.advance_record()?;
if self.total_fields == 0 {
return Ok(None);
}
}
seed.deserialize(&mut **self).map(Some)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment