vendor: Updated quic-go for QUIC 39+ (#1968)

* Updated lucas-clemente/quic-go for QUIC 39+ support

* Update quic-go to latest
This commit is contained in:
Amos Ng 2018-02-17 13:29:53 +08:00 committed by Matt Holt
parent faa5248d1f
commit 1201492222
229 changed files with 26903 additions and 4254 deletions

21
vendor/github.com/bifurcation/mint/LICENSE.md generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Richard Barnes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

99
vendor/github.com/bifurcation/mint/alert.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mint
import "strconv"
type Alert uint8
const (
// alert level
AlertLevelWarning = 1
AlertLevelError = 2
)
const (
AlertCloseNotify Alert = 0
AlertUnexpectedMessage Alert = 10
AlertBadRecordMAC Alert = 20
AlertDecryptionFailed Alert = 21
AlertRecordOverflow Alert = 22
AlertDecompressionFailure Alert = 30
AlertHandshakeFailure Alert = 40
AlertBadCertificate Alert = 42
AlertUnsupportedCertificate Alert = 43
AlertCertificateRevoked Alert = 44
AlertCertificateExpired Alert = 45
AlertCertificateUnknown Alert = 46
AlertIllegalParameter Alert = 47
AlertUnknownCA Alert = 48
AlertAccessDenied Alert = 49
AlertDecodeError Alert = 50
AlertDecryptError Alert = 51
AlertProtocolVersion Alert = 70
AlertInsufficientSecurity Alert = 71
AlertInternalError Alert = 80
AlertInappropriateFallback Alert = 86
AlertUserCanceled Alert = 90
AlertNoRenegotiation Alert = 100
AlertMissingExtension Alert = 109
AlertUnsupportedExtension Alert = 110
AlertCertificateUnobtainable Alert = 111
AlertUnrecognizedName Alert = 112
AlertBadCertificateStatsResponse Alert = 113
AlertBadCertificateHashValue Alert = 114
AlertUnknownPSKIdentity Alert = 115
AlertNoApplicationProtocol Alert = 120
AlertWouldBlock Alert = 254
AlertNoAlert Alert = 255
)
var alertText = map[Alert]string{
AlertCloseNotify: "close notify",
AlertUnexpectedMessage: "unexpected message",
AlertBadRecordMAC: "bad record MAC",
AlertDecryptionFailed: "decryption failed",
AlertRecordOverflow: "record overflow",
AlertDecompressionFailure: "decompression failure",
AlertHandshakeFailure: "handshake failure",
AlertBadCertificate: "bad certificate",
AlertUnsupportedCertificate: "unsupported certificate",
AlertCertificateRevoked: "revoked certificate",
AlertCertificateExpired: "expired certificate",
AlertCertificateUnknown: "unknown certificate",
AlertIllegalParameter: "illegal parameter",
AlertUnknownCA: "unknown certificate authority",
AlertAccessDenied: "access denied",
AlertDecodeError: "error decoding message",
AlertDecryptError: "error decrypting message",
AlertProtocolVersion: "protocol version not supported",
AlertInsufficientSecurity: "insufficient security level",
AlertInternalError: "internal error",
AlertInappropriateFallback: "inappropriate fallback",
AlertUserCanceled: "user canceled",
AlertMissingExtension: "missing extension",
AlertUnsupportedExtension: "unsupported extension",
AlertCertificateUnobtainable: "certificate unobtainable",
AlertUnrecognizedName: "unrecognized name",
AlertBadCertificateStatsResponse: "bad certificate status response",
AlertBadCertificateHashValue: "bad certificate hash value",
AlertUnknownPSKIdentity: "unknown PSK identity",
AlertNoApplicationProtocol: "no application protocol",
AlertNoRenegotiation: "no renegotiation",
AlertWouldBlock: "would have blocked",
AlertNoAlert: "no alert",
}
func (e Alert) String() string {
s, ok := alertText[e]
if ok {
return s
}
return "alert(" + strconv.Itoa(int(e)) + ")"
}
func (e Alert) Error() string {
return e.String()
}

View file

@ -0,0 +1,42 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"github.com/bifurcation/mint"
)
var url string
func main() {
url := flag.String("url", "https://localhost:4430", "URL to send request")
flag.Parse()
mintdial := func(network, addr string) (net.Conn, error) {
return mint.Dial(network, addr, nil)
}
tr := &http.Transport{
DialTLS: mintdial,
DisableCompression: true,
}
client := &http.Client{Transport: tr}
response, err := client.Get(*url)
if err != nil {
fmt.Println("err:", err)
return
}
defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
fmt.Printf("%s", err)
os.Exit(1)
}
fmt.Printf("%s\n", string(contents))
}

View file

@ -0,0 +1,37 @@
package main
import (
"flag"
"fmt"
"github.com/bifurcation/mint"
)
var addr string
func main() {
flag.StringVar(&addr, "addr", "localhost:4430", "port")
flag.Parse()
conn, err := mint.Dial("tcp", addr, nil)
if err != nil {
fmt.Println("TLS handshake failed:", err)
return
}
request := "GET / HTTP/1.0\r\n\r\n"
conn.Write([]byte(request))
response := ""
buffer := make([]byte, 1024)
var read int
for err == nil {
read, err = conn.Read(buffer)
fmt.Println(" ~~ read: ", read)
response += string(buffer)
}
fmt.Println("err:", err)
fmt.Println("Received from server:")
fmt.Println(response)
}

View file

@ -0,0 +1,226 @@
package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/bifurcation/mint"
"golang.org/x/net/http2"
)
var (
port string
serverName string
certFile string
keyFile string
responseFile string
h2 bool
sendTickets bool
)
type responder []byte
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(rsp)
}
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
// PEM-encoded private key.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
keyDER, _ := pem.Decode(keyPEM)
if keyDER == nil {
return nil, err
}
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
if err != nil {
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
if err != nil {
// We don't include the actual error into
// the final error. The reason might be
// we don't want to leak any info about
// the private key.
return nil, fmt.Errorf("No successful private key decoder")
}
}
}
switch generalKey.(type) {
case *rsa.PrivateKey:
return generalKey.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return generalKey.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, fmt.Errorf("Should be unreachable")
}
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
// either a raw x509 certificate or a PKCS #7 structure possibly containing
// multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM)
if block == nil {
return nil, rest, nil
}
cert, err := x509.ParseCertificate(block.Bytes)
var certs = []*x509.Certificate{cert}
return certs, rest, err
}
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
// can handle PEM encoded PKCS #7 structures.
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
var err error
certsPEM = bytes.TrimSpace(certsPEM)
for len(certsPEM) > 0 {
var cert []*x509.Certificate
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
if err != nil {
return nil, err
} else if cert == nil {
break
}
certs = append(certs, cert...)
}
if len(certsPEM) > 0 {
return nil, fmt.Errorf("Trailing PEM data")
}
return certs, nil
}
func main() {
flag.StringVar(&port, "port", "4430", "port")
flag.StringVar(&serverName, "host", "example.com", "hostname")
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
flag.StringVar(&responseFile, "response", "", "file to serve")
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
flag.Parse()
var certChain []*x509.Certificate
var priv crypto.Signer
var response []byte
var err error
// Load the key and certificate chain
if certFile != "" {
certs, err := ioutil.ReadFile(certFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
certChain, err = ParseCertificatesPEM(certs)
if err != nil {
certChain, err = x509.ParseCertificates(certs)
if err != nil {
log.Fatalf("Error parsing certificates: %v", err)
}
}
}
}
if keyFile != "" {
keyPEM, err := ioutil.ReadFile(keyFile)
if err != nil {
log.Fatalf("Error: %v", err)
} else {
priv, err = ParsePrivateKeyPEM(keyPEM)
if priv == nil || err != nil {
log.Fatalf("Error parsing private key: %v", err)
}
}
}
if err != nil {
log.Fatalf("Error: %v", err)
}
// Load response file
if responseFile != "" {
log.Printf("Loading response file: %v", responseFile)
response, err = ioutil.ReadFile(responseFile)
if err != nil {
log.Fatalf("Error: %v", err)
}
} else {
response = []byte("Welcome to the TLS 1.3 zone!")
}
handler := responder(response)
config := mint.Config{
SendSessionTickets: true,
ServerName: serverName,
NextProtos: []string{"http/1.1"},
}
if h2 {
config.NextProtos = []string{"h2"}
}
config.SendSessionTickets = sendTickets
if certChain != nil && priv != nil {
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
config.Certificates = []*mint.Certificate{
{
Chain: certChain,
PrivateKey: priv,
},
}
}
config.Init(false)
service := "0.0.0.0:" + port
srv := &http.Server{Handler: handler}
log.Printf("Listening on port %v", port)
// Need the inner loop here because the h1 server errors on a dropped connection
// Need the outer loop here because the h2 server is per-connection
for {
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Printf("Listen Error: %v", err)
continue
}
if !h2 {
alert := srv.Serve(listener)
if alert != mint.AlertNoAlert {
log.Printf("Serve Error: %v", err)
}
} else {
srv2 := new(http2.Server)
opts := &http2.ServeConnOpts{
Handler: handler,
BaseConfig: srv,
}
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go srv2.ServeConn(conn, opts)
}
}
}
}

View file

@ -0,0 +1,65 @@
package main
import (
"flag"
"log"
"net"
"github.com/bifurcation/mint"
)
var port string
func main() {
var config mint.Config
config.SendSessionTickets = true
config.ServerName = "localhost"
config.Init(false)
flag.StringVar(&port, "port", "4430", "port")
flag.Parse()
service := "0.0.0.0:" + port
listener, err := mint.Listen("tcp", service, &config)
if err != nil {
log.Fatalf("server: listen: %s", err)
}
log.Print("server: listening")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("server: accept: %s", err)
break
}
defer conn.Close()
log.Printf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}
func handleClient(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 10)
for {
log.Print("server: conn: waiting")
n, err := conn.Read(buf)
if err != nil {
if err != nil {
log.Printf("server: conn: read: %s", err)
}
break
}
n, err = conn.Write([]byte("hello world"))
log.Printf("server: conn: wrote %d bytes", n)
if err != nil {
log.Printf("server: write: %s", err)
break
}
break
}
log.Println("server: conn: closed")
}

View file

@ -0,0 +1,942 @@
package mint
import (
"bytes"
"crypto"
"hash"
"time"
)
// Client State Machine
//
// START <----+
// Send ClientHello | | Recv HelloRetryRequest
// / v |
// | WAIT_SH ---+
// Can | | Recv ServerHello
// send | V
// early | WAIT_EE
// data | | Recv EncryptedExtensions
// | +--------+--------+
// | Using | | Using certificate
// | PSK | v
// | | WAIT_CERT_CR
// | | Recv | | Recv CertificateRequest
// | | Certificate | v
// | | | WAIT_CERT
// | | | | Recv Certificate
// | | v v
// | | WAIT_CV
// | | | Recv CertificateVerify
// | +> WAIT_FINISHED <+
// | | Recv Finished
// \ |
// | [Send EndOfEarlyData]
// | [Send Certificate [+ CertificateVerify]]
// | Send Finished
// Can send v
// app data --> CONNECTED
// after
// here
//
// State Instructions
// START Send(CH); [RekeyOut; SendEarlyData]
// WAIT_SH Send(CH) || RekeyIn
// WAIT_EE {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ClientStateStart struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
cookie []byte
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
return nil, nil, AlertUnexpectedMessage
}
// key_shares
offeredDH := map[NamedGroup][]byte{}
ks := KeyShareExtension{
HandshakeType: HandshakeTypeClientHello,
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
}
for i, group := range state.Caps.Groups {
pub, priv, err := newKeyShare(group)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
return nil, nil, AlertInternalError
}
ks.Shares[i].Group = group
ks.Shares[i].KeyExchange = pub
offeredDH[group] = priv
}
logf(logTypeHandshake, "opts: %+v", state.Opts)
// supported_versions, supported_groups, signature_algorithms, server_name
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
sni := ServerNameExtension(state.Opts.ServerName)
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
state.Params.ServerName = state.Opts.ServerName
// Application Layer Protocol Negotiation
var alpn *ALPNExtension
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
}
// Construct base ClientHello
ch := &ClientHelloBody{
CipherSuites: state.Caps.CipherSuites,
}
_, err := prng.Read(ch.Random[:])
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
return nil, nil, AlertInternalError
}
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
err := ch.Extensions.Add(ext)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
return nil, nil, AlertInternalError
}
}
// XXX: These optional extensions can't be folded into the above because Go
// interface-typed values are never reported as nil
if alpn != nil {
err := ch.Extensions.Add(alpn)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.cookie != nil {
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
// Handle PSK and EarlyData just before transmitting, so that we can
// calculate the PSK binder value
var psk *PreSharedKeyExtension
var ed *EarlyDataExtension
var offeredPSK PreSharedKey
var earlyHash crypto.Hash
var earlySecret []byte
var clientEarlyTrafficKeys keySet
var clientHello *HandshakeMessage
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
offeredPSK = key
// Narrow ciphersuites to ones that match PSK hash
params, ok := cipherSuiteMap[key.CipherSuite]
if !ok {
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
return nil, nil, AlertInternalError
}
compatibleSuites := []CipherSuite{}
for _, suite := range ch.CipherSuites {
if cipherSuiteMap[suite].Hash == params.Hash {
compatibleSuites = append(compatibleSuites, suite)
}
}
ch.CipherSuites = compatibleSuites
// Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 {
state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed)
if err != nil {
logf(logTypeHandshake, "Error adding early data extension: %v", err)
return nil, nil, AlertInternalError
}
}
// Signal supported PSK key exchange modes
if len(state.Caps.PSKModes) == 0 {
logf(logTypeHandshake, "PSK selected, but no PSKModes")
return nil, nil, AlertInternalError
}
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
err = ch.Extensions.Add(kem)
if err != nil {
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
return nil, nil, AlertInternalError
}
// Add the shim PSK extension to the ClientHello
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
psk = &PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
Identities: []PSKIdentity{
{
Identity: key.Identity,
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
},
},
Binders: []PSKBinderEntry{
// Note: Stub to get the length fields right
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
},
}
ch.Extensions.Add(psk)
// Compute the binder key
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlyHash = params.Hash
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
binderLabel := labelExternalBinder
if key.IsResumption {
binderLabel = labelResumptionBinder
}
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
// Compute the binder value
trunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
truncHash := params.Hash.New()
truncHash.Write(trunc)
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
// Replace the PSK extension
psk.Binders[0].Binder = binder
ch.Extensions.Add(psk)
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
// this one should too.
clientHello, _ = HandshakeMessageFromBody(ch)
// Compute early traffic keys
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else {
clientHello, err = HandshakeMessageFromBody(ch)
if err != nil {
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
return nil, nil, AlertInternalError
}
}
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
nextState := ClientStateWaitSH{
Caps: state.Caps,
Opts: state.Opts,
Params: state.Params,
OfferedDH: offeredDH,
OfferedPSK: offeredPSK,
earlySecret: earlySecret,
earlyHash: earlyHash,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}
toSend := []HandshakeAction{
SendHandshakeMessage{clientHello},
}
if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...)
}
return nextState, toSend, AlertNoAlert
}
type ClientStateWaitSH struct {
Caps Capabilities
Opts ConnectionOptions
Params ConnectionParameters
OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey
PSK []byte
earlySecret []byte
earlyHash crypto.Hash
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *HelloRetryRequestBody:
hrr := body
if state.helloRetryRequest != nil {
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
return nil, nil, AlertUnexpectedMessage
}
// Check that the version sent by the server is the one we support
if hrr.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Narrow the supported ciphersuites to the server-provided one
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// The only thing we know how to respond to in an HRR is the Cookie
// extension, so if there is either no Cookie extension or anything other
// than a Cookie extension, we have to fail.
serverCookie := new(CookieExtension)
foundCookie := hrr.Extensions.Find(serverCookie)
if !foundCookie || len(hrr.Extensions) != 1 {
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
return nil, nil, AlertIllegalParameter
}
// Hash the body into a pseudo-message
// XXX: Ignoring some errors here
params := cipherSuiteMap[hrr.CipherSuite]
h := params.Hash.New()
h.Write(state.clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return ClientStateStart{
Caps: state.Caps,
Opts: state.Opts,
cookie: serverCookie.Cookie,
firstClientHello: firstClientHello,
helloRetryRequest: hm,
}.Next(nil)
case *ServerHelloBody:
sh := body
// Check that the version sent by the server is the one we support
if sh.Version != supportedVersion {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
return nil, nil, AlertProtocolVersion
}
// Check that the server provided a supported ciphersuite
supportedCipherSuite := false
for _, suite := range state.Caps.CipherSuites {
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
}
if !supportedCipherSuite {
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
// Do PSK or key agreement depending on extensions
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
foundPSK := sh.Extensions.Find(&serverPSK)
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
if foundPSK && (serverPSK.SelectedIdentity == 0) {
state.Params.UsingPSK = true
}
var dhSecret []byte
if foundKeyShare {
sks := serverKeyShare.Shares[0]
priv, ok := state.OfferedDH[sks.Group]
if !ok {
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
return nil, nil, AlertIllegalParameter
}
state.Params.UsingDH = true
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
}
suite := sh.CipherSuite
state.Params.CipherSuite = suite
params, ok := cipherSuiteMap[suite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(hm.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
if params.Hash != state.earlyHash {
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
state.earlyHash, suite, params.Hash)
}
earlySecret = state.earlySecret
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if dhSecret == nil {
dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := ClientStateWaitEE{
Caps: state.Caps,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
certificates: state.Caps.Certificates,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
}
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitEE struct {
Caps Capabilities
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ee := EncryptedExtensionsBody{}
_, err := ee.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
serverALPN := ALPNExtension{}
serverEarlyData := EarlyDataExtension{}
gotALPN := ee.Extensions.Find(&serverALPN)
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
if gotALPN && len(serverALPN.Protocols) > 0 {
state.Params.NextProto = serverALPN.Protocols[0]
}
state.handshakeHash.Write(hm.Marshal())
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
nextState := ClientStateWaitCertCR{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCertCR struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
switch body := bodyGeneric.(type) {
case *CertificateBody:
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
case *CertificateRequestBody:
// A certificate request in the handshake should have a zero-length context
if len(body.CertificateRequestContext) > 0 {
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
return nil, nil, AlertIllegalParameter
}
state.Params.UsingClientAuth = true
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
nextState := ClientStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: body,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
return nil, nil, AlertUnexpectedMessage
}
type ClientStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
nextState := ClientStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificate: cert,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificate *CertificateBody
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
certVerify := CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.serverCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
}
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
nextState := ClientStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
certificates: state.certificates,
serverCertificateRequest: state.serverCertificateRequest,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
}
type ClientStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
certificates []*Certificate
serverCertificateRequest *CertificateRequestBody
masterSecret []byte
clientHandshakeTrafficSecret []byte
serverHandshakeTrafficSecret []byte
}
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Verify server's Finished
h3 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
fin.VerifyData, serverFinishedData)
return nil, nil, AlertHandshakeFailure
}
// Update the handshake hash with the Finished
state.handshakeHash.Write(hm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
h4 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
// Compute traffic secrets and keys
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
// Assemble client's second flight
toSend := []HandshakeAction{}
if state.Params.UsingEarlyData {
// Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will
// mismatch.
// EOED marshal is infallible
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
toSend = append(toSend, SendHandshakeMessage{eoedm})
state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest
schemes := SignatureAlgorithmsExtension{}
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
if !gotSchemes {
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
return nil, nil, AlertIllegalParameter
}
// Select a certificate
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
if err != nil {
// XXX: Signal this to the application layer?
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
certificate := &CertificateBody{}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
} else {
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(cert.Chain)),
}
for i, entry := range cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
state.handshakeHash.Write(certm.Marshal())
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
err = certificateVerify.Sign(cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
state.handshakeHash.Write(certvm.Marshal())
}
}
// Compute the client's Finished message
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
fin = &FinishedBody{
VerifyDataLen: len(clientFinishedData),
VerifyData: clientFinishedData,
}
finm, err := HandshakeMessageFromBody(fin)
if err != nil {
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
return nil, nil, AlertInternalError
}
// Compute the resumption secret
state.handshakeHash.Write(finm.Marshal())
h6 := state.handshakeHash.Sum(nil)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
toSend = append(toSend, []HandshakeAction{
SendHandshakeMessage{finm},
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
}...)
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: true,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
return nextState, toSend, AlertNoAlert
}

152
vendor/github.com/bifurcation/mint/common.go generated vendored Normal file
View file

@ -0,0 +1,152 @@
package mint
import (
"fmt"
"strconv"
)
var (
supportedVersion uint16 = 0x7f15 // draft-21
// Flags for some minor compat issues
allowWrongVersionNumber = true
allowPKCS1 = true
)
// enum {...} ContentType;
type RecordType byte
const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
)
// enum {...} HandshakeType;
type HandshakeType byte
const (
// Omitted: *_RESERVED
HandshakeTypeClientHello HandshakeType = 1
HandshakeTypeServerHello HandshakeType = 2
HandshakeTypeNewSessionTicket HandshakeType = 4
HandshakeTypeEndOfEarlyData HandshakeType = 5
HandshakeTypeHelloRetryRequest HandshakeType = 6
HandshakeTypeEncryptedExtensions HandshakeType = 8
HandshakeTypeCertificate HandshakeType = 11
HandshakeTypeCertificateRequest HandshakeType = 13
HandshakeTypeCertificateVerify HandshakeType = 15
HandshakeTypeServerConfiguration HandshakeType = 17
HandshakeTypeFinished HandshakeType = 20
HandshakeTypeKeyUpdate HandshakeType = 24
HandshakeTypeMessageHash HandshakeType = 254
)
// uint8 CipherSuite[2];
type CipherSuite uint16
const (
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
// value for this type so that we can detect when a field is set.
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
)
func (c CipherSuite) String() string {
switch c {
case CIPHER_SUITE_UNKNOWN:
return "unknown"
case TLS_AES_128_GCM_SHA256:
return "TLS_AES_128_GCM_SHA256"
case TLS_AES_256_GCM_SHA384:
return "TLS_AES_256_GCM_SHA384"
case TLS_CHACHA20_POLY1305_SHA256:
return "TLS_CHACHA20_POLY1305_SHA256"
case TLS_AES_128_CCM_SHA256:
return "TLS_AES_128_CCM_SHA256"
case TLS_AES_256_CCM_8_SHA256:
return "TLS_AES_256_CCM_8_SHA256"
}
// cannot use %x here, since it calls String(), leading to infinite recursion
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
}
// enum {...} SignatureScheme
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
// ECDSA algorithms
ECDSA_P256_SHA256 SignatureScheme = 0x0403
ECDSA_P384_SHA384 SignatureScheme = 0x0503
ECDSA_P521_SHA512 SignatureScheme = 0x0603
// RSASSA-PSS algorithms
RSA_PSS_SHA256 SignatureScheme = 0x0804
RSA_PSS_SHA384 SignatureScheme = 0x0805
RSA_PSS_SHA512 SignatureScheme = 0x0806
// EdDSA algorithms
Ed25519 SignatureScheme = 0x0807
Ed448 SignatureScheme = 0x0808
)
// enum {...} ExtensionType
type ExtensionType uint16
const (
ExtensionTypeServerName ExtensionType = 0
ExtensionTypeSupportedGroups ExtensionType = 10
ExtensionTypeSignatureAlgorithms ExtensionType = 13
ExtensionTypeALPN ExtensionType = 16
ExtensionTypeKeyShare ExtensionType = 40
ExtensionTypePreSharedKey ExtensionType = 41
ExtensionTypeEarlyData ExtensionType = 42
ExtensionTypeSupportedVersions ExtensionType = 43
ExtensionTypeCookie ExtensionType = 44
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
)
// enum {...} NamedGroup
type NamedGroup uint16
const (
// Elliptic Curve Groups.
P256 NamedGroup = 23
P384 NamedGroup = 24
P521 NamedGroup = 25
// ECDH functions.
X25519 NamedGroup = 29
X448 NamedGroup = 30
// Finite field groups.
FFDHE2048 NamedGroup = 256
FFDHE3072 NamedGroup = 257
FFDHE4096 NamedGroup = 258
FFDHE6144 NamedGroup = 259
FFDHE8192 NamedGroup = 260
)
// enum {...} PskKeyExchangeMode;
type PSKKeyExchangeMode uint8
const (
PSKModeKE PSKKeyExchangeMode = 0
PSKModeDHEKE PSKKeyExchangeMode = 1
)
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
type KeyUpdateRequest uint8
const (
KeyUpdateNotRequested KeyUpdateRequest = 0
KeyUpdateRequested KeyUpdateRequest = 1
)

