protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)

This commit is contained in:
fatedier
2026-05-06 10:43:47 +08:00
committed by GitHub
Unverified
parent 57bb9e80fe
commit 8666e3643f
15 changed files with 866 additions and 86 deletions
+92 -3
View File
@@ -16,14 +16,16 @@ package net
import (
"context"
"crypto/sha256"
"errors"
"io"
"net"
"sync/atomic"
"time"
"github.com/fatedier/golib/crypto"
libcrypto "github.com/fatedier/golib/crypto"
quic "github.com/quic-go/quic-go"
"golang.org/x/crypto/hkdf"
"github.com/fatedier/frp/pkg/util/xlog"
)
@@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
}
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
encReader := crypto.NewReader(rw, key)
encWriter, err := crypto.NewWriter(rw, key)
encReader := libcrypto.NewReader(rw, key)
encWriter, err := libcrypto.NewWriter(rw, key)
if err != nil {
return nil, err
}
@@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
Writer: encWriter,
}, nil
}
type AEADCryptoRole int
const (
AEADCryptoRoleClient AEADCryptoRole = iota + 1
AEADCryptoRoleServer
)
const (
aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
aeadDirectionClientToServer = "client-to-server"
aeadDirectionServerToClient = "server-to-client"
)
// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
// control channel. Frames and their order are authenticated, but end-of-stream
// is not: a clean EOF at a frame boundary is returned as normal EOF by the
// underlying AEAD stream. Protocols that need truncation detection for finite
// objects must add their own authenticated final message.
func NewAEADCryptoReadWriter(
rw io.ReadWriter,
key []byte,
role AEADCryptoRole,
algorithm string,
transcriptHash []byte,
) (io.ReadWriter, error) {
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
if err != nil {
return nil, err
}
var readKey, writeKey []byte
switch role {
case AEADCryptoRoleClient:
readKey = serverToClientKey
writeKey = clientToServerKey
case AEADCryptoRoleServer:
readKey = clientToServerKey
writeKey = serverToClientKey
default:
return nil, errors.New("invalid aead crypto role")
}
encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
Algorithm: libcrypto.AEADAlgorithm(algorithm),
Key: readKey,
})
if err != nil {
return nil, err
}
encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
Algorithm: libcrypto.AEADAlgorithm(algorithm),
Key: writeKey,
})
if err != nil {
return nil, err
}
return struct {
io.Reader
io.Writer
}{
Reader: encReader,
Writer: encWriter,
}, nil
}
func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
if err != nil {
return nil, nil, err
}
serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
if err != nil {
return nil, nil, err
}
return clientToServerKey, serverToClientKey, nil
}
func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
reader := hkdf.New(sha256.New, key, transcriptHash, info)
out := make([]byte, libcrypto.AEADKeySize)
if _, err := io.ReadFull(reader, out); err != nil {
return nil, err
}
return out, nil
}
+118
View File
@@ -0,0 +1,118 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package net
import (
"bytes"
"io"
stdnet "net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/fatedier/frp/pkg/proto/wire"
)
func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
clientConn, serverConn := stdnet.Pipe()
defer clientConn.Close()
defer serverConn.Close()
key := []byte("token")
transcriptHash := bytes.Repeat([]byte{0x11}, 32)
clientRW, err := NewAEADCryptoReadWriter(
clientConn,
key,
AEADCryptoRoleClient,
wire.AEADAlgorithmXChaCha20Poly1305,
transcriptHash,
)
require.NoError(t, err)
serverRW, err := NewAEADCryptoReadWriter(
serverConn,
key,
AEADCryptoRoleServer,
wire.AEADAlgorithmXChaCha20Poly1305,
transcriptHash,
)
require.NoError(t, err)
clientErrCh := make(chan error, 1)
go func() {
if _, err := clientRW.Write([]byte("ping")); err != nil {
clientErrCh <- err
return
}
buf := make([]byte, len("pong"))
_, err := io.ReadFull(clientRW, buf)
clientErrCh <- err
}()
buf := make([]byte, len("ping"))
_, err = io.ReadFull(serverRW, buf)
require.NoError(t, err)
require.Equal(t, "ping", string(buf))
_, err = serverRW.Write([]byte("pong"))
require.NoError(t, err)
require.NoError(t, <-clientErrCh)
}
func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
clientConn, serverConn := stdnet.Pipe()
defer clientConn.Close()
defer serverConn.Close()
require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
key := []byte("token")
clientRW, err := NewAEADCryptoReadWriter(
clientConn,
key,
AEADCryptoRoleClient,
wire.AEADAlgorithmAES256GCM,
bytes.Repeat([]byte{0x22}, 32),
)
require.NoError(t, err)
serverRW, err := NewAEADCryptoReadWriter(
serverConn,
key,
AEADCryptoRoleServer,
wire.AEADAlgorithmAES256GCM,
bytes.Repeat([]byte{0x33}, 32),
)
require.NoError(t, err)
writeErrCh := make(chan error, 1)
go func() {
_, err := clientRW.Write([]byte("ping"))
writeErrCh <- err
}()
buf := make([]byte, len("ping"))
_, err = io.ReadFull(serverRW, buf)
require.Error(t, err)
require.NoError(t, <-writeErrCh)
}
func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
[]byte("token"),
wire.AEADAlgorithmXChaCha20Poly1305,
bytes.Repeat([]byte{0x44}, 32),
)
require.NoError(t, err)
require.NotEqual(t, clientToServerKey, serverToClientKey)
}