udp.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package services
  2. import (
  3. "bufio"
  4. "fmt"
  5. "hash/crc32"
  6. "io"
  7. "log"
  8. "net"
  9. "github.com/snail007/goproxy/utils"
  10. "runtime/debug"
  11. "strconv"
  12. "strings"
  13. "time"
  14. )
  15. type UDP struct {
  16. p utils.ConcurrentMap
  17. outPool utils.OutPool
  18. cfg UDPArgs
  19. sc *utils.ServerChannel
  20. }
  21. func NewUDP() Service {
  22. return &UDP{
  23. outPool: utils.OutPool{},
  24. p: utils.NewConcurrentMap(),
  25. }
  26. }
  27. func (s *UDP) InitService() {
  28. if *s.cfg.ParentType != TYPE_UDP {
  29. s.InitOutConnPool()
  30. }
  31. }
  32. func (s *UDP) StopService() {
  33. if s.outPool.Pool != nil {
  34. s.outPool.Pool.ReleaseAll()
  35. }
  36. }
  37. func (s *UDP) Start(args interface{}) (err error) {
  38. s.cfg = args.(UDPArgs)
  39. if *s.cfg.Parent != "" {
  40. log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
  41. } else {
  42. log.Fatalf("parent required for udp %s", *s.cfg.Local)
  43. }
  44. s.InitService()
  45. host, port, _ := net.SplitHostPort(*s.cfg.Local)
  46. p, _ := strconv.Atoi(port)
  47. sc := utils.NewServerChannel(host, p)
  48. s.sc = &sc
  49. err = sc.ListenUDP(s.callback)
  50. if err != nil {
  51. return
  52. }
  53. log.Printf("udp proxy on %s", (*sc.UDPListener).LocalAddr())
  54. return
  55. }
  56. func (s *UDP) Clean() {
  57. s.StopService()
  58. }
  59. func (s *UDP) callback(packet []byte, localAddr, srcAddr *net.UDPAddr) {
  60. defer func() {
  61. if err := recover(); err != nil {
  62. log.Printf("udp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
  63. }
  64. }()
  65. var err error
  66. switch *s.cfg.ParentType {
  67. case TYPE_TCP:
  68. fallthrough
  69. case TYPE_TLS:
  70. err = s.OutToTCP(packet, localAddr, srcAddr)
  71. case TYPE_UDP:
  72. err = s.OutToUDP(packet, localAddr, srcAddr)
  73. default:
  74. err = fmt.Errorf("unkown parent type %s", *s.cfg.ParentType)
  75. }
  76. if err != nil {
  77. log.Printf("connect to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
  78. }
  79. }
  80. func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) {
  81. isNew = !s.p.Has(connKey)
  82. var _conn interface{}
  83. if isNew {
  84. _conn, err = s.outPool.Pool.Get()
  85. if err != nil {
  86. return nil, false, err
  87. }
  88. s.p.Set(connKey, _conn)
  89. } else {
  90. _conn, _ = s.p.Get(connKey)
  91. }
  92. conn = _conn.(net.Conn)
  93. return
  94. }
  95. func (s *UDP) OutToTCP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
  96. numLocal := crc32.ChecksumIEEE([]byte(localAddr.String()))
  97. numSrc := crc32.ChecksumIEEE([]byte(srcAddr.String()))
  98. mod := uint32(*s.cfg.PoolSize)
  99. if mod == 0 {
  100. mod = 10
  101. }
  102. connKey := uint64((numLocal/10)*10 + numSrc%mod)
  103. conn, isNew, err := s.GetConn(fmt.Sprintf("%d", connKey))
  104. if err != nil {
  105. log.Printf("upd get conn to %s parent %s fail, ERR:%s", *s.cfg.ParentType, *s.cfg.Parent, err)
  106. return
  107. }
  108. if isNew {
  109. go func() {
  110. defer func() {
  111. if err := recover(); err != nil {
  112. log.Printf("udp conn handler out to tcp crashed with err : %s \nstack: %s", err, string(debug.Stack()))
  113. }
  114. }()
  115. log.Printf("conn %d created , local: %s", connKey, srcAddr.String())
  116. for {
  117. srcAddrFromConn, body, err := utils.ReadUDPPacket(&conn)
  118. if err == io.EOF || err == io.ErrUnexpectedEOF {
  119. //log.Printf("connection %d released", connKey)
  120. s.p.Remove(fmt.Sprintf("%d", connKey))
  121. break
  122. }
  123. if err != nil {
  124. log.Printf("parse revecived udp packet fail, err: %s", err)
  125. continue
  126. }
  127. //log.Printf("udp packet revecived over parent , local:%s", srcAddrFromConn)
  128. _srcAddr := strings.Split(srcAddrFromConn, ":")
  129. if len(_srcAddr) != 2 {
  130. log.Printf("parse revecived udp packet fail, addr error : %s", srcAddrFromConn)
  131. continue
  132. }
  133. port, _ := strconv.Atoi(_srcAddr[1])
  134. dstAddr := &net.UDPAddr{IP: net.ParseIP(_srcAddr[0]), Port: port}
  135. _, err = s.sc.UDPListener.WriteToUDP(body, dstAddr)
  136. if err != nil {
  137. log.Printf("udp response to local %s fail,ERR:%s", srcAddr, err)
  138. continue
  139. }
  140. //log.Printf("udp response to local %s success", srcAddr)
  141. }
  142. }()
  143. }
  144. //log.Printf("select conn %d , local: %s", connKey, srcAddr.String())
  145. writer := bufio.NewWriter(conn)
  146. //fmt.Println(conn, writer)
  147. writer.Write(utils.UDPPacket(srcAddr.String(), packet))
  148. err = writer.Flush()
  149. if err != nil {
  150. log.Printf("write udp packet to %s fail ,flush err:%s", *s.cfg.Parent, err)
  151. return
  152. }
  153. //log.Printf("write packet %v", packet)
  154. return
  155. }
  156. func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
  157. //log.Printf("udp packet revecived:%s,%v", srcAddr, packet)
  158. dstAddr, err := net.ResolveUDPAddr("udp", *s.cfg.Parent)
  159. if err != nil {
  160. log.Printf("resolve udp addr %s fail fail,ERR:%s", dstAddr.String(), err)
  161. return
  162. }
  163. clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
  164. conn, err := net.DialUDP("udp", clientSrcAddr, dstAddr)
  165. if err != nil {
  166. log.Printf("connect to udp %s fail,ERR:%s", dstAddr.String(), err)
  167. return
  168. }
  169. conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
  170. _, err = conn.Write(packet)
  171. if err != nil {
  172. log.Printf("send udp packet to %s fail,ERR:%s", dstAddr.String(), err)
  173. return
  174. }
  175. //log.Printf("send udp packet to %s success", dstAddr.String())
  176. buf := make([]byte, 512)
  177. len, _, err := conn.ReadFromUDP(buf)
  178. if err != nil {
  179. log.Printf("read udp response from %s fail ,ERR:%s", dstAddr.String(), err)
  180. return
  181. }
  182. //log.Printf("revecived udp packet from %s , %v", dstAddr.String(), respBody)
  183. _, err = s.sc.UDPListener.WriteToUDP(buf[0:len], srcAddr)
  184. if err != nil {
  185. log.Printf("send udp response to cluster fail ,ERR:%s", err)
  186. return
  187. }
  188. //log.Printf("send udp response to cluster success ,from:%s", dstAddr.String())
  189. return
  190. }
  191. func (s *UDP) InitOutConnPool() {
  192. if *s.cfg.ParentType == TYPE_TLS || *s.cfg.ParentType == TYPE_TCP {
  193. //dur int, isTLS bool, certBytes, keyBytes []byte,
  194. //parent string, timeout int, InitialCap int, MaxCap int
  195. s.outPool = utils.NewOutPool(
  196. *s.cfg.CheckParentInterval,
  197. *s.cfg.ParentType == TYPE_TLS,
  198. s.cfg.CertBytes, s.cfg.KeyBytes,
  199. *s.cfg.Parent,
  200. *s.cfg.Timeout,
  201. *s.cfg.PoolSize,
  202. *s.cfg.PoolSize*2,
  203. )
  204. }
  205. }