1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- package net
- import (
- "context"
- "net"
- "net/url"
- libnet "github.com/fatedier/golib/net"
- "golang.org/x/net/websocket"
- )
- func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libnet.AfterHookFunc {
- return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
- if enableTLS && !disableCustomTLSHeadByte {
- _, err := c.Write([]byte{byte(FRPTLSHeadByte)})
- if err != nil {
- return nil, nil, err
- }
- }
- return ctx, c, nil
- }
- }
- func DialHookWebsocket(protocol string, host string) libnet.AfterHookFunc {
- return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
- if protocol != "wss" {
- protocol = "ws"
- }
- if host == "" {
- host = addr
- }
- addr = protocol + "://" + host + FrpWebsocketPath
- uri, err := url.Parse(addr)
- if err != nil {
- return nil, nil, err
- }
- origin := "http://" + uri.Host
- cfg, err := websocket.NewConfig(addr, origin)
- if err != nil {
- return nil, nil, err
- }
- conn, err := websocket.NewClient(cfg, c)
- if err != nil {
- return nil, nil, err
- }
- return ctx, conn, nil
- }
- }
|