peripconn.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package fasthttp
  2. import (
  3. "fmt"
  4. "net"
  5. "sync"
  6. )
  7. type perIPConnCounter struct {
  8. pool sync.Pool
  9. lock sync.Mutex
  10. m map[uint32]int
  11. }
  12. func (cc *perIPConnCounter) Register(ip uint32) int {
  13. cc.lock.Lock()
  14. if cc.m == nil {
  15. cc.m = make(map[uint32]int)
  16. }
  17. n := cc.m[ip] + 1
  18. cc.m[ip] = n
  19. cc.lock.Unlock()
  20. return n
  21. }
  22. func (cc *perIPConnCounter) Unregister(ip uint32) {
  23. cc.lock.Lock()
  24. if cc.m == nil {
  25. cc.lock.Unlock()
  26. panic("BUG: perIPConnCounter.Register() wasn't called")
  27. }
  28. n := cc.m[ip] - 1
  29. if n < 0 {
  30. cc.lock.Unlock()
  31. panic(fmt.Sprintf("BUG: negative per-ip counter=%d for ip=%d", n, ip))
  32. }
  33. cc.m[ip] = n
  34. cc.lock.Unlock()
  35. }
  36. type perIPConn struct {
  37. net.Conn
  38. ip uint32
  39. perIPConnCounter *perIPConnCounter
  40. }
  41. func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
  42. v := counter.pool.Get()
  43. if v == nil {
  44. return &perIPConn{
  45. perIPConnCounter: counter,
  46. Conn: conn,
  47. ip: ip,
  48. }
  49. }
  50. c := v.(*perIPConn)
  51. c.Conn = conn
  52. c.ip = ip
  53. return c
  54. }
  55. func releasePerIPConn(c *perIPConn) {
  56. c.Conn = nil
  57. c.perIPConnCounter.pool.Put(c)
  58. }
  59. func (c *perIPConn) Close() error {
  60. err := c.Conn.Close()
  61. c.perIPConnCounter.Unregister(c.ip)
  62. releasePerIPConn(c)
  63. return err
  64. }
  65. func getUint32IP(c net.Conn) uint32 {
  66. return ip2uint32(getConnIP4(c))
  67. }
  68. func getConnIP4(c net.Conn) net.IP {
  69. addr := c.RemoteAddr()
  70. ipAddr, ok := addr.(*net.TCPAddr)
  71. if !ok {
  72. return net.IPv4zero
  73. }
  74. return ipAddr.IP.To4()
  75. }
  76. func ip2uint32(ip net.IP) uint32 {
  77. if len(ip) != 4 {
  78. return 0
  79. }
  80. return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
  81. }
  82. func uint322ip(ip uint32) net.IP {
  83. b := make([]byte, 4)
  84. b[0] = byte(ip >> 24)
  85. b[1] = byte(ip >> 16)
  86. b[2] = byte(ip >> 8)
  87. b[3] = byte(ip)
  88. return b
  89. }