819
vendor/github.com/bifurcation/mint/conn.go generated vendored Normal file
View file

@ -0,0 +1,819 @@
package mint
import (
"crypto"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net"
"reflect"
"sync"
"time"
)
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
}
type PreSharedKey struct {
CipherSuite CipherSuite
IsResumption bool
Identity []byte
Key []byte
NextProto string
ReceivedAt time.Time
ExpiresAt time.Time
TicketAgeAdd uint32
}
type PreSharedKeyCache interface {
Get(string) (PreSharedKey, bool)
Put(string, PreSharedKey)
Size() int
}
type PSKMapCache map[string]PreSharedKey
// A CookieHandler does two things:
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
// - validates this byte string echoed by the client in the ClientHello
type CookieHandler interface {
Generate(*Conn) ([]byte, error)
Validate(*Conn, []byte) bool
}
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
psk, ok = cache[key]
return
}
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
(*cache)[key] = psk
}
func (cache PSKMapCache) Size() int {
return len(cache)
}
// Config is the struct used to pass configuration settings to a TLS client or
// server instance. The settings for client and server are pretty different,
// but we just throw them all in here.
type Config struct {
// Client fields
ServerName string
// Server fields
SendSessionTickets bool
TicketLifetime uint32
TicketLen int
EarlyDataLifetime uint32
AllowEarlyData bool
// Require the client to echo a cookie.
RequireCookie bool
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
// The default cookie handler uses 32 random bytes as a cookie.
CookieHandler CookieHandler
RequireClientAuth bool
// Shared fields
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
NextProtos []string
PSKs PreSharedKeyCache
PSKModes []PSKKeyExchangeMode
NonBlocking bool
// The same config object can be shared among different connections, so it
// needs its own mutex
mutex sync.RWMutex
}
// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
c.mutex.Lock()
defer c.mutex.Unlock()
return &Config{
ServerName: c.ServerName,
SendSessionTickets: c.SendSessionTickets,
TicketLifetime: c.TicketLifetime,
TicketLen: c.TicketLen,
EarlyDataLifetime: c.EarlyDataLifetime,
AllowEarlyData: c.AllowEarlyData,
RequireCookie: c.RequireCookie,
RequireClientAuth: c.RequireClientAuth,
Certificates: c.Certificates,
AuthCertificate: c.AuthCertificate,
CipherSuites: c.CipherSuites,
Groups: c.Groups,
SignatureSchemes: c.SignatureSchemes,
NextProtos: c.NextProtos,
PSKs: c.PSKs,
PSKModes: c.PSKModes,
NonBlocking: c.NonBlocking,
}
}
func (c *Config) Init(isClient bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Set defaults
if len(c.CipherSuites) == 0 {
c.CipherSuites = defaultSupportedCipherSuites
}
if len(c.Groups) == 0 {
c.Groups = defaultSupportedGroups
}
if len(c.SignatureSchemes) == 0 {
c.SignatureSchemes = defaultSignatureSchemes
}
if c.TicketLen == 0 {
c.TicketLen = defaultTicketLen
}
if !reflect.ValueOf(c.PSKs).IsValid() {
c.PSKs = &PSKMapCache{}
}
if len(c.PSKModes) == 0 {
c.PSKModes = defaultPSKModes
}
// If there is no certificate, generate one
if !isClient && len(c.Certificates) == 0 {
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
priv, err := newSigningKey(RSA_PSS_SHA256)
if err != nil {
return err
}
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
if err != nil {
return err
}
c.Certificates = []*Certificate{
{
Chain: []*x509.Certificate{cert},
PrivateKey: priv,
},
}
}
return nil
}
func (c *Config) ValidForServer() bool {
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
(len(c.Certificates) > 0 &&
len(c.Certificates[0].Chain) > 0 &&
c.Certificates[0].PrivateKey != nil)
}
func (c *Config) ValidForClient() bool {
return len(c.ServerName) > 0
}
var (
defaultSupportedCipherSuites = []CipherSuite{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
defaultSupportedGroups = []NamedGroup{
P256,
P384,
FFDHE2048,
X25519,
}
defaultSignatureSchemes = []SignatureScheme{
RSA_PSS_SHA256,
RSA_PSS_SHA384,
RSA_PSS_SHA512,
ECDSA_P256_SHA256,
ECDSA_P384_SHA384,
ECDSA_P521_SHA512,
}
defaultTicketLen = 16
defaultPSKModes = []PSKKeyExchangeMode{
PSKModeKE,
PSKModeDHEKE,
}
)
type ConnectionState struct {
HandshakeState string // string representation of the handshake state.
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
NextProto string // Selected ALPN proto
}
// Conn implements the net.Conn interface, as with "crypto/tls"
// * Read, Write, and Close are provided locally
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
type Conn struct {
config *Config
conn net.Conn
isClient bool
EarlyData []byte
state StateConnected
hState HandshakeState
handshakeMutex sync.Mutex
handshakeAlert Alert
handshakeComplete bool
readBuffer []byte
in, out *RecordLayer
hIn, hOut *HandshakeLayer
extHandler AppExtensionHandler
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient}
c.in = NewRecordLayer(c.conn)
c.out = NewRecordLayer(c.conn)
c.hIn = NewHandshakeLayer(c.in)
c.hIn.nonblocking = c.config.NonBlocking
c.hOut = NewHandshakeLayer(c.out)
return c
}
// Read up
func (c *Conn) consumeRecord() error {
pt, err := c.in.ReadRecord()
if pt == nil {
logf(logTypeIO, "extendBuffer returns error %v", err)
return err
}
switch pt.contentType {
case RecordTypeHandshake:
logf(logTypeHandshake, "Received post-handshake message")
// We do not support fragmentation of post-handshake handshake messages.
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
start := 0
for start < len(pt.fragment) {
if len(pt.fragment[start:]) < handshakeHeaderLen {
return fmt.Errorf("Post-handshake handshake message too short for header")
}
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(pt.fragment[start])
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
return fmt.Errorf("Post-handshake handshake message too short for body")
}
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
// Advance state machine
state, actions, alert := c.state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return io.EOF
}
}
// XXX: If we want to support more advanced cases, e.g., post-handshake
// authentication, we'll need to allow transitions other than
// Connected -> Connected
var connected bool
c.state, connected = state.(StateConnected)
if !connected {
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
c.sendAlert(alert)
return io.EOF
}
start += handshakeHeaderLen + hmLen
}
case RecordTypeAlert:
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
if len(pt.fragment) != 2 {
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
if Alert(pt.fragment[1]) == AlertCloseNotify {
return io.EOF
}
switch pt.fragment[0] {
case AlertLevelWarning:
// drop on the floor
case AlertLevelError:
return Alert(pt.fragment[1])
default:
c.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
}
if len(buffer) == 0 {
return 0, nil
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
for len(c.readBuffer) == 0 {
err := c.consumeRecord()
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != WouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
var read int
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
}
// Write application data
func (c *Conn) Write(buffer []byte) (int, error) {
// Lock the output channel
c.out.Lock()
defer c.out.Unlock()
// Send full-size fragments
var start int
sent := 0
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return sent, err
}
sent += maxFragmentLen
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: buffer[start:],
})
if err != nil {
return sent, err
}
sent += len(buffer[start:])
}
return sent, nil
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlert(err Alert) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var level int
switch err {
case AlertNoRenegotiation, AlertCloseNotify:
level = AlertLevelWarning
default:
level = AlertLevelError
}
buf := []byte{byte(err), byte(level)}
c.out.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: buf,
})
// close_notify and end_of_early_data are not actually errors
if level == AlertLevelWarning {
return &net.OpError{Op: "local error", Err: err}
}
return c.Close()
}
// Close closes the connection.
func (c *Conn) Close() error {
// XXX crypto/tls has an interlock with Write here. Do we need that?
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
switch action := actionGeneric.(type) {
case SendHandshakeMessage:
err := c.hOut.WriteMessage(action.Message)
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
return AlertInternalError
}
case RekeyOut:
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
return AlertInternalError
}
case SendEarlyData:
logf(logTypeHandshake, "%s Sending early data...", label)
_, err := c.Write(c.EarlyData)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
if c.isClient {
// Clients look up PSKs based on server name
c.config.PSKs.Put(c.config.ServerName, action.PSK)
} else {
// Servers look them up based on the identity in the extension
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
}
default:
logf(logTypeHandshake, "%s Unknown actionuction type", label)
return AlertInternalError
}
return AlertNoAlert
}
func (c *Conn) HandshakeSetup() Alert {
var state HandshakeState
var actions []HandshakeAction
var alert Alert
if err := c.config.Init(c.isClient); err != nil {
logf(logTypeHandshake, "Error initializing config: %v", err)
return AlertInternalError
}
// Set things up
caps := Capabilities{
CipherSuites: c.config.CipherSuites,
Groups: c.config.Groups,
SignatureSchemes: c.config.SignatureSchemes,
PSKs: c.config.PSKs,
PSKModes: c.config.PSKModes,
AllowEarlyData: c.config.AllowEarlyData,
RequireCookie: c.config.RequireCookie,
CookieHandler: c.config.CookieHandler,
RequireClientAuth: c.config.RequireClientAuth,
NextProtos: c.config.NextProtos,
Certificates: c.config.Certificates,
ExtensionHandler: c.extHandler,
}
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if caps.RequireCookie && caps.CookieHandler == nil {
caps.CookieHandler = &defaultCookieHandler{}
}
if c.isClient {
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error initializing client state: %v", alert)
return alert
}
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
return alert
}
}
} else {
state = ServerStateStart{Caps: caps, conn: c}
}
c.hState = state
return AlertNoAlert
}
// Handshake causes a TLS handshake on the connection. The `isClient` member
// determines whether a client or server handshake is performed. If a
// handshake has already been performed, then its result will be returned.
func (c *Conn) Handshake() Alert {
label := "[server]"
if c.isClient {
label = "[client]"
}
// TODO Lock handshakeMutex
// TODO Remove CloseNotify hack
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
return c.handshakeAlert
}
if c.handshakeComplete {
return AlertNoAlert
}
var alert Alert
if c.hState == nil {
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
alert = c.HandshakeSetup()
if alert != AlertNoAlert {
return alert
}
} else {
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
}
state := c.hState
_, connected := state.(StateConnected)
var actions []HandshakeAction
for !connected {
// Read a handshake message
hm, err := c.hIn.ReadMessage()
if err == WouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
return AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
c.sendAlert(AlertCloseNotify)
return AlertCloseNotify
}
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
// Advance the state machine
state, actions, alert = state.Next(hm)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error in state transition: %v", alert)
return alert
}
for index, action := range actions {
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
c.hState = state
logf(logTypeHandshake, "state is now %s", c.GetHsState())
_, connected = state.(StateConnected)
}
c.state = state.(StateConnected)
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
c.handshakeComplete = true
return AlertNoAlert
}
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
if !c.handshakeComplete {
return fmt.Errorf("Cannot update keys until after handshake")
}
request := KeyUpdateNotRequested
if requestUpdate {
request = KeyUpdateRequested
}
// Create the key update and update state
actions, alert := c.state.KeyUpdate(request)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert while generating key update: %v", alert)
}
// Take actions (send key update and rekey)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
c.sendAlert(alert)
return fmt.Errorf("Alert during key update actions: %v", alert)
}
}
return nil
}
func (c *Conn) GetHsState() string {
return reflect.TypeOf(c.hState).Name()
}
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
_, connected := c.hState.(StateConnected)
if !connected {
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
}
if c.state.exporterSecret == nil {
return nil, fmt.Errorf("Internal error: no exporter secret")
}
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
hc := c.state.cryptoParams.Hash.New().Sum(context)
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
}
func (c *Conn) State() ConnectionState {
state := ConnectionState{
HandshakeState: c.GetHsState(),
}
if c.handshakeComplete {
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
state.NextProto = c.state.Params.NextProto
}
return state
}
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
if c.hState != nil {
return fmt.Errorf("Can't set extension handler after setup")
}
c.extHandler = h
return nil
}

654
vendor/github.com/bifurcation/mint/crypto.go generated vendored Normal file
View file

