store.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. /*
  2. * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
  3. * SPDX-License-Identifier: Apache-2.0
  4. */
  5. package ristretto
  6. import (
  7. "sync"
  8. "time"
  9. )
  10. type updateFn[V any] func(cur, prev V) bool
  11. // TODO: Do we need this to be a separate struct from Item?
  12. type storeItem[V any] struct {
  13. key uint64
  14. conflict uint64
  15. value V
  16. expiration time.Time
  17. }
  18. // store is the interface fulfilled by all hash map implementations in this
  19. // file. Some hash map implementations are better suited for certain data
  20. // distributions than others, so this allows us to abstract that out for use
  21. // in Ristretto.
  22. //
  23. // Every store is safe for concurrent usage.
  24. type store[V any] interface {
  25. // Get returns the value associated with the key parameter.
  26. Get(uint64, uint64) (V, bool)
  27. // Expiration returns the expiration time for this key.
  28. Expiration(uint64) time.Time
  29. // Set adds the key-value pair to the Map or updates the value if it's
  30. // already present. The key-value pair is passed as a pointer to an
  31. // item object.
  32. Set(*Item[V])
  33. // Del deletes the key-value pair from the Map.
  34. Del(uint64, uint64) (uint64, V)
  35. // Update attempts to update the key with a new value and returns true if
  36. // successful.
  37. Update(*Item[V]) (V, bool)
  38. // Cleanup removes items that have an expired TTL.
  39. Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V]))
  40. // Clear clears all contents of the store.
  41. Clear(onEvict func(item *Item[V]))
  42. SetShouldUpdateFn(f updateFn[V])
  43. }
  44. // newStore returns the default store implementation.
  45. func newStore[V any]() store[V] {
  46. return newShardedMap[V]()
  47. }
  48. const numShards uint64 = 256
  49. type shardedMap[V any] struct {
  50. shards []*lockedMap[V]
  51. expiryMap *expirationMap[V]
  52. }
  53. func newShardedMap[V any]() *shardedMap[V] {
  54. sm := &shardedMap[V]{
  55. shards: make([]*lockedMap[V], int(numShards)),
  56. expiryMap: newExpirationMap[V](),
  57. }
  58. for i := range sm.shards {
  59. sm.shards[i] = newLockedMap[V](sm.expiryMap)
  60. }
  61. return sm
  62. }
  63. func (m *shardedMap[V]) SetShouldUpdateFn(f updateFn[V]) {
  64. for i := range m.shards {
  65. m.shards[i].setShouldUpdateFn(f)
  66. }
  67. }
  68. func (sm *shardedMap[V]) Get(key, conflict uint64) (V, bool) {
  69. return sm.shards[key%numShards].get(key, conflict)
  70. }
  71. func (sm *shardedMap[V]) Expiration(key uint64) time.Time {
  72. return sm.shards[key%numShards].Expiration(key)
  73. }
  74. func (sm *shardedMap[V]) Set(i *Item[V]) {
  75. if i == nil {
  76. // If item is nil make this Set a no-op.
  77. return
  78. }
  79. sm.shards[i.Key%numShards].Set(i)
  80. }
  81. func (sm *shardedMap[V]) Del(key, conflict uint64) (uint64, V) {
  82. return sm.shards[key%numShards].Del(key, conflict)
  83. }
  84. func (sm *shardedMap[V]) Update(newItem *Item[V]) (V, bool) {
  85. return sm.shards[newItem.Key%numShards].Update(newItem)
  86. }
  87. func (sm *shardedMap[V]) Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V])) {
  88. sm.expiryMap.cleanup(sm, policy, onEvict)
  89. }
  90. func (sm *shardedMap[V]) Clear(onEvict func(item *Item[V])) {
  91. for i := uint64(0); i < numShards; i++ {
  92. sm.shards[i].Clear(onEvict)
  93. }
  94. sm.expiryMap.clear()
  95. }
  96. type lockedMap[V any] struct {
  97. sync.RWMutex
  98. data map[uint64]storeItem[V]
  99. em *expirationMap[V]
  100. shouldUpdate updateFn[V]
  101. }
  102. func newLockedMap[V any](em *expirationMap[V]) *lockedMap[V] {
  103. return &lockedMap[V]{
  104. data: make(map[uint64]storeItem[V]),
  105. em: em,
  106. shouldUpdate: func(cur, prev V) bool {
  107. return true
  108. },
  109. }
  110. }
  111. func (m *lockedMap[V]) setShouldUpdateFn(f updateFn[V]) {
  112. m.shouldUpdate = f
  113. }
  114. func (m *lockedMap[V]) get(key, conflict uint64) (V, bool) {
  115. m.RLock()
  116. item, ok := m.data[key]
  117. m.RUnlock()
  118. if !ok {
  119. return zeroValue[V](), false
  120. }
  121. if conflict != 0 && (conflict != item.conflict) {
  122. return zeroValue[V](), false
  123. }
  124. // Handle expired items.
  125. if !item.expiration.IsZero() && time.Now().After(item.expiration) {
  126. return zeroValue[V](), false
  127. }
  128. return item.value, true
  129. }
  130. func (m *lockedMap[V]) Expiration(key uint64) time.Time {
  131. m.RLock()
  132. defer m.RUnlock()
  133. return m.data[key].expiration
  134. }
  135. func (m *lockedMap[V]) Set(i *Item[V]) {
  136. if i == nil {
  137. // If the item is nil make this Set a no-op.
  138. return
  139. }
  140. m.Lock()
  141. defer m.Unlock()
  142. item, ok := m.data[i.Key]
  143. if ok {
  144. // The item existed already. We need to check the conflict key and reject the
  145. // update if they do not match. Only after that the expiration map is updated.
  146. if i.Conflict != 0 && (i.Conflict != item.conflict) {
  147. return
  148. }
  149. if m.shouldUpdate != nil && !m.shouldUpdate(i.Value, item.value) {
  150. return
  151. }
  152. m.em.update(i.Key, i.Conflict, item.expiration, i.Expiration)
  153. } else {
  154. // The value is not in the map already. There's no need to return anything.
  155. // Simply add the expiration map.
  156. m.em.add(i.Key, i.Conflict, i.Expiration)
  157. }
  158. m.data[i.Key] = storeItem[V]{
  159. key: i.Key,
  160. conflict: i.Conflict,
  161. value: i.Value,
  162. expiration: i.Expiration,
  163. }
  164. }
  165. func (m *lockedMap[V]) Del(key, conflict uint64) (uint64, V) {
  166. m.Lock()
  167. defer m.Unlock()
  168. item, ok := m.data[key]
  169. if !ok {
  170. return 0, zeroValue[V]()
  171. }
  172. if conflict != 0 && (conflict != item.conflict) {
  173. return 0, zeroValue[V]()
  174. }
  175. if !item.expiration.IsZero() {
  176. m.em.del(key, item.expiration)
  177. }
  178. delete(m.data, key)
  179. return item.conflict, item.value
  180. }
  181. func (m *lockedMap[V]) Update(newItem *Item[V]) (V, bool) {
  182. m.Lock()
  183. defer m.Unlock()
  184. item, ok := m.data[newItem.Key]
  185. if !ok {
  186. return zeroValue[V](), false
  187. }
  188. if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) {
  189. return zeroValue[V](), false
  190. }
  191. if m.shouldUpdate != nil && !m.shouldUpdate(newItem.Value, item.value) {
  192. return item.value, false
  193. }
  194. m.em.update(newItem.Key, newItem.Conflict, item.expiration, newItem.Expiration)
  195. m.data[newItem.Key] = storeItem[V]{
  196. key: newItem.Key,
  197. conflict: newItem.Conflict,
  198. value: newItem.Value,
  199. expiration: newItem.Expiration,
  200. }
  201. return item.value, true
  202. }
  203. func (m *lockedMap[V]) Clear(onEvict func(item *Item[V])) {
  204. m.Lock()
  205. defer m.Unlock()
  206. i := &Item[V]{}
  207. if onEvict != nil {
  208. for _, si := range m.data {
  209. i.Key = si.key
  210. i.Conflict = si.conflict
  211. i.Value = si.value
  212. onEvict(i)
  213. }
  214. }
  215. m.data = make(map[uint64]storeItem[V])
  216. }