From a9bf01a60f8de7bbe5c71a7d057a05f929d786a4 Mon Sep 17 00:00:00 2001 From: Sergey Pepyakin Date: Wed, 13 Jun 2018 16:15:45 +0300 Subject: [PATCH] Working --- src/common/mod.rs | 37 ---- src/lib.rs | 1 - src/runner.rs | 2 - src/validation/func.rs | 300 ++++++++++++++++++++-------- src/validation/mod.rs | 2 +- src/validation/tests.rs | 426 +++++++++++++++++++++++++++++++++++++++- src/validation/util.rs | 9 +- 7 files changed, 641 insertions(+), 136 deletions(-) diff --git a/src/common/mod.rs b/src/common/mod.rs index 4b90dc9..49ff10c 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,4 +1,3 @@ -use parity_wasm::elements::BlockType; pub mod stack; @@ -8,39 +7,3 @@ pub const DEFAULT_MEMORY_INDEX: u32 = 0; pub const DEFAULT_TABLE_INDEX: u32 = 0; // TODO: Move BlockFrame under validation. - -/// Control stack frame. -#[derive(Debug, Clone)] -pub struct BlockFrame { - /// Frame type. - pub frame_type: BlockFrameType, - /// A signature, which is a block signature type indicating the number and types of result values of the region. - pub block_type: BlockType, - /// A label for reference to block instruction. - pub begin_position: usize, - /// A label for reference from branch instructions. - pub branch_position: usize, - /// A label for reference from end instructions. - pub end_position: usize, - /// A limit integer value, which is an index into the value stack indicating where to reset it to on a branch to that label. - pub value_stack_len: usize, - /// Boolean which signals whether value stack became polymorphic. Value stack starts in non-polymorphic state and - /// becomes polymorphic only after an instruction that never passes control further is executed, - /// i.e. `unreachable`, `br` (but not `br_if`!), etc. - pub polymorphic_stack: bool, -} - -/// Type of block frame. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum BlockFrameType { - /// Function frame. - Function, - /// Usual block frame. - Block, - /// Loop frame (branching to the beginning of block). - Loop, - /// True-subblock of if expression. - IfTrue, - /// False-subblock of if expression. - IfFalse, -} diff --git a/src/lib.rs b/src/lib.rs index c46eff5..c563c75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,7 +109,6 @@ extern crate nan_preserving_float; use std::fmt; use std::error; -use std::collections::HashMap; /// Error type which can thrown by wasm code or by host environment. /// diff --git a/src/runner.rs b/src/runner.rs index 4cbb17f..4877d33 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -21,8 +21,6 @@ use isa; /// Maximum number of entries in value stack. pub const DEFAULT_VALUE_STACK_LIMIT: usize = 16384; -/// Maximum number of entries in frame stack. -pub const DEFAULT_FRAME_STACK_LIMIT: usize = 16384; /// Function interpreter. pub struct Interpreter<'a, E: Externals + 'a> { diff --git a/src/validation/func.rs b/src/validation/func.rs index 76137de..e8dfecc 100644 --- a/src/validation/func.rs +++ b/src/validation/func.rs @@ -24,8 +24,6 @@ struct BlockFrame { block_type: BlockType, /// A label for reference to block instruction. begin_position: usize, - // TODO: - branch_label: LabelId, /// A limit integer value, which is an index into the value stack indicating where to reset it to on a branch to that label. value_stack_len: usize, /// Boolean which signals whether value stack became polymorphic. Value stack starts in non-polymorphic state and @@ -37,35 +35,63 @@ struct BlockFrame { /// Type of block frame. #[derive(Debug, Clone, Copy, PartialEq)] enum BlockFrameType { - /// Function frame. - Function, /// Usual block frame. - Block, + /// + /// Can be used for an implicit function block. + Block { + end_label: LabelId, + }, /// Loop frame (branching to the beginning of block). - Loop, + Loop { + header: LabelId, + }, /// True-subblock of if expression. - IfTrue, + IfTrue { + /// If jump happens inside the if-true block then control will + /// land on this label. + end_label: LabelId, + + /// If the condition of the `if` statement is unsatisfied, control + /// will land on this label. This label might point to `else` block if it + /// exists. Otherwise it equal to `end_label`. + if_not: LabelId, + }, /// False-subblock of if expression. - IfFalse, + IfFalse { + end_label: LabelId, + } } -/// Function validation context. -struct FunctionValidationContext<'a> { - /// Wasm module - module: &'a ModuleContext, - /// Current instruction position. - position: usize, - /// Local variables. - locals: Locals<'a>, - /// Value stack. - value_stack: StackWithLimit, - /// Frame stack. - frame_stack: StackWithLimit, - /// Function return type. None if validating expression. - return_type: Option, +impl BlockFrameType { + /// Returns a label which should be used as a branch destination. + fn br_destination(&self) -> LabelId { + match *self { + BlockFrameType::Block { end_label } => end_label, + BlockFrameType::Loop { header } => header, + BlockFrameType::IfTrue { end_label, .. } => end_label, + BlockFrameType::IfFalse { end_label } => end_label, + } + } - // TODO: comment - sink: Sink, + /// Returns a label which should be resolved at the `End` opcode. + /// + /// All block types have it except loops. Loops doesn't use end as a branch + /// destination. + fn end_label(&self) -> LabelId { + match *self { + BlockFrameType::Block { end_label } => end_label, + BlockFrameType::IfTrue { end_label, .. } => end_label, + BlockFrameType::IfFalse { end_label } => end_label, + BlockFrameType::Loop { .. } => panic!("loop doesn't use end label"), + } + } + + fn is_loop(&self) -> bool { + match *self { + BlockFrameType::Loop { .. } => true, + _ => false, + } + } } /// Value type on the stack. @@ -105,20 +131,19 @@ impl Validator { result_ty, ); - let func_label = context.sink.new_label(); - context.push_label(BlockFrameType::Function, result_ty, func_label)?; + let end_label = context.sink.new_label(); + context.push_label( + BlockFrameType::Block { + end_label, + }, + result_ty + )?; Validator::validate_function_block(&mut context, body.code().elements())?; while !context.frame_stack.is_empty() { - let branch_label = context.top_label()?.branch_label; - context.sink.resolve_label(branch_label); + context.pop_label()?; } - context.sink.emit(isa::Instruction::Return { - drop: 0, - keep: if result_ty == BlockType::NoResult { 0 } else { 1 }, - }); - Ok(context.into_code()) } @@ -145,6 +170,8 @@ impl Validator { fn validate_instruction(context: &mut FunctionValidationContext, opcode: &Opcode) -> Result { // TODO: use InstructionOutcome::*; + println!("opcode={:?}", opcode); + use self::Opcode::*; match *opcode { // Nop instruction doesn't do anything. It is safe to just skip it. @@ -155,15 +182,25 @@ impl Validator { }, Block(block_type) => { - let label = context.sink.new_label(); - context.push_label(BlockFrameType::Block, block_type, label)?; + let end_label = context.sink.new_label(); + context.push_label( + BlockFrameType::Block { + end_label + }, + block_type + )?; }, Loop(block_type) => { // Resolve loop header right away. - let loop_header = context.top_label()?.branch_label; - context.sink.resolve_label(loop_header); + let header = context.sink.new_label(); + context.sink.resolve_label(header); - context.push_label(BlockFrameType::Loop, block_type, loop_header)?; + context.push_label( + BlockFrameType::Loop { + header, + }, + block_type + )?; }, If(block_type) => { // if @@ -172,84 +209,143 @@ impl Validator { // // translates to -> // - // br_if_not $end + // br_if_not $if_not // .. - // $end + // $if_not: + + // if_not will be resolved whenever `end` or `else` operator will be met. + let if_not = context.sink.new_label(); let end_label = context.sink.new_label(); + context.pop_value(ValueType::I32.into())?; - context.push_label(BlockFrameType::IfTrue, block_type, end_label)?; + context.push_label( + BlockFrameType::IfTrue { + if_not, + end_label, + }, + block_type + )?; + + context.sink.emit_br_eqz(Target { + label: if_not, + drop: 0, + keep: 0, + }); }, Else => { + let (block_type, if_not, end_label) = { + let top_frame = context.top_label()?; + + let (if_not, end_label) = match top_frame.frame_type { + BlockFrameType::IfTrue { if_not, end_label } => (if_not, end_label), + _ => return Err(Error("Misplaced else instruction".into())), + }; + (top_frame.block_type, if_not, end_label) + }; + // First, we need to finish if-true block: add a jump from the end of the if-true block // to the "end_label" (it will be resolved at End). - let end_target = context.require_target(0)?; - let end_label = end_target.label; - context.sink.emit_br(end_target); + context.sink.emit_br(Target { + label: end_label, + drop: 0, + keep: 0, + }); + + // Resolve `if_not` to here so when if condition is unsatisfied control flow + // will jump to this label. + context.sink.resolve_label(if_not); // Then, we validate. Validator will pop the if..else block and the push else..end block. - let block_type = { - let top_frame = context.top_label()?; - if top_frame.frame_type != BlockFrameType::IfTrue { - return Err(Error("Misplaced else instruction".into())); - } - top_frame.block_type - }; context.pop_label()?; if let BlockType::Value(value_type) = block_type { context.pop_value(value_type.into())?; } - context.push_label(BlockFrameType::IfFalse, block_type, end_label)?; + context.push_label( + BlockFrameType::IfFalse { + end_label, + }, + block_type, + )?; }, End => { { let frame_type = context.top_label()?.frame_type; + if let BlockFrameType::IfTrue { if_not, .. } = frame_type { + if context.top_label()?.block_type != BlockType::NoResult { + return Err( + Error( + format!( + "If block without else required to have NoResult block type. But it have {:?} type", + context.top_label()?.block_type + ) + ) + ); + } - // If this end for a non-loop frame then we resolve it's label location to here. - if frame_type != BlockFrameType::Loop { - let loop_header = context.top_label()?.branch_label; - context.sink.resolve_label(loop_header); + context.sink.resolve_label(if_not); } } { - let top_frame = context.top_label()?; - if top_frame.frame_type == BlockFrameType::IfTrue { - if top_frame.block_type != BlockType::NoResult { - return Err(Error(format!("If block without else required to have NoResult block type. But it have {:?} type", top_frame.block_type))); - } + let frame_type = context.top_label()?.frame_type; + + // If this end for a non-loop frame then we resolve it's label location to here. + if !frame_type.is_loop() { + let end_label = frame_type.end_label(); + context.sink.resolve_label(end_label); } } + if context.frame_stack.len() == 1 { + // We are about to close the last frame. Insert + // an explicit return. + let (drop, keep) = context.drop_keep_return()?; + context.sink.emit(isa::Instruction::Return { + drop, + keep, + }); + } + context.pop_label()?; }, Br(depth) => { - Validator::validate_br(context, depth)?; let target = context.require_target(depth)?; context.sink.emit_br(target); + + Validator::validate_br(context, depth)?; + + return Ok(InstructionOutcome::Unreachable); }, BrIf(depth) => { - Validator::validate_br_if(context, depth)?; let target = context.require_target(depth)?; context.sink.emit_br_nez(target); + + Validator::validate_br_if(context, depth)?; }, BrTable(ref table, default) => { - Validator::validate_br_table(context, table, default)?; let mut targets = Vec::new(); for depth in table.iter() { let target = context.require_target(*depth)?; targets.push(target); } - let default = context.require_target(default)?; - context.sink.emit_br_table(&targets, default); + let default_target = context.require_target(default)?; + context.sink.emit_br_table(&targets, default_target); + + Validator::validate_br_table(context, table, default)?; + + return Ok(InstructionOutcome::Unreachable); }, Return => { - Validator::validate_return(context)?; let (drop, keep) = context.drop_keep_return()?; context.sink.emit(isa::Instruction::Return { drop, keep, }); + + Validator::validate_return(context)?; + + return Ok(InstructionOutcome::Unreachable); }, Call(index) => { @@ -280,10 +376,8 @@ impl Validator { ); }, SetLocal(index) => { - // We need to calculate relative depth before validation since - // it will change value stack size. - let depth = context.relative_local_depth(index)?; Validator::validate_set_local(context, index)?; + let depth = context.relative_local_depth(index)?; context.sink.emit( isa::Instruction::SetLocal(depth), ); @@ -1205,7 +1299,7 @@ impl Validator { let frame = context.require_label(idx)?; (frame.frame_type, frame.block_type) }; - if frame_type != BlockFrameType::Loop { + if !frame_type.is_loop() { if let BlockType::Value(value_type) = frame_block_type { context.tee_value(value_type.into())?; } @@ -1220,7 +1314,7 @@ impl Validator { let frame = context.require_label(idx)?; (frame.frame_type, frame.block_type) }; - if frame_type != BlockFrameType::Loop { + if !frame_type.is_loop() { if let BlockType::Value(value_type) = frame_block_type { context.tee_value(value_type.into())?; } @@ -1231,7 +1325,7 @@ impl Validator { fn validate_br_table(context: &mut FunctionValidationContext, table: &[u32], default: u32) -> Result { let required_block_type: BlockType = { let default_block = context.require_label(default)?; - let required_block_type = if default_block.frame_type != BlockFrameType::Loop { + let required_block_type = if !default_block.frame_type.is_loop() { default_block.block_type } else { BlockType::NoResult @@ -1239,7 +1333,7 @@ impl Validator { for label in table { let label_block = context.require_label(*label)?; - let label_block_type = if label_block.frame_type != BlockFrameType::Loop { + let label_block_type = if !label_block.frame_type.is_loop() { label_block.block_type } else { BlockType::NoResult @@ -1322,6 +1416,25 @@ impl Validator { } } +/// Function validation context. +struct FunctionValidationContext<'a> { + /// Wasm module + module: &'a ModuleContext, + /// Current instruction position. + position: usize, + /// Local variables. + locals: Locals<'a>, + /// Value stack. + value_stack: StackWithLimit, + /// Frame stack. + frame_stack: StackWithLimit, + /// Function return type. + return_type: BlockType, + + // TODO: comment + sink: Sink, +} + impl<'a> FunctionValidationContext<'a> { fn new( module: &'a ModuleContext, @@ -1336,7 +1449,7 @@ impl<'a> FunctionValidationContext<'a> { locals: locals, value_stack: StackWithLimit::with_limit(value_stack_limit), frame_stack: StackWithLimit::with_limit(frame_stack_limit), - return_type: Some(return_type), + return_type: return_type, sink: Sink::new(), } } @@ -1395,12 +1508,11 @@ impl<'a> FunctionValidationContext<'a> { Ok(self.frame_stack.top()?) } - fn push_label(&mut self, frame_type: BlockFrameType, block_type: BlockType, branch_label: LabelId) -> Result<(), Error> { + fn push_label(&mut self, frame_type: BlockFrameType, block_type: BlockType) -> Result<(), Error> { Ok(self.frame_stack.push(BlockFrame { frame_type: frame_type, block_type: block_type, begin_position: self.position, - branch_label, value_stack_len: self.value_stack.len(), polymorphic_stack: false, })?) @@ -1420,6 +1532,11 @@ impl<'a> FunctionValidationContext<'a> { let frame = self.frame_stack.pop()?; if self.value_stack.len() != frame.value_stack_len { + if true { + panic!("Unexpected stack height {}, expected {}", + self.value_stack.len(), + frame.value_stack_len) + } return Err(Error(format!( "Unexpected stack height {}, expected {}", self.value_stack.len(), @@ -1439,7 +1556,7 @@ impl<'a> FunctionValidationContext<'a> { } fn return_type(&self) -> Result { - self.return_type.ok_or(Error("Trying to return from expression".into())) + Ok(self.return_type) } fn require_local(&self, idx: u32) -> Result { @@ -1448,19 +1565,24 @@ impl<'a> FunctionValidationContext<'a> { fn require_target(&self, depth: u32) -> Result { let frame = self.require_label(depth)?; - let label = frame.branch_label; let keep: u8 = match (frame.frame_type, frame.block_type) { - (BlockFrameType::Loop, _) => 0, + (BlockFrameType::Loop { .. }, _) => 0, (_, BlockType::NoResult) => 0, (_, BlockType::Value(_)) => 1, }; let value_stack_height = self.value_stack.len() as u32; - let drop = value_stack_height - frame.value_stack_len as u32 - keep as u32; + let drop = if frame.polymorphic_stack { 0 } else { + // TODO + println!("value_stack_height = {}", value_stack_height); + println!("frame.value_stack_len = {}", frame.value_stack_len); + println!("keep = {}", keep); + (value_stack_height - frame.value_stack_len as u32) - keep as u32 + }; Ok(Target { - label, + label: frame.frame_type.br_destination(), keep, drop, }) @@ -1468,8 +1590,11 @@ impl<'a> FunctionValidationContext<'a> { fn drop_keep_return(&self) -> Result<(u32, u8), Error> { // TODO: Refactor - let deepest = self.frame_stack.len(); - let target = self.require_target(deepest as u32)?; + let deepest = (self.frame_stack.len() - 1) as u32; + let mut target = self.require_target(deepest)?; + + // Drop all local variables and parameters upon exit. + target.drop += self.locals.count()?; Ok((target.drop, target.keep)) } @@ -1684,8 +1809,10 @@ impl Sink { if let Label::Resolved(_) = self.labels[label.0] { panic!("Trying to resolve already resolved label"); } - let dst_pc = self.cur_pc(); + + // Patch all relocations that was previously recorded for this + // particular label. let unresolved_rels = self.unresolved.remove(&label).unwrap_or(Vec::new()); for reloc in unresolved_rels { match reloc { @@ -1701,6 +1828,9 @@ impl Sink { } } } + + // Mark this label as resolved. + self.labels[label.0] = Label::Resolved(dst_pc); } fn into_inner(self) -> Vec { diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 657b5fb..823cc09 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -1,6 +1,6 @@ use std::error; use std::fmt; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use parity_wasm::elements::{ BlockType, External, GlobalEntry, GlobalType, Internal, MemoryType, Module, Opcode, ResizableLimits, TableType, ValueType, InitExpr, Type diff --git a/src/validation/tests.rs b/src/validation/tests.rs index 834ff5c..83439b7 100644 --- a/src/validation/tests.rs +++ b/src/validation/tests.rs @@ -1,8 +1,8 @@ -use super::validate_module; +use super::{validate_module, ValidatedModule}; use parity_wasm::builder::module; use parity_wasm::elements::{ - External, GlobalEntry, GlobalType, ImportEntry, InitExpr, MemoryType, - Opcode, Opcodes, TableType, ValueType, BlockType, deserialize_buffer, + External, GlobalEntry, GlobalType, ImportEntry, InitExpr, MemoryType, + Opcode, Opcodes, TableType, ValueType, BlockType, deserialize_buffer, Module, }; use isa; @@ -303,16 +303,21 @@ fn if_else_with_return_type_validation() { validate_module(m).unwrap(); } -fn compile(wat: &str) -> Vec { +fn validate(wat: &str) -> ValidatedModule { let wasm = wabt::wat2wasm(wat).unwrap(); let module = deserialize_buffer::(&wasm).unwrap(); let validated_module = validate_module(module).unwrap(); + validated_module +} + +fn compile(wat: &str) -> Vec { + let validated_module = validate(wat); let code = &validated_module.code_map[0]; code.code.clone() } #[test] -fn explicit_return_no_value() { +fn implicit_return_no_value() { let code = compile(r#" (module (func (export "call") @@ -331,7 +336,7 @@ fn explicit_return_no_value() { } #[test] -fn explicit_return_with_value() { +fn implicit_return_with_value() { let code = compile(r#" (module (func (export "call") (result i32) @@ -352,8 +357,8 @@ fn explicit_return_with_value() { } #[test] -fn explicit_return_param() { - let code = compile(r#" +fn implicit_return_param() { + let code = compile(r#" (module (func (export "call") (param i32) ) @@ -369,3 +374,408 @@ fn explicit_return_param() { ] ) } + +#[test] +fn get_local() { + let code = compile(r#" + (module + (func (export "call") (param i32) (result i32) + get_local 0 + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::GetLocal(1), + isa::Instruction::Return { + drop: 1, + keep: 1, + } + ] + ) +} + +#[test] +fn explicit_return() { + let code = compile(r#" + (module + (func (export "call") (param i32) (result i32) + get_local 0 + return + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::GetLocal(1), + isa::Instruction::Return { + drop: 1, + keep: 1, + }, + isa::Instruction::Return { + drop: 1, + keep: 1, + } + ] + ) +} + +#[test] +fn add_params() { + let code = compile(r#" + (module + (func (export "call") (param i32) (param i32) (result i32) + get_local 0 + get_local 1 + i32.add + ) + ) + "#); + assert_eq!( + code, + vec![ + // This is tricky. Locals are now loaded from the stack. The load + // happens from address relative of the current stack pointer. The first load + // takes the value below the previous one (i.e the second argument) and then, it increments + // the stack pointer. And then the same thing hapens with the value below the previous one + // (which happens to be the value loaded by the first get_local). + isa::Instruction::GetLocal(2), + isa::Instruction::GetLocal(2), + isa::Instruction::I32Add, + isa::Instruction::Return { + drop: 2, + keep: 1, + } + ] + ) +} + +#[test] +fn drop_locals() { + let code = compile(r#" + (module + (func (export "call") (param i32) + (local i32) + get_local 0 + set_local 1 + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::GetLocal(2), + isa::Instruction::SetLocal(1), + isa::Instruction::Return { + drop: 2, + keep: 0, + } + ] + ) +} + +#[test] +fn if_without_else() { + let code = compile(r#" + (module + (func (export "call") (param i32) (result i32) + i32.const 1 + if + i32.const 2 + return + end + i32.const 3 + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfEqz(isa::Target { + dst_pc: 4, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(2), + isa::Instruction::Return { + drop: 1, // 1 param + keep: 1, // 1 result + }, + isa::Instruction::I32Const(3), + isa::Instruction::Return { + drop: 1, + keep: 1, + }, + ] + ) +} + +#[test] +fn if_else() { + let code = compile(r#" + (module + (func (export "call") + (local i32) + i32.const 1 + if + i32.const 2 + set_local 0 + else + i32.const 3 + set_local 0 + end + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfEqz(isa::Target { + dst_pc: 5, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(2), + isa::Instruction::SetLocal(1), + isa::Instruction::Br(isa::Target { + dst_pc: 7, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(3), + isa::Instruction::SetLocal(1), + isa::Instruction::Return { + drop: 1, + keep: 0, + }, + ] + ) +} + +#[test] +fn if_else_returns_result() { + let code = compile(r#" + (module + (func (export "call") + i32.const 1 + if (result i32) + i32.const 2 + else + i32.const 3 + end + drop + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfEqz(isa::Target { + dst_pc: 4, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(2), + isa::Instruction::Br(isa::Target { + dst_pc: 5, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(3), + isa::Instruction::Drop, + isa::Instruction::Return { + drop: 0, + keep: 0, + }, + ] + ) +} + +#[test] +fn if_else_branch_from_true_branch() { + let code = compile(r#" + (module + (func (export "call") + i32.const 1 + if (result i32) + i32.const 1 + i32.const 1 + br_if 0 + drop + i32.const 2 + else + i32.const 3 + end + drop + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfEqz(isa::Target { + dst_pc: 8, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(1), + isa::Instruction::I32Const(1), + isa::Instruction::BrIfNez(isa::Target { + dst_pc: 9, + drop: 1, // TODO: Is this correct? + keep: 1, // TODO: Is this correct? + }), + isa::Instruction::Drop, + isa::Instruction::I32Const(2), + isa::Instruction::Br(isa::Target { + dst_pc: 9, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(3), + isa::Instruction::Drop, + isa::Instruction::Return { + drop: 0, + keep: 0, + }, + ] + ) +} + +#[test] +fn if_else_branch_from_false_branch() { + let code = compile(r#" + (module + (func (export "call") + i32.const 1 + if (result i32) + i32.const 1 + else + i32.const 2 + i32.const 1 + br_if 0 + drop + i32.const 3 + end + drop + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfEqz(isa::Target { + dst_pc: 4, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(1), + isa::Instruction::Br(isa::Target { + dst_pc: 9, + drop: 0, + keep: 0, + }), + isa::Instruction::I32Const(2), + isa::Instruction::I32Const(1), + isa::Instruction::BrIfNez(isa::Target { + dst_pc: 9, + drop: 1, // TODO: Is this correct? + keep: 1, + }), + isa::Instruction::Drop, + isa::Instruction::I32Const(3), + isa::Instruction::Drop, + isa::Instruction::Return { + drop: 0, + keep: 0, + }, + ] + ) +} + +#[test] +fn empty_loop() { + let code = compile(r#" + (module + (func (export "call") + loop (result i32) + i32.const 1 + br_if 0 + i32.const 2 + end + drop + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::I32Const(1), + isa::Instruction::BrIfNez(isa::Target { + dst_pc: 0, + drop: 1, + keep: 0, + }), + isa::Instruction::I32Const(2), + isa::Instruction::Drop, + isa::Instruction::Return { + drop: 0, + keep: 0, + }, + ] + ) +} + +// TODO: Loop +// TODO: Empty loop? +// TODO: brtable + +#[test] +fn wabt_example() { + let code = compile(r#" + (module + (func (export "call") (param i32) (result i32) + block $exit + get_local 0 + br_if $exit + i32.const 1 + return + end + i32.const 2 + return + ) + ) + "#); + assert_eq!( + code, + vec![ + isa::Instruction::GetLocal(1), + isa::Instruction::BrIfNez(isa::Target { + dst_pc: 4, + keep: 0, + drop: 1, + }), + isa::Instruction::I32Const(1), + isa::Instruction::Return { + drop: 1, // 1 parameter + keep: 1, // return value + }, + isa::Instruction::I32Const(2), + isa::Instruction::Return { + drop: 1, + keep: 1, + }, + isa::Instruction::Return { + drop: 1, + keep: 1, + }, + ] + ) +} diff --git a/src/validation/util.rs b/src/validation/util.rs index 4fb3e1b..8501801 100644 --- a/src/validation/util.rs +++ b/src/validation/util.rs @@ -20,11 +20,16 @@ impl<'a> Locals<'a> { } } + /// Returns parameter count. + pub fn param_count(&self) -> u32 { + self.params.len() as u32 + } + /// Returns total count of all declared locals and paramaterers. /// /// Returns `Err` if count overflows 32-bit value. pub fn count(&self) -> Result { - let mut acc = self.params.len() as u32; + let mut acc = self.param_count(); for locals_group in self.local_groups { acc = acc .checked_add(locals_group.count()) @@ -44,7 +49,7 @@ impl<'a> Locals<'a> { } // If an index doesn't point to a param, then we have to look into local declarations. - let mut start_idx = self.params.len() as u32; + let mut start_idx = self.param_count(); for locals_group in self.local_groups { let end_idx = start_idx .checked_add(locals_group.count())