route/vendor/github.com/bifurcation/mint/handshake-layer_test.go

385 lines
12 KiB
Go

package mint
import (
"bytes"
"encoding/hex"
"fmt"
"testing"
)
type ErrorReadWriter struct{}
func (e ErrorReadWriter) Read(p []byte) (n int, err error) {
return 0, fmt.Errorf("Unknown read error")
}
func (e ErrorReadWriter) Write(p []byte) (n int, err error) {
return 0, fmt.Errorf("Unknown write error")
}
func recordHeaderHex(data []byte) string {
dataLen := len(data)
return hex.EncodeToString([]byte{0x16, 0x03, 0x01, byte(dataLen >> 8), byte(dataLen)})
}
var (
messageType = HandshakeTypeClientHello
tinyMessageIn = &HandshakeMessage{
msgType: messageType,
body: []byte{0, 0, 0, 0},
length: 4,
}
tinyMessageHex = "0100000400000000"
// short: 0x000040
// long: 0x007fe0 = 0x4000 + 0x3fe0
shortMessageLen = 64
longMessageLen = 2*maxFragmentLen - (shortMessageLen / 2)
shortMessageHeader = []byte{byte(messageType), 0x00, 0x00, byte(shortMessageLen)}
shortMessageBody = bytes.Repeat([]byte{0xab}, shortMessageLen)
shortMessage = append(shortMessageHeader, shortMessageBody...)
longMessageHeader = []byte{byte(messageType), 0x00, byte(longMessageLen >> 8), byte(longMessageLen)}
longMessageBody = bytes.Repeat([]byte{0xcd}, longMessageLen)
longMessage = append(longMessageHeader, longMessageBody...)
shortLongMessage = append(shortMessage, longMessage...)
shortLongShortMessage = append(shortLongMessage, shortMessage...)
shortHex = recordHeaderHex(shortMessage) + hex.EncodeToString(shortMessage)
shortMessageIn = &HandshakeMessage{
msgType: messageType,
body: shortMessageBody,
length: uint32(len(shortMessageBody)),
}
longMessageIn = &HandshakeMessage{
msgType: messageType,
body: longMessageBody,
length: uint32(len(longMessageBody)),
}
tooLongMessageIn = &HandshakeMessage{
msgType: messageType,
body: bytes.Repeat([]byte{0xef}, maxHandshakeMessageLen+1),
}
longFragment1 = longMessage[:maxFragmentLen]
longFragment2 = longMessage[maxFragmentLen:]
longHex = recordHeaderHex(longFragment1) + hex.EncodeToString(longFragment1) +
recordHeaderHex(longFragment2) + hex.EncodeToString(longFragment2)
slsFragment1 = shortLongShortMessage[:maxFragmentLen]
slsFragment2 = shortLongShortMessage[maxFragmentLen : 2*maxFragmentLen]
slsFragment3 = shortLongShortMessage[2*maxFragmentLen:]
shortLongShortHex = recordHeaderHex(slsFragment1) + hex.EncodeToString(slsFragment1) +
recordHeaderHex(slsFragment2) + hex.EncodeToString(slsFragment2) +
recordHeaderHex(slsFragment3) + hex.EncodeToString(slsFragment3)
insufficientDataHex = "1603010004" + "01000004" + "1603010002" + "0000"
nonHandshakeHex = "15030100020000"
)
func TestMessageMarshal(t *testing.T) {
tinyMessage := unhex(tinyMessageHex)
out := tinyMessageIn.Marshal()
assertByteEquals(t, out, tinyMessage)
}
func newTestHandshakeMessage(t HandshakeType, m []byte) HandshakeMessage {
return HandshakeMessage{
msgType: t,
body: m,
}
}
func TestMessageToBody(t *testing.T) {
// Borrowing serialized bodies from handshake-messages_test.go
chValid := unhex(chValidHex)
shValid := unhex(shValidHex)
finValid := unhex(finValidHex)
encExtValid := unhex(encExtValidHex)
certValid := unhex(certValidHex)
certVerifyValid := unhex(certVerifyValidHex)
ticketValid := unhex(ticketValidHex)
// Test successful marshal of ClientHello
hm := newTestHandshakeMessage(HandshakeTypeClientHello, chValid)
_, err := hm.ToBody()
assertNotError(t, err, "Failed to convert ClientHello body")
// Test successful marshal of ServerHello
hm = newTestHandshakeMessage(HandshakeTypeServerHello, shValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert ServerHello body")
// Test successful marshal of EncryptedExtensions
hm = newTestHandshakeMessage(HandshakeTypeEncryptedExtensions, encExtValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert EncryptedExtensions body")
// Test successful marshal of Certificate
hm = newTestHandshakeMessage(HandshakeTypeCertificate, certValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert Certificate body")
// Test successful marshal of CertificateVerify
hm = newTestHandshakeMessage(HandshakeTypeCertificateVerify, certVerifyValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert CertificateVerify body")
// Test successful marshal of Finished
hm = newTestHandshakeMessage(HandshakeTypeFinished, finValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert Finished body")
// Test successful marshal of NewSessionTicket
hm = newTestHandshakeMessage(HandshakeTypeNewSessionTicket, ticketValid)
_, err = hm.ToBody()
assertNotError(t, err, "Failed to convert NewSessionTicket body")
// Test failure on unsupported body type
hm = newTestHandshakeMessage(HandshakeTypeHelloRetryRequest, []byte{})
_, err = hm.ToBody()
assertError(t, err, "Converted an unsupported message")
// Test failure on marshal failure
hm = newTestHandshakeMessage(HandshakeTypeClientHello, []byte{})
_, err = hm.ToBody()
assertError(t, err, "Converted an empty message")
}
func TestMessageFromBody(t *testing.T) {
chValid := unhex(chValidHex)
b := bytes.NewBuffer(nil)
h := NewHandshakeLayerTLS(NewRecordLayerTLS(b))
// Test successful conversion
hm, err := h.HandshakeMessageFromBody(&chValidIn)
assertNotError(t, err, "Failed to convert ClientHello body to message")
assertEquals(t, hm.msgType, chValidIn.Type())
assertByteEquals(t, hm.body, chValid)
// Test conversion failure on marshal failure
chValidIn.CipherSuites = []CipherSuite{}
hm, err = h.HandshakeMessageFromBody(&chValidIn)
assertError(t, err, "Converted a ClientHello that should not have marshaled")
chValidIn.CipherSuites = chCipherSuites
}
func TestReadHandshakeMessage(t *testing.T) {
short := unhex(shortHex)
long := unhex(longHex)
shortLongShort := unhex(shortLongShortHex)
insufficientData := unhex(insufficientDataHex)
nonHandshake := unhex(nonHandshakeHex)
// Test successful read of a message in a single record
b := bytes.NewBuffer(short)
h := NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm, err := h.ReadMessage()
assertNotError(t, err, "Failed to read a short handshake message")
assertDeepEquals(t, hm, shortMessageIn)
// Test successful read of a message split across records
b = bytes.NewBuffer(long)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm, err = h.ReadMessage()
assertNotError(t, err, "Failed to read a long handshake message")
assertDeepEquals(t, hm, longMessageIn)
// Test successful read of multiple messages sequentially
b = bytes.NewBuffer(shortLongShort)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm1, err := h.ReadMessage()
assertNotError(t, err, "Failed to read first handshake message")
assertDeepEquals(t, hm1, shortMessageIn)
hm2, err := h.ReadMessage()
assertNotError(t, err, "Failed to read second handshake message")
assertDeepEquals(t, hm2, longMessageIn)
hm3, err := h.ReadMessage()
assertNotError(t, err, "Failed to read third handshake message")
assertDeepEquals(t, hm3, shortMessageIn)
// Test read failure on inability to read header
b = bytes.NewBuffer(short[:handshakeHeaderLenTLS-1])
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm, err = h.ReadMessage()
assertError(t, err, "Read handshake message with an incomplete header")
// Test read failure on inability to read body
b = bytes.NewBuffer(insufficientData)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm, err = h.ReadMessage()
assertError(t, err, "Read handshake message with an incomplete body")
// Test read failure on receiving a non-handshake record
b = bytes.NewBuffer(nonHandshake)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
hm, err = h.ReadMessage()
assertError(t, err, "Read handshake message from a non-handshake record")
}
func testWriteHandshakeMessage(h *HandshakeLayer, hm *HandshakeMessage) error {
hm.cipher = h.conn.cipher
return h.WriteMessage(hm)
}
func TestWriteHandshakeMessage(t *testing.T) {
short := unhex(shortHex)
long := unhex(longHex)
// Test successful write of single message
b := bytes.NewBuffer(nil)
h := NewHandshakeLayerTLS(NewRecordLayerTLS(b))
err := testWriteHandshakeMessage(h, shortMessageIn)
assertNotError(t, err, "Failed to write valid short message")
assertByteEquals(t, b.Bytes(), short)
// Test successful write of single long message
b = bytes.NewBuffer(nil)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
err = testWriteHandshakeMessage(h, longMessageIn)
assertNotError(t, err, "Failed to write valid long message")
assertByteEquals(t, b.Bytes(), long)
// Test write failure on message too large
b = bytes.NewBuffer(nil)
h = NewHandshakeLayerTLS(NewRecordLayerTLS(b))
err = testWriteHandshakeMessage(h, tooLongMessageIn)
assertError(t, err, "Wrote a message exceeding the length bound")
// Test write failure on underlying write failure
h = NewHandshakeLayerTLS(NewRecordLayerTLS(ErrorReadWriter{}))
err = testWriteHandshakeMessage(h, longMessageIn)
assertError(t, err, "Write succeeded despite error in full fragment send")
err = testWriteHandshakeMessage(h, shortMessageIn)
assertError(t, err, "Write succeeded despite error in last fragment send")
}
type testReassembleFixture struct {
t *testing.T
h *HandshakeLayer
r *RecordLayer
rd *pipeConn
wr *pipeConn
m0 *HandshakeMessage
m0f0 *HandshakeMessage
m0f1 *HandshakeMessage
m0f2 *HandshakeMessage
m0f1x *HandshakeMessage
m0f1y *HandshakeMessage
m1 *HandshakeMessage
}
func newTestReassembleFixture(t *testing.T) *testReassembleFixture {
f := testReassembleFixture{t: t}
// Make two messages, m0 and m1, with m0 fragmented
m0 := make([]byte, 2048)
for i, _ := range m0 {
m0[i] = byte(i % 13)
}
f.m0 = newHsFragment(m0, 0, 0, 2048)
f.m0f0 = newHsFragment(m0, 0, 0, 1024)
f.m0f1 = newHsFragment(m0, 0, 1024, 512)
f.m0f2 = newHsFragment(m0, 0, 1536, 512)
f.m0f1x = newHsFragment(m0, 0, 512, 1000)
f.m0f1y = newHsFragment(m0, 0, 512, 1048)
m1 := make([]byte, 2048)
for i, _ := range m1 {
m1[i] = byte(i % 23)
}
f.m1 = newHsFragment(m1, 1, 0, 2048)
f.rd, f.wr = pipe()
f.r = NewRecordLayerDTLS(f.rd)
f.h = NewHandshakeLayerDTLS(f.r)
f.h.nonblocking = true
return &f
}
func newHsFragment(full []byte, seq uint32, offset uint32, fragLen uint32) *HandshakeMessage {
return &HandshakeMessage{
HandshakeTypeClientHello,
seq,
full[offset : offset+fragLen],
true,
offset,
uint32(len(full)),
nil,
nil,
}
}
func (f *testReassembleFixture) addFragment(in *HandshakeMessage, expected *HandshakeMessage) {
if in != nil {
b := in.Marshal()
r := []byte{byte(RecordTypeHandshake), 0xfe, 0xff,
0, 0, 0, 0, 0, 0, 0, 0,
byte((len(b) >> 8) & 0xff), byte(len(b) & 0xff)}
r = append(r, b...)
f.wr.Write(r)
}
h2, err := f.h.ReadMessage()
if expected == nil {
assertEquals(f.t, (*HandshakeMessage)(nil), h2)
assertEquals(f.t, WouldBlock, err)
} else {
assertNotError(f.t, err, "Error reading handshake")
assertEquals(f.t, expected.seq, h2.seq)
assertByteEquals(f.t, expected.body, h2.body)
}
}
func TestHandshakeDTLSInOrder(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m0, f.m0)
f.addFragment(f.m0, nil) // Should block
f.addFragment(f.m1, f.m1)
}
func TestHandshakeDTLSOutOfOrder(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m1, nil)
f.addFragment(f.m0, f.m0)
f.addFragment(nil, f.m1)
}
func TestHandshakeDTLSNonOverlappingFragments(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m0f0, nil)
f.addFragment(f.m0f1, nil)
f.addFragment(f.m0f2, f.m0)
}
func TestHandshakeDTLSNonOverlappingFragmentsOO(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m0f0, nil)
f.addFragment(f.m0f2, nil)
f.addFragment(f.m0f1, f.m0)
}
func TestHandshakeDTLSOverlappingFragments1(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m0f0, nil)
f.addFragment(f.m0f1, nil)
f.addFragment(f.m0f1x, nil)
f.addFragment(f.m0f2, f.m0)
}
func TestHandshakeDTLSOverlappingFragments2(t *testing.T) {
f := newTestReassembleFixture(t)
f.addFragment(f.m0f0, nil)
f.addFragment(f.m0f1y, nil)
f.addFragment(f.m0f2, f.m0)
}