@ -0,0 +1,654 @@
package mint
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/curve25519"
// Blank includes to ensure hash support
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
)
var prng = rand.Reader
type aeadFactory func(key []byte) (cipher.AEAD, error)
type CipherSuiteParams struct {
Suite CipherSuite
Cipher aeadFactory // Cipher factory
Hash crypto.Hash // Hash function
KeyLen int // Key length in octets
IvLen int // IV length in octets
}
type signatureAlgorithm uint8
const (
signatureAlgorithmUnknown = iota
signatureAlgorithmRSA_PKCS1
signatureAlgorithmRSA_PSS
signatureAlgorithmECDSA
)
var (
hashMap = map[SignatureScheme]crypto.Hash{
RSA_PKCS1_SHA1: crypto.SHA1,
RSA_PKCS1_SHA256: crypto.SHA256,
RSA_PKCS1_SHA384: crypto.SHA384,
RSA_PKCS1_SHA512: crypto.SHA512,
ECDSA_P256_SHA256: crypto.SHA256,
ECDSA_P384_SHA384: crypto.SHA384,
ECDSA_P521_SHA512: crypto.SHA512,
RSA_PSS_SHA256: crypto.SHA256,
RSA_PSS_SHA384: crypto.SHA384,
RSA_PSS_SHA512: crypto.SHA512,
}
sigMap = map[SignatureScheme]signatureAlgorithm{
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
}
curveMap = map[SignatureScheme]NamedGroup{
ECDSA_P256_SHA256: P256,
ECDSA_P384_SHA384: P384,
ECDSA_P521_SHA512: P521,
}
newAESGCM = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// TLS always uses 12-byte nonces
return cipher.NewGCMWithNonceSize(block, 12)
}
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
TLS_AES_128_GCM_SHA256: {
Suite: TLS_AES_128_GCM_SHA256,
Cipher: newAESGCM,
Hash: crypto.SHA256,
KeyLen: 16,
IvLen: 12,
},
TLS_AES_256_GCM_SHA384: {
Suite: TLS_AES_256_GCM_SHA384,
Cipher: newAESGCM,
Hash: crypto.SHA384,
KeyLen: 32,
IvLen: 12,
},
}
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
}
defaultRSAKeySize = 2048
)
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
switch group {
case P256:
crv = elliptic.P256()
case P384:
crv = elliptic.P384()
case P521:
crv = elliptic.P521()
}
return
}
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
switch key.Curve.Params().Name {
case elliptic.P256().Params().Name:
g = P256
case elliptic.P384().Params().Name:
g = P384
case elliptic.P521().Params().Name:
g = P521
}
return
}
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
size = 0
switch group {
case X25519:
size = 32
case P256:
size = 65
case P384:
size = 97
case P521:
size = 133
case FFDHE2048:
size = 256
case FFDHE3072:
size = 384
case FFDHE4096:
size = 512
case FFDHE6144:
size = 768
case FFDHE8192:
size = 1024
}
return
}
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
switch group {
case FFDHE2048:
p = finiteFieldPrime2048
case FFDHE3072:
p = finiteFieldPrime3072
case FFDHE4096:
p = finiteFieldPrime4096
case FFDHE6144:
p = finiteFieldPrime6144
case FFDHE8192:
p = finiteFieldPrime8192
}
return
}
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
sigType := sigMap[alg]
switch key.(type) {
case *rsa.PrivateKey:
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
case *ecdsa.PrivateKey:
return sigType == signatureAlgorithmECDSA
default:
return false
}
}
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
primeLen := len(p.Bytes())
for {
// g = 2 for all ffdhe groups
priv, err = rand.Int(prng, p)
if err != nil {
return
}
pub = big.NewInt(0)
pub.Exp(big.NewInt(2), priv, p)
if len(pub.Bytes()) == primeLen {
return
}
}
}
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
switch group {
case P256, P384, P521:
var x, y *big.Int
crv := curveFromNamedGroup(group)
priv, x, y, err = elliptic.GenerateKey(crv, prng)
if err != nil {
return
}
pub = elliptic.Marshal(crv, x, y)
return
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
p := primeFromNamedGroup(group)
x, X, err2 := ffdheKeyShareFromPrime(p)
if err2 != nil {
err = err2
return
}
priv = x.Bytes()
pubBytes := X.Bytes()
numBytes := keyExchangeSizeFromNamedGroup(group)
pub = make([]byte, numBytes)
copy(pub[numBytes-len(pubBytes):], pubBytes)
return
case X25519:
var private, public [32]byte
_, err = prng.Read(private[:])
if err != nil {
return
}
curve25519.ScalarBaseMult(&public, &private)
priv = private[:]
pub = public[:]
return
default:
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
}
}
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
switch group {
case P256, P384, P521:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
crv := curveFromNamedGroup(group)
pubX, pubY := elliptic.Unmarshal(crv, pub)
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
xBytes := x.Bytes()
numBytes := len(crv.Params().P.Bytes())
ret := make([]byte, numBytes)
copy(ret[numBytes-len(xBytes):], xBytes)
return ret, nil
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
numBytes := keyExchangeSizeFromNamedGroup(group)
if len(pub) != numBytes {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
p := primeFromNamedGroup(group)
x := big.NewInt(0).SetBytes(priv)
Y := big.NewInt(0).SetBytes(pub)
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
ret := make([]byte, numBytes)
copy(ret[numBytes-len(ZBytes):], ZBytes)
return ret, nil
case X25519:
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
}
var private, public, ret [32]byte
copy(private[:], priv)
copy(public[:], pub)
curve25519.ScalarMult(&ret, &private, &public)
return ret[:], nil
default:
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
}
}
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
switch sig {
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
RSA_PSS_SHA256, RSA_PSS_SHA384,
RSA_PSS_SHA512:
return rsa.GenerateKey(prng, defaultRSAKeySize)
case ECDSA_P256_SHA256:
return ecdsa.GenerateKey(elliptic.P256(), prng)
case ECDSA_P384_SHA384:
return ecdsa.GenerateKey(elliptic.P384(), prng)
case ECDSA_P521_SHA512:
return ecdsa.GenerateKey(elliptic.P521(), prng)
default:
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
}
}
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
sigAlg, ok := x509AlgMap[alg]
if !ok {
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
}
if len(name) == 0 {
return nil, fmt.Errorf("tls.selfsigned: No name provided")
}
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
if err != nil {
return nil, err
}
template := &x509.Certificate{
SerialNumber: serial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(0, 0, 1),
SignatureAlgorithm: sigAlg,
Subject: pkix.Name{CommonName: name},
DNSNames: []string{name},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
if err != nil {
return nil, err
}
// It is safe to ignore the error here because we're parsing known-good data
cert, _ := x509.ParseCertificate(der)
return cert, nil
}
// XXX(rlb): Copied from crypto/x509
type ecdsaSignature struct {
R, S *big.Int
}
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
var opts crypto.SignerOpts
hash := hashMap[alg]
if hash == crypto.SHA1 {
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
var realInput []byte
switch key := privateKey.(type) {
case *rsa.PrivateKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
opts = hash
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
case *ecdsa.PrivateKey:
if sigType != signatureAlgorithmECDSA {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
}
algGroup := curveMap[alg]
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
if algGroup != keyGroup {
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
}
h := hash.New()
h.Write(sigInput)
realInput = h.Sum(nil)
default:
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
}
sig, err := privateKey.Sign(prng, realInput, opts)
logf(logTypeCrypto, "signature: %x", sig)
return sig, err
}
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
hash := hashMap[alg]
if hash == crypto.SHA1 {
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
}
sigType := sigMap[alg]
switch pub := publicKey.(type) {
case *rsa.PublicKey:
switch {
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
fallthrough
case sigType == signatureAlgorithmRSA_PSS:
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
default:
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
}
case *ecdsa.PublicKey:
if sigType != signatureAlgorithmECDSA {
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
}
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
}
ecdsaSig := new(ecdsaSignature)
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
return err
} else if len(rest) != 0 {
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
}
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
}
h := hash.New()
h.Write(sigInput)
realInput := h.Sum(nil)
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
return fmt.Errorf("tls.verify: ECDSA verification failure")
}
return nil
default:
return fmt.Errorf("tls.verify: Unsupported key type")
}
}
// 0
// |
// v
// PSK -> HKDF-Extract = Early Secret
// |
// +-----> Derive-Secret(.,
// | "ext binder" |
// | "res binder",
// | "")
// | = binder_key
// |
// +-----> Derive-Secret(., "c e traffic",
// | ClientHello)
// | = client_early_traffic_secret
// |
// +-----> Derive-Secret(., "e exp master",
// | ClientHello)
// | = early_exporter_master_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// (EC)DHE -> HKDF-Extract = Handshake Secret
// |
// +-----> Derive-Secret(., "c hs traffic",
// | ClientHello...ServerHello)
// | = client_handshake_traffic_secret
// |
// +-----> Derive-Secret(., "s hs traffic",
// | ClientHello...ServerHello)
// | = server_handshake_traffic_secret
// v
// Derive-Secret(., "derived", "")
// |
// v
// 0 -> HKDF-Extract = Master Secret
// |
// +-----> Derive-Secret(., "c ap traffic",
// | ClientHello...server Finished)
// | = client_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "s ap traffic",
// | ClientHello...server Finished)
// | = server_application_traffic_secret_0
// |
// +-----> Derive-Secret(., "exp master",
// | ClientHello...server Finished)
// | = exporter_master_secret
// |
// +-----> Derive-Secret(., "res master",
// ClientHello...client Finished)
// = resumption_master_secret
// From RFC 5869
// PRK = HMAC-Hash(salt, IKM)
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
salt := saltIn
// if [salt is] not provided, it is set to a string of HashLen zeros
if salt == nil {
salt = bytes.Repeat([]byte{0}, hash.Size())
}
h := hmac.New(hash.New, salt)
h.Write(input)
out := h.Sum(nil)
logf(logTypeCrypto, "HKDF Extract:\n")
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
return out
}
const (
labelExternalBinder = "ext binder"
labelResumptionBinder = "res binder"
labelEarlyTrafficSecret = "c e traffic"
labelEarlyExporterSecret = "e exp master"
labelClientHandshakeTrafficSecret = "c hs traffic"
labelServerHandshakeTrafficSecret = "s hs traffic"
labelClientApplicationTrafficSecret = "c ap traffic"
labelServerApplicationTrafficSecret = "s ap traffic"
labelExporterSecret = "exp master"
labelResumptionSecret = "res master"
labelDerived = "derived"
labelFinished = "finished"
labelResumption = "resumption"
)
// struct HkdfLabel {
// uint16 length;
// opaque label<9..255>;
// opaque hash_value<0..255>;
// };
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
label := "tls13 " + labelIn
labelLen := len(label)
hashLen := len(hashValue)
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
hkdfLabel[0] = byte(outLen >> 8)
hkdfLabel[1] = byte(outLen)
hkdfLabel[2] = byte(labelLen)
copy(hkdfLabel[3:3+labelLen], []byte(label))
hkdfLabel[3+labelLen] = byte(hashLen)
copy(hkdfLabel[3+labelLen+1:], hashValue)
return hkdfLabel
}
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
out := []byte{}
T := []byte{}
i := byte(1)
for len(out) < outLen {
block := append(T, info...)
block = append(block, i)
h := hmac.New(hash.New, prk)
h.Write(block)
T = h.Sum(nil)
out = append(out, T...)
i++
}
return out[:outLen]
}
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
info := hkdfEncodeLabel(label, hashValue, outLen)
derived := HkdfExpand(hash, secret, info, outLen)
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
return derived
}
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
}
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
mac := hmac.New(params.Hash.New, macKey)
mac.Write(input)
return mac.Sum(nil)
}
type keySet struct {
cipher aeadFactory
key []byte
iv []byte
}
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
return keySet{
cipher: params.Cipher,
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
}
}

586
vendor/github.com/bifurcation/mint/extensions.go generated vendored Normal file
View file

@ -0,0 +1,586 @@
package mint
import (
"bytes"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type ExtensionBody interface {
Type() ExtensionType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ExtensionType extension_type;
// opaque extension_data<0..2^16-1>;
// } Extension;
type Extension struct {
ExtensionType ExtensionType
ExtensionData []byte `tls:"head=2"`
}
func (ext Extension) Marshal() ([]byte, error) {
return syntax.Marshal(ext)
}
func (ext *Extension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ext)
}
type ExtensionList []Extension
type extensionListInner struct {
List []Extension `tls:"head=2"`
}
func (el ExtensionList) Marshal() ([]byte, error) {
return syntax.Marshal(extensionListInner{el})
}
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
var list extensionListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
*el = list.List
return read, nil
}
func (el *ExtensionList) Add(src ExtensionBody) error {
data, err := src.Marshal()
if err != nil {
return err
}
if el == nil {
el = new(ExtensionList)
}
// If one already exists with this type, replace it
for i := range *el {
if (*el)[i].ExtensionType == src.Type() {
(*el)[i].ExtensionData = data
return nil
}
}
// Otherwise append
*el = append(*el, Extension{
ExtensionType: src.Type(),
ExtensionData: data,
})
return nil
}
func (el ExtensionList) Find(dst ExtensionBody) bool {
for _, ext := range el {
if ext.ExtensionType == dst.Type() {
_, err := dst.Unmarshal(ext.ExtensionData)
return err == nil
}
}
return false
}
// struct {
// NameType name_type;
// select (name_type) {
// case host_name: HostName;
// } name;
// } ServerName;
//
// enum {
// host_name(0), (255)
// } NameType;
//
// opaque HostName<1..2^16-1>;
//
// struct {
// ServerName server_name_list<1..2^16-1>
// } ServerNameList;
//
// But we only care about the case where there's a single DNS hostname. We
// will never create anything else, and throw if we receive something else
//
// 2 1 2
// | listLen | NameType | nameLen | name |
type ServerNameExtension string
type serverNameInner struct {
NameType uint8
HostName []byte `tls:"head=2,min=1"`
}
type serverNameListInner struct {
ServerNameList []serverNameInner `tls:"head=2,min=1"`
}
func (sni ServerNameExtension) Type() ExtensionType {
return ExtensionTypeServerName
}
func (sni ServerNameExtension) Marshal() ([]byte, error) {
list := serverNameListInner{
ServerNameList: []serverNameInner{{
NameType: 0x00, // host_name
HostName: []byte(sni),
}},
}
return syntax.Marshal(list)
}
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
var list serverNameListInner
read, err := syntax.Unmarshal(data, &list)
if err != nil {
return 0, err
}
// Syntax requires at least one entry
// Entries beyond the first are ignored
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
}
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
return read, nil
}
// struct {
// NamedGroup group;
// opaque key_exchange<1..2^16-1>;
// } KeyShareEntry;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// KeyShareEntry client_shares<0..2^16-1>;
//
// case hello_retry_request:
// NamedGroup selected_group;
//
// case server_hello:
// KeyShareEntry server_share;
// };
// } KeyShare;
type KeyShareEntry struct {
Group NamedGroup
KeyExchange []byte `tls:"head=2,min=1"`
}
func (kse KeyShareEntry) SizeValid() bool {
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
}
type KeyShareExtension struct {
HandshakeType HandshakeType
SelectedGroup NamedGroup
Shares []KeyShareEntry
}
type KeyShareClientHelloInner struct {
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
}
type KeyShareHelloRetryInner struct {
SelectedGroup NamedGroup
}
type KeyShareServerHelloInner struct {
ServerShare KeyShareEntry
}
func (ks KeyShareExtension) Type() ExtensionType {
return ExtensionTypeKeyShare
}
func (ks KeyShareExtension) Marshal() ([]byte, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
for _, share := range ks.Shares {
if !share.SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
case HandshakeTypeHelloRetryRequest:
if len(ks.Shares) > 0 {
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
}
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
case HandshakeTypeServerHello:
if len(ks.Shares) != 1 {
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
}
if !ks.Shares[0].SizeValid() {
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
default:
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
switch ks.HandshakeType {
case HandshakeTypeClientHello:
var inner KeyShareClientHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
for _, share := range inner.ClientShares {
if !share.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
}
ks.Shares = inner.ClientShares
return read, nil
case HandshakeTypeHelloRetryRequest:
var inner KeyShareHelloRetryInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
ks.SelectedGroup = inner.SelectedGroup
return read, nil
case HandshakeTypeServerHello:
var inner KeyShareServerHelloInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if !inner.ServerShare.SizeValid() {
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
}
ks.Shares = []KeyShareEntry{inner.ServerShare}
return read, nil
default:
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
}
}
// struct {
// NamedGroup named_group_list<2..2^16-1>;
// } NamedGroupList;
type SupportedGroupsExtension struct {
Groups []NamedGroup `tls:"head=2,min=2"`
}
func (sg SupportedGroupsExtension) Type() ExtensionType {
return ExtensionTypeSupportedGroups
}
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sg)
}
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sg)
}
// struct {
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
// } SignatureSchemeList
type SignatureAlgorithmsExtension struct {
Algorithms []SignatureScheme `tls:"head=2,min=2"`
}
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
return ExtensionTypeSignatureAlgorithms
}
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sa)
}
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sa)
}
// struct {
// opaque identity<1..2^16-1>;
// uint32 obfuscated_ticket_age;
// } PskIdentity;
//
// opaque PskBinderEntry<32..255>;
//
// struct {
// select (Handshake.msg_type) {
// case client_hello:
// PskIdentity identities<7..2^16-1>;
// PskBinderEntry binders<33..2^16-1>;
//
// case server_hello:
// uint16 selected_identity;
// };
//
// } PreSharedKeyExtension;
type PSKIdentity struct {
Identity []byte `tls:"head=2,min=1"`
ObfuscatedTicketAge uint32
}
type PSKBinderEntry struct {
Binder []byte `tls:"head=1,min=32"`
}
type PreSharedKeyExtension struct {
HandshakeType HandshakeType
Identities []PSKIdentity
Binders []PSKBinderEntry
SelectedIdentity uint16
}
type preSharedKeyClientInner struct {
Identities []PSKIdentity `tls:"head=2,min=7"`
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}
type preSharedKeyServerInner struct {
SelectedIdentity uint16
}
func (psk PreSharedKeyExtension) Type() ExtensionType {
return ExtensionTypePreSharedKey
}
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
return syntax.Marshal(preSharedKeyClientInner{
Identities: psk.Identities,
Binders: psk.Binders,
})
case HandshakeTypeServerHello:
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
}
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
default:
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
switch psk.HandshakeType {
case HandshakeTypeClientHello:
var inner preSharedKeyClientInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
if len(inner.Identities) != len(inner.Binders) {
return 0, fmt.Errorf("Lengths of identities and binders not equal")
}
psk.Identities = inner.Identities
psk.Binders = inner.Binders
return read, nil
case HandshakeTypeServerHello:
var inner preSharedKeyServerInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
psk.SelectedIdentity = inner.SelectedIdentity
return read, nil
default:
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
}
}
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
for i, localID := range psk.Identities {
if bytes.Equal(localID.Identity, id) {
return psk.Binders[i].Binder, true
}
}
return nil, false
}
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
//
// struct {
// PskKeyExchangeMode ke_modes<1..255>;
// } PskKeyExchangeModes;
type PSKKeyExchangeModesExtension struct {
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
}
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
return ExtensionTypePSKKeyExchangeModes
}
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
return syntax.Marshal(pkem)
}
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, pkem)
}
// struct {
// } EarlyDataIndication;
type EarlyDataExtension struct{}
func (ed EarlyDataExtension) Type() ExtensionType {
return ExtensionTypeEarlyData
}
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
return 0, nil
}
// struct {
// uint32 max_early_data_size;
// } TicketEarlyDataInfo;
type TicketEarlyDataInfoExtension struct {
MaxEarlyDataSize uint32
}
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
return ExtensionTypeTicketEarlyDataInfo
}
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
return syntax.Marshal(tedi)
}
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tedi)
}
// opaque ProtocolName<1..2^8-1>;
//
// struct {
// ProtocolName protocol_name_list<2..2^16-1>
// } ProtocolNameList;
type ALPNExtension struct {
Protocols []string
}
type protocolNameInner struct {
Name []byte `tls:"head=1,min=1"`
}
type alpnExtensionInner struct {
Protocols []protocolNameInner `tls:"head=2,min=2"`
}
func (alpn ALPNExtension) Type() ExtensionType {
return ExtensionTypeALPN
}
func (alpn ALPNExtension) Marshal() ([]byte, error) {
protocols := make([]protocolNameInner, len(alpn.Protocols))
for i, protocol := range alpn.Protocols {
protocols[i] = protocolNameInner{[]byte(protocol)}
}
return syntax.Marshal(alpnExtensionInner{protocols})
}
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
var inner alpnExtensionInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
alpn.Protocols = make([]string, len(inner.Protocols))
for i, protocol := range inner.Protocols {
alpn.Protocols[i] = string(protocol.Name)
}
return read, nil
}
// struct {
// ProtocolVersion versions<2..254>;
// } SupportedVersions;
type SupportedVersionsExtension struct {
Versions []uint16 `tls:"head=1,min=2,max=254"`
}
func (sv SupportedVersionsExtension) Type() ExtensionType {
return ExtensionTypeSupportedVersions
}
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
return syntax.Marshal(sv)
}
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sv)
}
// struct {
// opaque cookie<1..2^16-1>;
// } Cookie;
type CookieExtension struct {
Cookie []byte `tls:"head=2,min=1"`
}
func (c CookieExtension) Type() ExtensionType {
return ExtensionTypeCookie
}
func (c CookieExtension) Marshal() ([]byte, error) {
return syntax.Marshal(c)
}
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, c)
}
// defaultCookieLength is the default length of a cookie
const defaultCookieLength = 32
type defaultCookieHandler struct {
data []byte
}
var _ CookieHandler = &defaultCookieHandler{}
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
h.data = make([]byte, defaultCookieLength)
if _, err := prng.Read(h.data); err != nil {
return nil, err
}
return h.data, nil
}
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
return bytes.Equal(h.data, data)
}

147
vendor/github.com/bifurcation/mint/ffdhe.go generated vendored Normal file
View file

@ -0,0 +1,147 @@
package mint
import (
"encoding/hex"
"math/big"
)
var (
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B423861285C97FFFFFFFFFFFFFFFF"
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
"FFFFFFFFFFFFFFFF"
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
)

98
vendor/github.com/bifurcation/mint/frame-reader.go generated vendored Normal file
View file

@ -0,0 +1,98 @@
// Read a generic "framed" packet consisting of a header and a
// This is used for both TLS Records and TLS Handshake Messages
package mint
type framing interface {
headerLen() int
defaultReadLen() int
frameLen(hdr []byte) (int, error)
}
const (
kFrameReaderHdr = 0
kFrameReaderBody = 1
)
type frameNextAction func(f *frameReader) error
type frameReader struct {
details framing
state uint8
header []byte
body []byte
working []byte
writeOffset int
remainder []byte
}
func newFrameReader(d framing) *frameReader {
hdr := make([]byte, d.headerLen())
return &frameReader{
d,
kFrameReaderHdr,
hdr,
nil,
hdr,
0,
nil,
}
}
func dup(a []byte) []byte {
r := make([]byte, len(a))
copy(r, a)
return r
}
func (f *frameReader) needed() int {
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
if tmp < 0 {
return 0
}
return tmp
}
func (f *frameReader) addChunk(in []byte) {
// Append to the buffer.
logf(logTypeFrameReader, "Appending %v", len(in))
f.remainder = append(f.remainder, in...)
}
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
for f.needed() == 0 {
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
// Fill out our working block
copied := copy(f.working[f.writeOffset:], f.remainder)
f.remainder = f.remainder[copied:]
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeFrameReader, "Read would have blocked 1")
return nil, nil, WouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
// We have read a full frame
if f.state == kFrameReaderBody {
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
f.state = kFrameReaderHdr
f.working = f.header
return dup(f.header), dup(f.body), nil
}
// We have read the header
bodyLen, err := f.details.frameLen(f.header)
if err != nil {
return nil, nil, err
}
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
f.body = make([]byte, bodyLen)
f.working = f.body
f.writeOffset = 0
f.state = kFrameReaderBody
}
logf(logTypeFrameReader, "Read would have blocked 2")
return nil, nil, WouldBlock
}

253
vendor/github.com/bifurcation/mint/handshake-layer.go generated vendored Normal file
View file

