refactor: clean up code (#5308)
This commit is contained in:
committed by
GitHub
Unverified
parent
ad07d27914
commit
a88e0e9a49
@@ -14,11 +14,7 @@
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
import v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
|
||||
// ConfigSource implements Source for in-memory configuration.
|
||||
// All operations are thread-safe.
|
||||
@@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi
|
||||
|
||||
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
|
||||
for _, p := range proxies {
|
||||
if p == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := p.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
name, err := validateProxyName(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextProxies[name] = p
|
||||
}
|
||||
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
|
||||
for _, v := range visitors {
|
||||
if v == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := v.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
name, err := validateVisitorName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextVisitors[name] = v
|
||||
}
|
||||
|
||||
+100
-118
@@ -43,6 +43,11 @@ var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
const (
|
||||
storeKindProxy = "proxy"
|
||||
storeKindVisitor = "visitor"
|
||||
)
|
||||
|
||||
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
||||
if cfg.Path == "" {
|
||||
return nil, fmt.Errorf("path is required")
|
||||
@@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error {
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store map selectors return the target map for generic helpers.
|
||||
func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer {
|
||||
return s.proxies
|
||||
}
|
||||
|
||||
func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer {
|
||||
return s.visitors
|
||||
}
|
||||
|
||||
// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps.
|
||||
// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer.
|
||||
func addStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
if _, exists := entries[name]; exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name)
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
delete(entries, name)
|
||||
})
|
||||
}
|
||||
|
||||
func updateStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
func removeStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
return fmt.Errorf("%s name cannot be empty", kind)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.proxies[name]; exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
delete(entries, name)
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.proxies, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveProxy(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.proxies, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||
@@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||
}
|
||||
|
||||
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.visitors[name]; exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.visitors, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveVisitor(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.visitors, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {
|
||||
|
||||
@@ -17,6 +17,7 @@ package source
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T
|
||||
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
||||
}
|
||||
|
||||
func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
storeSource := newTestStoreSource(t)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists)
|
||||
require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists)
|
||||
require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty")
|
||||
require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty")
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
require.NoError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
require.NoError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
require.NoError(storeSource.RemoveProxy("proxy1"))
|
||||
require.Nil(storeSource.GetProxy("proxy1"))
|
||||
require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound)
|
||||
|
||||
require.NoError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.Nil(storeSource.GetVisitor("visitor1"))
|
||||
require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound)
|
||||
|
||||
require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound)
|
||||
require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("chmod does not make directories unwritable on Windows")
|
||||
}
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod does not block writes for uid 0")
|
||||
}
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
|
||||
require.NoError(err)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort
|
||||
originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
|
||||
require.NoError(os.Chmod(dir, 0o500))
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chmod(dir, 0o700)
|
||||
})
|
||||
|
||||
requirePersistError := func(err error) {
|
||||
t.Helper()
|
||||
require.Error(err)
|
||||
require.ErrorContains(err, "failed to persist")
|
||||
require.NotErrorIs(err, ErrAlreadyExists)
|
||||
require.NotErrorIs(err, ErrNotFound)
|
||||
}
|
||||
|
||||
requirePersistError(storeSource.AddProxy(mockProxy("proxy2")))
|
||||
require.Nil(storeSource.GetProxy("proxy2"))
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
requirePersistError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
requirePersistError(storeSource.RemoveProxy("proxy1"))
|
||||
require.NotNil(storeSource.GetProxy("proxy1"))
|
||||
|
||||
requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2")))
|
||||
require.Nil(storeSource.GetVisitor("visitor2"))
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
requirePersistError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
requirePersistError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.NotNil(storeSource.GetVisitor("visitor1"))
|
||||
}
|
||||
|
||||
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func validateProxyName(proxy v1.ProxyConfigurer) (string, error) {
|
||||
if proxy == nil {
|
||||
return "", fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) {
|
||||
if visitor == nil {
|
||||
return "", fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
Reference in New Issue
Block a user