diff --git a/src/validation/func.rs b/src/validation/func.rs index 89dcf9f..266e064 100644 --- a/src/validation/func.rs +++ b/src/validation/func.rs @@ -173,7 +173,7 @@ impl Validator { let ins_size_estimate = body.code().elements().len(); let mut context = FunctionValidationContext::new( &module, - Locals::new(params, body.locals()), + Locals::new(params, body.locals())?, DEFAULT_VALUE_STACK_LIMIT, DEFAULT_FRAME_STACK_LIMIT, result_ty, @@ -375,9 +375,8 @@ impl Validator { )?; } - let locals_count = context.locals.count()?; let DropKeep { drop, keep } = drop_keep_return( - locals_count, + &context.locals, &context.value_stack, &context.frame_stack, ); @@ -442,9 +441,8 @@ impl Validator { tee_value(&mut context.value_stack, &context.frame_stack, value_type.into())?; } - let locals_count = context.locals.count()?; let DropKeep { drop, keep } = drop_keep_return( - locals_count, + &context.locals, &context.value_stack, &context.frame_stack ); @@ -1589,7 +1587,7 @@ fn require_target( } fn drop_keep_return( - locals_count: u32, + locals: &Locals, value_stack: &StackWithLimit, frame_stack: &StackWithLimit, ) -> DropKeep { @@ -1602,7 +1600,7 @@ fn drop_keep_return( let mut drop_keep = require_target(deepest, value_stack, frame_stack).drop_keep; // Drop all local variables and parameters upon exit. - drop_keep.drop += locals_count; + drop_keep.drop += locals.count(); drop_keep } @@ -1618,7 +1616,7 @@ fn relative_local_depth( ) -> Result { // TODO: Comment stack layout let value_stack_height = value_stack.len() as u32; - let locals_and_params_count = locals.count()?; + let locals_and_params_count = locals.count(); let depth = value_stack_height .checked_add(locals_and_params_count) diff --git a/src/validation/util.rs b/src/validation/util.rs index 8501801..ed2cb96 100644 --- a/src/validation/util.rs +++ b/src/validation/util.rs @@ -10,14 +10,26 @@ use validation::Error; pub struct Locals<'a> { params: &'a [ValueType], local_groups: &'a [Local], + count: u32, } impl<'a> Locals<'a> { - pub fn new(params: &'a [ValueType], local_groups: &'a [Local]) -> Locals<'a> { - Locals { + /// Create a new wrapper around declared variables and parameters. + pub fn new(params: &'a [ValueType], local_groups: &'a [Local]) -> Result, Error> { + let mut acc = params.len() as u32; + for locals_group in local_groups { + acc = acc + .checked_add(locals_group.count()) + .ok_or_else(|| + Error(String::from("Locals range no in 32-bit range")) + )?; + } + + Ok(Locals { params, local_groups, - } + count: acc, + }) } /// Returns parameter count. @@ -26,18 +38,8 @@ impl<'a> Locals<'a> { } /// Returns total count of all declared locals and paramaterers. - /// - /// Returns `Err` if count overflows 32-bit value. - pub fn count(&self) -> Result { - let mut acc = self.param_count(); - for locals_group in self.local_groups { - acc = acc - .checked_add(locals_group.count()) - .ok_or_else(|| - Error(String::from("Locals range no in 32-bit range")) - )?; - } - Ok(acc) + pub fn count(&self) -> u32 { + self.count } /// Returns the type of a local variable (either a declared local or a param). @@ -82,7 +84,7 @@ mod tests { fn locals_it_works() { let params = vec![ValueType::I32, ValueType::I64]; let local_groups = vec![Local::new(2, ValueType::F32), Local::new(2, ValueType::F64)]; - let locals = Locals::new(¶ms, &local_groups); + let locals = Locals::new(¶ms, &local_groups).unwrap(); assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); assert_matches!(locals.type_of_local(1), Ok(ValueType::I64)); @@ -96,7 +98,7 @@ mod tests { #[test] fn locals_no_declared_locals() { let params = vec![ValueType::I32]; - let locals = Locals::new(¶ms, &[]); + let locals = Locals::new(¶ms, &[]).unwrap(); assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); assert_matches!(locals.type_of_local(1), Err(_)); @@ -105,7 +107,7 @@ mod tests { #[test] fn locals_no_params() { let local_groups = vec![Local::new(2, ValueType::I32), Local::new(3, ValueType::I64)]; - let locals = Locals::new(&[], &local_groups); + let locals = Locals::new(&[], &local_groups).unwrap(); assert_matches!(locals.type_of_local(0), Ok(ValueType::I32)); assert_matches!(locals.type_of_local(1), Ok(ValueType::I32)); @@ -121,7 +123,7 @@ mod tests { Local::new(u32::max_value(), ValueType::I32), Local::new(1, ValueType::I64), ]; - let locals = Locals::new(&[], &local_groups); + let locals = Locals::new(&[], &local_groups).unwrap(); assert_matches!( locals.type_of_local(u32::max_value() - 1),