@ -0,0 +1,253 @@
package mint
import (
"fmt"
"io"
"net"
)
const (
handshakeHeaderLen = 4 // handshake message header length
maxHandshakeMessageLen = 1 << 24 // max handshake message length
)
// struct {
// HandshakeType msg_type; /* handshake type */
// uint24 length; /* bytes in message */
// select (HandshakeType) {
// ...
// } body;
// } Handshake;
//
// We do the select{...} part in a different layer, so we treat the
// actual message body as opaque:
//
// struct {
// HandshakeType msg_type;
// opaque msg<0..2^24-1>
// } Handshake;
//
// TODO: File a spec bug
type HandshakeMessage struct {
// Omitted: length
msgType HandshakeType
body []byte
}
// Note: This could be done with the `syntax` module, using the simplified
// syntax as discussed above. However, since this is so simple, there's not
// much benefit to doing so.
func (hm *HandshakeMessage) Marshal() []byte {
if hm == nil {
return []byte{}
}
msgLen := len(hm.body)
data := make([]byte, 4+len(hm.body))
data[0] = byte(hm.msgType)
data[1] = byte(msgLen >> 16)
data[2] = byte(msgLen >> 8)
data[3] = byte(msgLen)
copy(data[4:], hm.body)
return data
}
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
var body HandshakeMessageBody
switch hm.msgType {
case HandshakeTypeClientHello:
body = new(ClientHelloBody)
case HandshakeTypeServerHello:
body = new(ServerHelloBody)
case HandshakeTypeHelloRetryRequest:
body = new(HelloRetryRequestBody)
case HandshakeTypeEncryptedExtensions:
body = new(EncryptedExtensionsBody)
case HandshakeTypeCertificate:
body = new(CertificateBody)
case HandshakeTypeCertificateRequest:
body = new(CertificateRequestBody)
case HandshakeTypeCertificateVerify:
body = new(CertificateVerifyBody)
case HandshakeTypeFinished:
body = &FinishedBody{VerifyDataLen: len(hm.body)}
case HandshakeTypeNewSessionTicket:
body = new(NewSessionTicketBody)
case HandshakeTypeKeyUpdate:
body = new(KeyUpdateBody)
case HandshakeTypeEndOfEarlyData:
body = new(EndOfEarlyDataBody)
default:
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
}
_, err := body.Unmarshal(hm.body)
return body, err
}
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
data, err := body.Marshal()
if err != nil {
return nil, err
}
return &HandshakeMessage{
msgType: body.Type(),
body: data,
}, nil
}
type HandshakeLayer struct {
nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
}
type handshakeLayerFrameDetails struct{}
func (d handshakeLayerFrameDetails) headerLen() int {
return handshakeHeaderLen
}
func (d handshakeLayerFrameDetails) defaultReadLen() int {
return handshakeHeaderLen + maxFragmentLen
}
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
logf(logTypeIO, "Header=%x", hdr)
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
}
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.conn = r
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
return &h
}
func (h *HandshakeLayer) readRecord() error {
logf(logTypeIO, "Trying to read record")
pt, err := h.conn.ReadRecord()
if err != nil {
return err
}
if pt.contentType != RecordTypeHandshake &&
pt.contentType != RecordTypeAlert {
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
}
if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 {
h.sendAlert(AlertUnexpectedMessage)
return io.EOF
}
return Alert(pt.fragment[1])
}
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
h.frame.addChunk(pt.fragment)
return nil
}
// sendAlert sends a TLS alert message.
func (h *HandshakeLayer) sendAlert(err Alert) error {
tmp := make([]byte, 2)
tmp[0] = AlertLevelError
tmp[1] = byte(err)
h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAlert,
fragment: tmp},
)
// closeNotify is a special case in that it isn't an error:
if err != AlertCloseNotify {
return &net.OpError{Op: "local error", Err: err}
}
return nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var hdr, body []byte
var err error
for {
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeHandshake, "Trying to read a new record")
err = h.readRecord()
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
hdr, body, err = h.frame.process()
if err == nil {
break
}
if err != nil && (h.nonblocking || err != WouldBlock) {
return nil, err
}
}
logf(logTypeHandshake, "read handshake message")
hm := &HandshakeMessage{}
hm.msgType = HandshakeType(hdr[0])
hm.body = make([]byte, len(body))
copy(hm.body, body)
return hm, nil
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
return h.WriteMessages([]*HandshakeMessage{hm})
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
}
// Write out headers and bodies
buffer := []byte{}
for _, msg := range hms {
msgLen := len(msg.body)
if msgLen > maxHandshakeMessageLen {
return fmt.Errorf("tls.handshakelayer: Message too large to send")
}
buffer = append(buffer, msg.Marshal()...)
}
// Send full-size fragments
var start int
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start : start+maxFragmentLen],
})
if err != nil {
return err
}
}
// Send a final partial fragment if necessary
if start < len(buffer) {
err := h.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buffer[start:],
})
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,450 @@
package mint
import (
"bytes"
"crypto"
"crypto/x509"
"encoding/binary"
"fmt"
"github.com/bifurcation/mint/syntax"
)
type HandshakeMessageBody interface {
Type() HandshakeType
Marshal() ([]byte, error)
Unmarshal(data []byte) (int, error)
}
// struct {
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
// Random random;
// opaque legacy_session_id<0..32>;
// CipherSuite cipher_suites<2..2^16-2>;
// opaque legacy_compression_methods<1..2^8-1>;
// Extension extensions<0..2^16-1>;
// } ClientHello;
type ClientHelloBody struct {
// Omitted: clientVersion
// Omitted: legacySessionID
// Omitted: legacyCompressionMethods
Random [32]byte
CipherSuites []CipherSuite
Extensions ExtensionList
}
type clientHelloBodyInner struct {
LegacyVersion uint16
Random [32]byte
LegacySessionID []byte `tls:"head=1,max=32"`
CipherSuites []CipherSuite `tls:"head=2,min=2"`
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
Extensions []Extension `tls:"head=2"`
}
func (ch ClientHelloBody) Type() HandshakeType {
return HandshakeTypeClientHello
}
func (ch ClientHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(clientHelloBodyInner{
LegacyVersion: 0x0303,
Random: ch.Random,
LegacySessionID: []byte{},
CipherSuites: ch.CipherSuites,
LegacyCompressionMethods: []byte{0},
Extensions: ch.Extensions,
})
}
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
var inner clientHelloBodyInner
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return 0, err
}
// We are strict about these things because we only support 1.3
if inner.LegacyVersion != 0x0303 {
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
}
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
}
ch.Random = inner.Random
ch.CipherSuites = inner.CipherSuites
ch.Extensions = inner.Extensions
return read, nil
}
// TODO: File a spec bug to clarify this
func (ch ClientHelloBody) Truncated() ([]byte, error) {
if len(ch.Extensions) == 0 {
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
}
pskExt := ch.Extensions[len(ch.Extensions)-1]
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
}
chm, err := HandshakeMessageFromBody(&ch)
if err != nil {
return nil, err
}
chData := chm.Marshal()
psk := PreSharedKeyExtension{
HandshakeType: HandshakeTypeClientHello,
}
_, err = psk.Unmarshal(pskExt.ExtensionData)
if err != nil {
return nil, err
}
// Marshal just the binders so that we know how much to truncate
binders := struct {
Binders []PSKBinderEntry `tls:"head=2,min=33"`
}{Binders: psk.Binders}
binderData, _ := syntax.Marshal(binders)
binderLen := len(binderData)
chLen := len(chData)
return chData[:chLen-binderLen], nil
}
// struct {
// ProtocolVersion server_version;
// CipherSuite cipher_suite;
// Extension extensions<2..2^16-1>;
// } HelloRetryRequest;
type HelloRetryRequestBody struct {
Version uint16
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2,min=2"`
}
func (hrr HelloRetryRequestBody) Type() HandshakeType {
return HandshakeTypeHelloRetryRequest
}
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(hrr)
}
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, hrr)
}
// struct {
// ProtocolVersion version;
// Random random;
// CipherSuite cipher_suite;
// Extension extensions<0..2^16-1>;
// } ServerHello;
type ServerHelloBody struct {
Version uint16
Random [32]byte
CipherSuite CipherSuite
Extensions ExtensionList `tls:"head=2"`
}
func (sh ServerHelloBody) Type() HandshakeType {
return HandshakeTypeServerHello
}
func (sh ServerHelloBody) Marshal() ([]byte, error) {
return syntax.Marshal(sh)
}
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, sh)
}
// struct {
// opaque verify_data[verify_data_length];
// } Finished;
//
// verifyDataLen is not a field in the TLS struct, but we add it here so
// that calling code can tell us how much data to expect when we marshal /
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
// try to keep the signature consistent for now.)
//
// For similar reasons, we don't use the `syntax` module here, because this
// struct doesn't map well to standard TLS presentation language concepts.
//
// TODO: File a spec bug
type FinishedBody struct {
VerifyDataLen int
VerifyData []byte
}
func (fin FinishedBody) Type() HandshakeType {
return HandshakeTypeFinished
}
func (fin FinishedBody) Marshal() ([]byte, error) {
if len(fin.VerifyData) != fin.VerifyDataLen {
return nil, fmt.Errorf("tls.finished: data length mismatch")
}
body := make([]byte, len(fin.VerifyData))
copy(body, fin.VerifyData)
return body, nil
}
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
if len(data) < fin.VerifyDataLen {
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
}
fin.VerifyData = make([]byte, fin.VerifyDataLen)
copy(fin.VerifyData, data[:fin.VerifyDataLen])
return fin.VerifyDataLen, nil
}
// struct {
// Extension extensions<0..2^16-1>;
// } EncryptedExtensions;
//
// Marshal() and Unmarshal() are handled by ExtensionList
type EncryptedExtensionsBody struct {
Extensions ExtensionList `tls:"head=2"`
}
func (ee EncryptedExtensionsBody) Type() HandshakeType {
return HandshakeTypeEncryptedExtensions
}
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
return syntax.Marshal(ee)
}
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ee)
}
// opaque ASN1Cert<1..2^24-1>;
//
// struct {
// ASN1Cert cert_data;
// Extension extensions<0..2^16-1>
// } CertificateEntry;
//
// struct {
// opaque certificate_request_context<0..2^8-1>;
// CertificateEntry certificate_list<0..2^24-1>;
// } Certificate;
type CertificateEntry struct {
CertData *x509.Certificate
Extensions ExtensionList
}
type CertificateBody struct {
CertificateRequestContext []byte
CertificateList []CertificateEntry
}
type certificateEntryInner struct {
CertData []byte `tls:"head=3,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
type certificateBodyInner struct {
CertificateRequestContext []byte `tls:"head=1"`
CertificateList []certificateEntryInner `tls:"head=3"`
}
func (c CertificateBody) Type() HandshakeType {
return HandshakeTypeCertificate
}
func (c CertificateBody) Marshal() ([]byte, error) {
inner := certificateBodyInner{
CertificateRequestContext: c.CertificateRequestContext,
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
}
for i, entry := range c.CertificateList {
inner.CertificateList[i] = certificateEntryInner{
CertData: entry.CertData.Raw,
Extensions: entry.Extensions,
}
}
return syntax.Marshal(inner)
}
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
inner := certificateBodyInner{}
read, err := syntax.Unmarshal(data, &inner)
if err != nil {
return read, err
}
c.CertificateRequestContext = inner.CertificateRequestContext
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
for i, entry := range inner.CertificateList {
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
if err != nil {
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
}
c.CertificateList[i].Extensions = entry.Extensions
}
return read, nil
}
// struct {
// SignatureScheme algorithm;
// opaque signature<0..2^16-1>;
// } CertificateVerify;
type CertificateVerifyBody struct {
Algorithm SignatureScheme
Signature []byte `tls:"head=2"`
}
func (cv CertificateVerifyBody) Type() HandshakeType {
return HandshakeTypeCertificateVerify
}
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
return syntax.Marshal(cv)
}
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cv)
}
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
// TODO: Change context for client auth
// TODO: Put this in a const
const context = "TLS 1.3, server CertificateVerify"
sigInput := bytes.Repeat([]byte{0x20}, 64)
sigInput = append(sigInput, []byte(context)...)
sigInput = append(sigInput, []byte{0}...)
sigInput = append(sigInput, data...)
return sigInput
}
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
sigInput := cv.EncodeSignatureInput(handshakeHash)
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return
}
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
sigInput := cv.EncodeSignatureInput(handshakeHash)
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
}
// struct {
// opaque certificate_request_context<0..2^8-1>;
// Extension extensions<2..2^16-1>;
// } CertificateRequest;
type CertificateRequestBody struct {
CertificateRequestContext []byte `tls:"head=1"`
Extensions ExtensionList `tls:"head=2"`
}
func (cr CertificateRequestBody) Type() HandshakeType {
return HandshakeTypeCertificateRequest
}
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
return syntax.Marshal(cr)
}
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, cr)
}
// struct {
// uint32 ticket_lifetime;
// uint32 ticket_age_add;
// opaque ticket_nonce<1..255>;
// opaque ticket<1..2^16-1>;
// Extension extensions<0..2^16-2>;
// } NewSessionTicket;
type NewSessionTicketBody struct {
TicketLifetime uint32
TicketAgeAdd uint32
TicketNonce []byte `tls:"head=1,min=1"`
Ticket []byte `tls:"head=2,min=1"`
Extensions ExtensionList `tls:"head=2"`
}
const ticketNonceLen = 16
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
buf := make([]byte, 4+ticketNonceLen+ticketLen)
_, err := prng.Read(buf)
if err != nil {
return nil, err
}
tkt := &NewSessionTicketBody{
TicketLifetime: ticketLifetime,
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
TicketNonce: buf[4 : 4+ticketNonceLen],
Ticket: buf[4+ticketNonceLen:],
}
return tkt, err
}
func (tkt NewSessionTicketBody) Type() HandshakeType {
return HandshakeTypeNewSessionTicket
}
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
return syntax.Marshal(tkt)
}
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, tkt)
}
// enum {
// update_not_requested(0), update_requested(1), (255)
// } KeyUpdateRequest;
//
// struct {
// KeyUpdateRequest request_update;
// } KeyUpdate;
type KeyUpdateBody struct {
KeyUpdateRequest KeyUpdateRequest
}
func (ku KeyUpdateBody) Type() HandshakeType {
return HandshakeTypeKeyUpdate
}
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
return syntax.Marshal(ku)
}
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
return syntax.Unmarshal(data, ku)
}
// struct {} EndOfEarlyData;
type EndOfEarlyDataBody struct{}
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
return HandshakeTypeEndOfEarlyData
}
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
return []byte{}, nil
}
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
return 0, nil
}

55
vendor/github.com/bifurcation/mint/log.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
package mint
import (
"fmt"
"log"
"os"
"strings"
)
// We use this environment variable to control logging. It should be a
// comma-separated list of log tags (see below) or "*" to enable all logging.
const logConfigVar = "MINT_LOG"
// Pre-defined log types
const (
logTypeCrypto = "crypto"
logTypeHandshake = "handshake"
logTypeNegotiation = "negotiation"
logTypeIO = "io"
logTypeFrameReader = "frame"
logTypeVerbose = "verbose"
)
var (
logFunction = log.Printf
logAll = false
logSettings = map[string]bool{}
)
func init() {
parseLogEnv(os.Environ())
}
func parseLogEnv(env []string) {
for _, stmt := range env {
if strings.HasPrefix(stmt, logConfigVar+"=") {
val := stmt[len(logConfigVar)+1:]
if val == "*" {
logAll = true
} else {
for _, t := range strings.Split(val, ",") {
logSettings[t] = true
}
}
}
}
}
func logf(tag string, format string, args ...interface{}) {
if logAll || logSettings[tag] {
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
logFunction(fullFormat, args...)
}
}

217
vendor/github.com/bifurcation/mint/negotiation.go generated vendored Normal file
View file

@ -0,0 +1,217 @@
package mint
import (
"bytes"
"encoding/hex"
"fmt"
"time"
)
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
for _, offeredVersion := range offered {
for _, supportedVersion := range supported {
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
if offeredVersion == supportedVersion {
// XXX: Should probably be highest supported version, but for now, we
// only support one version, so it doesn't really matter.
return true, offeredVersion
}
}
}
return false, 0
}
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
for _, share := range keyShares {
for _, group := range groups {
if group != share.Group {
continue
}
pub, priv, err := newKeyShare(share.Group)
if err != nil {
// If we encounter an error, just keep looking
continue
}
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
if err != nil {
// If we encounter an error, just keep looking
continue
}
return true, group, pub, dhSecret
}
}
return false, 0, nil, nil
}
const (
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
)
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
for i, id := range identities {
identityHex := hex.EncodeToString(id.Identity)
psk, ok := psks.Get(identityHex)
if !ok {
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
continue
}
// For resumption, make sure the ticket age is correct
if psk.IsResumption {
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
ticketAgeDelta := knownTicketAge - extTicketAge
if knownTicketAge < extTicketAge {
ticketAgeDelta = extTicketAge - knownTicketAge
}
if ticketAgeDelta > ticketAgeTolerance {
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
}
}
params, ok := cipherSuiteMap[psk.CipherSuite]
if !ok {
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
return false, 0, nil, CipherSuiteParams{}, err
}
// Compute binder
binderLabel := labelExternalBinder
if psk.IsResumption {
binderLabel = labelResumptionBinder
}
h0 := params.Hash.New().Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
// context = ClientHello[truncated]
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
ctxHash := params.Hash.New()
ctxHash.Write(context)
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
if !bytes.Equal(binder, binders[i].Binder) {
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
}
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
return true, i, &psk, params, nil
}
logf(logTypeNegotiation, "Failed to find a usable PSK")
return false, 0, nil, CipherSuiteParams{}, nil
}
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
dhAllowed := false
dhRequired := true
for _, mode := range modes {
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
dhRequired = dhRequired && (mode == PSKModeDHEKE)
}
// Use PSK if we can meet DH requirement and modes were provided
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
// Use DH if allowed
usingDH := canDoDH && (dhAllowed || !usingPSK)
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
return usingDH, usingPSK
}
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
// Select for server name if provided
candidates := certs
if serverName != nil {
candidatesByName := []*Certificate{}
for _, cert := range certs {
for _, name := range cert.Chain[0].DNSNames {
if len(*serverName) > 0 && name == *serverName {
candidatesByName = append(candidatesByName, cert)
}
}
}
if len(candidatesByName) == 0 {
return nil, 0, fmt.Errorf("No certificates available for server name")
}
candidates = candidatesByName
}
// Select for signature scheme
for _, cert := range candidates {
for _, scheme := range signatureSchemes {
if !schemeValidForKey(scheme, cert.PrivateKey) {
continue
}
return cert, scheme, nil
}
}
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
}
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
return usingEarlyData
}
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
for _, s1 := range offered {
if psk != nil {
if s1 == psk.CipherSuite {
return s1, nil
}
continue
}
for _, s2 := range supported {
if s1 == s2 {
return s1, nil
}
}
}
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
}
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
for _, p1 := range offered {
if psk != nil {
if p1 != psk.NextProto {
continue
}
}
for _, p2 := range supported {
if p1 == p2 {
return p1, nil
}
}
}
// If the client offers ALPN on resumption, it must match the earlier one
var err error
if psk != nil && psk.IsResumption && (len(offered) > 0) {
err = fmt.Errorf("ALPN for PSK not provided")
}
return "", err
}

296
vendor/github.com/bifurcation/mint/record-layer.go generated vendored Normal file
View file

@ -0,0 +1,296 @@
package mint
import (
"bytes"
"crypto/cipher"
"fmt"
"io"
"sync"
)
const (
sequenceNumberLen = 8 // sequence number length
recordHeaderLen = 5 // record header length
maxFragmentLen = 1 << 14 // max number of bytes in a record
)
type DecryptError string
func (err DecryptError) Error() string {
return string(err)
}
// struct {
// ContentType type;
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
// uint16 length;
// opaque fragment[TLSPlaintext.length];
// } TLSPlaintext;
type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
fragment []byte
}
type RecordLayer struct {
sync.Mutex
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
nextData []byte // The next record to send
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
ivLength int // Length of the seq and nonce fields
seq []byte // Zero-padded sequence number
nonce []byte // Buffer for per-record nonces
cipher cipher.AEAD // AEAD cipher
}
type recordLayerFrameDetails struct{}
func (d recordLayerFrameDetails) headerLen() int {
return recordHeaderLen
}
func (d recordLayerFrameDetails) defaultReadLen() int {
return recordHeaderLen + maxFragmentLen
}
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return (int(hdr[3]) << 8) | int(hdr[4]), nil
}
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
r := RecordLayer{}
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{})
r.ivLength = 0
return &r
}
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
var err error
r.cipher, err = cipher(key)
if err != nil {
return err
}
r.ivLength = len(iv)
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
r.nonce = make([]byte, r.ivLength)
copy(r.nonce, iv)
return nil
}
func (r *RecordLayer) incrementSequenceNumber() {
if r.ivLength == 0 {
return
}
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
r.seq[i]++
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
if r.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
ciphertextLen := plaintextLen + r.cipher.Overhead()
// Assemble the revised plaintext
out := &TLSPlaintext{
contentType: RecordTypeApplicationData,
fragment: make([]byte, ciphertextLen),
}
copy(out.fragment, pt.fragment)
out.fragment[originalLen] = byte(pt.contentType)
for i := 1; i <= padLen; i++ {
out.fragment[originalLen+i] = 0
}
// Encrypt the fragment
payload := out.fragment[:plaintextLen]
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
return out
}
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
if len(pt.fragment) < r.cipher.Overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
return nil, 0, DecryptError(msg)
}
decryptLen := len(pt.fragment) - r.cipher.Overhead()
out := &TLSPlaintext{
contentType: pt.contentType,
fragment: make([]byte, decryptLen),
}
// Decrypt
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
if err != nil {
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
// Find the padding boundary
padLen := 0
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
}
// Transfer the content type
newLen := decryptLen - padLen - 1
out.contentType = RecordType(out.fragment[newLen])
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
return out, padLen, nil
}
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
var pt *TLSPlaintext
var err error
for {
pt, err = r.nextRecord()
if err == nil {
break
}
if !block || err != WouldBlock {
return 0, err
}
}
return pt.contentType, nil
}
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord()
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
if r.cachedRecord != nil {
logf(logTypeIO, "Returning cached record")
return r.cachedRecord, r.cachedError
}
// Loop until one of three things happens:
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// return WouldBlock
// 3. We get an error.
err := WouldBlock
var header, body []byte
for err != nil {
if r.frame.needed() > 0 {
buf := make([]byte, recordHeaderLen+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "Error reading, %v", err)
return nil, err
}
if n == 0 {
return nil, WouldBlock
}
logf(logTypeIO, "Read %v bytes", n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around on WouldBlock to see if some
// data is now available.
if err != nil && err != WouldBlock {
return nil, err
}
}
pt := &TLSPlaintext{}
// Validate content type
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
pt.contentType = RecordType(header[0])
}
// Validate version
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
}
// Validate size < max
size := (int(header[3]) << 8) + int(header[4])
if size > maxFragmentLen+256 {
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
}
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// Attempt to decrypt fragment
if r.cipher != nil {
pt, _, err = r.decrypt(pt)
if err != nil {
return nil, err
}
}
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
r.cachedRecord = pt
r.incrementSequenceNumber()
return pt, nil
}
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
return r.WriteRecordWithPadding(pt, 0)
}
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
if r.cipher != nil {
pt = r.encrypt(pt, padLen)
} else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
}
if len(pt.fragment) > maxFragmentLen {
return fmt.Errorf("tls.record: Record size too big")
}
length := len(pt.fragment)
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
record := append(header, pt.fragment...)
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
r.incrementSequenceNumber()
_, err := r.conn.Write(record)
return err
}

View file

@ -0,0 +1,898 @@
package mint
import (
"bytes"
"hash"
"reflect"
)
// Server State Machine
//
// START <-----+
// Recv ClientHello | | Send HelloRetryRequest
// v |
// RECVD_CH ----+
// | Select parameters
// | Send ServerHello
// v
// NEGOTIATED
// | Send EncryptedExtensions
// | [Send CertificateRequest]
// Can send | [Send Certificate + CertificateVerify]
// app data --> | Send Finished
// after +--------+--------+
// here No 0-RTT | | 0-RTT
// | v
// | WAIT_EOED <---+
// | Recv | | | Recv
// | EndOfEarlyData | | | early data
// | | +-----+
// +> WAIT_FLIGHT2 <-+
// |
// +--------+--------+
// No auth | | Client auth
// | |
// | v
// | WAIT_CERT
// | Recv | | Recv Certificate
// | empty | v
// | Certificate | WAIT_CV
// | | | Recv
// | v | CertificateVerify
// +-> WAIT_FINISHED <---+
// | Recv Finished
// v
// CONNECTED
//
// NB: Not using state RECVD_CH
//
// State Instructions
// START {}
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
// WAIT_EOED RekeyIn;
// WAIT_FLIGHT2 {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
type ServerStateStart struct {
Caps Capabilities
conn *Conn
cookieSent bool
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
}
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeClientHello {
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
return nil, nil, AlertUnexpectedMessage
}
ch := &ClientHelloBody{}
_, err := ch.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
clientHello := hm
connParams := ConnectionParameters{}
supportedVersions := new(SupportedVersionsExtension)
serverName := new(ServerNameExtension)
supportedGroups := new(SupportedGroupsExtension)
signatureAlgorithms := new(SignatureAlgorithmsExtension)
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
clientEarlyData := &EarlyDataExtension{}
clientALPN := new(ALPNExtension)
clientPSKModes := new(PSKKeyExchangeModesExtension)
clientCookie := new(CookieExtension)
// Handle external extensions.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
return nil, nil, AlertInternalError
}
}
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
gotServerName := ch.Extensions.Find(serverName)
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
gotEarlyData := ch.Extensions.Find(clientEarlyData)
ch.Extensions.Find(clientKeyShares)
ch.Extensions.Find(clientPSK)
ch.Extensions.Find(clientALPN)
ch.Extensions.Find(clientPSKModes)
ch.Extensions.Find(clientCookie)
if gotServerName {
connParams.ServerName = string(*serverName)
}
// If the client didn't send supportedVersions or doesn't support 1.3,
// then we're done here.
if !gotSupportedVersions {
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
return nil, nil, AlertProtocolVersion
}
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
if !versionOK {
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
return nil, nil, AlertProtocolVersion
}
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
return nil, nil, AlertAccessDenied
}
// Figure out if we can do DH
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
// Figure out if we can do PSK
canDoPSK := false
var selectedPSK int
var psk *PreSharedKey
var params CipherSuiteParams
if len(clientPSK.Identities) > 0 {
contextBase := []byte{}
if state.helloRetryRequest != nil {
chBytes := state.firstClientHello.Marshal()
hrrBytes := state.helloRetryRequest.Marshal()
contextBase = append(chBytes, hrrBytes...)
}
chTrunc, err := ch.Truncated()
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
return nil, nil, AlertDecodeError
}
context := append(contextBase, chTrunc...)
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
return nil, nil, AlertInternalError
}
}
// Figure out if we actually should do DH / PSK
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
// Select a ciphersuite
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
return nil, nil, AlertHandshakeFailure
}
// Send a cookie if required
// NB: Need to do this here because it's after ciphersuite selection, which
// has to be after PSK selection.
// XXX: Doing this statefully for now, could be stateless
var cookieData []byte
if state.Caps.RequireCookie && !state.cookieSent {
var err error
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
return nil, nil, AlertInternalError
}
}
if cookieData != nil {
// Ignoring errors because everything here is newly constructed, so there
// shouldn't be marshal errors
hrr := &HelloRetryRequestBody{
Version: supportedVersion,
CipherSuite: connParams.CipherSuite,
}
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
return nil, nil, AlertInternalError
}
params := cipherSuiteMap[connParams.CipherSuite]
h := params.Hash.New()
h.Write(clientHello.Marshal())
firstClientHello := &HandshakeMessage{
msgType: HandshakeTypeMessageHash,
body: h.Sum(nil),
}
nextState := ServerStateStart{
Caps: state.Caps,
conn: state.conn,
cookieSent: true,
firstClientHello: firstClientHello,
helloRetryRequest: helloRetryRequest,
}
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
return nextState, toSend, AlertNoAlert
}
// If we've got no entropy to make keys from, fail
if !connParams.UsingDH && !connParams.UsingPSK {
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
return nil, nil, AlertHandshakeFailure
}
var pskSecret []byte
var cert *Certificate
var certScheme SignatureScheme
if connParams.UsingPSK {
pskSecret = psk.Key
} else {
psk = nil
// If we're not using a PSK mode, then we need to have certain extensions
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
return nil, nil, AlertMissingExtension
}
// Select a certificate
name := string(*serverName)
var err error
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
return nil, nil, AlertAccessDenied
}
}
if !connParams.UsingDH {
dhSecret = nil
}
// Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = gotEarlyData
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
if connParams.UsingEarlyData {
h := params.Hash.New()
h.Write(clientHello.Marshal())
chHash := h.Sum(nil)
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
}
// Select a next protocol
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
if err != nil {
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
return nil, nil, AlertNoApplicationProtocol
}
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
return ServerStateNegotiated{
Caps: state.Caps,
Params: connParams,
dhGroup: dhGroup,
dhPublic: dhPublic,
dhSecret: dhSecret,
pskSecret: pskSecret,
selectedPSK: selectedPSK,
cert: cert,
certScheme: certScheme,
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
firstClientHello: state.firstClientHello,
helloRetryRequest: state.helloRetryRequest,
clientHello: clientHello,
}.Next(nil)
}
type ServerStateNegotiated struct {
Caps Capabilities
Params ConnectionParameters
dhGroup NamedGroup
dhPublic []byte
dhSecret []byte
pskSecret []byte
clientEarlyTrafficSecret []byte
selectedPSK int
cert *Certificate
certScheme SignatureScheme
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
clientHello *HandshakeMessage
}
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
// Create the ServerHello
sh := &ServerHelloBody{
Version: supportedVersion,
CipherSuite: state.Params.CipherSuite,
}
_, err := prng.Read(sh.Random[:])
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
return nil, nil, AlertInternalError
}
if state.Params.UsingDH {
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
err = sh.Extensions.Add(&KeyShareExtension{
HandshakeType: HandshakeTypeServerHello,
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
err = sh.Extensions.Add(&PreSharedKeyExtension{
HandshakeType: HandshakeTypeServerHello,
SelectedIdentity: uint16(state.selectedPSK),
})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
return nil, nil, AlertInternalError
}
}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
serverHello, err := HandshakeMessageFromBody(sh)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
return nil, nil, AlertInternalError
}
// Look up crypto params
params, ok := cipherSuiteMap[sh.CipherSuite]
if !ok {
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
return nil, nil, AlertHandshakeFailure
}
// Start up the handshake hash
handshakeHash := params.Hash.New()
handshakeHash.Write(state.firstClientHello.Marshal())
handshakeHash.Write(state.helloRetryRequest.Marshal())
handshakeHash.Write(state.clientHello.Marshal())
handshakeHash.Write(serverHello.Marshal())
// Compute handshake secrets
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
var earlySecret []byte
if state.Params.UsingPSK {
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
} else {
earlySecret = HkdfExtract(params.Hash, zero, zero)
}
if state.dhSecret == nil {
state.dhSecret = zero
}
h0 := params.Hash.New().Sum(nil)
h2 := handshakeHash.Sum(nil)
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
// Send an EncryptedExtensions message (even if it's empty)
eeList := ExtensionList{}
if state.Params.NextProto != "" {
logf(logTypeHandshake, "[server] sending ALPN extension")
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
if state.Params.UsingEarlyData {
logf(logTypeHandshake, "[server] sending EDI extension")
err = eeList.Add(&EarlyDataExtension{})
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
}
ee := &EncryptedExtensionsBody{eeList}
// Run the external extension handler.
if state.Caps.ExtensionHandler != nil {
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
return nil, nil, AlertInternalError
}
}
eem, err := HandshakeMessageFromBody(ee)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
return nil, nil, AlertInternalError
}
handshakeHash.Write(eem.Marshal())
toSend := []HandshakeAction{
SendHandshakeMessage{serverHello},
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
SendHandshakeMessage{eem},
}
// Authenticate with a certificate if required
if !state.Params.UsingPSK {
// Send a CertificateRequest message if we want client auth
if state.Caps.RequireClientAuth {
state.Params.UsingClientAuth = true
// XXX: We don't support sending any constraints besides a list of
// supported signature algorithms
cr := &CertificateRequestBody{}
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
err := cr.Extensions.Add(schemes)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
crm, err := HandshakeMessageFromBody(cr)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
return nil, nil, AlertInternalError
}
//TODO state.state.serverCertificateRequest = cr
toSend = append(toSend, SendHandshakeMessage{crm})
handshakeHash.Write(crm.Marshal())
}
// Create and send Certificate, CertificateVerify
certificate := &CertificateBody{
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
}
for i, entry := range state.cert.Chain {
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
}
certm, err := HandshakeMessageFromBody(certificate)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certm})
handshakeHash.Write(certm.Marshal())
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
hcv := handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
certvm, err := HandshakeMessageFromBody(certificateVerify)
if err != nil {
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
return nil, nil, AlertInternalError
}
toSend = append(toSend, SendHandshakeMessage{certvm})
handshakeHash.Write(certvm.Marshal())
}
// Compute secrets resulting from the server's first flight
h3 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
// Assemble the Finished message
fin := &FinishedBody{
VerifyDataLen: len(serverFinishedData),
VerifyData: serverFinishedData,
}
finm, _ := HandshakeMessageFromBody(fin)
toSend = append(toSend, SendHandshakeMessage{finm})
handshakeHash.Write(finm.Marshal())
// Compute traffic secrets
h4 := handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
if state.Params.UsingEarlyData {
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
nextState := ServerStateWaitEOED{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
ReadEarlyData{},
}...)
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
toSend = append(toSend, []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
ReadPastEarlyData{},
}...)
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.Caps.AuthCertificate,
Params: state.Params,
cryptoParams: params,
handshakeHash: handshakeHash,
masterSecret: masterSecret,
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
clientTrafficSecret: clientTrafficSecret,
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitEOED struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if len(hm.body) > 0 {
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
toSend := []HandshakeAction{
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
}
waitFlight2 := ServerStateWaitFlight2{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
nextState, moreToSend, alert := waitFlight2.Next(nil)
toSend = append(toSend, moreToSend...)
return nextState, toSend, alert
}
type ServerStateWaitFlight2 struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm != nil {
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
if state.Params.UsingClientAuth {
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
nextState := ServerStateWaitCert{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
handshakeHash: state.handshakeHash,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCert struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificate {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
cert := &CertificateBody{}
_, err := cert.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
return nil, nil, AlertDecodeError
}
state.handshakeHash.Write(hm.Marshal())
if len(cert.CertificateList) == 0 {
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
nextState := ServerStateWaitCV{
AuthCertificate: state.AuthCertificate,
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
clientCertificate: cert,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitCV struct {
AuthCertificate func(chain []CertificateEntry) error
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
clientCertificate *CertificateBody
}
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
return nil, nil, AlertUnexpectedMessage
}
certVerify := &CertificateVerifyBody{}
_, err := certVerify.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client signature over handshake hash
hcv := state.handshakeHash.Sum(nil)
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
return nil, nil, AlertHandshakeFailure
}
if state.AuthCertificate != nil {
err := state.AuthCertificate(state.clientCertificate.CertificateList)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
return nil, nil, AlertBadCertificate
}
} else {
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
}
// If it passes, record the certificateVerify in the transcript hash
state.handshakeHash.Write(hm.Marshal())
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
nextState := ServerStateWaitFinished{
Params: state.Params,
cryptoParams: state.cryptoParams,
masterSecret: state.masterSecret,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
handshakeHash: state.handshakeHash,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
return nextState, nil, AlertNoAlert
}
type ServerStateWaitFinished struct {
Params ConnectionParameters
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
handshakeHash hash.Hash
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil || hm.msgType != HandshakeTypeFinished {
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
_, err := fin.Unmarshal(hm.body)
if err != nil {
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
return nil, nil, AlertDecodeError
}
// Verify client Finished data
h5 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
return nil, nil, AlertHandshakeFailure
}
// Compute the resumption secret
state.handshakeHash.Write(hm.Marshal())
h6 := state.handshakeHash.Sum(nil)
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
// Compute client traffic keys
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
nextState := StateConnected{
Params: state.Params,
isClient: false,
cryptoParams: state.cryptoParams,
resumptionSecret: resumptionSecret,
clientTrafficSecret: state.clientTrafficSecret,
serverTrafficSecret: state.serverTrafficSecret,
exporterSecret: state.exporterSecret,
}
toSend := []HandshakeAction{
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
}
return nextState, toSend, AlertNoAlert
}

230
vendor/github.com/bifurcation/mint/state-machine.go generated vendored Normal file
View file

@ -0,0 +1,230 @@
package mint
import (
"time"
)
// Marker interface for actions that an implementation should take based on
// state transitions.
type HandshakeAction interface{}
type SendHandshakeMessage struct {
Message *HandshakeMessage
}
type SendEarlyData struct{}
type ReadEarlyData struct{}
type ReadPastEarlyData struct{}
type RekeyIn struct {
Label string
KeySet keySet
}
type RekeyOut struct {
Label string
KeySet keySet
}
type StorePSK struct {
PSK PreSharedKey
}
type HandshakeState interface {
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
}
type AppExtensionHandler interface {
Send(hs HandshakeType, el *ExtensionList) error
Receive(hs HandshakeType, el *ExtensionList) error
}
// Capabilities objects represent the capabilities of a TLS client or server,
// as an input to TLS negotiation
type Capabilities struct {
// For both client and server
CipherSuites []CipherSuite
Groups []NamedGroup
SignatureSchemes []SignatureScheme
PSKs PreSharedKeyCache
Certificates []*Certificate
AuthCertificate func(chain []CertificateEntry) error
ExtensionHandler AppExtensionHandler
// For client
PSKModes []PSKKeyExchangeMode
// For server
NextProtos []string
AllowEarlyData bool
RequireCookie bool
CookieHandler CookieHandler
RequireClientAuth bool
}
// ConnectionOptions objects represent per-connection settings for a client
// initiating a connection
type ConnectionOptions struct {
ServerName string
NextProtos []string
EarlyData []byte
}
// ConnectionParameters objects represent the parameters negotiated for a
// connection.
type ConnectionParameters struct {
UsingPSK bool
UsingDH bool
ClientSendingEarlyData bool
UsingEarlyData bool
UsingClientAuth bool
CipherSuite CipherSuite
ServerName string
NextProto string
}
// StateConnected is symmetric between client and server
type StateConnected struct {
Params ConnectionParameters
isClient bool
cryptoParams CipherSuiteParams
resumptionSecret []byte
clientTrafficSecret []byte
serverTrafficSecret []byte
exporterSecret []byte
}
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
var trafficKeys keySet
if state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
SendHandshakeMessage{kum},
RekeyOut{Label: "update", KeySet: trafficKeys},
}
return toSend, AlertNoAlert
}
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
tkt, err := NewSessionTicket(length, lifetime)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
return nil, AlertInternalError
}
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
return nil, AlertInternalError
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
newPSK := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: tkt.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
TicketAgeAdd: tkt.TicketAgeAdd,
}
tktm, err := HandshakeMessageFromBody(tkt)
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
return nil, AlertInternalError
}
toSend := []HandshakeAction{
StorePSK{newPSK},
SendHandshakeMessage{tktm},
}
return toSend, AlertNoAlert
}
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
if hm == nil {
logf(logTypeHandshake, "[StateConnected] Unexpected message")
return nil, nil, AlertUnexpectedMessage
}
bodyGeneric, err := hm.ToBody()
if err != nil {
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
return nil, nil, AlertDecodeError
}
switch body := bodyGeneric.(type) {
case *KeyUpdateBody:
var trafficKeys keySet
if !state.isClient {
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
} else {
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
}
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
// If requested, roll outbound keys and send a KeyUpdate
if body.KeyUpdateRequest == KeyUpdateRequested {
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
if alert != AlertNoAlert {
return nil, nil, alert
}
toSend = append(toSend, moreToSend...)
}
return state, toSend, AlertNoAlert
case *NewSessionTicketBody:
// XXX: Allow NewSessionTicket in both directions?
if !state.isClient {
return nil, nil, AlertUnexpectedMessage
}
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
psk := PreSharedKey{
CipherSuite: state.cryptoParams.Suite,
IsResumption: true,
Identity: body.Ticket,
Key: resumptionKey,
NextProto: state.Params.NextProto,
ReceivedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
TicketAgeAdd: body.TicketAgeAdd,
}
toSend := []HandshakeAction{StorePSK{psk}}
return state, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
return nil, nil, AlertUnexpectedMessage
}

243
vendor/github.com/bifurcation/mint/syntax/decode.go generated vendored Normal file
View file

@ -0,0 +1,243 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Unmarshal(data []byte, v interface{}) (int, error) {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
d := decodeState{}
d.Write(data)
return d.unmarshal(v)
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type decOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type decodeState struct {
bytes.Buffer
}
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
}
read = d.value(rv)
return read, nil
}
func (e *decodeState) value(v reflect.Value) int {
return valueDecoder(v)(e, v, decOpts{})
}
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
func valueDecoder(v reflect.Value) decoderFunc {
return typeDecoder(v.Type().Elem())
}
func typeDecoder(t reflect.Type) decoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeDecoder(t)
}
func newTypeDecoder(t reflect.Type) decoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintDecoder
case reflect.Array:
return newArrayDecoder(t)
case reflect.Slice:
return newSliceDecoder(t)
case reflect.Struct:
return newStructDecoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific decoders below
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
var uintLen int
switch v.Elem().Kind() {
case reflect.Uint8:
uintLen = 1
case reflect.Uint16:
uintLen = 2
case reflect.Uint32:
uintLen = 4
case reflect.Uint64:
uintLen = 8
}
buf := make([]byte, uintLen)
n, err := d.Read(buf)
if err != nil {
panic(err)
}
if n != uintLen {
panic(fmt.Errorf("Insufficient data to read uint"))
}
val := uint64(0)
for _, b := range buf {
val = (val << 8) + uint64(b)
}
v.Elem().SetUint(val)
return uintLen
}
//////////
type arrayDecoder struct {
elemDec decoderFunc
}
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
n := v.Elem().Type().Len()
read := 0
for i := 0; i < n; i += 1 {
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
}
return read
}
func newArrayDecoder(t reflect.Type) decoderFunc {
dec := &arrayDecoder{typeDecoder(t.Elem())}
return dec.decode
}
//////////
type sliceDecoder struct {
elementType reflect.Type
elementDec decoderFunc
}
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
if opts.head == 0 {
panic(fmt.Errorf("Cannot decode a slice without a header length"))
}
lengthBytes := make([]byte, opts.head)
n, err := d.Read(lengthBytes)
if err != nil {
panic(err)
}
if uint(n) != opts.head {
panic(fmt.Errorf("Not enough data to read header"))
}
length := uint(0)
for _, b := range lengthBytes {
length = (length << 8) + uint(b)
}
if opts.max > 0 && length > opts.max {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < opts.min {
panic(fmt.Errorf("Length of vector below declared min"))
}
data := make([]byte, length)
n, err = d.Read(data)
if err != nil {
panic(err)
}
if uint(n) != length {
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
}
elemBuf := &decodeState{}
elemBuf.Write(data)
elems := []reflect.Value{}
read := int(opts.head)
for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts)
elems = append(elems, elem)
}
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
for i := 0; i < len(elems); i += 1 {
v.Elem().Index(i).Set(elems[i].Elem())
}
return read
}
func newSliceDecoder(t reflect.Type) decoderFunc {
dec := &sliceDecoder{
elementType: t.Elem(),
elementDec: typeDecoder(t.Elem()),
}
return dec.decode
}
//////////
type structDecoder struct {
fieldOpts []decOpts
fieldDecs []decoderFunc
}
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
read := 0
for i := range sd.fieldDecs {
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
}
return read
}
func newStructDecoder(t reflect.Type) decoderFunc {
n := t.NumField()
sd := structDecoder{
fieldOpts: make([]decOpts, n),
fieldDecs: make([]decoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
sd.fieldOpts[i] = decOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
sd.fieldDecs[i] = typeDecoder(f.Type)
}
return sd.decode
}

187
vendor/github.com/bifurcation/mint/syntax/encode.go generated vendored Normal file
View file

@ -0,0 +1,187 @@
package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
)
func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v, encOpts{})
if err != nil {
return nil, err
}
return e.Bytes(), nil
}
// These are the options that can be specified in the struct tag. Right now,
// all of them apply to variable-length vectors and nothing else
type encOpts struct {
head uint // length of length in bytes
min uint // minimum size in bytes
max uint // maximum size in bytes
}
type encodeState struct {
bytes.Buffer
}
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
panic(fmt.Errorf("Cannot encode an invalid value"))
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
// Note: Omits the caching / wait-group things that encoding/json uses
return newTypeEncoder(t)
}
func newTypeEncoder(t reflect.Type) encoderFunc {
// Note: Does not support Marshaler, so don't need the allowAddr argument
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return uintEncoder
case reflect.Array:
return newArrayEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Struct:
return newStructEncoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
///// Specific encoders below
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
u := v.Uint()
switch v.Type().Kind() {
case reflect.Uint8:
e.WriteByte(byte(u))
case reflect.Uint16:
e.Write([]byte{byte(u >> 8), byte(u)})
case reflect.Uint32:
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
case reflect.Uint64:
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
}
}
//////////
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
n := v.Len()
for i := 0; i < n; i += 1 {
ae.elemEnc(e, v.Index(i), opts)
}
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := &arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
//////////
type sliceEncoder struct {
ae *arrayEncoder
}
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if opts.head == 0 {
panic(fmt.Errorf("Cannot encode a slice without a header length"))
}
arrayState := &encodeState{}
se.ae.encode(arrayState, v, opts)
n := uint(arrayState.Len())
if opts.max > 0 && n > opts.max {
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
}
if n>>(8*opts.head) > 0 {
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
}
if n < opts.min {
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
}
for i := int(opts.head - 1); i >= 0; i -= 1 {
e.WriteByte(byte(n >> (8 * uint(i))))
}
e.Write(arrayState.Bytes())
}
func newSliceEncoder(t reflect.Type) encoderFunc {
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
return enc.encode
}
//////////
type structEncoder struct {
fieldOpts []encOpts
fieldEncs []encoderFunc
}
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
for i := range se.fieldEncs {
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
n := t.NumField()
se := structEncoder{
fieldOpts: make([]encOpts, n),
fieldEncs: make([]encoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
tagOpts := parseTag(tag)
se.fieldOpts[i] = encOpts{
head: tagOpts["head"],
max: tagOpts["max"],
min: tagOpts["min"],
}
se.fieldEncs[i] = typeEncoder(f.Type)
}
return se.encode
}

30
vendor/github.com/bifurcation/mint/syntax/tags.go generated vendored Normal file
View file

@ -0,0 +1,30 @@
package syntax
import (
"strconv"
"strings"
)
// `tls:"head=2,min=2,max=255"`
type tagOptions map[string]uint
// parseTag parses a struct field's "tls" tag as a comma-separated list of
// name=value pairs, where the values MUST be unsigned integers
func parseTag(tag string) tagOptions {
opts := tagOptions{}
for _, token := range strings.Split(tag, ",") {
if strings.Index(token, "=") == -1 {
continue
}
parts := strings.Split(token, "=")
if len(parts[0]) == 0 {
continue
}
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
opts[parts[0]] = uint(val)
}
}
return opts
}

168
vendor/github.com/bifurcation/mint/tls.go generated vendored Normal file
View file

@ -0,0 +1,168 @@
package mint
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
import (
"errors"
"net"
"strings"
"time"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, false)
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
return NewConn(conn, config, true)
}
// A listener implements a network listener (net.Listener) for TLS connections.
type Listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err error) {
c, err = l.Listener.Accept()
if err != nil {
return
}
server := Server(c, l.config)
err = server.Handshake()
if err == AlertNoAlert {
err = nil
}
c = server
return
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(Listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || !config.ValidForServer() {
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type TimeoutError struct{}
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (TimeoutError) Timeout() bool { return true }
func (TimeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- TimeoutError{}
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = &Config{}
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if timeout == 0 {
err = conn.Handshake()
if err == AlertNoAlert {
err = nil
}
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
if err == AlertNoAlert {
err = nil
}
}
if err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}

View file

@ -1,32 +0,0 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
SendingAllowed() bool
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
SetLowerLimit(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame
}

View file

@ -3,7 +3,7 @@ package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var bufferPool sync.Pool

View file

@ -10,32 +10,39 @@ import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
mutex sync.Mutex
listenErr error
mutex sync.Mutex
conn connection
hostname string
errorChan chan struct{}
handshakeChan <-chan handshakeEvent
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet
tlsConf *tls.Config
config *Config
tls handshake.MintTLS // only used when using TLS
connectionID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
session packetHandler
}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = utils.GenerateConnectionID
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
}
err = c.createNewSession(nil)
if err != nil {
return nil, err
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection()
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
@ -127,15 +69,39 @@ func Dial(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
connID, err := generateConnectionID()
if err != nil {
return nil, err
}
err = sess.WaitUntilHandshakeComplete()
if err != nil {
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
versionNegotiationChan: make(chan struct{}),
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
if err := c.dial(); err != nil {
return nil, err
}
return sess, nil
return c.session, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.IdleTimeout != 0 {
idleTimeout = config.IdleTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
}
}
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
func (c *client) dial() error {
var err error
if c.version.UsesTLS() {
err = c.dialTLS()
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
}
return err
}
func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
go c.listen()
return c.establishSecureConnection()
}
func (c *client) dialTLS() error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
// TODO(#523): make these values configurable
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
}
csc := handshake.NewCryptoStreamConn(nil)
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
}
return nil
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error
errorChan := make(chan struct{})
go func() {
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
}()
// wait until the server accepts the QUIC version (or an error occurs)
select {
case <-errorChan:
return runErr
case <-c.versionNegotiationChan:
}
select {
case <-c.errorChan:
return c.listenErr
case ev := <-c.handshakeChan:
if ev.err != nil {
return ev.err
}
if ev.encLevel != protocol.EncryptionSecure {
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
}
return nil
case <-errorChan:
return runErr
case err := <-c.session.handshakeStatus():
return err
}
}
// Listen listens
// Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() {
var err error
@ -205,13 +252,15 @@ func (c *client) listen() {
n, addr, err = c.conn.Read(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.session.Close(err)
c.mutex.Lock()
if c.session != nil {
c.session.Close(err)
}
c.mutex.Unlock()
}
break
}
data = data[:n]
c.handlePacket(addr, data)
c.handlePacket(addr, data[:n])
}
}
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
if err != nil {
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
// drop this packet if we can't parse the header
return
}
// reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
pr, err := wire.ParsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag {
return
}
// handle Version Negotiation Packets
if hdr.IsVersionNegotiation {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return
}
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
}
if hdr.VersionFlag {
// version negotiation packets have no payload
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
c.session.Close(err)
}
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
close(c.versionNegotiationChan)
}
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
publicHeader: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
}
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
}
}
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if newVersion == protocol.VersionUnsupported {
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
c.versionNegotiated = true
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
return c.createNewSession(hdr.SupportedVersions)
return nil
}
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, c.handshakeChan, err = newClientSession(
func (c *client) createNewGQUICSession() (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
negotiatedVersions,
c.initialVersion,
c.negotiatedVersions,
)
if err != nil {
return err
}
go func() {
// session.run() returns as soon as the session is closed
err := c.session.run()
if err == errCloseSessionForNewVersion {
return
}
c.listenErr = err
close(c.errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close()
}()
return nil
return err
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.config,
c.tls,
paramsChan,
1,
)
return err
}

View file

@ -1,58 +0,0 @@
package crypto
import (
"crypto/cipher"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/protocol"
)
type aeadAESGCM struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
}

View file

@ -1,14 +0,0 @@
package crypto
import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/protocol"
)
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View file

@ -1,76 +0,0 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// StkSource is used to create and verify source address tokens
type StkSource interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
type stkSource struct {
aead cipher.AEAD
}
const stkKeySize = 16
// Chrome currently sets this to 12, but discusses changing it to 16. We start
// at 16 :)
const stkNonceSize = 16
// NewStkSource creates a source for source address tokens
func NewStkSource() (StkSource, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
key, err := deriveKey(secret)
if err != nil {
return nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
if err != nil {
return nil, err
}
return &stkSource{aead: aead}, nil
}
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, stkNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
return s.aead.Seal(nonce, nonce, data, nil), nil
}
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
if len(p) < stkNonceSize {
return nil, fmt.Errorf("STK too short: %d", len(p))
}
nonce := p[:stkNonceSize]
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
}
func deriveKey(secret []byte) ([]byte, error) {
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
key := make([]byte, stkKeySize)
if _, err := io.ReadFull(r, key); err != nil {
return nil, err
}
return key, nil
}

View file

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStreamI interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStream struct {
*stream
}
var _ cryptoStreamI = &cryptoStream{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStream{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPosition = offset
}

View file

@ -7,12 +7,15 @@ import (
"net/http"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func main() {
verbose := flag.Bool("v", false, "verbose")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
urls := flag.Args()
@ -23,8 +26,17 @@ func main() {
}
utils.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
roundTripper := &h2quic.RoundTripper{
QuicConfig: &quic.Config{Versions: versions},
}
defer roundTripper.Close()
hclient := &http.Client{
Transport: &h2quic.RoundTripper{},
Transport: roundTripper,
}
var wg sync.WaitGroup

View file

@ -17,7 +17,9 @@ import (
_ "net/http/pprof"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -121,6 +123,7 @@ func main() {
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
www := flag.String("www", "/var/www", "www data")
tcp := flag.Bool("tcp", false, "also listen on TCP")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
if *verbose {
@ -130,6 +133,11 @@ func main() {
}
utils.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
certFile := *certPath + "/fullchain.pem"
keyFile := *certPath + "/privkey.pem"
@ -148,7 +156,11 @@ func main() {
if *tcp {
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
} else {
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
server := h2quic.Server{
Server: &http.Server{Addr: bCap},
QuicConfig: &quic.Config{Versions: versions},
}
err = server.ListenAndServeTLS(certFile, keyFile)
}
if err != nil {
fmt.Println(err)

View file

@ -1,240 +0,0 @@
package flowcontrol
import (
"errors"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
type flowControlManager struct {
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
streamFlowController map[protocol.StreamID]*flowController
connFlowController *flowController
mutex sync.RWMutex
}
var _ FlowControlManager = &flowControlManager{}
var errMapAccess = errors.New("Error accessing the flowController map.")
// NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
return &flowControlManager{
connectionParameters: connectionParameters,
rttStats: rttStats,
streamFlowController: make(map[protocol.StreamID]*flowController),
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
}
}
// NewStream creates new flow controllers for a stream
// it does nothing if the stream already exists
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
f.mutex.Lock()
defer f.mutex.Unlock()
if _, ok := f.streamFlowController[streamID]; ok {
return
}
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
}
// RemoveStream removes a closed stream from flow control
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Lock()
delete(f.streamFlowController, streamID)
f.mutex.Unlock()
}
// ResetStream should be called when receiving a RstStreamFrame
// it updates the byte offset to the value in the RstStreamFrame
// streamID must not be 0 here
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
if err != nil {
return qerr.StreamDataAfterTermination
}
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
}
if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
}
}
return nil
}
// UpdateHighestReceived updates the highest received byte offset for a stream
// it adds the number of additional bytes to connection level flow control
// streamID must not be 0 here
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
// this error can be ignored here
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
}
if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
}
}
return nil
}
// streamID must not be 0 here
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return err
}
fc.AddBytesRead(n)
if fc.ContributesToConnection() {
f.connFlowController.AddBytesRead(n)
}
return nil
}
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
f.mutex.Lock()
defer f.mutex.Unlock()
// get WindowUpdates for streams
for id, fc := range f.streamFlowController {
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
if fc.ContributesToConnection() && newIncrement != 0 {
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
}
}
}
// get a WindowUpdate for the connection
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
}
return
}
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.RLock()
defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return flowController.receiveWindow, nil
}
// streamID must not be 0 here
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return err
}
fc.AddBytesSent(n)
if fc.ContributesToConnection() {
f.connFlowController.AddBytesSent(n)
}
return nil
}
// must not be called with StreamID 0
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.RLock()
defer f.mutex.RUnlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
res := fc.SendWindowSize()
if fc.ContributesToConnection() {
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
}
return res, nil
}
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
f.mutex.RLock()
defer f.mutex.RUnlock()
return f.connFlowController.SendWindowSize()
}
// streamID may be 0 here
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
var fc *flowController
if streamID == 0 {
fc = f.connFlowController
} else {
var err error
fc, err = f.getFlowController(streamID)
if err != nil {
return false, err
}
}
return fc.UpdateSendWindow(offset), nil
}
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
streamFlowController, ok := f.streamFlowController[streamID]
if !ok {
return nil, errMapAccess
}
return streamFlowController, nil
}

View file

@ -1,198 +0,0 @@
package flowcontrol
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
type flowController struct {
streamID protocol.StreamID
contributesToConnection bool // does the stream contribute to connection level flow control
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastWindowUpdateTime time.Time
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowIncrement protocol.ByteCount
maxReceiveWindowIncrement protocol.ByteCount
}
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
// newFlowController gets a new flow controller
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
fc := flowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connectionParameters: connectionParameters,
rttStats: rttStats,
}
if streamID == 0 {
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
} else {
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
fc.receiveWindowIncrement = fc.receiveWindow
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
}
return &fc
}
func (c *flowController) ContributesToConnection() bool {
return c.contributesToConnection
}
func (c *flowController) getSendWindow() protocol.ByteCount {
if c.sendWindow == 0 {
if c.streamID == 0 {
return c.connectionParameters.GetSendConnectionFlowControlWindow()
}
return c.connectionParameters.GetSendStreamFlowControlWindow()
}
return c.sendWindow
}
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
if newOffset > c.sendWindow {
c.sendWindow = newOffset
return true
}
return false
}
func (c *flowController) SendWindowSize() protocol.ByteCount {
sendWindow := c.getSendWindow()
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
return 0
}
return sendWindow - c.bytesSent
}
func (c *flowController) SendWindowOffset() protocol.ByteCount {
return c.getSendWindow()
}
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// Should **only** be used for the stream-level FlowController
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
// It should only be treated as an error when resetting a stream
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
if byteOffset == c.highestReceived {
return 0, nil
}
if byteOffset > c.highestReceived {
increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset
return increment, nil
}
return 0, ErrReceivedSmallerByteOffset
}
// IncrementHighestReceived adds an increment to the highestReceived value
// Should **only** be used for the connection-level FlowController
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
c.highestReceived += increment
}
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window increment already works for the first WindowUpdate
if c.bytesRead == 0 {
c.lastWindowUpdateTime = time.Now()
}
c.bytesRead += n
}
// MaybeUpdateWindow updates the receive window, if necessary
// if the receive window increment is changed, the new value is returned, otherwise a 0
// the last return value is the new offset of the receive window
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
diff := c.receiveWindow - c.bytesRead
// Chromium implements the same threshold
if diff < (c.receiveWindowIncrement / 2) {
var newWindowIncrement protocol.ByteCount
oldWindowIncrement := c.receiveWindowIncrement
c.maybeAdjustWindowIncrement()
if c.receiveWindowIncrement != oldWindowIncrement {
newWindowIncrement = c.receiveWindowIncrement
}
c.lastWindowUpdateTime = time.Now()
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
return true, newWindowIncrement, c.receiveWindow
}
return false, 0, 0
}
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
func (c *flowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt {
return
}
oldWindowSize := c.receiveWindowIncrement
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
// debug log, if the window size was actually increased
if oldWindowSize < c.receiveWindowIncrement {
newWindowSize := c.receiveWindowIncrement / (1 << 10)
if c.streamID == 0 {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
} else {
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
}
}
}
// EnsureMinimumWindowIncrement sets a minimum window increment
// it is intended be used for the connection-level flow controller
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
if inc > c.receiveWindowIncrement {
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
}
}
func (c *flowController) CheckFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View file

@ -1,26 +0,0 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/protocol"
// WindowUpdate provides the data for WindowUpdateFrames.
type WindowUpdate struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
}
// A FlowControlManager manages the flow control
type FlowControlManager interface {
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
RemoveStream(streamID protocol.StreamID)
// methods needed for receiving data
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
GetWindowUpdates() []WindowUpdate
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
// methods needed for sending data
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
RemainingConnectionWindowSize() protocol.ByteCount
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
}

View file

@ -1,9 +0,0 @@
package frames
import "github.com/lucas-clemente/quic-go/protocol"
// AckRange is an ACK range
type AckRange struct {
FirstPacketNumber protocol.PacketNumber
LastPacketNumber protocol.PacketNumber
}

View file

@ -1,44 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A BlockedFrame in QUIC
type BlockedFrame struct {
StreamID protocol.StreamID
}
//Write writes a BlockedFrame frame
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x05)
utils.WriteUint32(b, uint32(f.StreamID))
return nil
}
// MinLength of a written frame
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4, nil
}
// ParseBlockedFrame parses a BLOCKED frame
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
frame := &BlockedFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
return frame, nil
}

View file

@ -1,73 +0,0 @@
package frames
import (
"bytes"
"errors"
"io"
"math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
// A ConnectionCloseFrame in QUIC
type ConnectionCloseFrame struct {
ErrorCode qerr.ErrorCode
ReasonPhrase string
}
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
frame := &ConnectionCloseFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
errorCode, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.ErrorCode = qerr.ErrorCode(errorCode)
reasonPhraseLen, err := utils.ReadUint16(r)
if err != nil {
return nil, err
}
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
return nil, err
}
frame.ReasonPhrase = string(reasonPhrase)
return frame, nil
}
// MinLength of a written frame
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
}
// Write writes an CONNECTION_CLOSE frame.
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x02)
utils.WriteUint32(b, uint32(f.ErrorCode))
if len(f.ReasonPhrase) > math.MaxUint16 {
return errors.New("ConnectionFrame: ReasonPhrase too long")
}
reasonPhraseLen := uint16(len(f.ReasonPhrase))
utils.WriteUint16(b, reasonPhraseLen)
b.WriteString(f.ReasonPhrase)
return nil
}

View file

@ -1,13 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/protocol"
)
// A Frame in QUIC
type Frame interface {
Write(b *bytes.Buffer, version protocol.VersionNumber) error
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
}

View file

@ -1,28 +0,0 @@
package frames
import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) {
if !utils.Debug() {
return
}
dir := "<-"
if sent {
dir = "->"
}
switch f := frame.(type) {
case *StreamFrame:
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame:
if sent {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
}
case *AckFrame:
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
default:
utils.Debugf("\t%s %#v", dir, frame)
}
}

View file

@ -1,59 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A RstStreamFrame in QUIC
type RstStreamFrame struct {
StreamID protocol.StreamID
ErrorCode uint32
ByteOffset protocol.ByteCount
}
//Write writes a RST_STREAM frame
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x01)
utils.WriteUint32(b, uint32(f.StreamID))
utils.WriteUint64(b, uint64(f.ByteOffset))
utils.WriteUint32(b, f.ErrorCode)
return nil
}
// MinLength of a written frame
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 8 + 4, nil
}
// ParseRstStreamFrame parses a RST_STREAM frame
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
frame := &RstStreamFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
byteOffset, err := utils.ReadUint64(r)
if err != nil {
return nil, err
}
frame.ByteOffset = protocol.ByteCount(byteOffset)
frame.ErrorCode, err = utils.ReadUint32(r)
if err != nil {
return nil, err
}
return frame, nil
}

View file

@ -1,54 +0,0 @@
package frames
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// A WindowUpdateFrame in QUIC
type WindowUpdateFrame struct {
StreamID protocol.StreamID
ByteOffset protocol.ByteCount
}
//Write writes a RST_STREAM frame
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x04)
b.WriteByte(typeByte)
utils.WriteUint32(b, uint32(f.StreamID))
utils.WriteUint64(b, uint64(f.ByteOffset))
return nil
}
// MinLength of a written frame
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
return 1 + 4 + 8, nil
}
// ParseWindowUpdateFrame parses a RST_STREAM frame
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
frame := &WindowUpdateFrame{}
// read the TypeByte
_, err := r.ReadByte()
if err != nil {
return nil, err
}
sid, err := utils.ReadUint32(r)
if err != nil {
return nil, err
}
frame.StreamID = protocol.StreamID(sid)
byteOffset, err := utils.ReadUint64(r)
if err != nil {
return nil, err
}
frame.ByteOffset = protocol.ByteCount(byteOffset)
return frame, nil
}

View file

@ -15,8 +15,8 @@ import (
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -34,10 +34,10 @@ type client struct {
config *quic.Config
opts *roundTripperOpts
hostname string
encryptionLevel protocol.EncryptionLevel
handshakeErr error
dialOnce sync.Once
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
@ -51,8 +51,8 @@ type client struct {
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDTruncation: true,
KeepAlive: true,
RequestConnectionIDOmission: true,
KeepAlive: true,
}
// newClient creates a new client
@ -61,26 +61,31 @@ func newClient(
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
@ -90,9 +95,6 @@ func (c *client) dial() error {
if err != nil {
return err
}
if c.headerStream.StreamID() != 3 {
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
utils.Debugf("Error handling header stream: %s", err)
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
for {
frame, err := h2framer.ReadFrame()
if err != nil {
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
break
}
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
break
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
break
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
break
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
responseChan <- rsp
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
close(c.headerErrored)
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response

View file

@ -13,8 +13,8 @@ import (
"golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
type requestWriter struct {

View file

@ -8,8 +8,8 @@ import (
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
func (w *responseWriter) Flush() {}
// TODO: Implement a functional CloseNotify method.
// This is a NOP. Use http.Request.Context
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher

View file

@ -41,6 +41,11 @@ type RoundTripper struct {
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil

View file

@ -7,14 +7,14 @@ import (
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
@ -50,6 +50,7 @@ type Server struct {
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
}
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("use of h2quic.Server without http.Server")
}
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) {
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return
}
if stream.StreamID() != 3 {
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
go func() {
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
return
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(err)
return
}
}()
}
}
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return err
}
req.RemoteAddr = session.RemoteAddr().String()
if utils.Debug() {
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
return nil
}
var streamEnded bool
if h2headersFrame.StreamEnded() {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
reqBody := newRequestBody(dataStream)
req.Body = reqBody
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
// handleRequest should be as non-blocking as possible to minimize
// head-of-line blocking. Potentially blocking code is run in a separate
// goroutine, enabling handleRequest to return before the code is executed.
go func() {
streamEnded := h2headersFrame.StreamEnded()
if streamEnded {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil)
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
}
if s.supportedVersionsAsString == "" {
for i, v := range protocol.SupportedVersions {
s.supportedVersionsAsString += strconv.Itoa(int(v))
if i != len(protocol.SupportedVersions)-1 {
s.supportedVersionsAsString += ","
}
var versions []string
for _, v := range protocol.SupportedVersions {
versions = append(versions, v.ToAltSvc())
}
s.supportedVersionsAsString = strings.Join(versions, ",")
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, config)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tcpConn)
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)

View file

@ -1,265 +0,0 @@
package handshake
import (
"bytes"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
// ConnectionParametersManager negotiates and stores the connection parameters
// A ConnectionParametersManager can be used for a server as well as a client
// For the server:
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
// 2. call GetHelloMap to get the values to send in the SHLO
// For the client:
// 1. call GetHelloMap to get the values to send in a CHLO
// 2. call SetFromMap with the values received in the SHLO
type ConnectionParametersManager interface {
SetFromMap(map[Tag][]byte) error
GetHelloMap() (map[Tag][]byte, error)
GetSendStreamFlowControlWindow() protocol.ByteCount
GetSendConnectionFlowControlWindow() protocol.ByteCount
GetReceiveStreamFlowControlWindow() protocol.ByteCount
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
GetMaxOutgoingStreams() uint32
GetMaxIncomingStreams() uint32
GetIdleConnectionStateLifetime() time.Duration
TruncateConnectionID() bool
}
type connectionParametersManager struct {
mutex sync.RWMutex
version protocol.VersionNumber
perspective protocol.Perspective
flowControlNegotiated bool
truncateConnectionID bool
maxStreamsPerConnection uint32
maxIncomingDynamicStreamsPerConnection uint32
idleConnectionStateLifetime time.Duration
sendStreamFlowControlWindow protocol.ByteCount
sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
}
var _ ConnectionParametersManager = &connectionParametersManager{}
// ErrMalformedTag is returned when the tag value cannot be read
var (
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
)
// NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{
perspective: pers,
version: v,
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
}
if h.perspective == protocol.PerspectiveServer {
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
} else {
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
}
return h
}
// SetFromMap reads all params
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
h.mutex.Lock()
defer h.mutex.Unlock()
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.truncateConnectionID = (clientValue == 0)
}
if value, ok := params[TagMSPC]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
}
if value, ok := params[TagMIDS]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
}
if value, ok := params[TagICSL]; ok {
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
}
if value, ok := params[TagSFCW]; ok {
if h.flowControlNegotiated {
return ErrFlowControlRenegotiationNotSupported
}
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
}
if value, ok := params[TagCFCW]; ok {
if h.flowControlNegotiated {
return ErrFlowControlRenegotiationNotSupported
}
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return ErrMalformedTag
}
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
}
_, containsSFCW := params[TagSFCW]
_, containsCFCW := params[TagCFCW]
if containsCFCW || containsSFCW {
h.flowControlNegotiated = true
}
return nil
}
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
}
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
}
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
if h.perspective == protocol.PerspectiveServer {
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
}
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
}
// GetHelloMap gets all parameters needed for the Hello message
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
sfcw := bytes.NewBuffer([]byte{})
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
cfcw := bytes.NewBuffer([]byte{})
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
mspc := bytes.NewBuffer([]byte{})
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
mids := bytes.NewBuffer([]byte{})
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
icsl := bytes.NewBuffer([]byte{})
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
return map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMSPC: mspc.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}, nil
}
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendStreamFlowControlWindow
}
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.sendConnectionFlowControlWindow
}
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.receiveStreamFlowControlWindow
}
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
return h.maxReceiveStreamFlowControlWindow
}
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.receiveConnectionFlowControlWindow
}
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
return h.maxReceiveConnectionFlowControlWindow
}
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.maxIncomingDynamicStreamsPerConnection
}
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
h.mutex.RLock()
defer h.mutex.RUnlock()
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
}
// GetIdleConnectionStateLifetime gets the idle timeout
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.idleConnectionStateLifetime
}
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
func (h *connectionParametersManager) TruncateConnectionID() bool {
if h.perspective == protocol.PerspectiveClient {
return false
}
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.truncateConnectionID
}

View file

@ -1,24 +0,0 @@
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// Sealer seals a packet
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
// CryptoSetup is a crypto setup
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
RequestConnectionIDTruncation bool
}

View file

@ -1,100 +0,0 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/crypto"
)
const (
stkPrefixIP byte = iota
stkPrefixString
)
// An STK is a source address token
type STK struct {
RemoteAddr string
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// An STKGenerator generates STKs
type STKGenerator struct {
stkSource crypto.StkSource
}
// NewSTKGenerator initializes a new STKGenerator
func NewSTKGenerator() (*STKGenerator, error) {
stkSource, err := crypto.NewStkSource()
if err != nil {
return nil, err
}
return &STKGenerator{
stkSource: stkSource,
}, nil
}
// NewToken generates a new STK token for a given source address
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.stkSource.NewToken(data)
}
// DecodeToken decodes an STK token
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.stkSource.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &STK{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{stkPrefixIP}, udpAddr.IP...)
}
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the STK
func decodeRemoteAddr(data []byte) string {
// data will never be empty for an STK that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == stkPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -1 +0,0 @@
package chrome

View file

@ -1 +0,0 @@
package gquic

View file

@ -1 +0,0 @@
package self

View file

@ -1,14 +1,12 @@
package quicproxy
import (
"bytes"
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Connection is a UDP connection
@ -28,21 +26,43 @@ const (
DirectionIncoming Direction = iota
// DirectionOutgoing is the direction from the server to the client.
DirectionOutgoing
// DirectionBoth is both incoming and outgoing
DirectionBoth
)
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
case DirectionOutgoing:
return "outgoing"
case DirectionBoth:
return "both"
default:
panic("unknown direction")
}
}
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
}
return d == dir
}
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(Direction, protocol.PacketNumber) bool
type DropCallback func(dir Direction, packetCount uint64) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
var NoDropper DropCallback = func(Direction, uint64) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
return 0
}
@ -62,6 +82,8 @@ type Opts struct {
type QuicProxy struct {
mutex sync.Mutex
version protocol.VersionNumber
conn *net.UDPConn
serverAddr *net.UDPAddr
@ -73,7 +95,10 @@ type QuicProxy struct {
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
}
go p.runProxy()
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// LocalPort is the UDP port number the proxy is listening on.
func (p *QuicProxy) LocalPort() int {
return p.conn.LocalAddr().(*net.UDPAddr).Port
}
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxPacketSize)
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
if err != nil {
return err
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
}
p.mutex.Unlock()
atomic.AddUint64(&conn.incomingPacketCounter, 1)
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
r := bytes.NewReader(raw)
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
if err != nil {
return err
}
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
if p.dropPacket(DirectionIncoming, packetCount) {
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
time.AfterFunc(delay, func() {
// TODO: handle error
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
for {
buffer := make([]byte, protocol.MaxPacketSize)
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
// r := bytes.NewReader(raw)
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
// if err != nil {
// return err
// }
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
packetNumber := protocol.PacketNumber(v)
if p.dropPacket(DirectionOutgoing, packetNumber) {
if p.dropPacket(DirectionOutgoing, packetCount) {
continue
}
delay := p.delayPacket(DirectionOutgoing, packetNumber)
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
time.AfterFunc(delay, func() {
// TODO: handle error

View file

@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
if len(logFileName) > 0 {
var err error
logFile, err = os.Create("./log.txt")
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.SetLogLevel(utils.LogLevelDebug)

View file

@ -7,7 +7,9 @@ import (
"net/http"
"strconv"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
@ -23,8 +25,9 @@ var (
PRData = GeneratePRData(dataLen)
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
port string
server *h2quic.Server
stoppedServing chan struct{}
port string
)
func init() {
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
return res
}
func StartQuicServer() {
// StartQuicServer starts a h2quic.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: &quic.Config{
Versions: versions,
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
@ -88,14 +96,18 @@ func StartQuicServer() {
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
stoppedServing = make(chan struct{})
go func() {
defer GinkgoRecover()
server.Serve(conn)
close(stoppedServing)
}()
}
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
func Port() string {

View file

@ -6,23 +6,55 @@ import (
"net"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
StreamID() protocol.StreamID
// Reset closes the stream with an error.
Reset(error)
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
@ -43,6 +75,41 @@ type Stream interface {
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
@ -64,53 +131,41 @@ type Session interface {
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
// The communication is encrypted, but not yet forward secure.
type NonFWSession interface {
Session
WaitUntilHandshakeComplete() error
}
// An STK is a Source Address token.
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
type STK struct {
// The remote address this token was issued for.
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
remoteAddr string
// The time that the STK was issued (resolution 1 second)
sentTime time.Time
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDTruncation bool
RequestConnectionIDOmission bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
IdleTimeout time.Duration
// AcceptCookie determines if a Cookie is accepted.
// It is called with cookie = nil if the client didn't send an Cookie.
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
// This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
MaxReceiveStreamFlowControlWindow uint64
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
MaxReceiveConnectionFlowControlWindow uint64
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}

View file

@ -0,0 +1,48 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// SendingAllowed says if a packet can be sent.
// Sending packets might not be possible because:
// * we're congestion limited
// * we're tracking the maximum number of sent packets
SendingAllowed() bool
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() (packet *Packet)
GetLeastUnacked() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm()
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View file

@ -3,29 +3,30 @@ package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
// +gen linkedlist
type Packet struct {
PacketNumber protocol.PacketNumber
Frames []frames.Frame
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
sendTime time.Time
}
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []frames.Frame {
var fs []frames.Frame
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
var fs []wire.Frame
for _, frame := range p.Frames {
switch frame.(type) {
case *frames.AckFrame:
case *wire.AckFrame:
continue
case *frames.StopWaitingFrame:
case *wire.StopWaitingFrame:
continue
}
fs = append(fs, frame)

View file

@ -1,18 +1,15 @@
package ackhandler
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
lowerLimit protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
@ -23,46 +20,45 @@ type receivedPacketHandler struct {
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *frames.AckFrame
lastAck *wire.AckFrame
version protocol.VersionNumber
}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler() ReceivedPacketHandler {
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: protocol.AckSendDelay,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 {
return errInvalidPacketNumber
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
h.largestObservedReceivedTime = rcvTime
}
if packetNumber <= h.lowerLimit {
if packetNumber < h.ignoreBelow {
return nil
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
return nil
}
// SetLowerLimit sets a lower limit for acking packets.
// Packets with packet numbers smaller or equal than p will not be acked.
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
h.lowerLimit = p
h.packetHistory.DeleteUpTo(p)
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
}
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
@ -74,12 +70,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
h.ackQueued = true
}
// Always send an ack every 20 packets in order to allow the peer to discard
// information from the SentPacketManager and provide an RTT measurement.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
@ -87,7 +77,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
h.ackQueued = true
}
@ -96,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
}
}
}
@ -107,15 +97,15 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
}
}
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
ack := &frames.AckFrame{
ack := &wire.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
LowestAcked: ackRanges[len(ackRanges)-1].First,
PacketReceivedTime: h.largestObservedReceivedTime,
}

View file

@ -1,9 +1,9 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -12,21 +12,15 @@ import (
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
// the map is used as a replacement for a set here. The bool is always supposed to be set to true
receivedPacketNumbers map[protocol.PacketNumber]bool
lowestInReceivedPacketNumbers protocol.PacketNumber
}
var (
errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets")
)
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
// newReceivedPacketHistory creates a new received packet history
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(),
receivedPacketNumbers: make(map[protocol.PacketNumber]bool),
ranges: utils.NewPacketIntervalList(),
}
}
@ -36,12 +30,6 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return errTooManyOutstandingReceivedAckRanges
}
if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets {
return errTooManyOutstandingReceivedPackets
}
h.receivedPacketNumbers[p] = true
if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return nil
@ -86,23 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
return nil
}
// DeleteUpTo deletes all entries up to (and including) p
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p <= h.lowestInReceivedPacketNumbers {
return
}
h.lowestInReceivedPacketNumbers = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if p >= el.Value.Start && p < el.Value.End {
for i := el.Value.Start; i <= p; i++ { // adjust start value of a range
delete(h.receivedPacketNumbers, i)
}
el.Value.Start = p + 1
} else if el.Value.End <= p { // delete a whole range
for i := el.Value.Start; i <= el.Value.End; i++ {
delete(h.receivedPacketNumbers, i)
}
if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
} else if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else { // no ranges affected. Nothing to do
return
@ -110,38 +95,27 @@ func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
}
}
// IsDuplicate determines if a packet should be regarded as a duplicate packet
// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed
func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool {
if p < h.lowestInReceivedPacketNumbers {
return true
}
_, ok := h.receivedPacketNumbers[p]
return ok
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
if h.ranges.Len() == 0 {
return nil
}
var ackRanges []frames.AckRange
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges = append(ackRanges, frames.AckRange{FirstPacketNumber: el.Value.Start, LastPacketNumber: el.Value.End})
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
i++
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange {
ackRange := frames.AckRange{}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.FirstPacketNumber = r.Start
ackRange.LastPacketNumber = r.End
ackRange.First = r.Start
ackRange.Last = r.End
}
return ackRange
}

View file

@ -1,12 +1,10 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
)
import "github.com/lucas-clemente/quic-go/internal/wire"
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
res := make([]frames.Frame, 0, len(fs))
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
res := make([]wire.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
@ -16,11 +14,11 @@ func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f frames.Frame) bool {
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *frames.StopWaitingFrame:
case *wire.StopWaitingFrame:
return false
case *frames.AckFrame:
case *wire.AckFrame:
return false
default:
return true
@ -28,7 +26,7 @@ func IsFrameRetransmittable(f frames.Frame) bool {
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []frames.Frame) bool {
func HasRetransmittableFrames(fs []wire.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true

View file

@ -3,12 +3,13 @@ package ackhandler
import (
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -16,33 +17,33 @@ const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// The default RTT used before an RTT sample is taken.
// Note: This constant is also defined in the congestion package.
defaultInitialRTT = 100 * time.Millisecond
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
var (
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
)
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
LargestAcked protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
packetHistory *PacketList
stopWaitingManager stopWaitingManager
@ -54,6 +55,10 @@ type sentPacketHandler struct {
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
@ -82,20 +87,27 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
}
}
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if f := h.packetHistory.Front(); f != nil {
return f.Value.PacketNumber - 1
return f.Value.PacketNumber
}
return h.LargestAcked
return h.largestAcked + 1
}
func (h *sentPacketHandler) SetHandshakeComplete() {
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
if packet.PacketNumber <= h.lastSentPacketNumber {
return errPacketNumberNotIncreasing
}
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
return ErrTooManyTrackedSentPackets
return errors.New("Too many outstanding non-acked and non-retransmitted packets")
}
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
@ -106,14 +118,22 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
}
}
h.lastSentPacketNumber = packet.PacketNumber
now := time.Now()
h.lastSentPacketNumber = packet.PacketNumber
var largestAcked protocol.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
largestAcked = ackFrame.LargestAcked
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
packet.sendTime = now
packet.largestAcked = largestAcked
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
}
@ -126,29 +146,32 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
isRetransmittable,
)
h.updateLossDetectionAlarm()
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
h.updateLossDetectionAlarm(now)
return nil
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, rcvTime time.Time) error {
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if ackFrame.LargestAcked > h.lastSentPacketNumber {
return errAckForUnsentPacket
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out-of-order ACK
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
if withPacketNumber <= h.largestReceivedPacketWithAck {
return ErrDuplicateOrOutOfOrderAck
}
h.largestReceivedPacketWithAck = withPacketNumber
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
if ackFrame.LargestAcked < h.lowestUnacked() {
return nil
}
h.LargestAcked = ackFrame.LargestAcked
h.largestAcked = ackFrame.LargestAcked
if h.skippedPacketsAcked(ackFrame) {
return ErrAckForSkippedPacket
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
@ -164,13 +187,22 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
if len(ackedPackets) > 0 {
for _, p := range ackedPackets {
if encLevel < p.Value.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.Value.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
}
h.onPacketAcked(p)
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
}
}
h.detectLostPackets()
h.updateLossDetectionAlarm()
h.detectLostPackets(rcvTime)
h.updateLossDetectionAlarm(rcvTime)
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
@ -178,7 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
return nil
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) {
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
var ackedPackets []*PacketElement
ackRangeIndex := 0
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
@ -197,14 +233,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for packetNumber > ackRange.LastPacketNumber && ackRangeIndex < len(ackFrame.AckRanges)-1 {
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range
if packetNumber > ackRange.LastPacketNumber {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber)
if packetNumber >= ackRange.First { // packet i contained in ACK range
if packetNumber > ackRange.Last {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
}
ackedPackets = append(ackedPackets, el)
}
@ -212,7 +248,6 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
ackedPackets = append(ackedPackets, el)
}
}
return ackedPackets, nil
}
@ -220,7 +255,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber == largestAcked {
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
return true
}
// Packets are sorted by number, so we can stop searching
@ -231,27 +266,27 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
// Cancel the alarm if no packets are outstanding
if h.packetHistory.Len() == 0 {
h.alarm = time.Time{}
return
}
// TODO(#496): Handle handshake packets separately
// TODO(#497): TLP
if !h.lossTime.IsZero() {
if !h.handshakeComplete {
h.alarm = now.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO
h.alarm = time.Now().Add(h.computeRTOTimeout())
h.alarm = now.Add(h.computeRTOTimeout())
}
}
func (h *sentPacketHandler) detectLostPackets() {
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
h.lossTime = time.Time{}
now := time.Now()
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
@ -260,11 +295,11 @@ func (h *sentPacketHandler) detectLostPackets() {
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
packet := el.Value
if packet.PacketNumber > h.LargestAcked {
if packet.PacketNumber > h.largestAcked {
break
}
timeSinceSent := now.Sub(packet.SendTime)
timeSinceSent := now.Sub(packet.sendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, el)
} else if h.lossTime.IsZero() {
@ -282,18 +317,22 @@ func (h *sentPacketHandler) detectLostPackets() {
}
func (h *sentPacketHandler) OnAlarm() {
// TODO(#496): Handle handshake packets separately
now := time.Now()
// TODO(#497): TLP
if !h.lossTime.IsZero() {
if !h.handshakeComplete {
h.queueHandshakePacketsForRetransmission()
h.handshakeCount++
} else if !h.lossTime.IsZero() {
// Early retransmit or time loss detection
h.detectLostPackets()
h.detectLostPackets(now)
} else {
// RTO
h.retransmitOldestTwoPackets()
h.rtoCount++
}
h.updateLossDetectionAlarm()
h.updateLossDetectionAlarm(now)
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
@ -303,6 +342,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
h.bytesInFlight -= packetElement.Value.Length
h.rtoCount = 0
h.handshakeCount = 0
// TODO(#497): h.tlpCount = 0
h.packetHistory.Remove(packetElement)
}
@ -320,20 +360,19 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
}
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
return h.largestInOrderAcked() + 1
return h.lowestUnacked()
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendingAllowed() bool {
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
cwnd := h.congestion.GetCongestionWindow()
congestionLimited := h.bytesInFlight > cwnd
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
if congestionLimited {
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
h.bytesInFlight,
h.congestion.GetCongestionWindow())
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
// Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
@ -342,6 +381,18 @@ func (h *sentPacketHandler) SendingAllowed() bool {
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.nextPacketSendTime
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
}
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
if p := h.packetHistory.Front(); p != nil {
h.queueRTO(p)
@ -363,6 +414,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) {
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
var handshakePackets []*PacketElement
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, el)
}
}
for _, el := range handshakePackets {
h.queuePacketForRetransmission(el)
}
}
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
packet := &packetElement.Value
h.bytesInFlight -= packet.Length
@ -371,6 +434,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := 2 * h.rttStats.SmoothedRTT()
if duration == 0 {
duration = 2 * defaultInitialRTT
}
duration = utils.MaxDuration(duration, minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
rto := h.congestion.RetransmissionDelay()
if rto == 0 {
@ -382,7 +456,7 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool {
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
@ -392,10 +466,10 @@ func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lioa := h.largestInOrderAcked()
lowestUnacked := h.lowestUnacked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p <= lioa {
if p < lowestUnacked {
deleteIndex = i + 1
}
}

View file

@ -1,8 +1,8 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
@ -10,10 +10,10 @@ type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *frames.StopWaitingFrame
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
@ -22,14 +22,14 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaiting
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &frames.StopWaitingFrame{
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) {
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
if ack.LargestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = ack.LargestAcked + 1
}

View file

@ -3,7 +3,7 @@ package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Bandwidth of a connection

View file

@ -4,8 +4,8 @@ import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// This cubic implementation is based on the one found in Chromiums's QUIC

View file

@ -3,8 +3,8 @@ package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
const (
@ -76,15 +76,19 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
}
}
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() {
// PRR is used when in recovery.
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 {
return 0
}
}
if c.GetCongestionWindow() > bytesInFlight {
return 0
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS)
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return utils.InfDuration
return delay
}
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {

View file

@ -3,8 +3,8 @@ package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// Note(pwestin): the magic clamping numbers come from the original code in

View file

@ -3,12 +3,12 @@ package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface {
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart()

View file

@ -3,8 +3,8 @@ package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
)
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937

View file

@ -7,6 +7,7 @@ import (
)
const (
// Note: This constant is also defined in the ackhandler package.
initialRTTus = 100 * 1000
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
@ -97,10 +98,10 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
r.updateRecentMinRTT(sendDelta, now)
// Correct for ackDelay if information received from the peer results in a
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
// measure for smoothedRTT.
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample > ackDelay {
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample

View file

@ -1,6 +1,6 @@
package congestion
import "github.com/lucas-clemente/quic-go/protocol"
import "github.com/lucas-clemente/quic-go/internal/protocol"
type connectionStats struct {
slowstartPacketsLost protocol.PacketNumber

View file

@ -1,9 +1,10 @@
package crypto
import "github.com/lucas-clemente/quic-go/protocol"
import "github.com/lucas-clemente/quic-go/internal/protocol"
// An AEAD implements QUIC's authenticated encryption and associated data
type AEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
Overhead() int
}

View file

@ -0,0 +1,72 @@
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM12 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM12{}
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM12{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
func (aead *aeadAESGCM12) Overhead() int {
return aead.encrypter.Overhead()
}

View file

@ -0,0 +1,74 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM{}
const ivLen = 12
// NewAEADAESGCM creates a AEAD using AES-GCM
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
// the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce
if len(otherIV) != ivLen || len(myIV) != ivLen {
return nil, errors.New("AES-GCM: expected 12 byte IVs")
}
encrypterCipher, err := aes.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := cipher.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := cipher.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
nonce := make([]byte, ivLen)
binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber))
for i := 0; i < ivLen; i++ {
nonce[i] ^= iv[i]
}
return nonce
}
func (aead *aeadAESGCM) Overhead() int {
return aead.encrypter.Overhead()
}

View file

@ -5,7 +5,7 @@ import (
"hash/fnv"
"github.com/hashicorp/golang-lru"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (

View file

@ -51,10 +51,10 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
res.WriteByte(uint8(e.t))
switch e.t {
case entryCached:
utils.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint64(res, e.h)
case entryCommon:
utils.WriteUint64(res, e.h)
utils.WriteUint32(res, e.i)
utils.LittleEndian.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint32(res, e.i)
case entryCompressed:
totalUncompressedLen += 4 + len(chain[i])
}
@ -67,7 +67,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
}
utils.WriteUint32(res, uint32(totalUncompressedLen))
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
for i, e := range entries {
if e.t != entryCompressed {
@ -115,11 +115,11 @@ func decompressChain(data []byte) ([][]byte, error) {
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.ReadUint64(r)
e.h, err = utils.LittleEndian.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.ReadUint32(r)
e.i, err = utils.LittleEndian.ReadUint32(r)
if err != nil {
return nil, err
}
@ -146,7 +146,7 @@ func decompressChain(data []byte) ([][]byte, error) {
}
if hasCompressedCerts {
uncompressedLength, err := utils.ReadUint32(r)
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err

View file

@ -18,6 +18,7 @@ type CertManager interface {
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
GetChain() []*x509.Certificate
}
type certManager struct {
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
return nil
}
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}

View file

@ -4,11 +4,12 @@ package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
@ -45,9 +46,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View file

@ -0,0 +1,49 @@
package crypto
import (
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
)
// A TLSExporter gets the negotiated ciphersuite and computes exporter
type TLSExporter interface {
GetCipherSuite() mint.CipherSuiteParams
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string
if pers == protocol.PerspectiveClient {
myLabel = clientExporterLabel
otherLabel = serverExporterLabel
} else {
myLabel = serverExporterLabel
otherLabel = clientExporterLabel
}
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
if err != nil {
return nil, err
}
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
if err != nil {
return nil, err
}
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
cs := tls.GetCipherSuite()
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
if err != nil {
return nil, nil, err
}
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
return key, iv, nil
}

View file

@ -5,8 +5,8 @@ import (
"crypto/sha256"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"golang.org/x/crypto/hkdf"
)
@ -20,8 +20,8 @@ import (
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
// }
// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
var swap bool
if pers == protocol.PerspectiveClient {
swap = true
@ -30,7 +30,7 @@ func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID pr
if err != nil {
return nil, err
}
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
}
// deriveKeys derives the keys and the IVs
@ -42,7 +42,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
} else {
info.Write([]byte("QUIC key expansion\x00"))
}
utils.WriteUint64(&info, uint64(connID))
utils.BigEndian.WriteUint64(&info, uint64(connID))
info.Write(chlo)
info.Write(scfg)
info.Write(cert)

View file

@ -0,0 +1,11 @@
package crypto
import "github.com/lucas-clemente/quic-go/internal/protocol"
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
if v.UsesTLS() {
return newNullAEADAESGCM(connID, p)
}
return &nullAEADFNV128a{perspective: p}, nil
}

View file

@ -0,0 +1,44 @@
package crypto
import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myIV := computeNullAEADKeyAndIV(mySecret)
otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret)
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
connID := make([]byte, 8)
binary.BigEndian.PutUint64(connID, uint64(connectionID))
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
return
}
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
return
}

View file

@ -5,27 +5,18 @@ import (
"errors"
"github.com/lucas-clemente/fnv128a"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// nullAEAD handles not-yet encrypted packets
type nullAEAD struct {
type nullAEADFNV128a struct {
perspective protocol.Perspective
version protocol.VersionNumber
}
var _ AEAD = &nullAEAD{}
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD {
return &nullAEAD{
perspective: p,
version: v,
}
}
var _ AEAD = &nullAEADFNV128a{}
// Open and verify the ciphertext
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if len(src) < 12 {
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
}
@ -33,12 +24,10 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
hash := fnv128a.New()
hash.Write(associatedData)
hash.Write(src[12:])
if n.version >= protocol.Version37 {
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
testHigh, testLow := hash.Sum128()
@ -52,7 +41,7 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
}
// Seal writes hash and ciphertext to the buffer
func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < 12+len(src) {
dst = make([]byte, 12+len(src))
} else {
@ -63,12 +52,10 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
hash.Write(associatedData)
hash.Write(src)
if n.version >= protocol.Version37 {
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
high, low := hash.Sum128()
@ -78,3 +65,7 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
return dst
}
func (n *nullAEADFNV128a) Overhead() int {
return 12
}

View file

@ -0,0 +1,108 @@
package flowcontrol
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
// for receiving data
mutex sync.RWMutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
if offset > c.sendWindow {
c.sendWindow = offset
}
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return c.sendWindow - c.bytesSent
}
func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.startNewAutoTuningEpoch()
}
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
}
c.startNewAutoTuningEpoch()
}
func (c *baseFlowController) startNewAutoTuningEpoch() {
c.epochStartTime = time.Now()
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View file

@ -0,0 +1,83 @@
package flowcontrol
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type connectionFlowController struct {
lastBlockedAt protocol.ByteCount
baseFlowController
}
var _ ConnectionFlowController = &connectionFlowController{}
// NewConnectionFlowController gets a new flow controller for the connection
// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0.
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
},
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.highestReceived += increment
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow))
}
return nil
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.startNewAutoTuningEpoch()
}
c.mutex.Unlock()
}

View file

@ -0,0 +1,42 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for sending
IsBlocked() (bool, protocol.ByteCount)
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// HasWindowUpdate says if it is necessary to update the window
HasWindowUpdate() bool
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
// for sending
IsNewlyBlocked() (bool, protocol.ByteCount)
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
}

View file

@ -0,0 +1,147 @@
package flowcontrol
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamFlowController struct {
baseFlowController
streamID protocol.StreamID
connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control
receivedFinalOffset bool
}
var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
contributesToConnection bool,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
rttStats *congestion.RTTStats,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
},
}
}
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier
if final && c.receivedFinalOffset && byteOffset != c.highestReceived {
return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset))
}
// if we already received a final offset, check that the offset in the STREAM frames is below the final offset
if c.receivedFinalOffset && byteOffset > c.highestReceived {
return qerr.StreamDataAfterTermination
}
if final {
c.receivedFinalOffset = true
}
if byteOffset == c.highestReceived {
return nil
}
if byteOffset <= c.highestReceived {
// a STREAM_FRAME with a higher offset was received before.
if final {
// If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream
return qerr.StreamDataAfterTermination
}
// this is a reordered STREAM_FRAME
return nil
}
increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
}
if c.contributesToConnection {
return c.connection.IncrementHighestReceived(increment)
}
return nil
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.baseFlowController.AddBytesRead(n)
if c.contributesToConnection {
c.connection.AddBytesRead(n)
}
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
if c.contributesToConnection {
c.connection.AddBytesSent(n)
}
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize())
}
return window
}
// IsBlocked says if it is blocked by stream-level flow control.
// If it is blocked, the offset is returned.
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 {
return false, 0
}
return true, c.sendWindow
}
func (c *streamFlowController) HasWindowUpdate() bool {
c.mutex.Lock()
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
c.mutex.Unlock()
return hasWindowUpdate
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
if c.receivedFinalOffset {
c.mutex.Unlock()
return 0
}
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
}
c.mutex.Unlock()
return offset
}

View file

@ -0,0 +1,101 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/bifurcation/mint"
)
const (
cookiePrefixIP byte = iota
cookiePrefixString
)
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
type Cookie struct {
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// A CookieGenerator generates Cookies
type CookieGenerator struct {
cookieProtector mint.CookieProtector
}
// NewCookieGenerator initializes a new CookieGenerator
func NewCookieGenerator() (*CookieGenerator, error) {
cookieProtector, err := mint.NewDefaultCookieProtector()
if err != nil {
return nil, err
}
return &CookieGenerator{
cookieProtector: cookieProtector,
}, nil
}
// NewToken generates a new Cookie for a given source address
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.cookieProtector.NewToken(data)
}
// DecodeToken decodes a Cookie
func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
// if the client didn't send any Cookie, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.cookieProtector.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &Cookie{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{cookiePrefixIP}, udpAddr.IP...)
}
return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the Cookie
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a Cookie that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == cookiePrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View file

@ -0,0 +1,43 @@
package handshake
import (
"net"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type CookieHandler struct {
callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator
}
var _ mint.CookieHandler = &CookieHandler{}
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &CookieHandler{
callback: callback,
cookieGenerator: cookieGenerator,
}, nil
}
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) {
return nil, nil
}
return h.cookieGenerator.NewToken(conn.RemoteAddr())
}
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token)
if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
return false
}
return h.callback(conn.RemoteAddr(), data)
}

View file

@ -11,9 +11,9 @@ import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
@ -23,6 +23,7 @@ type cryptoSetupClient struct {
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber
cryptoStream io.ReadWriter
@ -42,17 +43,18 @@ type cryptoSetupClient struct {
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation KeyDerivationFunction
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
aeadChanged chan<- protocol.EncryptionLevel
params *TransportParameters
connectionParameters ConnectionParametersManager
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
params *TransportParameters
}
var _ CryptoSetup = &cryptoSetupClient{}
@ -65,36 +67,42 @@ var (
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
hostname string,
connID protocol.ConnectionID,
version protocol.VersionNumber,
cryptoStream io.ReadWriter,
tlsConfig *tls.Config,
connectionParameters ConnectionParametersManager,
aeadChanged chan<- protocol.EncryptionLevel,
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupClient{
hostname: hostname,
connID: connID,
version: version,
cryptoStream: cryptoStream,
certManager: crypto.NewCertManager(tlsConfig),
connectionParameters: connectionParameters,
keyDerivation: crypto.DeriveKeysAESGCM,
keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
aeadChanged: aeadChanged,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
params: params,
cryptoStream: cryptoStream,
hostname: hostname,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConfig),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte),
}, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error)
errorChan := make(chan error, 1)
go func() {
for {
@ -141,15 +149,21 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
utils.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
err = h.handleREJMessage(message.Data)
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
err = h.handleSHLOMessage(message.Data)
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
default:
return qerr.InvalidCryptoMessageType
}
if err != nil {
return err
}
}
}
@ -215,12 +229,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
@ -229,22 +243,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return err
return nil, err
}
leafCert := h.certManager.GetLeafCert()
@ -261,39 +275,32 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
protocol.PerspectiveClient,
)
if err != nil {
return err
return nil, err
}
err = h.connectionParameters.SetFromMap(cryptoData)
params, err := readHelloMap(cryptoData)
if err != nil {
return qerr.InvalidCryptoMessageParameter
return nil, qerr.InvalidCryptoMessageParameter
}
h.aeadChanged <- protocol.EncryptionForwardSecure
close(h.aeadChanged)
return nil
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
if len(h.negotiatedVersions) == 0 {
numNegotiatedVersions := len(h.negotiatedVersions)
if numNegotiatedVersions == 0 {
return true
}
if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) {
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
return false
}
b := bytes.NewReader(verTags)
for _, negotiatedVersion := range h.negotiatedVersions {
verTag, err := utils.ReadUint32(b)
for i := 0; i < numNegotiatedVersions; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { // should never occur, since the length was already checked
return false
}
ver := protocol.VersionTagToNumber(verTag)
if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) {
ver = protocol.VersionUnsupported
}
if ver != negotiatedVersion {
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
return false
}
}
@ -333,16 +340,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure
return protocol.EncryptionSecure, h.secureAEAD
} else {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
return protocol.EncryptionUnencrypted, h.nullAEAD
}
}
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
@ -351,33 +358,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.sealUnencrypted, nil
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.sealSecure, nil
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.sealForwardSecure, nil
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupClient) DiversificationNonce() []byte {
panic("not needed for cryptoSetupClient")
}
@ -386,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
@ -413,15 +417,11 @@ func (h *cryptoSetupClient) sendCHLO() error {
}
h.lastSentCHLO = b.Bytes()
return nil
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags, err := h.connectionParameters.GetHelloMap()
if err != nil {
return nil, err
}
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
@ -431,12 +431,9 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
}
versionTag := make([]byte, 4)
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
tags[TagVER] = versionTag
if h.params.RequestConnectionIDTruncation {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}
@ -470,7 +467,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
}
paddingSize := protocol.ClientHelloMinimumSize - size
paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
}
@ -508,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if err != nil {
return err
}
h.aeadChanged <- protocol.EncryptionSecure
h.handshakeEvent <- struct{}{}
}
return nil
}

Some files were not shown because too many files have changed in this diff Show more