sketch.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /*
  2. * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
  3. * SPDX-License-Identifier: Apache-2.0
  4. */
  5. package ristretto
  6. import (
  7. "fmt"
  8. "math/rand"
  9. "time"
  10. )
  11. // cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily
  12. // based on Damian Gryski's CM4 [1].
  13. //
  14. // [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go
  15. type cmSketch struct {
  16. rows [cmDepth]cmRow
  17. seed [cmDepth]uint64
  18. mask uint64
  19. }
  20. const (
  21. // cmDepth is the number of counter copies to store (think of it as rows).
  22. cmDepth = 4
  23. )
  24. func newCmSketch(numCounters int64) *cmSketch {
  25. if numCounters == 0 {
  26. panic("cmSketch: bad numCounters")
  27. }
  28. // Get the next power of 2 for better cache performance.
  29. numCounters = next2Power(numCounters)
  30. sketch := &cmSketch{mask: uint64(numCounters - 1)}
  31. // Initialize rows of counters and seeds.
  32. // Cryptographic precision not needed
  33. source := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec
  34. for i := 0; i < cmDepth; i++ {
  35. sketch.seed[i] = source.Uint64()
  36. sketch.rows[i] = newCmRow(numCounters)
  37. }
  38. return sketch
  39. }
  40. // Increment increments the count(ers) for the specified key.
  41. func (s *cmSketch) Increment(hashed uint64) {
  42. for i := range s.rows {
  43. s.rows[i].increment((hashed ^ s.seed[i]) & s.mask)
  44. }
  45. }
  46. // Estimate returns the value of the specified key.
  47. func (s *cmSketch) Estimate(hashed uint64) int64 {
  48. min := byte(255)
  49. for i := range s.rows {
  50. val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask)
  51. if val < min {
  52. min = val
  53. }
  54. }
  55. return int64(min)
  56. }
  57. // Reset halves all counter values.
  58. func (s *cmSketch) Reset() {
  59. for _, r := range s.rows {
  60. r.reset()
  61. }
  62. }
  63. // Clear zeroes all counters.
  64. func (s *cmSketch) Clear() {
  65. for _, r := range s.rows {
  66. r.clear()
  67. }
  68. }
  69. // cmRow is a row of bytes, with each byte holding two counters.
  70. type cmRow []byte
  71. func newCmRow(numCounters int64) cmRow {
  72. return make(cmRow, numCounters/2)
  73. }
  74. func (r cmRow) get(n uint64) byte {
  75. return (r[n/2] >> ((n & 1) * 4)) & 0x0f
  76. }
  77. func (r cmRow) increment(n uint64) {
  78. // Index of the counter.
  79. i := n / 2
  80. // Shift distance (even 0, odd 4).
  81. s := (n & 1) * 4
  82. // Counter value.
  83. v := (r[i] >> s) & 0x0f
  84. // Only increment if not max value (overflow wrap is bad for LFU).
  85. if v < 15 {
  86. r[i] += 1 << s
  87. }
  88. }
  89. func (r cmRow) reset() {
  90. // Halve each counter.
  91. for i := range r {
  92. r[i] = (r[i] >> 1) & 0x77
  93. }
  94. }
  95. func (r cmRow) clear() {
  96. // Zero each counter.
  97. for i := range r {
  98. r[i] = 0
  99. }
  100. }
  101. func (r cmRow) string() string {
  102. s := ""
  103. for i := uint64(0); i < uint64(len(r)*2); i++ {
  104. s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f)
  105. }
  106. s = s[:len(s)-1]
  107. return s
  108. }
  109. // next2Power rounds x up to the next power of 2, if it's not already one.
  110. func next2Power(x int64) int64 {
  111. x--
  112. x |= x >> 1
  113. x |= x >> 2
  114. x |= x >> 4
  115. x |= x >> 8
  116. x |= x >> 16
  117. x |= x >> 32
  118. x++
  119. return x
  120. }