refactor: clean up code (#5308)
This commit is contained in:
committed by
GitHub
Unverified
parent
ad07d27914
commit
a88e0e9a49
+46
-18
@@ -7,6 +7,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
@@ -38,16 +40,25 @@ type Manager struct {
|
||||
|
||||
bindAddr string
|
||||
netType string
|
||||
clock clock.WithTicker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
|
||||
return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
pm := &Manager{
|
||||
reservedPorts: make(map[string]*PortCtx),
|
||||
usedPorts: make(map[int]*PortCtx),
|
||||
freePorts: make(map[int]struct{}),
|
||||
bindAddr: bindAddr,
|
||||
netType: netType,
|
||||
clock: clk,
|
||||
}
|
||||
if len(allowPorts) > 0 {
|
||||
for _, pair := range allowPorts {
|
||||
@@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
portCtx := &PortCtx{
|
||||
ProxyName: name,
|
||||
Closed: false,
|
||||
UpdateTime: time.Now(),
|
||||
UpdateTime: pm.clock.Now(),
|
||||
}
|
||||
|
||||
var ok bool
|
||||
@@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
if ctx, ok := pm.reservedPorts[name]; ok {
|
||||
if pm.isPortAvailable(ctx.Port) {
|
||||
realPort = ctx.Port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
}
|
||||
if pm.isPortAvailable(k) {
|
||||
realPort = k
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
if _, ok = pm.freePorts[port]; ok {
|
||||
if pm.isPortAvailable(port) {
|
||||
realPort = port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
} else {
|
||||
err = ErrPortUnAvailable
|
||||
}
|
||||
@@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// markPortAcquiredLocked records a successful acquisition. pm.mu must be held.
|
||||
func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) {
|
||||
pm.usedPorts[port] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, port)
|
||||
}
|
||||
|
||||
func (pm *Manager) isPortAvailable(port int) bool {
|
||||
if pm.netType == "udp" {
|
||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
|
||||
@@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) {
|
||||
pm.freePorts[port] = struct{}{}
|
||||
delete(pm.usedPorts, port)
|
||||
ctx.Closed = true
|
||||
ctx.UpdateTime = time.Now()
|
||||
ctx.UpdateTime = pm.clock.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Release reserved port if it isn't used in last 24 hours.
|
||||
func (pm *Manager) cleanReservedPortsWorker() {
|
||||
pm.cleanReservedPortsWorkerUntil(nil)
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) {
|
||||
ticker := pm.clock.NewTicker(CleanReservedPortsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
time.Sleep(CleanReservedPortsInterval)
|
||||
pm.mu.Lock()
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
select {
|
||||
case <-ticker.C():
|
||||
pm.cleanReservedPortsOnce()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsOnce() {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
func TestManagerUsesClockForPortTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
require.Equal(start, pm.usedPorts[port].UpdateTime)
|
||||
|
||||
releasedAt := start.Add(time.Minute)
|
||||
clk.SetTime(releasedAt)
|
||||
pm.Release(port)
|
||||
|
||||
require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime)
|
||||
}
|
||||
|
||||
func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
pm.Release(port)
|
||||
require.True(pm.hasReservedPort("proxy"))
|
||||
|
||||
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||
clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute)
|
||||
|
||||
require.Eventually(func() bool {
|
||||
return !pm.hasReservedPort("proxy")
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
func (pm *Manager) hasReservedPort(name string) bool {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
_, ok := pm.reservedPorts[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func freeTCPPort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// ClientInfo captures metadata about a connected frpc instance.
|
||||
@@ -42,12 +44,21 @@ type ClientRegistry struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*ClientInfo
|
||||
runIndex map[string]string
|
||||
clock clock.PassiveClock
|
||||
}
|
||||
|
||||
func NewClientRegistry() *ClientRegistry {
|
||||
return newClientRegistryWithClock(clock.RealClock{})
|
||||
}
|
||||
|
||||
func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &ClientRegistry{
|
||||
clients: make(map[string]*ClientInfo),
|
||||
runIndex: make(map[string]string),
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
||||
key = cr.composeClientKey(user, effectiveID)
|
||||
enforceUnique := rawClientID != ""
|
||||
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
@@ -116,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) {
|
||||
} else {
|
||||
info.RunID = ""
|
||||
info.Online = false
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
info.DisconnectedAt = now
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,9 @@ package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
@@ -35,3 +38,37 @@ func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
|
||||
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRegistryUsesClockForTimestamps(t *testing.T) {
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
registry := newClientRegistryWithClock(clk)
|
||||
|
||||
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||
if conflict {
|
||||
t.Fatal("unexpected client conflict")
|
||||
}
|
||||
|
||||
info, ok := registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found", key)
|
||||
}
|
||||
if !info.FirstConnectedAt.Equal(start) {
|
||||
t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt)
|
||||
}
|
||||
if !info.LastConnectedAt.Equal(start) {
|
||||
t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt)
|
||||
}
|
||||
|
||||
disconnectedAt := start.Add(time.Minute)
|
||||
clk.SetTime(disconnectedAt)
|
||||
registry.MarkOfflineByRunID("run-id")
|
||||
|
||||
info, ok = registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found after disconnect", key)
|
||||
}
|
||||
if !info.DisconnectedAt.Equal(disconnectedAt) {
|
||||
t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user