diff --git a/src/lib.rs b/src/lib.rs index 618ae75..b36532c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -385,7 +385,6 @@ pub struct Module { } impl Module { - /// Create `Module` from `parity_wasm::elements::Module`. /// /// This function will load, validate and prepare a `parity_wasm`'s `Module`. @@ -431,6 +430,66 @@ impl Module { }) } + /// Fail if the module contains any floating-point operations + /// + /// # Errors + /// + /// Returns `Err` if provided `Module` is not valid. + /// + /// # Examples + /// + /// ```rust + /// # extern crate wasmi; + /// # extern crate wabt; + /// + /// let wasm_binary: Vec = + /// wabt::wat2wasm( + /// r#" + /// (module + /// (func $add (param $lhs i32) (param $rhs i32) (result i32) + /// get_local $lhs + /// get_local $rhs + /// i32.add)) + /// "#, + /// ) + /// .expect("failed to parse wat"); + /// + /// // Load wasm binary and prepare it for instantiation. + /// let module = wasmi::Module::from_buffer(&wasm_binary).expect("Parsing failed"); + /// assert!(module.deny_floating_point().is_ok()); + /// + /// let wasm_binary: Vec = + /// wabt::wat2wasm( + /// r#" + /// (module + /// (func $add (param $lhs f32) (param $rhs f32) (result f32) + /// get_local $lhs + /// get_local $rhs + /// f32.add)) + /// "#, + /// ) + /// .expect("failed to parse wat"); + /// + /// let module = wasmi::Module::from_buffer(&wasm_binary).expect("Parsing failed"); + /// assert!(module.deny_floating_point().is_err()); + /// + /// let wasm_binary: Vec = + /// wabt::wat2wasm( + /// r#" + /// (module + /// (func $add (param $lhs f32) (param $rhs f32) (result f32) + /// get_local $lhs)) + /// "#, + /// ) + /// .expect("failed to parse wat"); + /// + /// let module = wasmi::Module::from_buffer(&wasm_binary).expect("Parsing failed"); + /// assert!(module.deny_floating_point().is_err()); + /// ``` + pub fn deny_floating_point(&self) -> Result<(), Error> { + validation::deny_floating_point(&self.module).map_err(Into::into) + } + /// Create `Module` from a given buffer. /// /// This function will deserialize wasm module from a given module, diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 5e4bc1e..268f420 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -50,6 +50,119 @@ impl ::std::ops::Deref for ValidatedModule { } } +pub fn deny_floating_point(module: &Module) -> Result<(), Error> { + if let Some(code) = module.code_section() { + for op in code.bodies().iter().flat_map(|body| body.code().elements()) { + use parity_wasm::elements::Opcode::*; + + macro_rules! match_eq { + ($pattern:pat) => { + |val| if let $pattern = *val { true } else { false } + }; + } + + const DENIED: &[fn(&Opcode) -> bool] = &[ + match_eq!(F32Load(_, _)), + match_eq!(F64Load(_, _)), + match_eq!(F32Store(_, _)), + match_eq!(F64Store(_, _)), + match_eq!(F32Const(_)), + match_eq!(F64Const(_)), + match_eq!(F32Eq), + match_eq!(F32Ne), + match_eq!(F32Lt), + match_eq!(F32Gt), + match_eq!(F32Le), + match_eq!(F32Ge), + match_eq!(F64Eq), + match_eq!(F64Ne), + match_eq!(F64Lt), + match_eq!(F64Gt), + match_eq!(F64Le), + match_eq!(F64Ge), + match_eq!(F32Abs), + match_eq!(F32Neg), + match_eq!(F32Ceil), + match_eq!(F32Floor), + match_eq!(F32Trunc), + match_eq!(F32Nearest), + match_eq!(F32Sqrt), + match_eq!(F32Add), + match_eq!(F32Sub), + match_eq!(F32Mul), + match_eq!(F32Div), + match_eq!(F32Min), + match_eq!(F32Max), + match_eq!(F32Copysign), + match_eq!(F64Abs), + match_eq!(F64Neg), + match_eq!(F64Ceil), + match_eq!(F64Floor), + match_eq!(F64Trunc), + match_eq!(F64Nearest), + match_eq!(F64Sqrt), + match_eq!(F64Add), + match_eq!(F64Sub), + match_eq!(F64Mul), + match_eq!(F64Div), + match_eq!(F64Min), + match_eq!(F64Max), + match_eq!(F64Copysign), + match_eq!(F32ConvertSI32), + match_eq!(F32ConvertUI32), + match_eq!(F32ConvertSI64), + match_eq!(F32ConvertUI64), + match_eq!(F32DemoteF64), + match_eq!(F64ConvertSI32), + match_eq!(F64ConvertUI32), + match_eq!(F64ConvertSI64), + match_eq!(F64ConvertUI64), + match_eq!(F64PromoteF32), + match_eq!(F32ReinterpretI32), + match_eq!(F64ReinterpretI64), + match_eq!(I32TruncSF32), + match_eq!(I32TruncUF32), + match_eq!(I32TruncSF64), + match_eq!(I32TruncUF64), + match_eq!(I64TruncSF32), + match_eq!(I64TruncUF32), + match_eq!(I64TruncSF64), + match_eq!(I64TruncUF64), + match_eq!(I32ReinterpretF32), + match_eq!(I64ReinterpretF64), + ]; + + if DENIED.iter().any(|is_denied| is_denied(op)) { + return Err(Error(format!("Floating point operation denied: {:?}", op))); + } + } + } + + if let (Some(sec), Some(types)) = (module.function_section(), module.type_section()) { + use parity_wasm::elements::{Type, ValueType}; + + let types = types.types(); + + for sig in sec.entries() { + if let Some(typ) = types.get(sig.type_ref() as usize) { + match *typ { + Type::Function(ref func) => { + if func.params() + .iter() + .chain(func.return_type().as_ref()) + .any(|&typ| typ == ValueType::F32 || typ == ValueType::F64) + { + return Err(Error(format!("Use of floating point types denied"))); + } + } + } + } + } + } + + Ok(()) +} + pub fn validate_module(module: Module) -> Result { let mut context_builder = ModuleContextBuilder::new(); let mut imported_globals = Vec::new();