protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)
This commit is contained in:
committed by
GitHub
Unverified
parent
57bb9e80fe
commit
8666e3643f
@@ -0,0 +1,197 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
const (
|
||||
AEADAlgorithmAES256GCM = "aes-256-gcm"
|
||||
AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305"
|
||||
|
||||
CryptoRandomSize = 32
|
||||
|
||||
cryptoTranscriptLabel = "frp wire v2 crypto transcript"
|
||||
)
|
||||
|
||||
var supportedAEADAlgorithms = []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}
|
||||
|
||||
type CryptoContext struct {
|
||||
Algorithm string
|
||||
TranscriptHash []byte
|
||||
}
|
||||
|
||||
func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) {
|
||||
clientRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ClientHello{}, err
|
||||
}
|
||||
return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil
|
||||
}
|
||||
|
||||
func NewServerHello(clientHello ClientHello) (ServerHello, error) {
|
||||
if err := ValidateClientHello(clientHello); err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms)
|
||||
if !ok {
|
||||
return ServerHello{}, fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
serverRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
return ServerHello{
|
||||
Selected: ServerSelection{
|
||||
Message: MessageSelection{
|
||||
Codec: MessageCodecJSON,
|
||||
},
|
||||
Crypto: CryptoSelection{
|
||||
Algorithm: algorithm,
|
||||
ServerRandom: serverRandom,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ValidateCryptoCapabilities(c CryptoCapabilities) error {
|
||||
if len(c.ClientRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize)
|
||||
}
|
||||
if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok {
|
||||
return fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error {
|
||||
if serverHello.Selected.Message.Codec != MessageCodecJSON {
|
||||
return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec)
|
||||
}
|
||||
cryptoSelection := serverHello.Selected.Crypto
|
||||
if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if len(cryptoSelection.ServerRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext {
|
||||
return &CryptoContext{
|
||||
Algorithm: algorithm,
|
||||
TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) {
|
||||
var clientHello ClientHello
|
||||
if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ClientHello transcript: %w", err)
|
||||
}
|
||||
var serverHello ServerHello
|
||||
if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ServerHello transcript: %w", err)
|
||||
}
|
||||
if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil
|
||||
}
|
||||
|
||||
func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte {
|
||||
h := sha256.New()
|
||||
_, _ = h.Write([]byte(cryptoTranscriptLabel))
|
||||
writeCryptoTranscriptPart(h, "client hello", clientHelloPayload)
|
||||
writeCryptoTranscriptPart(h, "server hello", serverHelloPayload)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) {
|
||||
var length [8]byte
|
||||
binary.BigEndian.PutUint64(length[:], uint64(len(payload)))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write([]byte(label))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write(length[:])
|
||||
_, _ = h.Write(payload)
|
||||
}
|
||||
|
||||
func PreferredAEADAlgorithms() []string {
|
||||
if hasFastAESGCM() {
|
||||
return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305}
|
||||
}
|
||||
return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
}
|
||||
|
||||
func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) {
|
||||
for _, algorithm := range clientAlgorithms {
|
||||
if IsSupportedAEADAlgorithm(algorithm) {
|
||||
return algorithm, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func IsSupportedAEADAlgorithm(algorithm string) bool {
|
||||
return Supports(supportedAEADAlgorithms, algorithm)
|
||||
}
|
||||
|
||||
func newCryptoRandom() ([]byte, error) {
|
||||
b := make([]byte, CryptoRandomSize)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, fmt.Errorf("generate crypto random: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func hasFastAESGCM() bool {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return cpu.X86.HasAES &&
|
||||
cpu.X86.HasPCLMULQDQ &&
|
||||
cpu.X86.HasSSE41 &&
|
||||
cpu.X86.HasSSSE3
|
||||
case "arm64":
|
||||
return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
case "s390x":
|
||||
return cpu.S390X.HasAES &&
|
||||
cpu.S390X.HasAESCTR &&
|
||||
cpu.S390X.HasGHASH
|
||||
case "ppc64", "ppc64le":
|
||||
// Go's ppc64/ppc64le port targets POWER8+, which has AES instructions;
|
||||
// x/sys/cpu does not expose a PPC64 AES feature flag.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
+31
-7
@@ -120,15 +120,23 @@ func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
|
||||
return json.Unmarshal(f.Payload, out)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
func NewJSONFrame(frameType uint16, in any) (*Frame, error) {
|
||||
payload, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
f, err := NewJSONFrame(frameType, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WriteFrame(&Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
})
|
||||
return c.WriteFrame(f)
|
||||
}
|
||||
|
||||
func WriteMagic(w io.Writer) error {
|
||||
@@ -170,12 +178,18 @@ type ClientHello struct {
|
||||
|
||||
type ClientCapabilities struct {
|
||||
Message MessageCapabilities `json:"message,omitempty"`
|
||||
Crypto CryptoCapabilities `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageCapabilities struct {
|
||||
Codecs []string `json:"codecs,omitempty"`
|
||||
}
|
||||
|
||||
type CryptoCapabilities struct {
|
||||
Algorithms []string `json:"algorithms,omitempty"`
|
||||
ClientRandom []byte `json:"clientRandom,omitempty"`
|
||||
}
|
||||
|
||||
type ServerHello struct {
|
||||
Selected ServerSelection `json:"selected,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
@@ -183,19 +197,29 @@ type ServerHello struct {
|
||||
|
||||
type ServerSelection struct {
|
||||
Message MessageSelection `json:"message,omitempty"`
|
||||
Crypto CryptoSelection `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageSelection struct {
|
||||
Codec string `json:"codec,omitempty"`
|
||||
}
|
||||
|
||||
func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
|
||||
type CryptoSelection struct {
|
||||
Algorithm string `json:"algorithm,omitempty"`
|
||||
ServerRandom []byte `json:"serverRandom,omitempty"`
|
||||
}
|
||||
|
||||
func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello {
|
||||
return ClientHello{
|
||||
Bootstrap: bootstrap,
|
||||
Capabilities: ClientCapabilities{
|
||||
Message: MessageCapabilities{
|
||||
Codecs: []string{MessageCodecJSON},
|
||||
},
|
||||
Crypto: CryptoCapabilities{
|
||||
Algorithms: PreferredAEADAlgorithms(),
|
||||
ClientRandom: clientRandom,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -218,5 +242,5 @@ func ValidateClientHello(h ClientHello) error {
|
||||
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
|
||||
return fmt.Errorf("unsupported message codec")
|
||||
}
|
||||
return nil
|
||||
return ValidateCryptoCapabilities(h.Capabilities.Crypto)
|
||||
}
|
||||
|
||||
+116
-3
@@ -28,7 +28,7 @@ func TestFrameRoundTrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
conn := NewConn(&buf)
|
||||
|
||||
in := DefaultClientHello(BootstrapInfo{
|
||||
in := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
@@ -112,9 +112,122 @@ func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateClientHello(t *testing.T) {
|
||||
require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
require.NoError(t, ValidateClientHello(hello))
|
||||
require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize)
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, hello.Capabilities.Crypto.Algorithms)
|
||||
|
||||
hello := DefaultClientHello(BootstrapInfo{})
|
||||
hello.Capabilities.Message.Codecs = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
|
||||
}
|
||||
|
||||
func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1]
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length")
|
||||
|
||||
hello = mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm")
|
||||
}
|
||||
|
||||
func TestPreferredAEADAlgorithms(t *testing.T) {
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, PreferredAEADAlgorithms())
|
||||
}
|
||||
|
||||
func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec)
|
||||
require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm)
|
||||
require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextValidatesServerHello(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm)
|
||||
require.Len(t, ctx.TranscriptHash, 32)
|
||||
|
||||
tampered := serverHello
|
||||
tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...)
|
||||
tampered.Selected.Crypto.ServerRandom[0] ^= 0xff
|
||||
_, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered)
|
||||
tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload)
|
||||
|
||||
tamperedHello := hello
|
||||
tamperedHello.Bootstrap.TLS = false
|
||||
tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello)
|
||||
tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = "unknown"
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "unknown selected crypto algorithm")
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM}
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client")
|
||||
}
|
||||
|
||||
func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello {
|
||||
t.Helper()
|
||||
|
||||
hello, err := NewClientHello(bootstrap)
|
||||
require.NoError(t, err)
|
||||
return hello
|
||||
}
|
||||
|
||||
func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello)
|
||||
require.NoError(t, err)
|
||||
serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello)
|
||||
require.NoError(t, err)
|
||||
return clientHelloFrame.Payload, serverHelloFrame.Payload
|
||||
}
|
||||
|
||||
+92
-3
@@ -16,14 +16,16 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/crypto"
|
||||
libcrypto "github.com/fatedier/golib/crypto"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
@@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
|
||||
}
|
||||
|
||||
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
encReader := crypto.NewReader(rw, key)
|
||||
encWriter, err := crypto.NewWriter(rw, key)
|
||||
encReader := libcrypto.NewReader(rw, key)
|
||||
encWriter, err := libcrypto.NewWriter(rw, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type AEADCryptoRole int
|
||||
|
||||
const (
|
||||
AEADCryptoRoleClient AEADCryptoRole = iota + 1
|
||||
AEADCryptoRoleServer
|
||||
)
|
||||
|
||||
const (
|
||||
aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
|
||||
aeadDirectionClientToServer = "client-to-server"
|
||||
aeadDirectionServerToClient = "server-to-client"
|
||||
)
|
||||
|
||||
// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
|
||||
// control channel. Frames and their order are authenticated, but end-of-stream
|
||||
// is not: a clean EOF at a frame boundary is returned as normal EOF by the
|
||||
// underlying AEAD stream. Protocols that need truncation detection for finite
|
||||
// objects must add their own authenticated final message.
|
||||
func NewAEADCryptoReadWriter(
|
||||
rw io.ReadWriter,
|
||||
key []byte,
|
||||
role AEADCryptoRole,
|
||||
algorithm string,
|
||||
transcriptHash []byte,
|
||||
) (io.ReadWriter, error) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var readKey, writeKey []byte
|
||||
switch role {
|
||||
case AEADCryptoRoleClient:
|
||||
readKey = serverToClientKey
|
||||
writeKey = clientToServerKey
|
||||
case AEADCryptoRoleServer:
|
||||
readKey = clientToServerKey
|
||||
writeKey = serverToClientKey
|
||||
default:
|
||||
return nil, errors.New("invalid aead crypto role")
|
||||
}
|
||||
|
||||
encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: readKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: writeKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{
|
||||
Reader: encReader,
|
||||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
|
||||
clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return clientToServerKey, serverToClientKey, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
|
||||
info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
|
||||
reader := hkdf.New(sha256.New, key, transcriptHash, info)
|
||||
out := make([]byte, libcrypto.AEADKeySize)
|
||||
if _, err := io.ReadFull(reader, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
stdnet "net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
key := []byte("token")
|
||||
transcriptHash := bytes.Repeat([]byte{0x11}, 32)
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
if _, err := clientRW.Write([]byte("ping")); err != nil {
|
||||
clientErrCh <- err
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len("pong"))
|
||||
_, err := io.ReadFull(clientRW, buf)
|
||||
clientErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ping", string(buf))
|
||||
_, err = serverRW.Write([]byte("pong"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, <-clientErrCh)
|
||||
}
|
||||
|
||||
func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
|
||||
key := []byte("token")
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x22}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x33}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := clientRW.Write([]byte("ping"))
|
||||
writeErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, <-writeErrCh)
|
||||
}
|
||||
|
||||
func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
|
||||
[]byte("token"),
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
bytes.Repeat([]byte{0x44}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, clientToServerKey, serverToClientKey)
|
||||
}
|
||||
Reference in New Issue
Block a user