peripconn.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package fasthttp
  2. import (
  3. "crypto/tls"
  4. "net"
  5. "sync"
  6. )
  7. type perIPConnCounter struct {
  8. perIPConnPool sync.Pool
  9. perIPTLSConnPool sync.Pool
  10. m map[uint32]int
  11. lock sync.Mutex
  12. }
  13. func (cc *perIPConnCounter) Register(ip uint32) int {
  14. cc.lock.Lock()
  15. if cc.m == nil {
  16. cc.m = make(map[uint32]int)
  17. }
  18. n := cc.m[ip] + 1
  19. cc.m[ip] = n
  20. cc.lock.Unlock()
  21. return n
  22. }
  23. func (cc *perIPConnCounter) Unregister(ip uint32) {
  24. cc.lock.Lock()
  25. defer cc.lock.Unlock()
  26. if cc.m == nil {
  27. // developer safeguard
  28. panic("BUG: perIPConnCounter.Register() wasn't called")
  29. }
  30. n := cc.m[ip] - 1
  31. if n < 0 {
  32. n = 0
  33. }
  34. cc.m[ip] = n
  35. }
  36. type perIPConn struct {
  37. net.Conn
  38. perIPConnCounter *perIPConnCounter
  39. ip uint32
  40. lock sync.Mutex
  41. }
  42. type perIPTLSConn struct {
  43. *tls.Conn
  44. perIPConnCounter *perIPConnCounter
  45. ip uint32
  46. lock sync.Mutex
  47. }
  48. func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
  49. if tlsConn, ok := conn.(*tls.Conn); ok {
  50. v := counter.perIPTLSConnPool.Get()
  51. if v == nil {
  52. return &perIPTLSConn{
  53. perIPConnCounter: counter,
  54. Conn: tlsConn,
  55. ip: ip,
  56. }
  57. }
  58. c := v.(*perIPTLSConn)
  59. c.Conn = tlsConn
  60. c.ip = ip
  61. return c
  62. }
  63. v := counter.perIPConnPool.Get()
  64. if v == nil {
  65. return &perIPConn{
  66. perIPConnCounter: counter,
  67. Conn: conn,
  68. ip: ip,
  69. }
  70. }
  71. c := v.(*perIPConn)
  72. c.Conn = conn
  73. c.ip = ip
  74. return c
  75. }
  76. func (c *perIPConn) Close() error {
  77. c.lock.Lock()
  78. cc := c.Conn
  79. c.Conn = nil
  80. c.lock.Unlock()
  81. if cc == nil {
  82. return nil
  83. }
  84. err := cc.Close()
  85. c.perIPConnCounter.Unregister(c.ip)
  86. c.perIPConnCounter.perIPConnPool.Put(c)
  87. return err
  88. }
  89. func (c *perIPTLSConn) Close() error {
  90. c.lock.Lock()
  91. cc := c.Conn
  92. c.Conn = nil
  93. c.lock.Unlock()
  94. if cc == nil {
  95. return nil
  96. }
  97. err := cc.Close()
  98. c.perIPConnCounter.Unregister(c.ip)
  99. c.perIPConnCounter.perIPTLSConnPool.Put(c)
  100. return err
  101. }
  102. func getUint32IP(c net.Conn) uint32 {
  103. return ip2uint32(getConnIP4(c))
  104. }
  105. func getConnIP4(c net.Conn) net.IP {
  106. addr := c.RemoteAddr()
  107. ipAddr, ok := addr.(*net.TCPAddr)
  108. if !ok {
  109. return net.IPv4zero
  110. }
  111. return ipAddr.IP.To4()
  112. }
  113. func ip2uint32(ip net.IP) uint32 {
  114. if len(ip) != 4 {
  115. return 0
  116. }
  117. return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
  118. }
  119. func uint322ip(ip uint32) net.IP {
  120. b := make([]byte, 4)
  121. b[0] = byte(ip >> 24)
  122. b[1] = byte(ip >> 16)
  123. b[2] = byte(ip >> 8)
  124. b[3] = byte(ip)
  125. return b
  126. }