tbotd/vendor/src/github.com/Sereal/Sereal/Go/sereal/decode.go

1267 lines
28 KiB
Go

package sereal
import (
"encoding"
"encoding/binary"
"errors"
"fmt"
"math"
"reflect"
"runtime"
"strings"
)
type serealHeader struct {
doctype documentType
version byte
suffixStart int
suffixSize int
suffixFlags uint8
}
func readHeader(b []byte) (serealHeader, error) {
if len(b) <= headerSize {
return serealHeader{}, ErrBadHeader
}
first4Bytes := binary.LittleEndian.Uint32(b[:4])
var h serealHeader
h.doctype = documentType(b[4] >> 4)
h.version = b[4] & 0x0f
validHeader := false
switch first4Bytes {
case magicHeaderBytes:
if 1 <= h.version && h.version <= 2 {
validHeader = true
}
case magicHeaderBytesHighBit:
if h.version >= 3 {
validHeader = true
}
case magicHeaderBytesHighBitUTF8:
return serealHeader{}, ErrBadHeaderUTF8
}
if !validHeader {
return serealHeader{}, ErrBadHeader
}
ln, sz, err := varintdecode(b[5:])
if err != nil {
return serealHeader{}, err
}
h.suffixSize = ln + sz
h.suffixStart = headerSize + sz
return h, nil
}
// A Decoder reads and decodes Sereal objects from an input buffer
type Decoder struct {
tracked map[int]reflect.Value
umcache map[string]reflect.Type
tcache tagsCache
copyDepth int
PerlCompat bool
}
type decompressor interface {
decompress(b []byte) ([]byte, error)
}
// NewDecoder returns a decoder with default flags
func NewDecoder() *Decoder {
return &Decoder{}
}
// Unmarshal decodes b into body with the default decoder
func Unmarshal(b []byte, body interface{}) error {
decoder := &Decoder{}
return decoder.UnmarshalHeaderBody(b, nil, body)
}
// UnmarshalHeader parses the Sereal-v2-encoded buffer b and stores the header data into the variable pointed to by vheader
func (d *Decoder) UnmarshalHeader(b []byte, vheader interface{}) (err error) {
return d.UnmarshalHeaderBody(b, vheader, nil)
}
// Unmarshal parses the Sereal-encoded buffer b and stores the result in the value pointed to by vbody
func (d *Decoder) Unmarshal(b []byte, vbody interface{}) (err error) {
return d.UnmarshalHeaderBody(b, nil, vbody)
}
// UnmarshalHeaderBody parses the Sereal-encoded buffer b extracts the header and body data into vheader and vbody, respectively
func (d *Decoder) UnmarshalHeaderBody(b []byte, vheader interface{}, vbody interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
err = errors.New(s)
} else {
err = r.(error)
}
}
}()
header, err := readHeader(b)
if err != nil {
return err
}
bodyStart := headerSize + header.suffixSize
if bodyStart > len(b) || bodyStart < 0 {
return ErrCorrupt{errBadOffset}
}
switch header.version {
case 1:
break
case 2:
break
case 3:
break
default:
return fmt.Errorf("document version '%d' not yet supported", header.version)
}
var decomp decompressor
switch header.doctype {
case serealRaw:
// nothing
case serealSnappy:
if header.version != 1 {
return ErrBadSnappy
}
decomp = SnappyCompressor{Incremental: false}
case serealSnappyIncremental:
decomp = SnappyCompressor{Incremental: true}
case serealZlib:
if header.version < 3 {
return ErrBadZlibV3
}
decomp = ZlibCompressor{}
default:
return fmt.Errorf("document type '%d' not yet supported", header.doctype)
}
if vheader != nil && header.suffixSize != 1 {
d.tracked = make(map[int]reflect.Value)
defer func() { d.tracked = nil }()
headerValue := reflect.ValueOf(vheader)
if headerValue.Kind() != reflect.Ptr {
return ErrHeaderPointer
}
header.suffixFlags = b[header.suffixStart]
if header.suffixFlags&1 == 1 {
if ptr, ok := vheader.(*interface{}); ok && *ptr == nil {
_, err = d.decode(b[:bodyStart], header.suffixStart+1, ptr)
} else {
_, err = d.decodeViaReflection(b[:bodyStart], header.suffixStart+1, headerValue.Elem())
}
}
}
if err == nil && vbody != nil {
/* XXX instead of creating an uncompressed copy of the document,
* it would be more flexible to use a sort of "Reader" interface */
if decomp != nil {
decompBody, err := decomp.decompress(b[bodyStart:])
if err != nil {
return err
}
// shrink down b to reuse the allocated buffer
b = b[:0]
b = append(b, b[:bodyStart]...)
b = append(b, decompBody...)
}
d.tracked = make(map[int]reflect.Value)
defer func() { d.tracked = nil }()
bodyValue := reflect.ValueOf(vbody)
if bodyValue.Kind() != reflect.Ptr {
return ErrBodyPointer
}
if header.version == 1 {
if ptr, ok := vbody.(*interface{}); ok && *ptr == nil {
_, err = d.decode(b, bodyStart, ptr)
} else {
_, err = d.decodeViaReflection(b, bodyStart, bodyValue.Elem())
}
} else {
// serealv2 documents have 1-based offsets :/
if ptr, ok := vbody.(*interface{}); ok && *ptr == nil {
_, err = d.decode(b[bodyStart-1:], 1, ptr)
} else {
_, err = d.decodeViaReflection(b[bodyStart-1:], 1, bodyValue.Elem())
}
}
}
return err
}
/****************************************************************
* Decode document of unknown structure (i.e. without reflection)
****************************************************************/
func (d *Decoder) decode(by []byte, idx int, ptr *interface{}) (int, error) {
if idx < 0 || idx >= len(by) {
return 0, ErrTruncated
}
tag := by[idx]
// skip over any padding bytes
for tag == typePAD || tag == typePAD|trackFlag {
idx++
if idx >= len(by) {
return 0, ErrTruncated
}
tag = by[idx]
}
trackme := (tag & trackFlag) == trackFlag
if trackme {
tag &^= trackFlag
d.tracked[idx] = reflect.ValueOf(ptr)
}
//fmt.Printf("start decode: tag %d (0x%x) at %d\n", int(tag), int(tag), idx)
idx++
var err error
switch {
case tag < typeVARINT:
*ptr = d.decodeInt(tag)
case tag == typeVARINT:
var val int
if val, idx, err = d.decodeVarint(by, idx); err != nil {
return 0, err
}
if val < 0 {
*ptr = uint(val)
} else {
*ptr = val
}
case tag == typeZIGZAG:
*ptr, idx, err = d.decodeZigzag(by, idx)
case tag == typeFLOAT:
*ptr, idx, err = d.decodeFloat(by, idx)
case tag == typeDOUBLE:
*ptr, idx, err = d.decodeDouble(by, idx)
case tag == typeTRUE:
*ptr = true
case tag == typeFALSE:
*ptr = false
case tag == typeHASH:
// see commends at top of the function
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx, err = d.decodeHash(by, idx+sz, ln, ptr, false)
case tag >= typeHASHREF_0 && tag < typeHASHREF_0+16:
idx, err = d.decodeHash(by, idx, int(tag&0x0f), ptr, d.PerlCompat)
if err != nil {
return 0, err
}
case tag == typeARRAY:
// see commends at top of the function
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx, err = d.decodeArray(by, idx+sz, ln, ptr, false)
if err != nil {
return 0, err
}
case tag >= typeARRAYREF_0 && tag < typeARRAYREF_0+16:
idx, err = d.decodeArray(by, idx, int(tag&0x0f), ptr, d.PerlCompat)
if err != nil {
return 0, err
}
case tag == typeSTR_UTF8:
var val []byte
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
if val, idx, err = d.decodeBinary(by, idx+sz, ln, false); err != nil {
return 0, err
}
*ptr = string(val)
case tag == typeBINARY:
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
*ptr, idx, err = d.decodeBinary(by, idx+sz, ln, true)
if err != nil {
return 0, err
}
case tag >= typeSHORT_BINARY_0 && tag < typeSHORT_BINARY_0+32:
*ptr, idx, err = d.decodeBinary(by, idx, int(tag&0x1f), true)
if err != nil {
return 0, err
}
case tag == typeUNDEF, tag == typeCANONICAL_UNDEF:
if d.PerlCompat && tag == typeCANONICAL_UNDEF {
*ptr = perlCanonicalUndef
} else if d.PerlCompat {
*ptr = &PerlUndef{}
} else {
*ptr = nil
}
case tag == typeCOPY:
if d.copyDepth > 0 {
return 0, ErrCorrupt{errNestedCOPY}
}
var offs, sz int
offs, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
if offs < 0 || offs >= idx {
return 0, ErrCorrupt{errBadOffset}
}
idx += sz
d.copyDepth++
_, err = d.decode(by, offs, ptr)
d.copyDepth--
case tag == typeREFN:
if d.PerlCompat {
var iface interface{}
*ptr = &iface
if idx, err = d.decode(by, idx, &iface); !trackme && err == nil {
// if REFN is not tracked, build a pointer to concrete type
// otherwise let ptr be pointer to something (i.e. *interface{})
riface := reflect.ValueOf(iface)
// reflect.New create value of type *riface.Type())
val := reflect.New(riface.Type())
val.Elem().Set(riface)
*ptr = val.Interface()
}
} else {
idx, err = d.decode(by, idx, ptr)
}
case tag == typeREFP, tag == typeALIAS:
var val reflect.Value
if val, idx, err = d.decodeREFP_ALIAS(by, idx, tag == typeREFP); err == nil {
*ptr = val.Interface()
}
case tag == typeWEAKEN:
if d.PerlCompat {
pweak := PerlWeakRef{}
*ptr = &pweak
idx, err = d.decode(by, idx, &pweak.Reference)
} else {
idx, err = d.decode(by, idx, ptr)
}
case tag == typeREGEXP:
*ptr, idx, err = d.decodeRegexp(by, idx)
case tag == typeOBJECT, tag == typeOBJECTV:
rvPtr := reflect.ValueOf(ptr)
idx, err = d.decodeObjectViaReflection(by, idx, rvPtr.Elem(), tag == typeOBJECTV)
case tag == typeOBJECT_FREEZE, tag == typeOBJECTV_FREEZE:
rvPtr := reflect.ValueOf(ptr)
idx, err = d.decodeObjectFreezeViaReflection(by, idx, rvPtr.Elem(), tag == typeOBJECTV_FREEZE)
default:
return 0, fmt.Errorf("unknown tag byte: %d (0x%x)", int(tag), int(tag))
}
//fmt.Printf("stop decode: tag %d (0x%x)\n", int(tag), int(tag))
return idx, err
}
func (d *Decoder) decodeInt(tag byte) int {
if (tag & 0x10) == 0x10 {
return int(tag) - 32 // negative number
}
return int(tag)
}
func (d *Decoder) decodeVarint(by []byte, idx int) (int, int, error) {
i, sz, err := varintdecode(by[idx:])
return i, idx + sz, err
}
func (d *Decoder) decodeZigzag(by []byte, idx int) (int, int, error) {
i, sz, err := varintdecode(by[idx:])
return int(-(1 + (uint64(i) >> 1))), idx + sz, err
}
func (d *Decoder) decodeFloat(by []byte, idx int) (float32, int, error) {
if idx+3 >= len(by) {
return 0, 0, ErrTruncated
}
bits := uint32(by[idx]) | uint32(by[idx+1])<<8 | uint32(by[idx+2])<<16 | uint32(by[idx+3])<<24
return math.Float32frombits(bits), idx + 4, nil
}
func (d *Decoder) decodeDouble(by []byte, idx int) (float64, int, error) {
if idx+7 >= len(by) {
return 0, 0, ErrTruncated
}
bits := uint64(by[idx]) | uint64(by[idx+1])<<8 | uint64(by[idx+2])<<16 | uint64(by[idx+3])<<24 | uint64(by[idx+4])<<32 | uint64(by[idx+5])<<40 | uint64(by[idx+6])<<48 | uint64(by[idx+7])<<56
return math.Float64frombits(bits), idx + 8, nil
}
func (d *Decoder) decodeHash(by []byte, idx int, ln int, ptr *interface{}, isRef bool) (int, error) {
if ln < 0 || ln > math.MaxInt32 {
return 0, ErrCorrupt{errBadHashSize}
}
if idx+2*ln > len(by) {
return 0, ErrTruncated
}
hash := make(map[string]interface{}, ln)
if isRef {
*ptr = &hash
} else {
*ptr = hash
}
var err error
for i := 0; i < ln; i++ {
var key []byte
key, idx, err = d.decodeStringish(by, idx)
if err != nil {
return 0, err
}
var value interface{}
idx, err = d.decode(by, idx, &value)
if err != nil {
return 0, err
}
hash[string(key)] = value
}
return idx, nil
}
func (d *Decoder) decodeArray(by []byte, idx int, ln int, ptr *interface{}, isRef bool) (int, error) {
if ln < 0 || ln > math.MaxInt32 {
return 0, ErrCorrupt{errBadSliceSize}
}
if idx+ln > len(by) {
return 0, ErrTruncated
}
var slice []interface{}
if ln == 0 {
// FIXME this is not optimal
slice = make([]interface{}, 0, 1)
} else {
slice = make([]interface{}, ln, ln)
}
if isRef {
*ptr = &slice
} else {
*ptr = slice
}
var err error
for i := 0; i < ln; i++ {
idx, err = d.decode(by, idx, &slice[i])
if err != nil {
return 0, err
}
}
return idx, nil
}
func (d *Decoder) decodeBinary(by []byte, idx int, ln int, makeCopy bool) ([]byte, int, error) {
if ln < 0 || ln > math.MaxInt32 {
return nil, 0, ErrCorrupt{errBadStringSize}
}
if idx+ln > len(by) {
return nil, 0, ErrTruncated
}
if makeCopy {
res := make([]byte, ln, ln)
copy(res, by[idx:idx+ln])
return res, idx + ln, nil
}
return by[idx : idx+ln], idx + ln, nil
}
// decodeStringish() return slice of by, i.e. not a copy
func (d *Decoder) decodeStringish(by []byte, idx int) ([]byte, int, error) {
if idx < 0 || idx >= len(by) {
return nil, 0, ErrTruncated
}
//TODO trackme
tag := by[idx]
for tag == typePAD || tag == typePAD|trackFlag {
idx++
if idx >= len(by) {
return nil, 0, ErrTruncated
}
tag = by[idx]
}
tag &^= trackFlag
idx++
//fmt.Printf("decodeStringish: tag %d (0x%x) at %d\n", int(tag), int(tag), idx)
var res []byte
switch {
case tag == typeBINARY, tag == typeSTR_UTF8:
ln, sz, err := varintdecode(by[idx:])
if err != nil {
return nil, 0, err
}
idx += sz
if ln < 0 || ln > math.MaxInt32 {
return nil, 0, ErrCorrupt{errBadStringSize}
} else if idx+ln > len(by) {
return nil, 0, ErrTruncated
}
res = by[idx : idx+ln]
idx += ln
case tag >= typeSHORT_BINARY_0 && tag < typeSHORT_BINARY_0+32:
ln := int(tag & 0x1F) // get length from tag
if idx+ln > len(by) {
return nil, 0, ErrTruncated
}
res = by[idx : idx+ln]
idx += ln
case tag == typeCOPY:
if d.copyDepth > 0 {
return nil, 0, ErrCorrupt{errNestedCOPY}
}
offs, sz, err := varintdecode(by[idx:])
if err != nil {
return nil, 0, err
}
if offs < 0 || offs >= idx {
return nil, 0, ErrCorrupt{errBadOffset}
}
idx += sz
d.copyDepth++
res, _, err = d.decodeStringish(by, offs)
d.copyDepth--
if err != nil {
return nil, 0, err
}
default:
return nil, 0, fmt.Errorf("expect stringish at offset %d but got %d (0x%x)", idx, int(tag), int(tag))
}
//fmt.Printf("decodeStringish res: %s at %d\n", string(res), idx)
return res, idx, nil
}
func (d *Decoder) decodeRegexp(by []byte, idx int) (*PerlRegexp, int, error) {
var err error
var pattern []byte
if pattern, idx, err = d.decodeStringish(by, idx); err != nil {
return nil, 0, err
}
var modifiers []byte
if modifiers, idx, err = d.decodeStringish(by, idx); err != nil {
return nil, 0, err
}
// TODO perhaps, copy values
return &PerlRegexp{pattern, modifiers}, idx, nil
}
/********************************************************************
* Decode document with predefined structure (have to use reflection)
********************************************************************/
func (d *Decoder) decodeViaReflection(by []byte, idx int, ptr reflect.Value) (int, error) {
if idx < 0 || idx >= len(by) {
return 0, ErrTruncated
}
ptrKind := ptr.Kind()
// at this point structure of decoding document is uknown, make a shortcut
if ptrKind == reflect.Interface && ptr.IsNil() {
var iface interface{}
var err error
idx, err = d.decode(by, idx, &iface)
ptr.Set(reflect.ValueOf(iface))
return idx, err
}
tag := by[idx]
for tag == typePAD || tag == typePAD|trackFlag {
idx++
if idx >= len(by) {
return 0, ErrTruncated
}
tag = by[idx]
}
if (tag & trackFlag) == trackFlag {
tag &^= trackFlag
d.tracked[idx] = ptr
}
//fmt.Printf("start decodeViaReflection: tag %d (0x%x) at %d\n", int(tag), int(tag), idx)
idx++
var err error
switch {
case tag < typeVARINT:
setInt(ptr, d.decodeInt(tag))
case tag == typeVARINT:
var val int
val, idx, err = d.decodeVarint(by, idx)
if err != nil {
return 0, err
}
setInt(ptr, val)
case tag == typeZIGZAG:
var val int
val, idx, err = d.decodeZigzag(by, idx)
if err != nil {
return 0, err
}
setInt(ptr, val)
case tag == typeFLOAT:
var val float32
if val, idx, err = d.decodeFloat(by, idx); err != nil {
return 0, err
}
ptr.SetFloat(float64(val))
case tag == typeDOUBLE:
var val float64
if val, idx, err = d.decodeDouble(by, idx); err != nil {
return 0, err
}
ptr.SetFloat(float64(val))
case tag == typeTRUE, tag == typeFALSE:
ptr.SetBool(tag == typeTRUE)
case tag == typeBINARY:
var val []byte
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
if val, idx, err = d.decodeBinary(by, idx+sz, ln, false); err != nil {
return 0, err
}
setBinary(ptr, val)
case tag >= typeSHORT_BINARY_0 && tag < typeSHORT_BINARY_0+32:
var val []byte
if val, idx, err = d.decodeBinary(by, idx, int(tag&0x1f), false); err != nil {
return 0, err
}
setBinary(ptr, val)
case tag == typeSTR_UTF8:
var val []byte
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
if val, idx, err = d.decodeBinary(by, idx+sz, ln, false); err != nil {
return 0, err
}
ptr.SetString(string(val))
case tag == typeHASH:
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx, err = d.decodeHashViaReflection(by, idx+sz, ln, ptr)
case tag >= typeHASHREF_0 && tag < typeHASHREF_0+16:
idx, err = d.decodeHashViaReflection(by, idx, int(tag&0x0f), ptr)
case tag == typeARRAY:
var ln, sz int
ln, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx, err = d.decodeArrayViaReflection(by, idx+sz, ln, ptr)
case tag >= typeARRAYREF_0 && tag < typeARRAYREF_0+16:
idx, err = d.decodeArrayViaReflection(by, idx, int(tag&0x0f), ptr)
case tag == typeUNDEF, tag == typeCANONICAL_UNDEF:
if d.PerlCompat && tag == typeCANONICAL_UNDEF {
ptr.Set(reflect.ValueOf(perlCanonicalUndef))
} else if d.PerlCompat {
ptr.Set(reflect.ValueOf(&PerlUndef{}))
} else {
if ptrKind == reflect.Ptr || ptrKind == reflect.Map || ptrKind == reflect.Slice {
ptr.Set(reflect.Zero(ptr.Type()))
} else {
// maybe panic
}
}
case tag == typeCOPY:
if d.copyDepth > 0 {
return 0, ErrCorrupt{errNestedCOPY}
}
var offs, sz int
offs, sz, err = varintdecode(by[idx:])
if err != nil {
return 0, err
}
if offs < 0 || offs >= idx {
return 0, ErrCorrupt{errBadOffset}
}
idx += sz
d.copyDepth++
_, err = d.decodeViaReflection(by, offs, ptr)
d.copyDepth--
case tag == typeREFN:
idx, err = d.decodeViaReflection(by, idx, ptr)
case tag == typeREFP, tag == typeALIAS:
var val reflect.Value
if val, idx, err = d.decodeREFP_ALIAS(by, idx, tag == typeREFP); err != nil {
return 0, err
}
ptr.Set(val.Elem())
case tag == typeWEAKEN:
if d.PerlCompat {
pweak := PerlWeakRef{}
ptr.Set(reflect.ValueOf(&pweak))
idx, err = d.decode(by, idx, &pweak.Reference)
} else {
idx, err = d.decodeViaReflection(by, idx, ptr)
}
case tag == typeREGEXP:
var pregexp *PerlRegexp
if pregexp, idx, err = d.decodeRegexp(by, idx); err != nil {
return 0, err
}
ptr.Set(reflect.ValueOf(pregexp))
case tag == typeOBJECT, tag == typeOBJECTV:
idx, err = d.decodeObjectViaReflection(by, idx, ptr, tag == typeOBJECTV)
case tag == typeOBJECT_FREEZE, tag == typeOBJECTV_FREEZE:
idx, err = d.decodeObjectFreezeViaReflection(by, idx, ptr, tag == typeOBJECTV_FREEZE)
default:
return 0, fmt.Errorf("unknown tag byte: %d (0x%x)", int(tag), int(tag))
}
return idx, err
}
func (d *Decoder) decodeArrayViaReflection(by []byte, idx int, ln int, ptr reflect.Value) (int, error) {
if ln < 0 || ln > math.MaxInt32 {
return 0, ErrCorrupt{errBadSliceSize}
}
if idx+ln > len(by) {
return 0, ErrTruncated
}
switch ptr.Kind() {
case reflect.Slice:
if ptr.IsNil() || ptr.Len() == 0 {
ptr.Set(reflect.MakeSlice(ptr.Type(), ln, ln))
}
case reflect.Array:
// do nothing
default:
panic(&reflect.ValueError{Method: "sereal.decodeArrayViaReflection", Kind: ptr.Kind()})
}
var err error
ptrLen := ptr.Len()
for i := 0; i < ln; i++ {
if i < ptrLen {
idx, err = d.decodeViaReflection(by, idx, ptr.Index(i))
} else {
// we went outside of array length, so ignore folowwing content
var iface interface{}
idx, err = d.decode(by, idx, &iface) // TODO make this process to be efficient
}
if err != nil {
return 0, err
}
}
return idx, nil
}
func (d *Decoder) decodeHashViaReflection(by []byte, idx int, ln int, ptr reflect.Value) (int, error) {
if ln < 0 || ln > math.MaxInt32 {
return 0, ErrCorrupt{errBadHashSize}
}
if idx+2*ln > len(by) {
return 0, ErrTruncated
}
switch ptr.Kind() {
case reflect.Map:
if ptr.IsNil() {
ptr.Set(reflect.MakeMap(ptr.Type()))
}
var err error
for i := 0; i < ln; i++ {
var key []byte
key, idx, err = d.decodeStringish(by, idx)
if err != nil {
return 0, err
}
keyValue := reflect.ValueOf(string(key))
if value := ptr.MapIndex(keyValue); value.IsValid() {
// strkey exists in map, replace its content but respect structure
idx, err = d.decodeViaReflection(by, idx, value)
} else {
// there is no strkey in map, crete a new one
riface := reflect.New(ptr.Type().Elem())
idx, err = d.decodeViaReflection(by, idx, riface.Elem())
if err != nil {
return 0, err
}
ptr.SetMapIndex(keyValue, riface.Elem())
}
if err != nil {
return 0, err
}
}
case reflect.Ptr:
if ptr.IsNil() {
n := reflect.New(ptr.Type().Elem())
ptr.Set(n)
}
return d.decodeHashViaReflection(by, idx, ln, ptr.Elem())
case reflect.Struct:
tags := d.tcache.Get(ptr)
var err error
for i := 0; i < ln; i++ {
var key []byte
key, idx, err = d.decodeStringish(by, idx)
if err != nil {
return 0, err
}
fld := 0
found := false
strkey := string(key)
if tags == nil {
// do nothing
} else if fld, found = tags[strkey]; found {
idx, err = d.decodeViaReflection(by, idx, ptr.Field(fld))
} else if fld, found = tags[strings.Title(strkey)]; found {
idx, err = d.decodeViaReflection(by, idx, ptr.Field(fld))
}
if !found {
// struct doesn't contain field with strkey name
var iface interface{}
idx, err = d.decode(by, idx, &iface) // TODO make this process to be efficient
}
if err != nil {
return 0, err
}
}
default:
panic(&reflect.ValueError{Method: "sereal.decodeHashViaReflection", Kind: ptr.Kind()})
}
return idx, nil
}
func (d *Decoder) decodeREFP_ALIAS(by []byte, idx int, isREFP bool) (reflect.Value, int, error) {
offs, sz, err := varintdecode(by[idx:])
if err != nil {
var res reflect.Value
return res, 0, err
}
idx += sz
if offs < 0 || offs >= idx {
var res reflect.Value
return res, 0, ErrCorrupt{errBadOffset}
}
rv, ok := d.tracked[offs]
if !ok {
var res reflect.Value
return res, 0, ErrCorrupt{errUntrackedOffsetREFP}
}
var res reflect.Value
if rv.Kind() == reflect.Ptr && rv.Elem().Kind() == reflect.Interface {
// rv contains *interface{},
// i.e. it was saved in decode() path
// rv.Elem() will be an interface
// rv.Elem().Elem() should be the data inside interface
if isREFP {
rvData := rv.Elem().Elem()
res = reflect.New(rvData.Type())
res.Elem().Set(rvData)
} else {
res = rv.Elem()
}
} else {
// rv contains original value
// i.e. it was saved in decodeViaReflection() path
res = reflect.New(rv.Type())
res.Elem().Set(rv)
}
return res, idx, nil
}
func (d *Decoder) decodeObjectViaReflection(by []byte, idx int, ptr reflect.Value, isObjectV bool) (int, error) {
var err error
var className []byte
if !isObjectV {
// typeOBJECT
className, idx, err = d.decodeStringish(by, idx)
} else {
// typeOBJECTV
offs, sz, err := varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx += sz
className, _, err = d.decodeStringish(by, offs)
}
if err != nil {
return 0, err
}
if d.PerlCompat {
pobj := PerlObject{Class: string(className)}
ptr.Set(reflect.ValueOf(&pobj))
idx, err = d.decode(by, idx, &pobj.Reference)
} else {
// FIXME: stuff className somewhere if map/struct?
idx, err = d.decodeViaReflection(by, idx, ptr)
}
return idx, err
}
func (d *Decoder) decodeObjectFreezeViaReflection(by []byte, idx int, ptr reflect.Value, isObjectV bool) (int, error) {
var err error
var className, classData []byte
if !isObjectV {
// typeOBJECT_FREEZE
className, idx, err = d.decodeStringish(by, idx)
if err != nil {
return 0, err
}
} else {
// typeOBJECTV_FREEZE
offs, sz, err := varintdecode(by[idx:])
if err != nil {
return 0, err
}
idx += sz
className, _, err = d.decodeStringish(by, offs)
if err != nil {
return 0, err
}
}
if err != nil {
return 0, err
}
if idx+1 >= len(by) {
return 0, ErrTruncated
}
if by[idx] != typeREFN || by[idx+1] != typeARRAY {
return 0, ErrCorrupt{errFreezeNotRefnArray}
}
var iface interface{}
if idx, err = d.decode(by, idx, &iface); err != nil {
return 0, err
}
wrapper, ok := iface.([]interface{})
if !ok {
return 0, ErrCorrupt{errFreezeNotArray}
}
if len(wrapper) != 1 {
return 0, ErrCorrupt{errFreezeMultipleElts}
}
// Expecting a single item in the array ref
if classData, ok = wrapper[0].([]byte); !ok {
return 0, ErrCorrupt{errFreezeNotByteSlice}
}
strClassName := string(className)
if d.PerlCompat {
ptr.Set(reflect.ValueOf(&PerlFreeze{strClassName, classData}))
} else {
if obj, ok := findUnmarshaler(ptr); ok {
if err := obj.UnmarshalBinary(classData); err != nil {
return 0, err
}
} else {
switch {
case ptr.Kind() == reflect.Interface && ptr.IsNil():
// do we have a registered handler for this type?
concreteClass, ok := d.getUnmarshalerType(strClassName)
if ok {
rzero := instantiateZero(concreteClass)
obj, ok := findUnmarshaler(rzero)
if !ok {
// only things that have an unmarshaler should have been put into the map
panic(fmt.Sprintf("unable to find unmarshaler for %s", rzero))
}
if err := obj.UnmarshalBinary(classData); err != nil {
return 0, err
}
ptr.Set(reflect.ValueOf(obj))
} else {
ptr.Set(reflect.ValueOf(&PerlFreeze{strClassName, classData}))
}
case ptr.Kind() == reflect.Slice && ptr.Type().Elem().Kind() == reflect.Uint8 && ptr.IsNil():
ptr.Set(reflect.ValueOf(classData))
default:
return 0, fmt.Errorf("can't unpack FROZEN object into %v", ptr.Type())
}
}
}
return idx, err
}
func setInt(ptr reflect.Value, i int) {
switch ptr.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
ptr.SetInt(int64(i))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
ptr.SetUint(uint64(i))
default:
panic(&reflect.ValueError{Method: "sereal.setInt", Kind: ptr.Kind()})
}
}
func setBinary(ptr reflect.Value, val []byte) {
switch ptr.Kind() {
case reflect.Slice:
if ptr.Type().Elem().Kind() == reflect.Uint8 && ptr.IsNil() {
slice := make([]byte, len(val), len(val))
ptr.Set(reflect.ValueOf(slice))
}
reflect.Copy(ptr, reflect.ValueOf(val))
case reflect.Array:
reflect.Copy(ptr.Slice(0, ptr.Len()), reflect.ValueOf(val))
case reflect.String:
ptr.SetString(string(val))
default:
panic(&reflect.ValueError{Method: "sereal.setBinary", Kind: ptr.Kind()})
}
}
func varintdecode(by []byte) (n int, sz int, err error) {
s := uint(0) // shift count
for i, b := range by {
n |= int(b&0x7f) << s
s += 7
if (b & 0x80) == 0 {
return n, i + 1, nil
}
if s > 63 {
// too many continuation bits
return 0, i + 1, ErrCorrupt{errBadVarint}
}
}
// byte without continuation bit
return 0, len(by), ErrCorrupt{errBadVarint}
}
func findUnmarshaler(ptr reflect.Value) (encoding.BinaryUnmarshaler, bool) {
if ptr.Kind() == reflect.Ptr && ptr.IsNil() {
p := reflect.New(ptr.Type().Elem())
ptr.Set(p)
}
if obj, ok := ptr.Interface().(encoding.BinaryUnmarshaler); ok {
return obj, true
}
pptr := ptr.Addr()
if obj, ok := pptr.Interface().(encoding.BinaryUnmarshaler); ok {
return obj, true
}
return nil, false
}
// RegisterName registers the named class with an instance of 'value'. When the
// decoder finds a FREEZE tag with the given class, the binary data will be
// passed to value's UnmarshalBinary method.
func (d *Decoder) RegisterName(name string, value interface{}) {
if d.umcache == nil {
d.umcache = make(map[string]reflect.Type)
}
rv := reflect.ValueOf(value)
if _, ok := value.(encoding.BinaryUnmarshaler); ok {
d.umcache[name] = rv.Type()
return
}
prv := rv.Addr()
if _, ok := prv.Interface().(encoding.BinaryUnmarshaler); ok {
d.umcache[name] = prv.Type()
return
}
panic(fmt.Sprintf("unable to register type %s: not encoding.BinaryUnmarshaler", rv.Type()))
}
func (d *Decoder) getUnmarshalerType(name string) (reflect.Type, bool) {
if d.umcache == nil {
return nil, false
}
val, ok := d.umcache[name]
return val, ok
}
func instantiateZero(typ reflect.Type) reflect.Value {
if typ.Kind() == reflect.Ptr {
return reflect.New(typ.Elem())
}
v := reflect.New(typ)
return v.Addr()
}