Jelajahi Sumber

Strict configuration parsing (#3773)

* Test configuration loading more precisely

* Add strict configuration parsing
Aarni Koskela 1 tahun lalu
induk
melakukan
e8deb65c4b

+ 1 - 1
client/admin_api.go

@@ -57,7 +57,7 @@ func (svr *Service) apiReload(w http.ResponseWriter, _ *http.Request) {
 		}
 	}()
 
-	cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile)
+	cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile, svr.strictConfig)
 	if err != nil {
 		res.Code = 400
 		res.Msg = err.Error()

+ 11 - 6
client/service.go

@@ -70,6 +70,9 @@ type Service struct {
 	// string if no configuration file was used.
 	cfgFile string
 
+	// Whether strict configuration parsing had been requested.
+	strictConfig bool
+
 	// service context
 	ctx context.Context
 	// call cancel to stop service
@@ -82,14 +85,16 @@ func NewService(
 	pxyCfgs []v1.ProxyConfigurer,
 	visitorCfgs []v1.VisitorConfigurer,
 	cfgFile string,
+	strictConfig bool,
 ) *Service {
 	return &Service{
-		authSetter:  auth.NewAuthSetter(cfg.Auth),
-		cfg:         cfg,
-		cfgFile:     cfgFile,
-		pxyCfgs:     pxyCfgs,
-		visitorCfgs: visitorCfgs,
-		ctx:         context.Background(),
+		authSetter:   auth.NewAuthSetter(cfg.Auth),
+		cfg:          cfg,
+		cfgFile:      cfgFile,
+		strictConfig: strictConfig,
+		pxyCfgs:      pxyCfgs,
+		visitorCfgs:  visitorCfgs,
+		ctx:          context.Background(),
 	}
 }
 

+ 1 - 1
cmd/frpc/sub/admin.go

@@ -52,7 +52,7 @@ func NewAdminCommand(name, short string, handler func(*v1.ClientCommonConfig) er
 		Use:   name,
 		Short: short,
 		Run: func(cmd *cobra.Command, args []string) {
-			cfg, _, _, _, err := config.LoadClientConfig(cfgFile)
+			cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig)
 			if err != nil {
 				fmt.Println(err)
 				os.Exit(1)

+ 1 - 1
cmd/frpc/sub/nathole.go

@@ -48,7 +48,7 @@ var natholeDiscoveryCmd = &cobra.Command{
 	Short: "Discover nathole information from stun server",
 	RunE: func(cmd *cobra.Command, args []string) error {
 		// ignore error here, because we can use command line pameters
-		cfg, _, _, _, err := config.LoadClientConfig(cfgFile)
+		cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig)
 		if err != nil {
 			cfg = &v1.ClientCommonConfig{}
 		}

+ 2 - 2
cmd/frpc/sub/proxy.go

@@ -84,7 +84,7 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm
 				fmt.Println(err)
 				os.Exit(1)
 			}
-			err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "")
+			err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "", strictConfig)
 			if err != nil {
 				fmt.Println(err)
 				os.Exit(1)
@@ -110,7 +110,7 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client
 				fmt.Println(err)
 				os.Exit(1)
 			}
-			err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "")
+			err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "", strictConfig)
 			if err != nil {
 				fmt.Println(err)
 				os.Exit(1)

+ 12 - 6
cmd/frpc/sub/root.go

@@ -36,15 +36,17 @@ import (
 )
 
 var (
-	cfgFile     string
-	cfgDir      string
-	showVersion bool
+	cfgFile      string
+	cfgDir       string
+	showVersion  bool
+	strictConfig bool
 )
 
 func init() {
 	rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "./frpc.ini", "config file of frpc")
 	rootCmd.PersistentFlags().StringVarP(&cfgDir, "config_dir", "", "", "config directory, run one frpc service for each file in config directory")
 	rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frpc")
+	rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode")
 }
 
 var rootCmd = &cobra.Command{
@@ -108,7 +110,7 @@ func handleTermSignal(svr *client.Service) {
 }
 
 func runClient(cfgFilePath string) error {
-	cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath)
+	cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfig)
 	if err != nil {
 		return err
 	}
@@ -120,11 +122,14 @@ func runClient(cfgFilePath string) error {
 	warning, err := validation.ValidateAllClientConfig(cfg, pxyCfgs, visitorCfgs)
 	if warning != nil {
 		fmt.Printf("WARNING: %v\n", warning)
+		if strictConfig {
+			return fmt.Errorf("warning: %v", warning)
+		}
 	}
 	if err != nil {
 		return err
 	}
-	return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath)
+	return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath, strictConfig)
 }
 
 func startService(
@@ -132,6 +137,7 @@ func startService(
 	pxyCfgs []v1.ProxyConfigurer,
 	visitorCfgs []v1.VisitorConfigurer,
 	cfgFile string,
+	strictConfig bool,
 ) error {
 	log.InitLog(cfg.Log.To, cfg.Log.Level, cfg.Log.MaxDays, cfg.Log.DisablePrintColor)
 
@@ -139,7 +145,7 @@ func startService(
 		log.Info("start frpc service for config file [%s]", cfgFile)
 		defer log.Info("frpc service for config file [%s] stopped", cfgFile)
 	}
-	svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile)
+	svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile, strictConfig)
 
 	shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic"
 	// Capture the exit signal if we use kcp or quic.

