Browse Source

frps: release resources in service.Close() (#4667)

fatedier 4 weeks ago
parent
commit
e0dd947e6a
6 changed files with 35 additions and 37 deletions
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 4 0
      pkg/ssh/gateway.go
  4. 4 0
      pkg/util/vhost/vhost.go
  5. 10 0
      server/controller/resource.go
  6. 14 34
      server/service.go

+ 1 - 1
go.mod

@@ -5,7 +5,7 @@ go 1.22.0
 require (
 	github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
 	github.com/coreos/go-oidc/v3 v3.10.0
-	github.com/fatedier/golib v0.5.0
+	github.com/fatedier/golib v0.5.1
 	github.com/google/uuid v1.6.0
 	github.com/gorilla/mux v1.8.1
 	github.com/gorilla/websocket v1.5.0

+ 2 - 2
go.sum

@@ -21,8 +21,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
 github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/fatedier/golib v0.5.0 h1:hNcH7hgfIFqVWbP+YojCCAj4eO94pPf4dEF8lmq2jWs=
-github.com/fatedier/golib v0.5.0/go.mod h1:W6kIYkIFxHsTzbgqg5piCxIiDo4LzwgTY6R5W8l9NFQ=
+github.com/fatedier/golib v0.5.1 h1:hcKAnaw5mdI/1KWRGejxR+i1Hn/NvbY5UsMKDr7o13M=
+github.com/fatedier/golib v0.5.1/go.mod h1:W6kIYkIFxHsTzbgqg5piCxIiDo4LzwgTY6R5W8l9NFQ=
 github.com/fatedier/yamux v0.0.0-20230628132301-7aca4898904d h1:ynk1ra0RUqDWQfvFi5KtMiSobkVQ3cNc0ODb8CfIETo=
 github.com/fatedier/yamux v0.0.0-20230628132301-7aca4898904d/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
 github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U=

+ 4 - 0
pkg/ssh/gateway.go

@@ -112,6 +112,10 @@ func (g *Gateway) Run() {
 	}
 }
 
+func (g *Gateway) Close() error {
+	return g.ln.Close()
+}
+
 func (g *Gateway) handleConn(conn net.Conn) {
 	defer conn.Close()
 

+ 4 - 0
pkg/util/vhost/vhost.go

@@ -100,6 +100,10 @@ func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer {
 	return v
 }
 
+func (v *Muxer) Close() error {
+	return v.listener.Close()
+}
+
 type ChooseEndpointFunc func() (string, error)
 
 type CreateConnFunc func(remoteAddr string) (net.Conn, error)

+ 10 - 0
server/controller/resource.go

@@ -59,3 +59,13 @@ type ResourceController struct {
 	// All server manager plugin
 	PluginManager *plugin.Manager
 }
+
+func (rc *ResourceController) Close() error {
+	if rc.VhostHTTPSMuxer != nil {
+		rc.VhostHTTPSMuxer.Close()
+	}
+	if rc.TCPMuxHTTPConnectMuxer != nil {
+		rc.TCPMuxHTTPConnectMuxer.Close()
+	}
+	return nil
+}

+ 14 - 34
server/service.go

@@ -77,7 +77,7 @@ type Service struct {
 	muxer *mux.Mux
 
 	// Accept connections from client
-	muxListener net.Listener
+	listener net.Listener
 
 	// Accept connections using kcp
 	kcpListener net.Listener
@@ -125,11 +125,6 @@ type Service struct {
 	ctx context.Context
 	// call cancel to stop service
 	cancel context.CancelFunc
-
-	// Track listeners so they can be closed manually
-	vhostHTTPSListener        net.Listener
-	tcpmuxHTTPConnectListener net.Listener
-	tcpListener               net.Listener
 }
 
 func NewService(cfg *v1.ServerConfig) (*Service, error) {
@@ -185,8 +180,6 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
 			return nil, fmt.Errorf("create server listener error, %v", err)
 		}
 
-		// Save listener so it can be closed in svr.Close()
-		svr.tcpmuxHTTPConnectListener = l
 		svr.rc.TCPMuxHTTPConnectMuxer, err = tcpmux.NewHTTPConnectTCPMuxer(l, cfg.TCPMuxPassthrough, vhostReadWriteTimeout)
 		if err != nil {
 			return nil, fmt.Errorf("create vhost tcpMuxer error, %v", err)
@@ -233,16 +226,14 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
 		return nil, fmt.Errorf("create server listener error, %v", err)
 	}
 
-	// Save listener so it can be closed in svr.Close()
-	svr.tcpListener = ln
-
 	svr.muxer = mux.NewMux(ln)
 	svr.muxer.SetKeepAlive(time.Duration(cfg.Transport.TCPKeepAlive) * time.Second)
 	go func() {
 		_ = svr.muxer.Serve()
 	}()
 	ln = svr.muxer.DefaultListener()
-	svr.muxListener = ln
+
+	svr.listener = ln
 	log.Infof("frps tcp listen on %s", address)
 
 	// Listen for accepting connections from client using kcp protocol.
@@ -327,8 +318,7 @@ func NewService(cfg *v1.ServerConfig) (*Service, error) {
 			}
 			log.Infof("https service listen on %s", address)
 		}
-		// Save listener so it can be closed in svr.Close()
-		svr.vhostHTTPSListener = l
+
 		svr.rc.VhostHTTPSMuxer, err = vhost.NewHTTPSMuxer(l, vhostReadWriteTimeout)
 		if err != nil {
 			return nil, fmt.Errorf("create vhost httpsMuxer error, %v", err)
@@ -384,11 +374,11 @@ func (svr *Service) Run(ctx context.Context) {
 		go svr.sshTunnelGateway.Run()
 	}
 
-	svr.HandleListener(svr.muxListener, false)
+	svr.HandleListener(svr.listener, false)
 
 	<-svr.ctx.Done()
 	// service context may not be canceled by svr.Close(), we should call it here to release resources
-	if svr.muxListener != nil {
+	if svr.listener != nil {
 		svr.Close()
 	}
 }
@@ -396,40 +386,30 @@ func (svr *Service) Run(ctx context.Context) {
 func (svr *Service) Close() error {
 	if svr.kcpListener != nil {
 		svr.kcpListener.Close()
-		svr.kcpListener = nil
 	}
 	if svr.quicListener != nil {
 		svr.quicListener.Close()
-		svr.quicListener = nil
 	}
 	if svr.websocketListener != nil {
 		svr.websocketListener.Close()
-		svr.websocketListener = nil
 	}
 	if svr.tlsListener != nil {
 		svr.tlsListener.Close()
-		svr.tlsConfig = nil
-	}
-	if svr.muxListener != nil {
-		svr.muxListener.Close()
-		svr.muxListener = nil
 	}
-	if svr.vhostHTTPSListener != nil {
-		svr.vhostHTTPSListener.Close()
-		svr.vhostHTTPSListener = nil
+	if svr.sshTunnelListener != nil {
+		svr.sshTunnelListener.Close()
 	}
-	if svr.tcpmuxHTTPConnectListener != nil {
-		svr.tcpmuxHTTPConnectListener.Close()
-		svr.tcpmuxHTTPConnectListener = nil
+	if svr.listener != nil {
+		svr.listener.Close()
 	}
 	if svr.webServer != nil {
 		svr.webServer.Close()
-		svr.webServer = nil
 	}
-	if svr.tcpListener != nil {
-		svr.tcpListener.Close()
-		svr.tcpListener = nil
+	if svr.sshTunnelGateway != nil {
+		svr.sshTunnelGateway.Close()
 	}
+	svr.rc.Close()
+	svr.muxer.Close()
 	svr.ctlManager.Close()
 	if svr.cancel != nil {
 		svr.cancel()