400 lines
9.7 KiB
Go
400 lines
9.7 KiB
Go
// Copyright 2017 The go-interpreter Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package exec provides functions for executing WebAssembly bytecode.
|
|
package exec
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/go-interpreter/wagon/disasm"
|
|
"github.com/go-interpreter/wagon/exec/internal/compile"
|
|
"github.com/go-interpreter/wagon/wasm"
|
|
ops "github.com/go-interpreter/wagon/wasm/operators"
|
|
)
|
|
|
|
var (
|
|
// ErrMultipleLinearMemories is returned by (*VM).NewVM when the module
|
|
// has more then one entries in the linear memory space.
|
|
ErrMultipleLinearMemories = errors.New("exec: more than one linear memories in module")
|
|
// ErrInvalidArgumentCount is returned by (*VM).ExecCode when an invalid
|
|
// number of arguments to the WebAssembly function are passed to it.
|
|
ErrInvalidArgumentCount = errors.New("exec: invalid number of arguments to function")
|
|
)
|
|
|
|
// InvalidReturnTypeError is returned by (*VM).ExecCode when the module
|
|
// specifies an invalid return type value for the executed function.
|
|
type InvalidReturnTypeError int8
|
|
|
|
func (e InvalidReturnTypeError) Error() string {
|
|
return fmt.Sprintf("Function has invalid return value_type: %d", int8(e))
|
|
}
|
|
|
|
// InvalidFunctionIndexError is returned by (*VM).ExecCode when the function
|
|
// index provided is invalid.
|
|
type InvalidFunctionIndexError int64
|
|
|
|
func (e InvalidFunctionIndexError) Error() string {
|
|
return fmt.Sprintf("Invalid index to function index space: %d", int64(e))
|
|
}
|
|
|
|
type context struct {
|
|
stack []uint64
|
|
locals []uint64
|
|
code []byte
|
|
pc int64
|
|
curFunc int64
|
|
}
|
|
|
|
// VM is the execution context for executing WebAssembly bytecode.
|
|
type VM struct {
|
|
ctx context
|
|
|
|
module *wasm.Module
|
|
globals []uint64
|
|
memory []byte
|
|
funcs []function
|
|
|
|
funcTable [256]func()
|
|
|
|
// RecoverPanic controls whether the `ExecCode` method
|
|
// recovers from a panic and returns it as an error
|
|
// instead.
|
|
// A panic can occur either when executing an invalid VM
|
|
// or encountering an invalid instruction, e.g. `unreachable`.
|
|
RecoverPanic bool
|
|
}
|
|
|
|
// As per the WebAssembly spec: https://github.com/WebAssembly/design/blob/27ac254c854994103c24834a994be16f74f54186/Semantics.md#linear-memory
|
|
const wasmPageSize = 65536 // (64 KB)
|
|
|
|
var endianess = binary.LittleEndian
|
|
|
|
// NewVM creates a new VM from a given module. If the module defines a
|
|
// start function, it will be executed.
|
|
func NewVM(module *wasm.Module) (*VM, error) {
|
|
var vm VM
|
|
|
|
if module.Memory != nil && len(module.Memory.Entries) != 0 {
|
|
if len(module.Memory.Entries) > 1 {
|
|
return nil, ErrMultipleLinearMemories
|
|
}
|
|
vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
|
|
copy(vm.memory, module.LinearMemoryIndexSpace[0])
|
|
}
|
|
|
|
vm.funcs = make([]function, len(module.FunctionIndexSpace))
|
|
vm.globals = make([]uint64, len(module.GlobalIndexSpace))
|
|
vm.newFuncTable()
|
|
vm.module = module
|
|
|
|
nNatives := 0
|
|
for i, fn := range module.FunctionIndexSpace {
|
|
// Skip native methods as they need not be
|
|
// disassembled; simply add them at the end
|
|
// of the `funcs` array as is, as specified
|
|
// in the spec. See the "host functions"
|
|
// section of:
|
|
// https://webassembly.github.io/spec/core/exec/modules.html#allocation
|
|
if fn.IsHost() {
|
|
vm.funcs[i] = goFunction{
|
|
typ: fn.Host.Type(),
|
|
val: fn.Host,
|
|
}
|
|
nNatives++
|
|
continue
|
|
}
|
|
|
|
disassembly, err := disasm.Disassemble(fn, module)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
totalLocalVars := 0
|
|
totalLocalVars += len(fn.Sig.ParamTypes)
|
|
for _, entry := range fn.Body.Locals {
|
|
totalLocalVars += int(entry.Count)
|
|
}
|
|
code, table := compile.Compile(disassembly.Code)
|
|
vm.funcs[i] = compiledFunction{
|
|
code: code,
|
|
branchTables: table,
|
|
maxDepth: disassembly.MaxDepth,
|
|
totalLocalVars: totalLocalVars,
|
|
args: len(fn.Sig.ParamTypes),
|
|
returns: len(fn.Sig.ReturnTypes) != 0,
|
|
}
|
|
}
|
|
|
|
for i, global := range module.GlobalIndexSpace {
|
|
val, err := module.ExecInitExpr(global.Init)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch v := val.(type) {
|
|
case int32:
|
|
vm.globals[i] = uint64(v)
|
|
case int64:
|
|
vm.globals[i] = uint64(v)
|
|
case float32:
|
|
vm.globals[i] = uint64(math.Float32bits(v))
|
|
case float64:
|
|
vm.globals[i] = uint64(math.Float64bits(v))
|
|
}
|
|
}
|
|
|
|
if module.Start != nil {
|
|
_, err := vm.ExecCode(int64(module.Start.Index))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &vm, nil
|
|
}
|
|
|
|
// Memory returns the linear memory space for the VM.
|
|
func (vm *VM) Memory() []byte {
|
|
return vm.memory
|
|
}
|
|
|
|
func (vm *VM) pushBool(v bool) {
|
|
if v {
|
|
vm.pushUint64(1)
|
|
} else {
|
|
vm.pushUint64(0)
|
|
}
|
|
}
|
|
|
|
func (vm *VM) fetchBool() bool {
|
|
return vm.fetchInt8() != 0
|
|
}
|
|
|
|
func (vm *VM) fetchInt8() int8 {
|
|
i := int8(vm.ctx.code[vm.ctx.pc])
|
|
vm.ctx.pc++
|
|
return i
|
|
}
|
|
|
|
func (vm *VM) fetchUint32() uint32 {
|
|
v := endianess.Uint32(vm.ctx.code[vm.ctx.pc:])
|
|
vm.ctx.pc += 4
|
|
return v
|
|
}
|
|
|
|
func (vm *VM) fetchInt32() int32 {
|
|
return int32(vm.fetchUint32())
|
|
}
|
|
|
|
func (vm *VM) fetchFloat32() float32 {
|
|
return math.Float32frombits(vm.fetchUint32())
|
|
}
|
|
|
|
func (vm *VM) fetchUint64() uint64 {
|
|
v := endianess.Uint64(vm.ctx.code[vm.ctx.pc:])
|
|
vm.ctx.pc += 8
|
|
return v
|
|
}
|
|
|
|
func (vm *VM) fetchInt64() int64 {
|
|
return int64(vm.fetchUint64())
|
|
}
|
|
|
|
func (vm *VM) fetchFloat64() float64 {
|
|
return math.Float64frombits(vm.fetchUint64())
|
|
}
|
|
|
|
func (vm *VM) popUint64() uint64 {
|
|
i := vm.ctx.stack[len(vm.ctx.stack)-1]
|
|
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-1]
|
|
return i
|
|
}
|
|
|
|
func (vm *VM) popInt64() int64 {
|
|
return int64(vm.popUint64())
|
|
}
|
|
|
|
func (vm *VM) popFloat64() float64 {
|
|
return math.Float64frombits(vm.popUint64())
|
|
}
|
|
|
|
func (vm *VM) popUint32() uint32 {
|
|
return uint32(vm.popUint64())
|
|
}
|
|
|
|
func (vm *VM) popInt32() int32 {
|
|
return int32(vm.popUint32())
|
|
}
|
|
|
|
func (vm *VM) popFloat32() float32 {
|
|
return math.Float32frombits(vm.popUint32())
|
|
}
|
|
|
|
func (vm *VM) pushUint64(i uint64) {
|
|
vm.ctx.stack = append(vm.ctx.stack, i)
|
|
}
|
|
|
|
func (vm *VM) pushInt64(i int64) {
|
|
vm.pushUint64(uint64(i))
|
|
}
|
|
|
|
func (vm *VM) pushFloat64(f float64) {
|
|
vm.pushUint64(math.Float64bits(f))
|
|
}
|
|
|
|
func (vm *VM) pushUint32(i uint32) {
|
|
vm.pushUint64(uint64(i))
|
|
}
|
|
|
|
func (vm *VM) pushInt32(i int32) {
|
|
vm.pushUint64(uint64(i))
|
|
}
|
|
|
|
func (vm *VM) pushFloat32(f float32) {
|
|
vm.pushUint32(math.Float32bits(f))
|
|
}
|
|
|
|
// ExecCode calls the function with the given index and arguments.
|
|
// fnIndex should be a valid index into the function index space of
|
|
// the VM's module.
|
|
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
|
|
// If used as a library, client code should set vm.RecoverPanic to true
|
|
// in order to have an error returned.
|
|
if vm.RecoverPanic {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
switch e := r.(type) {
|
|
case error:
|
|
err = e
|
|
default:
|
|
err = fmt.Errorf("exec: %v", e)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
if int(fnIndex) > len(vm.funcs) {
|
|
return nil, InvalidFunctionIndexError(fnIndex)
|
|
}
|
|
if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
|
|
return nil, ErrInvalidArgumentCount
|
|
}
|
|
compiled, ok := vm.funcs[fnIndex].(compiledFunction)
|
|
if !ok {
|
|
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
|
|
}
|
|
if len(vm.ctx.stack) < compiled.maxDepth {
|
|
vm.ctx.stack = make([]uint64, 0, compiled.maxDepth)
|
|
}
|
|
vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
|
|
vm.ctx.pc = 0
|
|
vm.ctx.code = compiled.code
|
|
vm.ctx.curFunc = fnIndex
|
|
|
|
for i, arg := range args {
|
|
vm.ctx.locals[i] = arg
|
|
}
|
|
|
|
res := vm.execCode(compiled)
|
|
if compiled.returns {
|
|
rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
|
|
switch rtrnType {
|
|
case wasm.ValueTypeI32:
|
|
rtrn = uint32(res)
|
|
case wasm.ValueTypeI64:
|
|
rtrn = uint64(res)
|
|
case wasm.ValueTypeF32:
|
|
rtrn = math.Float32frombits(uint32(res))
|
|
case wasm.ValueTypeF64:
|
|
rtrn = math.Float64frombits(res)
|
|
default:
|
|
return nil, InvalidReturnTypeError(rtrnType)
|
|
}
|
|
}
|
|
|
|
return rtrn, nil
|
|
}
|
|
|
|
func (vm *VM) execCode(compiled compiledFunction) uint64 {
|
|
outer:
|
|
for int(vm.ctx.pc) < len(vm.ctx.code) {
|
|
op := vm.ctx.code[vm.ctx.pc]
|
|
vm.ctx.pc++
|
|
switch op {
|
|
case ops.Return:
|
|
break outer
|
|
case compile.OpJmp:
|
|
vm.ctx.pc = vm.fetchInt64()
|
|
continue
|
|
case compile.OpJmpZ:
|
|
target := vm.fetchInt64()
|
|
if vm.popUint32() == 0 {
|
|
vm.ctx.pc = target
|
|
continue
|
|
}
|
|
case compile.OpJmpNz:
|
|
target := vm.fetchInt64()
|
|
preserveTop := vm.fetchBool()
|
|
discard := vm.fetchInt64()
|
|
if vm.popUint32() != 0 {
|
|
vm.ctx.pc = target
|
|
var top uint64
|
|
if preserveTop {
|
|
top = vm.ctx.stack[len(vm.ctx.stack)-1]
|
|
}
|
|
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(discard)]
|
|
if preserveTop {
|
|
vm.pushUint64(top)
|
|
}
|
|
continue
|
|
}
|
|
case ops.BrTable:
|
|
index := vm.fetchInt64()
|
|
label := vm.popInt32()
|
|
cf, ok := vm.funcs[vm.ctx.curFunc].(compiledFunction)
|
|
if !ok {
|
|
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", vm.ctx.curFunc))
|
|
}
|
|
table := cf.branchTables[index]
|
|
var target compile.Target
|
|
if label >= 0 && label < int32(len(table.Targets)) {
|
|
target = table.Targets[int32(label)]
|
|
} else {
|
|
target = table.DefaultTarget
|
|
}
|
|
|
|
if target.Return {
|
|
break outer
|
|
}
|
|
vm.ctx.pc = target.Addr
|
|
var top uint64
|
|
if target.PreserveTop {
|
|
top = vm.ctx.stack[len(vm.ctx.stack)-1]
|
|
}
|
|
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(target.Discard)]
|
|
if target.PreserveTop {
|
|
vm.pushUint64(top)
|
|
}
|
|
continue
|
|
case compile.OpDiscard:
|
|
place := vm.fetchInt64()
|
|
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)]
|
|
case compile.OpDiscardPreserveTop:
|
|
top := vm.ctx.stack[len(vm.ctx.stack)-1]
|
|
place := vm.fetchInt64()
|
|
vm.ctx.stack = vm.ctx.stack[:len(vm.ctx.stack)-int(place)]
|
|
vm.pushUint64(top)
|
|
default:
|
|
vm.funcTable[op]()
|
|
}
|
|
}
|
|
|
|
if compiled.returns {
|
|
return vm.ctx.stack[len(vm.ctx.stack)-1]
|
|
}
|
|
return 0
|
|
}
|