diff --git a/src/lib.rs b/src/lib.rs index 6b0604f..caa89d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -517,7 +517,7 @@ impl Module { /// assert!(module.deny_floating_point().is_err()); /// ``` pub fn deny_floating_point(&self) -> Result<(), Error> { - validation::deny_floating_point(&self.module).map_err(Into::into) + prepare::deny_floating_point(&self.module).map_err(Into::into) } /// Create `Module` from a given buffer. diff --git a/src/prepare/mod.rs b/src/prepare/mod.rs index e3c0eae..e5b2b80 100644 --- a/src/prepare/mod.rs +++ b/src/prepare/mod.rs @@ -41,3 +41,125 @@ pub fn compile_module(module: Module) -> Result { let code_map = validate_module2::(&module)?; Ok(CompiledModule { module, code_map }) } + +/// Verify that the module doesn't use floating point instructions or types. +/// +/// Returns `Err` if +/// +/// - Any of function bodies uses a floating pointer instruction (an instruction that +/// consumes or produces a value of a floating point type) +/// - If a floating point type used in a definition of a function. +pub fn deny_floating_point(module: &Module) -> Result<(), Error> { + use parity_wasm::elements::{ + Instruction::{self, *}, + Type, ValueType, + }; + + if let Some(code) = module.code_section() { + for op in code.bodies().iter().flat_map(|body| body.code().elements()) { + macro_rules! match_eq { + ($pattern:pat) => { + |val| if let $pattern = *val { true } else { false } + }; + } + + const DENIED: &[fn(&Instruction) -> bool] = &[ + match_eq!(F32Load(_, _)), + match_eq!(F64Load(_, _)), + match_eq!(F32Store(_, _)), + match_eq!(F64Store(_, _)), + match_eq!(F32Const(_)), + match_eq!(F64Const(_)), + match_eq!(F32Eq), + match_eq!(F32Ne), + match_eq!(F32Lt), + match_eq!(F32Gt), + match_eq!(F32Le), + match_eq!(F32Ge), + match_eq!(F64Eq), + match_eq!(F64Ne), + match_eq!(F64Lt), + match_eq!(F64Gt), + match_eq!(F64Le), + match_eq!(F64Ge), + match_eq!(F32Abs), + match_eq!(F32Neg), + match_eq!(F32Ceil), + match_eq!(F32Floor), + match_eq!(F32Trunc), + match_eq!(F32Nearest), + match_eq!(F32Sqrt), + match_eq!(F32Add), + match_eq!(F32Sub), + match_eq!(F32Mul), + match_eq!(F32Div), + match_eq!(F32Min), + match_eq!(F32Max), + match_eq!(F32Copysign), + match_eq!(F64Abs), + match_eq!(F64Neg), + match_eq!(F64Ceil), + match_eq!(F64Floor), + match_eq!(F64Trunc), + match_eq!(F64Nearest), + match_eq!(F64Sqrt), + match_eq!(F64Add), + match_eq!(F64Sub), + match_eq!(F64Mul), + match_eq!(F64Div), + match_eq!(F64Min), + match_eq!(F64Max), + match_eq!(F64Copysign), + match_eq!(F32ConvertSI32), + match_eq!(F32ConvertUI32), + match_eq!(F32ConvertSI64), + match_eq!(F32ConvertUI64), + match_eq!(F32DemoteF64), + match_eq!(F64ConvertSI32), + match_eq!(F64ConvertUI32), + match_eq!(F64ConvertSI64), + match_eq!(F64ConvertUI64), + match_eq!(F64PromoteF32), + match_eq!(F32ReinterpretI32), + match_eq!(F64ReinterpretI64), + match_eq!(I32TruncSF32), + match_eq!(I32TruncUF32), + match_eq!(I32TruncSF64), + match_eq!(I32TruncUF64), + match_eq!(I64TruncSF32), + match_eq!(I64TruncUF32), + match_eq!(I64TruncSF64), + match_eq!(I64TruncUF64), + match_eq!(I32ReinterpretF32), + match_eq!(I64ReinterpretF64), + ]; + + if DENIED.iter().any(|is_denied| is_denied(op)) { + return Err(Error(format!("Floating point operation denied: {:?}", op))); + } + } + } + + if let (Some(sec), Some(types)) = (module.function_section(), module.type_section()) { + let types = types.types(); + + for sig in sec.entries() { + if let Some(typ) = types.get(sig.type_ref() as usize) { + match *typ { + Type::Function(ref func) => { + if func + .params() + .iter() + .chain(func.return_type().as_ref()) + .any(|&typ| typ == ValueType::F32 || typ == ValueType::F64) + { + return Err(Error(format!("Use of floating point types denied"))); + } + } + } + } + } + } + + Ok(()) +} diff --git a/validation/src/lib.rs b/validation/src/lib.rs index f46c5a5..292755d 100644 --- a/validation/src/lib.rs +++ b/validation/src/lib.rs @@ -78,118 +78,6 @@ impl From for Error { } } -pub fn deny_floating_point(module: &Module) -> Result<(), Error> { - if let Some(code) = module.code_section() { - for op in code.bodies().iter().flat_map(|body| body.code().elements()) { - use parity_wasm::elements::Instruction::*; - - macro_rules! match_eq { - ($pattern:pat) => { - |val| if let $pattern = *val { true } else { false } - }; - } - - const DENIED: &[fn(&Instruction) -> bool] = &[ - match_eq!(F32Load(_, _)), - match_eq!(F64Load(_, _)), - match_eq!(F32Store(_, _)), - match_eq!(F64Store(_, _)), - match_eq!(F32Const(_)), - match_eq!(F64Const(_)), - match_eq!(F32Eq), - match_eq!(F32Ne), - match_eq!(F32Lt), - match_eq!(F32Gt), - match_eq!(F32Le), - match_eq!(F32Ge), - match_eq!(F64Eq), - match_eq!(F64Ne), - match_eq!(F64Lt), - match_eq!(F64Gt), - match_eq!(F64Le), - match_eq!(F64Ge), - match_eq!(F32Abs), - match_eq!(F32Neg), - match_eq!(F32Ceil), - match_eq!(F32Floor), - match_eq!(F32Trunc), - match_eq!(F32Nearest), - match_eq!(F32Sqrt), - match_eq!(F32Add), - match_eq!(F32Sub), - match_eq!(F32Mul), - match_eq!(F32Div), - match_eq!(F32Min), - match_eq!(F32Max), - match_eq!(F32Copysign), - match_eq!(F64Abs), - match_eq!(F64Neg), - match_eq!(F64Ceil), - match_eq!(F64Floor), - match_eq!(F64Trunc), - match_eq!(F64Nearest), - match_eq!(F64Sqrt), - match_eq!(F64Add), - match_eq!(F64Sub), - match_eq!(F64Mul), - match_eq!(F64Div), - match_eq!(F64Min), - match_eq!(F64Max), - match_eq!(F64Copysign), - match_eq!(F32ConvertSI32), - match_eq!(F32ConvertUI32), - match_eq!(F32ConvertSI64), - match_eq!(F32ConvertUI64), - match_eq!(F32DemoteF64), - match_eq!(F64ConvertSI32), - match_eq!(F64ConvertUI32), - match_eq!(F64ConvertSI64), - match_eq!(F64ConvertUI64), - match_eq!(F64PromoteF32), - match_eq!(F32ReinterpretI32), - match_eq!(F64ReinterpretI64), - match_eq!(I32TruncSF32), - match_eq!(I32TruncUF32), - match_eq!(I32TruncSF64), - match_eq!(I32TruncUF64), - match_eq!(I64TruncSF32), - match_eq!(I64TruncUF32), - match_eq!(I64TruncSF64), - match_eq!(I64TruncUF64), - match_eq!(I32ReinterpretF32), - match_eq!(I64ReinterpretF64), - ]; - - if DENIED.iter().any(|is_denied| is_denied(op)) { - return Err(Error(format!("Floating point operation denied: {:?}", op))); - } - } - } - - if let (Some(sec), Some(types)) = (module.function_section(), module.type_section()) { - let types = types.types(); - - for sig in sec.entries() { - if let Some(typ) = types.get(sig.type_ref() as usize) { - match *typ { - Type::Function(ref func) => { - if func - .params() - .iter() - .chain(func.return_type().as_ref()) - .any(|&typ| typ == ValueType::F32 || typ == ValueType::F64) - { - return Err(Error(format!("Use of floating point types denied"))); - } - } - } - } - } - } - - Ok(()) -} - pub trait Validation { type Output; type FunctionValidator: FunctionValidator;