package sereal import ( "encoding" "encoding/binary" "errors" "fmt" "math" "reflect" "runtime" "unsafe" ) // An Encoder encodes Go data structures into Sereal byte streams type Encoder struct { PerlCompat bool // try to mimic Perl's structure as much as possible Compression compressor // optionally compress the main payload of the document using SnappyCompressor or ZlibCompressor CompressionThreshold int // threshold in bytes above which compression is attempted: 1024 bytes by default DisableDedup bool // should we disable deduping of class names and hash keys DisableFREEZE bool // should we disable the FREEZE tag, which calls MarshalBinary ExpectedSize uint // give a hint to encoder about expected size of encoded data version int // default version to encode tcache tagsCache } type compressor interface { compress(b []byte) ([]byte, error) } // NewEncoder returns a new Encoder struct with default values func NewEncoder() *Encoder { return &Encoder{ PerlCompat: false, CompressionThreshold: 1024, version: 1, } } // NewEncoderV2 returns a new Encoder that encodes version 2 func NewEncoderV2() *Encoder { return &Encoder{ PerlCompat: false, CompressionThreshold: 1024, version: 2, } } // NewEncoderV3 returns a new Encoder that encodes version 3 func NewEncoderV3() *Encoder { return &Encoder{ PerlCompat: false, CompressionThreshold: 1024, version: 3, } } var defaultEncoder = NewEncoderV3() // Marshal encodes body with the default encoder func Marshal(body interface{}) ([]byte, error) { return defaultEncoder.MarshalWithHeader(nil, body) } // Marshal returns the Sereal encoding of body func (e *Encoder) Marshal(body interface{}) (b []byte, err error) { return e.MarshalWithHeader(nil, body) } // MarshalWithHeader returns the Sereal encoding of body with header data func (e *Encoder) MarshalWithHeader(header interface{}, body interface{}) (b []byte, err error) { defer func() { //return if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { panic(r) } if s, ok := r.(string); ok { err = errors.New(s) } else { err = r.(error) } } }() // uninitialized encoder? set to the most recent supported protocol version if e.version == 0 { e.version = ProtocolVersion } encHeader := make([]byte, headerSize, 32) if e.version < 3 { binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytes) } else { binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytesHighBit) } // Set the component in the header encHeader[4] = byte(e.version) | byte(serealRaw)<<4 if header != nil && e.version >= 2 { strTable := make(map[string]int) ptrTable := make(map[uintptr]int) // this is both the flag byte (== "there is user data") and also a hack to make 1-based offsets work henv := []byte{0x01} // flag byte == "there is user data" encHeaderSuffix, err := e.encode(henv, header, false, false, strTable, ptrTable) if err != nil { return nil, err } encHeader = varint(encHeader, uint(len(encHeaderSuffix))) encHeader = append(encHeader, encHeaderSuffix...) } else { /* header size */ encHeader = append(encHeader, 0) } strTable := make(map[string]int) ptrTable := make(map[uintptr]int) encBody := make([]byte, 0, e.ExpectedSize) switch e.version { case 1: encBody, err = e.encode(encBody, body, false, false, strTable, ptrTable) case 2, 3: encBody = append(encBody, 0) // hack for 1-based offsets encBody, err = e.encode(encBody, body, false, false, strTable, ptrTable) encBody = encBody[1:] // trim hacky first byte } if err != nil { return nil, err } if e.Compression != nil && (e.CompressionThreshold == 0 || len(encBody) >= e.CompressionThreshold) { encBody, err = e.Compression.compress(encBody) if err != nil { return nil, err } var doctype documentType switch c := e.Compression.(type) { case SnappyCompressor: if e.version > 1 && !c.Incremental { return nil, errors.New("non-incremental snappy compression only valid for v1 documents") } if e.version == 1 { doctype = serealSnappy } else { doctype = serealSnappyIncremental } case ZlibCompressor: if e.version < 3 { return nil, errors.New("zlib compression only valid for v3 documents and up") } doctype = serealZlib default: // Defensive programming: this point should never be // reached in production code because the compressor // interface is not exported, hence no way to pass in // an unknown thing. But it may happen during // development when a new compressor is implemented, // but a relevant document type is not defined. panic("undefined compression") } encHeader[4] |= byte(doctype) << 4 } return append(encHeader, encBody...), nil } /************************************* * Encode via static types - fast path *************************************/ func (e *Encoder) encode(b []byte, v interface{}, isKeyOrClass bool, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { var err error switch value := v.(type) { case nil: b = append(b, typeUNDEF) case bool: if value { b = append(b, typeTRUE) } else { b = append(b, typeFALSE) } case int: b = e.encodeInt(b, reflect.Int, int64(value)) case int8: b = e.encodeInt(b, reflect.Int, int64(value)) case int16: b = e.encodeInt(b, reflect.Int, int64(value)) case int32: b = e.encodeInt(b, reflect.Int, int64(value)) case int64: b = e.encodeInt(b, reflect.Int, int64(value)) case uint: b = e.encodeInt(b, reflect.Uint, int64(value)) case uint8: b = e.encodeInt(b, reflect.Uint, int64(value)) case uint16: b = e.encodeInt(b, reflect.Uint, int64(value)) case uint32: b = e.encodeInt(b, reflect.Uint, int64(value)) case uint64: b = e.encodeInt(b, reflect.Uint, int64(value)) case float32: b = e.encodeFloat(b, value) case float64: b = e.encodeDouble(b, value) case string: b = e.encodeString(b, value, isKeyOrClass, strTable) case []uint8: b = e.encodeBytes(b, value, isKeyOrClass, strTable) case []interface{}: b, err = e.encodeIntfArray(b, value, isRefNext, strTable, ptrTable) case map[string]interface{}: b, err = e.encodeStrMap(b, value, isRefNext, strTable, ptrTable) case reflect.Value: if value.Kind() == reflect.Invalid { b = append(b, typeUNDEF) } else { // could be optimized to tail call b, err = e.encode(b, value.Interface(), false, isRefNext, strTable, ptrTable) } case PerlUndef: if value.canonical { b = append(b, typeCANONICAL_UNDEF) } else { b = append(b, typeUNDEF) } case PerlObject: b = append(b, typeOBJECT) b = e.encodeBytes(b, []byte(value.Class), true, strTable) b, err = e.encode(b, value.Reference, false, false, strTable, ptrTable) case PerlRegexp: b = append(b, typeREGEXP) b = e.encodeBytes(b, value.Pattern, false, strTable) b = e.encodeBytes(b, value.Modifiers, false, strTable) case PerlWeakRef: b = append(b, typeWEAKEN) b, err = e.encode(b, value.Reference, false, false, strTable, ptrTable) //case *interface{}: //TODO handle here if easy //case interface{}: // http://blog.golang.org/laws-of-reflection // One important detail is that the pair inside an interface always has the form (value, concrete type) // and cannot have the form (value, interface type). Interfaces do not hold interface values. //panic("interface cannot hold an interface") // ikruglov // in theory this block should no be commented, // but in practise type *interface{} somehow manages to match interface{} // if one manages to properly implement *interface{} case, this block should be uncommented default: b, err = e.encodeViaReflection(b, reflect.ValueOf(value), isKeyOrClass, isRefNext, strTable, ptrTable) } return b, err } func (e *Encoder) encodeInt(by []byte, k reflect.Kind, i int64) []byte { switch { case 0 <= i && i <= 15: by = append(by, byte(i)&0x0f) case -16 <= i && i < 0 && k == reflect.Int: by = append(by, 0x010|(byte(i)&0x0f)) case i > 15: by = append(by, typeVARINT) by = varint(by, uint(i)) case i < 0: n := uint(i) if k == reflect.Int { by = append(by, typeZIGZAG) n = uint((i << 1) ^ (i >> 63)) } else { by = append(by, typeVARINT) } by = varint(by, uint(n)) } return by } func (e *Encoder) encodeFloat(by []byte, f float32) []byte { u := math.Float32bits(f) by = append(by, typeFLOAT) by = append(by, byte(u)) by = append(by, byte(u>>8)) by = append(by, byte(u>>16)) by = append(by, byte(u>>24)) return by } func (e *Encoder) encodeDouble(by []byte, f float64) []byte { u := math.Float64bits(f) by = append(by, typeDOUBLE) by = append(by, byte(u)) by = append(by, byte(u>>8)) by = append(by, byte(u>>16)) by = append(by, byte(u>>24)) by = append(by, byte(u>>32)) by = append(by, byte(u>>40)) by = append(by, byte(u>>48)) by = append(by, byte(u>>56)) return by } func (e *Encoder) encodeString(by []byte, s string, isKeyOrClass bool, strTable map[string]int) []byte { if !e.DisableDedup && isKeyOrClass { if copyOffs, ok := strTable[s]; ok { by = append(by, typeCOPY) by = varint(by, uint(copyOffs)) return by } strTable[s] = len(by) } by = append(by, typeSTR_UTF8) by = varint(by, uint(len(s))) return append(by, s...) } func (e *Encoder) encodeBytes(by []byte, byt []byte, isKeyOrClass bool, strTable map[string]int) []byte { if !e.DisableDedup && isKeyOrClass { if copyOffs, ok := strTable[string(byt)]; ok { by = append(by, typeCOPY) by = varint(by, uint(copyOffs)) return by } // save for later strTable[string(byt)] = len(by) } if l := len(byt); l < 32 { by = append(by, typeSHORT_BINARY_0+byte(l)) } else { by = append(by, typeBINARY) by = varint(by, uint(l)) } return append(by, byt...) } func (e *Encoder) encodeIntfArray(by []byte, arr []interface{}, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { if e.PerlCompat && !isRefNext { by = append(by, typeREFN) } // TODO implement ARRAYREF for small arrays l := len(arr) by = append(by, typeARRAY) by = varint(by, uint(l)) var err error for i := 0; i < l; i++ { if by, err = e.encode(by, arr[i], false, false, strTable, ptrTable); err != nil { return nil, err } } return by, nil } func (e *Encoder) encodeStrMap(by []byte, m map[string]interface{}, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { if e.PerlCompat && !isRefNext { by = append(by, typeREFN) } // TODO implement HASHREF for small maps by = append(by, typeHASH) by = varint(by, uint(len(m))) var err error for k, v := range m { by = e.encodeString(by, k, true, strTable) if by, err = e.encode(by, v, false, false, strTable, ptrTable); err != nil { return by, err } } return by, nil } /************************************* * Encode via reflection *************************************/ func (e *Encoder) encodeViaReflection(b []byte, rv reflect.Value, isKeyOrClass bool, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { if !e.DisableFREEZE && rv.Kind() != reflect.Invalid && rv.Kind() != reflect.Ptr { if m, ok := rv.Interface().(encoding.BinaryMarshaler); ok { by, err := m.MarshalBinary() if err != nil { return nil, err } b = append(b, typeOBJECT_FREEZE) b = e.encodeString(b, concreteName(rv), true, strTable) b = append(b, typeREFN) b = append(b, typeARRAY) b = varint(b, uint(1)) return e.encode(b, reflect.ValueOf(by), false, false, strTable, ptrTable) } } // make sure we're looking at a real type and not an interface for rv.Kind() == reflect.Interface { rv = rv.Elem() } var err error switch rk := rv.Kind(); rk { case reflect.Slice: // uint8 case is handled in encode() fallthrough case reflect.Array: b, err = e.encodeArray(b, rv, isRefNext, strTable, ptrTable) case reflect.Map: b, err = e.encodeMap(b, rv, isRefNext, strTable, ptrTable) case reflect.Struct: b, err = e.encodeStruct(b, rv, strTable, ptrTable) case reflect.Ptr: b, err = e.encodePointer(b, rv, strTable, ptrTable) default: panic(fmt.Sprintf("no support for type '%s' (%s)", rk.String(), rv.Type())) } return b, err } func (e *Encoder) encodeArray(by []byte, arr reflect.Value, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { if e.PerlCompat && !isRefNext { by = append(by, typeREFN) } l := arr.Len() by = append(by, typeARRAY) by = varint(by, uint(l)) var err error for i := 0; i < l; i++ { if by, err = e.encode(by, arr.Index(i), false, false, strTable, ptrTable); err != nil { return nil, err } } return by, nil } func (e *Encoder) encodeMap(by []byte, m reflect.Value, isRefNext bool, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { if e.PerlCompat && !isRefNext { by = append(by, typeREFN) } keys := m.MapKeys() by = append(by, typeHASH) by = varint(by, uint(len(keys))) if e.PerlCompat { var err error for _, k := range keys { by = e.encodeString(by, k.String(), true, strTable) if by, err = e.encode(by, m.MapIndex(k), false, false, strTable, ptrTable); err != nil { return by, err } } } else { var err error for _, k := range keys { if by, err = e.encode(by, k, true, false, strTable, ptrTable); err != nil { return by, err } if by, err = e.encode(by, m.MapIndex(k), false, false, strTable, ptrTable); err != nil { return by, err } } } return by, nil } func (e *Encoder) encodeStruct(by []byte, st reflect.Value, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { tags := e.tcache.Get(st) by = append(by, typeOBJECT) by = e.encodeBytes(by, []byte(st.Type().Name()), true, strTable) if e.PerlCompat { // must be a reference by = append(by, typeREFN) } by = append(by, typeHASH) by = varint(by, uint(len(tags))) var err error for f, i := range tags { by = e.encodeString(by, f, true, strTable) if by, err = e.encode(by, st.Field(i), false, false, strTable, ptrTable); err != nil { return nil, err } } return by, nil } func (e *Encoder) encodePointer(by []byte, rv reflect.Value, strTable map[string]int, ptrTable map[uintptr]int) ([]byte, error) { // ikruglov // I don't fully understand this logic, so leave it as is :-) if rv.Elem().Kind() == reflect.Struct { switch rv.Elem().Interface().(type) { case PerlRegexp: return e.encode(by, rv.Elem(), false, false, strTable, ptrTable) case PerlUndef: return e.encode(by, rv.Elem(), false, false, strTable, ptrTable) case PerlObject: return e.encode(by, rv.Elem(), false, false, strTable, ptrTable) case PerlWeakRef: return e.encode(by, rv.Elem(), false, false, strTable, ptrTable) } } rvptr := rv.Pointer() rvptr2 := getPointer(rv.Elem()) offs, ok := ptrTable[rvptr] if !ok && rvptr2 != 0 { offs, ok = ptrTable[rvptr2] if ok { rvptr = rvptr2 } } if ok { // seen this before by = append(by, typeREFP) by = varint(by, uint(offs)) by[offs] |= trackFlag // original offset now tracked } else { lenbOrig := len(by) by = append(by, typeREFN) if rvptr != 0 { ptrTable[rvptr] = lenbOrig } var err error by, err = e.encode(by, rv.Elem(), false, true, strTable, ptrTable) if err != nil { return nil, err } if rvptr2 != 0 { // The thing this this points to starts one after the current pointer ptrTable[rvptr2] = lenbOrig + 1 } } return by, nil } func varint(by []byte, n uint) []uint8 { for n >= 0x80 { b := byte(n) | 0x80 by = append(by, b) n >>= 7 } return append(by, byte(n)) } func getPointer(rv reflect.Value) uintptr { var rvptr uintptr switch rv.Kind() { case reflect.Map, reflect.Slice: rvptr = rv.Pointer() case reflect.Interface: // FIXME: still needed? return getPointer(rv.Elem()) case reflect.Ptr: rvptr = rv.Pointer() case reflect.String: ps := (*reflect.StringHeader)(unsafe.Pointer(rv.UnsafeAddr())) rvptr = ps.Data } return rvptr } func concreteName(value reflect.Value) string { return value.Type().PkgPath() + "." + value.Type().Name() }