+ 1 - 1
cmd/frpc/sub/verify.go

@@ -37,7 +37,7 @@ var verifyCmd = &cobra.Command{
 			return nil
 		}
 
-		cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile)
+		cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile, strictConfig)
 		if err != nil {
 			fmt.Println(err)
 			os.Exit(1)

+ 5 - 3
cmd/frps/root.go

@@ -30,8 +30,9 @@ import (
 )
 
 var (
-	cfgFile     string
-	showVersion bool
+	cfgFile      string
+	showVersion  bool
+	strictConfig bool
 
 	serverCfg v1.ServerConfig
 )
@@ -39,6 +40,7 @@ var (
 func init() {
 	rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps")
 	rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps")
+	rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode")
 
 	RegisterServerConfigFlags(rootCmd, &serverCfg)
 }
@@ -58,7 +60,7 @@ var rootCmd = &cobra.Command{
 			err            error
 		)
 		if cfgFile != "" {
-			svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile)
+			svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile, strictConfig)
 			if err != nil {
 				fmt.Println(err)
 				os.Exit(1)

+ 1 - 1
cmd/frps/verify.go

@@ -36,7 +36,7 @@ var verifyCmd = &cobra.Command{
 			fmt.Println("frps: the configuration file is not specified")
 			return nil
 		}
-		svrCfg, _, err := config.LoadServerConfig(cfgFile)
+		svrCfg, _, err := config.LoadServerConfig(cfgFile, strictConfig)
 		if err != nil {
 			fmt.Println(err)
 			os.Exit(1)

+ 26 - 13
pkg/config/load.go

@@ -27,7 +27,7 @@ import (
 	"github.com/samber/lo"
 	"gopkg.in/ini.v1"
 	"k8s.io/apimachinery/pkg/util/sets"
-	"k8s.io/apimachinery/pkg/util/yaml"
+	yaml "k8s.io/apimachinery/pkg/util/yaml"
 
 	"github.com/fatedier/frp/pkg/config/legacy"
 	v1 "github.com/fatedier/frp/pkg/config/v1"
@@ -100,26 +100,39 @@ func LoadFileContentWithTemplate(path string, values *Values) ([]byte, error) {
 	return RenderWithTemplate(b, values)
 }
 
-func LoadConfigureFromFile(path string, c any) error {
+func LoadConfigureFromFile(path string, c any, strict bool) error {
 	content, err := LoadFileContentWithTemplate(path, GetValues())
 	if err != nil {
 		return err
 	}
-	return LoadConfigure(content, c)
+	return LoadConfigure(content, c, strict)
 }
 
 // LoadConfigure loads configuration from bytes and unmarshal into c.
 // Now it supports json, yaml and toml format.
-func LoadConfigure(b []byte, c any) error {
+func LoadConfigure(b []byte, c any, strict bool) error {
 	var tomlObj interface{}
+	// Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML).
+	// TODO: caller should probably be able to specify the format, so we don't need to swallow errors.
 	if err := toml.Unmarshal(b, &tomlObj); err == nil {
 		b, err = json.Marshal(&tomlObj)
 		if err != nil {
 			return err
 		}
 	}
-	decoder := yaml.NewYAMLOrJSONDecoder(bytes.NewBuffer(b), 4096)
-	return decoder.Decode(c)
+	// If the buffer smells like JSON (first non-whitespace character is '{'), unmarshal as JSON directly.
+	if yaml.IsJSONBuffer(b) {
+		decoder := json.NewDecoder(bytes.NewBuffer(b))
+		if strict {
+			decoder.DisallowUnknownFields()
+		}
+		return decoder.Decode(c)
+	}
+	// It wasn't JSON. Unmarshal as YAML.
+	if strict {
+		return yaml.UnmarshalStrict(b, c)
+	}
+	return yaml.Unmarshal(b, c)
 }
 
 func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) {
@@ -139,7 +152,7 @@ func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.
 	return configurer, nil
 }
 
-func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) {
+func LoadServerConfig(path string, strict bool) (*v1.ServerConfig, bool, error) {
 	var (
 		svrCfg         *v1.ServerConfig
 		isLegacyFormat bool
@@ -158,7 +171,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) {
 		isLegacyFormat = true
 	} else {
 		svrCfg = &v1.ServerConfig{}
-		if err := LoadConfigureFromFile(path, svrCfg); err != nil {
+		if err := LoadConfigureFromFile(path, svrCfg, strict); err != nil {
 			return nil, false, err
 		}
 	}
@@ -168,7 +181,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) {
 	return svrCfg, isLegacyFormat, nil
 }
 
-func LoadClientConfig(path string) (
+func LoadClientConfig(path string, strict bool) (
 	*v1.ClientCommonConfig,
 	[]v1.ProxyConfigurer,
 	[]v1.VisitorConfigurer,
@@ -196,7 +209,7 @@ func LoadClientConfig(path string) (
 		isLegacyFormat = true
 	} else {
 		allCfg := v1.ClientConfig{}
-		if err := LoadConfigureFromFile(path, &allCfg); err != nil {
+		if err := LoadConfigureFromFile(path, &allCfg, strict); err != nil {
 			return nil, nil, nil, false, err
 		}
 		cliCfg = &allCfg.ClientCommonConfig
@@ -211,7 +224,7 @@ func LoadClientConfig(path string) (
 	// Load additional config from includes.
 	// legacy ini format alredy handle this in ParseClientConfig.
 	if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat {
-		extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat)
+		extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat, strict)
 		if err != nil {
 			return nil, nil, nil, isLegacyFormat, err
 		}
@@ -242,7 +255,7 @@ func LoadClientConfig(path string) (
 	return cliCfg, pxyCfgs, visitorCfgs, isLegacyFormat, nil
 }
 
-func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
+func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) {
 	pxyCfgs := make([]v1.ProxyConfigurer, 0)
 	visitorCfgs := make([]v1.VisitorConfigurer, 0)
 	for _, path := range paths {
@@ -265,7 +278,7 @@ func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.Prox
 			if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched {
 				// support yaml/json/toml
 				cfg := v1.ClientConfig{}
-				if err := LoadConfigureFromFile(absFile, &cfg); err != nil {
+				if err := LoadConfigureFromFile(absFile, &cfg, strict); err != nil {
 					return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err)
 				}
 				for _, c := range cfg.Proxies {

+ 58 - 12
pkg/config/load_test.go

@@ -15,6 +15,7 @@
 package config
 
 import (
+	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/require"
@@ -22,9 +23,7 @@ import (
 	v1 "github.com/fatedier/frp/pkg/config/v1"
 )
 
-func TestLoadConfigure(t *testing.T) {
-	require := require.New(t)
-	content := `
+const tomlServerContent = `
 bindAddr = "127.0.0.1"
 kcpBindPort = 7000
 quicBindPort = 7001
@@ -33,13 +32,60 @@ custom404Page = "/abc.html"
 transport.tcpKeepalive = 10
 `
 
-	svrCfg := v1.ServerConfig{}
-	err := LoadConfigure([]byte(content), &svrCfg)
-	require.NoError(err)
-	require.EqualValues("127.0.0.1", svrCfg.BindAddr)
-	require.EqualValues(7000, svrCfg.KCPBindPort)
-	require.EqualValues(7001, svrCfg.QUICBindPort)
-	require.EqualValues(7005, svrCfg.TCPMuxHTTPConnectPort)
-	require.EqualValues("/abc.html", svrCfg.Custom404Page)
-	require.EqualValues(10, svrCfg.Transport.TCPKeepAlive)
+const yamlServerContent = `
+bindAddr: 127.0.0.1
+kcpBindPort: 7000
+quicBindPort: 7001
+tcpmuxHTTPConnectPort: 7005
+custom404Page: /abc.html
+transport:
+  tcpKeepalive: 10
+`
+
+const jsonServerContent = `
+{
+  "bindAddr": "127.0.0.1",
+  "kcpBindPort": 7000,
+  "quicBindPort": 7001,
+  "tcpmuxHTTPConnectPort": 7005,
+  "custom404Page": "/abc.html",
+  "transport": {
+    "tcpKeepalive": 10
+  }
+}
+`
+
+func TestLoadServerConfig(t *testing.T) {
+	for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} {
+		svrCfg := v1.ServerConfig{}
+		err := LoadConfigure([]byte(content), &svrCfg, true)
+		require := require.New(t)
+		require.NoError(err)
+		require.EqualValues("127.0.0.1", svrCfg.BindAddr)
+		require.EqualValues(7000, svrCfg.KCPBindPort)
+		require.EqualValues(7001, svrCfg.QUICBindPort)
+		require.EqualValues(7005, svrCfg.TCPMuxHTTPConnectPort)
+		require.EqualValues("/abc.html", svrCfg.Custom404Page)
+		require.EqualValues(10, svrCfg.Transport.TCPKeepAlive)
+	}
+}
+
+// Test that loading in strict mode fails when the config is invalid.
+func TestLoadServerConfigErrorMode(t *testing.T) {
+	for strict := range []bool{false, true} {
+		for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} {
+			// Break the content with an innocent typo
+			brokenContent := strings.Replace(content, "bindAddr", "bindAdur", 1)
+			svrCfg := v1.ServerConfig{}
+			err := LoadConfigure([]byte(brokenContent), &svrCfg, strict == 1)
+			require := require.New(t)
+			if strict == 1 {
+				require.ErrorContains(err, "bindAdur")
+			} else {
+				require.NoError(err)
+				// BindAddr didn't get parsed because of the typo.
+				require.EqualValues("", svrCfg.BindAddr)
+			}
+		}
+	}
 }