real_ip.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package features
  2. import (
  3. "bufio"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "github.com/onsi/ginkgo/v2"
  8. pp "github.com/pires/go-proxyproto"
  9. "github.com/fatedier/frp/pkg/util/log"
  10. "github.com/fatedier/frp/test/e2e/framework"
  11. "github.com/fatedier/frp/test/e2e/framework/consts"
  12. "github.com/fatedier/frp/test/e2e/mock/server/httpserver"
  13. "github.com/fatedier/frp/test/e2e/mock/server/streamserver"
  14. "github.com/fatedier/frp/test/e2e/pkg/request"
  15. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  16. )
  17. var _ = ginkgo.Describe("[Feature: Real IP]", func() {
  18. f := framework.NewDefaultFramework()
  19. ginkgo.It("HTTP X-Forwarded-For", func() {
  20. vhostHTTPPort := f.AllocPort()
  21. serverConf := consts.LegacyDefaultServerConfig + fmt.Sprintf(`
  22. vhost_http_port = %d
  23. `, vhostHTTPPort)
  24. localPort := f.AllocPort()
  25. localServer := httpserver.New(
  26. httpserver.WithBindPort(localPort),
  27. httpserver.WithHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  28. _, _ = w.Write([]byte(req.Header.Get("X-Forwarded-For")))
  29. })),
  30. )
  31. f.RunServer("", localServer)
  32. clientConf := consts.LegacyDefaultClientConfig
  33. clientConf += fmt.Sprintf(`
  34. [test]
  35. type = http
  36. local_port = %d
  37. custom_domains = normal.example.com
  38. `, localPort)
  39. f.RunProcesses([]string{serverConf}, []string{clientConf})
  40. framework.NewRequestExpect(f).Port(vhostHTTPPort).
  41. RequestModify(func(r *request.Request) {
  42. r.HTTP().HTTPHost("normal.example.com")
  43. }).
  44. ExpectResp([]byte("127.0.0.1")).
  45. Ensure()
  46. })
  47. ginkgo.Describe("Proxy Protocol", func() {
  48. ginkgo.It("TCP", func() {
  49. serverConf := consts.LegacyDefaultServerConfig
  50. clientConf := consts.LegacyDefaultClientConfig
  51. localPort := f.AllocPort()
  52. localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
  53. streamserver.WithCustomHandler(func(c net.Conn) {
  54. defer c.Close()
  55. rd := bufio.NewReader(c)
  56. ppHeader, err := pp.Read(rd)
  57. if err != nil {
  58. log.Errorf("read proxy protocol error: %v", err)
  59. return
  60. }
  61. for {
  62. if _, err := rpc.ReadBytes(rd); err != nil {
  63. return
  64. }
  65. buf := []byte(ppHeader.SourceAddr.String())
  66. _, _ = rpc.WriteBytes(c, buf)
  67. }
  68. }))
  69. f.RunServer("", localServer)
  70. remotePort := f.AllocPort()
  71. clientConf += fmt.Sprintf(`
  72. [tcp]
  73. type = tcp
  74. local_port = %d
  75. remote_port = %d
  76. proxy_protocol_version = v2
  77. `, localPort, remotePort)
  78. f.RunProcesses([]string{serverConf}, []string{clientConf})
  79. framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool {
  80. log.Tracef("ProxyProtocol get SourceAddr: %s", string(resp.Content))
  81. addr, err := net.ResolveTCPAddr("tcp", string(resp.Content))
  82. if err != nil {
  83. return false
  84. }
  85. if addr.IP.String() != "127.0.0.1" {
  86. return false
  87. }
  88. return true
  89. })
  90. })
  91. ginkgo.It("HTTP", func() {
  92. vhostHTTPPort := f.AllocPort()
  93. serverConf := consts.LegacyDefaultServerConfig + fmt.Sprintf(`
  94. vhost_http_port = %d
  95. `, vhostHTTPPort)
  96. clientConf := consts.LegacyDefaultClientConfig
  97. localPort := f.AllocPort()
  98. var srcAddrRecord string
  99. localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
  100. streamserver.WithCustomHandler(func(c net.Conn) {
  101. defer c.Close()
  102. rd := bufio.NewReader(c)
  103. ppHeader, err := pp.Read(rd)
  104. if err != nil {
  105. log.Errorf("read proxy protocol error: %v", err)
  106. return
  107. }
  108. srcAddrRecord = ppHeader.SourceAddr.String()
  109. }))
  110. f.RunServer("", localServer)
  111. clientConf += fmt.Sprintf(`
  112. [test]
  113. type = http
  114. local_port = %d
  115. custom_domains = normal.example.com
  116. proxy_protocol_version = v2
  117. `, localPort)
  118. f.RunProcesses([]string{serverConf}, []string{clientConf})
  119. framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) {
  120. r.HTTP().HTTPHost("normal.example.com")
  121. }).Ensure(framework.ExpectResponseCode(404))
  122. log.Tracef("ProxyProtocol get SourceAddr: %s", srcAddrRecord)
  123. addr, err := net.ResolveTCPAddr("tcp", srcAddrRecord)
  124. framework.ExpectNoError(err, srcAddrRecord)
  125. framework.ExpectEqualValues("127.0.0.1", addr.IP.String())
  126. })
  127. })
  128. })