package mint import ( "bytes" "encoding/hex" "fmt" "testing" ) type ErrorReadWriter struct{} func (e ErrorReadWriter) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("Unknown read error") } func (e ErrorReadWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("Unknown write error") } func recordHeaderHex(data []byte) string { dataLen := len(data) return hex.EncodeToString([]byte{0x16, 0x03, 0x01, byte(dataLen >> 8), byte(dataLen)}) } var ( messageType = HandshakeTypeClientHello tinyMessageIn = &HandshakeMessage{ msgType: messageType, body: []byte{0, 0, 0, 0}, length: 4, } tinyMessageHex = "0100000400000000" // short: 0x000040 // long: 0x007fe0 = 0x4000 + 0x3fe0 shortMessageLen = 64 longMessageLen = 2*maxFragmentLen - (shortMessageLen / 2) shortMessageHeader = []byte{byte(messageType), 0x00, 0x00, byte(shortMessageLen)} shortMessageBody = bytes.Repeat([]byte{0xab}, shortMessageLen) shortMessage = append(shortMessageHeader, shortMessageBody...) longMessageHeader = []byte{byte(messageType), 0x00, byte(longMessageLen >> 8), byte(longMessageLen)} longMessageBody = bytes.Repeat([]byte{0xcd}, longMessageLen) longMessage = append(longMessageHeader, longMessageBody...) shortLongMessage = append(shortMessage, longMessage...) shortLongShortMessage = append(shortLongMessage, shortMessage...) shortHex = recordHeaderHex(shortMessage) + hex.EncodeToString(shortMessage) shortMessageIn = &HandshakeMessage{ msgType: messageType, body: shortMessageBody, length: uint32(len(shortMessageBody)), } longMessageIn = &HandshakeMessage{ msgType: messageType, body: longMessageBody, length: uint32(len(longMessageBody)), } tooLongMessageIn = &HandshakeMessage{ msgType: messageType, body: bytes.Repeat([]byte{0xef}, maxHandshakeMessageLen+1), } longFragment1 = longMessage[:maxFragmentLen] longFragment2 = longMessage[maxFragmentLen:] longHex = recordHeaderHex(longFragment1) + hex.EncodeToString(longFragment1) + recordHeaderHex(longFragment2) + hex.EncodeToString(longFragment2) slsFragment1 = shortLongShortMessage[:maxFragmentLen] slsFragment2 = shortLongShortMessage[maxFragmentLen : 2*maxFragmentLen] slsFragment3 = shortLongShortMessage[2*maxFragmentLen:] shortLongShortHex = recordHeaderHex(slsFragment1) + hex.EncodeToString(slsFragment1) + recordHeaderHex(slsFragment2) + hex.EncodeToString(slsFragment2) + recordHeaderHex(slsFragment3) + hex.EncodeToString(slsFragment3) insufficientDataHex = "1603010004" + "01000004" + "1603010002" + "0000" nonHandshakeHex = "15030100020000" ) func TestMessageMarshal(t *testing.T) { tinyMessage := unhex(tinyMessageHex) out := tinyMessageIn.Marshal() assertByteEquals(t, out, tinyMessage) } func newTestHandshakeMessage(t HandshakeType, m []byte) HandshakeMessage { return HandshakeMessage{ msgType: t, body: m, } } func TestMessageToBody(t *testing.T) { // Borrowing serialized bodies from handshake-messages_test.go chValid := unhex(chValidHex) shValid := unhex(shValidHex) finValid := unhex(finValidHex) encExtValid := unhex(encExtValidHex) certValid := unhex(certValidHex) certVerifyValid := unhex(certVerifyValidHex) ticketValid := unhex(ticketValidHex) // Test successful marshal of ClientHello hm := newTestHandshakeMessage(HandshakeTypeClientHello, chValid) _, err := hm.ToBody() assertNotError(t, err, "Failed to convert ClientHello body") // Test successful marshal of ServerHello hm = newTestHandshakeMessage(HandshakeTypeServerHello, shValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert ServerHello body") // Test successful marshal of EncryptedExtensions hm = newTestHandshakeMessage(HandshakeTypeEncryptedExtensions, encExtValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert EncryptedExtensions body") // Test successful marshal of Certificate hm = newTestHandshakeMessage(HandshakeTypeCertificate, certValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert Certificate body") // Test successful marshal of CertificateVerify hm = newTestHandshakeMessage(HandshakeTypeCertificateVerify, certVerifyValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert CertificateVerify body") // Test successful marshal of Finished hm = newTestHandshakeMessage(HandshakeTypeFinished, finValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert Finished body") // Test successful marshal of NewSessionTicket hm = newTestHandshakeMessage(HandshakeTypeNewSessionTicket, ticketValid) _, err = hm.ToBody() assertNotError(t, err, "Failed to convert NewSessionTicket body") // Test failure on unsupported body type hm = newTestHandshakeMessage(HandshakeTypeHelloRetryRequest, []byte{}) _, err = hm.ToBody() assertError(t, err, "Converted an unsupported message") // Test failure on marshal failure hm = newTestHandshakeMessage(HandshakeTypeClientHello, []byte{}) _, err = hm.ToBody() assertError(t, err, "Converted an empty message") } func TestMessageFromBody(t *testing.T) { chValid := unhex(chValidHex) b := bytes.NewBuffer(nil) h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) // Test successful conversion hm, err := h.HandshakeMessageFromBody(&chValidIn) assertNotError(t, err, "Failed to convert ClientHello body to message") assertEquals(t, hm.msgType, chValidIn.Type()) assertByteEquals(t, hm.body, chValid) // Test conversion failure on marshal failure chValidIn.CipherSuites = []CipherSuite{} hm, err = h.HandshakeMessageFromBody(&chValidIn) assertError(t, err, "Converted a ClientHello that should not have marshaled") chValidIn.CipherSuites = chCipherSuites } func TestReadHandshakeMessage(t *testing.T) { short := unhex(shortHex) long := unhex(longHex) shortLongShort := unhex(shortLongShortHex) insufficientData := unhex(insufficientDataHex) nonHandshake := unhex(nonHandshakeHex) // Test successful read of a message in a single record b := bytes.NewBuffer(short) h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm, err := h.ReadMessage() assertNotError(t, err, "Failed to read a short handshake message") assertDeepEquals(t, hm, shortMessageIn) // Test successful read of a message split across records b = bytes.NewBuffer(long) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm, err = h.ReadMessage() assertNotError(t, err, "Failed to read a long handshake message") assertDeepEquals(t, hm, longMessageIn) // Test successful read of multiple messages sequentially b = bytes.NewBuffer(shortLongShort) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm1, err := h.ReadMessage() assertNotError(t, err, "Failed to read first handshake message") assertDeepEquals(t, hm1, shortMessageIn) hm2, err := h.ReadMessage() assertNotError(t, err, "Failed to read second handshake message") assertDeepEquals(t, hm2, longMessageIn) hm3, err := h.ReadMessage() assertNotError(t, err, "Failed to read third handshake message") assertDeepEquals(t, hm3, shortMessageIn) // Test read failure on inability to read header b = bytes.NewBuffer(short[:handshakeHeaderLenTLS-1]) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message with an incomplete header") // Test read failure on inability to read body b = bytes.NewBuffer(insufficientData) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message with an incomplete body") // Test read failure on receiving a non-handshake record b = bytes.NewBuffer(nonHandshake) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) hm, err = h.ReadMessage() assertError(t, err, "Read handshake message from a non-handshake record") } func testWriteHandshakeMessage(h *HandshakeLayer, hm *HandshakeMessage) error { hm.cipher = h.conn.cipher return h.WriteMessage(hm) } func TestWriteHandshakeMessage(t *testing.T) { short := unhex(shortHex) long := unhex(longHex) // Test successful write of single message b := bytes.NewBuffer(nil) h := NewHandshakeLayerTLS(NewRecordLayerTLS(b)) err := testWriteHandshakeMessage(h, shortMessageIn) assertNotError(t, err, "Failed to write valid short message") assertByteEquals(t, b.Bytes(), short) // Test successful write of single long message b = bytes.NewBuffer(nil) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) err = testWriteHandshakeMessage(h, longMessageIn) assertNotError(t, err, "Failed to write valid long message") assertByteEquals(t, b.Bytes(), long) // Test write failure on message too large b = bytes.NewBuffer(nil) h = NewHandshakeLayerTLS(NewRecordLayerTLS(b)) err = testWriteHandshakeMessage(h, tooLongMessageIn) assertError(t, err, "Wrote a message exceeding the length bound") // Test write failure on underlying write failure h = NewHandshakeLayerTLS(NewRecordLayerTLS(ErrorReadWriter{})) err = testWriteHandshakeMessage(h, longMessageIn) assertError(t, err, "Write succeeded despite error in full fragment send") err = testWriteHandshakeMessage(h, shortMessageIn) assertError(t, err, "Write succeeded despite error in last fragment send") } type testReassembleFixture struct { t *testing.T h *HandshakeLayer r *RecordLayer rd *pipeConn wr *pipeConn m0 *HandshakeMessage m0f0 *HandshakeMessage m0f1 *HandshakeMessage m0f2 *HandshakeMessage m0f1x *HandshakeMessage m0f1y *HandshakeMessage m1 *HandshakeMessage } func newTestReassembleFixture(t *testing.T) *testReassembleFixture { f := testReassembleFixture{t: t} // Make two messages, m0 and m1, with m0 fragmented m0 := make([]byte, 2048) for i, _ := range m0 { m0[i] = byte(i % 13) } f.m0 = newHsFragment(m0, 0, 0, 2048) f.m0f0 = newHsFragment(m0, 0, 0, 1024) f.m0f1 = newHsFragment(m0, 0, 1024, 512) f.m0f2 = newHsFragment(m0, 0, 1536, 512) f.m0f1x = newHsFragment(m0, 0, 512, 1000) f.m0f1y = newHsFragment(m0, 0, 512, 1048) m1 := make([]byte, 2048) for i, _ := range m1 { m1[i] = byte(i % 23) } f.m1 = newHsFragment(m1, 1, 0, 2048) f.rd, f.wr = pipe() f.r = NewRecordLayerDTLS(f.rd) f.h = NewHandshakeLayerDTLS(f.r) f.h.nonblocking = true return &f } func newHsFragment(full []byte, seq uint32, offset uint32, fragLen uint32) *HandshakeMessage { return &HandshakeMessage{ HandshakeTypeClientHello, seq, full[offset : offset+fragLen], true, offset, uint32(len(full)), nil, nil, } } func (f *testReassembleFixture) addFragment(in *HandshakeMessage, expected *HandshakeMessage) { if in != nil { b := in.Marshal() r := []byte{byte(RecordTypeHandshake), 0xfe, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, byte((len(b) >> 8) & 0xff), byte(len(b) & 0xff)} r = append(r, b...) f.wr.Write(r) } h2, err := f.h.ReadMessage() if expected == nil { assertEquals(f.t, (*HandshakeMessage)(nil), h2) assertEquals(f.t, WouldBlock, err) } else { assertNotError(f.t, err, "Error reading handshake") assertEquals(f.t, expected.seq, h2.seq) assertByteEquals(f.t, expected.body, h2.body) } } func TestHandshakeDTLSInOrder(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0, f.m0) f.addFragment(f.m0, nil) // Should block f.addFragment(f.m1, f.m1) } func TestHandshakeDTLSOutOfOrder(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m1, nil) f.addFragment(f.m0, f.m0) f.addFragment(nil, f.m1) } func TestHandshakeDTLSNonOverlappingFragments(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0f0, nil) f.addFragment(f.m0f1, nil) f.addFragment(f.m0f2, f.m0) } func TestHandshakeDTLSNonOverlappingFragmentsOO(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0f0, nil) f.addFragment(f.m0f2, nil) f.addFragment(f.m0f1, f.m0) } func TestHandshakeDTLSOverlappingFragments1(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0f0, nil) f.addFragment(f.m0f1, nil) f.addFragment(f.m0f1x, nil) f.addFragment(f.m0f2, f.m0) } func TestHandshakeDTLSOverlappingFragments2(t *testing.T) { f := newTestReassembleFixture(t) f.addFragment(f.m0f0, nil) f.addFragment(f.m0f1y, nil) f.addFragment(f.m0f2, f.m0) }