From 332505f9acd4472612af9661b0ae8bf8d09e4af3 Mon Sep 17 00:00:00 2001 From: zrkn Date: Tue, 13 Feb 2018 15:12:22 +0300 Subject: [PATCH] Initial commit --- .gitignore | 4 + Cargo.toml | 11 ++ README.md | 3 + src/de.rs | 336 ++++++++++++++++++++++++++++++++++++++ src/error.rs | 56 +++++++ src/lib.rs | 25 +++ src/ser.rs | 446 +++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 881 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 src/de.rs create mode 100644 src/error.rs create mode 100644 src/lib.rs create mode 100644 src/ser.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..143b1ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ + +/target/ +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..82e9819 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "rlua_serde" +version = "0.1.0" +authors = ["zrkn "] + +[dependencies] +rlua = "0.10" +serde = "1.0" + +[dev-dependencies] +serde_derive = "1.0" diff --git a/README.md b/README.md new file mode 100644 index 0000000..4028ece --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# rlua Serde + +(De)serializer implementation for [rlua::Value](https://docs.rs/rlua/0.12/rlua/enum.Value.html) diff --git a/src/de.rs b/src/de.rs new file mode 100644 index 0000000..c468770 --- /dev/null +++ b/src/de.rs @@ -0,0 +1,336 @@ +use serde; +use serde::de::IntoDeserializer; + +use rlua::{Value, TablePairs, TableSequence}; + +use error::{Error, Result}; + + +pub struct Deserializer<'lua> { + pub value: Value<'lua>, +} + +impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where V: serde::de::Visitor<'de> + { + match self.value { + Value::Nil => visitor.visit_unit(), + Value::Boolean(v) => visitor.visit_bool(v), + Value::Integer(v) => visitor.visit_i64(v), + Value::Number(v) => visitor.visit_f64(v), + Value::String(v) => visitor.visit_str(v.to_str()?), + Value::Table(v) => if v.contains_key(1)? { + let len = v.len()? as usize; + let mut deserializer = SeqDeserializer(v.sequence_values()); + let seq = visitor.visit_seq(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(seq) + } else { + Err(serde::de::Error::invalid_length(len, &"fewer elements in array")) + } + } else { + let len = v.len()? as usize; + let mut deserializer = MapDeserializer(v.pairs(), None); + let map = visitor.visit_map(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(map) + } else { + Err(serde::de::Error::invalid_length(len, &"fewer elements in array")) + } + }, + _ => Err(serde::de::Error::custom("invalid value type")), + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where V: serde::de::Visitor<'de> + { + match self.value { + Value::Nil => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + #[inline] + fn deserialize_enum( + self, _name: &str, _variants: &'static [&'static str], visitor: V + ) -> Result + where V: serde::de::Visitor<'de> + { + let (variant, value) = match self.value { + Value::Table(value) => { + let mut iter = value.pairs::(); + let (variant, value) = match iter.next() { + Some(v) => v?, + None => return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )), + }; + + if iter.next().is_some() { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )) + } + (variant, Some(value)) + } + Value::String(variant) => (variant.to_str()?.to_owned(), None), + _ => return Err(serde::de::Error::custom("bad enum value")), + }; + + visitor.visit_enum(EnumDeserializer { variant, value }) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf unit unit_struct newtype_struct seq tuple + tuple_struct map struct identifier ignored_any + } +} + + +struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>); + +impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where T: serde::de::DeserializeSeed<'de> + { + match self.0.next() { + Some(value) => seed.deserialize(Deserializer { value: value? }) + .map(Some), + None => Ok(None) + } + } + + fn size_hint(&self) -> Option { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + + +struct MapDeserializer<'lua>( + TablePairs<'lua, Value<'lua>, Value<'lua>>, + Option> +); + +impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { + type Error = Error; + + fn next_key_seed(&mut self, seed: T) -> Result> + where T: serde::de::DeserializeSeed<'de> + { + match self.0.next() { + Some(item) => { + let (key, value) = item?; + self.1 = Some(value); + let key_de = Deserializer { value: key }; + seed.deserialize(key_de).map(Some) + }, + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: T) -> Result + where T: serde::de::DeserializeSeed<'de> + { + match self.1.take() { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::custom("value is missing")), + } + } + + fn size_hint(&self) -> Option { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + + +struct EnumDeserializer<'lua> { + variant: String, + value: Option>, +} + +impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> { + type Error = Error; + type Variant = VariantDeserializer<'lua>; + + fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant)> + where T: serde::de::DeserializeSeed<'de> + { + let variant = self.variant.into_deserializer(); + let variant_access = VariantDeserializer { value: self.value }; + seed.deserialize(variant).map(|v| (v, variant_access)) + } +} + + +struct VariantDeserializer<'lua> { + value: Option>, +} + +impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + match self.value { + Some(_) => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::NewtypeVariant, + &"unit variant", + )), + None => Ok(()) + } + } + + fn newtype_variant_seed(self, seed: T) -> Result + where T: serde::de::DeserializeSeed<'de> + { + match self.value { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"newtype variant", + )) + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where V: serde::de::Visitor<'de> + { + match self.value { + Some(value) => serde::Deserializer::deserialize_seq( + Deserializer { value }, visitor + ), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"tuple variant", + )) + } + } + + fn struct_variant( + self, _fields: &'static [&'static str], visitor: V + ) -> Result + where V: serde::de::Visitor<'de> + { + match self.value { + Some(value) => serde::Deserializer::deserialize_map( + Deserializer { value }, visitor + ), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"struct variant", + )) + } + } +} + +#[cfg(test)] +mod tests { + use rlua::Lua; + + use from_value; + use super::*; + + #[test] + fn test_struct() { + #[derive(Deserialize, PartialEq, Debug)] + struct Test { + int: u32, + seq: Vec, + } + + let expected = Test { int: 1, seq: vec!["a".to_owned(), "b".to_owned()] }; + + let lua = Lua::new(); + let value = lua.exec::( + r#" + a = {} + a.int = 1 + a.seq = {"a", "b"} + return a + "#, + None, + ).unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + } + + + #[test] + fn test_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let lua = Lua::new(); + + let expected = E::Unit; + let value = lua.exec::( + r#" + return "Unit" + "#, + None, + ).unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + + let expected = E::Newtype(1); + let value = lua.exec::( + r#" + a = {} + a["Newtype"] = 1 + return a + "#, + None, + ).unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Tuple(1, 2); + let value = lua.exec::( + r#" + a = {} + a["Tuple"] = {1, 2} + return a + "#, + None, + ).unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Struct { a: 1 }; + let value = lua.exec::( + r#" + a = {} + a["Struct"] = {} + a["Struct"]["a"] = 1 + return a + "#, + None, + ).unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..aa9a234 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,56 @@ +use std::fmt; +use std::error::Error as StdError; +use std::result::Result as StdResult; + +use serde; +use rlua::Error as LuaError; + + +#[derive(Debug)] +pub struct Error(LuaError); + +pub type Result = StdResult; + +impl From for Error { + fn from(err: LuaError) -> Error { + Error(err) + } +} + +impl From for LuaError { + fn from(err: Error) -> LuaError { + err.0 + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(fmt) + } +} + +impl StdError for Error { + fn description(&self) -> &'static str { + "Failed to serialize to Lua value" + } +} + +impl serde::ser::Error for Error { + fn custom(msg: T) -> Self { + Error(LuaError::ToLuaConversionError { + from: "serialize", + to: "value", + message: Some(format!("{}", msg)) + }) + } +} + +impl serde::de::Error for Error { + fn custom(msg: T) -> Self { + Error(LuaError::FromLuaConversionError { + from: "value", + to: "deserialize", + message: Some(format!("{}", msg)) + }) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6f51428 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,25 @@ +extern crate rlua; +#[macro_use] +extern crate serde; +#[cfg(test)] +#[macro_use] +extern crate serde_derive; + +pub mod error; +pub mod ser; +pub mod de; + + +use rlua::{Lua, Value, Error}; + + +pub fn to_value(lua: &Lua, t: T) -> Result { + let serializer = ser::Serializer { lua }; + Ok(t.serialize(serializer)?) +} + + +pub fn from_value<'de, T: serde::Deserialize<'de>>(value: Value<'de>) -> Result { + let deserializer = de::Deserializer { value }; + Ok(T::deserialize(deserializer)?) +} diff --git a/src/ser.rs b/src/ser.rs new file mode 100644 index 0000000..8e893ac --- /dev/null +++ b/src/ser.rs @@ -0,0 +1,446 @@ +use serde; + +use rlua::{Lua, Value, Table, String as LuaString}; + +use to_value; +use error::{Error, Result}; + + +pub struct Serializer<'lua> { + pub lua: &'lua Lua, +} + +impl<'lua> serde::Serializer for Serializer<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + type SerializeSeq = SerializeVec<'lua>; + type SerializeTuple = SerializeVec<'lua>; + type SerializeTupleStruct = SerializeVec<'lua>; + type SerializeTupleVariant = SerializeTupleVariant<'lua>; + type SerializeMap = SerializeMap<'lua>; + type SerializeStruct = SerializeMap<'lua>; + type SerializeStructVariant = SerializeStructVariant<'lua>; + + #[inline] + fn serialize_bool(self, value: bool) -> Result> { + Ok(Value::Boolean(value)) + } + + #[inline] + fn serialize_i8(self, value: i8) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_i16(self, value: i16) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_i32(self, value: i32) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_i64(self, value: i64) -> Result> { + Ok(Value::Integer(value)) + } + + #[inline] + fn serialize_u8(self, value: u8) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_u16(self, value: u16) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_u32(self, value: u32) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_u64(self, value: u64) -> Result> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_f32(self, value: f32) -> Result> { + self.serialize_f64(value as f64) + } + + #[inline] + fn serialize_f64(self, value: f64) -> Result> { + Ok(Value::Number(value)) + } + + #[inline] + fn serialize_char(self, value: char) -> Result> { + let mut s = String::new(); + s.push(value); + self.serialize_str(&s) + } + + #[inline] + fn serialize_str(self, value: &str) -> Result> { + Ok(Value::String(self.lua.create_string(value)?)) + } + + #[inline] + fn serialize_bytes(self, value: &[u8]) -> Result> { + Ok(Value::Table(self.lua.create_sequence_from(value.iter().cloned())?)) + } + + #[inline] + fn serialize_unit(self) -> Result> { + Ok(Value::Nil) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result> { + self.serialize_unit() + } + + #[inline] + fn serialize_unit_variant( + self, _name: &'static str, _variant_index: u32, variant: &'static str + ) -> Result> { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct( + self, _name: &'static str, value: &T + ) -> Result> + where T: ?Sized + serde::Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, _name: &'static str, _variant_index: u32, + variant: &'static str, value: &T, + ) -> Result> + where T: ?Sized + serde::Serialize, + { + let table = self.lua.create_table()?; + let variant = self.lua.create_string(variant)?; + let value = to_value(self.lua, value)?; + table.set(variant, value)?; + Ok(Value::Table(table)) + } + + #[inline] + fn serialize_none(self) -> Result> { + self.serialize_unit() + } + + #[inline] + fn serialize_some(self, value: &T) -> Result> + where T: ?Sized + serde::Serialize, + { + value.serialize(self) + } + + fn serialize_seq(self, _len: Option) -> Result { + let table = self.lua.create_table()?; + Ok(SerializeVec { + lua: self.lua, + idx: 1, + table, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct( + self, _name: &'static str, len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, _name: &'static str, _variant_index: u32, + variant: &'static str, _len: usize, + ) -> Result { + let name = self.lua.create_string(variant)?; + let table = self.lua.create_table()?; + Ok(SerializeTupleVariant { + lua: self.lua, + idx: 1, + name, + table + }) + } + + fn serialize_map(self, _len: Option) -> Result { + let table = self.lua.create_table()?; + Ok(SerializeMap { + lua: self.lua, + next_key: None, + table, + }) + } + + fn serialize_struct(self, _name: &'static str, len: usize) -> Result { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, _name: &'static str, _variant_index: u32, + variant: &'static str, _len: usize, + ) -> Result { + let name = self.lua.create_string(variant)?; + let table = self.lua.create_table()?; + Ok(SerializeStructVariant { + lua: self.lua, + name, + table, + }) + } + +} + + +pub struct SerializeVec<'lua> { + lua: &'lua Lua, + table: Table<'lua>, + idx: u64, +} + +impl<'lua> serde::ser::SerializeSeq for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + self.table.set(self.idx, to_value(self.lua, value)?)?; + self.idx += 1; + Ok(()) + } + + fn end(self) -> Result> { + Ok(Value::Table(self.table)) + } +} + +impl<'lua> serde::ser::SerializeTuple for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'lua> serde::ser::SerializeTupleStruct for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result> { + serde::ser::SerializeSeq::end(self) + } +} + + +pub struct SerializeTupleVariant<'lua> { + lua: &'lua Lua, + name: LuaString<'lua>, + table: Table<'lua>, + idx: u64, +} + +impl<'lua> serde::ser::SerializeTupleVariant for SerializeTupleVariant<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + self.table.set(self.idx, to_value(self.lua, value)?)?; + self.idx += 1; + Ok(()) + } + + fn end(self) -> Result> { + let table = self.lua.create_table()?; + table.set(self.name, self.table)?; + Ok(Value::Table(table)) + } +} + + +pub struct SerializeMap<'lua> { + lua: &'lua Lua, + table: Table<'lua>, + next_key: Option> +} + +impl<'lua> serde::ser::SerializeMap for SerializeMap<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + self.next_key = Some(to_value(self.lua, key)?); + Ok(()) + } + + fn serialize_value(&mut self, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + let key = self.next_key.take(); + // Panic because this indicates a bug in the program rather than an + // expected failure. + let key = key.expect("serialize_value called before serialize_key"); + self.table.set(key, to_value(self.lua, value)?)?; + Ok(()) + } + + fn end(self) -> Result> { + Ok(Value::Table(self.table)) + } +} + +impl<'lua> serde::ser::SerializeStruct for SerializeMap<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + serde::ser::SerializeMap::serialize_key(self, key)?; + serde::ser::SerializeMap::serialize_value(self, value) + } + + fn end(self) -> Result> { + serde::ser::SerializeMap::end(self) + } +} + + +pub struct SerializeStructVariant<'lua> { + lua: &'lua Lua, + name: LuaString<'lua>, + table: Table<'lua>, +} + +impl<'lua> serde::ser::SerializeStructVariant for SerializeStructVariant<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where T: ?Sized + serde::Serialize, + { + self.table + .set(key, to_value(self.lua, value)?)?; + Ok(()) + } + + fn end(self) -> Result> { + let table = self.lua.create_table()?; + table.set(self.name, self.table)?; + Ok(Value::Table(table)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_struct() { + #[derive(Serialize)] + struct Test { + int: u32, + seq: Vec<&'static str>, + } + + let test = Test { int: 1, seq: vec!["a", "b"] }; + + let lua = Lua::new(); + let value = to_value(&lua, &test).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.exec::<()>( + r#" + assert(value["int"] == 1) + assert(value["seq"][1] == "a") + assert(value["seq"][2] == "b") + "#, + None, + ).unwrap(); + } + + #[test] + fn test_num() { + #[derive(Serialize)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32}, + } + + let lua = Lua::new(); + + let u = E::Unit; + let value = to_value(&lua, &u).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.exec::<()>( + r#" + assert(value == "Unit") + "#, + None, + ).unwrap(); + + let n = E::Newtype(1); + let value = to_value(&lua, &n).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.exec::<()>( + r#" + assert(value["Newtype"] == 1) + "#, + None, + ).unwrap(); + + let t = E::Tuple(1, 2); + let value = to_value(&lua, &t).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.exec::<()>( + r#" + assert(value["Tuple"][1] == 1) + assert(value["Tuple"][2] == 2) + "#, + None, + ).unwrap(); + + let s = E::Struct { a: 1 }; + let value = to_value(&lua, &s).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.exec::<()>( + r#" + assert(value["Struct"]["a"] == 1) + "#, + None, + ).unwrap(); + } +}