123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- package lightsocks
- import (
- "io"
- "log"
- "net"
- "sync"
- )
- const (
- bufSize = 1024
- )
- var bpool sync.Pool
- func init() {
- bpool.New = func() interface{} {
- return make([]byte, bufSize)
- }
- }
- func bufferPoolGet() []byte {
- return bpool.Get().([]byte)
- }
- func bufferPoolPut(b []byte) {
- bpool.Put(b)
- }
- // 加密传输的 TCP Socket
- type SecureTCPConn struct {
- io.ReadWriteCloser
- Cipher *Cipher
- }
- // 从输入流里读取加密过的数据,解密后把原数据放到bs里
- func (secureSocket *SecureTCPConn) DecodeRead(bs []byte) (n int, err error) {
- n, err = secureSocket.Read(bs)
- if err != nil {
- return
- }
- secureSocket.Cipher.Decode(bs[:n])
- return
- }
- // 把放在bs里的数据加密后立即全部写入输出流
- func (secureSocket *SecureTCPConn) EncodeWrite(bs []byte) (int, error) {
- secureSocket.Cipher.Encode(bs)
- return secureSocket.Write(bs)
- }
- // 从src中源源不断的读取原数据加密后写入到dst,直到src中没有数据可以再读取
- func (secureSocket *SecureTCPConn) EncodeCopy(dst io.ReadWriteCloser) error {
- buf := bufferPoolGet()
- defer bufferPoolPut(buf)
- for {
- readCount, errRead := secureSocket.Read(buf)
- if errRead != nil {
- if errRead != io.EOF {
- return errRead
- } else {
- return nil
- }
- }
- if readCount > 0 {
- writeCount, errWrite := (&SecureTCPConn{
- ReadWriteCloser: dst,
- Cipher: secureSocket.Cipher,
- }).EncodeWrite(buf[0:readCount])
- if errWrite != nil {
- return errWrite
- }
- if readCount != writeCount {
- return io.ErrShortWrite
- }
- }
- }
- }
- // 从src中源源不断的读取加密后的数据解密后写入到dst,直到src中没有数据可以再读取
- func (secureSocket *SecureTCPConn) DecodeCopy(dst io.Writer) error {
- buf := bufferPoolGet()
- defer bufferPoolPut(buf)
- for {
- readCount, errRead := secureSocket.DecodeRead(buf)
- if errRead != nil {
- if errRead != io.EOF {
- return errRead
- } else {
- return nil
- }
- }
- if readCount > 0 {
- writeCount, errWrite := dst.Write(buf[0:readCount])
- if errWrite != nil {
- return errWrite
- }
- if readCount != writeCount {
- return io.ErrShortWrite
- }
- }
- }
- }
- // see net.DialTCP
- func DialEncryptedTCP(raddr *net.TCPAddr, cipher *Cipher) (*SecureTCPConn, error) {
- remoteConn, err := net.DialTCP("tcp", nil, raddr)
- if err != nil {
- return nil, err
- }
- // Conn被关闭时直接清除所有数据 不管没有发送的数据
- remoteConn.SetLinger(0)
- return &SecureTCPConn{
- ReadWriteCloser: remoteConn,
- Cipher: cipher,
- }, nil
- }
- // see net.ListenTCP
- func ListenEncryptedTCP(laddr *net.TCPAddr, cipher *Cipher, handleConn func(localConn *SecureTCPConn), didListen func(listenAddr *net.TCPAddr)) error {
- listener, err := net.ListenTCP("tcp", laddr)
- if err != nil {
- return err
- }
- defer listener.Close()
- if didListen != nil {
- // didListen 可能有阻塞操作
- go didListen(listener.Addr().(*net.TCPAddr))
- }
- for {
- localConn, err := listener.AcceptTCP()
- if err != nil {
- log.Println(err)
- continue
- }
- // localConn被关闭时直接清除所有数据 不管没有发送的数据
- localConn.SetLinger(0)
- go handleConn(&SecureTCPConn{
- ReadWriteCloser: localConn,
- Cipher: cipher,
- })
- }
- }
|