123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- package request
- import (
- "bufio"
- "bytes"
- "crypto/tls"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/url"
- "strconv"
- "time"
- libnet "github.com/fatedier/golib/net"
- httppkg "github.com/fatedier/frp/pkg/util/http"
- "github.com/fatedier/frp/test/e2e/pkg/rpc"
- )
- type Request struct {
- protocol string
- // for all protocol
- addr string
- port int
- body []byte
- timeout time.Duration
- resolver *net.Resolver
- // for http or https
- method string
- host string
- path string
- headers map[string]string
- tlsConfig *tls.Config
- authValue string
- proxyURL string
- }
- func New() *Request {
- return &Request{
- protocol: "tcp",
- addr: "127.0.0.1",
- method: "GET",
- path: "/",
- headers: map[string]string{},
- }
- }
- func (r *Request) Protocol(protocol string) *Request {
- r.protocol = protocol
- return r
- }
- func (r *Request) TCP() *Request {
- r.protocol = "tcp"
- return r
- }
- func (r *Request) UDP() *Request {
- r.protocol = "udp"
- return r
- }
- func (r *Request) HTTP() *Request {
- r.protocol = "http"
- return r
- }
- func (r *Request) HTTPS() *Request {
- r.protocol = "https"
- return r
- }
- func (r *Request) Proxy(url string) *Request {
- r.proxyURL = url
- return r
- }
- func (r *Request) Addr(addr string) *Request {
- r.addr = addr
- return r
- }
- func (r *Request) Port(port int) *Request {
- r.port = port
- return r
- }
- func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
- r.method = method
- r.host = host
- r.path = path
- r.headers = headers
- return r
- }
- func (r *Request) HTTPHost(host string) *Request {
- r.host = host
- return r
- }
- func (r *Request) HTTPPath(path string) *Request {
- r.path = path
- return r
- }
- func (r *Request) HTTPHeaders(headers map[string]string) *Request {
- r.headers = headers
- return r
- }
- func (r *Request) HTTPAuth(user, password string) *Request {
- r.authValue = httppkg.BasicAuth(user, password)
- return r
- }
- func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request {
- r.tlsConfig = tlsConfig
- return r
- }
- func (r *Request) Timeout(timeout time.Duration) *Request {
- r.timeout = timeout
- return r
- }
- func (r *Request) Body(content []byte) *Request {
- r.body = content
- return r
- }
- func (r *Request) Resolver(resolver *net.Resolver) *Request {
- r.resolver = resolver
- return r
- }
- func (r *Request) Do() (*Response, error) {
- var (
- conn net.Conn
- err error
- )
- addr := r.addr
- if r.port > 0 {
- addr = net.JoinHostPort(r.addr, strconv.Itoa(r.port))
- }
- // for protocol http and https
- if r.protocol == "http" || r.protocol == "https" {
- return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path),
- r.host, r.headers, r.proxyURL, r.body, r.tlsConfig)
- }
- // for protocol tcp and udp
- if len(r.proxyURL) > 0 {
- if r.protocol != "tcp" {
- return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
- }
- proxyType, proxyAddress, auth, err := libnet.ParseProxyURL(r.proxyURL)
- if err != nil {
- return nil, fmt.Errorf("parse ProxyURL error: %v", err)
- }
- conn, err = libnet.Dial(addr, libnet.WithProxy(proxyType, proxyAddress), libnet.WithProxyAuth(auth))
- if err != nil {
- return nil, err
- }
- } else {
- dialer := &net.Dialer{Resolver: r.resolver}
- switch r.protocol {
- case "tcp":
- conn, err = dialer.Dial("tcp", addr)
- case "udp":
- conn, err = dialer.Dial("udp", addr)
- default:
- return nil, fmt.Errorf("invalid protocol")
- }
- if err != nil {
- return nil, err
- }
- }
- defer conn.Close()
- if r.timeout > 0 {
- _ = conn.SetDeadline(time.Now().Add(r.timeout))
- }
- buf, err := r.sendRequestByConn(conn, r.body)
- if err != nil {
- return nil, err
- }
- return &Response{Content: buf}, nil
- }
- type Response struct {
- Code int
- Header http.Header
- Content []byte
- }
- func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string,
- proxy string, body []byte, tlsConfig *tls.Config,
- ) (*Response, error) {
- var inBody io.Reader
- if len(body) != 0 {
- inBody = bytes.NewReader(body)
- }
- req, err := http.NewRequest(method, urlstr, inBody)
- if err != nil {
- return nil, err
- }
- if host != "" {
- req.Host = host
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- if r.authValue != "" {
- req.Header.Set("Authorization", r.authValue)
- }
- tr := &http.Transport{
- DialContext: (&net.Dialer{
- Timeout: time.Second,
- KeepAlive: 30 * time.Second,
- DualStack: true,
- Resolver: r.resolver,
- }).DialContext,
- MaxIdleConns: 100,
- IdleConnTimeout: 90 * time.Second,
- TLSHandshakeTimeout: 10 * time.Second,
- ExpectContinueTimeout: 1 * time.Second,
- TLSClientConfig: tlsConfig,
- }
- if len(proxy) != 0 {
- tr.Proxy = func(req *http.Request) (*url.URL, error) {
- return url.Parse(proxy)
- }
- }
- client := http.Client{Transport: tr}
- resp, err := client.Do(req)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- ret := &Response{Code: resp.StatusCode, Header: resp.Header}
- buf, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, err
- }
- ret.Content = buf
- return ret, nil
- }
- func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
- _, err := rpc.WriteBytes(c, content)
- if err != nil {
- return nil, fmt.Errorf("write error: %v", err)
- }
- var reader io.Reader = c
- if r.protocol == "udp" {
- reader = bufio.NewReader(c)
- }
- buf, err := rpc.ReadBytes(reader)
- if err != nil {
- return nil, fmt.Errorf("read error: %v", err)
- }
- return buf, nil
- }
|