tbotd/vendor/src/github.com/Sereal/Sereal/Go/sereal/encode.go

625 lines
16 KiB
Go

package sereal
import (
"encoding"
"encoding/binary"
"errors"
"fmt"
"math"
"reflect"
"runtime"
"unsafe"
)
// An Encoder encodes Go data structures into Sereal byte streams
type Encoder struct {
PerlCompat bool // try to mimic Perl's structure as much as possible
Compression compressor // optionally compress the main payload of the document using SnappyCompressor or ZlibCompressor
CompressionThreshold int // threshold in bytes above which compression is attempted: 1024 bytes by default
DisableDedup bool // should we disable deduping of class names and hash keys
DisableFREEZE bool // should we disable the FREEZE tag, which calls MarshalBinary
ExpectedSize uint // give a hint to encoder about expected size of encoded data
version int // default version to encode
tcache tagsCache
}
type compressor interface {
compress(b []byte) ([]byte, error)
}
// NewEncoder returns a new Encoder struct with default values
func NewEncoder() *Encoder {
return &Encoder{
PerlCompat: false,
CompressionThreshold: 1024,
version: 1,
}
}
// NewEncoderV2 returns a new Encoder that encodes version 2
func NewEncoderV2() *Encoder {
return &Encoder{
PerlCompat: false,
CompressionThreshold: 1024,
version: 2,
}
}
// NewEncoderV3 returns a new Encoder that encodes version 3
func NewEncoderV3() *Encoder {
return &Encoder{
PerlCompat: false,
CompressionThreshold: 1024,
version: 3,
}
}
var defaultEncoder = NewEncoderV3()
// Marshal encodes body with the default encoder
func Marshal(body interface{}) ([]byte, error) {
return defaultEncoder.MarshalWithHeader(nil, body)
}
// Marshal returns the Sereal encoding of body
func (e *Encoder) Marshal(body interface{}) (b []byte, err error) {
return e.MarshalWithHeader(nil, body)
}
// MarshalWithHeader returns the Sereal encoding of body with header data
func (e *Encoder) MarshalWithHeader(header interface{}, body interface{}) (b []byte, err error) {
defer func() {
//return
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
err = errors.New(s)
} else {
err = r.(error)
}
}
}()
// uninitialized encoder? set to the most recent supported protocol version
if e.version == 0 {
e.version = ProtocolVersion
}
encHeader := make([]byte, headerSize, 32)
if e.version < 3 {
binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytes)
} else {
binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytesHighBit)
}
// Set the <version-type> component in the header
encHeader[4] = byte(e.version) | byte(serealRaw)<<4
if header != nil && e.version >= 2 {
strTable := make(map[string]int)
ptrTable := make(map[uintptr]int)
// this is both the flag byte (== "there is user data") and also a hack to make 1-based offsets work
henv := []byte{0x01} // flag byte == "there is user data"
encHeaderSuffix, err := e.encode(henv, header, false, false, strTable, ptrTable)
if err != nil {
return nil, err
}
encHeader = varint(encHeader, uint(len(encHeaderSuffix)))
encHeader = append(encHeader, encHeaderSuffix...)
} else {
/* header size */
encHeader = append(encHeader, 0)
}
strTable := make(map[string]int)
ptrTable := make(map[uintptr]int)
encBody := make([]byte, 0, e.ExpectedSize)
switch e.version {
case 1:
encBody, err = e.encode(encBody, body, false, false, strTable, ptrTable)
case 2, 3:
encBody = append(encBody, 0) // hack for 1-based offsets
encBody, err = e.encode(encBody, body, false, false, strTable, ptrTable)
encBody = encBody[1:] // trim hacky first byte
}
if err != nil {
return nil, err
}
if e.Compression != nil && (e.CompressionThreshold == 0 || len(encBody) >= e.CompressionThreshold) {
encBody, err = e.Compression.compress(encBody)
if err != nil {
return nil, err
}
var doctype documentType
switch c := e.Compression.(type) {
case SnappyCompressor:
if e.version > 1 && !c.Incremental {
return nil, errors.New("non-incremental snappy compression only valid for v1 documents")
}
if e.version == 1 {
doctype = serealSnappy
} else {
doctype = serealSnappyIncremental
}
case ZlibCompressor:
if e.version < 3 {
return nil, errors.New("zlib compression only valid for v3 documents and up")
}
doctype = serealZlib
default:
// Defensive programming: this point should never be
// reached in production code because the compressor
// interface is not exported, hence no way to pass in
// an unknown thing. But it may happen during
// development when a new compressor is implemented,
// but a relevant document type is not defined.
panic("undefined compression")
}
encHeader[4] |= byte(doctype) << 4
}
return append(encHeader, encBody...), nil
}
/*************************************
* Encode via static types - fast path
*************************************/
func (e *Encoder) encode(b []byte, v interface{}, isKeyOrClass bool, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
var err error
switch value := v.(type) {
case nil:
b = append(b, typeUNDEF)
case bool:
if value {
b = append(b, typeTRUE)
} else {
b = append(b, typeFALSE)
}
case int:
b = e.encodeInt(b, reflect.Int, int64(value))
case int8:
b = e.encodeInt(b, reflect.Int, int64(value))
case int16:
b = e.encodeInt(b, reflect.Int, int64(value))
case int32:
b = e.encodeInt(b, reflect.Int, int64(value))
case int64:
b = e.encodeInt(b, reflect.Int, int64(value))
case uint:
b = e.encodeInt(b, reflect.Uint, int64(value))
case uint8:
b = e.encodeInt(b, reflect.Uint, int64(value))
case uint16:
b = e.encodeInt(b, reflect.Uint, int64(value))
case uint32:
b = e.encodeInt(b, reflect.Uint, int64(value))
case uint64:
b = e.encodeInt(b, reflect.Uint, int64(value))
case float32:
b = e.encodeFloat(b, value)
case float64:
b = e.encodeDouble(b, value)
case string:
b = e.encodeString(b, value, isKeyOrClass, strTable)
case []uint8:
b = e.encodeBytes(b, value, isKeyOrClass, strTable)
case []interface{}:
b, err = e.encodeIntfArray(b, value, isRefNext, strTable, ptrTable)
case map[string]interface{}:
b, err = e.encodeStrMap(b, value, isRefNext, strTable, ptrTable)
case reflect.Value:
if value.Kind() == reflect.Invalid {
b = append(b, typeUNDEF)
} else {
// could be optimized to tail call
b, err = e.encode(b, value.Interface(), false, isRefNext, strTable, ptrTable)
}
case PerlUndef:
if value.canonical {
b = append(b, typeCANONICAL_UNDEF)
} else {
b = append(b, typeUNDEF)
}
case PerlObject:
b = append(b, typeOBJECT)
b = e.encodeBytes(b, []byte(value.Class), true, strTable)
b, err = e.encode(b, value.Reference, false, false, strTable, ptrTable)
case PerlRegexp:
b = append(b, typeREGEXP)
b = e.encodeBytes(b, value.Pattern, false, strTable)
b = e.encodeBytes(b, value.Modifiers, false, strTable)
case PerlWeakRef:
b = append(b, typeWEAKEN)
b, err = e.encode(b, value.Reference, false, false, strTable, ptrTable)
//case *interface{}:
//TODO handle here if easy
//case interface{}:
// http://blog.golang.org/laws-of-reflection
// One important detail is that the pair inside an interface always has the form (value, concrete type)
// and cannot have the form (value, interface type). Interfaces do not hold interface values.
//panic("interface cannot hold an interface")
// ikruglov
// in theory this block should no be commented,
// but in practise type *interface{} somehow manages to match interface{}
// if one manages to properly implement *interface{} case, this block should be uncommented
default:
b, err = e.encodeViaReflection(b, reflect.ValueOf(value), isKeyOrClass, isRefNext, strTable, ptrTable)
}
return b, err
}
func (e *Encoder) encodeInt(by []byte, k reflect.Kind, i int64) []byte {
switch {
case 0 <= i && i <= 15:
by = append(by, byte(i)&0x0f)
case -16 <= i && i < 0 && k == reflect.Int:
by = append(by, 0x010|(byte(i)&0x0f))
case i > 15:
by = append(by, typeVARINT)
by = varint(by, uint(i))
case i < 0:
n := uint(i)
if k == reflect.Int {
by = append(by, typeZIGZAG)
n = uint((i << 1) ^ (i >> 63))
} else {
by = append(by, typeVARINT)
}
by = varint(by, uint(n))
}
return by
}
func (e *Encoder) encodeFloat(by []byte, f float32) []byte {
u := math.Float32bits(f)
by = append(by, typeFLOAT)
by = append(by, byte(u))
by = append(by, byte(u>>8))
by = append(by, byte(u>>16))
by = append(by, byte(u>>24))
return by
}
func (e *Encoder) encodeDouble(by []byte, f float64) []byte {
u := math.Float64bits(f)
by = append(by, typeDOUBLE)
by = append(by, byte(u))
by = append(by, byte(u>>8))
by = append(by, byte(u>>16))
by = append(by, byte(u>>24))
by = append(by, byte(u>>32))
by = append(by, byte(u>>40))
by = append(by, byte(u>>48))
by = append(by, byte(u>>56))
return by
}
func (e *Encoder) encodeString(by []byte, s string, isKeyOrClass bool, strTable map[string]int) []byte {
if !e.DisableDedup && isKeyOrClass {
if copyOffs, ok := strTable[s]; ok {
by = append(by, typeCOPY)
by = varint(by, uint(copyOffs))
return by
}
strTable[s] = len(by)
}
by = append(by, typeSTR_UTF8)
by = varint(by, uint(len(s)))
return append(by, s...)
}
func (e *Encoder) encodeBytes(by []byte, byt []byte, isKeyOrClass bool, strTable map[string]int) []byte {
if !e.DisableDedup && isKeyOrClass {
if copyOffs, ok := strTable[string(byt)]; ok {
by = append(by, typeCOPY)
by = varint(by, uint(copyOffs))
return by
}
// save for later
strTable[string(byt)] = len(by)
}
if l := len(byt); l < 32 {
by = append(by, typeSHORT_BINARY_0+byte(l))
} else {
by = append(by, typeBINARY)
by = varint(by, uint(l))
}
return append(by, byt...)
}
func (e *Encoder) encodeIntfArray(by []byte, arr []interface{}, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
if e.PerlCompat && !isRefNext {
by = append(by, typeREFN)
}
// TODO implement ARRAYREF for small arrays
l := len(arr)
by = append(by, typeARRAY)
by = varint(by, uint(l))
var err error
for i := 0; i < l; i++ {
if by, err = e.encode(by, arr[i], false, false, strTable, ptrTable); err != nil {
return nil, err
}
}
return by, nil
}
func (e *Encoder) encodeStrMap(by []byte, m map[string]interface{}, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
if e.PerlCompat && !isRefNext {
by = append(by, typeREFN)
}
// TODO implement HASHREF for small maps
by = append(by, typeHASH)
by = varint(by, uint(len(m)))
var err error
for k, v := range m {
by = e.encodeString(by, k, true, strTable)
if by, err = e.encode(by, v, false, false, strTable, ptrTable); err != nil {
return by, err
}
}
return by, nil
}
/*************************************
* Encode via reflection
*************************************/
func (e *Encoder) encodeViaReflection(b []byte, rv reflect.Value, isKeyOrClass bool, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
if !e.DisableFREEZE && rv.Kind() != reflect.Invalid && rv.Kind() != reflect.Ptr {
if m, ok := rv.Interface().(encoding.BinaryMarshaler); ok {
by, err := m.MarshalBinary()
if err != nil {
return nil, err
}
b = append(b, typeOBJECT_FREEZE)
b = e.encodeString(b, concreteName(rv), true, strTable)
b = append(b, typeREFN)
b = append(b, typeARRAY)
b = varint(b, uint(1))
return e.encode(b, reflect.ValueOf(by), false, false, strTable, ptrTable)
}
}
// make sure we're looking at a real type and not an interface
for rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
var err error
switch rk := rv.Kind(); rk {
case reflect.Slice:
// uint8 case is handled in encode()
fallthrough
case reflect.Array:
b, err = e.encodeArray(b, rv, isRefNext, strTable, ptrTable)
case reflect.Map:
b, err = e.encodeMap(b, rv, isRefNext, strTable, ptrTable)
case reflect.Struct:
b, err = e.encodeStruct(b, rv, strTable, ptrTable)
case reflect.Ptr:
b, err = e.encodePointer(b, rv, strTable, ptrTable)
default:
panic(fmt.Sprintf("no support for type '%s' (%s)", rk.String(), rv.Type()))
}
return b, err
}
func (e *Encoder) encodeArray(by []byte, arr reflect.Value, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
if e.PerlCompat && !isRefNext {
by = append(by, typeREFN)
}
l := arr.Len()
by = append(by, typeARRAY)
by = varint(by, uint(l))
var err error
for i := 0; i < l; i++ {
if by, err = e.encode(by, arr.Index(i), false, false, strTable, ptrTable); err != nil {
return nil, err
}
}
return by, nil
}
func (e *Encoder) encodeMap(by []byte, m reflect.Value, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
if e.PerlCompat && !isRefNext {
by = append(by, typeREFN)
}
keys := m.MapKeys()
by = append(by, typeHASH)
by = varint(by, uint(len(keys)))
if e.PerlCompat {
var err error
for _, k := range keys {
by = e.encodeString(by, k.String(), true, strTable)
if by, err = e.encode(by, m.MapIndex(k), false, false, strTable, ptrTable); err != nil {
return by, err
}
}
} else {
var err error
for _, k := range keys {
if by, err = e.encode(by, k, true, false, strTable, ptrTable); err != nil {
return by, err
}
if by, err = e.encode(by, m.MapIndex(k), false, false, strTable, ptrTable); err != nil {
return by, err
}
}
}
return by, nil
}
func (e *Encoder) encodeStruct(by []byte, st reflect.Value, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
tags := e.tcache.Get(st)
by = append(by, typeOBJECT)
by = e.encodeBytes(by, []byte(st.Type().Name()), true, strTable)
if e.PerlCompat {
// must be a reference
by = append(by, typeREFN)
}
by = append(by, typeHASH)
by = varint(by, uint(len(tags)))
var err error
for f, i := range tags {
by = e.encodeString(by, f, true, strTable)
if by, err = e.encode(by, st.Field(i), false, false, strTable, ptrTable); err != nil {
return nil, err
}
}
return by, nil
}
func (e *Encoder) encodePointer(by []byte, rv reflect.Value, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) {
// ikruglov
// I don't fully understand this logic, so leave it as is :-)
if rv.Elem().Kind() == reflect.Struct {
switch rv.Elem().Interface().(type) {
case PerlRegexp:
return e.encode(by, rv.Elem(), false, false, strTable, ptrTable)
case PerlUndef:
return e.encode(by, rv.Elem(), false, false, strTable, ptrTable)
case PerlObject:
return e.encode(by, rv.Elem(), false, false, strTable, ptrTable)
case PerlWeakRef:
return e.encode(by, rv.Elem(), false, false, strTable, ptrTable)
}
}
rvptr := rv.Pointer()
rvptr2 := getPointer(rv.Elem())
offs, ok := ptrTable[rvptr]
if !ok && rvptr2 != 0 {
offs, ok = ptrTable[rvptr2]
if ok {
rvptr = rvptr2
}
}
if ok { // seen this before
by = append(by, typeREFP)
by = varint(by, uint(offs))
by[offs] |= trackFlag // original offset now tracked
} else {
lenbOrig := len(by)
by = append(by, typeREFN)
if rvptr != 0 {
ptrTable[rvptr] = lenbOrig
}
var err error
by, err = e.encode(by, rv.Elem(), false, true, strTable, ptrTable)
if err != nil {
return nil, err
}
if rvptr2 != 0 {
// The thing this this points to starts one after the current pointer
ptrTable[rvptr2] = lenbOrig + 1
}
}
return by, nil
}
func varint(by []byte, n uint) []uint8 {
for n >= 0x80 {
b := byte(n) | 0x80
by = append(by, b)
n >>= 7
}
return append(by, byte(n))
}
func getPointer(rv reflect.Value) uintptr {
var rvptr uintptr
switch rv.Kind() {
case reflect.Map, reflect.Slice:
rvptr = rv.Pointer()
case reflect.Interface:
// FIXME: still needed?
return getPointer(rv.Elem())
case reflect.Ptr:
rvptr = rv.Pointer()
case reflect.String:
ps := (*reflect.StringHeader)(unsafe.Pointer(rv.UnsafeAddr()))
rvptr = ps.Data
}
return rvptr
}
func concreteName(value reflect.Value) string {
return value.Type().PkgPath() + "." + value.Type().Name()
}