package sereal import ( "encoding/binary" "errors" "fmt" "math" "sort" ) type topLevelElementType int // Top-level data structures for merged documents const ( TopLevelArray topLevelElementType = iota TopLevelArrayRef // TopLevelHash // TopLevelHashRef ) const hashKeysValuesFlag = uint32(1 << 31) // Merger merges multiple sereal documents without reconstructing them type Merger struct { buf []byte strTable map[string]int objTable map[string]int version int length int lenOffset int bodyOffset int // 1-based // public arguments // TopLevelElement allows a user to choose what container will be used // at top level. Available options: array, arrayref, hash, hashref TopLevelElement topLevelElementType // optionally compress the main payload of the document using SnappyCompressor or ZlibCompressor // CompressionThreshold specifies threshold in bytes above which compression is attempted: 1024 bytes by default Compression compressor CompressionThreshold int // If enabled, merger will deduplicate all strings it meets. // Otherwise, only hash key and class names will be deduplicated DedupeStrings bool // If enabled, KeepFlat keeps flat structture of final document. // Specifically, consider two arrays [A,B,C] and [D,E,F]: // - when KeepFlat == false, the result of merging is [[A,B,C],[D,E,F]] // - when KeepFlat == true, the result is [A,B,C,D,E,F] // This mode is relevant only to top level elements KeepFlat bool // give a hint to encoder about expected size of encoded data ExpectedSize uint // moved bool fields here to make struct smaller inited bool finished bool } type mergerDoc struct { buf []byte trackIdxs []int trackTable map[int]int version int startIdx int // 0-based bodyOffset int // 1-based } // NewMerger returns a merger using the latest sereal version func NewMerger() *Merger { return &Merger{ TopLevelElement: TopLevelArrayRef, CompressionThreshold: 1024, } } // NewMergerV2 returns a merger for processing sereal v2 documents func NewMergerV2() *Merger { return &Merger{ version: 2, TopLevelElement: TopLevelArrayRef, CompressionThreshold: 1024, } } // NewMergerV3 returns a merger for processing sereal v3 documents func NewMergerV3() *Merger { return &Merger{ version: 3, TopLevelElement: TopLevelArrayRef, CompressionThreshold: 1024, } } func (m *Merger) initMerger() error { if m.inited { return nil } // initialize internal fields m.strTable = make(map[string]int) m.objTable = make(map[string]int) if m.ExpectedSize > 0 { m.buf = make([]byte, headerSize, m.ExpectedSize) } else { m.buf = make([]byte, headerSize) } if m.version == 0 { m.version = ProtocolVersion } switch { case m.version > ProtocolVersion: return fmt.Errorf("protocol version '%v' not yet supported", m.version) case m.version < 3: binary.LittleEndian.PutUint32(m.buf[:4], magicHeaderBytes) default: binary.LittleEndian.PutUint32(m.buf[:4], magicHeaderBytesHighBit) } m.buf[4] = byte(m.version) // fill version m.buf = append(m.buf, 0) // no header m.bodyOffset = len(m.buf) - 1 // remember body offset // append top level tags switch m.TopLevelElement { case TopLevelArray: m.buf = append(m.buf, typeARRAY) case TopLevelArrayRef: m.buf = append(m.buf, typeREFN, typeARRAY) default: return errors.New("invalid TopLevelElement") } // remember len offset + pad bytes for length m.lenOffset = len(m.buf) for i := 0; i < binary.MaxVarintLen32; i++ { m.buf = append(m.buf, typePAD) } m.inited = true return nil } // Append adds the sereal document b and returns the number of elements added to the top-level structure func (m *Merger) Append(b []byte) (int, error) { if err := m.initMerger(); err != nil { return 0, err } if m.finished { return 0, errors.New("finished document") } docHeader, err := readHeader(b) if err != nil { return 0, err } doc := mergerDoc{ buf: b[headerSize+docHeader.suffixSize:], version: int(docHeader.version), startIdx: 0, bodyOffset: -1, // 1-based offsets } var decomp decompressor switch docHeader.doctype { case serealRaw: // nothing case serealSnappy: if doc.version != 1 { return 0, errors.New("snappy compression only valid for v1 documents") } decomp = SnappyCompressor{Incremental: false} case serealSnappyIncremental: decomp = SnappyCompressor{Incremental: true} case serealZlib: if doc.version < 3 { return 0, errors.New("zlib compression only valid for v3 documents and up") } decomp = ZlibCompressor{} default: return 0, fmt.Errorf("document type '%d' not yet supported", docHeader.doctype) } if decomp != nil { if doc.buf, err = decomp.decompress(doc.buf); err != nil { return 0, err } } oldLength := m.length lastElementOffset := len(m.buf) // first pass: build table of tracked tags if err := m.buildTrackTable(&doc); err != nil { return 0, err } // preallocate memory // copying data from doc.buf might seem to be unefficient, // but profiling/benchmarking shows that there is no // difference between growing slice by append() or via new() + copy() m.buf = append(m.buf, doc.buf...) m.buf = m.buf[:lastElementOffset] // second pass: do the work if err := m.mergeItems(&doc); err != nil { m.buf = m.buf[0:lastElementOffset] // remove appended stuff return 0, err } return m.length - oldLength, nil } // Finish is called to terminate the merging process func (m *Merger) Finish() ([]byte, error) { if err := m.initMerger(); err != nil { return m.buf, err } if !m.finished { m.finished = true binary.PutUvarint(m.buf[m.lenOffset:], uint64(m.length)) if m.Compression != nil && (m.CompressionThreshold == 0 || len(m.buf) >= m.CompressionThreshold) { compressed, err := m.Compression.compress(m.buf[m.bodyOffset+1:]) if err != nil { return m.buf, err } // TODO think about some optimizations here copy(m.buf[m.bodyOffset+1:], compressed) m.buf = m.buf[:len(compressed)+m.bodyOffset+1] // verify compressor, there was little point in veryfing compressor in initMerger() // because use can change it meanwhile switch comp := m.Compression.(type) { case SnappyCompressor: if !comp.Incremental { return nil, errors.New("non-incremental snappy compression is not supported") } m.buf[4] |= byte(serealSnappyIncremental) << 4 case ZlibCompressor: if m.version < 3 { return nil, errors.New("zlib compression only valid for v3 documents and up") } m.buf[4] |= byte(serealZlib) << 4 default: return nil, errors.New("unknown compressor") } } } return m.buf, nil } func (m *Merger) buildTrackTable(doc *mergerDoc) error { buf := doc.buf idx := doc.startIdx if idx < 0 || idx > len(buf) { return errors.New("invalid index") } doc.trackTable = make(map[int]int) doc.trackIdxs = make([]int, 0) for idx < len(buf) { tag := buf[idx] if (tag & trackFlag) == trackFlag { doc.trackTable[idx-doc.bodyOffset] = -1 tag &^= trackFlag } //fmt.Printf("%x (%x) at %d (%d)\n", tag, buf[idx], idx, idx - doc.bodyOffset) switch { case tag < typeVARINT, tag == typePAD, tag == typeREFN, tag == typeWEAKEN, tag == typeUNDEF, tag == typeCANONICAL_UNDEF, tag == typeTRUE, tag == typeFALSE, tag == typeEXTEND, tag == typeREGEXP, tag == typeOBJECT, tag == typeOBJECT_FREEZE: idx++ case tag == typeVARINT, tag == typeZIGZAG: _, sz, err := varintdecode(buf[idx+1:]) if err != nil { return err } idx += sz + 1 case tag == typeFLOAT: idx += 5 // 4 bytes + tag case tag == typeDOUBLE: idx += 9 // 8 bytes + tag case tag == typeLONG_DOUBLE: idx += 17 // 16 bytes + tag case tag == typeBINARY, tag == typeSTR_UTF8: ln, sz, err := varintdecode(buf[idx+1:]) if err != nil { return err } idx += sz + ln + 1 if ln < 0 || ln > math.MaxUint32 { return fmt.Errorf("bad size for string: %d", ln) } else if idx > len(buf) { return fmt.Errorf("truncated document, expect %d bytes", len(buf)-idx) } case tag == typeARRAY, tag == typeHASH: _, sz, err := varintdecode(buf[idx+1:]) if err != nil { return err } idx += sz + 1 case tag == typeCOPY, tag == typeALIAS, tag == typeREFP, tag == typeOBJECTV, tag == typeOBJECTV_FREEZE: offset, sz, err := varintdecode(buf[idx+1:]) if err != nil { return err } if offset < 0 || offset >= idx { return fmt.Errorf("tag %d refers to invalid offset: %d", tag, offset) } doc.trackTable[offset] = -1 idx += sz + 1 case tag >= typeARRAYREF_0 && tag < typeARRAYREF_0+16: idx++ case tag >= typeHASHREF_0 && tag < typeHASHREF_0+16: idx++ case tag >= typeSHORT_BINARY_0 && tag < typeSHORT_BINARY_0+32: idx += 1 + int(tag&0x1F) // case tag == typeMANY: TODO case tag == typePACKET_START: return errors.New("unexpected start of new document") default: return fmt.Errorf("unknown tag: %d (0x%x) at offset %d", tag, tag, idx) } } for idx := range doc.trackTable { doc.trackIdxs = append(doc.trackIdxs, idx) } sort.Ints(doc.trackIdxs) return nil } func (m *Merger) mergeItems(doc *mergerDoc) error { mbuf := m.buf dbuf := doc.buf didx := doc.startIdx expElements, offset, err := m.expectedElements(dbuf[didx:]) if err != nil { return err } if expElements < 0 || expElements > math.MaxUint32 { return fmt.Errorf("bad amount of expected elements: %d", expElements) } didx += offset // stack is needed for three things: // - keep track of expected things // - verify document consistency // if a value put on stack has the highest significant bit on, // it means that hash keys/values are processed stack := make([]uint32, 0, 16) // preallocate 16 nested levels stack = append(stack, uint32(expElements)) LOOP: for didx < len(dbuf) { tag := dbuf[didx] tag &^= trackFlag docRelativeIdx := didx - doc.bodyOffset mrgRelativeIdx := len(mbuf) - m.bodyOffset trackme := len(doc.trackIdxs) > 0 && doc.trackIdxs[0] == docRelativeIdx level := len(stack) - 1 for stack[level]&^hashKeysValuesFlag == 0 { stack = stack[:level] level-- if level < 0 { break LOOP } } // If m.DedupeStrings is true - dedup all strings, otherwise dedup only hash keys and class names. // The trick with stack[level] % 2 == 0 works because stack[level] for hashes is always even at // the beggining (for each item in hash we expect key and value). In practise it means, // that if stack[level] is even - a key is being processed, if stack[level] is odd - value is being processed dedupString := m.DedupeStrings || ((stack[level]&hashKeysValuesFlag) == hashKeysValuesFlag && stack[level]%2 == 0) //fmt.Printf("0x%x (0x%x) at didx: %d (rlt: %d) len(dbuf): %d\n", tag, dbuf[didx], didx, didx-doc.bodyOffset, len(dbuf)) //fmt.Printf("level: %d, value: %d len: %d\n", level, stack[level], len(stack)) //fmt.Println("------") switch { case tag < typeVARINT, tag == typeUNDEF, tag == typeCANONICAL_UNDEF, tag == typeTRUE, tag == typeFALSE, tag == typeSHORT_BINARY_0: mbuf = append(mbuf, dbuf[didx]) didx++ case tag == typePAD, tag == typeREFN, tag == typeWEAKEN, tag == typeEXTEND: // this elemets are fake ones, so stack counter should not be decreased // but, I don't want to create another if-branch, so fake it stack[level]++ mbuf = append(mbuf, dbuf[didx]) didx++ case tag == typeVARINT, tag == typeZIGZAG: _, sz, err := varintdecode(dbuf[didx+1:]) if err != nil { return err } mbuf = append(mbuf, dbuf[didx:didx+sz+1]...) didx += sz + 1 case tag == typeFLOAT: mbuf = append(mbuf, dbuf[didx:didx+5]...) didx += 5 // 4 bytes + tag case tag == typeDOUBLE: mbuf = append(mbuf, dbuf[didx:didx+9]...) didx += 9 // 8 bytes + tag case tag == typeLONG_DOUBLE: mbuf = append(mbuf, dbuf[didx:didx+17]...) didx += 17 // 16 bytes + tag case tag == typeSHORT_BINARY_0+1: mbuf = append(mbuf, dbuf[didx:didx+2]...) didx += 2 case tag == typeBINARY, tag == typeSTR_UTF8, tag > typeSHORT_BINARY_0+1 && tag < typeSHORT_BINARY_0+32: // I don't want to call readString here because of performance reasons: // this path is the hot spot, so keep it overhead-free as much as possible var ln, sz int if tag > typeSHORT_BINARY_0 { ln = int(tag & 0x1F) // get length from tag } else { var err error ln, sz, err = varintdecode(dbuf[didx+1:]) if err != nil { return err } } length := sz + ln + 1 if ln < 0 || ln > math.MaxUint32 { return fmt.Errorf("bad size for string: %d", ln) } else if didx+length > len(dbuf) { return fmt.Errorf("truncated document, expect %d bytes", len(dbuf)-didx-length) } if dedupString { val := dbuf[didx+sz+1 : didx+length] if savedOffset, ok := m.strTable[string(val)]; ok { mbuf = appendTagVarint(mbuf, typeCOPY, uint(savedOffset)) mrgRelativeIdx = savedOffset } else { m.strTable[string(val)] = mrgRelativeIdx mbuf = append(mbuf, dbuf[didx:didx+length]...) } } else { mbuf = append(mbuf, dbuf[didx:didx+length]...) } didx += length case tag == typeCOPY, tag == typeREFP, tag == typeALIAS, tag == typeOBJECTV, tag == typeOBJECTV_FREEZE: offset, sz, err := varintdecode(dbuf[didx+1:]) if err != nil { return err } targetOffset, ok := doc.trackTable[offset] if !ok || targetOffset < 0 { return errors.New("bad target offset at COPY, ALIAS or REFP tag") } mbuf = appendTagVarint(mbuf, dbuf[didx], uint(targetOffset)) didx += sz + 1 if tag == typeALIAS || tag == typeREFP { mbuf[targetOffset] |= trackFlag } else if tag == typeOBJECTV || tag == typeOBJECTV_FREEZE { stack = append(stack, 1) } case tag == typeARRAY, tag == typeHASH: ln, sz, err := varintdecode(dbuf[didx+1:]) if err != nil { return err } if ln < 0 { return errors.New("bad array or hash length") } mbuf = append(mbuf, dbuf[didx:didx+sz+1]...) didx += sz + 1 if tag == typeHASH { stack = append(stack, uint32(ln*2)|hashKeysValuesFlag) } else { stack = append(stack, uint32(ln)) } case (tag >= typeARRAYREF_0 && tag < typeARRAYREF_0+16) || (tag >= typeHASHREF_0 && tag < typeHASHREF_0+16): mbuf = append(mbuf, dbuf[didx]) didx++ // for hash read 2*ln items if tag >= typeHASHREF_0 { stack = append(stack, uint32(tag&0xF*2)|hashKeysValuesFlag) } else { stack = append(stack, uint32(tag&0xF)) } case tag == typeREGEXP: offset, str, err := readString(dbuf[didx+1:]) if err != nil { return err } sizeToCopy := offset + len(str) + 1 offset, str, err = readString(dbuf[didx+sizeToCopy:]) if err != nil { return err } sizeToCopy += offset + len(str) mbuf = append(mbuf, dbuf[didx:didx+sizeToCopy]...) didx += sizeToCopy case tag == typeOBJECT, tag == typeOBJECT_FREEZE: // skip main tag for a second, and parse offset, str, err := readString(dbuf[didx+1:]) if err != nil { return err } length := offset + len(str) + 1 // respect typeOBJECT tag if savedOffset, ok := m.objTable[string(str)]; ok { if tag == typeOBJECT { mbuf = appendTagVarint(mbuf, typeOBJECTV, uint(savedOffset)) } else { mbuf = appendTagVarint(mbuf, typeOBJECTV_FREEZE, uint(savedOffset)) } mrgRelativeIdx = savedOffset } else { // +1 because we should refer to string tag, not object tag mrgRelativeIdx++ m.objTable[string(str)] = mrgRelativeIdx mbuf = append(mbuf, dbuf[didx:didx+length]...) } // parse stack = append(stack, 1) didx += length case tag == typePACKET_START: return errors.New("unexpected start of new document") default: // TODO typeMANY return fmt.Errorf("unknown tag: %d (0x%x) at offset %d", tag, tag, didx) } stack[level]-- if trackme { // if tag is tracked, remember its offset doc.trackTable[docRelativeIdx] = mrgRelativeIdx doc.trackIdxs = doc.trackIdxs[1:] } } m.length += expElements m.buf = mbuf return nil } func (m *Merger) expectedElements(b []byte) (int, int, error) { if m.KeepFlat { tag0 := b[0] &^ trackFlag tag1 := b[1] &^ trackFlag switch m.TopLevelElement { case TopLevelArray: if tag0 == typeARRAY { ln, sz, err := varintdecode(b[1:]) return ln, sz + 1, err } case TopLevelArrayRef: if tag0 == typeREFN && tag1 == typeARRAY { ln, sz, err := varintdecode(b[2:]) return ln, sz + 2, err } else if tag0 >= typeARRAYREF_0 && tag0 < typeARRAYREF_0+16 { return int(tag0 & 0xF), 1, nil } } } return 1, 0, nil // by default expect only one element } func isShallowStringish(tag byte) bool { return tag == typeBINARY || tag == typeSTR_UTF8 || (tag >= typeSHORT_BINARY_0 && tag < typeSHORT_BINARY_0+32) } func readString(buf []byte) (int, []byte, error) { tag := buf[0] tag &^= trackFlag if !isShallowStringish(tag) { return 0, nil, fmt.Errorf("expected stringish but found %d (0x%x)", int(tag), int(tag)) } var ln, offset int if tag > typeSHORT_BINARY_0 { ln = int(tag & 0x1F) // get length from tag } else { var err error ln, offset, err = varintdecode(buf[1:]) if err != nil { return 0, nil, err } } offset++ // respect tag itself if ln < 0 || ln > math.MaxUint32 { return 0, nil, fmt.Errorf("bad size for string: %d", ln) } else if offset+ln > len(buf) { return 0, nil, fmt.Errorf("truncated document, expect %d bytes", len(buf)-ln-offset) } return offset, buf[offset : offset+ln], nil } func appendTagVarint(by []byte, tag byte, n uint) []uint8 { // the slice should be allocated on stack due to escape analysis varintBuf := make([]byte, binary.MaxVarintLen64) varintBuf[0] = tag idx := 1 for n >= 0x80 { varintBuf[idx] = byte(n) | 0x80 n >>= 7 idx++ } varintBuf[idx] = byte(n) return append(by, varintBuf[:idx+1]...) }