structs.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. package utils
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "log"
  10. "net"
  11. "net/url"
  12. "strings"
  13. "time"
  14. )
  15. type Checker struct {
  16. data ConcurrentMap
  17. blockedMap ConcurrentMap
  18. directMap ConcurrentMap
  19. interval int64
  20. timeout int
  21. }
  22. type CheckerItem struct {
  23. IsHTTPS bool
  24. Method string
  25. URL string
  26. Domain string
  27. Host string
  28. Data []byte
  29. SuccessCount uint
  30. FailCount uint
  31. }
  32. //NewChecker args:
  33. //timeout : tcp timeout milliseconds ,connect to host
  34. //interval: recheck domain interval seconds
  35. func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker {
  36. ch := Checker{
  37. data: NewConcurrentMap(),
  38. interval: interval,
  39. timeout: timeout,
  40. }
  41. ch.blockedMap = ch.loadMap(blockedFile)
  42. ch.directMap = ch.loadMap(directFile)
  43. if !ch.blockedMap.IsEmpty() {
  44. log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
  45. }
  46. if !ch.directMap.IsEmpty() {
  47. log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
  48. }
  49. ch.start()
  50. return ch
  51. }
  52. func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
  53. dataMap = NewConcurrentMap()
  54. if PathExists(f) {
  55. _contents, err := ioutil.ReadFile(f)
  56. if err != nil {
  57. log.Printf("load file err:%s", err)
  58. return
  59. }
  60. for _, line := range strings.Split(string(_contents), "\n") {
  61. line = strings.Trim(line, "\r \t")
  62. if line != "" {
  63. dataMap.Set(line, true)
  64. }
  65. }
  66. }
  67. return
  68. }
  69. func (c *Checker) start() {
  70. go func() {
  71. for {
  72. for _, v := range c.data.Items() {
  73. go func(item CheckerItem) {
  74. if c.isNeedCheck(item) {
  75. //log.Printf("check %s", item.Domain)
  76. var conn net.Conn
  77. var err error
  78. if item.IsHTTPS {
  79. conn, err = ConnectHost(item.Host, c.timeout)
  80. if err == nil {
  81. conn.SetDeadline(time.Now().Add(time.Millisecond))
  82. conn.Close()
  83. }
  84. } else {
  85. err = HTTPGet(item.URL, c.timeout)
  86. }
  87. if err != nil {
  88. item.FailCount = item.FailCount + 1
  89. } else {
  90. item.SuccessCount = item.SuccessCount + 1
  91. }
  92. c.data.Set(item.Host, item)
  93. }
  94. }(v.(CheckerItem))
  95. }
  96. time.Sleep(time.Second * time.Duration(c.interval))
  97. }
  98. }()
  99. }
  100. func (c *Checker) isNeedCheck(item CheckerItem) bool {
  101. var minCount uint = 5
  102. if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount) ||
  103. (item.FailCount >= minCount && item.SuccessCount > item.FailCount) ||
  104. c.domainIsInMap(item.Host, false) ||
  105. c.domainIsInMap(item.Host, true) {
  106. return false
  107. }
  108. return true
  109. }
  110. func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) {
  111. if c.domainIsInMap(address, true) {
  112. //log.Printf("%s in blocked ? true", address)
  113. return true, 0, 0
  114. }
  115. if c.domainIsInMap(address, false) {
  116. //log.Printf("%s in direct ? true", address)
  117. return false, 0, 0
  118. }
  119. _item, ok := c.data.Get(address)
  120. if !ok {
  121. //log.Printf("%s not in map, blocked true", address)
  122. return true, 0, 0
  123. }
  124. item := _item.(CheckerItem)
  125. return item.FailCount >= item.SuccessCount, item.FailCount, item.SuccessCount
  126. }
  127. func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
  128. u, err := url.Parse("http://" + address)
  129. if err != nil {
  130. log.Printf("blocked check , url parse err:%s", err)
  131. return true
  132. }
  133. domainSlice := strings.Split(u.Hostname(), ".")
  134. if len(domainSlice) > 1 {
  135. subSlice := domainSlice[:len(domainSlice)-1]
  136. topDomain := strings.Join(domainSlice[len(domainSlice)-1:], ".")
  137. checkDomain := topDomain
  138. for i := len(subSlice) - 1; i >= 0; i-- {
  139. checkDomain = subSlice[i] + "." + checkDomain
  140. if !blockedMap && c.directMap.Has(checkDomain) {
  141. return true
  142. }
  143. if blockedMap && c.blockedMap.Has(checkDomain) {
  144. return true
  145. }
  146. }
  147. }
  148. return false
  149. }
  150. func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []byte) {
  151. if c.domainIsInMap(address, false) || c.domainIsInMap(address, true) {
  152. return
  153. }
  154. if !isHTTPS && strings.ToLower(method) != "get" {
  155. return
  156. }
  157. var item CheckerItem
  158. u := strings.Split(address, ":")
  159. item = CheckerItem{
  160. URL: URL,
  161. Domain: u[0],
  162. Host: address,
  163. Data: data,
  164. IsHTTPS: isHTTPS,
  165. Method: method,
  166. }
  167. c.data.SetIfAbsent(item.Host, item)
  168. }
  169. type BasicAuth struct {
  170. data ConcurrentMap
  171. }
  172. func NewBasicAuth() BasicAuth {
  173. return BasicAuth{
  174. data: NewConcurrentMap(),
  175. }
  176. }
  177. func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
  178. _content, err := ioutil.ReadFile(file)
  179. if err != nil {
  180. return
  181. }
  182. userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
  183. for _, userpass := range userpassArr {
  184. if strings.HasPrefix("#", userpass) {
  185. continue
  186. }
  187. u := strings.Split(strings.Trim(userpass, " "), ":")
  188. if len(u) == 2 {
  189. ba.data.Set(u[0], u[1])
  190. n++
  191. }
  192. }
  193. return
  194. }
  195. func (ba *BasicAuth) Add(userpassArr []string) (n int) {
  196. for _, userpass := range userpassArr {
  197. u := strings.Split(userpass, ":")
  198. if len(u) == 2 {
  199. ba.data.Set(u[0], u[1])
  200. n++
  201. }
  202. }
  203. return
  204. }
  205. func (ba *BasicAuth) Check(userpass string) (ok bool) {
  206. u := strings.Split(strings.Trim(userpass, " "), ":")
  207. if len(u) == 2 {
  208. if p, _ok := ba.data.Get(u[0]); _ok {
  209. return p.(string) == u[1]
  210. }
  211. }
  212. return
  213. }
  214. func (ba *BasicAuth) Total() (n int) {
  215. n = ba.data.Count()
  216. return
  217. }
  218. type HTTPRequest struct {
  219. HeadBuf []byte
  220. conn *net.Conn
  221. Host string
  222. Method string
  223. URL string
  224. hostOrURL string
  225. isBasicAuth bool
  226. basicAuth *BasicAuth
  227. }
  228. func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) {
  229. buf := make([]byte, bufSize)
  230. len := 0
  231. req = HTTPRequest{
  232. conn: inConn,
  233. }
  234. len, err = (*inConn).Read(buf[:])
  235. if err != nil {
  236. if err != io.EOF {
  237. err = fmt.Errorf("http decoder read err:%s", err)
  238. }
  239. CloseConn(inConn)
  240. return
  241. }
  242. req.HeadBuf = buf[:len]
  243. index := bytes.IndexByte(req.HeadBuf, '\n')
  244. if index == -1 {
  245. err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50])
  246. CloseConn(inConn)
  247. return
  248. }
  249. fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
  250. if req.Method == "" || req.hostOrURL == "" {
  251. err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50])
  252. CloseConn(inConn)
  253. return
  254. }
  255. req.Method = strings.ToUpper(req.Method)
  256. req.isBasicAuth = isBasicAuth
  257. req.basicAuth = basicAuth
  258. log.Printf("%s:%s", req.Method, req.hostOrURL)
  259. if req.IsHTTPS() {
  260. err = req.HTTPS()
  261. } else {
  262. err = req.HTTP()
  263. }
  264. return
  265. }
  266. func (req *HTTPRequest) HTTP() (err error) {
  267. if req.isBasicAuth {
  268. err = req.BasicAuth()
  269. if err != nil {
  270. return
  271. }
  272. }
  273. req.URL, err = req.getHTTPURL()
  274. if err == nil {
  275. u, _ := url.Parse(req.URL)
  276. req.Host = u.Host
  277. req.addPortIfNot()
  278. }
  279. return
  280. }
  281. func (req *HTTPRequest) HTTPS() (err error) {
  282. req.Host = req.hostOrURL
  283. req.addPortIfNot()
  284. //_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
  285. return
  286. }
  287. func (req *HTTPRequest) HTTPSReply() (err error) {
  288. _, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
  289. return
  290. }
  291. func (req *HTTPRequest) IsHTTPS() bool {
  292. return req.Method == "CONNECT"
  293. }
  294. func (req *HTTPRequest) BasicAuth() (err error) {
  295. //log.Printf("request :%s", string(b[:n]))
  296. authorization, err := req.getHeader("Authorization")
  297. if err != nil {
  298. fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
  299. CloseConn(req.conn)
  300. return
  301. }
  302. //log.Printf("Authorization:%s", authorization)
  303. basic := strings.Fields(authorization)
  304. if len(basic) != 2 {
  305. err = fmt.Errorf("authorization data error,ERR:%s", authorization)
  306. CloseConn(req.conn)
  307. return
  308. }
  309. user, err := base64.StdEncoding.DecodeString(basic[1])
  310. if err != nil {
  311. err = fmt.Errorf("authorization data parse error,ERR:%s", err)
  312. CloseConn(req.conn)
  313. return
  314. }
  315. authOk := (*req.basicAuth).Check(string(user))
  316. //log.Printf("auth %s,%v", string(user), authOk)
  317. if !authOk {
  318. fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
  319. CloseConn(req.conn)
  320. err = fmt.Errorf("basic auth fail")
  321. return
  322. }
  323. return
  324. }
  325. func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
  326. if !strings.HasPrefix(req.hostOrURL, "/") {
  327. return req.hostOrURL, nil
  328. }
  329. _host, err := req.getHeader("host")
  330. if err != nil {
  331. return
  332. }
  333. URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
  334. return
  335. }
  336. func (req *HTTPRequest) getHeader(key string) (val string, err error) {
  337. key = strings.ToUpper(key)
  338. lines := strings.Split(string(req.HeadBuf), "\r\n")
  339. for _, line := range lines {
  340. line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
  341. if len(line) == 2 {
  342. k := strings.ToUpper(strings.Trim(line[0], " "))
  343. v := strings.Trim(line[1], " ")
  344. if key == k {
  345. val = v
  346. return
  347. }
  348. }
  349. }
  350. err = fmt.Errorf("can not find HOST header")
  351. return
  352. }
  353. func (req *HTTPRequest) addPortIfNot() (newHost string) {
  354. //newHost = req.Host
  355. port := "80"
  356. if req.IsHTTPS() {
  357. port = "443"
  358. }
  359. if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
  360. //newHost = req.Host + ":" + port
  361. //req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
  362. req.Host = req.Host + ":" + port
  363. }
  364. return
  365. }
  366. type OutPool struct {
  367. Pool ConnPool
  368. dur int
  369. isTLS bool
  370. certBytes []byte
  371. keyBytes []byte
  372. address string
  373. timeout int
  374. }
  375. func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
  376. op = OutPool{
  377. dur: dur,
  378. isTLS: isTLS,
  379. certBytes: certBytes,
  380. keyBytes: keyBytes,
  381. address: address,
  382. timeout: timeout,
  383. }
  384. var err error
  385. op.Pool, err = NewConnPool(poolConfig{
  386. IsActive: func(conn interface{}) bool { return true },
  387. Release: func(conn interface{}) {
  388. if conn != nil {
  389. conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
  390. conn.(net.Conn).Close()
  391. // log.Println("conn released")
  392. }
  393. },
  394. InitialCap: InitialCap,
  395. MaxCap: MaxCap,
  396. Factory: func() (conn interface{}, err error) {
  397. conn, err = op.getConn()
  398. return
  399. },
  400. })
  401. if err != nil {
  402. log.Fatalf("init conn pool fail ,%s", err)
  403. } else {
  404. if InitialCap > 0 {
  405. log.Printf("init conn pool success")
  406. op.initPoolDeamon()
  407. } else {
  408. log.Printf("conn pool closed")
  409. }
  410. }
  411. return
  412. }
  413. func (op *OutPool) getConn() (conn interface{}, err error) {
  414. if op.isTLS {
  415. var _conn tls.Conn
  416. _conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes)
  417. if err == nil {
  418. conn = net.Conn(&_conn)
  419. }
  420. } else {
  421. conn, err = ConnectHost(op.address, op.timeout)
  422. }
  423. return
  424. }
  425. func (op *OutPool) initPoolDeamon() {
  426. go func() {
  427. if op.dur <= 0 {
  428. return
  429. }
  430. log.Printf("pool deamon started")
  431. for {
  432. time.Sleep(time.Second * time.Duration(op.dur))
  433. conn, err := op.getConn()
  434. if err != nil {
  435. log.Printf("pool deamon err %s , release pool", err)
  436. op.Pool.ReleaseAll()
  437. } else {
  438. conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
  439. conn.(net.Conn).Close()
  440. }
  441. }
  442. }()
  443. }