mirror of
https://github.com/caddyserver/caddy.git
synced 2025-02-24 00:38:53 +01:00
vendor: Updated quic-go for QUIC 39+ (#1968)
* Updated lucas-clemente/quic-go for QUIC 39+ support * Update quic-go to latest
This commit is contained in:
parent
faa5248d1f
commit
1201492222
229 changed files with 26903 additions and 4254 deletions
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
21
vendor/github.com/bifurcation/mint/LICENSE.md
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Richard Barnes
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
99
vendor/github.com/bifurcation/mint/alert.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mint
|
||||
|
||||
import "strconv"
|
||||
|
||||
type Alert uint8
|
||||
|
||||
const (
|
||||
// alert level
|
||||
AlertLevelWarning = 1
|
||||
AlertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
AlertCloseNotify Alert = 0
|
||||
AlertUnexpectedMessage Alert = 10
|
||||
AlertBadRecordMAC Alert = 20
|
||||
AlertDecryptionFailed Alert = 21
|
||||
AlertRecordOverflow Alert = 22
|
||||
AlertDecompressionFailure Alert = 30
|
||||
AlertHandshakeFailure Alert = 40
|
||||
AlertBadCertificate Alert = 42
|
||||
AlertUnsupportedCertificate Alert = 43
|
||||
AlertCertificateRevoked Alert = 44
|
||||
AlertCertificateExpired Alert = 45
|
||||
AlertCertificateUnknown Alert = 46
|
||||
AlertIllegalParameter Alert = 47
|
||||
AlertUnknownCA Alert = 48
|
||||
AlertAccessDenied Alert = 49
|
||||
AlertDecodeError Alert = 50
|
||||
AlertDecryptError Alert = 51
|
||||
AlertProtocolVersion Alert = 70
|
||||
AlertInsufficientSecurity Alert = 71
|
||||
AlertInternalError Alert = 80
|
||||
AlertInappropriateFallback Alert = 86
|
||||
AlertUserCanceled Alert = 90
|
||||
AlertNoRenegotiation Alert = 100
|
||||
AlertMissingExtension Alert = 109
|
||||
AlertUnsupportedExtension Alert = 110
|
||||
AlertCertificateUnobtainable Alert = 111
|
||||
AlertUnrecognizedName Alert = 112
|
||||
AlertBadCertificateStatsResponse Alert = 113
|
||||
AlertBadCertificateHashValue Alert = 114
|
||||
AlertUnknownPSKIdentity Alert = 115
|
||||
AlertNoApplicationProtocol Alert = 120
|
||||
AlertWouldBlock Alert = 254
|
||||
AlertNoAlert Alert = 255
|
||||
)
|
||||
|
||||
var alertText = map[Alert]string{
|
||||
AlertCloseNotify: "close notify",
|
||||
AlertUnexpectedMessage: "unexpected message",
|
||||
AlertBadRecordMAC: "bad record MAC",
|
||||
AlertDecryptionFailed: "decryption failed",
|
||||
AlertRecordOverflow: "record overflow",
|
||||
AlertDecompressionFailure: "decompression failure",
|
||||
AlertHandshakeFailure: "handshake failure",
|
||||
AlertBadCertificate: "bad certificate",
|
||||
AlertUnsupportedCertificate: "unsupported certificate",
|
||||
AlertCertificateRevoked: "revoked certificate",
|
||||
AlertCertificateExpired: "expired certificate",
|
||||
AlertCertificateUnknown: "unknown certificate",
|
||||
AlertIllegalParameter: "illegal parameter",
|
||||
AlertUnknownCA: "unknown certificate authority",
|
||||
AlertAccessDenied: "access denied",
|
||||
AlertDecodeError: "error decoding message",
|
||||
AlertDecryptError: "error decrypting message",
|
||||
AlertProtocolVersion: "protocol version not supported",
|
||||
AlertInsufficientSecurity: "insufficient security level",
|
||||
AlertInternalError: "internal error",
|
||||
AlertInappropriateFallback: "inappropriate fallback",
|
||||
AlertUserCanceled: "user canceled",
|
||||
AlertMissingExtension: "missing extension",
|
||||
AlertUnsupportedExtension: "unsupported extension",
|
||||
AlertCertificateUnobtainable: "certificate unobtainable",
|
||||
AlertUnrecognizedName: "unrecognized name",
|
||||
AlertBadCertificateStatsResponse: "bad certificate status response",
|
||||
AlertBadCertificateHashValue: "bad certificate hash value",
|
||||
AlertUnknownPSKIdentity: "unknown PSK identity",
|
||||
AlertNoApplicationProtocol: "no application protocol",
|
||||
AlertNoRenegotiation: "no renegotiation",
|
||||
AlertWouldBlock: "would have blocked",
|
||||
AlertNoAlert: "no alert",
|
||||
}
|
||||
|
||||
func (e Alert) String() string {
|
||||
s, ok := alertText[e]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||
}
|
||||
|
||||
func (e Alert) Error() string {
|
||||
return e.String()
|
||||
}
|
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
42
vendor/github.com/bifurcation/mint/bin/mint-client-https/main.go
generated
vendored
Normal file
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var url string
|
||||
|
||||
func main() {
|
||||
url := flag.String("url", "https://localhost:4430", "URL to send request")
|
||||
flag.Parse()
|
||||
mintdial := func(network, addr string) (net.Conn, error) {
|
||||
return mint.Dial(network, addr, nil)
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
DialTLS: mintdial,
|
||||
DisableCompression: true,
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
response, err := client.Get(*url)
|
||||
if err != nil {
|
||||
fmt.Println("err:", err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("%s\n", string(contents))
|
||||
}
|
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
37
vendor/github.com/bifurcation/mint/bin/mint-client/main.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var addr string
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&addr, "addr", "localhost:4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
conn, err := mint.Dial("tcp", addr, nil)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println("TLS handshake failed:", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := "GET / HTTP/1.0\r\n\r\n"
|
||||
conn.Write([]byte(request))
|
||||
|
||||
response := ""
|
||||
buffer := make([]byte, 1024)
|
||||
var read int
|
||||
for err == nil {
|
||||
read, err = conn.Read(buffer)
|
||||
fmt.Println(" ~~ read: ", read)
|
||||
response += string(buffer)
|
||||
}
|
||||
fmt.Println("err:", err)
|
||||
fmt.Println("Received from server:")
|
||||
fmt.Println(response)
|
||||
}
|
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
226
vendor/github.com/bifurcation/mint/bin/mint-server-https/main.go
generated
vendored
Normal file
|
@ -0,0 +1,226 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
port string
|
||||
serverName string
|
||||
certFile string
|
||||
keyFile string
|
||||
responseFile string
|
||||
h2 bool
|
||||
sendTickets bool
|
||||
)
|
||||
|
||||
type responder []byte
|
||||
|
||||
func (rsp responder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(rsp)
|
||||
}
|
||||
|
||||
// ParsePrivateKeyDER parses a PKCS #1, PKCS #8, or elliptic curve
|
||||
// PEM-encoded private key.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) {
|
||||
keyDER, _ := pem.Decode(keyPEM)
|
||||
if keyDER == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
generalKey, err = x509.ParseECPrivateKey(keyDER.Bytes)
|
||||
if err != nil {
|
||||
// We don't include the actual error into
|
||||
// the final error. The reason might be
|
||||
// we don't want to leak any info about
|
||||
// the private key.
|
||||
return nil, fmt.Errorf("No successful private key decoder")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch generalKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return generalKey.(*rsa.PrivateKey), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return generalKey.(*ecdsa.PrivateKey), nil
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
return nil, fmt.Errorf("Should be unreachable")
|
||||
}
|
||||
|
||||
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
|
||||
// either a raw x509 certificate or a PKCS #7 structure possibly containing
|
||||
// multiple certificates, from the top of certsPEM, which itself may
|
||||
// contain multiple PEM encoded certificate objects.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
|
||||
block, rest := pem.Decode(certsPEM)
|
||||
if block == nil {
|
||||
return nil, rest, nil
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
var certs = []*x509.Certificate{cert}
|
||||
return certs, rest, err
|
||||
}
|
||||
|
||||
// ParseCertificatesPEM parses a sequence of PEM-encoded certificate and returns them,
|
||||
// can handle PEM encoded PKCS #7 structures.
|
||||
// XXX: Inlined from github.com/cloudflare/cfssl because of build issues with that module
|
||||
func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
|
||||
var certs []*x509.Certificate
|
||||
var err error
|
||||
certsPEM = bytes.TrimSpace(certsPEM)
|
||||
for len(certsPEM) > 0 {
|
||||
var cert []*x509.Certificate
|
||||
cert, certsPEM, err = ParseOneCertificateFromPEM(certsPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert == nil {
|
||||
break
|
||||
}
|
||||
|
||||
certs = append(certs, cert...)
|
||||
}
|
||||
if len(certsPEM) > 0 {
|
||||
return nil, fmt.Errorf("Trailing PEM data")
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.StringVar(&serverName, "host", "example.com", "hostname")
|
||||
flag.StringVar(&certFile, "cert", "", "certificate chain in PEM or DER")
|
||||
flag.StringVar(&keyFile, "key", "", "private key in PEM format")
|
||||
flag.StringVar(&responseFile, "response", "", "file to serve")
|
||||
flag.BoolVar(&h2, "h2", false, "whether to use HTTP/2 (exclusively)")
|
||||
flag.BoolVar(&sendTickets, "tickets", true, "whether to send session tickets")
|
||||
flag.Parse()
|
||||
|
||||
var certChain []*x509.Certificate
|
||||
var priv crypto.Signer
|
||||
var response []byte
|
||||
var err error
|
||||
|
||||
// Load the key and certificate chain
|
||||
if certFile != "" {
|
||||
certs, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
certChain, err = ParseCertificatesPEM(certs)
|
||||
if err != nil {
|
||||
certChain, err = x509.ParseCertificates(certs)
|
||||
if err != nil {
|
||||
log.Fatalf("Error parsing certificates: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if keyFile != "" {
|
||||
keyPEM, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
} else {
|
||||
priv, err = ParsePrivateKeyPEM(keyPEM)
|
||||
if priv == nil || err != nil {
|
||||
log.Fatalf("Error parsing private key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Load response file
|
||||
if responseFile != "" {
|
||||
log.Printf("Loading response file: %v", responseFile)
|
||||
response, err = ioutil.ReadFile(responseFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
response = []byte("Welcome to the TLS 1.3 zone!")
|
||||
}
|
||||
handler := responder(response)
|
||||
|
||||
config := mint.Config{
|
||||
SendSessionTickets: true,
|
||||
ServerName: serverName,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
|
||||
if h2 {
|
||||
config.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
config.SendSessionTickets = sendTickets
|
||||
|
||||
if certChain != nil && priv != nil {
|
||||
log.Printf("Loading cert: %v key: %v", certFile, keyFile)
|
||||
config.Certificates = []*mint.Certificate{
|
||||
{
|
||||
Chain: certChain,
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
config.Init(false)
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
srv := &http.Server{Handler: handler}
|
||||
|
||||
log.Printf("Listening on port %v", port)
|
||||
// Need the inner loop here because the h1 server errors on a dropped connection
|
||||
// Need the outer loop here because the h2 server is per-connection
|
||||
for {
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
if err != nil {
|
||||
log.Printf("Listen Error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !h2 {
|
||||
alert := srv.Serve(listener)
|
||||
if alert != mint.AlertNoAlert {
|
||||
log.Printf("Serve Error: %v", err)
|
||||
}
|
||||
} else {
|
||||
srv2 := new(http2.Server)
|
||||
opts := &http2.ServeConnOpts{
|
||||
Handler: handler,
|
||||
BaseConfig: srv,
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go srv2.ServeConn(conn, opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
65
vendor/github.com/bifurcation/mint/bin/mint-server/main.go
generated
vendored
Normal file
|
@ -0,0 +1,65 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
var port string
|
||||
|
||||
func main() {
|
||||
var config mint.Config
|
||||
config.SendSessionTickets = true
|
||||
config.ServerName = "localhost"
|
||||
config.Init(false)
|
||||
|
||||
flag.StringVar(&port, "port", "4430", "port")
|
||||
flag.Parse()
|
||||
|
||||
service := "0.0.0.0:" + port
|
||||
listener, err := mint.Listen("tcp", service, &config)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("server: listen: %s", err)
|
||||
}
|
||||
log.Print("server: listening")
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("server: accept: %s", err)
|
||||
break
|
||||
}
|
||||
defer conn.Close()
|
||||
log.Printf("server: accepted from %s", conn.RemoteAddr())
|
||||
go handleClient(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func handleClient(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 10)
|
||||
for {
|
||||
log.Print("server: conn: waiting")
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
log.Printf("server: conn: read: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
n, err = conn.Write([]byte("hello world"))
|
||||
log.Printf("server: conn: wrote %d bytes", n)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("server: write: %s", err)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Println("server: conn: closed")
|
||||
}
|
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
942
vendor/github.com/bifurcation/mint/client-state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,942 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"hash"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client State Machine
|
||||
//
|
||||
// START <----+
|
||||
// Send ClientHello | | Recv HelloRetryRequest
|
||||
// / v |
|
||||
// | WAIT_SH ---+
|
||||
// Can | | Recv ServerHello
|
||||
// send | V
|
||||
// early | WAIT_EE
|
||||
// data | | Recv EncryptedExtensions
|
||||
// | +--------+--------+
|
||||
// | Using | | Using certificate
|
||||
// | PSK | v
|
||||
// | | WAIT_CERT_CR
|
||||
// | | Recv | | Recv CertificateRequest
|
||||
// | | Certificate | v
|
||||
// | | | WAIT_CERT
|
||||
// | | | | Recv Certificate
|
||||
// | | v v
|
||||
// | | WAIT_CV
|
||||
// | | | Recv CertificateVerify
|
||||
// | +> WAIT_FINISHED <+
|
||||
// | | Recv Finished
|
||||
// \ |
|
||||
// | [Send EndOfEarlyData]
|
||||
// | [Send Certificate [+ CertificateVerify]]
|
||||
// | Send Finished
|
||||
// Can send v
|
||||
// app data --> CONNECTED
|
||||
// after
|
||||
// here
|
||||
//
|
||||
// State Instructions
|
||||
// START Send(CH); [RekeyOut; SendEarlyData]
|
||||
// WAIT_SH Send(CH) || RekeyIn
|
||||
// WAIT_EE {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ClientStateStart struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
|
||||
cookie []byte
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Unexpected non-nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// key_shares
|
||||
offeredDH := map[NamedGroup][]byte{}
|
||||
ks := KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Shares: make([]KeyShareEntry, len(state.Caps.Groups)),
|
||||
}
|
||||
for i, group := range state.Caps.Groups {
|
||||
pub, priv, err := newKeyShare(group)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
ks.Shares[i].Group = group
|
||||
ks.Shares[i].KeyExchange = pub
|
||||
offeredDH[group] = priv
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "opts: %+v", state.Opts)
|
||||
|
||||
// supported_versions, supported_groups, signature_algorithms, server_name
|
||||
sv := SupportedVersionsExtension{Versions: []uint16{supportedVersion}}
|
||||
sni := ServerNameExtension(state.Opts.ServerName)
|
||||
sg := SupportedGroupsExtension{Groups: state.Caps.Groups}
|
||||
sa := SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
|
||||
state.Params.ServerName = state.Opts.ServerName
|
||||
|
||||
// Application Layer Protocol Negotiation
|
||||
var alpn *ALPNExtension
|
||||
if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) {
|
||||
alpn = &ALPNExtension{Protocols: state.Opts.NextProtos}
|
||||
}
|
||||
|
||||
// Construct base ClientHello
|
||||
ch := &ClientHelloBody{
|
||||
CipherSuites: state.Caps.CipherSuites,
|
||||
}
|
||||
_, err := prng.Read(ch.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} {
|
||||
err := ch.Extensions.Add(ext)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
// XXX: These optional extensions can't be folded into the above because Go
|
||||
// interface-typed values are never reported as nil
|
||||
if alpn != nil {
|
||||
err := ch.Extensions.Add(alpn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.cookie != nil {
|
||||
err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PSK and EarlyData just before transmitting, so that we can
|
||||
// calculate the PSK binder value
|
||||
var psk *PreSharedKeyExtension
|
||||
var ed *EarlyDataExtension
|
||||
var offeredPSK PreSharedKey
|
||||
var earlyHash crypto.Hash
|
||||
var earlySecret []byte
|
||||
var clientEarlyTrafficKeys keySet
|
||||
var clientHello *HandshakeMessage
|
||||
if key, ok := state.Caps.PSKs.Get(state.Opts.ServerName); ok {
|
||||
offeredPSK = key
|
||||
|
||||
// Narrow ciphersuites to ones that match PSK hash
|
||||
params, ok := cipherSuiteMap[key.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
compatibleSuites := []CipherSuite{}
|
||||
for _, suite := range ch.CipherSuites {
|
||||
if cipherSuiteMap[suite].Hash == params.Hash {
|
||||
compatibleSuites = append(compatibleSuites, suite)
|
||||
}
|
||||
}
|
||||
ch.CipherSuites = compatibleSuites
|
||||
|
||||
// Signal early data if we're going to do it
|
||||
if len(state.Opts.EarlyData) > 0 {
|
||||
state.Params.ClientSendingEarlyData = true
|
||||
ed = &EarlyDataExtension{}
|
||||
err = ch.Extensions.Add(ed)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding early data extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Signal supported PSK key exchange modes
|
||||
if len(state.Caps.PSKModes) == 0 {
|
||||
logf(logTypeHandshake, "PSK selected, but no PSKModes")
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
kem := &PSKKeyExchangeModesExtension{KEModes: state.Caps.PSKModes}
|
||||
err = ch.Extensions.Add(kem)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Add the shim PSK extension to the ClientHello
|
||||
logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity)
|
||||
psk = &PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
Identities: []PSKIdentity{
|
||||
{
|
||||
Identity: key.Identity,
|
||||
ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd,
|
||||
},
|
||||
},
|
||||
Binders: []PSKBinderEntry{
|
||||
// Note: Stub to get the length fields right
|
||||
{Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())},
|
||||
},
|
||||
}
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// Compute the binder key
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
earlyHash = params.Hash
|
||||
earlySecret = HkdfExtract(params.Hash, zero, key.Key)
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
|
||||
binderLabel := labelExternalBinder
|
||||
if key.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey)
|
||||
|
||||
// Compute the binder value
|
||||
trunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
truncHash := params.Hash.New()
|
||||
truncHash.Write(trunc)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, truncHash.Sum(nil))
|
||||
|
||||
// Replace the PSK extension
|
||||
psk.Binders[0].Binder = binder
|
||||
ch.Extensions.Add(psk)
|
||||
|
||||
// If we got here, the earlier marshal succeeded (in ch.Truncated()), so
|
||||
// this one should too.
|
||||
clientHello, _ = HandshakeMessageFromBody(ch)
|
||||
|
||||
// Compute early traffic keys
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
|
||||
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
|
||||
} else if len(state.Opts.EarlyData) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
|
||||
return nil, nil, AlertInternalError
|
||||
} else {
|
||||
clientHello, err = HandshakeMessageFromBody(ch)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]")
|
||||
nextState := ClientStateWaitSH{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
Params: state.Params,
|
||||
OfferedDH: offeredDH,
|
||||
OfferedPSK: offeredPSK,
|
||||
|
||||
earlySecret: earlySecret,
|
||||
earlyHash: earlyHash,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{clientHello},
|
||||
}
|
||||
if state.Params.ClientSendingEarlyData {
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyOut{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
SendEarlyData{},
|
||||
}...)
|
||||
}
|
||||
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitSH struct {
|
||||
Caps Capabilities
|
||||
Opts ConnectionOptions
|
||||
Params ConnectionParameters
|
||||
OfferedDH map[NamedGroup][]byte
|
||||
OfferedPSK PreSharedKey
|
||||
PSK []byte
|
||||
|
||||
earlySecret []byte
|
||||
earlyHash crypto.Hash
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ClientStateWaitSH) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected nil message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *HelloRetryRequestBody:
|
||||
hrr := body
|
||||
|
||||
if state.helloRetryRequest != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Received a second HelloRetryRequest")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if hrr.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", hrr.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == hrr.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", hrr.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Narrow the supported ciphersuites to the server-provided one
|
||||
state.Caps.CipherSuites = []CipherSuite{hrr.CipherSuite}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// The only thing we know how to respond to in an HRR is the Cookie
|
||||
// extension, so if there is either no Cookie extension or anything other
|
||||
// than a Cookie extension, we have to fail.
|
||||
serverCookie := new(CookieExtension)
|
||||
foundCookie := hrr.Extensions.Find(serverCookie)
|
||||
if !foundCookie || len(hrr.Extensions) != 1 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions))
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Hash the body into a pseudo-message
|
||||
// XXX: Ignoring some errors here
|
||||
params := cipherSuiteMap[hrr.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(state.clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
|
||||
return ClientStateStart{
|
||||
Caps: state.Caps,
|
||||
Opts: state.Opts,
|
||||
cookie: serverCookie.Cookie,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: hm,
|
||||
}.Next(nil)
|
||||
|
||||
case *ServerHelloBody:
|
||||
sh := body
|
||||
|
||||
// Check that the version sent by the server is the one we support
|
||||
if sh.Version != supportedVersion {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported version [%v]", sh.Version)
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
// Check that the server provided a supported ciphersuite
|
||||
supportedCipherSuite := false
|
||||
for _, suite := range state.Caps.CipherSuites {
|
||||
supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite)
|
||||
}
|
||||
if !supportedCipherSuite {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Do PSK or key agreement depending on extensions
|
||||
serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello}
|
||||
|
||||
foundPSK := sh.Extensions.Find(&serverPSK)
|
||||
foundKeyShare := sh.Extensions.Find(&serverKeyShare)
|
||||
|
||||
if foundPSK && (serverPSK.SelectedIdentity == 0) {
|
||||
state.Params.UsingPSK = true
|
||||
}
|
||||
|
||||
var dhSecret []byte
|
||||
if foundKeyShare {
|
||||
sks := serverKeyShare.Shares[0]
|
||||
priv, ok := state.OfferedDH[sks.Group]
|
||||
if !ok {
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group")
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingDH = true
|
||||
dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv)
|
||||
}
|
||||
|
||||
suite := sh.CipherSuite
|
||||
state.Params.CipherSuite = suite
|
||||
|
||||
params, ok := cipherSuiteMap[suite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(hm.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
if params.Hash != state.earlyHash {
|
||||
logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]",
|
||||
state.earlyHash, suite, params.Hash)
|
||||
}
|
||||
|
||||
earlySecret = state.earlySecret
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if dhSecret == nil {
|
||||
dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
|
||||
nextState := ClientStateWaitEE{
|
||||
Caps: state.Caps,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
certificates: state.Caps.Certificates,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: serverHandshakeTrafficSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message [%d]", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitEE struct {
|
||||
Caps Capabilities
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitEE) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ee := EncryptedExtensionsBody{}
|
||||
_, err := ee.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverALPN := ALPNExtension{}
|
||||
serverEarlyData := EarlyDataExtension{}
|
||||
|
||||
gotALPN := ee.Extensions.Find(&serverALPN)
|
||||
state.Params.UsingEarlyData = ee.Extensions.Find(&serverEarlyData)
|
||||
|
||||
if gotALPN && len(serverALPN.Protocols) > 0 {
|
||||
state.Params.NextProto = serverALPN.Protocols[0]
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
|
||||
nextState := ClientStateWaitCertCR{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCertCR struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
certificates []*Certificate
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCertCR) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *CertificateBody:
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
|
||||
case *CertificateRequestBody:
|
||||
// A certificate request in the handshake should have a zero-length context
|
||||
if len(body.CertificateRequestContext) > 0 {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]")
|
||||
nextState := ClientStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: body,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
type ClientStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]")
|
||||
nextState := ClientStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificate: cert,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificate *CertificateBody
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(serverPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.serverCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] WARNING: No verification of server certificate")
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]")
|
||||
nextState := ClientStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
certificates: state.certificates,
|
||||
serverCertificateRequest: state.serverCertificateRequest,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ClientStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
handshakeHash hash.Hash
|
||||
|
||||
certificates []*Certificate
|
||||
serverCertificateRequest *CertificateRequestBody
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
serverHandshakeTrafficSecret []byte
|
||||
}
|
||||
|
||||
func (state ClientStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Verify server's Finished
|
||||
h3 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, serverFinishedData) {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]",
|
||||
fin.VerifyData, serverFinishedData)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Update the handshake hash with the Finished
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal())
|
||||
h4 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4)
|
||||
|
||||
// Compute traffic secrets and keys
|
||||
clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret)
|
||||
serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret)
|
||||
|
||||
exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
// Assemble client's second flight
|
||||
toSend := []HandshakeAction{}
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
// Note: We only send EOED if the server is actually going to use the early
|
||||
// data. Otherwise, it will never see it, and the transcripts will
|
||||
// mismatch.
|
||||
// EOED marshal is infallible
|
||||
eoedm, _ := HandshakeMessageFromBody(&EndOfEarlyDataBody{})
|
||||
toSend = append(toSend, SendHandshakeMessage{eoedm})
|
||||
state.handshakeHash.Write(eoedm.Marshal())
|
||||
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
|
||||
}
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "handshake", KeySet: clientHandshakeKeys})
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
// Extract constraints from certicateRequest
|
||||
schemes := SignatureAlgorithmsExtension{}
|
||||
gotSchemes := state.serverCertificateRequest.Extensions.Find(&schemes)
|
||||
if !gotSchemes {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertIllegalParameter
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates)
|
||||
if err != nil {
|
||||
// XXX: Signal this to the application layer?
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err)
|
||||
|
||||
certificate := &CertificateBody{}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
} else {
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(cert.Chain)),
|
||||
}
|
||||
for i, entry := range cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
state.handshakeHash.Write(certm.Marshal())
|
||||
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash)
|
||||
|
||||
err = certificateVerify.Sign(cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
state.handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the client's Finished message
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
fin = &FinishedBody{
|
||||
VerifyDataLen: len(clientFinishedData),
|
||||
VerifyData: clientFinishedData,
|
||||
}
|
||||
finm, err := HandshakeMessageFromBody(fin)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(finm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
SendHandshakeMessage{finm},
|
||||
RekeyIn{Label: "application", KeySet: serverTrafficKeys},
|
||||
RekeyOut{Label: "application", KeySet: clientTrafficKeys},
|
||||
}...)
|
||||
|
||||
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: true,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
152
vendor/github.com/bifurcation/mint/common.go
generated
vendored
Normal file
|
@ -0,0 +1,152 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedVersion uint16 = 0x7f15 // draft-21
|
||||
|
||||
// Flags for some minor compat issues
|
||||
allowWrongVersionNumber = true
|
||||
allowPKCS1 = true
|
||||
)
|
||||
|
||||
// enum {...} ContentType;
|
||||
type RecordType byte
|
||||
|
||||
const (
|
||||
RecordTypeAlert RecordType = 21
|
||||
RecordTypeHandshake RecordType = 22
|
||||
RecordTypeApplicationData RecordType = 23
|
||||
)
|
||||
|
||||
// enum {...} HandshakeType;
|
||||
type HandshakeType byte
|
||||
|
||||
const (
|
||||
// Omitted: *_RESERVED
|
||||
HandshakeTypeClientHello HandshakeType = 1
|
||||
HandshakeTypeServerHello HandshakeType = 2
|
||||
HandshakeTypeNewSessionTicket HandshakeType = 4
|
||||
HandshakeTypeEndOfEarlyData HandshakeType = 5
|
||||
HandshakeTypeHelloRetryRequest HandshakeType = 6
|
||||
HandshakeTypeEncryptedExtensions HandshakeType = 8
|
||||
HandshakeTypeCertificate HandshakeType = 11
|
||||
HandshakeTypeCertificateRequest HandshakeType = 13
|
||||
HandshakeTypeCertificateVerify HandshakeType = 15
|
||||
HandshakeTypeServerConfiguration HandshakeType = 17
|
||||
HandshakeTypeFinished HandshakeType = 20
|
||||
HandshakeTypeKeyUpdate HandshakeType = 24
|
||||
HandshakeTypeMessageHash HandshakeType = 254
|
||||
)
|
||||
|
||||
// uint8 CipherSuite[2];
|
||||
type CipherSuite uint16
|
||||
|
||||
const (
|
||||
// XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero
|
||||
// value for this type so that we can detect when a field is set.
|
||||
CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000
|
||||
TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301
|
||||
TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302
|
||||
TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303
|
||||
TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304
|
||||
TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305
|
||||
)
|
||||
|
||||
func (c CipherSuite) String() string {
|
||||
switch c {
|
||||
case CIPHER_SUITE_UNKNOWN:
|
||||
return "unknown"
|
||||
case TLS_AES_128_GCM_SHA256:
|
||||
return "TLS_AES_128_GCM_SHA256"
|
||||
case TLS_AES_256_GCM_SHA384:
|
||||
return "TLS_AES_256_GCM_SHA384"
|
||||
case TLS_CHACHA20_POLY1305_SHA256:
|
||||
return "TLS_CHACHA20_POLY1305_SHA256"
|
||||
case TLS_AES_128_CCM_SHA256:
|
||||
return "TLS_AES_128_CCM_SHA256"
|
||||
case TLS_AES_256_CCM_8_SHA256:
|
||||
return "TLS_AES_256_CCM_8_SHA256"
|
||||
}
|
||||
// cannot use %x here, since it calls String(), leading to infinite recursion
|
||||
return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16))
|
||||
}
|
||||
|
||||
// enum {...} SignatureScheme
|
||||
type SignatureScheme uint16
|
||||
|
||||
const (
|
||||
// RSASSA-PKCS1-v1_5 algorithms
|
||||
RSA_PKCS1_SHA1 SignatureScheme = 0x0201
|
||||
RSA_PKCS1_SHA256 SignatureScheme = 0x0401
|
||||
RSA_PKCS1_SHA384 SignatureScheme = 0x0501
|
||||
RSA_PKCS1_SHA512 SignatureScheme = 0x0601
|
||||
// ECDSA algorithms
|
||||
ECDSA_P256_SHA256 SignatureScheme = 0x0403
|
||||
ECDSA_P384_SHA384 SignatureScheme = 0x0503
|
||||
ECDSA_P521_SHA512 SignatureScheme = 0x0603
|
||||
// RSASSA-PSS algorithms
|
||||
RSA_PSS_SHA256 SignatureScheme = 0x0804
|
||||
RSA_PSS_SHA384 SignatureScheme = 0x0805
|
||||
RSA_PSS_SHA512 SignatureScheme = 0x0806
|
||||
// EdDSA algorithms
|
||||
Ed25519 SignatureScheme = 0x0807
|
||||
Ed448 SignatureScheme = 0x0808
|
||||
)
|
||||
|
||||
// enum {...} ExtensionType
|
||||
type ExtensionType uint16
|
||||
|
||||
const (
|
||||
ExtensionTypeServerName ExtensionType = 0
|
||||
ExtensionTypeSupportedGroups ExtensionType = 10
|
||||
ExtensionTypeSignatureAlgorithms ExtensionType = 13
|
||||
ExtensionTypeALPN ExtensionType = 16
|
||||
ExtensionTypeKeyShare ExtensionType = 40
|
||||
ExtensionTypePreSharedKey ExtensionType = 41
|
||||
ExtensionTypeEarlyData ExtensionType = 42
|
||||
ExtensionTypeSupportedVersions ExtensionType = 43
|
||||
ExtensionTypeCookie ExtensionType = 44
|
||||
ExtensionTypePSKKeyExchangeModes ExtensionType = 45
|
||||
ExtensionTypeTicketEarlyDataInfo ExtensionType = 46
|
||||
)
|
||||
|
||||
// enum {...} NamedGroup
|
||||
type NamedGroup uint16
|
||||
|
||||
const (
|
||||
// Elliptic Curve Groups.
|
||||
P256 NamedGroup = 23
|
||||
P384 NamedGroup = 24
|
||||
P521 NamedGroup = 25
|
||||
// ECDH functions.
|
||||
X25519 NamedGroup = 29
|
||||
X448 NamedGroup = 30
|
||||
// Finite field groups.
|
||||
FFDHE2048 NamedGroup = 256
|
||||
FFDHE3072 NamedGroup = 257
|
||||
FFDHE4096 NamedGroup = 258
|
||||
FFDHE6144 NamedGroup = 259
|
||||
FFDHE8192 NamedGroup = 260
|
||||
)
|
||||
|
||||
// enum {...} PskKeyExchangeMode;
|
||||
type PSKKeyExchangeMode uint8
|
||||
|
||||
const (
|
||||
PSKModeKE PSKKeyExchangeMode = 0
|
||||
PSKModeDHEKE PSKKeyExchangeMode = 1
|
||||
)
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
type KeyUpdateRequest uint8
|
||||
|
||||
const (
|
||||
KeyUpdateNotRequested KeyUpdateRequest = 0
|
||||
KeyUpdateRequested KeyUpdateRequest = 1
|
||||
)
|
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
819
vendor/github.com/bifurcation/mint/conn.go
generated
vendored
Normal file
|
@ -0,0 +1,819 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var WouldBlock = fmt.Errorf("Would have blocked")
|
||||
|
||||
type Certificate struct {
|
||||
Chain []*x509.Certificate
|
||||
PrivateKey crypto.Signer
|
||||
}
|
||||
|
||||
type PreSharedKey struct {
|
||||
CipherSuite CipherSuite
|
||||
IsResumption bool
|
||||
Identity []byte
|
||||
Key []byte
|
||||
NextProto string
|
||||
ReceivedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
TicketAgeAdd uint32
|
||||
}
|
||||
|
||||
type PreSharedKeyCache interface {
|
||||
Get(string) (PreSharedKey, bool)
|
||||
Put(string, PreSharedKey)
|
||||
Size() int
|
||||
}
|
||||
|
||||
type PSKMapCache map[string]PreSharedKey
|
||||
|
||||
// A CookieHandler does two things:
|
||||
// - generates a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest
|
||||
// - validates this byte string echoed by the client in the ClientHello
|
||||
type CookieHandler interface {
|
||||
Generate(*Conn) ([]byte, error)
|
||||
Validate(*Conn, []byte) bool
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) {
|
||||
psk, ok = cache[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (cache *PSKMapCache) Put(key string, psk PreSharedKey) {
|
||||
(*cache)[key] = psk
|
||||
}
|
||||
|
||||
func (cache PSKMapCache) Size() int {
|
||||
return len(cache)
|
||||
}
|
||||
|
||||
// Config is the struct used to pass configuration settings to a TLS client or
|
||||
// server instance. The settings for client and server are pretty different,
|
||||
// but we just throw them all in here.
|
||||
type Config struct {
|
||||
// Client fields
|
||||
ServerName string
|
||||
|
||||
// Server fields
|
||||
SendSessionTickets bool
|
||||
TicketLifetime uint32
|
||||
TicketLen int
|
||||
EarlyDataLifetime uint32
|
||||
AllowEarlyData bool
|
||||
// Require the client to echo a cookie.
|
||||
RequireCookie bool
|
||||
// If cookies are required and no CookieHandler is set, a default cookie handler is used.
|
||||
// The default cookie handler uses 32 random bytes as a cookie.
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
|
||||
// Shared fields
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
NextProtos []string
|
||||
PSKs PreSharedKeyCache
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
NonBlocking bool
|
||||
|
||||
// The same config object can be shared among different connections, so it
|
||||
// needs its own mutex
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Clone returns a shallow clone of c. It is safe to clone a Config that is
|
||||
// being used concurrently by a TLS client or server.
|
||||
func (c *Config) Clone() *Config {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return &Config{
|
||||
ServerName: c.ServerName,
|
||||
|
||||
SendSessionTickets: c.SendSessionTickets,
|
||||
TicketLifetime: c.TicketLifetime,
|
||||
TicketLen: c.TicketLen,
|
||||
EarlyDataLifetime: c.EarlyDataLifetime,
|
||||
AllowEarlyData: c.AllowEarlyData,
|
||||
RequireCookie: c.RequireCookie,
|
||||
RequireClientAuth: c.RequireClientAuth,
|
||||
|
||||
Certificates: c.Certificates,
|
||||
AuthCertificate: c.AuthCertificate,
|
||||
CipherSuites: c.CipherSuites,
|
||||
Groups: c.Groups,
|
||||
SignatureSchemes: c.SignatureSchemes,
|
||||
NextProtos: c.NextProtos,
|
||||
PSKs: c.PSKs,
|
||||
PSKModes: c.PSKModes,
|
||||
NonBlocking: c.NonBlocking,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Init(isClient bool) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Set defaults
|
||||
if len(c.CipherSuites) == 0 {
|
||||
c.CipherSuites = defaultSupportedCipherSuites
|
||||
}
|
||||
if len(c.Groups) == 0 {
|
||||
c.Groups = defaultSupportedGroups
|
||||
}
|
||||
if len(c.SignatureSchemes) == 0 {
|
||||
c.SignatureSchemes = defaultSignatureSchemes
|
||||
}
|
||||
if c.TicketLen == 0 {
|
||||
c.TicketLen = defaultTicketLen
|
||||
}
|
||||
if !reflect.ValueOf(c.PSKs).IsValid() {
|
||||
c.PSKs = &PSKMapCache{}
|
||||
}
|
||||
if len(c.PSKModes) == 0 {
|
||||
c.PSKModes = defaultPSKModes
|
||||
}
|
||||
|
||||
// If there is no certificate, generate one
|
||||
if !isClient && len(c.Certificates) == 0 {
|
||||
logf(logTypeHandshake, "Generating key name=%v", c.ServerName)
|
||||
priv, err := newSigningKey(RSA_PSS_SHA256)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := newSelfSigned(c.ServerName, RSA_PKCS1_SHA256, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Certificates = []*Certificate{
|
||||
{
|
||||
Chain: []*x509.Certificate{cert},
|
||||
PrivateKey: priv,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) ValidForServer() bool {
|
||||
return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) ||
|
||||
(len(c.Certificates) > 0 &&
|
||||
len(c.Certificates[0].Chain) > 0 &&
|
||||
c.Certificates[0].PrivateKey != nil)
|
||||
}
|
||||
|
||||
func (c *Config) ValidForClient() bool {
|
||||
return len(c.ServerName) > 0
|
||||
}
|
||||
|
||||
var (
|
||||
defaultSupportedCipherSuites = []CipherSuite{
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
TLS_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
defaultSupportedGroups = []NamedGroup{
|
||||
P256,
|
||||
P384,
|
||||
FFDHE2048,
|
||||
X25519,
|
||||
}
|
||||
|
||||
defaultSignatureSchemes = []SignatureScheme{
|
||||
RSA_PSS_SHA256,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512,
|
||||
ECDSA_P256_SHA256,
|
||||
ECDSA_P384_SHA384,
|
||||
ECDSA_P521_SHA512,
|
||||
}
|
||||
|
||||
defaultTicketLen = 16
|
||||
|
||||
defaultPSKModes = []PSKKeyExchangeMode{
|
||||
PSKModeKE,
|
||||
PSKModeDHEKE,
|
||||
}
|
||||
)
|
||||
|
||||
type ConnectionState struct {
|
||||
HandshakeState string // string representation of the handshake state.
|
||||
CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer TODO(ekr@rtfm.com): implement
|
||||
NextProto string // Selected ALPN proto
|
||||
}
|
||||
|
||||
// Conn implements the net.Conn interface, as with "crypto/tls"
|
||||
// * Read, Write, and Close are provided locally
|
||||
// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn
|
||||
type Conn struct {
|
||||
config *Config
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
EarlyData []byte
|
||||
|
||||
state StateConnected
|
||||
hState HandshakeState
|
||||
handshakeMutex sync.Mutex
|
||||
handshakeAlert Alert
|
||||
handshakeComplete bool
|
||||
|
||||
readBuffer []byte
|
||||
in, out *RecordLayer
|
||||
hIn, hOut *HandshakeLayer
|
||||
|
||||
extHandler AppExtensionHandler
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
|
||||
c := &Conn{conn: conn, config: config, isClient: isClient}
|
||||
c.in = NewRecordLayer(c.conn)
|
||||
c.out = NewRecordLayer(c.conn)
|
||||
c.hIn = NewHandshakeLayer(c.in)
|
||||
c.hIn.nonblocking = c.config.NonBlocking
|
||||
c.hOut = NewHandshakeLayer(c.out)
|
||||
return c
|
||||
}
|
||||
|
||||
// Read up
|
||||
func (c *Conn) consumeRecord() error {
|
||||
pt, err := c.in.ReadRecord()
|
||||
if pt == nil {
|
||||
logf(logTypeIO, "extendBuffer returns error %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt.contentType {
|
||||
case RecordTypeHandshake:
|
||||
logf(logTypeHandshake, "Received post-handshake message")
|
||||
// We do not support fragmentation of post-handshake handshake messages.
|
||||
// TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage()
|
||||
start := 0
|
||||
for start < len(pt.fragment) {
|
||||
if len(pt.fragment[start:]) < handshakeHeaderLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for header")
|
||||
}
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(pt.fragment[start])
|
||||
hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3])
|
||||
|
||||
if len(pt.fragment[start+handshakeHeaderLen:]) < hmLen {
|
||||
return fmt.Errorf("Post-handshake handshake message too short for body")
|
||||
}
|
||||
hm.body = pt.fragment[start+handshakeHeaderLen : start+handshakeHeaderLen+hmLen]
|
||||
|
||||
// Advance state machine
|
||||
state, actions, alert := c.state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: If we want to support more advanced cases, e.g., post-handshake
|
||||
// authentication, we'll need to allow transitions other than
|
||||
// Connected -> Connected
|
||||
var connected bool
|
||||
c.state, connected = state.(StateConnected)
|
||||
if !connected {
|
||||
logf(logTypeHandshake, "Disconnected after state transition: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
start += handshakeHeaderLen + hmLen
|
||||
}
|
||||
case RecordTypeAlert:
|
||||
logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
if len(pt.fragment) != 2 {
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
if Alert(pt.fragment[1]) == AlertCloseNotify {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
switch pt.fragment[0] {
|
||||
case AlertLevelWarning:
|
||||
// drop on the floor
|
||||
case AlertLevelError:
|
||||
return Alert(pt.fragment[1])
|
||||
default:
|
||||
c.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
case RecordTypeApplicationData:
|
||||
c.readBuffer = append(c.readBuffer, pt.fragment...)
|
||||
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Read application data up to the size of buffer. Handshake and alert records
|
||||
// are consumed by the Conn object directly.
|
||||
func (c *Conn) Read(buffer []byte) (int, error) {
|
||||
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
|
||||
if alert := c.Handshake(); alert != AlertNoAlert {
|
||||
return 0, alert
|
||||
}
|
||||
|
||||
if len(buffer) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Lock the input channel
|
||||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
for len(c.readBuffer) == 0 {
|
||||
err := c.consumeRecord()
|
||||
|
||||
// err can be nil if consumeRecord processed a non app-data
|
||||
// record.
|
||||
if err != nil {
|
||||
if c.config.NonBlocking || err != WouldBlock {
|
||||
logf(logTypeIO, "conn.Read returns err=%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var read int
|
||||
n := len(buffer)
|
||||
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
|
||||
if len(c.readBuffer) <= n {
|
||||
buffer = buffer[:len(c.readBuffer)]
|
||||
copy(buffer, c.readBuffer)
|
||||
read = len(c.readBuffer)
|
||||
c.readBuffer = c.readBuffer[:0]
|
||||
} else {
|
||||
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
|
||||
copy(buffer[:n], c.readBuffer[:n])
|
||||
c.readBuffer = c.readBuffer[n:]
|
||||
read = n
|
||||
}
|
||||
|
||||
logf(logTypeVerbose, "Returning %v", string(buffer))
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// Write application data
|
||||
func (c *Conn) Write(buffer []byte) (int, error) {
|
||||
// Lock the output channel
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
sent := 0
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += maxFragmentLen
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
sent += len(buffer[start:])
|
||||
}
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
// c.out.Mutex <= L.
|
||||
func (c *Conn) sendAlert(err Alert) error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
var level int
|
||||
switch err {
|
||||
case AlertNoRenegotiation, AlertCloseNotify:
|
||||
level = AlertLevelWarning
|
||||
default:
|
||||
level = AlertLevelError
|
||||
}
|
||||
|
||||
buf := []byte{byte(err), byte(level)}
|
||||
c.out.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: buf,
|
||||
})
|
||||
|
||||
// close_notify and end_of_early_data are not actually errors
|
||||
if level == AlertLevelWarning {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
// XXX crypto/tls has an interlock with Write here. Do we need that?
|
||||
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||
// A zero value for t means Read and Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline on the underlying connection.
|
||||
// A zero value for t means Read will not time out.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline on the underlying connection.
|
||||
// A zero value for t means Write will not time out.
|
||||
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
switch action := actionGeneric.(type) {
|
||||
case SendHandshakeMessage:
|
||||
err := c.hOut.WriteMessage(action.Message)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyIn:
|
||||
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.in.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case RekeyOut:
|
||||
logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.Label, action.KeySet)
|
||||
err := c.out.Rekey(action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case SendEarlyData:
|
||||
logf(logTypeHandshake, "%s Sending early data...", label)
|
||||
_, err := c.Write(c.EarlyData)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
case ReadPastEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading past early data...", label)
|
||||
// Scan past all records that fail to decrypt
|
||||
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok := err.(DecryptError)
|
||||
|
||||
for ok {
|
||||
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
_, ok = err.(DecryptError)
|
||||
}
|
||||
|
||||
case ReadEarlyData:
|
||||
logf(logTypeHandshake, "%s Reading early data...", label)
|
||||
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
|
||||
|
||||
for t == RecordTypeApplicationData {
|
||||
// Read a record into the buffer. Note that this is safe
|
||||
// in blocking mode because we read the record in in
|
||||
// PeekRecordType.
|
||||
pt, err := c.in.ReadRecord()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
|
||||
c.EarlyData = append(c.EarlyData, pt.fragment...)
|
||||
|
||||
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
|
||||
return AlertInternalError
|
||||
}
|
||||
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
|
||||
}
|
||||
logf(logTypeHandshake, "%s Done reading early data", label)
|
||||
|
||||
case StorePSK:
|
||||
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
|
||||
if c.isClient {
|
||||
// Clients look up PSKs based on server name
|
||||
c.config.PSKs.Put(c.config.ServerName, action.PSK)
|
||||
} else {
|
||||
// Servers look them up based on the identity in the extension
|
||||
c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK)
|
||||
}
|
||||
|
||||
default:
|
||||
logf(logTypeHandshake, "%s Unknown actionuction type", label)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) HandshakeSetup() Alert {
|
||||
var state HandshakeState
|
||||
var actions []HandshakeAction
|
||||
var alert Alert
|
||||
|
||||
if err := c.config.Init(c.isClient); err != nil {
|
||||
logf(logTypeHandshake, "Error initializing config: %v", err)
|
||||
return AlertInternalError
|
||||
}
|
||||
|
||||
// Set things up
|
||||
caps := Capabilities{
|
||||
CipherSuites: c.config.CipherSuites,
|
||||
Groups: c.config.Groups,
|
||||
SignatureSchemes: c.config.SignatureSchemes,
|
||||
PSKs: c.config.PSKs,
|
||||
PSKModes: c.config.PSKModes,
|
||||
AllowEarlyData: c.config.AllowEarlyData,
|
||||
RequireCookie: c.config.RequireCookie,
|
||||
CookieHandler: c.config.CookieHandler,
|
||||
RequireClientAuth: c.config.RequireClientAuth,
|
||||
NextProtos: c.config.NextProtos,
|
||||
Certificates: c.config.Certificates,
|
||||
ExtensionHandler: c.extHandler,
|
||||
}
|
||||
opts := ConnectionOptions{
|
||||
ServerName: c.config.ServerName,
|
||||
NextProtos: c.config.NextProtos,
|
||||
EarlyData: c.EarlyData,
|
||||
}
|
||||
|
||||
if caps.RequireCookie && caps.CookieHandler == nil {
|
||||
caps.CookieHandler = &defaultCookieHandler{}
|
||||
}
|
||||
|
||||
if c.isClient {
|
||||
state, actions, alert = ClientStateStart{Caps: caps, Opts: opts}.Next(nil)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error initializing client state: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
} else {
|
||||
state = ServerStateStart{Caps: caps, conn: c}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
// Handshake causes a TLS handshake on the connection. The `isClient` member
|
||||
// determines whether a client or server handshake is performed. If a
|
||||
// handshake has already been performed, then its result will be returned.
|
||||
func (c *Conn) Handshake() Alert {
|
||||
label := "[server]"
|
||||
if c.isClient {
|
||||
label = "[client]"
|
||||
}
|
||||
|
||||
// TODO Lock handshakeMutex
|
||||
// TODO Remove CloseNotify hack
|
||||
if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify {
|
||||
logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert)
|
||||
return c.handshakeAlert
|
||||
}
|
||||
if c.handshakeComplete {
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
var alert Alert
|
||||
if c.hState == nil {
|
||||
logf(logTypeHandshake, "%s First time through handshake, setting up", label)
|
||||
alert = c.HandshakeSetup()
|
||||
if alert != AlertNoAlert {
|
||||
return alert
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "Re-entering handshake, state=%v", c.hState)
|
||||
}
|
||||
|
||||
state := c.hState
|
||||
_, connected := state.(StateConnected)
|
||||
|
||||
var actions []HandshakeAction
|
||||
|
||||
for !connected {
|
||||
// Read a handshake message
|
||||
hm, err := c.hIn.ReadMessage()
|
||||
if err == WouldBlock {
|
||||
logf(logTypeHandshake, "%s Would block reading message: %v", label, err)
|
||||
return AlertWouldBlock
|
||||
}
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "%s Error reading message: %v", label, err)
|
||||
c.sendAlert(AlertCloseNotify)
|
||||
return AlertCloseNotify
|
||||
}
|
||||
logf(logTypeHandshake, "Read message with type: %v", hm.msgType)
|
||||
|
||||
// Advance the state machine
|
||||
state, actions, alert = state.Next(hm)
|
||||
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error in state transition: %v", alert)
|
||||
return alert
|
||||
}
|
||||
|
||||
for index, action := range actions {
|
||||
logf(logTypeHandshake, "%s taking next action (%d)", label, index)
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
|
||||
c.hState = state
|
||||
logf(logTypeHandshake, "state is now %s", c.GetHsState())
|
||||
|
||||
_, connected = state.(StateConnected)
|
||||
}
|
||||
|
||||
c.state = state.(StateConnected)
|
||||
|
||||
// Send NewSessionTicket if acting as server
|
||||
if !c.isClient && c.config.SendSessionTickets {
|
||||
actions, alert := c.state.NewSessionTicket(
|
||||
c.config.TicketLen,
|
||||
c.config.TicketLifetime,
|
||||
c.config.EarlyDataLifetime)
|
||||
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
|
||||
c.sendAlert(alert)
|
||||
return alert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
return AlertNoAlert
|
||||
}
|
||||
|
||||
func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
|
||||
if !c.handshakeComplete {
|
||||
return fmt.Errorf("Cannot update keys until after handshake")
|
||||
}
|
||||
|
||||
request := KeyUpdateNotRequested
|
||||
if requestUpdate {
|
||||
request = KeyUpdateRequested
|
||||
}
|
||||
|
||||
// Create the key update and update state
|
||||
actions, alert := c.state.KeyUpdate(request)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert while generating key update: %v", alert)
|
||||
}
|
||||
|
||||
// Take actions (send key update and rekey)
|
||||
for _, action := range actions {
|
||||
alert = c.takeAction(action)
|
||||
if alert != AlertNoAlert {
|
||||
c.sendAlert(alert)
|
||||
return fmt.Errorf("Alert during key update actions: %v", alert)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) GetHsState() string {
|
||||
return reflect.TypeOf(c.hState).Name()
|
||||
}
|
||||
|
||||
func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
_, connected := c.hState.(StateConnected)
|
||||
if !connected {
|
||||
return nil, fmt.Errorf("Cannot compute exporter when state is not connected")
|
||||
}
|
||||
|
||||
if c.state.exporterSecret == nil {
|
||||
return nil, fmt.Errorf("Internal error: no exporter secret")
|
||||
}
|
||||
|
||||
h0 := c.state.cryptoParams.Hash.New().Sum(nil)
|
||||
tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0)
|
||||
|
||||
hc := c.state.cryptoParams.Hash.New().Sum(context)
|
||||
return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil
|
||||
}
|
||||
|
||||
func (c *Conn) State() ConnectionState {
|
||||
state := ConnectionState{
|
||||
HandshakeState: c.GetHsState(),
|
||||
}
|
||||
|
||||
if c.handshakeComplete {
|
||||
state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite]
|
||||
state.NextProto = c.state.Params.NextProto
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (c *Conn) SetExtensionHandler(h AppExtensionHandler) error {
|
||||
if c.hState != nil {
|
||||
return fmt.Errorf("Can't set extension handler after setup")
|
||||
}
|
||||
|
||||
c.extHandler = h
|
||||
return nil
|
||||
}
|
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
654
vendor/github.com/bifurcation/mint/crypto.go
generated
vendored
Normal file
|
@ -0,0 +1,654 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
// Blank includes to ensure hash support
|
||||
_ "crypto/sha1"
|
||||
_ "crypto/sha256"
|
||||
_ "crypto/sha512"
|
||||
)
|
||||
|
||||
var prng = rand.Reader
|
||||
|
||||
type aeadFactory func(key []byte) (cipher.AEAD, error)
|
||||
|
||||
type CipherSuiteParams struct {
|
||||
Suite CipherSuite
|
||||
Cipher aeadFactory // Cipher factory
|
||||
Hash crypto.Hash // Hash function
|
||||
KeyLen int // Key length in octets
|
||||
IvLen int // IV length in octets
|
||||
}
|
||||
|
||||
type signatureAlgorithm uint8
|
||||
|
||||
const (
|
||||
signatureAlgorithmUnknown = iota
|
||||
signatureAlgorithmRSA_PKCS1
|
||||
signatureAlgorithmRSA_PSS
|
||||
signatureAlgorithmECDSA
|
||||
)
|
||||
|
||||
var (
|
||||
hashMap = map[SignatureScheme]crypto.Hash{
|
||||
RSA_PKCS1_SHA1: crypto.SHA1,
|
||||
RSA_PKCS1_SHA256: crypto.SHA256,
|
||||
RSA_PKCS1_SHA384: crypto.SHA384,
|
||||
RSA_PKCS1_SHA512: crypto.SHA512,
|
||||
ECDSA_P256_SHA256: crypto.SHA256,
|
||||
ECDSA_P384_SHA384: crypto.SHA384,
|
||||
ECDSA_P521_SHA512: crypto.SHA512,
|
||||
RSA_PSS_SHA256: crypto.SHA256,
|
||||
RSA_PSS_SHA384: crypto.SHA384,
|
||||
RSA_PSS_SHA512: crypto.SHA512,
|
||||
}
|
||||
|
||||
sigMap = map[SignatureScheme]signatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1,
|
||||
RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1,
|
||||
ECDSA_P256_SHA256: signatureAlgorithmECDSA,
|
||||
ECDSA_P384_SHA384: signatureAlgorithmECDSA,
|
||||
ECDSA_P521_SHA512: signatureAlgorithmECDSA,
|
||||
RSA_PSS_SHA256: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA384: signatureAlgorithmRSA_PSS,
|
||||
RSA_PSS_SHA512: signatureAlgorithmRSA_PSS,
|
||||
}
|
||||
|
||||
curveMap = map[SignatureScheme]NamedGroup{
|
||||
ECDSA_P256_SHA256: P256,
|
||||
ECDSA_P384_SHA384: P384,
|
||||
ECDSA_P521_SHA512: P521,
|
||||
}
|
||||
|
||||
newAESGCM = func(key []byte) (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TLS always uses 12-byte nonces
|
||||
return cipher.NewGCMWithNonceSize(block, 12)
|
||||
}
|
||||
|
||||
cipherSuiteMap = map[CipherSuite]CipherSuiteParams{
|
||||
TLS_AES_128_GCM_SHA256: {
|
||||
Suite: TLS_AES_128_GCM_SHA256,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA256,
|
||||
KeyLen: 16,
|
||||
IvLen: 12,
|
||||
},
|
||||
TLS_AES_256_GCM_SHA384: {
|
||||
Suite: TLS_AES_256_GCM_SHA384,
|
||||
Cipher: newAESGCM,
|
||||
Hash: crypto.SHA384,
|
||||
KeyLen: 32,
|
||||
IvLen: 12,
|
||||
},
|
||||
}
|
||||
|
||||
x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{
|
||||
RSA_PKCS1_SHA1: x509.SHA1WithRSA,
|
||||
RSA_PKCS1_SHA256: x509.SHA256WithRSA,
|
||||
RSA_PKCS1_SHA384: x509.SHA384WithRSA,
|
||||
RSA_PKCS1_SHA512: x509.SHA512WithRSA,
|
||||
ECDSA_P256_SHA256: x509.ECDSAWithSHA256,
|
||||
ECDSA_P384_SHA384: x509.ECDSAWithSHA384,
|
||||
ECDSA_P521_SHA512: x509.ECDSAWithSHA512,
|
||||
}
|
||||
|
||||
defaultRSAKeySize = 2048
|
||||
)
|
||||
|
||||
func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) {
|
||||
switch group {
|
||||
case P256:
|
||||
crv = elliptic.P256()
|
||||
case P384:
|
||||
crv = elliptic.P384()
|
||||
case P521:
|
||||
crv = elliptic.P521()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) {
|
||||
switch key.Curve.Params().Name {
|
||||
case elliptic.P256().Params().Name:
|
||||
g = P256
|
||||
case elliptic.P384().Params().Name:
|
||||
g = P384
|
||||
case elliptic.P521().Params().Name:
|
||||
g = P521
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) {
|
||||
size = 0
|
||||
switch group {
|
||||
case X25519:
|
||||
size = 32
|
||||
case P256:
|
||||
size = 65
|
||||
case P384:
|
||||
size = 97
|
||||
case P521:
|
||||
size = 133
|
||||
case FFDHE2048:
|
||||
size = 256
|
||||
case FFDHE3072:
|
||||
size = 384
|
||||
case FFDHE4096:
|
||||
size = 512
|
||||
case FFDHE6144:
|
||||
size = 768
|
||||
case FFDHE8192:
|
||||
size = 1024
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func primeFromNamedGroup(group NamedGroup) (p *big.Int) {
|
||||
switch group {
|
||||
case FFDHE2048:
|
||||
p = finiteFieldPrime2048
|
||||
case FFDHE3072:
|
||||
p = finiteFieldPrime3072
|
||||
case FFDHE4096:
|
||||
p = finiteFieldPrime4096
|
||||
case FFDHE6144:
|
||||
p = finiteFieldPrime6144
|
||||
case FFDHE8192:
|
||||
p = finiteFieldPrime8192
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool {
|
||||
sigType := sigMap[alg]
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS
|
||||
case *ecdsa.PrivateKey:
|
||||
return sigType == signatureAlgorithmECDSA
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) {
|
||||
primeLen := len(p.Bytes())
|
||||
for {
|
||||
// g = 2 for all ffdhe groups
|
||||
priv, err = rand.Int(prng, p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = big.NewInt(0)
|
||||
pub.Exp(big.NewInt(2), priv, p)
|
||||
|
||||
if len(pub.Bytes()) == primeLen {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
var x, y *big.Int
|
||||
crv := curveFromNamedGroup(group)
|
||||
priv, x, y, err = elliptic.GenerateKey(crv, prng)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pub = elliptic.Marshal(crv, x, y)
|
||||
return
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
p := primeFromNamedGroup(group)
|
||||
x, X, err2 := ffdheKeyShareFromPrime(p)
|
||||
if err2 != nil {
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
|
||||
priv = x.Bytes()
|
||||
pubBytes := X.Bytes()
|
||||
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
|
||||
pub = make([]byte, numBytes)
|
||||
copy(pub[numBytes-len(pubBytes):], pubBytes)
|
||||
|
||||
return
|
||||
|
||||
case X25519:
|
||||
var private, public [32]byte
|
||||
_, err = prng.Read(private[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
curve25519.ScalarBaseMult(&public, &private)
|
||||
priv = private[:]
|
||||
pub = public[:]
|
||||
return
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) {
|
||||
switch group {
|
||||
case P256, P384, P521:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
crv := curveFromNamedGroup(group)
|
||||
pubX, pubY := elliptic.Unmarshal(crv, pub)
|
||||
x, _ := crv.Params().ScalarMult(pubX, pubY, priv)
|
||||
xBytes := x.Bytes()
|
||||
|
||||
numBytes := len(crv.Params().P.Bytes())
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(xBytes):], xBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192:
|
||||
numBytes := keyExchangeSizeFromNamedGroup(group)
|
||||
if len(pub) != numBytes {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
p := primeFromNamedGroup(group)
|
||||
x := big.NewInt(0).SetBytes(priv)
|
||||
Y := big.NewInt(0).SetBytes(pub)
|
||||
ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes()
|
||||
|
||||
ret := make([]byte, numBytes)
|
||||
copy(ret[numBytes-len(ZBytes):], ZBytes)
|
||||
|
||||
return ret, nil
|
||||
|
||||
case X25519:
|
||||
if len(pub) != keyExchangeSizeFromNamedGroup(group) {
|
||||
return nil, fmt.Errorf("tls.keyagreement: Wrong public key size")
|
||||
}
|
||||
|
||||
var private, public, ret [32]byte
|
||||
copy(private[:], priv)
|
||||
copy(public[:], pub)
|
||||
curve25519.ScalarMult(&ret, &private, &public)
|
||||
|
||||
return ret[:], nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group)
|
||||
}
|
||||
}
|
||||
|
||||
func newSigningKey(sig SignatureScheme) (crypto.Signer, error) {
|
||||
switch sig {
|
||||
case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256,
|
||||
RSA_PKCS1_SHA384, RSA_PKCS1_SHA512,
|
||||
RSA_PSS_SHA256, RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA512:
|
||||
return rsa.GenerateKey(prng, defaultRSAKeySize)
|
||||
case ECDSA_P256_SHA256:
|
||||
return ecdsa.GenerateKey(elliptic.P256(), prng)
|
||||
case ECDSA_P384_SHA384:
|
||||
return ecdsa.GenerateKey(elliptic.P384(), prng)
|
||||
case ECDSA_P521_SHA512:
|
||||
return ecdsa.GenerateKey(elliptic.P521(), prng)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig)
|
||||
}
|
||||
}
|
||||
|
||||
func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) {
|
||||
sigAlg, ok := x509AlgMap[alg]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg)
|
||||
}
|
||||
if len(name) == 0 {
|
||||
return nil, fmt.Errorf("tls.selfsigned: No name provided")
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(0, 0, 1),
|
||||
SignatureAlgorithm: sigAlg,
|
||||
Subject: pkix.Name{CommonName: name},
|
||||
DNSNames: []string{name},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// It is safe to ignore the error here because we're parsing known-good data
|
||||
cert, _ := x509.ParseCertificate(der)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// XXX(rlb): Copied from crypto/x509
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) {
|
||||
var opts crypto.SignerOpts
|
||||
|
||||
hash := hashMap[alg]
|
||||
if hash == crypto.SHA1 {
|
||||
return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
var realInput []byte
|
||||
switch key := privateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size())
|
||||
opts = hash
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size())
|
||||
opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
case *ecdsa.PrivateKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
algGroup := curveMap[alg]
|
||||
keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey))
|
||||
if algGroup != keyGroup {
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput = h.Sum(nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type")
|
||||
}
|
||||
|
||||
sig, err := privateKey.Sign(prng, realInput, opts)
|
||||
logf(logTypeCrypto, "signature: %x", sig)
|
||||
return sig, err
|
||||
}
|
||||
|
||||
func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error {
|
||||
hash := hashMap[alg]
|
||||
|
||||
if hash == crypto.SHA1 {
|
||||
return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden")
|
||||
}
|
||||
|
||||
sigType := sigMap[alg]
|
||||
switch pub := publicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
switch {
|
||||
case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size())
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPKCS1v15(pub, hash, realInput, sig)
|
||||
case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1:
|
||||
fallthrough
|
||||
case sigType == signatureAlgorithmRSA_PSS:
|
||||
logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size())
|
||||
opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
return rsa.VerifyPSS(pub, hash, realInput, sig, opts)
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key")
|
||||
}
|
||||
|
||||
case *ecdsa.PublicKey:
|
||||
if sigType != signatureAlgorithmECDSA {
|
||||
return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key")
|
||||
}
|
||||
|
||||
if curveMap[alg] != namedGroupFromECDSAKey(pub) {
|
||||
return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key")
|
||||
}
|
||||
|
||||
ecdsaSig := new(ecdsaSignature)
|
||||
if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||
return err
|
||||
} else if len(rest) != 0 {
|
||||
return fmt.Errorf("tls.verify: trailing data after ECDSA signature")
|
||||
}
|
||||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||
return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values")
|
||||
}
|
||||
|
||||
h := hash.New()
|
||||
h.Write(sigInput)
|
||||
realInput := h.Sum(nil)
|
||||
if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) {
|
||||
return fmt.Errorf("tls.verify: ECDSA verification failure")
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("tls.verify: Unsupported key type")
|
||||
}
|
||||
}
|
||||
|
||||
// 0
|
||||
// |
|
||||
// v
|
||||
// PSK -> HKDF-Extract = Early Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(.,
|
||||
// | "ext binder" |
|
||||
// | "res binder",
|
||||
// | "")
|
||||
// | = binder_key
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c e traffic",
|
||||
// | ClientHello)
|
||||
// | = client_early_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "e exp master",
|
||||
// | ClientHello)
|
||||
// | = early_exporter_master_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// (EC)DHE -> HKDF-Extract = Handshake Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = client_handshake_traffic_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s hs traffic",
|
||||
// | ClientHello...ServerHello)
|
||||
// | = server_handshake_traffic_secret
|
||||
// v
|
||||
// Derive-Secret(., "derived", "")
|
||||
// |
|
||||
// v
|
||||
// 0 -> HKDF-Extract = Master Secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "c ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = client_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "s ap traffic",
|
||||
// | ClientHello...server Finished)
|
||||
// | = server_application_traffic_secret_0
|
||||
// |
|
||||
// +-----> Derive-Secret(., "exp master",
|
||||
// | ClientHello...server Finished)
|
||||
// | = exporter_master_secret
|
||||
// |
|
||||
// +-----> Derive-Secret(., "res master",
|
||||
// ClientHello...client Finished)
|
||||
// = resumption_master_secret
|
||||
|
||||
// From RFC 5869
|
||||
// PRK = HMAC-Hash(salt, IKM)
|
||||
func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte {
|
||||
salt := saltIn
|
||||
|
||||
// if [salt is] not provided, it is set to a string of HashLen zeros
|
||||
if salt == nil {
|
||||
salt = bytes.Repeat([]byte{0}, hash.Size())
|
||||
}
|
||||
|
||||
h := hmac.New(hash.New, salt)
|
||||
h.Write(input)
|
||||
out := h.Sum(nil)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Extract:\n")
|
||||
logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt)
|
||||
logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input)
|
||||
logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
const (
|
||||
labelExternalBinder = "ext binder"
|
||||
labelResumptionBinder = "res binder"
|
||||
labelEarlyTrafficSecret = "c e traffic"
|
||||
labelEarlyExporterSecret = "e exp master"
|
||||
labelClientHandshakeTrafficSecret = "c hs traffic"
|
||||
labelServerHandshakeTrafficSecret = "s hs traffic"
|
||||
labelClientApplicationTrafficSecret = "c ap traffic"
|
||||
labelServerApplicationTrafficSecret = "s ap traffic"
|
||||
labelExporterSecret = "exp master"
|
||||
labelResumptionSecret = "res master"
|
||||
labelDerived = "derived"
|
||||
labelFinished = "finished"
|
||||
labelResumption = "resumption"
|
||||
)
|
||||
|
||||
// struct HkdfLabel {
|
||||
// uint16 length;
|
||||
// opaque label<9..255>;
|
||||
// opaque hash_value<0..255>;
|
||||
// };
|
||||
func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte {
|
||||
label := "tls13 " + labelIn
|
||||
|
||||
labelLen := len(label)
|
||||
hashLen := len(hashValue)
|
||||
hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen)
|
||||
hkdfLabel[0] = byte(outLen >> 8)
|
||||
hkdfLabel[1] = byte(outLen)
|
||||
hkdfLabel[2] = byte(labelLen)
|
||||
copy(hkdfLabel[3:3+labelLen], []byte(label))
|
||||
hkdfLabel[3+labelLen] = byte(hashLen)
|
||||
copy(hkdfLabel[3+labelLen+1:], hashValue)
|
||||
|
||||
return hkdfLabel
|
||||
}
|
||||
|
||||
func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte {
|
||||
out := []byte{}
|
||||
T := []byte{}
|
||||
i := byte(1)
|
||||
for len(out) < outLen {
|
||||
block := append(T, info...)
|
||||
block = append(block, i)
|
||||
|
||||
h := hmac.New(hash.New, prk)
|
||||
h.Write(block)
|
||||
|
||||
T = h.Sum(nil)
|
||||
out = append(out, T...)
|
||||
i++
|
||||
}
|
||||
return out[:outLen]
|
||||
}
|
||||
|
||||
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte {
|
||||
info := hkdfEncodeLabel(label, hashValue, outLen)
|
||||
derived := HkdfExpand(hash, secret, info, outLen)
|
||||
|
||||
logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen)
|
||||
logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret)
|
||||
logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue)
|
||||
logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info)
|
||||
logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived)
|
||||
|
||||
return derived
|
||||
}
|
||||
|
||||
func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte {
|
||||
return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size())
|
||||
}
|
||||
|
||||
func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte {
|
||||
macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size())
|
||||
mac := hmac.New(params.Hash.New, macKey)
|
||||
mac.Write(input)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
type keySet struct {
|
||||
cipher aeadFactory
|
||||
key []byte
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet {
|
||||
logf(logTypeCrypto, "making traffic keys: secret=%x", secret)
|
||||
return keySet{
|
||||
cipher: params.Cipher,
|
||||
key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen),
|
||||
iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen),
|
||||
}
|
||||
}
|
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
586
vendor/github.com/bifurcation/mint/extensions.go
generated
vendored
Normal file
|
@ -0,0 +1,586 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type ExtensionBody interface {
|
||||
Type() ExtensionType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ExtensionType extension_type;
|
||||
// opaque extension_data<0..2^16-1>;
|
||||
// } Extension;
|
||||
type Extension struct {
|
||||
ExtensionType ExtensionType
|
||||
ExtensionData []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ext Extension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ext)
|
||||
}
|
||||
|
||||
func (ext *Extension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ext)
|
||||
}
|
||||
|
||||
type ExtensionList []Extension
|
||||
|
||||
type extensionListInner struct {
|
||||
List []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (el ExtensionList) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(extensionListInner{el})
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Unmarshal(data []byte) (int, error) {
|
||||
var list extensionListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
*el = list.List
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (el *ExtensionList) Add(src ExtensionBody) error {
|
||||
data, err := src.Marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if el == nil {
|
||||
el = new(ExtensionList)
|
||||
}
|
||||
|
||||
// If one already exists with this type, replace it
|
||||
for i := range *el {
|
||||
if (*el)[i].ExtensionType == src.Type() {
|
||||
(*el)[i].ExtensionData = data
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise append
|
||||
*el = append(*el, Extension{
|
||||
ExtensionType: src.Type(),
|
||||
ExtensionData: data,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (el ExtensionList) Find(dst ExtensionBody) bool {
|
||||
for _, ext := range el {
|
||||
if ext.ExtensionType == dst.Type() {
|
||||
_, err := dst.Unmarshal(ext.ExtensionData)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NameType name_type;
|
||||
// select (name_type) {
|
||||
// case host_name: HostName;
|
||||
// } name;
|
||||
// } ServerName;
|
||||
//
|
||||
// enum {
|
||||
// host_name(0), (255)
|
||||
// } NameType;
|
||||
//
|
||||
// opaque HostName<1..2^16-1>;
|
||||
//
|
||||
// struct {
|
||||
// ServerName server_name_list<1..2^16-1>
|
||||
// } ServerNameList;
|
||||
//
|
||||
// But we only care about the case where there's a single DNS hostname. We
|
||||
// will never create anything else, and throw if we receive something else
|
||||
//
|
||||
// 2 1 2
|
||||
// | listLen | NameType | nameLen | name |
|
||||
type ServerNameExtension string
|
||||
|
||||
type serverNameInner struct {
|
||||
NameType uint8
|
||||
HostName []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
type serverNameListInner struct {
|
||||
ServerNameList []serverNameInner `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Type() ExtensionType {
|
||||
return ExtensionTypeServerName
|
||||
}
|
||||
|
||||
func (sni ServerNameExtension) Marshal() ([]byte, error) {
|
||||
list := serverNameListInner{
|
||||
ServerNameList: []serverNameInner{{
|
||||
NameType: 0x00, // host_name
|
||||
HostName: []byte(sni),
|
||||
}},
|
||||
}
|
||||
|
||||
return syntax.Marshal(list)
|
||||
}
|
||||
|
||||
func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) {
|
||||
var list serverNameListInner
|
||||
read, err := syntax.Unmarshal(data, &list)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Syntax requires at least one entry
|
||||
// Entries beyond the first are ignored
|
||||
if nameType := list.ServerNameList[0].NameType; nameType != 0x00 {
|
||||
return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType)
|
||||
}
|
||||
|
||||
*sni = ServerNameExtension(list.ServerNameList[0].HostName)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup group;
|
||||
// opaque key_exchange<1..2^16-1>;
|
||||
// } KeyShareEntry;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// KeyShareEntry client_shares<0..2^16-1>;
|
||||
//
|
||||
// case hello_retry_request:
|
||||
// NamedGroup selected_group;
|
||||
//
|
||||
// case server_hello:
|
||||
// KeyShareEntry server_share;
|
||||
// };
|
||||
// } KeyShare;
|
||||
type KeyShareEntry struct {
|
||||
Group NamedGroup
|
||||
KeyExchange []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (kse KeyShareEntry) SizeValid() bool {
|
||||
return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group)
|
||||
}
|
||||
|
||||
type KeyShareExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
SelectedGroup NamedGroup
|
||||
Shares []KeyShareEntry
|
||||
}
|
||||
|
||||
type KeyShareClientHelloInner struct {
|
||||
ClientShares []KeyShareEntry `tls:"head=2,min=0"`
|
||||
}
|
||||
type KeyShareHelloRetryInner struct {
|
||||
SelectedGroup NamedGroup
|
||||
}
|
||||
type KeyShareServerHelloInner struct {
|
||||
ServerShare KeyShareEntry
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Type() ExtensionType {
|
||||
return ExtensionTypeKeyShare
|
||||
}
|
||||
|
||||
func (ks KeyShareExtension) Marshal() ([]byte, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
for _, share := range ks.Shares {
|
||||
if !share.SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
return syntax.Marshal(KeyShareClientHelloInner{ks.Shares})
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
if len(ks.Shares) > 0 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(ks.Shares) != 1 {
|
||||
return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share")
|
||||
}
|
||||
|
||||
if !ks.Shares[0].SizeValid() {
|
||||
return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch ks.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner KeyShareClientHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, share := range inner.ClientShares {
|
||||
if !share.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
}
|
||||
|
||||
ks.Shares = inner.ClientShares
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
var inner KeyShareHelloRetryInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ks.SelectedGroup = inner.SelectedGroup
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner KeyShareServerHelloInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !inner.ServerShare.SizeValid() {
|
||||
return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group")
|
||||
}
|
||||
|
||||
ks.Shares = []KeyShareEntry{inner.ServerShare}
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// struct {
|
||||
// NamedGroup named_group_list<2..2^16-1>;
|
||||
// } NamedGroupList;
|
||||
type SupportedGroupsExtension struct {
|
||||
Groups []NamedGroup `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedGroups
|
||||
}
|
||||
|
||||
func (sg SupportedGroupsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sg)
|
||||
}
|
||||
|
||||
func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sg)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme supported_signature_algorithms<2..2^16-2>;
|
||||
// } SignatureSchemeList
|
||||
type SignatureAlgorithmsExtension struct {
|
||||
Algorithms []SignatureScheme `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSignatureAlgorithms
|
||||
}
|
||||
|
||||
func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sa)
|
||||
}
|
||||
|
||||
func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sa)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque identity<1..2^16-1>;
|
||||
// uint32 obfuscated_ticket_age;
|
||||
// } PskIdentity;
|
||||
//
|
||||
// opaque PskBinderEntry<32..255>;
|
||||
//
|
||||
// struct {
|
||||
// select (Handshake.msg_type) {
|
||||
// case client_hello:
|
||||
// PskIdentity identities<7..2^16-1>;
|
||||
// PskBinderEntry binders<33..2^16-1>;
|
||||
//
|
||||
// case server_hello:
|
||||
// uint16 selected_identity;
|
||||
// };
|
||||
//
|
||||
// } PreSharedKeyExtension;
|
||||
type PSKIdentity struct {
|
||||
Identity []byte `tls:"head=2,min=1"`
|
||||
ObfuscatedTicketAge uint32
|
||||
}
|
||||
|
||||
type PSKBinderEntry struct {
|
||||
Binder []byte `tls:"head=1,min=32"`
|
||||
}
|
||||
|
||||
type PreSharedKeyExtension struct {
|
||||
HandshakeType HandshakeType
|
||||
Identities []PSKIdentity
|
||||
Binders []PSKBinderEntry
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
type preSharedKeyClientInner struct {
|
||||
Identities []PSKIdentity `tls:"head=2,min=7"`
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}
|
||||
|
||||
type preSharedKeyServerInner struct {
|
||||
SelectedIdentity uint16
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Type() ExtensionType {
|
||||
return ExtensionTypePreSharedKey
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) Marshal() ([]byte, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
return syntax.Marshal(preSharedKeyClientInner{
|
||||
Identities: psk.Identities,
|
||||
Binders: psk.Binders,
|
||||
})
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
if len(psk.Identities) > 0 || len(psk.Binders) > 0 {
|
||||
return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index")
|
||||
}
|
||||
return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) {
|
||||
switch psk.HandshakeType {
|
||||
case HandshakeTypeClientHello:
|
||||
var inner preSharedKeyClientInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(inner.Identities) != len(inner.Binders) {
|
||||
return 0, fmt.Errorf("Lengths of identities and binders not equal")
|
||||
}
|
||||
|
||||
psk.Identities = inner.Identities
|
||||
psk.Binders = inner.Binders
|
||||
return read, nil
|
||||
|
||||
case HandshakeTypeServerHello:
|
||||
var inner preSharedKeyServerInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
psk.SelectedIdentity = inner.SelectedIdentity
|
||||
return read, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) {
|
||||
for i, localID := range psk.Identities {
|
||||
if bytes.Equal(localID.Identity, id) {
|
||||
return psk.Binders[i].Binder, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
|
||||
//
|
||||
// struct {
|
||||
// PskKeyExchangeMode ke_modes<1..255>;
|
||||
// } PskKeyExchangeModes;
|
||||
type PSKKeyExchangeModesExtension struct {
|
||||
KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType {
|
||||
return ExtensionTypePSKKeyExchangeModes
|
||||
}
|
||||
|
||||
func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(pkem)
|
||||
}
|
||||
|
||||
func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, pkem)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// } EarlyDataIndication;
|
||||
|
||||
type EarlyDataExtension struct{}
|
||||
|
||||
func (ed EarlyDataExtension) Type() ExtensionType {
|
||||
return ExtensionTypeEarlyData
|
||||
}
|
||||
|
||||
func (ed EarlyDataExtension) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 max_early_data_size;
|
||||
// } TicketEarlyDataInfo;
|
||||
|
||||
type TicketEarlyDataInfoExtension struct {
|
||||
MaxEarlyDataSize uint32
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType {
|
||||
return ExtensionTypeTicketEarlyDataInfo
|
||||
}
|
||||
|
||||
func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tedi)
|
||||
}
|
||||
|
||||
func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tedi)
|
||||
}
|
||||
|
||||
// opaque ProtocolName<1..2^8-1>;
|
||||
//
|
||||
// struct {
|
||||
// ProtocolName protocol_name_list<2..2^16-1>
|
||||
// } ProtocolNameList;
|
||||
type ALPNExtension struct {
|
||||
Protocols []string
|
||||
}
|
||||
|
||||
type protocolNameInner struct {
|
||||
Name []byte `tls:"head=1,min=1"`
|
||||
}
|
||||
|
||||
type alpnExtensionInner struct {
|
||||
Protocols []protocolNameInner `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Type() ExtensionType {
|
||||
return ExtensionTypeALPN
|
||||
}
|
||||
|
||||
func (alpn ALPNExtension) Marshal() ([]byte, error) {
|
||||
protocols := make([]protocolNameInner, len(alpn.Protocols))
|
||||
for i, protocol := range alpn.Protocols {
|
||||
protocols[i] = protocolNameInner{[]byte(protocol)}
|
||||
}
|
||||
return syntax.Marshal(alpnExtensionInner{protocols})
|
||||
}
|
||||
|
||||
func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) {
|
||||
var inner alpnExtensionInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
alpn.Protocols = make([]string, len(inner.Protocols))
|
||||
for i, protocol := range inner.Protocols {
|
||||
alpn.Protocols[i] = string(protocol.Name)
|
||||
}
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion versions<2..254>;
|
||||
// } SupportedVersions;
|
||||
type SupportedVersionsExtension struct {
|
||||
Versions []uint16 `tls:"head=1,min=2,max=254"`
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Type() ExtensionType {
|
||||
return ExtensionTypeSupportedVersions
|
||||
}
|
||||
|
||||
func (sv SupportedVersionsExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sv)
|
||||
}
|
||||
|
||||
func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sv)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque cookie<1..2^16-1>;
|
||||
// } Cookie;
|
||||
type CookieExtension struct {
|
||||
Cookie []byte `tls:"head=2,min=1"`
|
||||
}
|
||||
|
||||
func (c CookieExtension) Type() ExtensionType {
|
||||
return ExtensionTypeCookie
|
||||
}
|
||||
|
||||
func (c CookieExtension) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *CookieExtension) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, c)
|
||||
}
|
||||
|
||||
// defaultCookieLength is the default length of a cookie
|
||||
const defaultCookieLength = 32
|
||||
|
||||
type defaultCookieHandler struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
var _ CookieHandler = &defaultCookieHandler{}
|
||||
|
||||
// NewRandomCookie generates a cookie with DefaultCookieLength bytes of random data
|
||||
func (h *defaultCookieHandler) Generate(*Conn) ([]byte, error) {
|
||||
h.data = make([]byte, defaultCookieLength)
|
||||
if _, err := prng.Read(h.data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.data, nil
|
||||
}
|
||||
|
||||
func (h *defaultCookieHandler) Validate(_ *Conn, data []byte) bool {
|
||||
return bytes.Equal(h.data, data)
|
||||
}
|
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
147
vendor/github.com/bifurcation/mint/ffdhe.go
generated
vendored
Normal file
|
@ -0,0 +1,147 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B423861285C97FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex)
|
||||
finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes)
|
||||
|
||||
finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex)
|
||||
finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes)
|
||||
|
||||
finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" +
|
||||
"FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex)
|
||||
finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes)
|
||||
|
||||
finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex)
|
||||
finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes)
|
||||
|
||||
finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" +
|
||||
"D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" +
|
||||
"7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" +
|
||||
"2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" +
|
||||
"984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" +
|
||||
"30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" +
|
||||
"B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" +
|
||||
"0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" +
|
||||
"9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" +
|
||||
"3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" +
|
||||
"886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" +
|
||||
"61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" +
|
||||
"AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" +
|
||||
"64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" +
|
||||
"ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" +
|
||||
"3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" +
|
||||
"7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" +
|
||||
"87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" +
|
||||
"A907600A918130C46DC778F971AD0038092999A333CB8B7A" +
|
||||
"1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" +
|
||||
"8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" +
|
||||
"0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" +
|
||||
"3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" +
|
||||
"CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" +
|
||||
"A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" +
|
||||
"0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" +
|
||||
"763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" +
|
||||
"B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" +
|
||||
"D72B03746AE77F5E62292C311562A846505DC82DB854338A" +
|
||||
"E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" +
|
||||
"5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" +
|
||||
"A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" +
|
||||
"1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" +
|
||||
"0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" +
|
||||
"CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" +
|
||||
"2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" +
|
||||
"BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" +
|
||||
"51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" +
|
||||
"D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" +
|
||||
"1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" +
|
||||
"FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" +
|
||||
"97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" +
|
||||
"D68C8BB7C5C6424CFFFFFFFFFFFFFFFF"
|
||||
finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex)
|
||||
finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes)
|
||||
)
|
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
98
vendor/github.com/bifurcation/mint/frame-reader.go
generated
vendored
Normal file
|
@ -0,0 +1,98 @@
|
|||
// Read a generic "framed" packet consisting of a header and a
|
||||
// This is used for both TLS Records and TLS Handshake Messages
|
||||
package mint
|
||||
|
||||
type framing interface {
|
||||
headerLen() int
|
||||
defaultReadLen() int
|
||||
frameLen(hdr []byte) (int, error)
|
||||
}
|
||||
|
||||
const (
|
||||
kFrameReaderHdr = 0
|
||||
kFrameReaderBody = 1
|
||||
)
|
||||
|
||||
type frameNextAction func(f *frameReader) error
|
||||
|
||||
type frameReader struct {
|
||||
details framing
|
||||
state uint8
|
||||
header []byte
|
||||
body []byte
|
||||
working []byte
|
||||
writeOffset int
|
||||
remainder []byte
|
||||
}
|
||||
|
||||
func newFrameReader(d framing) *frameReader {
|
||||
hdr := make([]byte, d.headerLen())
|
||||
return &frameReader{
|
||||
d,
|
||||
kFrameReaderHdr,
|
||||
hdr,
|
||||
nil,
|
||||
hdr,
|
||||
0,
|
||||
nil,
|
||||
}
|
||||
}
|
||||
|
||||
func dup(a []byte) []byte {
|
||||
r := make([]byte, len(a))
|
||||
copy(r, a)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *frameReader) needed() int {
|
||||
tmp := (len(f.working) - f.writeOffset) - len(f.remainder)
|
||||
if tmp < 0 {
|
||||
return 0
|
||||
}
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (f *frameReader) addChunk(in []byte) {
|
||||
// Append to the buffer.
|
||||
logf(logTypeFrameReader, "Appending %v", len(in))
|
||||
f.remainder = append(f.remainder, in...)
|
||||
}
|
||||
|
||||
func (f *frameReader) process() (hdr []byte, body []byte, err error) {
|
||||
for f.needed() == 0 {
|
||||
logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset)
|
||||
// Fill out our working block
|
||||
copied := copy(f.working[f.writeOffset:], f.remainder)
|
||||
f.remainder = f.remainder[copied:]
|
||||
f.writeOffset += copied
|
||||
if f.writeOffset < len(f.working) {
|
||||
logf(logTypeFrameReader, "Read would have blocked 1")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
||||
// Reset the write offset, because we are now full.
|
||||
f.writeOffset = 0
|
||||
|
||||
// We have read a full frame
|
||||
if f.state == kFrameReaderBody {
|
||||
logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder))
|
||||
f.state = kFrameReaderHdr
|
||||
f.working = f.header
|
||||
return dup(f.header), dup(f.body), nil
|
||||
}
|
||||
|
||||
// We have read the header
|
||||
bodyLen, err := f.details.frameLen(f.header)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen)
|
||||
|
||||
f.body = make([]byte, bodyLen)
|
||||
f.working = f.body
|
||||
f.writeOffset = 0
|
||||
f.state = kFrameReaderBody
|
||||
}
|
||||
|
||||
logf(logTypeFrameReader, "Read would have blocked 2")
|
||||
return nil, nil, WouldBlock
|
||||
}
|
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
253
vendor/github.com/bifurcation/mint/handshake-layer.go
generated
vendored
Normal file
|
@ -0,0 +1,253 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeHeaderLen = 4 // handshake message header length
|
||||
maxHandshakeMessageLen = 1 << 24 // max handshake message length
|
||||
)
|
||||
|
||||
// struct {
|
||||
// HandshakeType msg_type; /* handshake type */
|
||||
// uint24 length; /* bytes in message */
|
||||
// select (HandshakeType) {
|
||||
// ...
|
||||
// } body;
|
||||
// } Handshake;
|
||||
//
|
||||
// We do the select{...} part in a different layer, so we treat the
|
||||
// actual message body as opaque:
|
||||
//
|
||||
// struct {
|
||||
// HandshakeType msg_type;
|
||||
// opaque msg<0..2^24-1>
|
||||
// } Handshake;
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type HandshakeMessage struct {
|
||||
// Omitted: length
|
||||
msgType HandshakeType
|
||||
body []byte
|
||||
}
|
||||
|
||||
// Note: This could be done with the `syntax` module, using the simplified
|
||||
// syntax as discussed above. However, since this is so simple, there's not
|
||||
// much benefit to doing so.
|
||||
func (hm *HandshakeMessage) Marshal() []byte {
|
||||
if hm == nil {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
msgLen := len(hm.body)
|
||||
data := make([]byte, 4+len(hm.body))
|
||||
data[0] = byte(hm.msgType)
|
||||
data[1] = byte(msgLen >> 16)
|
||||
data[2] = byte(msgLen >> 8)
|
||||
data[3] = byte(msgLen)
|
||||
copy(data[4:], hm.body)
|
||||
return data
|
||||
}
|
||||
|
||||
func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) {
|
||||
logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body)
|
||||
|
||||
var body HandshakeMessageBody
|
||||
switch hm.msgType {
|
||||
case HandshakeTypeClientHello:
|
||||
body = new(ClientHelloBody)
|
||||
case HandshakeTypeServerHello:
|
||||
body = new(ServerHelloBody)
|
||||
case HandshakeTypeHelloRetryRequest:
|
||||
body = new(HelloRetryRequestBody)
|
||||
case HandshakeTypeEncryptedExtensions:
|
||||
body = new(EncryptedExtensionsBody)
|
||||
case HandshakeTypeCertificate:
|
||||
body = new(CertificateBody)
|
||||
case HandshakeTypeCertificateRequest:
|
||||
body = new(CertificateRequestBody)
|
||||
case HandshakeTypeCertificateVerify:
|
||||
body = new(CertificateVerifyBody)
|
||||
case HandshakeTypeFinished:
|
||||
body = &FinishedBody{VerifyDataLen: len(hm.body)}
|
||||
case HandshakeTypeNewSessionTicket:
|
||||
body = new(NewSessionTicketBody)
|
||||
case HandshakeTypeKeyUpdate:
|
||||
body = new(KeyUpdateBody)
|
||||
case HandshakeTypeEndOfEarlyData:
|
||||
body = new(EndOfEarlyDataBody)
|
||||
default:
|
||||
return body, fmt.Errorf("tls.handshakemessage: Unsupported body type")
|
||||
}
|
||||
|
||||
_, err := body.Unmarshal(hm.body)
|
||||
return body, err
|
||||
}
|
||||
|
||||
func HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) {
|
||||
data, err := body.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HandshakeMessage{
|
||||
msgType: body.Type(),
|
||||
body: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type HandshakeLayer struct {
|
||||
nonblocking bool // Should we operate in nonblocking mode
|
||||
conn *RecordLayer // Used for reading/writing records
|
||||
frame *frameReader // The buffered frame reader
|
||||
}
|
||||
|
||||
type handshakeLayerFrameDetails struct{}
|
||||
|
||||
func (d handshakeLayerFrameDetails) headerLen() int {
|
||||
return handshakeHeaderLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) defaultReadLen() int {
|
||||
return handshakeHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
logf(logTypeIO, "Header=%x", hdr)
|
||||
return (int(hdr[1]) << 16) | (int(hdr[2]) << 8) | int(hdr[3]), nil
|
||||
}
|
||||
|
||||
func NewHandshakeLayer(r *RecordLayer) *HandshakeLayer {
|
||||
h := HandshakeLayer{}
|
||||
h.conn = r
|
||||
h.frame = newFrameReader(&handshakeLayerFrameDetails{})
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) readRecord() error {
|
||||
logf(logTypeIO, "Trying to read record")
|
||||
pt, err := h.conn.ReadRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pt.contentType != RecordTypeHandshake &&
|
||||
pt.contentType != RecordTypeAlert {
|
||||
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
|
||||
}
|
||||
|
||||
if pt.contentType == RecordTypeAlert {
|
||||
logf(logTypeIO, "read alert %v", pt.fragment[1])
|
||||
if len(pt.fragment) < 2 {
|
||||
h.sendAlert(AlertUnexpectedMessage)
|
||||
return io.EOF
|
||||
}
|
||||
return Alert(pt.fragment[1])
|
||||
}
|
||||
|
||||
logf(logTypeIO, "read handshake record of len %v", len(pt.fragment))
|
||||
h.frame.addChunk(pt.fragment)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (h *HandshakeLayer) sendAlert(err Alert) error {
|
||||
tmp := make([]byte, 2)
|
||||
tmp[0] = AlertLevelError
|
||||
tmp[1] = byte(err)
|
||||
h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeAlert,
|
||||
fragment: tmp},
|
||||
)
|
||||
|
||||
// closeNotify is a special case in that it isn't an error:
|
||||
if err != AlertCloseNotify {
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
|
||||
var hdr, body []byte
|
||||
var err error
|
||||
|
||||
for {
|
||||
logf(logTypeHandshake, "ReadMessage() buffered=%v", len(h.frame.remainder))
|
||||
if h.frame.needed() > 0 {
|
||||
logf(logTypeHandshake, "Trying to read a new record")
|
||||
err = h.readRecord()
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdr, body, err = h.frame.process()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err != nil && (h.nonblocking || err != WouldBlock) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "read handshake message")
|
||||
|
||||
hm := &HandshakeMessage{}
|
||||
hm.msgType = HandshakeType(hdr[0])
|
||||
|
||||
hm.body = make([]byte, len(body))
|
||||
copy(hm.body, body)
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
|
||||
return h.WriteMessages([]*HandshakeMessage{hm})
|
||||
}
|
||||
|
||||
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
|
||||
for _, hm := range hms {
|
||||
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
|
||||
}
|
||||
|
||||
// Write out headers and bodies
|
||||
buffer := []byte{}
|
||||
for _, msg := range hms {
|
||||
msgLen := len(msg.body)
|
||||
if msgLen > maxHandshakeMessageLen {
|
||||
return fmt.Errorf("tls.handshakelayer: Message too large to send")
|
||||
}
|
||||
|
||||
buffer = append(buffer, msg.Marshal()...)
|
||||
}
|
||||
|
||||
// Send full-size fragments
|
||||
var start int
|
||||
for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start : start+maxFragmentLen],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Send a final partial fragment if necessary
|
||||
if start < len(buffer) {
|
||||
err := h.conn.WriteRecord(&TLSPlaintext{
|
||||
contentType: RecordTypeHandshake,
|
||||
fragment: buffer[start:],
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
450
vendor/github.com/bifurcation/mint/handshake-messages.go
generated
vendored
Normal file
|
@ -0,0 +1,450 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint/syntax"
|
||||
)
|
||||
|
||||
type HandshakeMessageBody interface {
|
||||
Type() HandshakeType
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) (int, error)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */
|
||||
// Random random;
|
||||
// opaque legacy_session_id<0..32>;
|
||||
// CipherSuite cipher_suites<2..2^16-2>;
|
||||
// opaque legacy_compression_methods<1..2^8-1>;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ClientHello;
|
||||
type ClientHelloBody struct {
|
||||
// Omitted: clientVersion
|
||||
// Omitted: legacySessionID
|
||||
// Omitted: legacyCompressionMethods
|
||||
Random [32]byte
|
||||
CipherSuites []CipherSuite
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type clientHelloBodyInner struct {
|
||||
LegacyVersion uint16
|
||||
Random [32]byte
|
||||
LegacySessionID []byte `tls:"head=1,max=32"`
|
||||
CipherSuites []CipherSuite `tls:"head=2,min=2"`
|
||||
LegacyCompressionMethods []byte `tls:"head=1,min=1"`
|
||||
Extensions []Extension `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeClientHello
|
||||
}
|
||||
|
||||
func (ch ClientHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(clientHelloBodyInner{
|
||||
LegacyVersion: 0x0303,
|
||||
Random: ch.Random,
|
||||
LegacySessionID: []byte{},
|
||||
CipherSuites: ch.CipherSuites,
|
||||
LegacyCompressionMethods: []byte{0},
|
||||
Extensions: ch.Extensions,
|
||||
})
|
||||
}
|
||||
|
||||
func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
var inner clientHelloBodyInner
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We are strict about these things because we only support 1.3
|
||||
if inner.LegacyVersion != 0x0303 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Incorrect version number")
|
||||
}
|
||||
|
||||
if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 {
|
||||
return 0, fmt.Errorf("tls.clienthello: Invalid compression method")
|
||||
}
|
||||
|
||||
ch.Random = inner.Random
|
||||
ch.CipherSuites = inner.CipherSuites
|
||||
ch.Extensions = inner.Extensions
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// TODO: File a spec bug to clarify this
|
||||
func (ch ClientHelloBody) Truncated() ([]byte, error) {
|
||||
if len(ch.Extensions) == 0 {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: No extensions")
|
||||
}
|
||||
|
||||
pskExt := ch.Extensions[len(ch.Extensions)-1]
|
||||
if pskExt.ExtensionType != ExtensionTypePreSharedKey {
|
||||
return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK")
|
||||
}
|
||||
|
||||
chm, err := HandshakeMessageFromBody(&ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chData := chm.Marshal()
|
||||
|
||||
psk := PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeClientHello,
|
||||
}
|
||||
_, err = psk.Unmarshal(pskExt.ExtensionData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Marshal just the binders so that we know how much to truncate
|
||||
binders := struct {
|
||||
Binders []PSKBinderEntry `tls:"head=2,min=33"`
|
||||
}{Binders: psk.Binders}
|
||||
binderData, _ := syntax.Marshal(binders)
|
||||
binderLen := len(binderData)
|
||||
|
||||
chLen := len(chData)
|
||||
return chData[:chLen-binderLen], nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion server_version;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } HelloRetryRequest;
|
||||
type HelloRetryRequestBody struct {
|
||||
Version uint16
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2,min=2"`
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeHelloRetryRequest
|
||||
}
|
||||
|
||||
func (hrr HelloRetryRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(hrr)
|
||||
}
|
||||
|
||||
func (hrr *HelloRetryRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, hrr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ProtocolVersion version;
|
||||
// Random random;
|
||||
// CipherSuite cipher_suite;
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } ServerHello;
|
||||
type ServerHelloBody struct {
|
||||
Version uint16
|
||||
Random [32]byte
|
||||
CipherSuite CipherSuite
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Type() HandshakeType {
|
||||
return HandshakeTypeServerHello
|
||||
}
|
||||
|
||||
func (sh ServerHelloBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(sh)
|
||||
}
|
||||
|
||||
func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, sh)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque verify_data[verify_data_length];
|
||||
// } Finished;
|
||||
//
|
||||
// verifyDataLen is not a field in the TLS struct, but we add it here so
|
||||
// that calling code can tell us how much data to expect when we marshal /
|
||||
// unmarshal. (We could add this to the marshal/unmarshal methods, but let's
|
||||
// try to keep the signature consistent for now.)
|
||||
//
|
||||
// For similar reasons, we don't use the `syntax` module here, because this
|
||||
// struct doesn't map well to standard TLS presentation language concepts.
|
||||
//
|
||||
// TODO: File a spec bug
|
||||
type FinishedBody struct {
|
||||
VerifyDataLen int
|
||||
VerifyData []byte
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Type() HandshakeType {
|
||||
return HandshakeTypeFinished
|
||||
}
|
||||
|
||||
func (fin FinishedBody) Marshal() ([]byte, error) {
|
||||
if len(fin.VerifyData) != fin.VerifyDataLen {
|
||||
return nil, fmt.Errorf("tls.finished: data length mismatch")
|
||||
}
|
||||
|
||||
body := make([]byte, len(fin.VerifyData))
|
||||
copy(body, fin.VerifyData)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (fin *FinishedBody) Unmarshal(data []byte) (int, error) {
|
||||
if len(data) < fin.VerifyDataLen {
|
||||
return 0, fmt.Errorf("tls.finished: Malformed finished; too short")
|
||||
}
|
||||
|
||||
fin.VerifyData = make([]byte, fin.VerifyDataLen)
|
||||
copy(fin.VerifyData, data[:fin.VerifyDataLen])
|
||||
return fin.VerifyDataLen, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// Extension extensions<0..2^16-1>;
|
||||
// } EncryptedExtensions;
|
||||
//
|
||||
// Marshal() and Unmarshal() are handled by ExtensionList
|
||||
type EncryptedExtensionsBody struct {
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Type() HandshakeType {
|
||||
return HandshakeTypeEncryptedExtensions
|
||||
}
|
||||
|
||||
func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ee)
|
||||
}
|
||||
|
||||
func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ee)
|
||||
}
|
||||
|
||||
// opaque ASN1Cert<1..2^24-1>;
|
||||
//
|
||||
// struct {
|
||||
// ASN1Cert cert_data;
|
||||
// Extension extensions<0..2^16-1>
|
||||
// } CertificateEntry;
|
||||
//
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// CertificateEntry certificate_list<0..2^24-1>;
|
||||
// } Certificate;
|
||||
type CertificateEntry struct {
|
||||
CertData *x509.Certificate
|
||||
Extensions ExtensionList
|
||||
}
|
||||
|
||||
type CertificateBody struct {
|
||||
CertificateRequestContext []byte
|
||||
CertificateList []CertificateEntry
|
||||
}
|
||||
|
||||
type certificateEntryInner struct {
|
||||
CertData []byte `tls:"head=3,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
type certificateBodyInner struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
CertificateList []certificateEntryInner `tls:"head=3"`
|
||||
}
|
||||
|
||||
func (c CertificateBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificate
|
||||
}
|
||||
|
||||
func (c CertificateBody) Marshal() ([]byte, error) {
|
||||
inner := certificateBodyInner{
|
||||
CertificateRequestContext: c.CertificateRequestContext,
|
||||
CertificateList: make([]certificateEntryInner, len(c.CertificateList)),
|
||||
}
|
||||
|
||||
for i, entry := range c.CertificateList {
|
||||
inner.CertificateList[i] = certificateEntryInner{
|
||||
CertData: entry.CertData.Raw,
|
||||
Extensions: entry.Extensions,
|
||||
}
|
||||
}
|
||||
|
||||
return syntax.Marshal(inner)
|
||||
}
|
||||
|
||||
func (c *CertificateBody) Unmarshal(data []byte) (int, error) {
|
||||
inner := certificateBodyInner{}
|
||||
read, err := syntax.Unmarshal(data, &inner)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
|
||||
c.CertificateRequestContext = inner.CertificateRequestContext
|
||||
c.CertificateList = make([]CertificateEntry, len(inner.CertificateList))
|
||||
|
||||
for i, entry := range inner.CertificateList {
|
||||
c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err)
|
||||
}
|
||||
|
||||
c.CertificateList[i].Extensions = entry.Extensions
|
||||
}
|
||||
|
||||
return read, nil
|
||||
}
|
||||
|
||||
// struct {
|
||||
// SignatureScheme algorithm;
|
||||
// opaque signature<0..2^16-1>;
|
||||
// } CertificateVerify;
|
||||
type CertificateVerifyBody struct {
|
||||
Algorithm SignatureScheme
|
||||
Signature []byte `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateVerify
|
||||
}
|
||||
|
||||
func (cv CertificateVerifyBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cv)
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte {
|
||||
// TODO: Change context for client auth
|
||||
// TODO: Put this in a const
|
||||
const context = "TLS 1.3, server CertificateVerify"
|
||||
sigInput := bytes.Repeat([]byte{0x20}, 64)
|
||||
sigInput = append(sigInput, []byte(context)...)
|
||||
sigInput = append(sigInput, []byte{0}...)
|
||||
sigInput = append(sigInput, data...)
|
||||
return sigInput
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput)
|
||||
logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return
|
||||
}
|
||||
|
||||
func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error {
|
||||
sigInput := cv.EncodeSignatureInput(handshakeHash)
|
||||
logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature)
|
||||
return verify(cv.Algorithm, publicKey, sigInput, cv.Signature)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// opaque certificate_request_context<0..2^8-1>;
|
||||
// Extension extensions<2..2^16-1>;
|
||||
// } CertificateRequest;
|
||||
type CertificateRequestBody struct {
|
||||
CertificateRequestContext []byte `tls:"head=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Type() HandshakeType {
|
||||
return HandshakeTypeCertificateRequest
|
||||
}
|
||||
|
||||
func (cr CertificateRequestBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(cr)
|
||||
}
|
||||
|
||||
func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, cr)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// uint32 ticket_lifetime;
|
||||
// uint32 ticket_age_add;
|
||||
// opaque ticket_nonce<1..255>;
|
||||
// opaque ticket<1..2^16-1>;
|
||||
// Extension extensions<0..2^16-2>;
|
||||
// } NewSessionTicket;
|
||||
type NewSessionTicketBody struct {
|
||||
TicketLifetime uint32
|
||||
TicketAgeAdd uint32
|
||||
TicketNonce []byte `tls:"head=1,min=1"`
|
||||
Ticket []byte `tls:"head=2,min=1"`
|
||||
Extensions ExtensionList `tls:"head=2"`
|
||||
}
|
||||
|
||||
const ticketNonceLen = 16
|
||||
|
||||
func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) {
|
||||
buf := make([]byte, 4+ticketNonceLen+ticketLen)
|
||||
_, err := prng.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tkt := &NewSessionTicketBody{
|
||||
TicketLifetime: ticketLifetime,
|
||||
TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]),
|
||||
TicketNonce: buf[4 : 4+ticketNonceLen],
|
||||
Ticket: buf[4+ticketNonceLen:],
|
||||
}
|
||||
|
||||
return tkt, err
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Type() HandshakeType {
|
||||
return HandshakeTypeNewSessionTicket
|
||||
}
|
||||
|
||||
func (tkt NewSessionTicketBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(tkt)
|
||||
}
|
||||
|
||||
func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, tkt)
|
||||
}
|
||||
|
||||
// enum {
|
||||
// update_not_requested(0), update_requested(1), (255)
|
||||
// } KeyUpdateRequest;
|
||||
//
|
||||
// struct {
|
||||
// KeyUpdateRequest request_update;
|
||||
// } KeyUpdate;
|
||||
type KeyUpdateBody struct {
|
||||
KeyUpdateRequest KeyUpdateRequest
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Type() HandshakeType {
|
||||
return HandshakeTypeKeyUpdate
|
||||
}
|
||||
|
||||
func (ku KeyUpdateBody) Marshal() ([]byte, error) {
|
||||
return syntax.Marshal(ku)
|
||||
}
|
||||
|
||||
func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) {
|
||||
return syntax.Unmarshal(data, ku)
|
||||
}
|
||||
|
||||
// struct {} EndOfEarlyData;
|
||||
type EndOfEarlyDataBody struct{}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Type() HandshakeType {
|
||||
return HandshakeTypeEndOfEarlyData
|
||||
}
|
||||
|
||||
func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
55
vendor/github.com/bifurcation/mint/log.go
generated
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// We use this environment variable to control logging. It should be a
|
||||
// comma-separated list of log tags (see below) or "*" to enable all logging.
|
||||
const logConfigVar = "MINT_LOG"
|
||||
|
||||
// Pre-defined log types
|
||||
const (
|
||||
logTypeCrypto = "crypto"
|
||||
logTypeHandshake = "handshake"
|
||||
logTypeNegotiation = "negotiation"
|
||||
logTypeIO = "io"
|
||||
logTypeFrameReader = "frame"
|
||||
logTypeVerbose = "verbose"
|
||||
)
|
||||
|
||||
var (
|
||||
logFunction = log.Printf
|
||||
logAll = false
|
||||
logSettings = map[string]bool{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
parseLogEnv(os.Environ())
|
||||
}
|
||||
|
||||
func parseLogEnv(env []string) {
|
||||
for _, stmt := range env {
|
||||
if strings.HasPrefix(stmt, logConfigVar+"=") {
|
||||
val := stmt[len(logConfigVar)+1:]
|
||||
|
||||
if val == "*" {
|
||||
logAll = true
|
||||
} else {
|
||||
for _, t := range strings.Split(val, ",") {
|
||||
logSettings[t] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func logf(tag string, format string, args ...interface{}) {
|
||||
if logAll || logSettings[tag] {
|
||||
fullFormat := fmt.Sprintf("[%s] %s", tag, format)
|
||||
logFunction(fullFormat, args...)
|
||||
}
|
||||
}
|
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
217
vendor/github.com/bifurcation/mint/negotiation.go
generated
vendored
Normal file
|
@ -0,0 +1,217 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func VersionNegotiation(offered, supported []uint16) (bool, uint16) {
|
||||
for _, offeredVersion := range offered {
|
||||
for _, supportedVersion := range supported {
|
||||
logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion)
|
||||
if offeredVersion == supportedVersion {
|
||||
// XXX: Should probably be highest supported version, but for now, we
|
||||
// only support one version, so it doesn't really matter.
|
||||
return true, offeredVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) {
|
||||
for _, share := range keyShares {
|
||||
for _, group := range groups {
|
||||
if group != share.Group {
|
||||
continue
|
||||
}
|
||||
|
||||
pub, priv, err := newKeyShare(share.Group)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv)
|
||||
if err != nil {
|
||||
// If we encounter an error, just keep looking
|
||||
continue
|
||||
}
|
||||
|
||||
return true, group, pub, dhSecret
|
||||
}
|
||||
}
|
||||
|
||||
return false, 0, nil, nil
|
||||
}
|
||||
|
||||
const (
|
||||
ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds
|
||||
)
|
||||
|
||||
func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size())
|
||||
for i, id := range identities {
|
||||
identityHex := hex.EncodeToString(id.Identity)
|
||||
|
||||
psk, ok := psks.Get(identityHex)
|
||||
if !ok {
|
||||
logf(logTypeNegotiation, "No PSK for identity %x", identityHex)
|
||||
continue
|
||||
}
|
||||
|
||||
// For resumption, make sure the ticket age is correct
|
||||
if psk.IsResumption {
|
||||
extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd
|
||||
knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond)
|
||||
ticketAgeDelta := knownTicketAge - extTicketAge
|
||||
if knownTicketAge < extTicketAge {
|
||||
ticketAgeDelta = extTicketAge - knownTicketAge
|
||||
}
|
||||
if ticketAgeDelta > ticketAgeTolerance {
|
||||
logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity)
|
||||
logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]",
|
||||
extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity)
|
||||
}
|
||||
}
|
||||
|
||||
params, ok := cipherSuiteMap[psk.CipherSuite]
|
||||
if !ok {
|
||||
err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite)
|
||||
return false, 0, nil, CipherSuiteParams{}, err
|
||||
}
|
||||
|
||||
// Compute binder
|
||||
binderLabel := labelExternalBinder
|
||||
if psk.IsResumption {
|
||||
binderLabel = labelResumptionBinder
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, psk.Key)
|
||||
binderKey := deriveSecret(params, earlySecret, binderLabel, h0)
|
||||
|
||||
// context = ClientHello[truncated]
|
||||
// context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated]
|
||||
ctxHash := params.Hash.New()
|
||||
ctxHash.Write(context)
|
||||
|
||||
binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil))
|
||||
if !bytes.Equal(binder, binders[i].Binder) {
|
||||
logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder)
|
||||
return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity)
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity)
|
||||
return true, i, &psk, params, nil
|
||||
}
|
||||
|
||||
logf(logTypeNegotiation, "Failed to find a usable PSK")
|
||||
return false, 0, nil, CipherSuiteParams{}, nil
|
||||
}
|
||||
|
||||
func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) {
|
||||
logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes)
|
||||
dhAllowed := false
|
||||
dhRequired := true
|
||||
for _, mode := range modes {
|
||||
dhAllowed = dhAllowed || (mode == PSKModeDHEKE)
|
||||
dhRequired = dhRequired && (mode == PSKModeDHEKE)
|
||||
}
|
||||
|
||||
// Use PSK if we can meet DH requirement and modes were provided
|
||||
usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0)
|
||||
|
||||
// Use DH if allowed
|
||||
usingDH := canDoDH && (dhAllowed || !usingPSK)
|
||||
|
||||
logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK)
|
||||
return usingDH, usingPSK
|
||||
}
|
||||
|
||||
func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) {
|
||||
// Select for server name if provided
|
||||
candidates := certs
|
||||
if serverName != nil {
|
||||
candidatesByName := []*Certificate{}
|
||||
for _, cert := range certs {
|
||||
for _, name := range cert.Chain[0].DNSNames {
|
||||
if len(*serverName) > 0 && name == *serverName {
|
||||
candidatesByName = append(candidatesByName, cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidatesByName) == 0 {
|
||||
return nil, 0, fmt.Errorf("No certificates available for server name")
|
||||
}
|
||||
|
||||
candidates = candidatesByName
|
||||
}
|
||||
|
||||
// Select for signature scheme
|
||||
for _, cert := range candidates {
|
||||
for _, scheme := range signatureSchemes {
|
||||
if !schemeValidForKey(scheme, cert.PrivateKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
return cert, scheme, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
|
||||
}
|
||||
|
||||
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
|
||||
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
|
||||
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
|
||||
return usingEarlyData
|
||||
}
|
||||
|
||||
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {
|
||||
for _, s1 := range offered {
|
||||
if psk != nil {
|
||||
if s1 == psk.CipherSuite {
|
||||
return s1, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s2 := range supported {
|
||||
if s1 == s2 {
|
||||
return s1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil)
|
||||
}
|
||||
|
||||
func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) {
|
||||
for _, p1 := range offered {
|
||||
if psk != nil {
|
||||
if p1 != psk.NextProto {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, p2 := range supported {
|
||||
if p1 == p2 {
|
||||
return p1, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the client offers ALPN on resumption, it must match the earlier one
|
||||
var err error
|
||||
if psk != nil && psk.IsResumption && (len(offered) > 0) {
|
||||
err = fmt.Errorf("ALPN for PSK not provided")
|
||||
}
|
||||
return "", err
|
||||
}
|
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
296
vendor/github.com/bifurcation/mint/record-layer.go
generated
vendored
Normal file
|
@ -0,0 +1,296 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
sequenceNumberLen = 8 // sequence number length
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxFragmentLen = 1 << 14 // max number of bytes in a record
|
||||
)
|
||||
|
||||
type DecryptError string
|
||||
|
||||
func (err DecryptError) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
// struct {
|
||||
// ContentType type;
|
||||
// ProtocolVersion record_version = { 3, 1 }; /* TLS v1.x */
|
||||
// uint16 length;
|
||||
// opaque fragment[TLSPlaintext.length];
|
||||
// } TLSPlaintext;
|
||||
type TLSPlaintext struct {
|
||||
// Omitted: record_version (static)
|
||||
// Omitted: length (computed from fragment)
|
||||
contentType RecordType
|
||||
fragment []byte
|
||||
}
|
||||
|
||||
type RecordLayer struct {
|
||||
sync.Mutex
|
||||
|
||||
conn io.ReadWriter // The underlying connection
|
||||
frame *frameReader // The buffered frame reader
|
||||
nextData []byte // The next record to send
|
||||
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
|
||||
cachedError error // Error on the last record read
|
||||
|
||||
ivLength int // Length of the seq and nonce fields
|
||||
seq []byte // Zero-padded sequence number
|
||||
nonce []byte // Buffer for per-record nonces
|
||||
cipher cipher.AEAD // AEAD cipher
|
||||
}
|
||||
|
||||
type recordLayerFrameDetails struct{}
|
||||
|
||||
func (d recordLayerFrameDetails) headerLen() int {
|
||||
return recordHeaderLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) defaultReadLen() int {
|
||||
return recordHeaderLen + maxFragmentLen
|
||||
}
|
||||
|
||||
func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
|
||||
return (int(hdr[3]) << 8) | int(hdr[4]), nil
|
||||
}
|
||||
|
||||
func NewRecordLayer(conn io.ReadWriter) *RecordLayer {
|
||||
r := RecordLayer{}
|
||||
r.conn = conn
|
||||
r.frame = newFrameReader(recordLayerFrameDetails{})
|
||||
r.ivLength = 0
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *RecordLayer) Rekey(cipher aeadFactory, key []byte, iv []byte) error {
|
||||
var err error
|
||||
r.cipher, err = cipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ivLength = len(iv)
|
||||
r.seq = bytes.Repeat([]byte{0}, r.ivLength)
|
||||
r.nonce = make([]byte, r.ivLength)
|
||||
copy(r.nonce, iv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) incrementSequenceNumber() {
|
||||
if r.ivLength == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := r.ivLength - 1; i > r.ivLength-sequenceNumberLen; i-- {
|
||||
r.seq[i]++
|
||||
r.nonce[i] ^= (r.seq[i] - 1) ^ r.seq[i]
|
||||
if r.seq[i] != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Not allowed to let sequence number wrap.
|
||||
// Instead, must renegotiate before it does.
|
||||
// Not likely enough to bother.
|
||||
panic("TLS: sequence number wraparound")
|
||||
}
|
||||
|
||||
func (r *RecordLayer) encrypt(pt *TLSPlaintext, padLen int) *TLSPlaintext {
|
||||
// Expand the fragment to hold contentType, padding, and overhead
|
||||
originalLen := len(pt.fragment)
|
||||
plaintextLen := originalLen + 1 + padLen
|
||||
ciphertextLen := plaintextLen + r.cipher.Overhead()
|
||||
|
||||
// Assemble the revised plaintext
|
||||
out := &TLSPlaintext{
|
||||
contentType: RecordTypeApplicationData,
|
||||
fragment: make([]byte, ciphertextLen),
|
||||
}
|
||||
copy(out.fragment, pt.fragment)
|
||||
out.fragment[originalLen] = byte(pt.contentType)
|
||||
for i := 1; i <= padLen; i++ {
|
||||
out.fragment[originalLen+i] = 0
|
||||
}
|
||||
|
||||
// Encrypt the fragment
|
||||
payload := out.fragment[:plaintextLen]
|
||||
r.cipher.Seal(payload[:0], r.nonce, payload, nil)
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *RecordLayer) decrypt(pt *TLSPlaintext) (*TLSPlaintext, int, error) {
|
||||
if len(pt.fragment) < r.cipher.Overhead() {
|
||||
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.Overhead())
|
||||
return nil, 0, DecryptError(msg)
|
||||
}
|
||||
|
||||
decryptLen := len(pt.fragment) - r.cipher.Overhead()
|
||||
out := &TLSPlaintext{
|
||||
contentType: pt.contentType,
|
||||
fragment: make([]byte, decryptLen),
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
_, err := r.cipher.Open(out.fragment[:0], r.nonce, pt.fragment, nil)
|
||||
if err != nil {
|
||||
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
|
||||
}
|
||||
|
||||
// Find the padding boundary
|
||||
padLen := 0
|
||||
for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
|
||||
}
|
||||
|
||||
// Transfer the content type
|
||||
newLen := decryptLen - padLen - 1
|
||||
out.contentType = RecordType(out.fragment[newLen])
|
||||
|
||||
// Truncate the message to remove contentType, padding, overhead
|
||||
out.fragment = out.fragment[:newLen]
|
||||
return out, padLen, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
|
||||
var pt *TLSPlaintext
|
||||
var err error
|
||||
|
||||
for {
|
||||
pt, err = r.nextRecord()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if !block || err != WouldBlock {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return pt.contentType, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
|
||||
pt, err := r.nextRecord()
|
||||
|
||||
// Consume the cached record if there was one
|
||||
r.cachedRecord = nil
|
||||
r.cachedError = nil
|
||||
|
||||
return pt, err
|
||||
}
|
||||
|
||||
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
|
||||
if r.cachedRecord != nil {
|
||||
logf(logTypeIO, "Returning cached record")
|
||||
return r.cachedRecord, r.cachedError
|
||||
}
|
||||
|
||||
// Loop until one of three things happens:
|
||||
//
|
||||
// 1. We get a frame
|
||||
// 2. We try to read off the socket and get nothing, in which case
|
||||
// return WouldBlock
|
||||
// 3. We get an error.
|
||||
err := WouldBlock
|
||||
var header, body []byte
|
||||
|
||||
for err != nil {
|
||||
if r.frame.needed() > 0 {
|
||||
buf := make([]byte, recordHeaderLen+maxFragmentLen)
|
||||
n, err := r.conn.Read(buf)
|
||||
if err != nil {
|
||||
logf(logTypeIO, "Error reading, %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nil, WouldBlock
|
||||
}
|
||||
|
||||
logf(logTypeIO, "Read %v bytes", n)
|
||||
|
||||
buf = buf[:n]
|
||||
r.frame.addChunk(buf)
|
||||
}
|
||||
|
||||
header, body, err = r.frame.process()
|
||||
// Loop around on WouldBlock to see if some
|
||||
// data is now available.
|
||||
if err != nil && err != WouldBlock {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pt := &TLSPlaintext{}
|
||||
// Validate content type
|
||||
switch RecordType(header[0]) {
|
||||
default:
|
||||
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
|
||||
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
|
||||
pt.contentType = RecordType(header[0])
|
||||
}
|
||||
|
||||
// Validate version
|
||||
if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
|
||||
return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
|
||||
}
|
||||
|
||||
// Validate size < max
|
||||
size := (int(header[3]) << 8) + int(header[4])
|
||||
if size > maxFragmentLen+256 {
|
||||
return nil, fmt.Errorf("tls.record: Ciphertext size too big")
|
||||
}
|
||||
|
||||
pt.fragment = make([]byte, size)
|
||||
copy(pt.fragment, body)
|
||||
|
||||
// Attempt to decrypt fragment
|
||||
if r.cipher != nil {
|
||||
pt, _, err = r.decrypt(pt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check that plaintext length is not too long
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return nil, fmt.Errorf("tls.record: Plaintext size too big")
|
||||
}
|
||||
|
||||
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.cachedRecord = pt
|
||||
r.incrementSequenceNumber()
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
|
||||
return r.WriteRecordWithPadding(pt, 0)
|
||||
}
|
||||
|
||||
func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
|
||||
if r.cipher != nil {
|
||||
pt = r.encrypt(pt, padLen)
|
||||
} else if padLen > 0 {
|
||||
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
|
||||
}
|
||||
|
||||
if len(pt.fragment) > maxFragmentLen {
|
||||
return fmt.Errorf("tls.record: Record size too big")
|
||||
}
|
||||
|
||||
length := len(pt.fragment)
|
||||
header := []byte{byte(pt.contentType), 0x03, 0x01, byte(length >> 8), byte(length)}
|
||||
record := append(header, pt.fragment...)
|
||||
|
||||
logf(logTypeIO, "RecordLayer.WriteRecord [%d] [%x]", pt.contentType, pt.fragment)
|
||||
|
||||
r.incrementSequenceNumber()
|
||||
_, err := r.conn.Write(record)
|
||||
return err
|
||||
}
|
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
898
vendor/github.com/bifurcation/mint/server-state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,898 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"hash"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Server State Machine
|
||||
//
|
||||
// START <-----+
|
||||
// Recv ClientHello | | Send HelloRetryRequest
|
||||
// v |
|
||||
// RECVD_CH ----+
|
||||
// | Select parameters
|
||||
// | Send ServerHello
|
||||
// v
|
||||
// NEGOTIATED
|
||||
// | Send EncryptedExtensions
|
||||
// | [Send CertificateRequest]
|
||||
// Can send | [Send Certificate + CertificateVerify]
|
||||
// app data --> | Send Finished
|
||||
// after +--------+--------+
|
||||
// here No 0-RTT | | 0-RTT
|
||||
// | v
|
||||
// | WAIT_EOED <---+
|
||||
// | Recv | | | Recv
|
||||
// | EndOfEarlyData | | | early data
|
||||
// | | +-----+
|
||||
// +> WAIT_FLIGHT2 <-+
|
||||
// |
|
||||
// +--------+--------+
|
||||
// No auth | | Client auth
|
||||
// | |
|
||||
// | v
|
||||
// | WAIT_CERT
|
||||
// | Recv | | Recv Certificate
|
||||
// | empty | v
|
||||
// | Certificate | WAIT_CV
|
||||
// | | | Recv
|
||||
// | v | CertificateVerify
|
||||
// +-> WAIT_FINISHED <---+
|
||||
// | Recv Finished
|
||||
// v
|
||||
// CONNECTED
|
||||
//
|
||||
// NB: Not using state RECVD_CH
|
||||
//
|
||||
// State Instructions
|
||||
// START {}
|
||||
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
|
||||
// WAIT_EOED RekeyIn;
|
||||
// WAIT_FLIGHT2 {}
|
||||
// WAIT_CERT_CR {}
|
||||
// WAIT_CERT {}
|
||||
// WAIT_CV {}
|
||||
// WAIT_FINISHED RekeyIn; RekeyOut;
|
||||
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
|
||||
|
||||
type ServerStateStart struct {
|
||||
Caps Capabilities
|
||||
conn *Conn
|
||||
|
||||
cookieSent bool
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateStart) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeClientHello {
|
||||
logf(logTypeHandshake, "[ServerStateStart] unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
ch := &ClientHelloBody{}
|
||||
_, err := ch.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
clientHello := hm
|
||||
connParams := ConnectionParameters{}
|
||||
|
||||
supportedVersions := new(SupportedVersionsExtension)
|
||||
serverName := new(ServerNameExtension)
|
||||
supportedGroups := new(SupportedGroupsExtension)
|
||||
signatureAlgorithms := new(SignatureAlgorithmsExtension)
|
||||
clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello}
|
||||
clientEarlyData := &EarlyDataExtension{}
|
||||
clientALPN := new(ALPNExtension)
|
||||
clientPSKModes := new(PSKKeyExchangeModesExtension)
|
||||
clientCookie := new(CookieExtension)
|
||||
|
||||
// Handle external extensions.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
gotSupportedVersions := ch.Extensions.Find(supportedVersions)
|
||||
gotServerName := ch.Extensions.Find(serverName)
|
||||
gotSupportedGroups := ch.Extensions.Find(supportedGroups)
|
||||
gotSignatureAlgorithms := ch.Extensions.Find(signatureAlgorithms)
|
||||
gotEarlyData := ch.Extensions.Find(clientEarlyData)
|
||||
ch.Extensions.Find(clientKeyShares)
|
||||
ch.Extensions.Find(clientPSK)
|
||||
ch.Extensions.Find(clientALPN)
|
||||
ch.Extensions.Find(clientPSKModes)
|
||||
ch.Extensions.Find(clientCookie)
|
||||
|
||||
if gotServerName {
|
||||
connParams.ServerName = string(*serverName)
|
||||
}
|
||||
|
||||
// If the client didn't send supportedVersions or doesn't support 1.3,
|
||||
// then we're done here.
|
||||
if !gotSupportedVersions {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion})
|
||||
if !versionOK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version")
|
||||
return nil, nil, AlertProtocolVersion
|
||||
}
|
||||
|
||||
if state.Caps.RequireCookie && state.cookieSent && !state.Caps.CookieHandler.Validate(state.conn, clientCookie.Cookie) {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch")
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
|
||||
// Figure out if we can do DH
|
||||
canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Caps.Groups)
|
||||
|
||||
// Figure out if we can do PSK
|
||||
canDoPSK := false
|
||||
var selectedPSK int
|
||||
var psk *PreSharedKey
|
||||
var params CipherSuiteParams
|
||||
if len(clientPSK.Identities) > 0 {
|
||||
contextBase := []byte{}
|
||||
if state.helloRetryRequest != nil {
|
||||
chBytes := state.firstClientHello.Marshal()
|
||||
hrrBytes := state.helloRetryRequest.Marshal()
|
||||
contextBase = append(chBytes, hrrBytes...)
|
||||
}
|
||||
|
||||
chTrunc, err := ch.Truncated()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
context := append(contextBase, chTrunc...)
|
||||
|
||||
canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Caps.PSKs)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out if we actually should do DH / PSK
|
||||
connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes)
|
||||
|
||||
// Select a ciphersuite
|
||||
connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Caps.CipherSuites)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Send a cookie if required
|
||||
// NB: Need to do this here because it's after ciphersuite selection, which
|
||||
// has to be after PSK selection.
|
||||
// XXX: Doing this statefully for now, could be stateless
|
||||
var cookieData []byte
|
||||
if state.Caps.RequireCookie && !state.cookieSent {
|
||||
var err error
|
||||
cookieData, err = state.Caps.CookieHandler.Generate(state.conn)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if cookieData != nil {
|
||||
// Ignoring errors because everything here is newly constructed, so there
|
||||
// shouldn't be marshal errors
|
||||
hrr := &HelloRetryRequestBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: connParams.CipherSuite,
|
||||
}
|
||||
hrr.Extensions.Add(&CookieExtension{Cookie: cookieData})
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
helloRetryRequest, err := HandshakeMessageFromBody(hrr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
params := cipherSuiteMap[connParams.CipherSuite]
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
firstClientHello := &HandshakeMessage{
|
||||
msgType: HandshakeTypeMessageHash,
|
||||
body: h.Sum(nil),
|
||||
}
|
||||
|
||||
nextState := ServerStateStart{
|
||||
Caps: state.Caps,
|
||||
conn: state.conn,
|
||||
cookieSent: true,
|
||||
firstClientHello: firstClientHello,
|
||||
helloRetryRequest: helloRetryRequest,
|
||||
}
|
||||
toSend := []HandshakeAction{SendHandshakeMessage{helloRetryRequest}}
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]")
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
// If we've got no entropy to make keys from, fail
|
||||
if !connParams.UsingDH && !connParams.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
var pskSecret []byte
|
||||
var cert *Certificate
|
||||
var certScheme SignatureScheme
|
||||
if connParams.UsingPSK {
|
||||
pskSecret = psk.Key
|
||||
} else {
|
||||
psk = nil
|
||||
|
||||
// If we're not using a PSK mode, then we need to have certain extensions
|
||||
if !gotServerName || !gotSupportedGroups || !gotSignatureAlgorithms {
|
||||
logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v %v %v)",
|
||||
gotServerName, gotSupportedGroups, gotSignatureAlgorithms)
|
||||
return nil, nil, AlertMissingExtension
|
||||
}
|
||||
|
||||
// Select a certificate
|
||||
name := string(*serverName)
|
||||
var err error
|
||||
cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Caps.Certificates)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err)
|
||||
return nil, nil, AlertAccessDenied
|
||||
}
|
||||
}
|
||||
|
||||
if !connParams.UsingDH {
|
||||
dhSecret = nil
|
||||
}
|
||||
|
||||
// Figure out if we're going to do early data
|
||||
var clientEarlyTrafficSecret []byte
|
||||
connParams.ClientSendingEarlyData = gotEarlyData
|
||||
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, gotEarlyData, state.Caps.AllowEarlyData)
|
||||
if connParams.UsingEarlyData {
|
||||
|
||||
h := params.Hash.New()
|
||||
h.Write(clientHello.Marshal())
|
||||
chHash := h.Sum(nil)
|
||||
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
earlySecret := HkdfExtract(params.Hash, zero, pskSecret)
|
||||
clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
|
||||
}
|
||||
|
||||
// Select a next protocol
|
||||
connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Caps.NextProtos)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err)
|
||||
return nil, nil, AlertNoApplicationProtocol
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
|
||||
return ServerStateNegotiated{
|
||||
Caps: state.Caps,
|
||||
Params: connParams,
|
||||
|
||||
dhGroup: dhGroup,
|
||||
dhPublic: dhPublic,
|
||||
dhSecret: dhSecret,
|
||||
pskSecret: pskSecret,
|
||||
selectedPSK: selectedPSK,
|
||||
cert: cert,
|
||||
certScheme: certScheme,
|
||||
clientEarlyTrafficSecret: clientEarlyTrafficSecret,
|
||||
|
||||
firstClientHello: state.firstClientHello,
|
||||
helloRetryRequest: state.helloRetryRequest,
|
||||
clientHello: clientHello,
|
||||
}.Next(nil)
|
||||
}
|
||||
|
||||
type ServerStateNegotiated struct {
|
||||
Caps Capabilities
|
||||
Params ConnectionParameters
|
||||
|
||||
dhGroup NamedGroup
|
||||
dhPublic []byte
|
||||
dhSecret []byte
|
||||
pskSecret []byte
|
||||
clientEarlyTrafficSecret []byte
|
||||
selectedPSK int
|
||||
cert *Certificate
|
||||
certScheme SignatureScheme
|
||||
|
||||
firstClientHello *HandshakeMessage
|
||||
helloRetryRequest *HandshakeMessage
|
||||
clientHello *HandshakeMessage
|
||||
}
|
||||
|
||||
func (state ServerStateNegotiated) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
// Create the ServerHello
|
||||
sh := &ServerHelloBody{
|
||||
Version: supportedVersion,
|
||||
CipherSuite: state.Params.CipherSuite,
|
||||
}
|
||||
_, err := prng.Read(sh.Random[:])
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
if state.Params.UsingDH {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension")
|
||||
err = sh.Extensions.Add(&KeyShareExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}},
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingPSK {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension")
|
||||
err = sh.Extensions.Add(&PreSharedKeyExtension{
|
||||
HandshakeType: HandshakeTypeServerHello,
|
||||
SelectedIdentity: uint16(state.selectedPSK),
|
||||
})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
serverHello, err := HandshakeMessageFromBody(sh)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
// Look up crypto params
|
||||
params, ok := cipherSuiteMap[sh.CipherSuite]
|
||||
if !ok {
|
||||
logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Start up the handshake hash
|
||||
handshakeHash := params.Hash.New()
|
||||
handshakeHash.Write(state.firstClientHello.Marshal())
|
||||
handshakeHash.Write(state.helloRetryRequest.Marshal())
|
||||
handshakeHash.Write(state.clientHello.Marshal())
|
||||
handshakeHash.Write(serverHello.Marshal())
|
||||
|
||||
// Compute handshake secrets
|
||||
zero := bytes.Repeat([]byte{0}, params.Hash.Size())
|
||||
|
||||
var earlySecret []byte
|
||||
if state.Params.UsingPSK {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret)
|
||||
} else {
|
||||
earlySecret = HkdfExtract(params.Hash, zero, zero)
|
||||
}
|
||||
|
||||
if state.dhSecret == nil {
|
||||
state.dhSecret = zero
|
||||
}
|
||||
|
||||
h0 := params.Hash.New().Sum(nil)
|
||||
h2 := handshakeHash.Sum(nil)
|
||||
preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0)
|
||||
handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret)
|
||||
clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2)
|
||||
serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2)
|
||||
preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0)
|
||||
masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero)
|
||||
|
||||
logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret)
|
||||
logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret)
|
||||
logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret)
|
||||
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret)
|
||||
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
|
||||
|
||||
// Send an EncryptedExtensions message (even if it's empty)
|
||||
eeList := ExtensionList{}
|
||||
if state.Params.NextProto != "" {
|
||||
logf(logTypeHandshake, "[server] sending ALPN extension")
|
||||
err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
if state.Params.UsingEarlyData {
|
||||
logf(logTypeHandshake, "[server] sending EDI extension")
|
||||
err = eeList.Add(&EarlyDataExtension{})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
ee := &EncryptedExtensionsBody{eeList}
|
||||
|
||||
// Run the external extension handler.
|
||||
if state.Caps.ExtensionHandler != nil {
|
||||
err := state.Caps.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
}
|
||||
|
||||
eem, err := HandshakeMessageFromBody(ee)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
handshakeHash.Write(eem.Marshal())
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{serverHello},
|
||||
RekeyOut{Label: "handshake", KeySet: serverHandshakeKeys},
|
||||
SendHandshakeMessage{eem},
|
||||
}
|
||||
|
||||
// Authenticate with a certificate if required
|
||||
if !state.Params.UsingPSK {
|
||||
// Send a CertificateRequest message if we want client auth
|
||||
if state.Caps.RequireClientAuth {
|
||||
state.Params.UsingClientAuth = true
|
||||
|
||||
// XXX: We don't support sending any constraints besides a list of
|
||||
// supported signature algorithms
|
||||
cr := &CertificateRequestBody{}
|
||||
schemes := &SignatureAlgorithmsExtension{Algorithms: state.Caps.SignatureSchemes}
|
||||
err := cr.Extensions.Add(schemes)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
crm, err := HandshakeMessageFromBody(cr)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
//TODO state.state.serverCertificateRequest = cr
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{crm})
|
||||
handshakeHash.Write(crm.Marshal())
|
||||
}
|
||||
|
||||
// Create and send Certificate, CertificateVerify
|
||||
certificate := &CertificateBody{
|
||||
CertificateList: make([]CertificateEntry, len(state.cert.Chain)),
|
||||
}
|
||||
for i, entry := range state.cert.Chain {
|
||||
certificate.CertificateList[i] = CertificateEntry{CertData: entry}
|
||||
}
|
||||
certm, err := HandshakeMessageFromBody(certificate)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certm})
|
||||
handshakeHash.Write(certm.Marshal())
|
||||
|
||||
certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme}
|
||||
logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash)
|
||||
|
||||
hcv := handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
err = certificateVerify.Sign(state.cert.PrivateKey, hcv)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
certvm, err := HandshakeMessageFromBody(certificateVerify)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err)
|
||||
return nil, nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{certvm})
|
||||
handshakeHash.Write(certvm.Marshal())
|
||||
}
|
||||
|
||||
// Compute secrets resulting from the server's first flight
|
||||
h3 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3)
|
||||
|
||||
serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3)
|
||||
logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData)
|
||||
|
||||
// Assemble the Finished message
|
||||
fin := &FinishedBody{
|
||||
VerifyDataLen: len(serverFinishedData),
|
||||
VerifyData: serverFinishedData,
|
||||
}
|
||||
finm, _ := HandshakeMessageFromBody(fin)
|
||||
|
||||
toSend = append(toSend, SendHandshakeMessage{finm})
|
||||
handshakeHash.Write(finm.Marshal())
|
||||
|
||||
// Compute traffic secrets
|
||||
h4 := handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4)
|
||||
logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4)
|
||||
|
||||
clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4)
|
||||
serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4)
|
||||
logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret)
|
||||
logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret)
|
||||
|
||||
serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret)
|
||||
toSend = append(toSend, RekeyOut{Label: "application", KeySet: serverTrafficKeys})
|
||||
|
||||
exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4)
|
||||
logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret)
|
||||
|
||||
if state.Params.UsingEarlyData {
|
||||
clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]")
|
||||
nextState := ServerStateWaitEOED{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "early", KeySet: clientEarlyTrafficKeys},
|
||||
ReadEarlyData{},
|
||||
}...)
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
|
||||
toSend = append(toSend, []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
ReadPastEarlyData{},
|
||||
}...)
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.Caps.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: params,
|
||||
handshakeHash: handshakeHash,
|
||||
masterSecret: masterSecret,
|
||||
clientHandshakeTrafficSecret: clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: clientTrafficSecret,
|
||||
serverTrafficSecret: serverTrafficSecret,
|
||||
exporterSecret: exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitEOED struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitEOED) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if len(hm.body) > 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]")
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "handshake", KeySet: clientHandshakeKeys},
|
||||
}
|
||||
waitFlight2 := ServerStateWaitFlight2{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
nextState, moreToSend, alert := waitFlight2.Next(nil)
|
||||
toSend = append(toSend, moreToSend...)
|
||||
return nextState, toSend, alert
|
||||
}
|
||||
|
||||
type ServerStateWaitFlight2 struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFlight2) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
if state.Params.UsingClientAuth {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]")
|
||||
nextState := ServerStateWaitCert{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
handshakeHash: state.handshakeHash,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCert struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCert) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificate {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
cert := &CertificateBody{}
|
||||
_, err := cert.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message")
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
if len(cert.CertificateList) == 0 {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate")
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]")
|
||||
nextState := ServerStateWaitCV{
|
||||
AuthCertificate: state.AuthCertificate,
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
clientCertificate: cert,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitCV struct {
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
|
||||
clientCertificate *CertificateBody
|
||||
}
|
||||
|
||||
func (state ServerStateWaitCV) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeCertificateVerify {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm))
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
certVerify := &CertificateVerifyBody{}
|
||||
_, err := certVerify.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client signature over handshake hash
|
||||
hcv := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv)
|
||||
|
||||
clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey
|
||||
if err := certVerify.Verify(clientPublicKey, hcv); err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err)
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
if state.AuthCertificate != nil {
|
||||
err := state.AuthCertificate(state.clientCertificate.CertificateList)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate")
|
||||
return nil, nil, AlertBadCertificate
|
||||
}
|
||||
} else {
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] WARNING: No verification of client certificate")
|
||||
}
|
||||
|
||||
// If it passes, record the certificateVerify in the transcript hash
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]")
|
||||
nextState := ServerStateWaitFinished{
|
||||
Params: state.Params,
|
||||
cryptoParams: state.cryptoParams,
|
||||
masterSecret: state.masterSecret,
|
||||
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
|
||||
handshakeHash: state.handshakeHash,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
return nextState, nil, AlertNoAlert
|
||||
}
|
||||
|
||||
type ServerStateWaitFinished struct {
|
||||
Params ConnectionParameters
|
||||
cryptoParams CipherSuiteParams
|
||||
|
||||
masterSecret []byte
|
||||
clientHandshakeTrafficSecret []byte
|
||||
|
||||
handshakeHash hash.Hash
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state ServerStateWaitFinished) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil || hm.msgType != HandshakeTypeFinished {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()}
|
||||
_, err := fin.Unmarshal(hm.body)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
// Verify client Finished data
|
||||
h5 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5)
|
||||
|
||||
clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5)
|
||||
logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData)
|
||||
|
||||
if !bytes.Equal(fin.VerifyData, clientFinishedData) {
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify")
|
||||
return nil, nil, AlertHandshakeFailure
|
||||
}
|
||||
|
||||
// Compute the resumption secret
|
||||
state.handshakeHash.Write(hm.Marshal())
|
||||
h6 := state.handshakeHash.Sum(nil)
|
||||
logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6)
|
||||
|
||||
resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6)
|
||||
logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret)
|
||||
|
||||
// Compute client traffic keys
|
||||
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
|
||||
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
|
||||
nextState := StateConnected{
|
||||
Params: state.Params,
|
||||
isClient: false,
|
||||
cryptoParams: state.cryptoParams,
|
||||
resumptionSecret: resumptionSecret,
|
||||
clientTrafficSecret: state.clientTrafficSecret,
|
||||
serverTrafficSecret: state.serverTrafficSecret,
|
||||
exporterSecret: state.exporterSecret,
|
||||
}
|
||||
toSend := []HandshakeAction{
|
||||
RekeyIn{Label: "application", KeySet: clientTrafficKeys},
|
||||
}
|
||||
return nextState, toSend, AlertNoAlert
|
||||
}
|
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
230
vendor/github.com/bifurcation/mint/state-machine.go
generated
vendored
Normal file
|
@ -0,0 +1,230 @@
|
|||
package mint
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Marker interface for actions that an implementation should take based on
|
||||
// state transitions.
|
||||
type HandshakeAction interface{}
|
||||
|
||||
type SendHandshakeMessage struct {
|
||||
Message *HandshakeMessage
|
||||
}
|
||||
|
||||
type SendEarlyData struct{}
|
||||
|
||||
type ReadEarlyData struct{}
|
||||
|
||||
type ReadPastEarlyData struct{}
|
||||
|
||||
type RekeyIn struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type RekeyOut struct {
|
||||
Label string
|
||||
KeySet keySet
|
||||
}
|
||||
|
||||
type StorePSK struct {
|
||||
PSK PreSharedKey
|
||||
}
|
||||
|
||||
type HandshakeState interface {
|
||||
Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert)
|
||||
}
|
||||
|
||||
type AppExtensionHandler interface {
|
||||
Send(hs HandshakeType, el *ExtensionList) error
|
||||
Receive(hs HandshakeType, el *ExtensionList) error
|
||||
}
|
||||
|
||||
// Capabilities objects represent the capabilities of a TLS client or server,
|
||||
// as an input to TLS negotiation
|
||||
type Capabilities struct {
|
||||
// For both client and server
|
||||
CipherSuites []CipherSuite
|
||||
Groups []NamedGroup
|
||||
SignatureSchemes []SignatureScheme
|
||||
PSKs PreSharedKeyCache
|
||||
Certificates []*Certificate
|
||||
AuthCertificate func(chain []CertificateEntry) error
|
||||
ExtensionHandler AppExtensionHandler
|
||||
|
||||
// For client
|
||||
PSKModes []PSKKeyExchangeMode
|
||||
|
||||
// For server
|
||||
NextProtos []string
|
||||
AllowEarlyData bool
|
||||
RequireCookie bool
|
||||
CookieHandler CookieHandler
|
||||
RequireClientAuth bool
|
||||
}
|
||||
|
||||
// ConnectionOptions objects represent per-connection settings for a client
|
||||
// initiating a connection
|
||||
type ConnectionOptions struct {
|
||||
ServerName string
|
||||
NextProtos []string
|
||||
EarlyData []byte
|
||||
}
|
||||
|
||||
// ConnectionParameters objects represent the parameters negotiated for a
|
||||
// connection.
|
||||
type ConnectionParameters struct {
|
||||
UsingPSK bool
|
||||
UsingDH bool
|
||||
ClientSendingEarlyData bool
|
||||
UsingEarlyData bool
|
||||
UsingClientAuth bool
|
||||
|
||||
CipherSuite CipherSuite
|
||||
ServerName string
|
||||
NextProto string
|
||||
}
|
||||
|
||||
// StateConnected is symmetric between client and server
|
||||
type StateConnected struct {
|
||||
Params ConnectionParameters
|
||||
isClient bool
|
||||
cryptoParams CipherSuiteParams
|
||||
resumptionSecret []byte
|
||||
clientTrafficSecret []byte
|
||||
serverTrafficSecret []byte
|
||||
exporterSecret []byte
|
||||
}
|
||||
|
||||
func (state *StateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) {
|
||||
var trafficKeys keySet
|
||||
if state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
kum, err := HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
SendHandshakeMessage{kum},
|
||||
RekeyOut{Label: "update", KeySet: trafficKeys},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state *StateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) {
|
||||
tkt, err := NewSessionTicket(length, lifetime)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime})
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
newPSK := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: tkt.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: tkt.TicketAgeAdd,
|
||||
}
|
||||
|
||||
tktm, err := HandshakeMessageFromBody(tkt)
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err)
|
||||
return nil, AlertInternalError
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{
|
||||
StorePSK{newPSK},
|
||||
SendHandshakeMessage{tktm},
|
||||
}
|
||||
return toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
func (state StateConnected) Next(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) {
|
||||
if hm == nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message")
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
bodyGeneric, err := hm.ToBody()
|
||||
if err != nil {
|
||||
logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err)
|
||||
return nil, nil, AlertDecodeError
|
||||
}
|
||||
|
||||
switch body := bodyGeneric.(type) {
|
||||
case *KeyUpdateBody:
|
||||
var trafficKeys keySet
|
||||
if !state.isClient {
|
||||
state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret,
|
||||
labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
|
||||
} else {
|
||||
state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret,
|
||||
labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size())
|
||||
trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret)
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{RekeyIn{Label: "update", KeySet: trafficKeys}}
|
||||
|
||||
// If requested, roll outbound keys and send a KeyUpdate
|
||||
if body.KeyUpdateRequest == KeyUpdateRequested {
|
||||
moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested)
|
||||
if alert != AlertNoAlert {
|
||||
return nil, nil, alert
|
||||
}
|
||||
|
||||
toSend = append(toSend, moreToSend...)
|
||||
}
|
||||
|
||||
return state, toSend, AlertNoAlert
|
||||
|
||||
case *NewSessionTicketBody:
|
||||
// XXX: Allow NewSessionTicket in both directions?
|
||||
if !state.isClient {
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
||||
|
||||
resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret,
|
||||
labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size())
|
||||
|
||||
psk := PreSharedKey{
|
||||
CipherSuite: state.cryptoParams.Suite,
|
||||
IsResumption: true,
|
||||
Identity: body.Ticket,
|
||||
Key: resumptionKey,
|
||||
NextProto: state.Params.NextProto,
|
||||
ReceivedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second),
|
||||
TicketAgeAdd: body.TicketAgeAdd,
|
||||
}
|
||||
|
||||
toSend := []HandshakeAction{StorePSK{psk}}
|
||||
return state, toSend, AlertNoAlert
|
||||
}
|
||||
|
||||
logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType)
|
||||
return nil, nil, AlertUnexpectedMessage
|
||||
}
|
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
243
vendor/github.com/bifurcation/mint/syntax/decode.go
generated
vendored
Normal file
|
@ -0,0 +1,243 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v interface{}) (int, error) {
|
||||
// Check for well-formedness.
|
||||
// Avoids filling out half a data structure
|
||||
// before discovering a JSON syntax error.
|
||||
d := decodeState{}
|
||||
d.Write(data)
|
||||
return d.unmarshal(v)
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type decOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type decodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr || rv.IsNil() {
|
||||
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
|
||||
}
|
||||
|
||||
read = d.value(rv)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (e *decodeState) value(v reflect.Value) int {
|
||||
return valueDecoder(v)(e, v, decOpts{})
|
||||
}
|
||||
|
||||
type decoderFunc func(e *decodeState, v reflect.Value, opts decOpts) int
|
||||
|
||||
func valueDecoder(v reflect.Value) decoderFunc {
|
||||
return typeDecoder(v.Type().Elem())
|
||||
}
|
||||
|
||||
func typeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeDecoder(t)
|
||||
}
|
||||
|
||||
func newTypeDecoder(t reflect.Type) decoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintDecoder
|
||||
case reflect.Array:
|
||||
return newArrayDecoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceDecoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructDecoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific decoders below
|
||||
|
||||
func uintDecoder(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
var uintLen int
|
||||
switch v.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
uintLen = 1
|
||||
case reflect.Uint16:
|
||||
uintLen = 2
|
||||
case reflect.Uint32:
|
||||
uintLen = 4
|
||||
case reflect.Uint64:
|
||||
uintLen = 8
|
||||
}
|
||||
|
||||
buf := make([]byte, uintLen)
|
||||
n, err := d.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if n != uintLen {
|
||||
panic(fmt.Errorf("Insufficient data to read uint"))
|
||||
}
|
||||
|
||||
val := uint64(0)
|
||||
for _, b := range buf {
|
||||
val = (val << 8) + uint64(b)
|
||||
}
|
||||
|
||||
v.Elem().SetUint(val)
|
||||
return uintLen
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayDecoder struct {
|
||||
elemDec decoderFunc
|
||||
}
|
||||
|
||||
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
n := v.Elem().Type().Len()
|
||||
read := 0
|
||||
for i := 0; i < n; i += 1 {
|
||||
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newArrayDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &arrayDecoder{typeDecoder(t.Elem())}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceDecoder struct {
|
||||
elementType reflect.Type
|
||||
elementDec decoderFunc
|
||||
}
|
||||
|
||||
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot decode a slice without a header length"))
|
||||
}
|
||||
|
||||
lengthBytes := make([]byte, opts.head)
|
||||
n, err := d.Read(lengthBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != opts.head {
|
||||
panic(fmt.Errorf("Not enough data to read header"))
|
||||
}
|
||||
|
||||
length := uint(0)
|
||||
for _, b := range lengthBytes {
|
||||
length = (length << 8) + uint(b)
|
||||
}
|
||||
|
||||
if opts.max > 0 && length > opts.max {
|
||||
panic(fmt.Errorf("Length of vector exceeds declared max"))
|
||||
}
|
||||
if length < opts.min {
|
||||
panic(fmt.Errorf("Length of vector below declared min"))
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err = d.Read(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if uint(n) != length {
|
||||
panic(fmt.Errorf("Available data less than declared length [%04x < %04x]", n, length))
|
||||
}
|
||||
|
||||
elemBuf := &decodeState{}
|
||||
elemBuf.Write(data)
|
||||
elems := []reflect.Value{}
|
||||
read := int(opts.head)
|
||||
for elemBuf.Len() > 0 {
|
||||
elem := reflect.New(sd.elementType)
|
||||
read += sd.elementDec(elemBuf, elem, opts)
|
||||
elems = append(elems, elem)
|
||||
}
|
||||
|
||||
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
|
||||
for i := 0; i < len(elems); i += 1 {
|
||||
v.Elem().Index(i).Set(elems[i].Elem())
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newSliceDecoder(t reflect.Type) decoderFunc {
|
||||
dec := &sliceDecoder{
|
||||
elementType: t.Elem(),
|
||||
elementDec: typeDecoder(t.Elem()),
|
||||
}
|
||||
return dec.decode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structDecoder struct {
|
||||
fieldOpts []decOpts
|
||||
fieldDecs []decoderFunc
|
||||
}
|
||||
|
||||
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts decOpts) int {
|
||||
read := 0
|
||||
for i := range sd.fieldDecs {
|
||||
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
|
||||
}
|
||||
return read
|
||||
}
|
||||
|
||||
func newStructDecoder(t reflect.Type) decoderFunc {
|
||||
n := t.NumField()
|
||||
sd := structDecoder{
|
||||
fieldOpts: make([]decOpts, n),
|
||||
fieldDecs: make([]decoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
sd.fieldOpts[i] = decOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
|
||||
sd.fieldDecs[i] = typeDecoder(f.Type)
|
||||
}
|
||||
|
||||
return sd.decode
|
||||
}
|
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
187
vendor/github.com/bifurcation/mint/syntax/encode.go
generated
vendored
Normal file
|
@ -0,0 +1,187 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func Marshal(v interface{}) ([]byte, error) {
|
||||
e := &encodeState{}
|
||||
err := e.marshal(v, encOpts{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.Bytes(), nil
|
||||
}
|
||||
|
||||
// These are the options that can be specified in the struct tag. Right now,
|
||||
// all of them apply to variable-length vectors and nothing else
|
||||
type encOpts struct {
|
||||
head uint // length of length in bytes
|
||||
min uint // minimum size in bytes
|
||||
max uint // maximum size in bytes
|
||||
}
|
||||
|
||||
type encodeState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (e *encodeState) marshal(v interface{}, opts encOpts) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if _, ok := r.(runtime.Error); ok {
|
||||
panic(r)
|
||||
}
|
||||
if s, ok := r.(string); ok {
|
||||
panic(s)
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
e.reflectValue(reflect.ValueOf(v), opts)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
|
||||
valueEncoder(v)(e, v, opts)
|
||||
}
|
||||
|
||||
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
|
||||
|
||||
func valueEncoder(v reflect.Value) encoderFunc {
|
||||
if !v.IsValid() {
|
||||
panic(fmt.Errorf("Cannot encode an invalid value"))
|
||||
}
|
||||
return typeEncoder(v.Type())
|
||||
}
|
||||
|
||||
func typeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Omits the caching / wait-group things that encoding/json uses
|
||||
return newTypeEncoder(t)
|
||||
}
|
||||
|
||||
func newTypeEncoder(t reflect.Type) encoderFunc {
|
||||
// Note: Does not support Marshaler, so don't need the allowAddr argument
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return uintEncoder
|
||||
case reflect.Array:
|
||||
return newArrayEncoder(t)
|
||||
case reflect.Slice:
|
||||
return newSliceEncoder(t)
|
||||
case reflect.Struct:
|
||||
return newStructEncoder(t)
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported type (%s)", t))
|
||||
}
|
||||
}
|
||||
|
||||
///// Specific encoders below
|
||||
|
||||
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
u := v.Uint()
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Uint8:
|
||||
e.WriteByte(byte(u))
|
||||
case reflect.Uint16:
|
||||
e.Write([]byte{byte(u >> 8), byte(u)})
|
||||
case reflect.Uint32:
|
||||
e.Write([]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
case reflect.Uint64:
|
||||
e.Write([]byte{byte(u >> 56), byte(u >> 48), byte(u >> 40), byte(u >> 32),
|
||||
byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
|
||||
}
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type arrayEncoder struct {
|
||||
elemEnc encoderFunc
|
||||
}
|
||||
|
||||
func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
n := v.Len()
|
||||
for i := 0; i < n; i += 1 {
|
||||
ae.elemEnc(e, v.Index(i), opts)
|
||||
}
|
||||
}
|
||||
|
||||
func newArrayEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &arrayEncoder{typeEncoder(t.Elem())}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type sliceEncoder struct {
|
||||
ae *arrayEncoder
|
||||
}
|
||||
|
||||
func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
if opts.head == 0 {
|
||||
panic(fmt.Errorf("Cannot encode a slice without a header length"))
|
||||
}
|
||||
|
||||
arrayState := &encodeState{}
|
||||
se.ae.encode(arrayState, v, opts)
|
||||
|
||||
n := uint(arrayState.Len())
|
||||
if opts.max > 0 && n > opts.max {
|
||||
panic(fmt.Errorf("Encoded length more than max [%d > %d]", n, opts.max))
|
||||
}
|
||||
if n>>(8*opts.head) > 0 {
|
||||
panic(fmt.Errorf("Encoded length too long for header length [%d, %d]", n, opts.head))
|
||||
}
|
||||
if n < opts.min {
|
||||
panic(fmt.Errorf("Encoded length less than min [%d < %d]", n, opts.min))
|
||||
}
|
||||
|
||||
for i := int(opts.head - 1); i >= 0; i -= 1 {
|
||||
e.WriteByte(byte(n >> (8 * uint(i))))
|
||||
}
|
||||
e.Write(arrayState.Bytes())
|
||||
}
|
||||
|
||||
func newSliceEncoder(t reflect.Type) encoderFunc {
|
||||
enc := &sliceEncoder{&arrayEncoder{typeEncoder(t.Elem())}}
|
||||
return enc.encode
|
||||
}
|
||||
|
||||
//////////
|
||||
|
||||
type structEncoder struct {
|
||||
fieldOpts []encOpts
|
||||
fieldEncs []encoderFunc
|
||||
}
|
||||
|
||||
func (se *structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
|
||||
for i := range se.fieldEncs {
|
||||
se.fieldEncs[i](e, v.Field(i), se.fieldOpts[i])
|
||||
}
|
||||
}
|
||||
|
||||
func newStructEncoder(t reflect.Type) encoderFunc {
|
||||
n := t.NumField()
|
||||
se := structEncoder{
|
||||
fieldOpts: make([]encOpts, n),
|
||||
fieldEncs: make([]encoderFunc, n),
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 1 {
|
||||
f := t.Field(i)
|
||||
tag := f.Tag.Get("tls")
|
||||
tagOpts := parseTag(tag)
|
||||
|
||||
se.fieldOpts[i] = encOpts{
|
||||
head: tagOpts["head"],
|
||||
max: tagOpts["max"],
|
||||
min: tagOpts["min"],
|
||||
}
|
||||
se.fieldEncs[i] = typeEncoder(f.Type)
|
||||
}
|
||||
|
||||
return se.encode
|
||||
}
|
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
30
vendor/github.com/bifurcation/mint/syntax/tags.go
generated
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
package syntax
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// `tls:"head=2,min=2,max=255"`
|
||||
|
||||
type tagOptions map[string]uint
|
||||
|
||||
// parseTag parses a struct field's "tls" tag as a comma-separated list of
|
||||
// name=value pairs, where the values MUST be unsigned integers
|
||||
func parseTag(tag string) tagOptions {
|
||||
opts := tagOptions{}
|
||||
for _, token := range strings.Split(tag, ",") {
|
||||
if strings.Index(token, "=") == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(token, "=")
|
||||
if len(parts[0]) == 0 {
|
||||
continue
|
||||
}
|
||||
if val, err := strconv.Atoi(parts[1]); err == nil && val >= 0 {
|
||||
opts[parts[0]] = uint(val)
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
168
vendor/github.com/bifurcation/mint/tls.go
generated
vendored
Normal file
|
@ -0,0 +1,168 @@
|
|||
package mint
|
||||
|
||||
// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Server(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, false)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection
|
||||
// using conn as the underlying transport.
|
||||
// The config cannot be nil: users must set either ServerName or
|
||||
// InsecureSkipVerify in the config.
|
||||
func Client(conn net.Conn, config *Config) *Conn {
|
||||
return NewConn(conn, config, true)
|
||||
}
|
||||
|
||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection c is a *tls.Conn.
|
||||
func (l *Listener) Accept() (c net.Conn, err error) {
|
||||
c, err = l.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server := Server(c, l.config)
|
||||
err = server.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
c = server
|
||||
return
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
l := new(Listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if config == nil || !config.ValidForServer() {
|
||||
return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
|
||||
}
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config), nil
|
||||
}
|
||||
|
||||
type TimeoutError struct{}
|
||||
|
||||
func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||
func (TimeoutError) Timeout() bool { return true }
|
||||
func (TimeoutError) Temporary() bool { return true }
|
||||
|
||||
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||
// timeout or deadline given in the dialer apply to connection and TLS
|
||||
// handshake as a whole.
|
||||
//
|
||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||
// configuration; see the documentation of Config for the defaults.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
timeout := dialer.Timeout
|
||||
|
||||
if !dialer.Deadline.IsZero() {
|
||||
deadlineTimeout := dialer.Deadline.Sub(time.Now())
|
||||
if timeout == 0 || deadlineTimeout < timeout {
|
||||
timeout = deadlineTimeout
|
||||
}
|
||||
}
|
||||
|
||||
var errChannel chan error
|
||||
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- TimeoutError{}
|
||||
})
|
||||
}
|
||||
|
||||
rawConn, err := dialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := config.Clone()
|
||||
c.ServerName = hostname
|
||||
config = c
|
||||
}
|
||||
|
||||
conn := Client(rawConn, config)
|
||||
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
|
||||
err = <-errChannel
|
||||
if err == AlertNoAlert {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial connects to the given network address using net.Dial
|
||||
// and then initiates a TLS handshake, returning the resulting
|
||||
// TLS connection.
|
||||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||
}
|
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
32
vendor/github.com/lucas-clemente/quic-go/ackhandler/interfaces.go
generated
vendored
|
@ -1,32 +0,0 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
|
||||
|
||||
SendingAllowed() bool
|
||||
GetStopWaitingFrame(force bool) *frames.StopWaitingFrame
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
|
||||
SetLowerLimit(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *frames.AckFrame
|
||||
}
|
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/buffer_pool.go
generated
vendored
|
@ -3,7 +3,7 @@ package quic
|
|||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
|
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
360
vendor/github.com/lucas-clemente/quic-go/client.go
generated
vendored
|
@ -10,32 +10,39 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
listenErr error
|
||||
mutex sync.Mutex
|
||||
|
||||
conn connection
|
||||
hostname string
|
||||
|
||||
errorChan chan struct{}
|
||||
handshakeChan <-chan handshakeEvent
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
tls handshake.MintTLS // only used when using TLS
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = utils.GenerateConnectionID
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
)
|
||||
|
||||
|
@ -53,71 +60,6 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
|||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrNonFWSecure establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrNonFWSecure(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func DialNonFWSecure(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (NonFWSession, error) {
|
||||
connID, err := utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
errorChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
err = c.createNewSession(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
return c.session.(NonFWSession), c.establishSecureConnection()
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
|
@ -127,15 +69,39 @@ func Dial(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
|
||||
connID, err := generateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = sess.WaitUntilHandshakeComplete()
|
||||
if err != nil {
|
||||
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientConfig := populateClientConfig(config)
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
connectionID: connID,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: clientConfig,
|
||||
version: clientConfig.Versions[0],
|
||||
versionNegotiationChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
|
@ -153,6 +119,10 @@ func populateClientConfig(config *Config) *Config {
|
|||
if config.HandshakeTimeout != 0 {
|
||||
handshakeTimeout = config.HandshakeTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.IdleTimeout != 0 {
|
||||
idleTimeout = config.IdleTimeout
|
||||
}
|
||||
|
||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||
if maxReceiveStreamFlowControlWindow == 0 {
|
||||
|
@ -166,32 +136,109 @@ func populateClientConfig(config *Config) *Config {
|
|||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS()
|
||||
} else {
|
||||
err = c.dialGQUIC()
|
||||
}
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC() error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
return c.establishSecureConnection()
|
||||
}
|
||||
|
||||
func (c *client) dialTLS() error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
// TODO(#523): make these values configurable
|
||||
MaxBidiStreamID: protocol.MaxBidiStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
MaxUniStreamID: protocol.MaxUniStreamID(protocol.MaxIncomingStreams, protocol.PerspectiveClient),
|
||||
}
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
select {
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case <-c.versionNegotiationChan:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.errorChan:
|
||||
return c.listenErr
|
||||
case ev := <-c.handshakeChan:
|
||||
if ev.err != nil {
|
||||
return ev.err
|
||||
}
|
||||
if ev.encLevel != protocol.EncryptionSecure {
|
||||
return fmt.Errorf("Client BUG: Expected encryption level to be secure, was %s", ev.encLevel)
|
||||
}
|
||||
return nil
|
||||
case <-errorChan:
|
||||
return runErr
|
||||
case err := <-c.session.handshakeStatus():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
|
@ -205,13 +252,15 @@ func (c *client) listen() {
|
|||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.session.Close(err)
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
c.handlePacket(addr, data)
|
||||
c.handlePacket(addr, data[:n])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,10 +268,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(packet)
|
||||
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r, c.version)
|
||||
if err != nil {
|
||||
utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
|
||||
// drop this packet if we can't parse the Public Header
|
||||
// drop this packet if we can't parse the header
|
||||
return
|
||||
}
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
|
@ -230,6 +283,11 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
|
@ -238,44 +296,48 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
utils.Infof("Received a spoofed Public Reset. Ignoring.")
|
||||
return
|
||||
}
|
||||
pr, err := parsePublicReset(r)
|
||||
pr, err := wire.ParsePublicReset(r)
|
||||
if err != nil {
|
||||
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
|
||||
utils.Infof("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
return
|
||||
}
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
|
||||
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.RejectedPacketNumber)
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
return
|
||||
}
|
||||
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.versionNegotiated && hdr.VersionFlag {
|
||||
return
|
||||
}
|
||||
// handle Version Negotiation Packets
|
||||
if hdr.IsVersionNegotiation {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet after the client sent a packet with the VersionFlag set
|
||||
// if the server doesn't send a version negotiation packet, it supports the suggested version
|
||||
if !hdr.VersionFlag && !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
}
|
||||
|
||||
if hdr.VersionFlag {
|
||||
// version negotiation packets have no payload
|
||||
if err := c.handlePacketWithVersionFlag(hdr); err != nil {
|
||||
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
publicHeader: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
// the version negotiation packet contains the version that we offered
|
||||
|
@ -285,51 +347,57 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
|
|||
}
|
||||
}
|
||||
|
||||
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if newVersion == protocol.VersionUnsupported {
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
c.versionNegotiated = true
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(hdr.SupportedVersions)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
c.session, c.handshakeChan, err = newClientSession(
|
||||
func (c *client) createNewGQUICSession() (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
negotiatedVersions,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
err := c.session.run()
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return
|
||||
}
|
||||
c.listenErr = err
|
||||
close(c.errorChan)
|
||||
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
c.conn.Close()
|
||||
}()
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.config,
|
||||
c.tls,
|
||||
paramsChan,
|
||||
1,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
58
vendor/github.com/lucas-clemente/quic-go/crypto/aesgcm_aead.go
generated
vendored
|
@ -1,58 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/aes12"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAEADAESGCM creates a AEAD using AES-GCM with 12 bytes tag size
|
||||
//
|
||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||
// tag size, and couples the cipher and aes packages closely.
|
||||
// See https://github.com/lucas-clemente/aes12.
|
||||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||
}
|
||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadAESGCM{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/crypto/nonce.go
generated
vendored
|
@ -1,14 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
76
vendor/github.com/lucas-clemente/quic-go/crypto/source_address_token.go
generated
vendored
|
@ -1,76 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// StkSource is used to create and verify source address tokens
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type stkSource struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
const stkKeySize = 16
|
||||
|
||||
// Chrome currently sets this to 12, but discusses changing it to 16. We start
|
||||
// at 16 :)
|
||||
const stkNonceSize = 16
|
||||
|
||||
// NewStkSource creates a source for source address tokens
|
||||
func NewStkSource() (StkSource, error) {
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := deriveKey(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCMWithNonceSize(c, stkNonceSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &stkSource{aead: aead}, nil
|
||||
}
|
||||
|
||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, stkNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < stkNonceSize {
|
||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:stkNonceSize]
|
||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
r := hkdf.New(sha256.New, secret, nil, []byte("QUIC source address token key"))
|
||||
key := make([]byte, stkKeySize)
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
41
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
generated
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStreamI interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStreamI = &cryptoStream{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStream{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPosition = offset
|
||||
}
|
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/client/main.go
generated
vendored
|
@ -7,12 +7,15 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
|
@ -23,8 +26,17 @@ func main() {
|
|||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
roundTripper := &h2quic.RoundTripper{
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
defer roundTripper.Close()
|
||||
hclient := &http.Client{
|
||||
Transport: &h2quic.RoundTripper{},
|
||||
Transport: roundTripper,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
|
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
14
vendor/github.com/lucas-clemente/quic-go/example/main.go
generated
vendored
|
@ -17,7 +17,9 @@ import (
|
|||
|
||||
_ "net/http/pprof"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -121,6 +123,7 @@ func main() {
|
|||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||
www := flag.String("www", "/var/www", "www data")
|
||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
|
||||
if *verbose {
|
||||
|
@ -130,6 +133,11 @@ func main() {
|
|||
}
|
||||
utils.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
certFile := *certPath + "/fullchain.pem"
|
||||
keyFile := *certPath + "/privkey.pem"
|
||||
|
||||
|
@ -148,7 +156,11 @@ func main() {
|
|||
if *tcp {
|
||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||
} else {
|
||||
err = h2quic.ListenAndServeQUIC(bCap, certFile, keyFile, nil)
|
||||
server := h2quic.Server{
|
||||
Server: &http.Server{Addr: bCap},
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
|
|
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
240
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_control_manager.go
generated
vendored
|
@ -1,240 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type flowControlManager struct {
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
streamFlowController map[protocol.StreamID]*flowController
|
||||
connFlowController *flowController
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var _ FlowControlManager = &flowControlManager{}
|
||||
|
||||
var errMapAccess = errors.New("Error accessing the flowController map.")
|
||||
|
||||
// NewFlowControlManager creates a new flow control manager
|
||||
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
|
||||
return &flowControlManager{
|
||||
connectionParameters: connectionParameters,
|
||||
rttStats: rttStats,
|
||||
streamFlowController: make(map[protocol.StreamID]*flowController),
|
||||
connFlowController: newFlowController(0, false, connectionParameters, rttStats),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStream creates new flow controllers for a stream
|
||||
// it does nothing if the stream already exists
|
||||
func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesToConnection bool) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if _, ok := f.streamFlowController[streamID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
f.streamFlowController[streamID] = newFlowController(streamID, contributesToConnection, f.connectionParameters, f.rttStats)
|
||||
}
|
||||
|
||||
// RemoveStream removes a closed stream from flow control
|
||||
func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
delete(f.streamFlowController, streamID)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ResetStream should be called when receiving a RstStreamFrame
|
||||
// it updates the byte offset to the value in the RstStreamFrame
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
streamFlowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
|
||||
if err != nil {
|
||||
return qerr.StreamDataAfterTermination
|
||||
}
|
||||
|
||||
if streamFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
||||
}
|
||||
|
||||
if streamFlowController.ContributesToConnection() {
|
||||
f.connFlowController.IncrementHighestReceived(increment)
|
||||
if f.connFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highest received byte offset for a stream
|
||||
// it adds the number of additional bytes to connection level flow control
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
streamFlowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
|
||||
// this error can be ignored here
|
||||
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
|
||||
|
||||
if streamFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveWindow))
|
||||
}
|
||||
|
||||
if streamFlowController.ContributesToConnection() {
|
||||
f.connFlowController.IncrementHighestReceived(increment)
|
||||
if f.connFlowController.CheckFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.AddBytesRead(n)
|
||||
if fc.ContributesToConnection() {
|
||||
f.connFlowController.AddBytesRead(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// get WindowUpdates for streams
|
||||
for id, fc := range f.streamFlowController {
|
||||
if necessary, newIncrement, offset := fc.MaybeUpdateWindow(); necessary {
|
||||
res = append(res, WindowUpdate{StreamID: id, Offset: offset})
|
||||
if fc.ContributesToConnection() && newIncrement != 0 {
|
||||
f.connFlowController.EnsureMinimumWindowIncrement(protocol.ByteCount(float64(newIncrement) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
}
|
||||
}
|
||||
// get a WindowUpdate for the connection
|
||||
if necessary, _, offset := f.connFlowController.MaybeUpdateWindow(); necessary {
|
||||
res = append(res, WindowUpdate{StreamID: 0, Offset: offset})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
// StreamID can be 0 when retransmitting
|
||||
if streamID == 0 {
|
||||
return f.connFlowController.receiveWindow, nil
|
||||
}
|
||||
|
||||
flowController, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return flowController.receiveWindow, nil
|
||||
}
|
||||
|
||||
// streamID must not be 0 here
|
||||
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.AddBytesSent(n)
|
||||
if fc.ContributesToConnection() {
|
||||
f.connFlowController.AddBytesSent(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// must not be called with StreamID 0
|
||||
func (f *flowControlManager) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
fc, err := f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res := fc.SendWindowSize()
|
||||
|
||||
if fc.ContributesToConnection() {
|
||||
res = utils.MinByteCount(res, f.connFlowController.SendWindowSize())
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
return f.connFlowController.SendWindowSize()
|
||||
}
|
||||
|
||||
// streamID may be 0 here
|
||||
func (f *flowControlManager) UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
var fc *flowController
|
||||
if streamID == 0 {
|
||||
fc = f.connFlowController
|
||||
} else {
|
||||
var err error
|
||||
fc, err = f.getFlowController(streamID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return fc.UpdateSendWindow(offset), nil
|
||||
}
|
||||
|
||||
func (f *flowControlManager) getFlowController(streamID protocol.StreamID) (*flowController, error) {
|
||||
streamFlowController, ok := f.streamFlowController[streamID]
|
||||
if !ok {
|
||||
return nil, errMapAccess
|
||||
}
|
||||
return streamFlowController, nil
|
||||
}
|
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
198
vendor/github.com/lucas-clemente/quic-go/flowcontrol/flow_controller.go
generated
vendored
|
@ -1,198 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type flowController struct {
|
||||
streamID protocol.StreamID
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
|
||||
connectionParameters handshake.ConnectionParametersManager
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
|
||||
lastWindowUpdateTime time.Time
|
||||
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowIncrement protocol.ByteCount
|
||||
maxReceiveWindowIncrement protocol.ByteCount
|
||||
}
|
||||
|
||||
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
|
||||
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
|
||||
|
||||
// newFlowController gets a new flow controller
|
||||
func newFlowController(streamID protocol.StreamID, contributesToConnection bool, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
|
||||
fc := flowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connectionParameters: connectionParameters,
|
||||
rttStats: rttStats,
|
||||
}
|
||||
|
||||
if streamID == 0 {
|
||||
fc.receiveWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
|
||||
fc.receiveWindowIncrement = fc.receiveWindow
|
||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
|
||||
} else {
|
||||
fc.receiveWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
|
||||
fc.receiveWindowIncrement = fc.receiveWindow
|
||||
fc.maxReceiveWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
|
||||
}
|
||||
|
||||
return &fc
|
||||
}
|
||||
|
||||
func (c *flowController) ContributesToConnection() bool {
|
||||
return c.contributesToConnection
|
||||
}
|
||||
|
||||
func (c *flowController) getSendWindow() protocol.ByteCount {
|
||||
if c.sendWindow == 0 {
|
||||
if c.streamID == 0 {
|
||||
return c.connectionParameters.GetSendConnectionFlowControlWindow()
|
||||
}
|
||||
return c.connectionParameters.GetSendStreamFlowControlWindow()
|
||||
}
|
||||
return c.sendWindow
|
||||
}
|
||||
|
||||
func (c *flowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||
// it returns true if the window was actually updated
|
||||
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
|
||||
if newOffset > c.sendWindow {
|
||||
c.sendWindow = newOffset
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *flowController) SendWindowSize() protocol.ByteCount {
|
||||
sendWindow := c.getSendWindow()
|
||||
|
||||
if c.bytesSent > sendWindow { // should never happen, but make sure we don't do an underflow here
|
||||
return 0
|
||||
}
|
||||
return sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
func (c *flowController) SendWindowOffset() protocol.ByteCount {
|
||||
return c.getSendWindow()
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
||||
// Should **only** be used for the stream-level FlowController
|
||||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
||||
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
|
||||
// It should only be treated as an error when resetting a stream
|
||||
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
|
||||
if byteOffset == c.highestReceived {
|
||||
return 0, nil
|
||||
}
|
||||
if byteOffset > c.highestReceived {
|
||||
increment := byteOffset - c.highestReceived
|
||||
c.highestReceived = byteOffset
|
||||
return increment, nil
|
||||
}
|
||||
return 0, ErrReceivedSmallerByteOffset
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
// Should **only** be used for the connection-level FlowController
|
||||
func (c *flowController) IncrementHighestReceived(increment protocol.ByteCount) {
|
||||
c.highestReceived += increment
|
||||
}
|
||||
|
||||
func (c *flowController) AddBytesRead(n protocol.ByteCount) {
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window increment already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
// MaybeUpdateWindow updates the receive window, if necessary
|
||||
// if the receive window increment is changed, the new value is returned, otherwise a 0
|
||||
// the last return value is the new offset of the receive window
|
||||
func (c *flowController) MaybeUpdateWindow() (bool, protocol.ByteCount /* new increment */, protocol.ByteCount /* new offset */) {
|
||||
diff := c.receiveWindow - c.bytesRead
|
||||
|
||||
// Chromium implements the same threshold
|
||||
if diff < (c.receiveWindowIncrement / 2) {
|
||||
var newWindowIncrement protocol.ByteCount
|
||||
oldWindowIncrement := c.receiveWindowIncrement
|
||||
|
||||
c.maybeAdjustWindowIncrement()
|
||||
if c.receiveWindowIncrement != oldWindowIncrement {
|
||||
newWindowIncrement = c.receiveWindowIncrement
|
||||
}
|
||||
|
||||
c.lastWindowUpdateTime = time.Now()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowIncrement
|
||||
return true, newWindowIncrement, c.receiveWindow
|
||||
}
|
||||
|
||||
return false, 0, 0
|
||||
}
|
||||
|
||||
// maybeAdjustWindowIncrement increases the receiveWindowIncrement if we're sending WindowUpdates too often
|
||||
func (c *flowController) maybeAdjustWindowIncrement() {
|
||||
if c.lastWindowUpdateTime.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timeSinceLastWindowUpdate := time.Since(c.lastWindowUpdateTime)
|
||||
|
||||
// interval between the window updates is sufficiently large, no need to increase the increment
|
||||
if timeSinceLastWindowUpdate >= 2*rtt {
|
||||
return
|
||||
}
|
||||
|
||||
oldWindowSize := c.receiveWindowIncrement
|
||||
c.receiveWindowIncrement = utils.MinByteCount(2*c.receiveWindowIncrement, c.maxReceiveWindowIncrement)
|
||||
|
||||
// debug log, if the window size was actually increased
|
||||
if oldWindowSize < c.receiveWindowIncrement {
|
||||
newWindowSize := c.receiveWindowIncrement / (1 << 10)
|
||||
if c.streamID == 0 {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
|
||||
} else {
|
||||
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowIncrement sets a minimum window increment
|
||||
// it is intended be used for the connection-level flow controller
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *flowController) EnsureMinimumWindowIncrement(inc protocol.ByteCount) {
|
||||
if inc > c.receiveWindowIncrement {
|
||||
c.receiveWindowIncrement = utils.MinByteCount(inc, c.maxReceiveWindowIncrement)
|
||||
c.lastWindowUpdateTime = time.Time{} // disables autotuning for the next window update
|
||||
}
|
||||
}
|
||||
|
||||
func (c *flowController) CheckFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
26
vendor/github.com/lucas-clemente/quic-go/flowcontrol/interface.go
generated
vendored
|
@ -1,26 +0,0 @@
|
|||
package flowcontrol
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// WindowUpdate provides the data for WindowUpdateFrames.
|
||||
type WindowUpdate struct {
|
||||
StreamID protocol.StreamID
|
||||
Offset protocol.ByteCount
|
||||
}
|
||||
|
||||
// A FlowControlManager manages the flow control
|
||||
type FlowControlManager interface {
|
||||
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
|
||||
RemoveStream(streamID protocol.StreamID)
|
||||
// methods needed for receiving data
|
||||
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
|
||||
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
|
||||
GetWindowUpdates() []WindowUpdate
|
||||
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
|
||||
// methods needed for sending data
|
||||
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
|
||||
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
|
||||
RemainingConnectionWindowSize() protocol.ByteCount
|
||||
UpdateWindow(streamID protocol.StreamID, offset protocol.ByteCount) (bool, error)
|
||||
}
|
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/frames/ack_range.go
generated
vendored
|
@ -1,9 +0,0 @@
|
|||
package frames
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// AckRange is an ACK range
|
||||
type AckRange struct {
|
||||
FirstPacketNumber protocol.PacketNumber
|
||||
LastPacketNumber protocol.PacketNumber
|
||||
}
|
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
44
vendor/github.com/lucas-clemente/quic-go/frames/blocked_frame.go
generated
vendored
|
@ -1,44 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A BlockedFrame in QUIC
|
||||
type BlockedFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
}
|
||||
|
||||
//Write writes a BlockedFrame frame
|
||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x05)
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *BlockedFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4, nil
|
||||
}
|
||||
|
||||
// ParseBlockedFrame parses a BLOCKED frame
|
||||
func ParseBlockedFrame(r *bytes.Reader) (*BlockedFrame, error) {
|
||||
frame := &BlockedFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
return frame, nil
|
||||
}
|
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/frames/connection_close_frame.go
generated
vendored
|
@ -1,73 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A ConnectionCloseFrame in QUIC
|
||||
type ConnectionCloseFrame struct {
|
||||
ErrorCode qerr.ErrorCode
|
||||
ReasonPhrase string
|
||||
}
|
||||
|
||||
// ParseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
||||
func ParseConnectionCloseFrame(r *bytes.Reader) (*ConnectionCloseFrame, error) {
|
||||
frame := &ConnectionCloseFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorCode, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
||||
|
||||
reasonPhraseLen, err := utils.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if reasonPhraseLen > uint16(protocol.MaxPacketSize) {
|
||||
return nil, qerr.Error(qerr.InvalidConnectionCloseData, "reason phrase too long")
|
||||
}
|
||||
|
||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ReasonPhrase = string(reasonPhrase)
|
||||
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *ConnectionCloseFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase)), nil
|
||||
}
|
||||
|
||||
// Write writes an CONNECTION_CLOSE frame.
|
||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x02)
|
||||
utils.WriteUint32(b, uint32(f.ErrorCode))
|
||||
|
||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
||||
}
|
||||
|
||||
reasonPhraseLen := uint16(len(f.ReasonPhrase))
|
||||
utils.WriteUint16(b, reasonPhraseLen)
|
||||
b.WriteString(f.ReasonPhrase)
|
||||
|
||||
return nil
|
||||
}
|
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/frames/frame.go
generated
vendored
|
@ -1,13 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A Frame in QUIC
|
||||
type Frame interface {
|
||||
Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
||||
MinLength(version protocol.VersionNumber) (protocol.ByteCount, error)
|
||||
}
|
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/frames/log.go
generated
vendored
|
@ -1,28 +0,0 @@
|
|||
package frames
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
// LogFrame logs a frame, either sent or received
|
||||
func LogFrame(frame Frame, sent bool) {
|
||||
if !utils.Debug() {
|
||||
return
|
||||
}
|
||||
dir := "<-"
|
||||
if sent {
|
||||
dir = "->"
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *StreamFrame:
|
||||
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
||||
case *StopWaitingFrame:
|
||||
if sent {
|
||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
||||
} else {
|
||||
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
||||
}
|
||||
case *AckFrame:
|
||||
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
|
||||
default:
|
||||
utils.Debugf("\t%s %#v", dir, frame)
|
||||
}
|
||||
}
|
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
59
vendor/github.com/lucas-clemente/quic-go/frames/rst_stream_frame.go
generated
vendored
|
@ -1,59 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A RstStreamFrame in QUIC
|
||||
type RstStreamFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ErrorCode uint32
|
||||
ByteOffset protocol.ByteCount
|
||||
}
|
||||
|
||||
//Write writes a RST_STREAM frame
|
||||
func (f *RstStreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x01)
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
||||
utils.WriteUint32(b, f.ErrorCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *RstStreamFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 8 + 4, nil
|
||||
}
|
||||
|
||||
// ParseRstStreamFrame parses a RST_STREAM frame
|
||||
func ParseRstStreamFrame(r *bytes.Reader) (*RstStreamFrame, error) {
|
||||
frame := &RstStreamFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
byteOffset, err := utils.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
||||
|
||||
frame.ErrorCode, err = utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
}
|
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
54
vendor/github.com/lucas-clemente/quic-go/frames/window_update_frame.go
generated
vendored
|
@ -1,54 +0,0 @@
|
|||
package frames
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// A WindowUpdateFrame in QUIC
|
||||
type WindowUpdateFrame struct {
|
||||
StreamID protocol.StreamID
|
||||
ByteOffset protocol.ByteCount
|
||||
}
|
||||
|
||||
//Write writes a RST_STREAM frame
|
||||
func (f *WindowUpdateFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
typeByte := uint8(0x04)
|
||||
b.WriteByte(typeByte)
|
||||
|
||||
utils.WriteUint32(b, uint32(f.StreamID))
|
||||
utils.WriteUint64(b, uint64(f.ByteOffset))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MinLength of a written frame
|
||||
func (f *WindowUpdateFrame) MinLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
return 1 + 4 + 8, nil
|
||||
}
|
||||
|
||||
// ParseWindowUpdateFrame parses a RST_STREAM frame
|
||||
func ParseWindowUpdateFrame(r *bytes.Reader) (*WindowUpdateFrame, error) {
|
||||
frame := &WindowUpdateFrame{}
|
||||
|
||||
// read the TypeByte
|
||||
_, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sid, err := utils.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.StreamID = protocol.StreamID(sid)
|
||||
|
||||
byteOffset, err := utils.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ByteOffset = protocol.ByteCount(byteOffset)
|
||||
|
||||
return frame, nil
|
||||
}
|
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
109
vendor/github.com/lucas-clemente/quic-go/h2quic/client.go
generated
vendored
|
@ -15,8 +15,8 @@ import (
|
|||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -34,10 +34,10 @@ type client struct {
|
|||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
hostname string
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
hostname string
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
|
@ -51,8 +51,8 @@ type client struct {
|
|||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDTruncation: true,
|
||||
KeepAlive: true,
|
||||
RequestConnectionIDOmission: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// newClient creates a new client
|
||||
|
@ -61,26 +61,31 @@ func newClient(
|
|||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
config = quicConfig
|
||||
}
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
encryptionLevel: protocol.EncryptionUnencrypted,
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
if c.dialer != nil {
|
||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -90,9 +95,6 @@ func (c *client) dial() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.headerStream.StreamID() != 3 {
|
||||
return errors.New("h2quic Client BUG: StreamID of Header Stream is not 3")
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
|
@ -102,45 +104,44 @@ func (c *client) handleHeaderStream() {
|
|||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var lastStream protocol.StreamID
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
utils.Debugf("Error handling header stream: %s", err)
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
for {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
break
|
||||
}
|
||||
lastStream = protocol.StreamID(frame.Header().StreamID)
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
|
||||
break
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
|
||||
break
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
|
||||
break
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
|
||||
}
|
||||
responseChan <- rsp
|
||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
return errors.New("not a headers frame")
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||
}
|
||||
|
||||
// stop all running request
|
||||
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
|
||||
close(c.headerErrored)
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseChan <- rsp
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/h2quic/request_writer.go
generated
vendored
|
@ -13,8 +13,8 @@ import (
|
|||
"golang.org/x/net/lex/httplex"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
type requestWriter struct {
|
||||
|
|
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
4
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer.go
generated
vendored
|
@ -8,8 +8,8 @@ import (
|
|||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
@ -83,7 +83,7 @@ func (w *responseWriter) Write(p []byte) (int, error) {
|
|||
|
||||
func (w *responseWriter) Flush() {}
|
||||
|
||||
// TODO: Implement a functional CloseNotify method.
|
||||
// This is a NOP. Use http.Request.Context
|
||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||
|
||||
// test that we implement http.Flusher
|
||||
|
|
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
13
vendor/github.com/lucas-clemente/quic-go/h2quic/roundtrip.go
generated
vendored
|
@ -41,6 +41,11 @@ type RoundTripper struct {
|
|||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
|
@ -120,7 +125,13 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
|||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
|
||||
client = newClient(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client, nil
|
||||
|
|
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
87
vendor/github.com/lucas-clemente/quic-go/h2quic/server.go
generated
vendored
|
@ -7,14 +7,14 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
@ -50,6 +50,7 @@ type Server struct {
|
|||
|
||||
listenerMutex sync.Mutex
|
||||
listener quic.Listener
|
||||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
}
|
||||
|
@ -88,6 +89,10 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("Server is already closed")
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
|
@ -122,29 +127,23 @@ func (s *Server) handleHeaderStream(session streamCreator) {
|
|||
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
|
||||
return
|
||||
}
|
||||
if stream.StreamID() != 3 {
|
||||
session.Close(qerr.Error(qerr.InternalError, "h2quic server BUG: header stream does not have stream ID 3"))
|
||||
return
|
||||
}
|
||||
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, stream)
|
||||
|
||||
go func() {
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
return
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
if _, ok := err.(*qerr.QuicError); !ok {
|
||||
utils.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.Close(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
|
@ -170,8 +169,6 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
return err
|
||||
}
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
if utils.Debug() {
|
||||
utils.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
} else {
|
||||
|
@ -187,19 +184,25 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
return nil
|
||||
}
|
||||
|
||||
var streamEnded bool
|
||||
if h2headersFrame.StreamEnded() {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
// handleRequest should be as non-blocking as possible to minimize
|
||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||
// goroutine, enabling handleRequest to return before the code is executed.
|
||||
go func() {
|
||||
streamEnded := h2headersFrame.StreamEnded()
|
||||
if streamEnded {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
req = req.WithContext(dataStream.Context())
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
|
||||
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
|
@ -225,7 +228,8 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !reqBody.requestRead {
|
||||
responseWriter.dataStream.Reset(nil)
|
||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||
responseWriter.dataStream.CancelRead(0)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
|
@ -243,6 +247,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream,
|
|||
func (s *Server) Close() error {
|
||||
s.listenerMutex.Lock()
|
||||
defer s.listenerMutex.Unlock()
|
||||
s.closed = true
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
|
@ -279,12 +284,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
|||
}
|
||||
|
||||
if s.supportedVersionsAsString == "" {
|
||||
for i, v := range protocol.SupportedVersions {
|
||||
s.supportedVersionsAsString += strconv.Itoa(int(v))
|
||||
if i != len(protocol.SupportedVersions)-1 {
|
||||
s.supportedVersionsAsString += ","
|
||||
}
|
||||
var versions []string
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
versions = append(versions, v.ToAltSvc())
|
||||
}
|
||||
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||
}
|
||||
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||
|
@ -344,6 +348,9 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||
}
|
||||
defer tcpConn.Close()
|
||||
|
||||
tlsConn := tls.NewListener(tcpConn, config)
|
||||
defer tlsConn.Close()
|
||||
|
||||
// Start the servers
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
|
@ -365,7 +372,7 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- httpServer.Serve(tcpConn)
|
||||
hErr <- httpServer.Serve(tlsConn)
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
|
|
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
265
vendor/github.com/lucas-clemente/quic-go/handshake/connection_parameters_manager.go
generated
vendored
|
@ -1,265 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// ConnectionParametersManager negotiates and stores the connection parameters
|
||||
// A ConnectionParametersManager can be used for a server as well as a client
|
||||
// For the server:
|
||||
// 1. call SetFromMap with the values received in the CHLO. This sets the corresponding values here, subject to negotiation
|
||||
// 2. call GetHelloMap to get the values to send in the SHLO
|
||||
// For the client:
|
||||
// 1. call GetHelloMap to get the values to send in a CHLO
|
||||
// 2. call SetFromMap with the values received in the SHLO
|
||||
type ConnectionParametersManager interface {
|
||||
SetFromMap(map[Tag][]byte) error
|
||||
GetHelloMap() (map[Tag][]byte, error)
|
||||
|
||||
GetSendStreamFlowControlWindow() protocol.ByteCount
|
||||
GetSendConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetReceiveStreamFlowControlWindow() protocol.ByteCount
|
||||
GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount
|
||||
GetReceiveConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount
|
||||
GetMaxOutgoingStreams() uint32
|
||||
GetMaxIncomingStreams() uint32
|
||||
GetIdleConnectionStateLifetime() time.Duration
|
||||
TruncateConnectionID() bool
|
||||
}
|
||||
|
||||
type connectionParametersManager struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
perspective protocol.Perspective
|
||||
|
||||
flowControlNegotiated bool
|
||||
|
||||
truncateConnectionID bool
|
||||
maxStreamsPerConnection uint32
|
||||
maxIncomingDynamicStreamsPerConnection uint32
|
||||
idleConnectionStateLifetime time.Duration
|
||||
sendStreamFlowControlWindow protocol.ByteCount
|
||||
sendConnectionFlowControlWindow protocol.ByteCount
|
||||
receiveStreamFlowControlWindow protocol.ByteCount
|
||||
receiveConnectionFlowControlWindow protocol.ByteCount
|
||||
maxReceiveStreamFlowControlWindow protocol.ByteCount
|
||||
maxReceiveConnectionFlowControlWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
var _ ConnectionParametersManager = &connectionParametersManager{}
|
||||
|
||||
// ErrMalformedTag is returned when the tag value cannot be read
|
||||
var (
|
||||
ErrMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
||||
ErrFlowControlRenegotiationNotSupported = qerr.Error(qerr.InvalidCryptoMessageParameter, "renegotiation of flow control parameters not supported")
|
||||
)
|
||||
|
||||
// NewConnectionParamatersManager creates a new connection parameters manager
|
||||
func NewConnectionParamatersManager(
|
||||
pers protocol.Perspective, v protocol.VersionNumber,
|
||||
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
|
||||
) ConnectionParametersManager {
|
||||
h := &connectionParametersManager{
|
||||
perspective: pers,
|
||||
version: v,
|
||||
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
|
||||
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
|
||||
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
}
|
||||
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
h.idleConnectionStateLifetime = protocol.DefaultIdleTimeout
|
||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective
|
||||
} else {
|
||||
h.idleConnectionStateLifetime = protocol.MaxIdleTimeoutClient
|
||||
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
|
||||
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the server's perspective
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// SetFromMap reads all params
|
||||
func (h *connectionParametersManager) SetFromMap(params map[Tag][]byte) error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if value, ok := params[TagTCID]; ok && h.perspective == protocol.PerspectiveServer {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.truncateConnectionID = (clientValue == 0)
|
||||
}
|
||||
if value, ok := params[TagMSPC]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue)
|
||||
}
|
||||
if value, ok := params[TagMIDS]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.maxIncomingDynamicStreamsPerConnection = h.negotiateMaxIncomingDynamicStreamsPerConnection(clientValue)
|
||||
}
|
||||
if value, ok := params[TagICSL]; ok {
|
||||
clientValue, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second)
|
||||
}
|
||||
if value, ok := params[TagSFCW]; ok {
|
||||
if h.flowControlNegotiated {
|
||||
return ErrFlowControlRenegotiationNotSupported
|
||||
}
|
||||
sendStreamFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.sendStreamFlowControlWindow = protocol.ByteCount(sendStreamFlowControlWindow)
|
||||
}
|
||||
if value, ok := params[TagCFCW]; ok {
|
||||
if h.flowControlNegotiated {
|
||||
return ErrFlowControlRenegotiationNotSupported
|
||||
}
|
||||
sendConnectionFlowControlWindow, err := utils.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return ErrMalformedTag
|
||||
}
|
||||
h.sendConnectionFlowControlWindow = protocol.ByteCount(sendConnectionFlowControlWindow)
|
||||
}
|
||||
|
||||
_, containsSFCW := params[TagSFCW]
|
||||
_, containsCFCW := params[TagCFCW]
|
||||
if containsCFCW || containsSFCW {
|
||||
h.flowControlNegotiated = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 {
|
||||
return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection)
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateMaxIncomingDynamicStreamsPerConnection(clientValue uint32) uint32 {
|
||||
return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection)
|
||||
}
|
||||
|
||||
func (h *connectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration {
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutServer)
|
||||
}
|
||||
return utils.MinDuration(clientValue, protocol.MaxIdleTimeoutClient)
|
||||
}
|
||||
|
||||
// GetHelloMap gets all parameters needed for the Hello message
|
||||
func (h *connectionParametersManager) GetHelloMap() (map[Tag][]byte, error) {
|
||||
sfcw := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow()))
|
||||
cfcw := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow()))
|
||||
mspc := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(mspc, h.maxStreamsPerConnection)
|
||||
mids := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection)
|
||||
icsl := bytes.NewBuffer([]byte{})
|
||||
utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second))
|
||||
|
||||
return map[Tag][]byte{
|
||||
TagICSL: icsl.Bytes(),
|
||||
TagMSPC: mspc.Bytes(),
|
||||
TagMIDS: mids.Bytes(),
|
||||
TagCFCW: cfcw.Bytes(),
|
||||
TagSFCW: sfcw.Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.sendStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetSendConnectionFlowControlWindow gets the size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.sendConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetReceiveStreamFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
||||
func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.receiveStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
|
||||
return h.maxReceiveStreamFlowControlWindow
|
||||
}
|
||||
|
||||
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
|
||||
func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.receiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
|
||||
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
|
||||
return h.maxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection
|
||||
func (h *connectionParametersManager) GetMaxOutgoingStreams() uint32 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
return h.maxIncomingDynamicStreamsPerConnection
|
||||
}
|
||||
|
||||
// GetMaxIncomingStreams get the maximum number of incoming streams per connection
|
||||
func (h *connectionParametersManager) GetMaxIncomingStreams() uint32 {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
maxStreams := protocol.MaxIncomingDynamicStreamsPerConnection
|
||||
return utils.MaxUint32(uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement, uint32(float64(maxStreams)*protocol.MaxStreamsMultiplier))
|
||||
}
|
||||
|
||||
// GetIdleConnectionStateLifetime gets the idle timeout
|
||||
func (h *connectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.idleConnectionStateLifetime
|
||||
}
|
||||
|
||||
// TruncateConnectionID determines if the client requests truncated ConnectionIDs
|
||||
func (h *connectionParametersManager) TruncateConnectionID() bool {
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
return false
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.truncateConnectionID
|
||||
}
|
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/handshake/interface.go
generated
vendored
|
@ -1,24 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
// Sealer seals a packet
|
||||
type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
|
||||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// TransportParameters are parameters sent to the peer during the handshake
|
||||
type TransportParameters struct {
|
||||
RequestConnectionIDTruncation bool
|
||||
}
|
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/handshake/stk_generator.go
generated
vendored
|
@ -1,100 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
stkPrefixIP byte = iota
|
||||
stkPrefixString
|
||||
)
|
||||
|
||||
// An STK is a source address token
|
||||
type STK struct {
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
Data []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// An STKGenerator generates STKs
|
||||
type STKGenerator struct {
|
||||
stkSource crypto.StkSource
|
||||
}
|
||||
|
||||
// NewSTKGenerator initializes a new STKGenerator
|
||||
func NewSTKGenerator() (*STKGenerator, error) {
|
||||
stkSource, err := crypto.NewStkSource()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &STKGenerator{
|
||||
stkSource: stkSource,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewToken generates a new STK token for a given source address
|
||||
func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
Data: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.stkSource.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes an STK token
|
||||
func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) {
|
||||
// if the client didn't send any STK, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.stkSource.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
return &STK{
|
||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
||||
SentTime: time.Unix(t.Timestamp, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the STK
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{stkPrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
|
||||
// decodeRemoteAddr decodes the remote address saved in the STK
|
||||
func decodeRemoteAddr(data []byte) string {
|
||||
// data will never be empty for an STK that we generated. Check it to be on the safe side
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
if data[0] == stkPrefixIP {
|
||||
return net.IP(data[1:]).String()
|
||||
}
|
||||
return string(data[1:])
|
||||
}
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/chrome/chrome.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package chrome
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/gquic/gquic.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package gquic
|
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
1
vendor/github.com/lucas-clemente/quic-go/integrationtests/self/self.go
generated
vendored
|
@ -1 +0,0 @@
|
|||
package self
|
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy.go
generated
vendored
|
@ -1,14 +1,12 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Connection is a UDP connection
|
||||
|
@ -28,21 +26,43 @@ const (
|
|||
DirectionIncoming Direction = iota
|
||||
// DirectionOutgoing is the direction from the server to the client.
|
||||
DirectionOutgoing
|
||||
// DirectionBoth is both incoming and outgoing
|
||||
DirectionBoth
|
||||
)
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionIncoming:
|
||||
return "incoming"
|
||||
case DirectionOutgoing:
|
||||
return "outgoing"
|
||||
case DirectionBoth:
|
||||
return "both"
|
||||
default:
|
||||
panic("unknown direction")
|
||||
}
|
||||
}
|
||||
|
||||
func (d Direction) Is(dir Direction) bool {
|
||||
if d == DirectionBoth || dir == DirectionBoth {
|
||||
return true
|
||||
}
|
||||
return d == dir
|
||||
}
|
||||
|
||||
// DropCallback is a callback that determines which packet gets dropped.
|
||||
type DropCallback func(Direction, protocol.PacketNumber) bool
|
||||
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||
|
||||
// NoDropper doesn't drop packets.
|
||||
var NoDropper DropCallback = func(Direction, protocol.PacketNumber) bool {
|
||||
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||
type DelayCallback func(Direction, protocol.PacketNumber) time.Duration
|
||||
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||
|
||||
// NoDelay doesn't apply a delay.
|
||||
var NoDelay DelayCallback = func(Direction, protocol.PacketNumber) time.Duration {
|
||||
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -62,6 +82,8 @@ type Opts struct {
|
|||
type QuicProxy struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
conn *net.UDPConn
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
|
@ -73,7 +95,10 @@ type QuicProxy struct {
|
|||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
||||
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -103,6 +128,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) {
|
|||
serverAddr: raddr,
|
||||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
version: version,
|
||||
}
|
||||
|
||||
go p.runProxy()
|
||||
|
@ -119,6 +145,7 @@ func (p *QuicProxy) LocalAddr() net.Addr {
|
|||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalPort is the UDP port number the proxy is listening on.
|
||||
func (p *QuicProxy) LocalPort() int {
|
||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||
}
|
||||
|
@ -137,7 +164,7 @@ func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
|||
// runProxy listens on the proxy address and handles incoming packets.
|
||||
func (p *QuicProxy) runProxy() error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketSize)
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -159,20 +186,14 @@ func (p *QuicProxy) runProxy() error {
|
|||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
|
||||
r := bytes.NewReader(raw)
|
||||
hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.dropPacket(DirectionIncoming, hdr.PacketNumber) {
|
||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, hdr.PacketNumber)
|
||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||
if delay != 0 {
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
|
@ -190,28 +211,20 @@ func (p *QuicProxy) runProxy() error {
|
|||
// runConnection handles packets from server to a single client
|
||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxPacketSize)
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
// TODO: Switch back to using the public header once Chrome properly sets the type byte.
|
||||
// r := bytes.NewReader(raw)
|
||||
// , err := quic.ParsePublicHeader(r, protocol.PerspectiveServer)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
v := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
packetNumber := protocol.PacketNumber(v)
|
||||
if p.dropPacket(DirectionOutgoing, packetNumber) {
|
||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, packetNumber)
|
||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||
if delay != 0 {
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testlog/testlog.go
generated
vendored
|
@ -27,7 +27,7 @@ var _ = BeforeEach(func() {
|
|||
|
||||
if len(logFileName) > 0 {
|
||||
var err error
|
||||
logFile, err = os.Create("./log.txt")
|
||||
logFile, err = os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
log.SetOutput(logFile)
|
||||
utils.SetLogLevel(utils.LogLevelDebug)
|
||||
|
|
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
18
vendor/github.com/lucas-clemente/quic-go/integrationtests/tools/testserver/server.go
generated
vendored
|
@ -7,7 +7,9 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
|
@ -23,8 +25,9 @@ var (
|
|||
PRData = GeneratePRData(dataLen)
|
||||
PRDataLong = GeneratePRData(dataLenLong)
|
||||
|
||||
server *h2quic.Server
|
||||
port string
|
||||
server *h2quic.Server
|
||||
stoppedServing chan struct{}
|
||||
port string
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -75,11 +78,16 @@ func GeneratePRData(l int) []byte {
|
|||
return res
|
||||
}
|
||||
|
||||
func StartQuicServer() {
|
||||
// StartQuicServer starts a h2quic.Server.
|
||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||
server = &h2quic.Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: versions,
|
||||
},
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
|
@ -88,14 +96,18 @@ func StartQuicServer() {
|
|||
Expect(err).NotTo(HaveOccurred())
|
||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
stoppedServing = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.Serve(conn)
|
||||
close(stoppedServing)
|
||||
}()
|
||||
}
|
||||
|
||||
func StopQuicServer() {
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
}
|
||||
|
||||
func Port() string {
|
||||
|
|
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
121
vendor/github.com/lucas-clemente/quic-go/interface.go
generated
vendored
|
@ -6,23 +6,55 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
||||
// An ErrorCode is an application-defined error code.
|
||||
type ErrorCode = protocol.ApplicationErrorCode
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
type Stream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Reader
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
StreamID() protocol.StreamID
|
||||
// Reset closes the stream with an error.
|
||||
Reset(error)
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// It must not be called after Close.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
CancelWrite(ErrorCode) error
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
CancelRead(ErrorCode) error
|
||||
// The context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
|
@ -43,6 +75,41 @@ type Stream interface {
|
|||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Read
|
||||
io.Reader
|
||||
// see Stream.CancelRead
|
||||
CancelRead(ErrorCode) error
|
||||
// see Stream.SetReadDealine
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Write
|
||||
io.Writer
|
||||
// see Stream.Close
|
||||
io.Closer
|
||||
// see Stream.CancelWrite
|
||||
CancelWrite(ErrorCode) error
|
||||
// see Stream.Context
|
||||
Context() context.Context
|
||||
// see Stream.SetWriteDeadline
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||
type StreamError interface {
|
||||
error
|
||||
Canceled() bool
|
||||
ErrorCode() ErrorCode
|
||||
}
|
||||
|
||||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
|
@ -64,53 +131,41 @@ type Session interface {
|
|||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
// The communication is encrypted, but not yet forward secure.
|
||||
type NonFWSession interface {
|
||||
Session
|
||||
WaitUntilHandshakeComplete() error
|
||||
}
|
||||
|
||||
// An STK is a Source Address token.
|
||||
// It is issued by the server and sent to the client. For the client, it is an opaque blob.
|
||||
// The client can send the STK in subsequent handshakes to prove ownership of its IP address.
|
||||
type STK struct {
|
||||
// The remote address this token was issued for.
|
||||
// If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String())
|
||||
// Otherwise, this is the string representation of the net.Addr (net.Addr.String())
|
||||
remoteAddr string
|
||||
// The time that the STK was issued (resolution 1 second)
|
||||
sentTime time.Time
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []protocol.VersionNumber
|
||||
// Ask the server to truncate the connection ID sent in the Public Header.
|
||||
Versions []VersionNumber
|
||||
// Ask the server to omit the connection ID sent in the Public Header.
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDTruncation bool
|
||||
RequestConnectionIDOmission bool
|
||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 10 seconds.
|
||||
HandshakeTimeout time.Duration
|
||||
// AcceptSTK determines if an STK is accepted.
|
||||
// It is called with stk = nil if the client didn't send an STK.
|
||||
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
|
||||
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
IdleTimeout time.Duration
|
||||
// AcceptCookie determines if a Cookie is accepted.
|
||||
// It is called with cookie = nil if the client didn't send an Cookie.
|
||||
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
|
||||
// This option is only valid for the server.
|
||||
AcceptSTK func(clientAddr net.Addr, stk *STK) bool
|
||||
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
|
||||
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
||||
MaxReceiveStreamFlowControlWindow protocol.ByteCount
|
||||
MaxReceiveStreamFlowControlWindow uint64
|
||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
|
||||
MaxReceiveConnectionFlowControlWindow uint64
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
}
|
||||
|
|
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
|
@ -0,0 +1,48 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet) error
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
// SendingAllowed says if a packet can be sent.
|
||||
// Sending packets might not be possible because:
|
||||
// * we're congestion limited
|
||||
// * we're tracking the maximum number of sent packets
|
||||
SendingAllowed() bool
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||
// It always returns a number greater or equal than 1.
|
||||
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() (packet *Packet)
|
||||
GetLeastUnacked() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm()
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
IgnoreBelow(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *wire.AckFrame
|
||||
}
|
|
@ -3,29 +3,30 @@ package ackhandler
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
// +gen linkedlist
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
Frames []frames.Frame
|
||||
Frames []wire.Frame
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
|
||||
SendTime time.Time
|
||||
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||
sendTime time.Time
|
||||
}
|
||||
|
||||
// GetFramesForRetransmission gets all the frames for retransmission
|
||||
func (p *Packet) GetFramesForRetransmission() []frames.Frame {
|
||||
var fs []frames.Frame
|
||||
func (p *Packet) GetFramesForRetransmission() []wire.Frame {
|
||||
var fs []wire.Frame
|
||||
for _, frame := range p.Frames {
|
||||
switch frame.(type) {
|
||||
case *frames.AckFrame:
|
||||
case *wire.AckFrame:
|
||||
continue
|
||||
case *frames.StopWaitingFrame:
|
||||
case *wire.StopWaitingFrame:
|
||||
continue
|
||||
}
|
||||
fs = append(fs, frame)
|
|
@ -1,18 +1,15 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
lowerLimit protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedReceivedTime time.Time
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
@ -23,46 +20,45 @@ type receivedPacketHandler struct {
|
|||
retransmittablePacketsReceivedSinceLastAck int
|
||||
ackQueued bool
|
||||
ackAlarm time.Time
|
||||
lastAck *frames.AckFrame
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler() ReceivedPacketHandler {
|
||||
func NewReceivedPacketHandler(version protocol.VersionNumber) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: protocol.AckSendDelay,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
|
||||
if packetNumber == 0 {
|
||||
return errInvalidPacketNumber
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = time.Now()
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if packetNumber <= h.lowerLimit {
|
||||
if packetNumber < h.ignoreBelow {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, shouldInstigateAck)
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLowerLimit sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller or equal than p will not be acked.
|
||||
func (h *receivedPacketHandler) SetLowerLimit(p protocol.PacketNumber) {
|
||||
h.lowerLimit = p
|
||||
h.packetHistory.DeleteUpTo(p)
|
||||
// IgnoreBelow sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||
h.ignoreBelow = p
|
||||
h.packetHistory.DeleteBelow(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
if shouldInstigateAck {
|
||||
|
@ -74,12 +70,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// Always send an ack every 20 packets in order to allow the peer to discard
|
||||
// information from the SentPacketManager and provide an RTT measurement.
|
||||
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
|
||||
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
|
||||
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
|
||||
|
@ -87,7 +77,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
}
|
||||
|
||||
// check if a new missing range above the previously was created
|
||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
|
||||
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().First > h.lastAck.LargestAcked {
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
|
@ -96,7 +86,7 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
h.ackQueued = true
|
||||
} else {
|
||||
if h.ackAlarm.IsZero() {
|
||||
h.ackAlarm = time.Now().Add(h.ackSendDelay)
|
||||
h.ackAlarm = rcvTime.Add(h.ackSendDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,15 +97,15 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
|
|||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
|
||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ackRanges := h.packetHistory.GetAckRanges()
|
||||
ack := &frames.AckFrame{
|
||||
ack := &wire.AckFrame{
|
||||
LargestAcked: h.largestObserved,
|
||||
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
|
||||
LowestAcked: ackRanges[len(ackRanges)-1].First,
|
||||
PacketReceivedTime: h.largestObservedReceivedTime,
|
||||
}
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -12,21 +12,15 @@ import (
|
|||
type receivedPacketHistory struct {
|
||||
ranges *utils.PacketIntervalList
|
||||
|
||||
// the map is used as a replacement for a set here. The bool is always supposed to be set to true
|
||||
receivedPacketNumbers map[protocol.PacketNumber]bool
|
||||
lowestInReceivedPacketNumbers protocol.PacketNumber
|
||||
}
|
||||
|
||||
var (
|
||||
errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
|
||||
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets")
|
||||
)
|
||||
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
|
||||
|
||||
// newReceivedPacketHistory creates a new received packet history
|
||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||
return &receivedPacketHistory{
|
||||
ranges: utils.NewPacketIntervalList(),
|
||||
receivedPacketNumbers: make(map[protocol.PacketNumber]bool),
|
||||
ranges: utils.NewPacketIntervalList(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,12 +30,6 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
|||
return errTooManyOutstandingReceivedAckRanges
|
||||
}
|
||||
|
||||
if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets {
|
||||
return errTooManyOutstandingReceivedPackets
|
||||
}
|
||||
|
||||
h.receivedPacketNumbers[p] = true
|
||||
|
||||
if h.ranges.Len() == 0 {
|
||||
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
|
||||
return nil
|
||||
|
@ -86,23 +74,20 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all entries up to (and including) p
|
||||
func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
||||
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, p+1)
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p <= h.lowestInReceivedPacketNumbers {
|
||||
return
|
||||
}
|
||||
h.lowestInReceivedPacketNumbers = p
|
||||
|
||||
nextEl := h.ranges.Front()
|
||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
|
||||
if p >= el.Value.Start && p < el.Value.End {
|
||||
for i := el.Value.Start; i <= p; i++ { // adjust start value of a range
|
||||
delete(h.receivedPacketNumbers, i)
|
||||
}
|
||||
el.Value.Start = p + 1
|
||||
} else if el.Value.End <= p { // delete a whole range
|
||||
for i := el.Value.Start; i <= el.Value.End; i++ {
|
||||
delete(h.receivedPacketNumbers, i)
|
||||
}
|
||||
if p > el.Value.Start && p <= el.Value.End {
|
||||
el.Value.Start = p
|
||||
} else if el.Value.End < p { // delete a whole range
|
||||
h.ranges.Remove(el)
|
||||
} else { // no ranges affected. Nothing to do
|
||||
return
|
||||
|
@ -110,38 +95,27 @@ func (h *receivedPacketHistory) DeleteUpTo(p protocol.PacketNumber) {
|
|||
}
|
||||
}
|
||||
|
||||
// IsDuplicate determines if a packet should be regarded as a duplicate packet
|
||||
// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed
|
||||
func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool {
|
||||
if p < h.lowestInReceivedPacketNumbers {
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok := h.receivedPacketNumbers[p]
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
|
||||
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
|
||||
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
||||
if h.ranges.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ackRanges []frames.AckRange
|
||||
|
||||
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||
i := 0
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges = append(ackRanges, frames.AckRange{FirstPacketNumber: el.Value.Start, LastPacketNumber: el.Value.End})
|
||||
ackRanges[i] = wire.AckRange{First: el.Value.Start, Last: el.Value.End}
|
||||
i++
|
||||
}
|
||||
|
||||
return ackRanges
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange {
|
||||
ackRange := frames.AckRange{}
|
||||
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.FirstPacketNumber = r.Start
|
||||
ackRange.LastPacketNumber = r.End
|
||||
ackRange.First = r.Start
|
||||
ackRange.Last = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
|
@ -1,12 +1,10 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
)
|
||||
import "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
// Returns a new slice with all non-retransmittable frames deleted.
|
||||
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
|
||||
res := make([]frames.Frame, 0, len(fs))
|
||||
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
||||
res := make([]wire.Frame, 0, len(fs))
|
||||
for _, f := range fs {
|
||||
if IsFrameRetransmittable(f) {
|
||||
res = append(res, f)
|
||||
|
@ -16,11 +14,11 @@ func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
|
|||
}
|
||||
|
||||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||
func IsFrameRetransmittable(f frames.Frame) bool {
|
||||
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||
switch f.(type) {
|
||||
case *frames.StopWaitingFrame:
|
||||
case *wire.StopWaitingFrame:
|
||||
return false
|
||||
case *frames.AckFrame:
|
||||
case *wire.AckFrame:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
|
@ -28,7 +26,7 @@ func IsFrameRetransmittable(f frames.Frame) bool {
|
|||
}
|
||||
|
||||
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
|
||||
func HasRetransmittableFrames(fs []frames.Frame) bool {
|
||||
func HasRetransmittableFrames(fs []wire.Frame) bool {
|
||||
for _, f := range fs {
|
||||
if IsFrameRetransmittable(f) {
|
||||
return true
|
|
@ -3,12 +3,13 @@ package ackhandler
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -16,33 +17,33 @@ const (
|
|||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// In fraction of an RTT.
|
||||
timeReorderingFraction = 1.0 / 8
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
// Note: This constant is also defined in the congestion package.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
// defaultRTOTimeout is the RTO time on new connections
|
||||
defaultRTOTimeout = 500 * time.Millisecond
|
||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||
minTPLTimeout = 10 * time.Millisecond
|
||||
// Minimum time in the future an RTO alarm may be set for.
|
||||
minRTOTimeout = 200 * time.Millisecond
|
||||
// maxRTOTimeout is the maximum RTO time
|
||||
maxRTOTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
||||
ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
||||
// ErrTooManyTrackedSentPackets occurs when the sentPacketHandler has to keep track of too many packets
|
||||
ErrTooManyTrackedSentPackets = errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
||||
// ErrAckForSkippedPacket occurs when the client sent an ACK for a packet number that we intentionally skipped
|
||||
ErrAckForSkippedPacket = qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
errAckForUnsentPacket = qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
)
|
||||
|
||||
var errPacketNumberNotIncreasing = errors.New("Already sent a packet with a higher packet number")
|
||||
// ErrDuplicateOrOutOfOrderAck occurs when a duplicate or an out-of-order ACK is received
|
||||
var ErrDuplicateOrOutOfOrderAck = errors.New("SentPacketHandler: Duplicate or out-of-order ACK")
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
LargestAcked protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
|
||||
packetHistory *PacketList
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
@ -54,6 +55,10 @@ type sentPacketHandler struct {
|
|||
congestion congestion.SendAlgorithm
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
handshakeComplete bool
|
||||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||
handshakeCount uint32
|
||||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
|
||||
|
@ -82,20 +87,27 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) largestInOrderAcked() protocol.PacketNumber {
|
||||
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||
if f := h.packetHistory.Front(); f != nil {
|
||||
return f.Value.PacketNumber - 1
|
||||
return f.Value.PacketNumber
|
||||
}
|
||||
return h.LargestAcked
|
||||
return h.largestAcked + 1
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
||||
if packet.PacketNumber <= h.lastSentPacketNumber {
|
||||
return errPacketNumberNotIncreasing
|
||||
}
|
||||
|
||||
if protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()+1) > protocol.MaxTrackedSentPackets {
|
||||
return ErrTooManyTrackedSentPackets
|
||||
return errors.New("Too many outstanding non-acked and non-retransmitted packets")
|
||||
}
|
||||
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
|
@ -106,14 +118,22 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
}
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
now := time.Now()
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
|
||||
var largestAcked protocol.PacketNumber
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
largestAcked = ackFrame.LargestAcked
|
||||
}
|
||||
}
|
||||
|
||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
packet.SendTime = now
|
||||
packet.sendTime = now
|
||||
packet.largestAcked = largestAcked
|
||||
h.bytesInFlight += packet.Length
|
||||
h.packetHistory.PushBack(*packet)
|
||||
}
|
||||
|
@ -126,29 +146,32 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
|
|||
isRetransmittable,
|
||||
)
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, now).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||
|
||||
h.updateLossDetectionAlarm(now)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, rcvTime time.Time) error {
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
if ackFrame.LargestAcked > h.lastSentPacketNumber {
|
||||
return errAckForUnsentPacket
|
||||
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
}
|
||||
|
||||
// duplicate or out-of-order ACK
|
||||
// if withPacketNumber <= h.largestReceivedPacketWithAck && withPacketNumber != 0 {
|
||||
if withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
return ErrDuplicateOrOutOfOrderAck
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
|
||||
// ignore repeated ACK (ACKs that don't have a higher LargestAcked than the last ACK)
|
||||
if ackFrame.LargestAcked <= h.largestInOrderAcked() {
|
||||
if ackFrame.LargestAcked < h.lowestUnacked() {
|
||||
return nil
|
||||
}
|
||||
h.LargestAcked = ackFrame.LargestAcked
|
||||
h.largestAcked = ackFrame.LargestAcked
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return ErrAckForSkippedPacket
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
rttUpdated := h.maybeUpdateRTT(ackFrame.LargestAcked, ackFrame.DelayTime, rcvTime)
|
||||
|
@ -164,13 +187,22 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
|
|||
|
||||
if len(ackedPackets) > 0 {
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.Value.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.Value.PacketNumber, p.Value.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.Value.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.Value.largestAcked+1)
|
||||
}
|
||||
h.onPacketAcked(p)
|
||||
h.congestion.OnPacketAcked(p.Value.PacketNumber, p.Value.Length, h.bytesInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
h.detectLostPackets()
|
||||
h.updateLossDetectionAlarm()
|
||||
h.detectLostPackets(rcvTime)
|
||||
h.updateLossDetectionAlarm(rcvTime)
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
@ -178,7 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *frames.AckFrame, withPacketNum
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame) ([]*PacketElement, error) {
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestPacketNotConfirmedAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*PacketElement, error) {
|
||||
var ackedPackets []*PacketElement
|
||||
ackRangeIndex := 0
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
|
@ -197,14 +233,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
|
|||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for packetNumber > ackRange.LastPacketNumber && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
for packetNumber > ackRange.Last && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if packetNumber >= ackRange.FirstPacketNumber { // packet i contained in ACK range
|
||||
if packetNumber > ackRange.LastPacketNumber {
|
||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.FirstPacketNumber, ackRange.LastPacketNumber)
|
||||
if packetNumber >= ackRange.First { // packet i contained in ACK range
|
||||
if packetNumber > ackRange.Last {
|
||||
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", packetNumber, ackRange.First, ackRange.Last)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, el)
|
||||
}
|
||||
|
@ -212,7 +248,6 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *frames.AckFrame
|
|||
ackedPackets = append(ackedPackets, el)
|
||||
}
|
||||
}
|
||||
|
||||
return ackedPackets, nil
|
||||
}
|
||||
|
||||
|
@ -220,7 +255,7 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
if packet.PacketNumber == largestAcked {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.SendTime), ackDelay, time.Now())
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(packet.sendTime), ackDelay, rcvTime)
|
||||
return true
|
||||
}
|
||||
// Packets are sorted by number, so we can stop searching
|
||||
|
@ -231,27 +266,27 @@ func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, a
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm(now time.Time) {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if h.packetHistory.Len() == 0 {
|
||||
h.alarm = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(#496): Handle handshake packets separately
|
||||
// TODO(#497): TLP
|
||||
if !h.lossTime.IsZero() {
|
||||
if !h.handshakeComplete {
|
||||
h.alarm = now.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO
|
||||
h.alarm = time.Now().Add(h.computeRTOTimeout())
|
||||
h.alarm = now.Add(h.computeRTOTimeout())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets() {
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time) {
|
||||
h.lossTime = time.Time{}
|
||||
now := time.Now()
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
@ -260,11 +295,11 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
packet := el.Value
|
||||
|
||||
if packet.PacketNumber > h.LargestAcked {
|
||||
if packet.PacketNumber > h.largestAcked {
|
||||
break
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
timeSinceSent := now.Sub(packet.sendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, el)
|
||||
} else if h.lossTime.IsZero() {
|
||||
|
@ -282,18 +317,22 @@ func (h *sentPacketHandler) detectLostPackets() {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() {
|
||||
// TODO(#496): Handle handshake packets separately
|
||||
now := time.Now()
|
||||
|
||||
// TODO(#497): TLP
|
||||
if !h.lossTime.IsZero() {
|
||||
if !h.handshakeComplete {
|
||||
h.queueHandshakePacketsForRetransmission()
|
||||
h.handshakeCount++
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit or time loss detection
|
||||
h.detectLostPackets()
|
||||
h.detectLostPackets(now)
|
||||
} else {
|
||||
// RTO
|
||||
h.retransmitOldestTwoPackets()
|
||||
h.rtoCount++
|
||||
}
|
||||
|
||||
h.updateLossDetectionAlarm()
|
||||
h.updateLossDetectionAlarm(now)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
|
@ -303,6 +342,7 @@ func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
|||
func (h *sentPacketHandler) onPacketAcked(packetElement *PacketElement) {
|
||||
h.bytesInFlight -= packetElement.Value.Length
|
||||
h.rtoCount = 0
|
||||
h.handshakeCount = 0
|
||||
// TODO(#497): h.tlpCount = 0
|
||||
h.packetHistory.Remove(packetElement)
|
||||
}
|
||||
|
@ -320,20 +360,19 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
|
||||
return h.largestInOrderAcked() + 1
|
||||
return h.lowestUnacked()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendingAllowed() bool {
|
||||
congestionLimited := h.bytesInFlight > h.congestion.GetCongestionWindow()
|
||||
cwnd := h.congestion.GetCongestionWindow()
|
||||
congestionLimited := h.bytesInFlight > cwnd
|
||||
maxTrackedLimited := protocol.PacketNumber(len(h.retransmissionQueue)+h.packetHistory.Len()) >= protocol.MaxTrackedSentPackets
|
||||
if congestionLimited {
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d",
|
||||
h.bytesInFlight,
|
||||
h.congestion.GetCongestionWindow())
|
||||
utils.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
}
|
||||
// Workaround for #555:
|
||||
// Always allow sending of retransmissions. This should probably be limited
|
||||
|
@ -342,6 +381,18 @@ func (h *sentPacketHandler) SendingAllowed() bool {
|
|||
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
return h.nextPacketSendTime
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||
return 1
|
||||
}
|
||||
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) retransmitOldestTwoPackets() {
|
||||
if p := h.packetHistory.Front(); p != nil {
|
||||
h.queueRTO(p)
|
||||
|
@ -363,6 +414,18 @@ func (h *sentPacketHandler) queueRTO(el *PacketElement) {
|
|||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() {
|
||||
var handshakePackets []*PacketElement
|
||||
for el := h.packetHistory.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, el)
|
||||
}
|
||||
}
|
||||
for _, el := range handshakePackets {
|
||||
h.queuePacketForRetransmission(el)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketElement) {
|
||||
packet := &packetElement.Value
|
||||
h.bytesInFlight -= packet.Length
|
||||
|
@ -371,6 +434,17 @@ func (h *sentPacketHandler) queuePacketForRetransmission(packetElement *PacketEl
|
|||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(packet.PacketNumber)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
duration := 2 * h.rttStats.SmoothedRTT()
|
||||
if duration == 0 {
|
||||
duration = 2 * defaultInitialRTT
|
||||
}
|
||||
duration = utils.MaxDuration(duration, minTPLTimeout)
|
||||
// exponential backoff
|
||||
// There's an implicit limit to this set by the handshake timeout.
|
||||
return duration << h.handshakeCount
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
rto := h.congestion.RetransmissionDelay()
|
||||
if rto == 0 {
|
||||
|
@ -382,7 +456,7 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
|||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool {
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
|
@ -392,10 +466,10 @@ func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *frames.AckFrame) bool
|
|||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lioa := h.largestInOrderAcked()
|
||||
lowestUnacked := h.lowestUnacked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p <= lioa {
|
||||
if p < lowestUnacked {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
||||
|
@ -10,10 +10,10 @@ type stopWaitingManager struct {
|
|||
largestLeastUnackedSent protocol.PacketNumber
|
||||
nextLeastUnacked protocol.PacketNumber
|
||||
|
||||
lastStopWaitingFrame *frames.StopWaitingFrame
|
||||
lastStopWaitingFrame *wire.StopWaitingFrame
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaitingFrame {
|
||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
||||
if force {
|
||||
return s.lastStopWaitingFrame
|
||||
|
@ -22,14 +22,14 @@ func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *frames.StopWaiting
|
|||
}
|
||||
|
||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
||||
swf := &frames.StopWaitingFrame{
|
||||
swf := &wire.StopWaitingFrame{
|
||||
LeastUnacked: s.nextLeastUnacked,
|
||||
}
|
||||
s.lastStopWaitingFrame = swf
|
||||
return swf
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *frames.AckFrame) {
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
if ack.LargestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = ack.LargestAcked + 1
|
||||
}
|
|
@ -3,7 +3,7 @@ package congestion
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
|
@ -4,8 +4,8 @@ import (
|
|||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
|
@ -3,8 +3,8 @@ package congestion
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -76,15 +76,19 @@ func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestio
|
|||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
return c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold())
|
||||
if c.prr.TimeUntilSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) == 0 {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if c.GetCongestionWindow() > bytesInFlight {
|
||||
return 0
|
||||
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow()/protocol.DefaultTCPMSS)
|
||||
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||
delay = delay * 8 / 5
|
||||
}
|
||||
return utils.InfDuration
|
||||
return delay
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
|
@ -3,8 +3,8 @@ package congestion
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
|
@ -3,12 +3,12 @@ package congestion
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
MaybeExitSlowStart()
|
|
@ -3,8 +3,8 @@ package congestion
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
)
|
||||
|
||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
|
@ -7,6 +7,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// Note: This constant is also defined in the ackhandler package.
|
||||
initialRTTus = 100 * 1000
|
||||
rttAlpha float32 = 0.125
|
||||
oneMinusAlpha float32 = (1 - rttAlpha)
|
||||
|
@ -97,10 +98,10 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
|||
r.updateRecentMinRTT(sendDelta, now)
|
||||
|
||||
// Correct for ackDelay if information received from the peer results in a
|
||||
// positive RTT sample. Otherwise, we use the sendDelta as a reasonable
|
||||
// measure for smoothedRTT.
|
||||
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||
// sendDelta.
|
||||
sample := sendDelta
|
||||
if sample > ackDelay {
|
||||
if sample-r.minRTT >= ackDelay {
|
||||
sample -= ackDelay
|
||||
}
|
||||
r.latestRTT = sample
|
|
@ -1,6 +1,6 @@
|
|||
package congestion
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
type connectionStats struct {
|
||||
slowstartPacketsLost protocol.PacketNumber
|
|
@ -1,9 +1,10 @@
|
|||
package crypto
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/protocol"
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
// An AEAD implements QUIC's authenticated encryption and associated data
|
||||
type AEAD interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
Overhead() int
|
||||
}
|
72
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go
generated
vendored
Normal file
72
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm12_aead.go
generated
vendored
Normal file
|
@ -0,0 +1,72 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/aes12"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM12 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
var _ AEAD = &aeadAESGCM12{}
|
||||
|
||||
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
||||
//
|
||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||
// tag size, and couples the cipher and aes packages closely.
|
||||
// See https://github.com/lucas-clemente/aes12.
|
||||
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||
}
|
||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadAESGCM12{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Overhead() int {
|
||||
return aead.encrypter.Overhead()
|
||||
}
|
74
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go
generated
vendored
Normal file
74
vendor/github.com/lucas-clemente/quic-go/internal/crypto/aesgcm_aead.go
generated
vendored
Normal file
|
@ -0,0 +1,74 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
var _ AEAD = &aeadAESGCM{}
|
||||
|
||||
const ivLen = 12
|
||||
|
||||
// NewAEADAESGCM creates a AEAD using AES-GCM
|
||||
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
// the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce
|
||||
if len(otherIV) != ivLen || len(myIV) != ivLen {
|
||||
return nil, errors.New("AES-GCM: expected 12 byte IVs")
|
||||
}
|
||||
|
||||
encrypterCipher, err := aes.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := cipher.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := cipher.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &aeadAESGCM{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
nonce := make([]byte, ivLen)
|
||||
binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber))
|
||||
for i := 0; i < ivLen; i++ {
|
||||
nonce[i] ^= iv[i]
|
||||
}
|
||||
return nonce
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM) Overhead() int {
|
||||
return aead.encrypter.Overhead()
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"hash/fnv"
|
||||
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
|
@ -51,10 +51,10 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
|
|||
res.WriteByte(uint8(e.t))
|
||||
switch e.t {
|
||||
case entryCached:
|
||||
utils.WriteUint64(res, e.h)
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
case entryCommon:
|
||||
utils.WriteUint64(res, e.h)
|
||||
utils.WriteUint32(res, e.i)
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
utils.LittleEndian.WriteUint32(res, e.i)
|
||||
case entryCompressed:
|
||||
totalUncompressedLen += 4 + len(chain[i])
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
|
|||
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
||||
}
|
||||
|
||||
utils.WriteUint32(res, uint32(totalUncompressedLen))
|
||||
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
|
||||
|
||||
for i, e := range entries {
|
||||
if e.t != entryCompressed {
|
||||
|
@ -115,11 +115,11 @@ func decompressChain(data []byte) ([][]byte, error) {
|
|||
return nil, errors.New("unexpected cached certificate")
|
||||
case entryCommon:
|
||||
e := entry{t: entryCommon}
|
||||
e.h, err = utils.ReadUint64(r)
|
||||
e.h, err = utils.LittleEndian.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.i, err = utils.ReadUint32(r)
|
||||
e.i, err = utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func decompressChain(data []byte) ([][]byte, error) {
|
|||
}
|
||||
|
||||
if hasCompressedCerts {
|
||||
uncompressedLength, err := utils.ReadUint32(r)
|
||||
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
fmt.Println(4)
|
||||
return nil, err
|
|
@ -18,6 +18,7 @@ type CertManager interface {
|
|||
GetLeafCertHash() (uint64, error)
|
||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||
Verify(hostname string) error
|
||||
GetChain() []*x509.Certificate
|
||||
}
|
||||
|
||||
type certManager struct {
|
||||
|
@ -54,6 +55,10 @@ func (c *certManager) SetData(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *certManager) GetChain() []*x509.Certificate {
|
||||
return c.chain
|
||||
}
|
||||
|
||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||
return getCommonCertificateHashes()
|
||||
}
|
|
@ -4,11 +4,12 @@ package crypto
|
|||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/aead/chacha20"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadChacha20Poly1305 struct {
|
||||
|
@ -45,9 +46,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV
|
|||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
49
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
Normal file
49
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
generated
vendored
Normal file
|
@ -0,0 +1,49 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1-RTT Secret"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1-RTT Secret"
|
||||
)
|
||||
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
type TLSExporter interface {
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||
var myLabel, otherLabel string
|
||||
if pers == protocol.PerspectiveClient {
|
||||
myLabel = clientExporterLabel
|
||||
otherLabel = serverExporterLabel
|
||||
} else {
|
||||
myLabel = serverExporterLabel
|
||||
otherLabel = clientExporterLabel
|
||||
}
|
||||
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
||||
cs := tls.GetCipherSuite()
|
||||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
key = mint.HkdfExpandLabel(cs.Hash, secret, "key", nil, cs.KeyLen)
|
||||
iv = mint.HkdfExpandLabel(cs.Hash, secret, "iv", nil, cs.IvLen)
|
||||
return key, iv, nil
|
||||
}
|
|
@ -5,8 +5,8 @@ import (
|
|||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
@ -20,8 +20,8 @@ import (
|
|||
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
||||
// }
|
||||
|
||||
// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
||||
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
||||
var swap bool
|
||||
if pers == protocol.PerspectiveClient {
|
||||
swap = true
|
||||
|
@ -30,7 +30,7 @@ func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID pr
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
// deriveKeys derives the keys and the IVs
|
||||
|
@ -42,7 +42,7 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
|
|||
} else {
|
||||
info.Write([]byte("QUIC key expansion\x00"))
|
||||
}
|
||||
utils.WriteUint64(&info, uint64(connID))
|
||||
utils.BigEndian.WriteUint64(&info, uint64(connID))
|
||||
info.Write(chlo)
|
||||
info.Write(scfg)
|
||||
info.Write(cert)
|
11
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go
generated
vendored
Normal file
11
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead.go
generated
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
package crypto
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
|
||||
if v.UsesTLS() {
|
||||
return newNullAEADAESGCM(connID, p)
|
||||
}
|
||||
return &nullAEADFNV128a{perspective: p}, nil
|
||||
}
|
44
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
Normal file
44
vendor/github.com/lucas-clemente/quic-go/internal/crypto/null_aead_aesgcm.go
generated
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var quicVersion1Salt = []byte{0xaf, 0xc8, 0x24, 0xec, 0x5f, 0xc7, 0x7e, 0xca, 0x1e, 0x9d, 0x36, 0xf3, 0x7f, 0xb2, 0xd4, 0x65, 0x18, 0xc3, 0x66, 0x39}
|
||||
|
||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
mySecret = clientSecret
|
||||
otherSecret = serverSecret
|
||||
} else {
|
||||
mySecret = serverSecret
|
||||
otherSecret = clientSecret
|
||||
}
|
||||
|
||||
myKey, myIV := computeNullAEADKeyAndIV(mySecret)
|
||||
otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret)
|
||||
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeSecrets(connectionID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
connID := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(connID, uint64(connectionID))
|
||||
cleartextSecret := mint.HkdfExtract(crypto.SHA256, []byte(quicVersion1Salt), connID)
|
||||
clientSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC client cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
serverSecret = mint.HkdfExpandLabel(crypto.SHA256, cleartextSecret, "QUIC server cleartext Secret", []byte{}, crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||
key = mint.HkdfExpandLabel(crypto.SHA256, secret, "key", nil, 16)
|
||||
iv = mint.HkdfExpandLabel(crypto.SHA256, secret, "iv", nil, 12)
|
||||
return
|
||||
}
|
|
@ -5,27 +5,18 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/fnv128a"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// nullAEAD handles not-yet encrypted packets
|
||||
type nullAEAD struct {
|
||||
type nullAEADFNV128a struct {
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ AEAD = &nullAEAD{}
|
||||
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(p protocol.Perspective, v protocol.VersionNumber) AEAD {
|
||||
return &nullAEAD{
|
||||
perspective: p,
|
||||
version: v,
|
||||
}
|
||||
}
|
||||
var _ AEAD = &nullAEADFNV128a{}
|
||||
|
||||
// Open and verify the ciphertext
|
||||
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
if len(src) < 12 {
|
||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||
}
|
||||
|
@ -33,12 +24,10 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||
hash := fnv128a.New()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src[12:])
|
||||
if n.version >= protocol.Version37 {
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Client"))
|
||||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Client"))
|
||||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
testHigh, testLow := hash.Sum128()
|
||||
|
||||
|
@ -52,7 +41,7 @@ func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||
}
|
||||
|
||||
// Seal writes hash and ciphertext to the buffer
|
||||
func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
if cap(dst) < 12+len(src) {
|
||||
dst = make([]byte, 12+len(src))
|
||||
} else {
|
||||
|
@ -63,12 +52,10 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||
hash.Write(associatedData)
|
||||
hash.Write(src)
|
||||
|
||||
if n.version >= protocol.Version37 {
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Server"))
|
||||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Server"))
|
||||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
|
||||
high, low := hash.Sum128()
|
||||
|
@ -78,3 +65,7 @@ func (n *nullAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, ass
|
|||
binary.LittleEndian.PutUint32(dst[8:], uint32(high))
|
||||
return dst
|
||||
}
|
||||
|
||||
func (n *nullAEADFNV128a) Overhead() int {
|
||||
return 12
|
||||
}
|
108
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
108
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,108 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
|
||||
// for receiving data
|
||||
mutex sync.RWMutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *congestion.RTTStats
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
|
||||
// it returns true if the window was actually updated
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||
if c.bytesSent > c.sendWindow {
|
||||
return 0
|
||||
}
|
||||
return c.sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowSize()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
}
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch() {
|
||||
c.epochStartTime = time.Now()
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
83
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
83
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,83 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
lastBlockedAt protocol.ByteCount
|
||||
baseFlowController
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
|
||||
// NewConnectionFlowController gets a new flow controller for the connection
|
||||
// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0.
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
rttStats *congestion.RTTStats,
|
||||
) ConnectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.highestReceived += increment
|
||||
if c.checkFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
if inc > c.receiveWindowSize {
|
||||
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
|
||||
c.startNewAutoTuningEpoch()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
42
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
42
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
|
@ -0,0 +1,42 @@
|
|||
package flowcontrol
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsBlocked() (bool, protocol.ByteCount)
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// HasWindowUpdate says if it is necessary to update the window
|
||||
HasWindowUpdate() bool
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
// for sending
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount) error
|
||||
}
|
147
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
147
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,147 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type streamFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
connection connectionFlowControllerI
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
||||
var _ StreamFlowController = &streamFlowController{}
|
||||
|
||||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
contributesToConnection bool,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
rttStats *congestion.RTTStats,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
|
||||
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
|
||||
func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier
|
||||
if final && c.receivedFinalOffset && byteOffset != c.highestReceived {
|
||||
return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset))
|
||||
}
|
||||
// if we already received a final offset, check that the offset in the STREAM frames is below the final offset
|
||||
if c.receivedFinalOffset && byteOffset > c.highestReceived {
|
||||
return qerr.StreamDataAfterTermination
|
||||
}
|
||||
if final {
|
||||
c.receivedFinalOffset = true
|
||||
}
|
||||
if byteOffset == c.highestReceived {
|
||||
return nil
|
||||
}
|
||||
if byteOffset <= c.highestReceived {
|
||||
// a STREAM_FRAME with a higher offset was received before.
|
||||
if final {
|
||||
// If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream
|
||||
return qerr.StreamDataAfterTermination
|
||||
}
|
||||
// this is a reordered STREAM_FRAME
|
||||
return nil
|
||||
}
|
||||
|
||||
increment := byteOffset - c.highestReceived
|
||||
c.highestReceived = byteOffset
|
||||
if c.checkFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesRead(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
window := c.baseFlowController.sendWindowSize()
|
||||
if c.contributesToConnection {
|
||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||
}
|
||||
return window
|
||||
}
|
||||
|
||||
// IsBlocked says if it is blocked by stream-level flow control.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 {
|
||||
return false, 0
|
||||
}
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *streamFlowController) HasWindowUpdate() bool {
|
||||
c.mutex.Lock()
|
||||
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
return hasWindowUpdate
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||
c.mutex.Lock()
|
||||
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
|
||||
if c.receivedFinalOffset {
|
||||
c.mutex.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
utils.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
Normal file
101
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
Normal file
|
@ -0,0 +1,101 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
)
|
||||
|
||||
const (
|
||||
cookiePrefixIP byte = iota
|
||||
cookiePrefixString
|
||||
)
|
||||
|
||||
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Cookie struct {
|
||||
RemoteAddr string
|
||||
// The time that the STK was issued (resolution 1 second)
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
Data []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// A CookieGenerator generates Cookies
|
||||
type CookieGenerator struct {
|
||||
cookieProtector mint.CookieProtector
|
||||
}
|
||||
|
||||
// NewCookieGenerator initializes a new CookieGenerator
|
||||
func NewCookieGenerator() (*CookieGenerator, error) {
|
||||
cookieProtector, err := mint.NewDefaultCookieProtector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CookieGenerator{
|
||||
cookieProtector: cookieProtector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewToken generates a new Cookie for a given source address
|
||||
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
Data: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.cookieProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a Cookie
|
||||
func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
||||
// if the client didn't send any Cookie, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.cookieProtector.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
return &Cookie{
|
||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
||||
SentTime: time.Unix(t.Timestamp, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{cookiePrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
|
||||
// decodeRemoteAddr decodes the remote address saved in the Cookie
|
||||
func decodeRemoteAddr(data []byte) string {
|
||||
// data will never be empty for a Cookie that we generated. Check it to be on the safe side
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
if data[0] == cookiePrefixIP {
|
||||
return net.IP(data[1:]).String()
|
||||
}
|
||||
return string(data[1:])
|
||||
}
|
43
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
Normal file
43
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
generated
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type CookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &CookieHandler{}
|
||||
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
return false
|
||||
}
|
||||
return h.callback(conn.RemoteAddr(), data)
|
||||
}
|
|
@ -11,9 +11,9 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
|
@ -23,6 +23,7 @@ type cryptoSetupClient struct {
|
|||
hostname string
|
||||
connID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
initialVersion protocol.VersionNumber
|
||||
negotiatedVersions []protocol.VersionNumber
|
||||
|
||||
cryptoStream io.ReadWriter
|
||||
|
@ -42,17 +43,18 @@ type cryptoSetupClient struct {
|
|||
|
||||
clientHelloCounter int
|
||||
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
keyDerivation KeyDerivationFunction
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
||||
receivedSecurePacket bool
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
|
||||
params *TransportParameters
|
||||
connectionParameters ConnectionParametersManager
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupClient{}
|
||||
|
@ -65,36 +67,42 @@ var (
|
|||
|
||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||
func NewCryptoSetupClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
hostname string,
|
||||
connID protocol.ConnectionID,
|
||||
version protocol.VersionNumber,
|
||||
cryptoStream io.ReadWriter,
|
||||
tlsConfig *tls.Config,
|
||||
connectionParameters ConnectionParametersManager,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupClient{
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
cryptoStream: cryptoStream,
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
connectionParameters: connectionParameters,
|
||||
keyDerivation: crypto.DeriveKeysAESGCM,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||
aeadChanged: aeadChanged,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
params: params,
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
negotiatedVersions: negotiatedVersions,
|
||||
divNonceChan: make(chan []byte),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
messageChan := make(chan HandshakeMessage)
|
||||
errorChan := make(chan error)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
|
@ -141,15 +149,21 @@ func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|||
utils.Debugf("Got %s", message)
|
||||
switch message.Tag {
|
||||
case TagREJ:
|
||||
err = h.handleREJMessage(message.Data)
|
||||
if err := h.handleREJMessage(message.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
case TagSHLO:
|
||||
err = h.handleSHLOMessage(message.Data)
|
||||
params, err := h.handleSHLOMessage(message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,12 +229,12 @@ func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.receivedSecurePacket {
|
||||
return qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||
}
|
||||
|
||||
if sno, ok := cryptoData[TagSNO]; ok {
|
||||
|
@ -229,22 +243,22 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
|||
|
||||
serverPubs, ok := cryptoData[TagPUBS]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
verTag, ok := cryptoData[TagVER]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||
}
|
||||
if !h.validateVersionList(verTag) {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
}
|
||||
|
||||
nonce := append(h.nonc, h.sno...)
|
||||
|
||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
|
@ -261,39 +275,32 @@ func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) error {
|
|||
protocol.PerspectiveClient,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = h.connectionParameters.SetFromMap(cryptoData)
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
return qerr.InvalidCryptoMessageParameter
|
||||
return nil, qerr.InvalidCryptoMessageParameter
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionForwardSecure
|
||||
close(h.aeadChanged)
|
||||
|
||||
return nil
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
||||
if len(h.negotiatedVersions) == 0 {
|
||||
numNegotiatedVersions := len(h.negotiatedVersions)
|
||||
if numNegotiatedVersions == 0 {
|
||||
return true
|
||||
}
|
||||
if len(verTags)%4 != 0 || len(verTags)/4 != len(h.negotiatedVersions) {
|
||||
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
|
||||
return false
|
||||
}
|
||||
|
||||
b := bytes.NewReader(verTags)
|
||||
for _, negotiatedVersion := range h.negotiatedVersions {
|
||||
verTag, err := utils.ReadUint32(b)
|
||||
for i := 0; i < numNegotiatedVersions; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil { // should never occur, since the length was already checked
|
||||
return false
|
||||
}
|
||||
ver := protocol.VersionTagToNumber(verTag)
|
||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) {
|
||||
ver = protocol.VersionUnsupported
|
||||
}
|
||||
if ver != negotiatedVersion {
|
||||
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -333,16 +340,16 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.forwardSecureAEAD != nil {
|
||||
return protocol.EncryptionForwardSecure, h.sealForwardSecure
|
||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||
} else if h.secureAEAD != nil {
|
||||
return protocol.EncryptionSecure, h.sealSecure
|
||||
return protocol.EncryptionSecure, h.secureAEAD
|
||||
} else {
|
||||
return protocol.EncryptionUnencrypted, h.sealUnencrypted
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
return protocol.EncryptionUnencrypted, h.sealUnencrypted
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
|
@ -351,33 +358,21 @@ func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.Encry
|
|||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.sealUnencrypted, nil
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionSecure:
|
||||
if h.secureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
||||
}
|
||||
return h.sealSecure, nil
|
||||
return h.secureAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.forwardSecureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
||||
}
|
||||
return h.sealForwardSecure, nil
|
||||
return h.forwardSecureAEAD, nil
|
||||
}
|
||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return h.nullAEAD.Seal(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) DiversificationNonce() []byte {
|
||||
panic("not needed for cryptoSetupClient")
|
||||
}
|
||||
|
@ -386,6 +381,15 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
|||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||
PeerCertificates: h.certManager.GetChain(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
|
@ -413,15 +417,11 @@ func (h *cryptoSetupClient) sendCHLO() error {
|
|||
}
|
||||
|
||||
h.lastSentCHLO = b.Bytes()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||
tags, err := h.connectionParameters.GetHelloMap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tags := h.params.getHelloMap()
|
||||
tags[TagSNI] = []byte(h.hostname)
|
||||
tags[TagPDMD] = []byte("X509")
|
||||
|
||||
|
@ -431,12 +431,9 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
|||
}
|
||||
|
||||
versionTag := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
|
||||
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
|
||||
tags[TagVER] = versionTag
|
||||
|
||||
if h.params.RequestConnectionIDTruncation {
|
||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
||||
}
|
||||
if len(h.stk) > 0 {
|
||||
tags[TagSTK] = h.stk
|
||||
}
|
||||
|
@ -470,7 +467,7 @@ func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
|||
for _, tag := range tags {
|
||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||
}
|
||||
paddingSize := protocol.ClientHelloMinimumSize - size
|
||||
paddingSize := protocol.MinClientHelloSize - size
|
||||
if paddingSize > 0 {
|
||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||
}
|
||||
|
@ -508,10 +505,8 @@ func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.aeadChanged <- protocol.EncryptionSecure
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue