package mint import ( "bytes" "crypto/x509" "fmt" "io" "net" "sync" "testing" "time" ) type pipeConn struct { r *bytes.Buffer w *bytes.Buffer rLock *sync.Mutex wLock *sync.Mutex } func pipe() (client *pipeConn, server *pipeConn) { client = new(pipeConn) server = new(pipeConn) c2s := bytes.NewBuffer(nil) server.r = c2s client.w = c2s c2sLock := new(sync.Mutex) server.rLock = c2sLock client.wLock = c2sLock s2c := bytes.NewBuffer(nil) client.r = s2c server.w = s2c s2cLock := new(sync.Mutex) client.rLock = s2cLock server.wLock = s2cLock return } func (p *pipeConn) Read(data []byte) (n int, err error) { p.rLock.Lock() n, err = p.r.Read(data) p.rLock.Unlock() // Suppress bytes.Buffer's EOF on an empty buffer if err == io.EOF { err = nil } return } func (p *pipeConn) Write(data []byte) (n int, err error) { p.wLock.Lock() defer p.wLock.Unlock() return p.w.Write(data) } func (p *pipeConn) Close() error { return nil } func (p *pipeConn) LocalAddr() net.Addr { return nil } func (p *pipeConn) RemoteAddr() net.Addr { return nil } func (p *pipeConn) SetDeadline(t time.Time) error { return nil } func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil } func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil } type bufferedConn struct { buffer bytes.Buffer w net.Conn } func (b *bufferedConn) Write(buf []byte) (n int, err error) { return b.buffer.Write(buf) } func (p *bufferedConn) Read(data []byte) (n int, err error) { return p.w.Read(data) } func (p *bufferedConn) Close() error { return nil } func (p *bufferedConn) LocalAddr() net.Addr { return nil } func (p *bufferedConn) RemoteAddr() net.Addr { return nil } func (p *bufferedConn) SetDeadline(t time.Time) error { return nil } func (p *bufferedConn) SetReadDeadline(t time.Time) error { return nil } func (p *bufferedConn) SetWriteDeadline(t time.Time) error { return nil } func (b *bufferedConn) Flush() error { buf := b.buffer.Bytes() n, err := b.w.Write(buf) if err != nil { return err } if n != len(buf) { return fmt.Errorf("Incomplete flush") } b.buffer.Reset() return nil } func newBufferedConn(p net.Conn) *bufferedConn { return &bufferedConn{bytes.Buffer{}, p} } var ( serverName = "example.com" // Certificate generated by Go // * CN: "example.com" // * SANs: "example.com", "www.example.com" // * Random 2048-bit public key // * Self-signed serverCertHex = "308202f7308201dfa00302010202012a300d06092a864886f70d01010b05003016311430" + "120603550403130b6578616d706c652e636f6d301e170d3135303130313030303030305a" + "170d3235303130313030303030305a3016311430120603550403130b6578616d706c652e" + "636f6d30820122300d06092a864886f70d01010105000382010f003082010a0282010100" + "a558ff3c12b8c4906b7f638878c71963ac95548c5d36975bc575de8775a141408c449c3e" + "7fe7eddf93329dd894ecb2705b7f79caa06f1477b7bd2d3ff32f43076dd32a7f9f97ed4d" + "4593db3f28adbea7794c14d8d206832652e93959e2b8d2b4781fadcf55c852641482f7fc" + "6b9e7e751442a0818c21c9cacc28e7594606ff692392510df57ce26d9c0d052f84e236b9" + "9e3f81daa98c554607432e3bb26a5fe3fa2b5fc5e5c1fcb1d76050328b92edc80238773d" + "16547ccc24c0784933d86b3f8d0ee33d90a1b47ecbfbaad12e77155f1b4e84b3e5c4d565" + "1717832fcbf82886eb6f925435b4ca9f87ec207b4338f03a846fbf0f68ea0e674bf50a21" + "d9165b690203010001a350304e300e0603551d0f0101ff04040302028430130603551d25" + "040c300a06082b0601050507030130270603551d110420301e820b6578616d706c652e63" + "6f6d820f7777772e6578616d706c652e636f6d300d06092a864886f70d01010b05000382" + "01010024ed08531171df6256ed5668b962ca3740e61c20d014b3aae8ac436b200ee89b50" + "dc5e5a74859fdf1d0f1136ad16d5fe018dac83d0e4dcfd1b5951e1502e23b7b740621b10" + "b3376f8b3bc5494c25917b4103d88670390e33c2b089e66326316e4bbd75fd6e5dced78f" + "79caf97d83b981950ed10449f61d826af4a6eb70e291fccdaa76145f7ba085d27698197f" + "60e944646640ea18d5439955d91a80d4dfb1e4c12f539da9423a33f479ee19a0fa9c5339" + "1e0d164633bea4346dc0c8081172d67ee7bca4bd5463cc147d8c062ebb31be6e9c39518c" + "37f5607a2d6f36114800f6c6f509893fa352a468b30ad874ae56db769f1786567e9c96c1" + "6b4a4b2a25dda3" serverCertDER = unhex(serverCertHex) serverCert, _ = x509.ParseCertificate(serverCertDER) // The corresponding private key serverKeyHex = "308204a40201000282010100a558ff3c12b8c4906b7f638878c71963ac95548c5d36975b" + "c575de8775a141408c449c3e7fe7eddf93329dd894ecb2705b7f79caa06f1477b7bd2d3f" + "f32f43076dd32a7f9f97ed4d4593db3f28adbea7794c14d8d206832652e93959e2b8d2b4" + "781fadcf55c852641482f7fc6b9e7e751442a0818c21c9cacc28e7594606ff692392510d" + "f57ce26d9c0d052f84e236b99e3f81daa98c554607432e3bb26a5fe3fa2b5fc5e5c1fcb1" + "d76050328b92edc80238773d16547ccc24c0784933d86b3f8d0ee33d90a1b47ecbfbaad1" + "2e77155f1b4e84b3e5c4d5651717832fcbf82886eb6f925435b4ca9f87ec207b4338f03a" + "846fbf0f68ea0e674bf50a21d9165b6902030100010282010074f08262ec22bcf21ef4d3" + "621b79445d981b6cd670be4141e85f3a68b72abac979eab44e078bf25222fab3640fbf6f" + "5bc37a5e9a8de8c1a301d1cb84e4ead20f18ff35995937cbded08c878d1da9f3a2e2488a" + "9de5bc3159135e5aef5547bdcd60ff969f825dd0d77322455cc2882f8b822eb4f1aa37e3" + "4d88228dac37b88f3d9b671ef6b05e2f47b562265e0d09fefb01c190c7fb4b3682231cd8" + "564c59b6cc788ff742fb040562110b1f849f1535164503b0a402399e2c6cf1c0847dd50a" + "a917b62fc3215e4eb43d7d07fa9731a51e01f0f7b694dd002b48c0bad04b9ff34e576393" + "c0a213a12dda4bf43a7dd4ee0563c5e0de2025eb76e049cd771c96330102818100c590bd" + "8f226cec50c818afb3ebe7ceeacabb107ac73ac159b1eca1a194ea550a0609c432a183e2" + "fee62dafdc0201426f90cb46f9b2fc7a9bcc2365b58177529cf78c209eb6a3afd1896466" + "63e8462729e8bf902dc1c42c7d46c1c0c99c632f0560418604b4260a1ed8d165375c674c" + "806c2a8e202d0b7c5a8b8717309106fb3102818100d640cae7b6adea649a8c11863a3ba8" + "098025a972d130aecaa4db08154fd0feb8af79bf7009c1ea2a790752464e923b53b41ff4" + "3ff84e6ddb94bfc5b157e6a21e1fefe11cc082e7e8b31d07eab5e13d7a84cdeeba24d283" + "699a8fa5138e753e88856a033ab2153c1a8200caac28377a1d09d6318ac2e946cef879a0" + "5acbd8e5b902818100bfe142ea189257b66190f05d3bba8951aa92a27fccadf90a076f7e" + "cff354e040fafa534ea565f57a81ce4fa5cb60b3c8ad8570aaa5b6e7d217232dee6a0e9c" + "f30cce510434f8a79347f0762d84735628330092a48e33dccdd381ec9f233f8574a03723" + "55c02dcdd885d6618ab23935a8e8e52fe27a3d548a90472533ab376f910281805253fd64" + "02875bbd22c1d5ee0d2c654a994a5f8d7622cdd7a27763e8c48ddb835e325b44930b478e" + "e088d6ad9b7d877c87878bd494f696323d3b5f9ce0d907cca99b049686c706941d577776" + "524365db5172cc5c0cd0339cfdbe5ac164095b691c52fb40afb3872fec6a9f767dd1ab83" + "c306e26c9eaf02fd7eef4595fe24af4902818100b5a2294d7567283f3f4bf54be7b98785" + "fc564f24ff2d67215ecdc7955cbf05260f48c9608a59a8ebfbedc62b4d110c1704ade704" + "cb27a591f69752d1d6ebe21291aec29b301efe47eced0187125f741ce52b3826beac3778" + "f3560448e91644fd52460f8c3afa1596c01e6cd2c37120d8122c09edf326988b48d98c27" + "f788eb83" serverKeyDER = unhex(serverKeyHex) serverKey, _ = x509.ParsePKCS1PrivateKey(serverKeyDER) psk = PreSharedKey{ CipherSuite: TLS_AES_128_GCM_SHA256, IsResumption: false, Identity: []byte{0, 1, 2, 3}, Key: []byte{4, 5, 6, 7}, } certificates = []*Certificate{ { Chain: []*x509.Certificate{serverCert}, PrivateKey: serverKey, }, } psks = &PSKMapCache{ serverName: psk, "00010203": psk, } basicConfig = &Config{ ServerName: serverName, Certificates: certificates, } dtlsConfig = &Config{ ServerName: serverName, Certificates: certificates, UseDTLS: true, } nbConfig = &Config{ ServerName: serverName, Certificates: certificates, NonBlocking: true, } hrrConfig = &Config{ ServerName: serverName, Certificates: certificates, RequireCookie: true, } alpnConfig = &Config{ ServerName: serverName, Certificates: certificates, NextProtos: []string{"http/1.1", "h2"}, } clientAuthConfig = &Config{ ServerName: serverName, RequireClientAuth: true, Certificates: certificates, } pskConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, PSKs: psks, AllowEarlyData: true, } pskECDHEConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, Certificates: certificates, PSKs: psks, } pskDHEConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, Certificates: certificates, PSKs: psks, Groups: []NamedGroup{FFDHE2048}, } resumptionConfig = &Config{ ServerName: serverName, Certificates: certificates, SendSessionTickets: true, } ffdhConfig = &Config{ ServerName: serverName, Certificates: certificates, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, Groups: []NamedGroup{FFDHE2048}, } x25519Config = &Config{ ServerName: serverName, Certificates: certificates, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, Groups: []NamedGroup{X25519}, } ) func assertKeySetEquals(t *testing.T, k1, k2 keySet) { t.Helper() // Assume cipher is the same assertByteEquals(t, k1.iv, k2.iv) assertByteEquals(t, k1.key, k2.key) } func computeExporter(t *testing.T, c *Conn, label string, context []byte, length int) []byte { t.Helper() res, err := c.ComputeExporter(label, context, length) assertNotError(t, err, "Could not compute exporter") return res } func TestBasicFlows(t *testing.T) { tests := []struct { name string config *Config }{ {"basic config", basicConfig}, {"HRR", hrrConfig}, {"ALPN", alpnConfig}, {"FFDH", ffdhConfig}, {"x25519", x25519Config}, } for _, testcase := range tests { t.Run(fmt.Sprintf("with %s", testcase.name), func(t *testing.T) { conf := testcase.config cConn, sConn := pipe() client := Client(cConn, conf) server := Server(sConn, conf) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) assertByteEquals(t, client.state.exporterSecret, server.state.exporterSecret) emptyContext := []byte{} assertByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 20)) assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "E", emptyContext, 21)) assertNotByteEquals(t, computeExporter(t, client, "E", emptyContext, 20), computeExporter(t, server, "F", emptyContext, 20)) assertByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'A'}, 20)) assertNotByteEquals(t, computeExporter(t, client, "E", []byte{'A'}, 20), computeExporter(t, server, "E", []byte{'B'}, 20)) }) } } func TestClientAuth(t *testing.T) { cConn, sConn := pipe() client := Client(cConn, clientAuthConfig) server := Server(sConn, clientAuthConfig) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) assert(t, client.state.Params.UsingClientAuth, "Session did not negotiate client auth") } func TestPSKFlows(t *testing.T) { for _, conf := range []*Config{pskConfig, pskECDHEConfig, pskDHEConfig} { cConn, sConn := pipe() client := Client(cConn, conf) server := Server(sConn, conf) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) assert(t, client.state.Params.UsingPSK, "Session did not use the provided PSK") } } func TestNonBlockingReadBeforeConnected(t *testing.T) { conn := Client(&bufferedConn{}, &Config{NonBlocking: true}) _, err := conn.Read(make([]byte, 10)) assertEquals(t, err.Error(), "Read called before the handshake completed") } func TestResumption(t *testing.T) { // Phase 1: Verify that the session ticket gets sent and stored clientConfig := resumptionConfig.Clone() serverConfig := resumptionConfig.Clone() cConn1, sConn1 := pipe() client1 := Client(cConn1, clientConfig) server1 := Server(sConn1, serverConfig) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server1.Handshake() assertEquals(t, serverAlert, AlertNoAlert) server1.Write([]byte{'a'}) done <- true }(t) clientAlert = client1.Handshake() assertEquals(t, clientAlert, AlertNoAlert) tmpBuf := make([]byte, 1) n, err := client1.Read(tmpBuf) assertNil(t, err, "Couldn't read one byte") assertEquals(t, 1, n) <-done assertDeepEquals(t, client1.state.Params, server1.state.Params) assertCipherSuiteParamsEquals(t, client1.state.cryptoParams, server1.state.cryptoParams) assertByteEquals(t, client1.state.resumptionSecret, server1.state.resumptionSecret) assertByteEquals(t, client1.state.clientTrafficSecret, server1.state.clientTrafficSecret) assertByteEquals(t, client1.state.serverTrafficSecret, server1.state.serverTrafficSecret) assertEquals(t, clientConfig.PSKs.Size(), 1) assertEquals(t, serverConfig.PSKs.Size(), 1) clientCache := clientConfig.PSKs.(*PSKMapCache) serverCache := serverConfig.PSKs.(*PSKMapCache) var serverPSK PreSharedKey for _, key := range *serverCache { serverPSK = key } var clientPSK PreSharedKey for _, key := range *clientCache { clientPSK = key } // Ensure that the PSKs are the same, except with regard to the // receivedAt/expiresAt times, which might differ by a little. assertEquals(t, clientPSK.CipherSuite, serverPSK.CipherSuite) assertEquals(t, clientPSK.IsResumption, serverPSK.IsResumption) assertByteEquals(t, clientPSK.Identity, serverPSK.Identity) assertByteEquals(t, clientPSK.Key, serverPSK.Key) assertEquals(t, clientPSK.NextProto, serverPSK.NextProto) assertEquals(t, clientPSK.TicketAgeAdd, serverPSK.TicketAgeAdd) receivedDelta := clientPSK.ReceivedAt.Sub(serverPSK.ReceivedAt) / time.Millisecond expiresDelta := clientPSK.ExpiresAt.Sub(serverPSK.ExpiresAt) / time.Millisecond assert(t, receivedDelta < 10 && receivedDelta > -10, "Unequal received times") assert(t, expiresDelta < 10 && expiresDelta > -10, "Unequal received times") // Phase 2: Verify that the session ticket gets used as a PSK cConn2, sConn2 := pipe() client2 := Client(cConn2, clientConfig) server2 := Server(sConn2, serverConfig) go func(t *testing.T) { serverAlert = server2.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client2.Handshake() assertEquals(t, clientAlert, AlertNoAlert) client2.Read(nil) <-done assertDeepEquals(t, client2.state.Params, server2.state.Params) assertCipherSuiteParamsEquals(t, client2.state.cryptoParams, server2.state.cryptoParams) assertByteEquals(t, client2.state.resumptionSecret, server2.state.resumptionSecret) assertByteEquals(t, client2.state.clientTrafficSecret, server2.state.clientTrafficSecret) assertByteEquals(t, client2.state.serverTrafficSecret, server2.state.serverTrafficSecret) assert(t, client2.state.Params.UsingPSK, "Session did not use the provided PSK") } func Test0xRTT(t *testing.T) { conf := pskConfig cConn, sConn := pipe() client := Client(cConn, conf) client.EarlyData = []byte("hello 0xRTT world!") server := Server(sConn, conf) done := make(chan bool) go func(t *testing.T) { alert := server.Handshake() assertEquals(t, alert, AlertNoAlert) done <- true }(t) alert := client.Handshake() assertEquals(t, alert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) assert(t, client.state.Params.UsingEarlyData, "Session did not negotiate early data") assertByteEquals(t, client.EarlyData, server.EarlyData) } func Test0xRTTFailure(t *testing.T) { // Client thinks it has a PSK clientConfig := &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, PSKs: psks, } // Server doesn't serverConfig := &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, } cConn, sConn := pipe() client := Client(cConn, clientConfig) client.EarlyData = []byte("hello 0xRTT world!") server := Server(sConn, serverConfig) done := make(chan bool) go func(t *testing.T) { alert := server.Handshake() assertEquals(t, alert, AlertNoAlert) done <- true }(t) alert := client.Handshake() assertEquals(t, alert, AlertNoAlert) <-done } func TestKeyUpdate(t *testing.T) { cConn, sConn := pipe() conf := basicConfig client := Client(cConn, conf) server := Server(sConn, conf) oneBuf := []byte{'a'} c2s := make(chan bool) s2c := make(chan bool) go func(t *testing.T) { alert := server.Handshake() assertEquals(t, alert, AlertNoAlert) // Send a single byte so that the client can consume NST. server.Write(oneBuf) s2c <- true // Test server-initiated KeyUpdate <-c2s err := server.SendKeyUpdate(false) assertNotError(t, err, "Key update send failed") // Write a single byte so that the client can read it // after KeyUpdate. server.Write(oneBuf) s2c <- true // Null read to trigger key update <-c2s server.Read(oneBuf) s2c <- true // Null read to trigger key update and KeyUpdate response <-c2s server.Read(oneBuf) server.Write(oneBuf) s2c <- true }(t) alert := client.Handshake() assertEquals(t, alert, AlertNoAlert) // Read NST. client.Read(oneBuf) <-s2c clientState0 := client.state serverState0 := server.state assertByteEquals(t, clientState0.serverTrafficSecret, serverState0.serverTrafficSecret) assertByteEquals(t, clientState0.clientTrafficSecret, serverState0.clientTrafficSecret) // Null read to trigger key update c2s <- true <-s2c client.Read(oneBuf) logf(logTypeHandshake, "Client read key update") clientState1 := client.state serverState1 := server.state assertByteEquals(t, clientState1.serverTrafficSecret, serverState1.serverTrafficSecret) assertByteEquals(t, clientState1.clientTrafficSecret, serverState1.clientTrafficSecret) assertNotByteEquals(t, serverState0.serverTrafficSecret, serverState1.serverTrafficSecret) assertByteEquals(t, clientState0.clientTrafficSecret, clientState1.clientTrafficSecret) // Test client-initiated KeyUpdate client.SendKeyUpdate(false) client.Write(oneBuf) c2s <- true <-s2c clientState2 := client.state serverState2 := server.state assertByteEquals(t, clientState2.serverTrafficSecret, serverState2.serverTrafficSecret) assertByteEquals(t, clientState2.clientTrafficSecret, serverState2.clientTrafficSecret) assertByteEquals(t, serverState1.serverTrafficSecret, serverState2.serverTrafficSecret) assertNotByteEquals(t, clientState1.clientTrafficSecret, clientState2.clientTrafficSecret) // Test client-initiated with keyUpdateRequested client.SendKeyUpdate(true) client.Write(oneBuf) c2s <- true <-s2c client.Read(oneBuf) clientState3 := client.state serverState3 := server.state assertByteEquals(t, clientState3.serverTrafficSecret, serverState3.serverTrafficSecret) assertByteEquals(t, clientState3.clientTrafficSecret, serverState3.clientTrafficSecret) assertNotByteEquals(t, serverState2.serverTrafficSecret, serverState3.serverTrafficSecret) assertNotByteEquals(t, clientState2.clientTrafficSecret, clientState3.clientTrafficSecret) } func TestNonblockingHandshakeAndDataFlow(t *testing.T) { cConn, sConn := pipe() // Wrap these in a buffer so we can simulate blocking cbConn := newBufferedConn(cConn) sbConn := newBufferedConn(sConn) client := Client(cbConn, nbConfig) server := Server(sbConn, nbConfig) var clientAlert, serverAlert Alert // Send ClientHello clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) assertEquals(t, client.GetHsState(), StateClientWaitSH) serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertWouldBlock) assertEquals(t, server.GetHsState(), StateServerStart) // Release ClientHello cbConn.Flush() // Process ClientHello, send server first flight. states := []State{StateServerNegotiated, StateServerWaitFlight2, StateServerWaitFinished} for _, state := range states { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) assertEquals(t, server.GetHsState(), state) } serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertWouldBlock) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertWouldBlock) // Release server first flight sbConn.Flush() states = []State{StateClientWaitEE, StateClientWaitCertCR, StateClientWaitCV, StateClientWaitFinished, StateClientConnected} for _, state := range states { clientAlert = client.Handshake() assertEquals(t, client.GetHsState(), state) assertEquals(t, clientAlert, AlertNoAlert) } serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertWouldBlock) assertEquals(t, server.GetHsState(), StateServerWaitFinished) // Release client's second flight. cbConn.Flush() serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) assertEquals(t, server.GetHsState(), StateServerConnected) assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) buf := []byte{'a', 'b', 'c'} n, err := client.Write(buf) assertNotError(t, err, "Couldn't write") assertEquals(t, n, len(buf)) // read := make([]byte, 5) // n, err = server.Read(buf) } type testExtensionHandler struct { sent map[HandshakeType]bool rcvd map[HandshakeType]bool } func newTestExtensionHandler() *testExtensionHandler { return &testExtensionHandler{ make(map[HandshakeType]bool), make(map[HandshakeType]bool), } } type testExtensionBody struct { t HandshakeType } const ( testExtensionType = ExtensionType(240) // Dummy type. ) func (t testExtensionBody) Type() ExtensionType { return testExtensionType } func (t testExtensionBody) Marshal() ([]byte, error) { return []byte{byte(t.t)}, nil } func (t *testExtensionBody) Unmarshal(data []byte) (int, error) { if len(data) != 1 { return 0, fmt.Errorf("Illegal length") } t.t = HandshakeType(data[0]) return 1, nil } func (t *testExtensionHandler) Send(hs HandshakeType, el *ExtensionList) error { t.sent[hs] = true el.Add(&testExtensionBody{t: hs}) return nil } func (t *testExtensionHandler) Receive(hs HandshakeType, el *ExtensionList) error { var body testExtensionBody ok, _ := el.Find(&body) if !ok { return fmt.Errorf("Couldn't find extension") } if hs != body.t { return fmt.Errorf("Does not match hs type") } t.rcvd[hs] = true return nil } func (h *testExtensionHandler) Check(t *testing.T, hs []HandshakeType) { assertEquals(t, len(hs), len(h.sent)) assertEquals(t, len(hs), len(h.rcvd)) for _, ht := range hs { v, ok := h.sent[ht] assert(t, ok, "Cannot find handshake type in sent") assert(t, v, "Value wasn't true in sent") v, ok = h.rcvd[ht] assert(t, ok, "Cannot find handshake type in rcvd") assert(t, v, "Value wasn't true in rcvd") } } func TestExternalExtensions(t *testing.T) { cConn, sConn := pipe() var handler = newTestExtensionHandler() client := Client(cConn, basicConfig) client.SetExtensionHandler(handler) server := Server(sConn, basicConfig) server.SetExtensionHandler(handler) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) handler.Check(t, []HandshakeType{ HandshakeTypeClientHello, HandshakeTypeServerHello, HandshakeTypeEncryptedExtensions, }) } func TestDTLS(t *testing.T) { cConn, sConn := pipe() var handler = newTestExtensionHandler() client := Client(cConn, dtlsConfig) client.SetExtensionHandler(handler) server := Server(sConn, dtlsConfig) server.SetExtensionHandler(handler) var clientAlert, serverAlert Alert done := make(chan bool) go func(t *testing.T) { serverAlert = server.Handshake() assertEquals(t, serverAlert, AlertNoAlert) done <- true }(t) clientAlert = client.Handshake() assertEquals(t, clientAlert, AlertNoAlert) <-done assertDeepEquals(t, client.state.Params, server.state.Params) assertCipherSuiteParamsEquals(t, client.state.cryptoParams, server.state.cryptoParams) assertByteEquals(t, client.state.resumptionSecret, server.state.resumptionSecret) assertByteEquals(t, client.state.clientTrafficSecret, server.state.clientTrafficSecret) assertByteEquals(t, client.state.serverTrafficSecret, server.state.serverTrafficSecret) handler.Check(t, []HandshakeType{ HandshakeTypeClientHello, HandshakeTypeServerHello, HandshakeTypeEncryptedExtensions, }) }