1
0

visitor_manager.go 4.9 KB


  1. // Copyright 2018 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package visitor
  15. import (
  16. "context"
  17. "fmt"
  18. "net"
  19. "reflect"
  20. "sync"
  21. "time"
  22. "github.com/samber/lo"
  23. v1 "github.com/fatedier/frp/pkg/config/v1"
  24. "github.com/fatedier/frp/pkg/transport"
  25. "github.com/fatedier/frp/pkg/util/xlog"
  26. )
  27. type Manager struct {
  28. clientCfg *v1.ClientCommonConfig
  29. cfgs map[string]v1.VisitorConfigurer
  30. visitors map[string]Visitor
  31. helper Helper
  32. checkInterval time.Duration
  33. keepVisitorsRunningOnce sync.Once
  34. mu sync.RWMutex
  35. ctx context.Context
  36. stopCh chan struct{}
  37. }
  38. func NewManager(
  39. ctx context.Context,
  40. runID string,
  41. clientCfg *v1.ClientCommonConfig,
  42. connectServer func() (net.Conn, error),
  43. msgTransporter transport.MessageTransporter,
  44. ) *Manager {
  45. m := &Manager{
  46. clientCfg: clientCfg,
  47. cfgs: make(map[string]v1.VisitorConfigurer),
  48. visitors: make(map[string]Visitor),
  49. checkInterval: 10 * time.Second,
  50. ctx: ctx,
  51. stopCh: make(chan struct{}),
  52. }
  53. m.helper = &visitorHelperImpl{
  54. connectServerFn: connectServer,
  55. msgTransporter: msgTransporter,
  56. transferConnFn: m.TransferConn,
  57. runID: runID,
  58. }
  59. return m
  60. }
  61. // keepVisitorsRunning checks all visitors' status periodically, if some visitor is not running, start it.
  62. // It will only start after Reload is called and a new visitor is added.
  63. func (vm *Manager) keepVisitorsRunning() {
  64. xl := xlog.FromContextSafe(vm.ctx)
  65. ticker := time.NewTicker(vm.checkInterval)
  66. defer ticker.Stop()
  67. for {
  68. select {
  69. case <-vm.stopCh:
  70. xl.Tracef("gracefully shutdown visitor manager")
  71. return
  72. case <-ticker.C:
  73. vm.mu.Lock()
  74. for _, cfg := range vm.cfgs {
  75. name := cfg.GetBaseConfig().Name
  76. if _, exist := vm.visitors[name]; !exist {
  77. xl.Infof("try to start visitor [%s]", name)
  78. _ = vm.startVisitor(cfg)
  79. }
  80. }
  81. vm.mu.Unlock()
  82. }
  83. }
  84. }
  85. func (vm *Manager) Close() {
  86. vm.mu.Lock()
  87. defer vm.mu.Unlock()
  88. for _, v := range vm.visitors {
  89. v.Close()
  90. }
  91. select {
  92. case <-vm.stopCh:
  93. default:
  94. close(vm.stopCh)
  95. }
  96. }
  97. // Hold lock before calling this function.
  98. func (vm *Manager) startVisitor(cfg v1.VisitorConfigurer) (err error) {
  99. xl := xlog.FromContextSafe(vm.ctx)
  100. name := cfg.GetBaseConfig().Name
  101. visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.helper)
  102. err = visitor.Run()
  103. if err != nil {
  104. xl.Warnf("start error: %v", err)
  105. } else {
  106. vm.visitors[name] = visitor
  107. xl.Infof("start visitor success")
  108. }
  109. return
  110. }
  111. func (vm *Manager) UpdateAll(cfgs []v1.VisitorConfigurer) {
  112. if len(cfgs) > 0 {
  113. // Only start keepVisitorsRunning goroutine once and only when there is at least one visitor.
  114. vm.keepVisitorsRunningOnce.Do(func() {
  115. go vm.keepVisitorsRunning()
  116. })
  117. }
  118. xl := xlog.FromContextSafe(vm.ctx)
  119. cfgsMap := lo.KeyBy(cfgs, func(c v1.VisitorConfigurer) string {
  120. return c.GetBaseConfig().Name
  121. })
  122. vm.mu.Lock()
  123. defer vm.mu.Unlock()
  124. delNames := make([]string, 0)
  125. for name, oldCfg := range vm.cfgs {
  126. del := false
  127. cfg, ok := cfgsMap[name]
  128. if !ok || !reflect.DeepEqual(oldCfg, cfg) {
  129. del = true
  130. }
  131. if del {
  132. delNames = append(delNames, name)
  133. delete(vm.cfgs, name)
  134. if visitor, ok := vm.visitors[name]; ok {
  135. visitor.Close()
  136. }
  137. delete(vm.visitors, name)
  138. }
  139. }
  140. if len(delNames) > 0 {
  141. xl.Infof("visitor removed: %v", delNames)
  142. }
  143. addNames := make([]string, 0)
  144. for _, cfg := range cfgs {
  145. name := cfg.GetBaseConfig().Name
  146. if _, ok := vm.cfgs[name]; !ok {
  147. vm.cfgs[name] = cfg
  148. addNames = append(addNames, name)
  149. _ = vm.startVisitor(cfg)
  150. }
  151. }
  152. if len(addNames) > 0 {
  153. xl.Infof("visitor added: %v", addNames)
  154. }
  155. }
  156. // TransferConn transfers a connection to a visitor.
  157. func (vm *Manager) TransferConn(name string, conn net.Conn) error {
  158. vm.mu.RLock()
  159. defer vm.mu.RUnlock()
  160. v, ok := vm.visitors[name]
  161. if !ok {
  162. return fmt.Errorf("visitor [%s] not found", name)
  163. }
  164. return v.AcceptConn(conn)
  165. }
  166. type visitorHelperImpl struct {
  167. connectServerFn func() (net.Conn, error)
  168. msgTransporter transport.MessageTransporter
  169. transferConnFn func(name string, conn net.Conn) error
  170. runID string
  171. }
  172. func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
  173. return v.connectServerFn()
  174. }
  175. func (v *visitorHelperImpl) TransferConn(name string, conn net.Conn) error {
  176. return v.transferConnFn(name, conn)
  177. }
  178. func (v *visitorHelperImpl) MsgTransporter() transport.MessageTransporter {
  179. return v.msgTransporter
  180. }
  181. func (v *visitorHelperImpl) RunID() string {
  182. return v.runID
  183. }