protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)

This commit is contained in:
fatedier
2026-05-06 10:43:47 +08:00
committed by GitHub
Unverified
parent 57bb9e80fe
commit 8666e3643f
15 changed files with 866 additions and 86 deletions
+69 -20
View File
@@ -60,6 +60,7 @@ import (
const (
connReadTimeout time.Duration = 10 * time.Second
connWriteTimeout time.Duration = 5 * time.Second
vhostReadWriteTimeout time.Duration = 30 * time.Second
)
@@ -456,7 +457,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
controlConn := acceptedConn.conn
if !internal {
var controlRW io.ReadWriter
controlRW, err = netpkg.NewCryptoReadWriter(conn, svr.auth.EncryptionKey())
controlRW, err = acceptedConn.newControlReadWriter(conn, svr.auth.EncryptionKey())
if err == nil {
controlConn = acceptedConn.messageConnFor(controlRW)
}
@@ -468,17 +469,23 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
if err != nil {
xl.Warnf("register control error: %v", err)
_ = acceptedConn.conn.WriteMsg(&msg.LoginResp{
Version: version.Full(),
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
})
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
Version: version.Full(),
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
})
}); writeErr != nil {
xl.Warnf("write login error response error: %v", writeErr)
}
conn.Close()
return
}
if err = acceptedConn.conn.WriteMsg(&msg.LoginResp{
Version: version.Full(),
RunID: ctl.runID,
Error: "",
if err = writeWithDeadline(conn, connWriteTimeout, func() error {
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
Version: version.Full(),
RunID: ctl.runID,
Error: "",
})
}); err != nil {
xl.Warnf("write login response error: %v", err)
svr.ctlManager.Del(m.RunID, ctl)
@@ -521,9 +528,10 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
}
type acceptedConnection struct {
conn *msg.Conn
wireProtocol string
firstMsg msg.Message
conn *msg.Conn
wireProtocol string
cryptoContext *wire.CryptoContext
firstMsg msg.Message
}
func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
@@ -544,7 +552,7 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
wireConn := wire.NewConn(conn)
rw := msg.NewV2ReadWriterWithConn(wireConn)
acceptedConn.conn = msg.NewConn(conn, rw)
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(wireConn)
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(conn, wireConn)
} else {
rw := msg.NewV1ReadWriter(conn)
acceptedConn.conn = msg.NewConn(conn, rw)
@@ -557,17 +565,41 @@ func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*accep
return acceptedConn, nil
}
func writeWithDeadline(conn net.Conn, timeout time.Duration, writeFn func() error) error {
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
defer func() {
_ = conn.SetWriteDeadline(time.Time{})
}()
return writeFn()
}
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
}
func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message, error) {
func (ac *acceptedConnection) newControlReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
if ac.wireProtocol == wire.ProtocolV2 {
if ac.cryptoContext == nil {
return nil, fmt.Errorf("missing v2 crypto negotiation")
}
return netpkg.NewAEADCryptoReadWriter(
rw,
key,
netpkg.AEADCryptoRoleServer,
ac.cryptoContext.Algorithm,
ac.cryptoContext.TranscriptHash,
)
}
return netpkg.NewCryptoReadWriter(rw, key)
}
func (ac *acceptedConnection) readFirstV2Msg(conn net.Conn, wireConn *wire.Conn) (msg.Message, error) {
frame, err := wireConn.ReadFrame()
if err != nil {
return nil, fmt.Errorf("read v2 frame: %w", err)
}
if frame.Type == wire.FrameTypeClientHello {
if err := ac.handleClientHello(wireConn, frame); err != nil {
if err := ac.handleClientHello(conn, wireConn, frame); err != nil {
return nil, err
}
frame, err = wireConn.ReadFrame()
@@ -583,21 +615,38 @@ func (ac *acceptedConnection) readFirstV2Msg(wireConn *wire.Conn) (msg.Message,
return m, nil
}
func (ac *acceptedConnection) handleClientHello(wireConn *wire.Conn, frame *wire.Frame) error {
func (ac *acceptedConnection) handleClientHello(conn net.Conn, wireConn *wire.Conn, frame *wire.Frame) error {
var hello wire.ClientHello
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
return fmt.Errorf("decode ClientHello: %w", err)
}
serverHello := wire.DefaultServerHello()
if err := wire.ValidateClientHello(hello); err != nil {
serverHello, err := wire.NewServerHello(hello)
if err != nil {
serverHello = wire.DefaultServerHello()
serverHello.Error = err.Error()
_ = wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
return wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
}); writeErr != nil {
return fmt.Errorf("%w; write ServerHello error: %v", err, writeErr)
}
return err
}
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello); err != nil {
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
if err != nil {
return fmt.Errorf("encode ServerHello: %w", err)
}
cryptoContext := wire.NewCryptoContext(
serverHello.Selected.Crypto.Algorithm,
frame.Payload,
serverHelloFrame.Payload,
)
if err := writeWithDeadline(conn, connWriteTimeout, func() error {
return wireConn.WriteFrame(serverHelloFrame)
}); err != nil {
return fmt.Errorf("write ServerHello: %w", err)
}
ac.cryptoContext = cryptoContext
return nil
}
+63
View File
@@ -0,0 +1,63 @@
// Copyright 2026 The frp Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestWriteWithDeadlineTimesOutAndClearsDeadline(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
err := writeWithDeadline(serverConn, 50*time.Millisecond, func() error {
_, writeErr := serverConn.Write([]byte("x"))
return writeErr
})
require.Error(t, err)
var netErr net.Error
require.True(t, errors.As(err, &netErr))
require.True(t, netErr.Timeout())
readCh := make(chan byte, 1)
errCh := make(chan error, 1)
go func() {
buf := make([]byte, 1)
if _, readErr := clientConn.Read(buf); readErr != nil {
errCh <- readErr
return
}
readCh <- buf[0]
}()
_, err = serverConn.Write([]byte("y"))
require.NoError(t, err)
select {
case b := <-readCh:
require.Equal(t, byte('y'), b)
case err := <-errCh:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timed out waiting for write after deadline reset")
}
}