securetcp.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package lightsocks
  2. import (
  3. "io"
  4. "log"
  5. "net"
  6. "sync"
  7. )
  8. const (
  9. bufSize = 1024
  10. )
  11. var bpool sync.Pool
  12. func init() {
  13. bpool.New = func() interface{} {
  14. return make([]byte, bufSize)
  15. }
  16. }
  17. func bufferPoolGet() []byte {
  18. return bpool.Get().([]byte)
  19. }
  20. func bufferPoolPut(b []byte) {
  21. bpool.Put(b)
  22. }
  23. // 加密传输的 TCP Socket
  24. type SecureTCPConn struct {
  25. io.ReadWriteCloser
  26. Cipher *Cipher
  27. }
  28. // 从输入流里读取加密过的数据,解密后把原数据放到bs里
  29. func (secureSocket *SecureTCPConn) DecodeRead(bs []byte) (n int, err error) {
  30. n, err = secureSocket.Read(bs)
  31. if err != nil {
  32. return
  33. }
  34. secureSocket.Cipher.Decode(bs[:n])
  35. return
  36. }
  37. // 把放在bs里的数据加密后立即全部写入输出流
  38. func (secureSocket *SecureTCPConn) EncodeWrite(bs []byte) (int, error) {
  39. secureSocket.Cipher.Encode(bs)
  40. return secureSocket.Write(bs)
  41. }
  42. // 从src中源源不断的读取原数据加密后写入到dst,直到src中没有数据可以再读取
  43. func (secureSocket *SecureTCPConn) EncodeCopy(dst io.ReadWriteCloser) error {
  44. buf := bufferPoolGet()
  45. defer bufferPoolPut(buf)
  46. for {
  47. readCount, errRead := secureSocket.Read(buf)
  48. if errRead != nil {
  49. if errRead != io.EOF {
  50. return errRead
  51. } else {
  52. return nil
  53. }
  54. }
  55. if readCount > 0 {
  56. writeCount, errWrite := (&SecureTCPConn{
  57. ReadWriteCloser: dst,
  58. Cipher: secureSocket.Cipher,
  59. }).EncodeWrite(buf[0:readCount])
  60. if errWrite != nil {
  61. return errWrite
  62. }
  63. if readCount != writeCount {
  64. return io.ErrShortWrite
  65. }
  66. }
  67. }
  68. }
  69. // 从src中源源不断的读取加密后的数据解密后写入到dst,直到src中没有数据可以再读取
  70. func (secureSocket *SecureTCPConn) DecodeCopy(dst io.Writer) error {
  71. buf := bufferPoolGet()
  72. defer bufferPoolPut(buf)
  73. for {
  74. readCount, errRead := secureSocket.DecodeRead(buf)
  75. if errRead != nil {
  76. if errRead != io.EOF {
  77. return errRead
  78. } else {
  79. return nil
  80. }
  81. }
  82. if readCount > 0 {
  83. writeCount, errWrite := dst.Write(buf[0:readCount])
  84. if errWrite != nil {
  85. return errWrite
  86. }
  87. if readCount != writeCount {
  88. return io.ErrShortWrite
  89. }
  90. }
  91. }
  92. }
  93. // see net.DialTCP
  94. func DialEncryptedTCP(raddr *net.TCPAddr, cipher *Cipher) (*SecureTCPConn, error) {
  95. remoteConn, err := net.DialTCP("tcp", nil, raddr)
  96. if err != nil {
  97. return nil, err
  98. }
  99. // Conn被关闭时直接清除所有数据 不管没有发送的数据
  100. remoteConn.SetLinger(0)
  101. return &SecureTCPConn{
  102. ReadWriteCloser: remoteConn,
  103. Cipher: cipher,
  104. }, nil
  105. }
  106. // see net.ListenTCP
  107. func ListenEncryptedTCP(laddr *net.TCPAddr, cipher *Cipher, handleConn func(localConn *SecureTCPConn), didListen func(listenAddr *net.TCPAddr)) error {
  108. listener, err := net.ListenTCP("tcp", laddr)
  109. if err != nil {
  110. return err
  111. }
  112. defer listener.Close()
  113. if didListen != nil {
  114. // didListen 可能有阻塞操作
  115. go didListen(listener.Addr().(*net.TCPAddr))
  116. }
  117. for {
  118. localConn, err := listener.AcceptTCP()
  119. if err != nil {
  120. log.Println(err)
  121. continue
  122. }
  123. // localConn被关闭时直接清除所有数据 不管没有发送的数据
  124. localConn.SetLinger(0)
  125. go handleConn(&SecureTCPConn{
  126. ReadWriteCloser: localConn,
  127. Cipher: cipher,
  128. })
  129. }
  130. }