protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)

This commit is contained in:
fatedier
2026-05-06 10:43:47 +08:00
committed by GitHub
Unverified
parent 57bb9e80fe
commit 8666e3643f
15 changed files with 866 additions and 86 deletions
+197
View File
@@ -0,0 +1,197 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package wire
import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"hash"
"runtime"
"golang.org/x/sys/cpu"
)
const (
AEADAlgorithmAES256GCM = "aes-256-gcm"
AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305"
CryptoRandomSize = 32
cryptoTranscriptLabel = "frp wire v2 crypto transcript"
)
var supportedAEADAlgorithms = []string{
AEADAlgorithmAES256GCM,
AEADAlgorithmXChaCha20Poly1305,
}
type CryptoContext struct {
Algorithm string
TranscriptHash []byte
}
func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) {
clientRandom, err := newCryptoRandom()
if err != nil {
return ClientHello{}, err
}
return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil
}
func NewServerHello(clientHello ClientHello) (ServerHello, error) {
if err := ValidateClientHello(clientHello); err != nil {
return ServerHello{}, err
}
algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms)
if !ok {
return ServerHello{}, fmt.Errorf("no supported crypto algorithm")
}
serverRandom, err := newCryptoRandom()
if err != nil {
return ServerHello{}, err
}
return ServerHello{
Selected: ServerSelection{
Message: MessageSelection{
Codec: MessageCodecJSON,
},
Crypto: CryptoSelection{
Algorithm: algorithm,
ServerRandom: serverRandom,
},
},
}, nil
}
func ValidateCryptoCapabilities(c CryptoCapabilities) error {
if len(c.ClientRandom) != CryptoRandomSize {
return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize)
}
if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok {
return fmt.Errorf("no supported crypto algorithm")
}
return nil
}
func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error {
if serverHello.Selected.Message.Codec != MessageCodecJSON {
return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec)
}
cryptoSelection := serverHello.Selected.Crypto
if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) {
return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm)
}
if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) {
return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm)
}
if len(cryptoSelection.ServerRandom) != CryptoRandomSize {
return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize)
}
return nil
}
func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext {
return &CryptoContext{
Algorithm: algorithm,
TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload),
}
}
func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) {
var clientHello ClientHello
if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil {
return nil, fmt.Errorf("decode ClientHello transcript: %w", err)
}
var serverHello ServerHello
if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil {
return nil, fmt.Errorf("decode ServerHello transcript: %w", err)
}
if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil {
return nil, err
}
return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil
}
func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte {
h := sha256.New()
_, _ = h.Write([]byte(cryptoTranscriptLabel))
writeCryptoTranscriptPart(h, "client hello", clientHelloPayload)
writeCryptoTranscriptPart(h, "server hello", serverHelloPayload)
return h.Sum(nil)
}
func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) {
var length [8]byte
binary.BigEndian.PutUint64(length[:], uint64(len(payload)))
_, _ = h.Write([]byte{0})
_, _ = h.Write([]byte(label))
_, _ = h.Write([]byte{0})
_, _ = h.Write(length[:])
_, _ = h.Write(payload)
}
func PreferredAEADAlgorithms() []string {
if hasFastAESGCM() {
return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305}
}
return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
}
func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) {
for _, algorithm := range clientAlgorithms {
if IsSupportedAEADAlgorithm(algorithm) {
return algorithm, true
}
}
return "", false
}
func IsSupportedAEADAlgorithm(algorithm string) bool {
return Supports(supportedAEADAlgorithms, algorithm)
}
func newCryptoRandom() ([]byte, error) {
b := make([]byte, CryptoRandomSize)
if _, err := rand.Read(b); err != nil {
return nil, fmt.Errorf("generate crypto random: %w", err)
}
return b, nil
}
func hasFastAESGCM() bool {
switch runtime.GOARCH {
case "amd64":
return cpu.X86.HasAES &&
cpu.X86.HasPCLMULQDQ &&
cpu.X86.HasSSE41 &&
cpu.X86.HasSSSE3
case "arm64":
return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
case "s390x":
return cpu.S390X.HasAES &&
cpu.S390X.HasAESCTR &&
cpu.S390X.HasGHASH
case "ppc64", "ppc64le":
// Go's ppc64/ppc64le port targets POWER8+, which has AES instructions;
// x/sys/cpu does not expose a PPC64 AES feature flag.
return true
default:
return false
}
}
+31 -7
View File
@@ -120,15 +120,23 @@ func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
return json.Unmarshal(f.Payload, out)
}
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
func NewJSONFrame(frameType uint16, in any) (*Frame, error) {
payload, err := json.Marshal(in)
if err != nil {
return nil, err
}
return &Frame{
Type: frameType,
Payload: payload,
}, nil
}
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
f, err := NewJSONFrame(frameType, in)
if err != nil {
return err
}
return c.WriteFrame(&Frame{
Type: frameType,
Payload: payload,
})
return c.WriteFrame(f)
}
func WriteMagic(w io.Writer) error {
@@ -170,12 +178,18 @@ type ClientHello struct {
type ClientCapabilities struct {
Message MessageCapabilities `json:"message,omitempty"`
Crypto CryptoCapabilities `json:"crypto,omitempty"`
}
type MessageCapabilities struct {
Codecs []string `json:"codecs,omitempty"`
}
type CryptoCapabilities struct {
Algorithms []string `json:"algorithms,omitempty"`
ClientRandom []byte `json:"clientRandom,omitempty"`
}
type ServerHello struct {
Selected ServerSelection `json:"selected,omitempty"`
Error string `json:"error,omitempty"`
@@ -183,19 +197,29 @@ type ServerHello struct {
type ServerSelection struct {
Message MessageSelection `json:"message,omitempty"`
Crypto CryptoSelection `json:"crypto,omitempty"`
}
type MessageSelection struct {
Codec string `json:"codec,omitempty"`
}
func DefaultClientHello(bootstrap BootstrapInfo) ClientHello {
type CryptoSelection struct {
Algorithm string `json:"algorithm,omitempty"`
ServerRandom []byte `json:"serverRandom,omitempty"`
}
func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello {
return ClientHello{
Bootstrap: bootstrap,
Capabilities: ClientCapabilities{
Message: MessageCapabilities{
Codecs: []string{MessageCodecJSON},
},
Crypto: CryptoCapabilities{
Algorithms: PreferredAEADAlgorithms(),
ClientRandom: clientRandom,
},
},
}
}
@@ -218,5 +242,5 @@ func ValidateClientHello(h ClientHello) error {
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
return fmt.Errorf("unsupported message codec")
}
return nil
return ValidateCryptoCapabilities(h.Capabilities.Crypto)
}
+116 -3
View File
@@ -28,7 +28,7 @@ func TestFrameRoundTrip(t *testing.T) {
var buf bytes.Buffer
conn := NewConn(&buf)
in := DefaultClientHello(BootstrapInfo{
in := mustClientHello(t, BootstrapInfo{
Transport: "tcp",
TLS: true,
TCPMux: true,
@@ -112,9 +112,122 @@ func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
}
func TestValidateClientHello(t *testing.T) {
require.NoError(t, ValidateClientHello(DefaultClientHello(BootstrapInfo{})))
hello := mustClientHello(t, BootstrapInfo{})
require.NoError(t, ValidateClientHello(hello))
require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize)
require.ElementsMatch(t, []string{
AEADAlgorithmAES256GCM,
AEADAlgorithmXChaCha20Poly1305,
}, hello.Capabilities.Crypto.Algorithms)
hello := DefaultClientHello(BootstrapInfo{})
hello.Capabilities.Message.Codecs = []string{"unknown"}
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
}
func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{})
hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1]
require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length")
hello = mustClientHello(t, BootstrapInfo{})
hello.Capabilities.Crypto.Algorithms = []string{"unknown"}
require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm")
}
func TestPreferredAEADAlgorithms(t *testing.T) {
require.ElementsMatch(t, []string{
AEADAlgorithmAES256GCM,
AEADAlgorithmXChaCha20Poly1305,
}, PreferredAEADAlgorithms())
}
func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{})
hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
serverHello, err := NewServerHello(hello)
require.NoError(t, err)
require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec)
require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm)
require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize)
}
func TestNewClientCryptoContextValidatesServerHello(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{})
serverHello, err := NewServerHello(hello)
require.NoError(t, err)
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
require.NoError(t, err)
require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm)
require.Len(t, ctx.TranscriptHash, 32)
tampered := serverHello
tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...)
tampered.Selected.Crypto.ServerRandom[0] ^= 0xff
_, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered)
tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload)
require.NoError(t, err)
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
}
func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{
Transport: "tcp",
TLS: true,
TCPMux: true,
})
serverHello, err := NewServerHello(hello)
require.NoError(t, err)
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload)
tamperedHello := hello
tamperedHello.Bootstrap.TLS = false
tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello)
tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload)
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
}
func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{})
serverHello, err := NewServerHello(hello)
require.NoError(t, err)
serverHello.Selected.Crypto.Algorithm = "unknown"
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
require.ErrorContains(t, err, "unknown selected crypto algorithm")
}
func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) {
hello := mustClientHello(t, BootstrapInfo{})
hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM}
serverHello, err := NewServerHello(hello)
require.NoError(t, err)
serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client")
}
func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello {
t.Helper()
hello, err := NewClientHello(bootstrap)
require.NoError(t, err)
return hello
}
func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) {
t.Helper()
clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello)
require.NoError(t, err)
serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello)
require.NoError(t, err)
return clientHelloFrame.Payload, serverHelloFrame.Payload
}
+92 -3
View File
@@ -16,14 +16,16 @@ package net
import (
"context"
"crypto/sha256"
"errors"
"io"
"net"
"sync/atomic"
"time"
"github.com/fatedier/golib/crypto"
libcrypto "github.com/fatedier/golib/crypto"
quic "github.com/quic-go/quic-go"
"golang.org/x/crypto/hkdf"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
}
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
encReader := crypto.NewReader(rw, key)
encWriter, err := crypto.NewWriter(rw, key)
encReader := libcrypto.NewReader(rw, key)
encWriter, err := libcrypto.NewWriter(rw, key)
if err != nil {
return nil, err
}
@@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
Writer: encWriter,
}, nil
}
type AEADCryptoRole int
const (
AEADCryptoRoleClient AEADCryptoRole = iota + 1
AEADCryptoRoleServer
)
const (
aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
aeadDirectionClientToServer = "client-to-server"
aeadDirectionServerToClient = "server-to-client"
)
// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
// control channel. Frames and their order are authenticated, but end-of-stream
// is not: a clean EOF at a frame boundary is returned as normal EOF by the
// underlying AEAD stream. Protocols that need truncation detection for finite
// objects must add their own authenticated final message.
func NewAEADCryptoReadWriter(
rw io.ReadWriter,
key []byte,
role AEADCryptoRole,
algorithm string,
transcriptHash []byte,
) (io.ReadWriter, error) {
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
if err != nil {
return nil, err
}
var readKey, writeKey []byte
switch role {
case AEADCryptoRoleClient:
readKey = serverToClientKey
writeKey = clientToServerKey
case AEADCryptoRoleServer:
readKey = clientToServerKey
writeKey = serverToClientKey
default:
return nil, errors.New("invalid aead crypto role")
}
encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
Algorithm: libcrypto.AEADAlgorithm(algorithm),
Key: readKey,
})
if err != nil {
return nil, err
}
encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
Algorithm: libcrypto.AEADAlgorithm(algorithm),
Key: writeKey,
})
if err != nil {
return nil, err
}
return struct {
io.Reader
io.Writer
}{
Reader: encReader,
Writer: encWriter,
}, nil
}
func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
if err != nil {
return nil, nil, err
}
serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
if err != nil {
return nil, nil, err
}
return clientToServerKey, serverToClientKey, nil
}
func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
reader := hkdf.New(sha256.New, key, transcriptHash, info)
out := make([]byte, libcrypto.AEADKeySize)
if _, err := io.ReadFull(reader, out); err != nil {
return nil, err
}
return out, nil
}
+118
View File
@@ -0,0 +1,118 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"bytes"
"io"
stdnet "net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/fatedier/frp/pkg/proto/wire"
)
func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
clientConn, serverConn := stdnet.Pipe()
defer clientConn.Close()
defer serverConn.Close()
key := []byte("token")
transcriptHash := bytes.Repeat([]byte{0x11}, 32)
clientRW, err := NewAEADCryptoReadWriter(
clientConn,
key,
AEADCryptoRoleClient,
wire.AEADAlgorithmXChaCha20Poly1305,
transcriptHash,
)
require.NoError(t, err)
serverRW, err := NewAEADCryptoReadWriter(
serverConn,
key,
AEADCryptoRoleServer,
wire.AEADAlgorithmXChaCha20Poly1305,
transcriptHash,
)
require.NoError(t, err)
clientErrCh := make(chan error, 1)
go func() {
if _, err := clientRW.Write([]byte("ping")); err != nil {
clientErrCh <- err
return
}
buf := make([]byte, len("pong"))
_, err := io.ReadFull(clientRW, buf)
clientErrCh <- err
}()
buf := make([]byte, len("ping"))
_, err = io.ReadFull(serverRW, buf)
require.NoError(t, err)
require.Equal(t, "ping", string(buf))
_, err = serverRW.Write([]byte("pong"))
require.NoError(t, err)
require.NoError(t, <-clientErrCh)
}
func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
clientConn, serverConn := stdnet.Pipe()
defer clientConn.Close()
defer serverConn.Close()
require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
key := []byte("token")
clientRW, err := NewAEADCryptoReadWriter(
clientConn,
key,
AEADCryptoRoleClient,
wire.AEADAlgorithmAES256GCM,
bytes.Repeat([]byte{0x22}, 32),
)
require.NoError(t, err)
serverRW, err := NewAEADCryptoReadWriter(
serverConn,
key,
AEADCryptoRoleServer,
wire.AEADAlgorithmAES256GCM,
bytes.Repeat([]byte{0x33}, 32),
)
require.NoError(t, err)
writeErrCh := make(chan error, 1)
go func() {
_, err := clientRW.Write([]byte("ping"))
writeErrCh <- err
}()
buf := make([]byte, len("ping"))
_, err = io.ReadFull(serverRW, buf)
require.Error(t, err)
require.NoError(t, <-writeErrCh)
}
func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
[]byte("token"),
wire.AEADAlgorithmXChaCha20Poly1305,
bytes.Repeat([]byte{0x44}, 32),
)
require.NoError(t, err)
require.NotEqual(t, clientToServerKey, serverToClientKey)
}