protocol: add AEAD encryption negotiation to v2 wire control channel (#5304)
This commit is contained in:
committed by
GitHub
Unverified
parent
57bb9e80fe
commit
8666e3643f
@@ -29,6 +29,7 @@ import (
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type testConnector struct {
|
||||
@@ -140,8 +141,17 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
}
|
||||
|
||||
wireConn := wire.NewConn(serverRaw)
|
||||
clientHelloFrame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if clientHelloFrame.Type != wire.FrameTypeClientHello {
|
||||
serverErrCh <- fmt.Errorf("unexpected frame type %d, want %d", clientHelloFrame.Type, wire.FrameTypeClientHello)
|
||||
return
|
||||
}
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.ReadJSONFrame(wire.FrameTypeClientHello, &hello); err != nil {
|
||||
if err := wireConn.UnmarshalFrame(clientHelloFrame, &hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
@@ -160,11 +170,52 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
if err := wireConn.WriteJSONFrame(wire.FrameTypeServerHello, wire.DefaultServerHello()); err != nil {
|
||||
serverHello, err := wire.NewServerHello(hello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"})
|
||||
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
cryptoContext := wire.NewCryptoContext(
|
||||
serverHello.Selected.Crypto.Algorithm,
|
||||
clientHelloFrame.Payload,
|
||||
serverHelloFrame.Payload,
|
||||
)
|
||||
if err := wireConn.WriteFrame(serverHelloFrame); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
controlRW, err := netpkg.NewAEADCryptoReadWriter(
|
||||
serverRaw,
|
||||
[]byte("token"),
|
||||
netpkg.AEADCryptoRoleServer,
|
||||
cryptoContext.Algorithm,
|
||||
cryptoContext.TranscriptHash,
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
controlMsgRW := msg.NewReadWriter(controlRW, wire.ProtocolV2)
|
||||
var ping msg.Ping
|
||||
if err := controlMsgRW.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if ping.PrivilegeKey != "v2-ping" || ping.Timestamp != 12345 {
|
||||
serverErrCh <- fmt.Errorf("unexpected ping: %+v", ping)
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||
@@ -177,6 +228,7 @@ func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{PrivilegeKey: "v2-ping", Timestamp: 12345}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user