request.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. package request
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strconv"
  12. "time"
  13. libnet "github.com/fatedier/golib/net"
  14. httppkg "github.com/fatedier/frp/pkg/util/http"
  15. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  16. )
  17. type Request struct {
  18. protocol string
  19. // for all protocol
  20. addr string
  21. port int
  22. body []byte
  23. timeout time.Duration
  24. resolver *net.Resolver
  25. // for http or https
  26. method string
  27. host string
  28. path string
  29. headers map[string]string
  30. tlsConfig *tls.Config
  31. authValue string
  32. proxyURL string
  33. }
  34. func New() *Request {
  35. return &Request{
  36. protocol: "tcp",
  37. addr: "127.0.0.1",
  38. method: "GET",
  39. path: "/",
  40. headers: map[string]string{},
  41. }
  42. }
  43. func (r *Request) Protocol(protocol string) *Request {
  44. r.protocol = protocol
  45. return r
  46. }
  47. func (r *Request) TCP() *Request {
  48. r.protocol = "tcp"
  49. return r
  50. }
  51. func (r *Request) UDP() *Request {
  52. r.protocol = "udp"
  53. return r
  54. }
  55. func (r *Request) HTTP() *Request {
  56. r.protocol = "http"
  57. return r
  58. }
  59. func (r *Request) HTTPS() *Request {
  60. r.protocol = "https"
  61. return r
  62. }
  63. func (r *Request) Proxy(url string) *Request {
  64. r.proxyURL = url
  65. return r
  66. }
  67. func (r *Request) Addr(addr string) *Request {
  68. r.addr = addr
  69. return r
  70. }
  71. func (r *Request) Port(port int) *Request {
  72. r.port = port
  73. return r
  74. }
  75. func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
  76. r.method = method
  77. r.host = host
  78. r.path = path
  79. r.headers = headers
  80. return r
  81. }
  82. func (r *Request) HTTPHost(host string) *Request {
  83. r.host = host
  84. return r
  85. }
  86. func (r *Request) HTTPPath(path string) *Request {
  87. r.path = path
  88. return r
  89. }
  90. func (r *Request) HTTPHeaders(headers map[string]string) *Request {
  91. r.headers = headers
  92. return r
  93. }
  94. func (r *Request) HTTPAuth(user, password string) *Request {
  95. r.authValue = httppkg.BasicAuth(user, password)
  96. return r
  97. }
  98. func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request {
  99. r.tlsConfig = tlsConfig
  100. return r
  101. }
  102. func (r *Request) Timeout(timeout time.Duration) *Request {
  103. r.timeout = timeout
  104. return r
  105. }
  106. func (r *Request) Body(content []byte) *Request {
  107. r.body = content
  108. return r
  109. }
  110. func (r *Request) Resolver(resolver *net.Resolver) *Request {
  111. r.resolver = resolver
  112. return r
  113. }
  114. func (r *Request) Do() (*Response, error) {
  115. var (
  116. conn net.Conn
  117. err error
  118. )
  119. addr := r.addr
  120. if r.port > 0 {
  121. addr = net.JoinHostPort(r.addr, strconv.Itoa(r.port))
  122. }
  123. // for protocol http and https
  124. if r.protocol == "http" || r.protocol == "https" {
  125. return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path),
  126. r.host, r.headers, r.proxyURL, r.body, r.tlsConfig)
  127. }
  128. // for protocol tcp and udp
  129. if len(r.proxyURL) > 0 {
  130. if r.protocol != "tcp" {
  131. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  132. }
  133. proxyType, proxyAddress, auth, err := libnet.ParseProxyURL(r.proxyURL)
  134. if err != nil {
  135. return nil, fmt.Errorf("parse ProxyURL error: %v", err)
  136. }
  137. conn, err = libnet.Dial(addr, libnet.WithProxy(proxyType, proxyAddress), libnet.WithProxyAuth(auth))
  138. if err != nil {
  139. return nil, err
  140. }
  141. } else {
  142. dialer := &net.Dialer{Resolver: r.resolver}
  143. switch r.protocol {
  144. case "tcp":
  145. conn, err = dialer.Dial("tcp", addr)
  146. case "udp":
  147. conn, err = dialer.Dial("udp", addr)
  148. default:
  149. return nil, fmt.Errorf("invalid protocol")
  150. }
  151. if err != nil {
  152. return nil, err
  153. }
  154. }
  155. defer conn.Close()
  156. if r.timeout > 0 {
  157. _ = conn.SetDeadline(time.Now().Add(r.timeout))
  158. }
  159. buf, err := r.sendRequestByConn(conn, r.body)
  160. if err != nil {
  161. return nil, err
  162. }
  163. return &Response{Content: buf}, nil
  164. }
  165. type Response struct {
  166. Code int
  167. Header http.Header
  168. Content []byte
  169. }
  170. func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string,
  171. proxy string, body []byte, tlsConfig *tls.Config,
  172. ) (*Response, error) {
  173. var inBody io.Reader
  174. if len(body) != 0 {
  175. inBody = bytes.NewReader(body)
  176. }
  177. req, err := http.NewRequest(method, urlstr, inBody)
  178. if err != nil {
  179. return nil, err
  180. }
  181. if host != "" {
  182. req.Host = host
  183. }
  184. for k, v := range headers {
  185. req.Header.Set(k, v)
  186. }
  187. if r.authValue != "" {
  188. req.Header.Set("Authorization", r.authValue)
  189. }
  190. tr := &http.Transport{
  191. DialContext: (&net.Dialer{
  192. Timeout: time.Second,
  193. KeepAlive: 30 * time.Second,
  194. DualStack: true,
  195. Resolver: r.resolver,
  196. }).DialContext,
  197. MaxIdleConns: 100,
  198. IdleConnTimeout: 90 * time.Second,
  199. TLSHandshakeTimeout: 10 * time.Second,
  200. ExpectContinueTimeout: 1 * time.Second,
  201. TLSClientConfig: tlsConfig,
  202. }
  203. if len(proxy) != 0 {
  204. tr.Proxy = func(req *http.Request) (*url.URL, error) {
  205. return url.Parse(proxy)
  206. }
  207. }
  208. client := http.Client{Transport: tr}
  209. resp, err := client.Do(req)
  210. if err != nil {
  211. return nil, err
  212. }
  213. defer resp.Body.Close()
  214. ret := &Response{Code: resp.StatusCode, Header: resp.Header}
  215. buf, err := io.ReadAll(resp.Body)
  216. if err != nil {
  217. return nil, err
  218. }
  219. ret.Content = buf
  220. return ret, nil
  221. }
  222. func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  223. _, err := rpc.WriteBytes(c, content)
  224. if err != nil {
  225. return nil, fmt.Errorf("write error: %v", err)
  226. }
  227. var reader io.Reader = c
  228. if r.protocol == "udp" {
  229. reader = bufio.NewReader(c)
  230. }
  231. buf, err := rpc.ReadBytes(reader)
  232. if err != nil {
  233. return nil, fmt.Errorf("read error: %v", err)
  234. }
  235. return buf, nil
  236. }