1
0

dial.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. package net
  2. import (
  3. "context"
  4. "net"
  5. "net/url"
  6. libnet "github.com/fatedier/golib/net"
  7. "golang.org/x/net/websocket"
  8. )
  9. func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libnet.AfterHookFunc {
  10. return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
  11. if enableTLS && !disableCustomTLSHeadByte {
  12. _, err := c.Write([]byte{byte(FRPTLSHeadByte)})
  13. if err != nil {
  14. return nil, nil, err
  15. }
  16. }
  17. return ctx, c, nil
  18. }
  19. }
  20. func DialHookWebsocket(protocol string, host string) libnet.AfterHookFunc {
  21. return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
  22. if protocol != "wss" {
  23. protocol = "ws"
  24. }
  25. if host == "" {
  26. host = addr
  27. }
  28. addr = protocol + "://" + host + FrpWebsocketPath
  29. uri, err := url.Parse(addr)
  30. if err != nil {
  31. return nil, nil, err
  32. }
  33. origin := "http://" + uri.Host
  34. cfg, err := websocket.NewConfig(addr, origin)
  35. if err != nil {
  36. return nil, nil, err
  37. }
  38. conn, err := websocket.NewClient(cfg, c)
  39. if err != nil {
  40. return nil, nil, err
  41. }
  42. return ctx, conn, nil
  43. }
  44. }