land/vendor/github.com/go-interpreter/wagon/exec/vm.go

400 lines
9.7 KiB
Go

// Copyright 2017 The go-interpreter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package exec provides functions for executing WebAssembly bytecode.
package exec
import (
"encoding/binary"
"errors"
"fmt"
"math"
"github.com/go-interpreter/wagon/disasm"
"github.com/go-interpreter/wagon/exec/internal/compile"
"github.com/go-interpreter/wagon/wasm"
ops "github.com/go-interpreter/wagon/wasm/operators"
)
var (
// ErrMultipleLinearMemories is returned by (*VM).NewVM when the module
// has more then one entries in the linear memory space.
ErrMultipleLinearMemories = errors.New("exec: more than one linear memories in module")
// ErrInvalidArgumentCount is returned by (*VM).ExecCode when an invalid
// number of arguments to the WebAssembly function are passed to it.
ErrInvalidArgumentCount = errors.New("exec: invalid number of arguments to function")
)
// InvalidReturnTypeError is returned by (*VM).ExecCode when the module
// specifies an invalid return type value for the executed function.
type InvalidReturnTypeError int8
func (e InvalidReturnTypeError) Error() string {
return fmt.Sprintf("Function has invalid return value_type: %d", int8(e))
}
// InvalidFunctionIndexError is returned by (*VM).ExecCode when the function
// index provided is invalid.
type InvalidFunctionIndexError int64
func (e InvalidFunctionIndexError) Error() string {
return fmt.Sprintf("Invalid index to function index space: %d", int64(e))
}
type context struct {
stack []uint64
locals []uint64
code []byte
pc int64
curFunc int64
}
// VM is the execution context for executing WebAssembly bytecode.
type VM struct {
ctx context
module *wasm.Module
globals []uint64
memory []byte
funcs []function
funcTable [256]func()
// RecoverPanic controls whether the `ExecCode` method
// recovers from a panic and returns it as an error
// instead.
// A panic can occur either when executing an invalid VM
// or encountering an invalid instruction, e.g. `unreachable`.
RecoverPanic bool
}
// As per the WebAssembly spec: https://github.com/WebAssembly/design/blob/27ac254c854994103c24834a994be16f74f54186/Semantics.md#linear-memory
const wasmPageSize = 65536 // (64 KB)
var endianess = binary.LittleEndian
// NewVM creates a new VM from a given module. If the module defines a
// start function, it will be executed.
func NewVM(module *wasm.Module) (*VM, error) {
var vm VM
if module.Memory != nil && len(module.Memory.Entries) != 0 {
if len(module.Memory.Entries) > 1 {
return nil, ErrMultipleLinearMemories
}
vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
copy(vm.memory, module.LinearMemoryIndexSpace[0])
}
vm.funcs = make([]function, len(module.FunctionIndexSpace))
vm.globals = make([]uint64, len(module.GlobalIndexSpace))
vm.newFuncTable()
vm.module = module
nNatives := 0
for i, fn := range module.FunctionIndexSpace {
// Skip native methods as they need not be
// disassembled; simply add them at the end
// of the `funcs` array as is, as specified
// in the spec. See the "host functions"
// section of:
// https://webassembly.github.io/spec/core/exec/modules.html#allocation
if fn.IsHost() {
vm.funcs[i] = goFunction{
typ: fn.Host.Type(),
val: fn.Host,
}
nNatives++
continue
}
disassembly, err := disasm.Disassemble(fn, module)
if err != nil {
return nil, err
}
totalLocalVars := 0
totalLocalVars += len(fn.Sig.ParamTypes)
for _, entry := range fn.Body.Locals {
totalLocalVars += int(entry.Count)
}
code, table := compile.Compile(disassembly.Code)
vm.funcs[i] = compiledFunction{
code: code,
branchTables: table,
maxDepth: disassembly.MaxDepth,
totalLocalVars: totalLocalVars,
args: len(fn.Sig.ParamTypes),
returns: len(fn.Sig.ReturnTypes) != 0,
}
}
for i, global := range module.GlobalIndexSpace {
val, err := module.ExecInitExpr(global.Init)
if err != nil {
return nil, err
}
switch v := val.(type) {
case int32:
vm.globals[i] = uint64(v)
case int64:
vm.globals[i] = uint64(v)
case float32:
vm.globals[i] = uint64(math.Float32bits(v))
case float64:
vm.globals[i] = uint64(math.Float64bits(v))
}
}
if module.Start != nil {
_, err := vm.ExecCode(int64(module.Start.Index))
if err != nil {
return nil, err
}
}
return &vm, nil
}
// Memory returns the linear memory space for the VM.
func (vm *VM) Memory() []byte {
return vm.memory
}
func (vm *VM) pushBool(v bool) {
if v {
vm.pushUint64(1)
} else {
vm.pushUint64(0)
}
}
func (vm *VM) fetchBool() bool {
return vm.fetchInt8() != 0
}
func (vm *VM) fetchInt8() int8 {
i := int8(vm.ctx.code[vm.ctx.pc])
vm.ctx.pc++
return i
}
func (vm *VM) fetchUint32() uint32 {
v := endianess.Uint32(vm.ctx.code[vm.ctx.pc:])
vm.ctx.pc += 4
return v
}
func (vm *VM) fetchInt32() int32 {
return int32(vm.fetchUint32())
}
func (vm *VM) fetchFloat32() float32 {
return math.Float32frombits(vm.fetchUint32())
}
func (vm *VM) fetchUint64() uint64 {
v := endianess.Uint64(vm.ctx.code[vm.ctx.pc:])
vm.ctx.pc += 8
return v
}
func (vm *VM) fetchInt64() int64 {
return int64(vm.fetchUint64())
}
func (vm *VM) fetchFloat64() float64 {
return math.Float64frombits(vm.fetchUint64())
}
func (vm *VM) popUint64() uint64 {
i := vm.ctx.stack[len(vm.ctx.stack)-1]
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-1]
return i
}
func (vm *VM) popInt64() int64 {
return int64(vm.popUint64())
}
func (vm *VM) popFloat64() float64 {
return math.Float64frombits(vm.popUint64())
}
func (vm *VM) popUint32() uint32 {
return uint32(vm.popUint64())
}
func (vm *VM) popInt32() int32 {
return int32(vm.popUint32())
}
func (vm *VM) popFloat32() float32 {
return math.Float32frombits(vm.popUint32())
}
func (vm *VM) pushUint64(i uint64) {
vm.ctx.stack = append(vm.ctx.stack, i)
}
func (vm *VM) pushInt64(i int64) {
vm.pushUint64(uint64(i))
}
func (vm *VM) pushFloat64(f float64) {
vm.pushUint64(math.Float64bits(f))
}
func (vm *VM) pushUint32(i uint32) {
vm.pushUint64(uint64(i))
}
func (vm *VM) pushInt32(i int32) {
vm.pushUint64(uint64(i))
}
func (vm *VM) pushFloat32(f float32) {
vm.pushUint32(math.Float32bits(f))
}
// ExecCode calls the function with the given index and arguments.
// fnIndex should be a valid index into the function index space of
// the VM's module.
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
// If used as a library, client code should set vm.RecoverPanic to true
// in order to have an error returned.
if vm.RecoverPanic {
defer func() {
if r := recover(); r != nil {
switch e := r.(type) {
case error:
err = e
default:
err = fmt.Errorf("exec: %v", e)
}
}
}()
}
if int(fnIndex) > len(vm.funcs) {
return nil, InvalidFunctionIndexError(fnIndex)
}
if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
return nil, ErrInvalidArgumentCount
}
compiled, ok := vm.funcs[fnIndex].(compiledFunction)
if !ok {
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
}
if len(vm.ctx.stack) < compiled.maxDepth {
vm.ctx.stack = make([]uint64, 0, compiled.maxDepth)
}
vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
vm.ctx.pc = 0
vm.ctx.code = compiled.code
vm.ctx.curFunc = fnIndex
for i, arg := range args {
vm.ctx.locals[i] = arg
}
res := vm.execCode(compiled)
if compiled.returns {
rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
switch rtrnType {
case wasm.ValueTypeI32:
rtrn = uint32(res)
case wasm.ValueTypeI64:
rtrn = uint64(res)
case wasm.ValueTypeF32:
rtrn = math.Float32frombits(uint32(res))
case wasm.ValueTypeF64:
rtrn = math.Float64frombits(res)
default:
return nil, InvalidReturnTypeError(rtrnType)
}
}
return rtrn, nil
}
func (vm *VM) execCode(compiled compiledFunction) uint64 {
outer:
for int(vm.ctx.pc) < len(vm.ctx.code) {
op := vm.ctx.code[vm.ctx.pc]
vm.ctx.pc++
switch op {
case ops.Return:
break outer
case compile.OpJmp:
vm.ctx.pc = vm.fetchInt64()
continue
case compile.OpJmpZ:
target := vm.fetchInt64()
if vm.popUint32() == 0 {
vm.ctx.pc = target
continue
}
case compile.OpJmpNz:
target := vm.fetchInt64()
preserveTop := vm.fetchBool()
discard := vm.fetchInt64()
if vm.popUint32() != 0 {
vm.ctx.pc = target
var top uint64
if preserveTop {
top = vm.ctx.stack[len(vm.ctx.stack)-1]
}
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(discard)]
if preserveTop {
vm.pushUint64(top)
}
continue
}
case ops.BrTable:
index := vm.fetchInt64()
label := vm.popInt32()
cf, ok := vm.funcs[vm.ctx.curFunc].(compiledFunction)
if !ok {
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", vm.ctx.curFunc))
}
table := cf.branchTables[index]
var target compile.Target
if label >= 0 && label < int32(len(table.Targets)) {
target = table.Targets[int32(label)]
} else {
target = table.DefaultTarget
}
if target.Return {
break outer
}
vm.ctx.pc = target.Addr
var top uint64
if target.PreserveTop {
top = vm.ctx.stack[len(vm.ctx.stack)-1]
}
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(target.Discard)]
if target.PreserveTop {
vm.pushUint64(top)
}
continue
case compile.OpDiscard:
place := vm.fetchInt64()
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)]
case compile.OpDiscardPreserveTop:
top := vm.ctx.stack[len(vm.ctx.stack)-1]
place := vm.fetchInt64()
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)]
vm.pushUint64(top)
default:
vm.funcTable[op]()
}
}
if compiled.returns {
return vm.ctx.stack[len(vm.ctx.stack)-1]
}
return 0
}