route/vendor/github.com/bifurcation/mint/conn_test.go

862 lines
27 KiB
Go

package mint
import (
"bytes"
"crypto/x509"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
)
type pipeConn struct {
r *bytes.Buffer
w *bytes.Buffer
rLock *sync.Mutex
wLock *sync.Mutex
}
func pipe() (client *pipeConn, server *pipeConn) {
client = new(pipeConn)
server = new(pipeConn)
c2s := bytes.NewBuffer(nil)
server.r = c2s
client.w = c2s
c2sLock := new(sync.Mutex)
server.rLock = c2sLock
client.wLock = c2sLock
s2c := bytes.NewBuffer(nil)
client.r = s2c
server.w = s2c
s2cLock := new(sync.Mutex)
client.rLock = s2cLock
server.wLock = s2cLock
return
}
func (p *pipeConn) Read(data []byte) (n int, err error) {
p.rLock.Lock()
n, err = p.r.Read(data)
p.rLock.Unlock()
// Suppress bytes.Buffer's EOF on an empty buffer
if err == io.EOF {
err = nil
}
return
}
func (p *pipeConn) Write(data []byte) (n int, err error) {
p.wLock.Lock()
defer p.wLock.Unlock()
return p.w.Write(data)
}
func (p *pipeConn) Close() error {
return nil
}
func (p *pipeConn) LocalAddr() net.Addr { return nil }
func (p *pipeConn) RemoteAddr() net.Addr { return nil }
func (p *pipeConn) SetDeadline(t time.Time) error { return nil }
func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil }
func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil }
type bufferedConn struct {
buffer bytes.Buffer
w net.Conn
}
func (b *bufferedConn) Write(buf []byte) (n int, err error) {
return b.buffer.Write(buf)
}
func (p *bufferedConn) Read(data []byte) (n int, err error) {
return p.w.Read(data)
}
func (p *bufferedConn) Close() error {
return nil
}
func (p *bufferedConn) LocalAddr() net.Addr { return nil }
func (p *bufferedConn) RemoteAddr() net.Addr { return nil }
func (p *bufferedConn) SetDeadline(t time.Time) error { return nil }
func (p *bufferedConn) SetReadDeadline(t time.Time) error { return nil }
func (p *bufferedConn) SetWriteDeadline(t time.Time) error { return nil }
func (b *bufferedConn) Flush() error {
buf := b.buffer.Bytes()
n, err := b.w.Write(buf)
if err != nil {
return err
}
if n != len(buf) {
return fmt.Errorf("Incomplete flush")
}
b.buffer.Reset()
return nil
}
func newBufferedConn(p net.Conn) *bufferedConn {
return &bufferedConn{bytes.Buffer{}, p}
}
var (
serverName = "example.com"
// Certificate generated by Go
// * CN: "example.com"
// * SANs: "example.com", "www.example.com"
// * Random 2048-bit public key
// * Self-signed
serverCertHex = "308202f7308201dfa00302010202012a300d06092a864886f70d01010b05003016311430" +
"120603550403130b6578616d706c652e636f6d301e170d3135303130313030303030305a" +
"170d3235303130313030303030305a3016311430120603550403130b6578616d706c652e" +
"636f6d30820122300d06092a864886f70d01010105000382010f003082010a0282010100" +
"a558ff3c12b8c4906b7f638878c71963ac95548c5d36975bc575de8775a141408c449c3e" +
"7fe7eddf93329dd894ecb2705b7f79caa06f1477b7bd2d3ff32f43076dd32a7f9f97ed4d" +
"4593db3f28adbea7794c14d8d206832652e93959e2b8d2b4781fadcf55c852641482f7fc" +
"6b9e7e751442a0818c21c9cacc28e7594606ff692392510df57ce26d9c0d052f84e236b9" +
"9e3f81daa98c554607432e3bb26a5fe3fa2b5fc5e5c1fcb1d76050328b92edc80238773d" +
"16547ccc24c0784933d86b3f8d0ee33d90a1b47ecbfbaad12e77155f1b4e84b3e5c4d565" +
"1717832fcbf82886eb6f925435b4ca9f87ec207b4338f03a846fbf0f68ea0e674bf50a21" +
"d9165b690203010001a350304e300e0603551d0f0101ff04040302028430130603551d25" +
"040c300a06082b0601050507030130270603551d110420301e820b6578616d706c652e63" +
"6f6d820f7777772e6578616d706c652e636f6d300d06092a864886f70d01010b05000382" +
"01010024ed08531171df6256ed5668b962ca3740e61c20d014b3aae8ac436b200ee89b50" +
"dc5e5a74859fdf1d0f1136ad16d5fe018dac83d0e4dcfd1b5951e1502e23b7b740621b10" +
"b3376f8b3bc5494c25917b4103d88670390e33c2b089e66326316e4bbd75fd6e5dced78f" +
"79caf97d83b981950ed10449f61d826af4a6eb70e291fccdaa76145f7ba085d27698197f" +
"60e944646640ea18d5439955d91a80d4dfb1e4c12f539da9423a33f479ee19a0fa9c5339" +
"1e0d164633bea4346dc0c8081172d67ee7bca4bd5463cc147d8c062ebb31be6e9c39518c" +
"37f5607a2d6f36114800f6c6f509893fa352a468b30ad874ae56db769f1786567e9c96c1" +
"6b4a4b2a25dda3"
serverCertDER = unhex(serverCertHex)
serverCert, _ = x509.ParseCertificate(serverCertDER)
// The corresponding private key
serverKeyHex = "308204a40201000282010100a558ff3c12b8c4906b7f638878c71963ac95548c5d36975b" +
"c575de8775a141408c449c3e7fe7eddf93329dd894ecb2705b7f79caa06f1477b7bd2d3f" +
"f32f43076dd32a7f9f97ed4d4593db3f28adbea7794c14d8d206832652e93959e2b8d2b4" +
"781fadcf55c852641482f7fc6b9e7e751442a0818c21c9cacc28e7594606ff692392510d" +
"f57ce26d9c0d052f84e236b99e3f81daa98c554607432e3bb26a5fe3fa2b5fc5e5c1fcb1" +
"d76050328b92edc80238773d16547ccc24c0784933d86b3f8d0ee33d90a1b47ecbfbaad1" +
"2e77155f1b4e84b3e5c4d5651717832fcbf82886eb6f925435b4ca9f87ec207b4338f03a" +
"846fbf0f68ea0e674bf50a21d9165b6902030100010282010074f08262ec22bcf21ef4d3" +
"621b79445d981b6cd670be4141e85f3a68b72abac979eab44e078bf25222fab3640fbf6f" +
"5bc37a5e9a8de8c1a301d1cb84e4ead20f18ff35995937cbded08c878d1da9f3a2e2488a" +
"9de5bc3159135e5aef5547bdcd60ff969f825dd0d77322455cc2882f8b822eb4f1aa37e3" +
"4d88228dac37b88f3d9b671ef6b05e2f47b562265e0d09fefb01c190c7fb4b3682231cd8" +
"564c59b6cc788ff742fb040562110b1f849f1535164503b0a402399e2c6cf1c0847dd50a" +
"a917b62fc3215e4eb43d7d07fa9731a51e01f0f7b694dd002b48c0bad04b9ff34e576393" +
"c0a213a12dda4bf43a7dd4ee0563c5e0de2025eb76e049cd771c96330102818100c590bd" +
"8f226cec50c818afb3ebe7ceeacabb107ac73ac159b1eca1a194ea550a0609c432a183e2" +
"fee62dafdc0201426f90cb46f9b2fc7a9bcc2365b58177529cf78c209eb6a3afd1896466" +
"63e8462729e8bf902dc1c42c7d46c1c0c99c632f0560418604b4260a1ed8d165375c674c" +
"806c2a8e202d0b7c5a8b8717309106fb3102818100d640cae7b6adea649a8c11863a3ba8" +
"098025a972d130aecaa4db08154fd0feb8af79bf7009c1ea2a790752464e923b53b41ff4" +
"3ff84e6ddb94bfc5b157e6a21e1fefe11cc082e7e8b31d07eab5e13d7a84cdeeba24d283" +
"699a8fa5138e753e88856a033ab2153c1a8200caac28377a1d09d6318ac2e946cef879a0" +
"5acbd8e5b902818100bfe142ea189257b66190f05d3bba8951aa92a27fccadf90a076f7e" +
"cff354e040fafa534ea565f57a81ce4fa5cb60b3c8ad8570aaa5b6e7d217232dee6a0e9c" +
"f30cce510434f8a79347f0762d84735628330092a48e33dccdd381ec9f233f8574a03723" +
"55c02dcdd885d6618ab23935a8e8e52fe27a3d548a90472533ab376f910281805253fd64" +
"02875bbd22c1d5ee0d2c654a994a5f8d7622cdd7a27763e8c48ddb835e325b44930b478e" +
"e088d6ad9b7d877c87878bd494f696323d3b5f9ce0d907cca99b049686c706941d577776" +
"524365db5172cc5c0cd0339cfdbe5ac164095b691c52fb40afb3872fec6a9f767dd1ab83" +
"c306e26c9eaf02fd7eef4595fe24af4902818100b5a2294d7567283f3f4bf54be7b98785" +
"fc564f24ff2d67215ecdc7955cbf05260f48c9608a59a8ebfbedc62b4d110c1704ade704" +
"cb27a591f69752d1d6ebe21291aec29b301efe47eced0187125f741ce52b3826beac3778" +
"f3560448e91644fd52460f8c3afa1596c01e6cd2c37120d8122c09edf326988b48d98c27" +
"f788eb83"
serverKeyDER = unhex(serverKeyHex)
serverKey, _ = x509.ParsePKCS1PrivateKey(serverKeyDER)
psk = PreSharedKey{
CipherSuite: TLS_AES_128_GCM_SHA256,
IsResumption: false,
Identity: []byte{0, 1, 2, 3},
Key: []byte{4, 5, 6, 7},
}
certificates = []*Certificate{
{
Chain: []*x509.Certificate{serverCert},
PrivateKey: serverKey,
},
}
psks = &PSKMapCache{
serverName: psk,
"00010203": psk,
}
basicConfig = &Config{
ServerName: serverName,
Certificates: certificates,
}
dtlsConfig = &Config{
ServerName: serverName,
Certificates: certificates,
UseDTLS: true,
}
nbConfig = &Config{
ServerName: serverName,
Certificates: certificates,
NonBlocking: true,
}
hrrConfig = &Config{
ServerName: serverName,
Certificates: certificates,
RequireCookie: true,
}
alpnConfig = &Config{
ServerName: serverName,
Certificates: certificates,
NextProtos: []string{"http/1.1", "h2"},
}
clientAuthConfig = &Config{
ServerName: serverName,
RequireClientAuth: true,
Certificates: certificates,
}
pskConfig = &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
PSKs: psks,
AllowEarlyData: true,
}
pskECDHEConfig = &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
Certificates: certificates,
PSKs: psks,
}
pskDHEConfig = &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
Certificates: certificates,
PSKs: psks,
Groups: []NamedGroup{FFDHE2048},
}
resumptionConfig = &Config{
ServerName: serverName,
Certificates: certificates,
SendSessionTickets: true,
}
ffdhConfig = &Config{
ServerName: serverName,
Certificates: certificates,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
Groups: []NamedGroup{FFDHE2048},
}
x25519Config = &Config{
ServerName: serverName,
Certificates: certificates,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
Groups: []NamedGroup{X25519},
}
)
func assertKeySetEquals(t *testing.T, k1, k2 keySet) {
t.Helper()
// Assume cipher is the same
assertByteEquals(t, k1.iv, k2.iv)
assertByteEquals(t, k1.key, k2.key)
}
func computeExporter(t *testing.T, c *Conn, label string, context []byte, length int) []byte {
t.Helper()
res, err := c.ComputeExporter(label, context, length)
assertNotError(t, err, "Could not compute exporter")
return res
}
func TestBasicFlows(t *testing.T) {
tests := []struct {
name string
config *Config
}{
{"basic config", basicConfig},
{"HRR", hrrConfig},
{"ALPN", alpnConfig},
{"FFDH", ffdhConfig},
{"x25519", x25519Config},
}
for _, testcase := range tests {
t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) {
conf := testcase.config
cConn, sConn := pipe()
client := Client(cConn, conf)
server := Server(sConn, conf)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret)
emptyContext := []byte{}
assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21))
assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20))
assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20))
assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20))
})
}
}
func TestClientAuth(t *testing.T) {
cConn, sConn := pipe()
client := Client(cConn, clientAuthConfig)
server := Server(sConn, clientAuthConfig)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assert(t, client.state.Params.UsingClientAuth, "Session did not negotiate client auth")
}
func TestPSKFlows(t *testing.T) {
for _, conf := range []*Config{pskConfig, pskECDHEConfig, pskDHEConfig} {
cConn, sConn := pipe()
client := Client(cConn, conf)
server := Server(sConn, conf)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assert(t, client.state.Params.UsingPSK, "Session did not use the provided PSK")
}
}
func TestNonBlockingReadBeforeConnected(t *testing.T) {
conn := Client(&bufferedConn{}, &Config{NonBlocking: true})
_, err := conn.Read(make([]byte, 10))
assertEquals(t, err.Error(), "Read called before the handshake completed")
}
func TestResumption(t *testing.T) {
// Phase 1: Verify that the session ticket gets sent and stored
clientConfig := resumptionConfig.Clone()
serverConfig := resumptionConfig.Clone()
cConn1, sConn1 := pipe()
client1 := Client(cConn1, clientConfig)
server1 := Server(sConn1, serverConfig)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server1.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
server1.Write([]byte{'a'})
done <- true
}(t)
clientAlert = client1.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
tmpBuf := make([]byte, 1)
n, err := client1.Read(tmpBuf)
assertNil(t, err, "Couldn't read one byte")
assertEquals(t, 1, n)
<-done
assertDeepEquals(t, client1.state.Params, server1.state.Params)
assertCipherSuiteParamsEquals(t, client1.state.cryptoParams, server1.state.cryptoParams)
assertByteEquals(t, client1.state.resumptionSecret, server1.state.resumptionSecret)
assertByteEquals(t, client1.state.clientTrafficSecret, server1.state.clientTrafficSecret)
assertByteEquals(t, client1.state.serverTrafficSecret, server1.state.serverTrafficSecret)
assertEquals(t, clientConfig.PSKs.Size(), 1)
assertEquals(t, serverConfig.PSKs.Size(), 1)
clientCache := clientConfig.PSKs.(*PSKMapCache)
serverCache := serverConfig.PSKs.(*PSKMapCache)
var serverPSK PreSharedKey
for _, key := range *serverCache {
serverPSK = key
}
var clientPSK PreSharedKey
for _, key := range *clientCache {
clientPSK = key
}
// Ensure that the PSKs are the same, except with regard to the
// receivedAt/expiresAt times, which might differ by a little.
assertEquals(t, clientPSK.CipherSuite, serverPSK.CipherSuite)
assertEquals(t, clientPSK.IsResumption, serverPSK.IsResumption)
assertByteEquals(t, clientPSK.Identity, serverPSK.Identity)
assertByteEquals(t, clientPSK.Key, serverPSK.Key)
assertEquals(t, clientPSK.NextProto, serverPSK.NextProto)
assertEquals(t, clientPSK.TicketAgeAdd, serverPSK.TicketAgeAdd)
receivedDelta := clientPSK.ReceivedAt.Sub(serverPSK.ReceivedAt) / time.Millisecond
expiresDelta := clientPSK.ExpiresAt.Sub(serverPSK.ExpiresAt) / time.Millisecond
assert(t, receivedDelta < 10 && receivedDelta > -10, "Unequal received times")
assert(t, expiresDelta < 10 && expiresDelta > -10, "Unequal received times")
// Phase 2: Verify that the session ticket gets used as a PSK
cConn2, sConn2 := pipe()
client2 := Client(cConn2, clientConfig)
server2 := Server(sConn2, serverConfig)
go func(t *testing.T) {
serverAlert = server2.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client2.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
client2.Read(nil)
<-done
assertDeepEquals(t, client2.state.Params, server2.state.Params)
assertCipherSuiteParamsEquals(t, client2.state.cryptoParams, server2.state.cryptoParams)
assertByteEquals(t, client2.state.resumptionSecret, server2.state.resumptionSecret)
assertByteEquals(t, client2.state.clientTrafficSecret, server2.state.clientTrafficSecret)
assertByteEquals(t, client2.state.serverTrafficSecret, server2.state.serverTrafficSecret)
assert(t, client2.state.Params.UsingPSK, "Session did not use the provided PSK")
}
func Test0xRTT(t *testing.T) {
conf := pskConfig
cConn, sConn := pipe()
client := Client(cConn, conf)
client.EarlyData = []byte("hello 0xRTT world!")
server := Server(sConn, conf)
done := make(chan bool)
go func(t *testing.T) {
alert := server.Handshake()
assertEquals(t, alert, AlertNoAlert)
done <- true
}(t)
alert := client.Handshake()
assertEquals(t, alert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
assert(t, client.state.Params.UsingEarlyData, "Session did not negotiate early data")
assertByteEquals(t, client.EarlyData, server.EarlyData)
}
func Test0xRTTFailure(t *testing.T) {
// Client thinks it has a PSK
clientConfig := &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
PSKs: psks,
}
// Server doesn't
serverConfig := &Config{
ServerName: serverName,
CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256},
}
cConn, sConn := pipe()
client := Client(cConn, clientConfig)
client.EarlyData = []byte("hello 0xRTT world!")
server := Server(sConn, serverConfig)
done := make(chan bool)
go func(t *testing.T) {
alert := server.Handshake()
assertEquals(t, alert, AlertNoAlert)
done <- true
}(t)
alert := client.Handshake()
assertEquals(t, alert, AlertNoAlert)
<-done
}
func TestKeyUpdate(t *testing.T) {
cConn, sConn := pipe()
conf := basicConfig
client := Client(cConn, conf)
server := Server(sConn, conf)
oneBuf := []byte{'a'}
c2s := make(chan bool)
s2c := make(chan bool)
go func(t *testing.T) {
alert := server.Handshake()
assertEquals(t, alert, AlertNoAlert)
// Send a single byte so that the client can consume NST.
server.Write(oneBuf)
s2c <- true
// Test server-initiated KeyUpdate
<-c2s
err := server.SendKeyUpdate(false)
assertNotError(t, err, "Key update send failed")
// Write a single byte so that the client can read it
// after KeyUpdate.
server.Write(oneBuf)
s2c <- true
// Null read to trigger key update
<-c2s
server.Read(oneBuf)
s2c <- true
// Null read to trigger key update and KeyUpdate response
<-c2s
server.Read(oneBuf)
server.Write(oneBuf)
s2c <- true
}(t)
alert := client.Handshake()
assertEquals(t, alert, AlertNoAlert)
// Read NST.
client.Read(oneBuf)
<-s2c
clientState0 := client.state
serverState0 := server.state
assertByteEquals(t, clientState0.serverTrafficSecret, serverState0.serverTrafficSecret)
assertByteEquals(t, clientState0.clientTrafficSecret, serverState0.clientTrafficSecret)
// Null read to trigger key update
c2s <- true
<-s2c
client.Read(oneBuf)
logf(logTypeHandshake, "Client read key update")
clientState1 := client.state
serverState1 := server.state
assertByteEquals(t, clientState1.serverTrafficSecret, serverState1.serverTrafficSecret)
assertByteEquals(t, clientState1.clientTrafficSecret, serverState1.clientTrafficSecret)
assertNotByteEquals(t, serverState0.serverTrafficSecret, serverState1.serverTrafficSecret)
assertByteEquals(t, clientState0.clientTrafficSecret, clientState1.clientTrafficSecret)
// Test client-initiated KeyUpdate
client.SendKeyUpdate(false)
client.Write(oneBuf)
c2s <- true
<-s2c
clientState2 := client.state
serverState2 := server.state
assertByteEquals(t, clientState2.serverTrafficSecret, serverState2.serverTrafficSecret)
assertByteEquals(t, clientState2.clientTrafficSecret, serverState2.clientTrafficSecret)
assertByteEquals(t, serverState1.serverTrafficSecret, serverState2.serverTrafficSecret)
assertNotByteEquals(t, clientState1.clientTrafficSecret, clientState2.clientTrafficSecret)
// Test client-initiated with keyUpdateRequested
client.SendKeyUpdate(true)
client.Write(oneBuf)
c2s <- true
<-s2c
client.Read(oneBuf)
clientState3 := client.state
serverState3 := server.state
assertByteEquals(t, clientState3.serverTrafficSecret, serverState3.serverTrafficSecret)
assertByteEquals(t, clientState3.clientTrafficSecret, serverState3.clientTrafficSecret)
assertNotByteEquals(t, serverState2.serverTrafficSecret, serverState3.serverTrafficSecret)
assertNotByteEquals(t, clientState2.clientTrafficSecret, clientState3.clientTrafficSecret)
}
func TestNonblockingHandshakeAndDataFlow(t *testing.T) {
cConn, sConn := pipe()
// Wrap these in a buffer so we can simulate blocking
cbConn := newBufferedConn(cConn)
sbConn := newBufferedConn(sConn)
client := Client(cbConn, nbConfig)
server := Server(sbConn, nbConfig)
var clientAlert, serverAlert Alert
// Send ClientHello
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
assertEquals(t, client.GetHsState(), StateClientWaitSH)
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertWouldBlock)
assertEquals(t, server.GetHsState(), StateServerStart)
// Release ClientHello
cbConn.Flush()
// Process ClientHello, send server first flight.
states := []State{StateServerNegotiated, StateServerWaitFlight2, StateServerWaitFinished}
for _, state := range states {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
assertEquals(t, server.GetHsState(), state)
}
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertWouldBlock)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertWouldBlock)
// Release server first flight
sbConn.Flush()
states = []State{StateClientWaitEE, StateClientWaitCertCR, StateClientWaitCV, StateClientWaitFinished, StateClientConnected}
for _, state := range states {
clientAlert = client.Handshake()
assertEquals(t, client.GetHsState(), state)
assertEquals(t, clientAlert, AlertNoAlert)
}
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertWouldBlock)
assertEquals(t, server.GetHsState(), StateServerWaitFinished)
// Release client's second flight.
cbConn.Flush()
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
assertEquals(t, server.GetHsState(), StateServerConnected)
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
buf := []byte{'a', 'b', 'c'}
n, err := client.Write(buf)
assertNotError(t, err, "Couldn't write")
assertEquals(t, n, len(buf))
// read := make([]byte, 5)
// n, err = server.Read(buf)
}
type testExtensionHandler struct {
sent map[HandshakeType]bool
rcvd map[HandshakeType]bool
}
func newTestExtensionHandler() *testExtensionHandler {
return &testExtensionHandler{
make(map[HandshakeType]bool),
make(map[HandshakeType]bool),
}
}
type testExtensionBody struct {
t HandshakeType
}
const (
testExtensionType = ExtensionType(240) // Dummy type.
)
func (t testExtensionBody) Type() ExtensionType {
return testExtensionType
}
func (t testExtensionBody) Marshal() ([]byte, error) {
return []byte{byte(t.t)}, nil
}
func (t *testExtensionBody) Unmarshal(data []byte) (int, error) {
if len(data) != 1 {
return 0, fmt.Errorf("Illegal length")
}
t.t = HandshakeType(data[0])
return 1, nil
}
func (t *testExtensionHandler) Send(hs HandshakeType, el *ExtensionList) error {
t.sent[hs] = true
el.Add(&testExtensionBody{t: hs})
return nil
}
func (t *testExtensionHandler) Receive(hs HandshakeType, el *ExtensionList) error {
var body testExtensionBody
ok, _ := el.Find(&body)
if !ok {
return fmt.Errorf("Couldn't find extension")
}
if hs != body.t {
return fmt.Errorf("Does not match hs type")
}
t.rcvd[hs] = true
return nil
}
func (h *testExtensionHandler) Check(t *testing.T, hs []HandshakeType) {
assertEquals(t, len(hs), len(h.sent))
assertEquals(t, len(hs), len(h.rcvd))
for _, ht := range hs {
v, ok := h.sent[ht]
assert(t, ok, "Cannot find handshake type in sent")
assert(t, v, "Value wasn't true in sent")
v, ok = h.rcvd[ht]
assert(t, ok, "Cannot find handshake type in rcvd")
assert(t, v, "Value wasn't true in rcvd")
}
}
func TestExternalExtensions(t *testing.T) {
cConn, sConn := pipe()
var handler = newTestExtensionHandler()
client := Client(cConn, basicConfig)
client.SetExtensionHandler(handler)
server := Server(sConn, basicConfig)
server.SetExtensionHandler(handler)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
handler.Check(t, []HandshakeType{
HandshakeTypeClientHello,
HandshakeTypeServerHello,
HandshakeTypeEncryptedExtensions,
})
}
func TestDTLS(t *testing.T) {
cConn, sConn := pipe()
var handler = newTestExtensionHandler()
client := Client(cConn, dtlsConfig)
client.SetExtensionHandler(handler)
server := Server(sConn, dtlsConfig)
server.SetExtensionHandler(handler)
var clientAlert, serverAlert Alert
done := make(chan bool)
go func(t *testing.T) {
serverAlert = server.Handshake()
assertEquals(t, serverAlert, AlertNoAlert)
done <- true
}(t)
clientAlert = client.Handshake()
assertEquals(t, clientAlert, AlertNoAlert)
<-done
assertDeepEquals(t, client.state.Params, server.state.Params)
assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams)
assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret)
assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret)
assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret)
handler.Check(t, []HandshakeType{
HandshakeTypeClientHello,
HandshakeTypeServerHello,
HandshakeTypeEncryptedExtensions,
})
}