128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
|
package crypto
|
||
|
|
||
|
import (
|
||
|
"crypto/aes"
|
||
|
"crypto/cipher"
|
||
|
"crypto/rand"
|
||
|
"crypto/sha256"
|
||
|
"crypto/subtle"
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"time"
|
||
|
|
||
|
"github.com/lucas-clemente/quic-go/protocol"
|
||
|
|
||
|
"golang.org/x/crypto/hkdf"
|
||
|
)
|
||
|
|
||
|
// StkSource is used to create and verify source address tokens
|
||
|
type StkSource interface {
|
||
|
// NewToken creates a new token for a given IP address
|
||
|
NewToken(sourceAddress []byte) ([]byte, error)
|
||
|
// VerifyToken verifies if a token matches a given IP address and is not outdated
|
||
|
VerifyToken(sourceAddress []byte, data []byte) error
|
||
|
}
|
||
|
|
||
|
type sourceAddressToken struct {
|
||
|
sourceAddr []byte
|
||
|
// unix timestamp in seconds
|
||
|
timestamp uint64
|
||
|
}
|
||
|
|
||
|
func (t *sourceAddressToken) serialize() []byte {
|
||
|
res := make([]byte, 8+len(t.sourceAddr))
|
||
|
binary.LittleEndian.PutUint64(res, t.timestamp)
|
||
|
copy(res[8:], t.sourceAddr)
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func parseToken(data []byte) (*sourceAddressToken, error) {
|
||
|
if len(data) != 8+4 && len(data) != 8+16 {
|
||
|
return nil, fmt.Errorf("invalid STK length: %d", len(data))
|
||
|
}
|
||
|
return &sourceAddressToken{
|
||
|
sourceAddr: data[8:],
|
||
|
timestamp: binary.LittleEndian.Uint64(data),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
type stkSource struct {
|
||
|
aead cipher.AEAD
|
||
|
}
|
||
|
|
||
|
const stkKeySize = 16
|
||
|
|
||
|
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
||
|
// at 16 :)
|
||
|
const stkNonceSize = 16
|
||
|
|
||
|
// NewStkSource creates a source for source address tokens
|
||
|
func NewStkSource(secret []byte) (StkSource, error) {
|
||
|
key, err := deriveKey(secret)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
c, err := aes.NewCipher(key)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &stkSource{aead: aead}, nil
|
||
|
}
|
||
|
|
||
|
func (s *stkSource) NewToken(sourceAddr []byte) ([]byte, error) {
|
||
|
return encryptToken(s.aead, &sourceAddressToken{
|
||
|
sourceAddr: sourceAddr,
|
||
|
timestamp: uint64(time.Now().Unix()),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (s *stkSource) VerifyToken(sourceAddr []byte, data []byte) error {
|
||
|
if len(data) < stkNonceSize {
|
||
|
return errors.New("STK too short")
|
||
|
}
|
||
|
nonce := data[:stkNonceSize]
|
||
|
|
||
|
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
token, err := parseToken(res)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if subtle.ConstantTimeCompare(token.sourceAddr, sourceAddr) != 1 {
|
||
|
return errors.New("invalid source address in STK")
|
||
|
}
|
||
|
|
||
|
if time.Now().Unix() > int64(token.timestamp)+protocol.STKExpiryTimeSec {
|
||
|
return errors.New("STK expired")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func deriveKey(secret []byte) ([]byte, error) {
|
||
|
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
||
|
key := make([]byte, stkKeySize)
|
||
|
if _, err := io.ReadFull(r, key); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return key, nil
|
||
|
}
|
||
|
|
||
|
func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) {
|
||
|
nonce := make([]byte, stkNonceSize)
|
||
|
if _, err := rand.Read(nonce); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return aead.Seal(nonce, nonce, token.serialize(), nil), nil
|
||
|
}
|