Pārlūkot izejas kodu

support tls connection

fatedier 6 gadi atpakaļ
vecāks
revīzija
d812488767

+ 9 - 2
client/control.go

@@ -15,6 +15,7 @@
 package client
 
 import (
+	"crypto/tls"
 	"fmt"
 	"io"
 	"runtime/debug"
@@ -166,8 +167,14 @@ func (ctl *Control) connectServer() (conn frpNet.Conn, err error) {
 		}
 		conn = frpNet.WrapConn(stream)
 	} else {
-		conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
-			fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort))
+		var tlsConfig *tls.Config
+		if g.GlbClientCfg.TLSEnable {
+			tlsConfig = &tls.Config{
+				InsecureSkipVerify: true,
+			}
+		}
+		conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
+			fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig)
 		if err != nil {
 			ctl.Warn("start new connection to server error: %v", err)
 			return

+ 9 - 2
client/service.go

@@ -15,6 +15,7 @@
 package client
 
 import (
+	"crypto/tls"
 	"fmt"
 	"io/ioutil"
 	"runtime"
@@ -151,8 +152,14 @@ func (svr *Service) keepControllerWorking() {
 // conn: control connection
 // session: if it's not nil, using tcp mux
 func (svr *Service) login() (conn frpNet.Conn, session *fmux.Session, err error) {
-	conn, err = frpNet.ConnectServerByProxy(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
-		fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort))
+	var tlsConfig *tls.Config
+	if g.GlbClientCfg.TLSEnable {
+		tlsConfig = &tls.Config{
+			InsecureSkipVerify: true,
+		}
+	}
+	conn, err = frpNet.ConnectServerByProxyWithTLS(g.GlbClientCfg.HttpProxy, g.GlbClientCfg.Protocol,
+		fmt.Sprintf("%s:%d", g.GlbClientCfg.ServerAddr, g.GlbClientCfg.ServerPort), tlsConfig)
 	if err != nil {
 		return
 	}

+ 3 - 0
conf/frpc_full.ini

@@ -44,6 +44,9 @@ login_fail_exit = true
 # now it supports tcp and kcp and websocket, default is tcp
 protocol = tcp
 
+# if tls_enable is true, frpc will connect frps by tls
+tls_enable = true
+
 # specify a dns server, so frpc will use this instead of default one
 # dns_server = 8.8.8.8
 

+ 8 - 0
models/config/client_common.go

@@ -44,6 +44,7 @@ type ClientCommonConf struct {
 	LoginFailExit     bool                `json:"login_fail_exit"`
 	Start             map[string]struct{} `json:"start"`
 	Protocol          string              `json:"protocol"`
+	TLSEnable         bool                `json:"tls_enable"`
 	HeartBeatInterval int64               `json:"heartbeat_interval"`
 	HeartBeatTimeout  int64               `json:"heartbeat_timeout"`
 }
@@ -69,6 +70,7 @@ func GetDefaultClientConf() *ClientCommonConf {
 		LoginFailExit:     true,
 		Start:             make(map[string]struct{}),
 		Protocol:          "tcp",
+		TLSEnable:         false,
 		HeartBeatInterval: 30,
 		HeartBeatTimeout:  90,
 	}
@@ -194,6 +196,12 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c
 		cfg.Protocol = tmpStr
 	}
 
+	if tmpStr, ok = conf.Get("common", "tls_enable"); ok && tmpStr == "true" {
+		cfg.TLSEnable = true
+	} else {
+		cfg.TLSEnable = false
+	}
+
 	if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok {
 		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
 			err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")

+ 41 - 0
server/service.go

@@ -16,8 +16,14 @@ package server
 
 import (
 	"bytes"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
 	"fmt"
 	"io/ioutil"
+	"math/big"
 	"net"
 	"net/http"
 	"time"
@@ -61,6 +67,9 @@ type Service struct {
 	// Accept connections using websocket
 	websocketListener frpNet.Listener
 
+	// Accept frp tls connections
+	tlsListener frpNet.Listener
+
 	// Manage all controllers
 	ctlManager *ControlManager
 
@@ -72,6 +81,8 @@ type Service struct {
 
 	// stats collector to store server and proxies stats info
 	statsCollector stats.Collector
+
+	tlsConfig *tls.Config
 }
 
 func NewService() (svr *Service, err error) {
@@ -84,6 +95,7 @@ func NewService() (svr *Service, err error) {
 			TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
 			UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
 		},
+		tlsConfig: generateTLSConfig(),
 	}
 
 	// Init group controller
@@ -187,6 +199,12 @@ func NewService() (svr *Service, err error) {
 		log.Info("https service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort)
 	}
 
+	// frp tls listener
+	tlsListener := svr.muxer.Listen(1, 1, func(data []byte) bool {
+		return int(data[0]) == frpNet.FRP_TLS_HEAD_BYTE
+	})
+	svr.tlsListener = frpNet.WrapLogListener(tlsListener)
+
 	// Create nat hole controller.
 	if cfg.BindUdpPort > 0 {
 		var nc *nathole.NatHoleController
@@ -225,6 +243,7 @@ func (svr *Service) Run() {
 	}
 
 	go svr.HandleListener(svr.websocketListener)
+	go svr.HandleListener(svr.tlsListener)
 
 	svr.HandleListener(svr.listener)
 }
@@ -237,6 +256,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) {
 			log.Warn("Listener for incoming connections from client closed")
 			return
 		}
+		c = frpNet.CheckAndEnableTLSServerConn(c, svr.tlsConfig)
 
 		// Start a new goroutine for dealing connections.
 		go func(frpConn frpNet.Conn) {
@@ -373,3 +393,24 @@ func (svr *Service) RegisterVisitorConn(visitorConn frpNet.Conn, newMsg *msg.New
 	return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
 		newMsg.UseEncryption, newMsg.UseCompression)
 }
+
+// Setup a bare-bones TLS config for the server
+func generateTLSConfig() *tls.Config {
+	key, err := rsa.GenerateKey(rand.Reader, 1024)
+	if err != nil {
+		panic(err)
+	}
+	template := x509.Certificate{SerialNumber: big.NewInt(1)}
+	certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
+	if err != nil {
+		panic(err)
+	}
+	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
+	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+
+	tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
+	if err != nil {
+		panic(err)
+	}
+	return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
+}

+ 188 - 0
tests/ci/tls_test.go

@@ -0,0 +1,188 @@
+package ci
+
+import (
+	"os"
+	"testing"
+	"time"
+
+	"github.com/fatedier/frp/tests/config"
+	"github.com/fatedier/frp/tests/consts"
+	"github.com/fatedier/frp/tests/util"
+
+	"github.com/stretchr/testify/assert"
+)
+
+const FRPS_TLS_TCP_CONF = `
+[common]
+bind_addr = 0.0.0.0
+bind_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+`
+
+const FRPC_TLS_TCP_CONF = `
+[common]
+server_addr = 127.0.0.1
+server_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+protocol = tcp
+tls_enable = true
+
+[tcp]
+type = tcp
+local_port = 10701
+remote_port = 20801
+`
+
+func TestTlsOverTCP(t *testing.T) {
+	assert := assert.New(t)
+	frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_TCP_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpsCfgPath)
+	}
+
+	frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_TCP_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpcCfgPath)
+	}
+
+	frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
+	err = frpsProcess.Start()
+	if assert.NoError(err) {
+		defer frpsProcess.Stop()
+	}
+
+	time.Sleep(100 * time.Millisecond)
+
+	frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
+	err = frpcProcess.Start()
+	if assert.NoError(err) {
+		defer frpcProcess.Stop()
+	}
+	time.Sleep(250 * time.Millisecond)
+
+	// test tcp
+	res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
+	assert.NoError(err)
+	assert.Equal(consts.TEST_TCP_ECHO_STR, res)
+}
+
+const FRPS_TLS_KCP_CONF = `
+[common]
+bind_addr = 0.0.0.0
+bind_port = 20000
+kcp_bind_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+`
+
+const FRPC_TLS_KCP_CONF = `
+[common]
+server_addr = 127.0.0.1
+server_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+protocol = kcp
+tls_enable = true
+
+[tcp]
+type = tcp
+local_port = 10701
+remote_port = 20801
+`
+
+func TestTLSOverKCP(t *testing.T) {
+	assert := assert.New(t)
+	frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_KCP_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpsCfgPath)
+	}
+
+	frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_KCP_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpcCfgPath)
+	}
+
+	frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
+	err = frpsProcess.Start()
+	if assert.NoError(err) {
+		defer frpsProcess.Stop()
+	}
+
+	time.Sleep(200 * time.Millisecond)
+
+	frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
+	err = frpcProcess.Start()
+	if assert.NoError(err) {
+		defer frpcProcess.Stop()
+	}
+	time.Sleep(500 * time.Millisecond)
+
+	// test tcp
+	res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
+	assert.NoError(err)
+	assert.Equal(consts.TEST_TCP_ECHO_STR, res)
+}
+
+const FRPS_TLS_WS_CONF = `
+[common]
+bind_addr = 0.0.0.0
+bind_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+`
+
+const FRPC_TLS_WS_CONF = `
+[common]
+server_addr = 127.0.0.1
+server_port = 20000
+log_file = console
+log_level = debug
+token = 123456
+protocol = websocket
+tls_enable = true
+
+[tcp]
+type = tcp
+local_port = 10701
+remote_port = 20801
+`
+
+func TestTLSOverWebsocket(t *testing.T) {
+	assert := assert.New(t)
+	frpsCfgPath, err := config.GenerateConfigFile(consts.FRPS_NORMAL_CONFIG, FRPS_TLS_WS_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpsCfgPath)
+	}
+
+	frpcCfgPath, err := config.GenerateConfigFile(consts.FRPC_NORMAL_CONFIG, FRPC_TLS_WS_CONF)
+	if assert.NoError(err) {
+		defer os.Remove(frpcCfgPath)
+	}
+
+	frpsProcess := util.NewProcess(consts.FRPS_BIN_PATH, []string{"-c", frpsCfgPath})
+	err = frpsProcess.Start()
+	if assert.NoError(err) {
+		defer frpsProcess.Stop()
+	}
+
+	time.Sleep(200 * time.Millisecond)
+
+	frpcProcess := util.NewProcess(consts.FRPC_BIN_PATH, []string{"-c", frpcCfgPath})
+	err = frpcProcess.Start()
+	if assert.NoError(err) {
+		defer frpcProcess.Stop()
+	}
+	time.Sleep(500 * time.Millisecond)
+
+	// test tcp
+	res, err := util.SendTcpMsg("127.0.0.1:20801", consts.TEST_TCP_ECHO_STR)
+	assert.NoError(err)
+	assert.Equal(consts.TEST_TCP_ECHO_STR, res)
+}

