peripconn.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package fasthttp
  2. import (
  3. "crypto/tls"
  4. "net"
  5. "sync"
  6. )
  7. type perIPConnCounter struct {
  8. perIPConnPool sync.Pool
  9. perIPTLSConnPool sync.Pool
  10. m map[uint32]int
  11. lock sync.Mutex
  12. }
  13. func (cc *perIPConnCounter) Register(ip uint32) int {
  14. cc.lock.Lock()
  15. if cc.m == nil {
  16. cc.m = make(map[uint32]int)
  17. }
  18. n := cc.m[ip] + 1
  19. cc.m[ip] = n
  20. cc.lock.Unlock()
  21. return n
  22. }
  23. func (cc *perIPConnCounter) Unregister(ip uint32) {
  24. cc.lock.Lock()
  25. defer cc.lock.Unlock()
  26. if cc.m == nil {
  27. // developer safeguard
  28. panic("BUG: perIPConnCounter.Register() wasn't called")
  29. }
  30. n := cc.m[ip] - 1
  31. if n < 0 {
  32. n = 0
  33. }
  34. cc.m[ip] = n
  35. }
  36. type perIPConn struct {
  37. net.Conn
  38. perIPConnCounter *perIPConnCounter
  39. ip uint32
  40. }
  41. type perIPTLSConn struct {
  42. *tls.Conn
  43. perIPConnCounter *perIPConnCounter
  44. ip uint32
  45. }
  46. func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
  47. if tlsConn, ok := conn.(*tls.Conn); ok {
  48. v := counter.perIPTLSConnPool.Get()
  49. if v == nil {
  50. return &perIPTLSConn{
  51. perIPConnCounter: counter,
  52. Conn: tlsConn,
  53. ip: ip,
  54. }
  55. }
  56. c := v.(*perIPTLSConn)
  57. c.Conn = tlsConn
  58. c.ip = ip
  59. return c
  60. }
  61. v := counter.perIPConnPool.Get()
  62. if v == nil {
  63. return &perIPConn{
  64. perIPConnCounter: counter,
  65. Conn: conn,
  66. ip: ip,
  67. }
  68. }
  69. c := v.(*perIPConn)
  70. c.Conn = conn
  71. c.ip = ip
  72. return c
  73. }
  74. func (c *perIPConn) Close() error {
  75. err := c.Conn.Close()
  76. c.perIPConnCounter.Unregister(c.ip)
  77. c.Conn = nil
  78. c.perIPConnCounter.perIPConnPool.Put(c)
  79. return err
  80. }
  81. func (c *perIPTLSConn) Close() error {
  82. err := c.Conn.Close()
  83. c.perIPConnCounter.Unregister(c.ip)
  84. c.Conn = nil
  85. c.perIPConnCounter.perIPTLSConnPool.Put(c)
  86. return err
  87. }
  88. func getUint32IP(c net.Conn) uint32 {
  89. return ip2uint32(getConnIP4(c))
  90. }
  91. func getConnIP4(c net.Conn) net.IP {
  92. addr := c.RemoteAddr()
  93. ipAddr, ok := addr.(*net.TCPAddr)
  94. if !ok {
  95. return net.IPv4zero
  96. }
  97. return ipAddr.IP.To4()
  98. }
  99. func ip2uint32(ip net.IP) uint32 {
  100. if len(ip) != 4 {
  101. return 0
  102. }
  103. return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
  104. }
  105. func uint322ip(ip uint32) net.IP {
  106. b := make([]byte, 4)
  107. b[0] = byte(ip >> 24)
  108. b[1] = byte(ip >> 16)
  109. b[2] = byte(ip >> 8)
  110. b[3] = byte(ip)
  111. return b
  112. }