diff --git a/Cargo.toml b/Cargo.toml index c85cb43..a383f85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ memory_units = "0.3.0" [dev-dependencies] wabt = "~0.2.2" +assert_matches = "1.1" [features] # 32-bit platforms are not supported and not tested. Use this flag if you really want to use diff --git a/src/lib.rs b/src/lib.rs index b36532c..34ca058 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,10 @@ #[cfg(test)] extern crate wabt; +#[cfg(test)] +#[macro_use] +extern crate assert_matches; + extern crate parity_wasm; extern crate byteorder; extern crate memory_units as memory_units_crate; diff --git a/src/validation/func.rs b/src/validation/func.rs index 1145644..b15acc1 100644 --- a/src/validation/func.rs +++ b/src/validation/func.rs @@ -1,11 +1,11 @@ use std::u32; -use std::iter::repeat; use std::collections::HashMap; use parity_wasm::elements::{Opcode, BlockType, ValueType, TableElementType, Func, FuncBody}; use common::{DEFAULT_MEMORY_INDEX, DEFAULT_TABLE_INDEX}; use validation::context::ModuleContext; use validation::Error; +use validation::util::Locals; use common::stack::StackWithLimit; use common::{BlockFrame, BlockFrameType}; @@ -22,7 +22,7 @@ struct FunctionValidationContext<'a> { /// Current instruction position. position: usize, /// Local variables. - locals: &'a [ValueType], + locals: Locals<'a>, /// Value stack. value_stack: StackWithLimit, /// Frame stack. @@ -62,19 +62,9 @@ impl Validator { ) -> Result, Error> { let (params, result_ty) = module.require_function_type(func.type_ref())?; - // locals = (params + vars) - let mut locals = params.to_vec(); - locals.extend( - body.locals() - .iter() - .flat_map(|l| repeat(l.value_type()) - .take(l.count() as usize) - ), - ); - let mut context = FunctionValidationContext::new( &module, - &locals, + Locals::new(params, body.locals()), DEFAULT_VALUE_STACK_LIMIT, DEFAULT_FRAME_STACK_LIMIT, result_ty, @@ -585,7 +575,7 @@ impl Validator { impl<'a> FunctionValidationContext<'a> { fn new( module: &'a ModuleContext, - locals: &'a [ValueType], + locals: Locals<'a>, value_stack_limit: usize, frame_stack_limit: usize, return_type: BlockType, @@ -707,10 +697,7 @@ impl<'a> FunctionValidationContext<'a> { } fn require_local(&self, idx: u32) -> Result { - self.locals.get(idx as usize) - .cloned() - .map(Into::into) - .ok_or(Error(format!("Trying to access local with index {} when there are only {} locals", idx, self.locals.len()))) + Ok(self.locals.type_of_local(idx).map(StackValueType::from)?) } fn into_labels(self) -> HashMap { diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 268f420..3244173 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -12,6 +12,7 @@ use memory_units::Pages; mod context; mod func; +mod util; #[cfg(test)] mod tests; diff --git a/src/validation/util.rs b/src/validation/util.rs new file mode 100644 index 0000000..a219e4c --- /dev/null +++ b/src/validation/util.rs @@ -0,0 +1,112 @@ +use parity_wasm::elements::{Local, ValueType}; +use validation::Error; + +/// Locals are the concatenation of a slice of function parameters +/// with function declared local variables. +/// +/// Local variables are given in the form of groups represented by pairs +/// of a value_type and a count. +#[derive(Debug)] +pub struct Locals<'a> { + params: &'a [ValueType], + local_groups: &'a [Local], +} + +impl<'a> Locals<'a> { + pub fn new(params: &'a [ValueType], local_groups: &'a [Local]) -> Locals<'a> { + Locals { + params, + local_groups, + } + } + + /// Returns the type of a local variable (either a declared local or a param). + /// + /// Returns `Err` in the case of overflow or when idx falls out of range. + pub fn type_of_local(&self, idx: u32) -> Result { + if let Some(param) = self.params.get(idx as usize) { + return Ok(*param); + } + + // If an index doesn't point to a param, then we have to look into local declarations. + let mut start_idx = self.params.len() as u32; + for locals_group in self.local_groups { + let end_idx = start_idx + .checked_add(locals_group.count()) + .ok_or_else(|| Error(String::from("Locals range not in 32-bit range")))?; + + if idx >= start_idx && idx < end_idx { + return Ok(locals_group.value_type()); + } + + start_idx = end_idx; + } + + // We didn't find anything, that's an error. + // At this moment `start_idx` should hold the count of all locals + // (since it's either set to the `end_idx` or equal to `params.len()`) + let total_count = start_idx; + + Err(Error(format!( + "Trying to access local with index {} when there are only {} locals", + idx, total_count + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn locals_it_works() { + let params = vec![ValueType::I32, ValueType::I64]; + let local_groups = vec![Local::new(2, ValueType::F32), Local::new(2, ValueType::F64)]; + let locals = Locals::new(¶ms, &local_groups); + + assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); + assert_matches!(locals.type_of_local(1), Ok(ValueType::I64)); + assert_matches!(locals.type_of_local(2), Ok(ValueType::F32)); + assert_matches!(locals.type_of_local(3), Ok(ValueType::F32)); + assert_matches!(locals.type_of_local(4), Ok(ValueType::F64)); + assert_matches!(locals.type_of_local(5), Ok(ValueType::F64)); + assert_matches!(locals.type_of_local(6), Err(_)); + } + + #[test] + fn locals_no_declared_locals() { + let params = vec![ValueType::I32]; + let locals = Locals::new(¶ms, &[]); + + assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); + assert_matches!(locals.type_of_local(1), Err(_)); + } + + #[test] + fn locals_no_params() { + let local_groups = vec![Local::new(2, ValueType::I32), Local::new(3, ValueType::I64)]; + let locals = Locals::new(&[], &local_groups); + + assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); + assert_matches!(locals.type_of_local(1), Ok(ValueType::I32)); + assert_matches!(locals.type_of_local(2), Ok(ValueType::I64)); + assert_matches!(locals.type_of_local(3), Ok(ValueType::I64)); + assert_matches!(locals.type_of_local(4), Ok(ValueType::I64)); + assert_matches!(locals.type_of_local(5), Err(_)); + } + + #[test] + fn locals_u32_overflow() { + let local_groups = vec![ + Local::new(u32::max_value(), ValueType::I32), + Local::new(1, ValueType::I64), + ]; + let locals = Locals::new(&[], &local_groups); + + assert_matches!( + locals.type_of_local(u32::max_value() - 1), + Ok(ValueType::I32) + ); + assert_matches!(locals.type_of_local(u32::max_value()), Err(_)); + } +}