diff --git a/client/proxy/xtcp.go b/client/proxy/xtcp.go index 41dc5229..aef66780 100644 --- a/client/proxy/xtcp.go +++ b/client/proxy/xtcp.go @@ -57,8 +57,7 @@ func NewXTCPProxy(baseProxy *BaseProxy, cfg v1.ProxyConfigurer) Proxy { func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkConn) { xl := pxy.xl defer conn.Close() - var natHoleSidMsg msg.NatHoleSid - err := msg.ReadMsgInto(conn, &natHoleSidMsg) + natHoleSidMsg, err := readNatHoleSid(conn, pxy.clientCfg.Transport.WireProtocol) if err != nil { xl.Errorf("xtcp read from workConn error: %v", err) return @@ -131,6 +130,15 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, startWorkConnMsg *msg.StartWorkC pxy.listenByQUIC(listenConn, raddr, startWorkConnMsg) } +func readNatHoleSid(conn net.Conn, wireProtocol string) (*msg.NatHoleSid, error) { + workMsgConn := msg.NewConn(conn, msg.NewReadWriter(conn, wireProtocol)) + var natHoleSidMsg msg.NatHoleSid + if err := workMsgConn.ReadMsgInto(&natHoleSidMsg); err != nil { + return nil, err + } + return &natHoleSidMsg, nil +} + func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, startWorkConnMsg *msg.StartWorkConn) { xl := pxy.xl listenConn.Close() diff --git a/client/proxy/xtcp_test.go b/client/proxy/xtcp_test.go new file mode 100644 index 00000000..bd295a07 --- /dev/null +++ b/client/proxy/xtcp_test.go @@ -0,0 +1,66 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !frps + +package proxy + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestReadNatHoleSidUsesSelectedWireProtocol(t *testing.T) { + for _, tc := range []struct { + name string + wireProtocol string + }{ + {name: "v2", wireProtocol: wire.ProtocolV2}, + {name: "v1", wireProtocol: wire.ProtocolV1}, + {name: "default", wireProtocol: ""}, + } { + t.Run(tc.name, func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + setPipeDeadline(t, client, server) + + errCh := make(chan error, 1) + go func() { + writer := msg.NewConn(server, msg.NewReadWriter(server, tc.wireProtocol)) + errCh <- writer.WriteMsg(&msg.NatHoleSid{Sid: "sid"}) + }() + + out, err := readNatHoleSid(client, tc.wireProtocol) + require.NoError(t, err) + require.Equal(t, "sid", out.Sid) + require.NoError(t, <-errCh) + }) + } +} + +func setPipeDeadline(t *testing.T, conns ...net.Conn) { + t.Helper() + + deadline := time.Now().Add(time.Second) + for _, conn := range conns { + require.NoError(t, conn.SetDeadline(deadline)) + } +} diff --git a/server/proxy/xtcp.go b/server/proxy/xtcp.go index bef7320e..a6488d42 100644 --- a/server/proxy/xtcp.go +++ b/server/proxy/xtcp.go @@ -16,6 +16,7 @@ package proxy import ( "fmt" + "net" "reflect" "sync" @@ -73,10 +74,7 @@ func (pxy *XTCPProxy) Run() (remoteAddr string, err error) { if errRet != nil { continue } - m := &msg.NatHoleSid{ - Sid: sid, - } - errRet = msg.WriteMsg(workConn, m) + errRet = writeNatHoleSid(workConn, pxy.wireProtocol, sid) if errRet != nil { xl.Warnf("write nat hole sid package error, %v", errRet) } @@ -87,6 +85,13 @@ func (pxy *XTCPProxy) Run() (remoteAddr string, err error) { return } +func writeNatHoleSid(workConn net.Conn, wireProtocol string, sid string) error { + workMsgConn := msg.NewConn(workConn, msg.NewReadWriter(workConn, wireProtocol)) + return workMsgConn.WriteMsg(&msg.NatHoleSid{ + Sid: sid, + }) +} + func (pxy *XTCPProxy) Close() { pxy.closeOnce.Do(func() { pxy.BaseProxy.Close() diff --git a/server/proxy/xtcp_test.go b/server/proxy/xtcp_test.go new file mode 100644 index 00000000..b2e6c1ad --- /dev/null +++ b/server/proxy/xtcp_test.go @@ -0,0 +1,93 @@ +// Copyright 2026 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "bufio" + "encoding/binary" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/fatedier/frp/pkg/msg" + "github.com/fatedier/frp/pkg/proto/wire" +) + +func TestWriteNatHoleSidUsesWireV2MessageFrame(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + setPipeDeadline(t, client, server) + + errCh := make(chan error, 1) + go func() { + errCh <- writeNatHoleSid(server, wire.ProtocolV2, "sid-v2") + }() + + frame, err := wire.NewConn(client).ReadFrame() + require.NoError(t, err) + require.Equal(t, wire.FrameTypeMessage, frame.Type) + require.GreaterOrEqual(t, len(frame.Payload), 2) + require.Equal(t, msg.V2TypeNatHoleSid, binary.BigEndian.Uint16(frame.Payload[:2])) + + var out msg.NatHoleSid + require.NoError(t, msg.DecodeV2MessageFrameInto(frame, &out)) + require.Equal(t, "sid-v2", out.Sid) + require.NoError(t, <-errCh) +} + +func TestWriteNatHoleSidUsesLegacyCodecForWireV1AndDefault(t *testing.T) { + for _, tc := range []struct { + name string + wireProtocol string + }{ + {name: "default", wireProtocol: ""}, + {name: "v1", wireProtocol: wire.ProtocolV1}, + } { + t.Run(tc.name, func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + setPipeDeadline(t, client, server) + + errCh := make(chan error, 1) + go func() { + errCh <- writeNatHoleSid(server, tc.wireProtocol, "sid-legacy") + }() + + reader := bufio.NewReader(client) + typeByte, err := reader.ReadByte() + require.NoError(t, err) + require.Equal(t, msg.TypeNatHoleSid, typeByte) + require.NoError(t, reader.UnreadByte()) + + var out msg.NatHoleSid + require.NoError(t, msg.ReadMsgInto(reader, &out)) + require.Equal(t, "sid-legacy", out.Sid) + require.NoError(t, <-errCh) + }) + } +} + +func setPipeDeadline(t *testing.T, conns ...net.Conn) { + t.Helper() + + deadline := time.Now().Add(time.Second) + for _, conn := range conns { + require.NoError(t, conn.SetDeadline(deadline)) + } +}