+ 11 - 0
utils/net/conn.go

@@ -15,6 +15,7 @@
 package net
 
 import (
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"io"
@@ -207,3 +208,13 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
 		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
 	}
 }
+
+func ConnectServerByProxyWithTLS(proxyUrl string, protocol string, addr string, tlsConfig *tls.Config) (c Conn, err error) {
+	c, err = ConnectServerByProxy(proxyUrl, protocol, addr)
+	if tlsConfig == nil {
+		return
+	}
+
+	c = WrapTLSClientConn(c, tlsConfig)
+	return
+}

+ 44 - 0
utils/net/tls.go

@@ -0,0 +1,44 @@
+// Copyright 2019 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package net
+
+import (
+	"crypto/tls"
+	"net"
+
+	gnet "github.com/fatedier/golib/net"
+)
+
+var (
+	FRP_TLS_HEAD_BYTE = 0x17
+)
+
+func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
+	c.Write([]byte{byte(FRP_TLS_HEAD_BYTE)})
+	out = WrapConn(tls.Client(c, tlsConfig))
+	return
+}
+
+func CheckAndEnableTLSServerConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
+	sc, r := gnet.NewSharedConnSize(c, 1)
+	buf := make([]byte, 1)
+	n, _ := r.Read(buf)
+	if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE {
+		out = WrapConn(tls.Server(c, tlsConfig))
+	} else {
+		out = WrapConn(sc)
+	}
+	return
+}

+ 1 - 1
utils/version/version.go

@@ -19,7 +19,7 @@ import (
 	"strings"
 )
 
-var version string = "0.24.1"
+var version string = "0.25.0"
 
 func Full() string {
 	return version