use crate::{
cenum,
types::{
serde::{builder::SetOnce, BoltKind},
BoltFloat, BoltInteger, BoltPoint2D, BoltPoint3D,
},
Point2D, Point3D,
};
use std::{fmt, marker::PhantomData, result::Result};
use serde::{
de::{
DeserializeSeed, Deserializer, EnumAccess, Error, IntoDeserializer, MapAccess, SeqAccess,
VariantAccess, Visitor,
},
forward_to_deserialize_any, Deserialize,
};
impl<'de> Deserialize<'de> for Point2D {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
BoltPoint2D::deserialize(deserializer).map(Point2D::new)
}
}
impl<'de> Deserialize<'de> for Point3D {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
BoltPoint3D::deserialize(deserializer).map(Point3D::new)
}
}
cenum!(Field { SrId, X, Y, Z });
#[derive(Clone, Debug, Default)]
struct BoltPointBuilder {
sr_id: SetOnce<BoltInteger>,
x: SetOnce<BoltFloat>,
y: SetOnce<BoltFloat>,
z: SetOnce<BoltFloat>,
}
impl BoltPointBuilder {
fn sr_id<E: Error>(&mut self, sr_id: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
match self.sr_id.try_insert_with(sr_id)? {
Ok(_) => Ok(()),
Err(_) => Err(Error::duplicate_field(Field::SrId.name())),
}
}
fn x<E: Error>(&mut self, x: impl FnOnce() -> Result<BoltFloat, E>) -> Result<(), E> {
match self.x.try_insert_with(x)? {
Ok(_) => Ok(()),
Err(_) => Err(Error::duplicate_field(Field::X.name())),
}
}
fn y<E: Error>(&mut self, y: impl FnOnce() -> Result<BoltFloat, E>) -> Result<(), E> {
match self.y.try_insert_with(y)? {
Ok(_) => Ok(()),
Err(_) => Err(Error::duplicate_field(Field::Y.name())),
}
}
fn z<E: Error>(&mut self, z: impl FnOnce() -> Result<BoltFloat, E>) -> Result<(), E> {
match self.z.try_insert_with(z)? {
Ok(_) => Ok(()),
Err(_) => Err(Error::duplicate_field(Field::Z.name())),
}
}
fn build<P: FromBuilder, E: Error>(self) -> Result<P, E> {
P::build(self)
}
}
trait FromBuilder: Sized {
fn build<E: Error>(builder: BoltPointBuilder) -> Result<Self, E>;
}
impl FromBuilder for BoltPoint2D {
fn build<E: Error>(builder: BoltPointBuilder) -> Result<Self, E> {
if builder.z.is_set() {
return Err(Error::unknown_field("z", &Field::NAMES[..3]));
}
let sr_id = builder
.sr_id
.ok_or_else(|| Error::missing_field(Field::SrId.name()))?;
let x = builder
.x
.ok_or_else(|| Error::missing_field(Field::X.name()))?;
let y = builder
.y
.ok_or_else(|| Error::missing_field(Field::Y.name()))?;
Ok(BoltPoint2D { sr_id, x, y })
}
}
impl FromBuilder for BoltPoint3D {
fn build<E: Error>(builder: BoltPointBuilder) -> Result<Self, E> {
let sr_id = builder
.sr_id
.ok_or_else(|| Error::missing_field(Field::SrId.name()))?;
let x = builder
.x
.ok_or_else(|| Error::missing_field(Field::X.name()))?;
let y = builder
.y
.ok_or_else(|| Error::missing_field(Field::Y.name()))?;
let z = builder
.z
.ok_or_else(|| Error::missing_field(Field::Z.name()))?;
Ok(BoltPoint3D { sr_id, x, y, z })
}
}
pub struct BoltPointVisitor<P, E>(PhantomData<(P, E)>);
impl BoltPointVisitor<(), ()> {
pub fn _2d<E: Error>() -> BoltPointVisitor<BoltPoint2D, E> {
BoltPointVisitor(PhantomData)
}
pub fn _3d<E: Error>() -> BoltPointVisitor<BoltPoint3D, E> {
BoltPointVisitor(PhantomData)
}
}
impl<'de, P: FromBuilder, E: Error> Visitor<'de> for BoltPointVisitor<P, E> {
type Value = P;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "struct {}", std::any::type_name::<P>())
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: ::serde::de::MapAccess<'de>,
{
let mut point = BoltPointBuilder::default();
while let Some(key) = map.next_key::<Field>()? {
match key {
Field::SrId => point.sr_id(|| map.next_value())?,
Field::X => point.x(|| map.next_value())?,
Field::Y => point.y(|| map.next_value())?,
Field::Z => point.z(|| map.next_value())?,
}
}
let point = point.build()?;
Ok(point)
}
}
impl<'de> Deserialize<'de> for BoltPoint2D {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"BoltPoint2D",
&Field::NAMES[..3],
BoltPointVisitor::_2d::<D::Error>(),
)
}
}
impl<'de> Deserialize<'de> for BoltPoint3D {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"BoltPoint3D",
Field::NAMES,
BoltPointVisitor::_3d::<D::Error>(),
)
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct BoltPoint<'de> {
sr_id: &'de BoltInteger,
x: &'de BoltFloat,
y: &'de BoltFloat,
z: Option<&'de BoltFloat>,
}
impl<'de> From<&'de BoltPoint2D> for BoltPoint<'de> {
fn from(point: &'de BoltPoint2D) -> Self {
Self {
sr_id: &point.sr_id,
x: &point.x,
y: &point.y,
z: None,
}
}
}
impl<'de> From<&'de BoltPoint3D> for BoltPoint<'de> {
fn from(point: &'de BoltPoint3D) -> Self {
Self {
sr_id: &point.sr_id,
x: &point.x,
y: &point.y,
z: Some(&point.z),
}
}
}
struct BoltPointData<'de, I, E> {
point: BoltPoint<'de>,
fields: I,
next_field: Option<Field>,
_error: PhantomData<E>,
}
impl<'de, I, E> BoltPointData<'de, I, E> {
fn new(point: BoltPoint<'de>, fields: I) -> Self {
Self {
point,
fields,
next_field: None,
_error: PhantomData,
}
}
}
impl<'de, E: Error, I: Iterator<Item = Result<Field, &'static str>>> MapAccess<'de>
for BoltPointData<'de, I, E>
{
type Error = E;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: DeserializeSeed<'de>,
{
match self.fields.next() {
Some(Ok(field)) => {
self.next_field = Some(field);
seed.deserialize(field.into_deserializer()).map(Some)
}
Some(Err(field)) => seed.deserialize(field.into_deserializer()).map(Some),
None => Ok(None),
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: DeserializeSeed<'de>,
{
match self.next_field.take() {
Some(field) => match field {
Field::SrId => seed.deserialize(self.point.sr_id.value.into_deserializer()),
Field::X => seed.deserialize(self.point.x.value.into_deserializer()),
Field::Y => seed.deserialize(self.point.y.value.into_deserializer()),
Field::Z => match self.point.z {
Some(z) => seed.deserialize(z.value.into_deserializer()),
None => Err(Error::unknown_field("z", &Field::NAMES[..3])),
},
},
None => seed.deserialize(().into_deserializer()),
}
}
}
impl<'de, E: Error, I: Iterator<Item = Result<Field, &'static str>>> SeqAccess<'de>
for BoltPointData<'de, I, E>
{
type Error = E;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
self.next_key::<Field>()?;
self.next_value_seed(seed).map(Some)
}
}
impl<'de, E: Error> IntoDeserializer<'de, E> for &'de BoltPoint2D {
type Deserializer = BoltPointDeserializer<'de, E>;
fn into_deserializer(self) -> Self::Deserializer {
BoltPointDeserializer::new(self)
}
}
impl<'de, E: Error> IntoDeserializer<'de, E> for &'de BoltPoint3D {
type Deserializer = BoltPointDeserializer<'de, E>;
fn into_deserializer(self) -> Self::Deserializer {
BoltPointDeserializer::new(self)
}
}
pub struct BoltPointDeserializer<'de, E>(BoltPoint<'de>, PhantomData<E>);
impl<'de, E: Error> BoltPointDeserializer<'de, E> {
pub(crate) fn new(point: impl Into<BoltPoint<'de>>) -> Self {
Self(point.into(), PhantomData)
}
}
impl<'de, E: Error> Deserializer<'de> for BoltPointDeserializer<'de, E> {
type Error = E;
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let fields = match self.0.z {
None => &Field::VARIANTS[..3],
Some(_) => Field::VARIANTS,
};
visitor.visit_map(BoltPointData::new(self.0, fields.iter().copied().map(Ok)))
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let fields = match len {
2 => [Field::X, Field::Y].as_slice(),
3 => match self.0.z {
None => &[Field::SrId, Field::X, Field::Y],
Some(_) => &[Field::X, Field::Y, Field::Z],
},
4 => Field::VARIANTS,
_ => return Err(Error::invalid_length(len, &"2, 3 or 4")),
};
visitor.visit_seq(BoltPointData::new(self.0, fields.iter().copied().map(Ok)))
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let fields = fields.iter().map(|o| match Field::from_str(o) {
Some(field) => Ok(field),
None => Err(*o),
});
visitor.visit_map(BoltPointData::new(self.0, fields))
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_enum(self)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_map(visitor)
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct seq identifier newtype_struct
}
}
impl<'de, E: Error> EnumAccess<'de> for BoltPointDeserializer<'de, E> {
type Error = E;
type Variant = BoltPointDeserializer<'de, E>;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
let kind = match self.0.z {
None => BoltKind::Point2D,
Some(_) => BoltKind::Point3D,
};
let val = seed.deserialize(kind.into_deserializer())?;
Ok((val, self))
}
}
impl<'de, E: Error> VariantAccess<'de> for BoltPointDeserializer<'de, E> {
type Error = E;
fn unit_variant(self) -> Result<(), Self::Error> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
seed.deserialize(self)
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let name = match self.0.z {
None => "Point2D",
Some(_) => "Point3D",
};
self.deserialize_struct(name, fields, visitor)
}
}
#[cfg(test)]
mod tests {
use std::{fmt::Debug, marker::PhantomData};
use crate::{types::BoltType, DeError};
use super::*;
impl BoltPoint2D {
pub(crate) fn to<'this, T>(&'this self) -> Result<T, DeError>
where
T: Deserialize<'this>,
{
T::deserialize(self.into_deserializer())
}
}
impl BoltPoint3D {
pub(crate) fn to<'this, T>(&'this self) -> Result<T, DeError>
where
T: Deserialize<'this>,
{
T::deserialize(self.into_deserializer())
}
}
fn test_point2d() -> BoltPoint2D {
BoltPoint2D {
sr_id: 420.into(),
x: BoltFloat::new(42.0),
y: BoltFloat::new(13.37),
}
}
#[test]
fn point2d_full_struct() {
#[derive(Debug, PartialEq, Deserialize)]
struct P {
sr_id: u64,
x: f64,
y: f64,
}
test_extract_point2d(P {
sr_id: 420,
x: 42.0,
y: 13.37,
});
}
#[test]
fn point2d_xy_struct() {
#[derive(Debug, PartialEq, Deserialize)]
struct P {
x: f64,
y: f64,
}
test_extract_point2d(P { x: 42.0, y: 13.37 });
}
#[test]
fn point2d_with_unit_types() {
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
struct P<T> {
_t: PhantomData<T>,
_u: (),
}
test_extract_point2d(P {
_t: PhantomData::<i32>,
_u: (),
});
}
#[test]
fn point2d_tuple_struct_full() {
#[derive(Debug, PartialEq, Deserialize)]
struct P(u64, f64, f64);
test_extract_point2d(P(420, 42.0, 13.37));
}
#[test]
fn point2d_tuple_struct_xy() {
#[derive(Debug, PartialEq, Deserialize)]
struct P(f64, f64);
test_extract_point2d(P(42.0, 13.37));
}
#[test]
fn point2d_tuple_full() {
test_extract_point2d((420, 42.0, 13.37));
}
#[test]
fn point2d_tuple_xy() {
test_extract_point2d((42.0, 13.37));
}
fn test_extract_point2d<P: Debug + PartialEq + for<'a> Deserialize<'a>>(expected: P) {
let point = test_point2d();
let actual = point.to::<P>().unwrap();
assert_eq!(actual, expected);
}
fn test_point3d() -> BoltPoint3D {
BoltPoint3D {
sr_id: 420.into(),
x: BoltFloat::new(42.0),
y: BoltFloat::new(13.37),
z: BoltFloat::new(84.0),
}
}
#[test]
fn point3d_full_struct() {
#[derive(Debug, PartialEq, Deserialize)]
struct P {
sr_id: u64,
x: f64,
y: f64,
z: f64,
}
test_extract_point3d(P {
sr_id: 420,
x: 42.0,
y: 13.37,
z: 84.0,
});
}
#[test]
fn point3d_xy_struct() {
#[derive(Debug, PartialEq, Deserialize)]
struct P {
x: f64,
y: f64,
}
test_extract_point3d(P { x: 42.0, y: 13.37 });
}
#[test]
fn point3d_xyz_struct() {
#[derive(Debug, PartialEq, Deserialize)]
struct P {
x: f64,
y: f64,
z: f64,
}
test_extract_point3d(P {
x: 42.0,
y: 13.37,
z: 84.0,
});
}
#[test]
fn point3d_with_unit_types() {
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
struct P<T> {
_t: PhantomData<T>,
_u: (),
}
test_extract_point3d(P {
_t: PhantomData::<i32>,
_u: (),
});
}
#[test]
fn point3d_tuple_struct_full() {
#[derive(Debug, PartialEq, Deserialize)]
struct P(u64, f64, f64, f64);
test_extract_point3d(P(420, 42.0, 13.37, 84.0));
}
#[test]
fn point3d_tuple_struct_xy() {
#[derive(Debug, PartialEq, Deserialize)]
struct P(f64, f64);
test_extract_point3d(P(42.0, 13.37));
}
#[test]
fn point3d_tuple_struct_xyz() {
#[derive(Debug, PartialEq, Deserialize)]
struct P(f64, f64, f64);
test_extract_point3d(P(42.0, 13.37, 84.0));
}
#[test]
fn point3d_tuple_full() {
test_extract_point3d((420, 42.0, 13.37, 84.0));
}
#[test]
fn point3d_tuple_xy() {
test_extract_point3d((42.0, 13.37));
}
#[test]
fn point3d_tuple_xyz() {
test_extract_point3d((42.0, 13.37, 84.0));
}
fn test_extract_point3d<P: Debug + PartialEq + for<'a> Deserialize<'a>>(expected: P) {
let point = test_point3d();
let actual = point.to::<P>().unwrap();
assert_eq!(actual, expected);
}
#[test]
fn point2d_to_bolt_type() {
let point = test_point2d();
let actual = point.to::<BoltType>().unwrap();
assert_eq!(actual, BoltType::Point2D(point));
}
#[test]
fn point2d_to_bolt_point() {
let point = test_point2d();
let actual = point.to::<BoltPoint2D>().unwrap();
assert_eq!(actual, point);
}
#[test]
fn point2d_to_point() {
let point = test_point2d();
let actual = point.to::<Point2D>().unwrap();
assert_eq!(actual.sr_id(), point.sr_id.value);
assert_eq!(actual.x(), point.x.value);
assert_eq!(actual.y(), point.y.value);
}
#[test]
fn point3d_to_bolt_type() {
let point = test_point3d();
let actual = point.to::<BoltType>().unwrap();
assert_eq!(actual, BoltType::Point3D(point));
}
#[test]
fn point3d_to_bolt_point() {
let point = test_point3d();
let actual = point.to::<BoltPoint3D>().unwrap();
assert_eq!(actual, point);
}
#[test]
fn point3d_to_point() {
let point = test_point3d();
let actual = point.to::<Point3D>().unwrap();
assert_eq!(actual.sr_id(), point.sr_id.value);
assert_eq!(actual.x(), point.x.value);
assert_eq!(actual.y(), point.y.value);
}
}