// Copyright 2017 The go-interpreter Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package exec provides functions for executing WebAssembly bytecode. package exec import ( "encoding/binary" "errors" "fmt" "math" "github.com/go-interpreter/wagon/disasm" "github.com/go-interpreter/wagon/exec/internal/compile" "github.com/go-interpreter/wagon/wasm" ops "github.com/go-interpreter/wagon/wasm/operators" ) var ( // ErrMultipleLinearMemories is returned by (*VM).NewVM when the module // has more then one entries in the linear memory space. ErrMultipleLinearMemories = errors.New("exec: more than one linear memories in module") // ErrInvalidArgumentCount is returned by (*VM).ExecCode when an invalid // number of arguments to the WebAssembly function are passed to it. ErrInvalidArgumentCount = errors.New("exec: invalid number of arguments to function") ) // InvalidReturnTypeError is returned by (*VM).ExecCode when the module // specifies an invalid return type value for the executed function. type InvalidReturnTypeError int8 func (e InvalidReturnTypeError) Error() string { return fmt.Sprintf("Function has invalid return value_type: %d", int8(e)) } // InvalidFunctionIndexError is returned by (*VM).ExecCode when the function // index provided is invalid. type InvalidFunctionIndexError int64 func (e InvalidFunctionIndexError) Error() string { return fmt.Sprintf("Invalid index to function index space: %d", int64(e)) } type context struct { stack []uint64 locals []uint64 code []byte pc int64 curFunc int64 } // VM is the execution context for executing WebAssembly bytecode. type VM struct { ctx context module *wasm.Module globals []uint64 memory []byte funcs []function funcTable [256]func() // RecoverPanic controls whether the `ExecCode` method // recovers from a panic and returns it as an error // instead. // A panic can occur either when executing an invalid VM // or encountering an invalid instruction, e.g. `unreachable`. RecoverPanic bool } // As per the WebAssembly spec: https://github.com/WebAssembly/design/blob/27ac254c854994103c24834a994be16f74f54186/Semantics.md#linear-memory const wasmPageSize = 65536 // (64 KB) var endianess = binary.LittleEndian // NewVM creates a new VM from a given module. If the module defines a // start function, it will be executed. func NewVM(module *wasm.Module) (*VM, error) { var vm VM if module.Memory != nil && len(module.Memory.Entries) != 0 { if len(module.Memory.Entries) > 1 { return nil, ErrMultipleLinearMemories } vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize) copy(vm.memory, module.LinearMemoryIndexSpace[0]) } vm.funcs = make([]function, len(module.FunctionIndexSpace)) vm.globals = make([]uint64, len(module.GlobalIndexSpace)) vm.newFuncTable() vm.module = module nNatives := 0 for i, fn := range module.FunctionIndexSpace { // Skip native methods as they need not be // disassembled; simply add them at the end // of the `funcs` array as is, as specified // in the spec. See the "host functions" // section of: // https://webassembly.github.io/spec/core/exec/modules.html#allocation if fn.IsHost() { vm.funcs[i] = goFunction{ typ: fn.Host.Type(), val: fn.Host, } nNatives++ continue } disassembly, err := disasm.Disassemble(fn, module) if err != nil { return nil, err } totalLocalVars := 0 totalLocalVars += len(fn.Sig.ParamTypes) for _, entry := range fn.Body.Locals { totalLocalVars += int(entry.Count) } code, table := compile.Compile(disassembly.Code) vm.funcs[i] = compiledFunction{ code: code, branchTables: table, maxDepth: disassembly.MaxDepth, totalLocalVars: totalLocalVars, args: len(fn.Sig.ParamTypes), returns: len(fn.Sig.ReturnTypes) != 0, } } for i, global := range module.GlobalIndexSpace { val, err := module.ExecInitExpr(global.Init) if err != nil { return nil, err } switch v := val.(type) { case int32: vm.globals[i] = uint64(v) case int64: vm.globals[i] = uint64(v) case float32: vm.globals[i] = uint64(math.Float32bits(v)) case float64: vm.globals[i] = uint64(math.Float64bits(v)) } } if module.Start != nil { _, err := vm.ExecCode(int64(module.Start.Index)) if err != nil { return nil, err } } return &vm, nil } // Memory returns the linear memory space for the VM. func (vm *VM) Memory() []byte { return vm.memory } func (vm *VM) pushBool(v bool) { if v { vm.pushUint64(1) } else { vm.pushUint64(0) } } func (vm *VM) fetchBool() bool { return vm.fetchInt8() != 0 } func (vm *VM) fetchInt8() int8 { i := int8(vm.ctx.code[vm.ctx.pc]) vm.ctx.pc++ return i } func (vm *VM) fetchUint32() uint32 { v := endianess.Uint32(vm.ctx.code[vm.ctx.pc:]) vm.ctx.pc += 4 return v } func (vm *VM) fetchInt32() int32 { return int32(vm.fetchUint32()) } func (vm *VM) fetchFloat32() float32 { return math.Float32frombits(vm.fetchUint32()) } func (vm *VM) fetchUint64() uint64 { v := endianess.Uint64(vm.ctx.code[vm.ctx.pc:]) vm.ctx.pc += 8 return v } func (vm *VM) fetchInt64() int64 { return int64(vm.fetchUint64()) } func (vm *VM) fetchFloat64() float64 { return math.Float64frombits(vm.fetchUint64()) } func (vm *VM) popUint64() uint64 { i := vm.ctx.stack[len(vm.ctx.stack)-1] vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-1] return i } func (vm *VM) popInt64() int64 { return int64(vm.popUint64()) } func (vm *VM) popFloat64() float64 { return math.Float64frombits(vm.popUint64()) } func (vm *VM) popUint32() uint32 { return uint32(vm.popUint64()) } func (vm *VM) popInt32() int32 { return int32(vm.popUint32()) } func (vm *VM) popFloat32() float32 { return math.Float32frombits(vm.popUint32()) } func (vm *VM) pushUint64(i uint64) { vm.ctx.stack = append(vm.ctx.stack, i) } func (vm *VM) pushInt64(i int64) { vm.pushUint64(uint64(i)) } func (vm *VM) pushFloat64(f float64) { vm.pushUint64(math.Float64bits(f)) } func (vm *VM) pushUint32(i uint32) { vm.pushUint64(uint64(i)) } func (vm *VM) pushInt32(i int32) { vm.pushUint64(uint64(i)) } func (vm *VM) pushFloat32(f float32) { vm.pushUint32(math.Float32bits(f)) } // ExecCode calls the function with the given index and arguments. // fnIndex should be a valid index into the function index space of // the VM's module. func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) { // If used as a library, client code should set vm.RecoverPanic to true // in order to have an error returned. if vm.RecoverPanic { defer func() { if r := recover(); r != nil { switch e := r.(type) { case error: err = e default: err = fmt.Errorf("exec: %v", e) } } }() } if int(fnIndex) > len(vm.funcs) { return nil, InvalidFunctionIndexError(fnIndex) } if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) { return nil, ErrInvalidArgumentCount } compiled, ok := vm.funcs[fnIndex].(compiledFunction) if !ok { panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex)) } if len(vm.ctx.stack) < compiled.maxDepth { vm.ctx.stack = make([]uint64, 0, compiled.maxDepth) } vm.ctx.locals = make([]uint64, compiled.totalLocalVars) vm.ctx.pc = 0 vm.ctx.code = compiled.code vm.ctx.curFunc = fnIndex for i, arg := range args { vm.ctx.locals[i] = arg } res := vm.execCode(compiled) if compiled.returns { rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0] switch rtrnType { case wasm.ValueTypeI32: rtrn = uint32(res) case wasm.ValueTypeI64: rtrn = uint64(res) case wasm.ValueTypeF32: rtrn = math.Float32frombits(uint32(res)) case wasm.ValueTypeF64: rtrn = math.Float64frombits(res) default: return nil, InvalidReturnTypeError(rtrnType) } } return rtrn, nil } func (vm *VM) execCode(compiled compiledFunction) uint64 { outer: for int(vm.ctx.pc) < len(vm.ctx.code) { op := vm.ctx.code[vm.ctx.pc] vm.ctx.pc++ switch op { case ops.Return: break outer case compile.OpJmp: vm.ctx.pc = vm.fetchInt64() continue case compile.OpJmpZ: target := vm.fetchInt64() if vm.popUint32() == 0 { vm.ctx.pc = target continue } case compile.OpJmpNz: target := vm.fetchInt64() preserveTop := vm.fetchBool() discard := vm.fetchInt64() if vm.popUint32() != 0 { vm.ctx.pc = target var top uint64 if preserveTop { top = vm.ctx.stack[len(vm.ctx.stack)-1] } vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(discard)] if preserveTop { vm.pushUint64(top) } continue } case ops.BrTable: index := vm.fetchInt64() label := vm.popInt32() cf, ok := vm.funcs[vm.ctx.curFunc].(compiledFunction) if !ok { panic(fmt.Sprintf("exec: function at index %d is not a compiled function", vm.ctx.curFunc)) } table := cf.branchTables[index] var target compile.Target if label >= 0 && label < int32(len(table.Targets)) { target = table.Targets[int32(label)] } else { target = table.DefaultTarget } if target.Return { break outer } vm.ctx.pc = target.Addr var top uint64 if target.PreserveTop { top = vm.ctx.stack[len(vm.ctx.stack)-1] } vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(target.Discard)] if target.PreserveTop { vm.pushUint64(top) } continue case compile.OpDiscard: place := vm.fetchInt64() vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)] case compile.OpDiscardPreserveTop: top := vm.ctx.stack[len(vm.ctx.stack)-1] place := vm.fetchInt64() vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)] vm.pushUint64(top) default: vm.funcTable[op]() } } if compiled.returns { return vm.ctx.stack[len(vm.ctx.stack)-1] } return 0 }