trie.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /*
  2. * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
  3. * SPDX-License-Identifier: Apache-2.0
  4. */
  5. package trie
  6. import (
  7. "fmt"
  8. "strconv"
  9. "strings"
  10. "github.com/pkg/errors"
  11. "github.com/dgraph-io/badger/v4/pb"
  12. "github.com/dgraph-io/badger/v4/y"
  13. )
  14. type node struct {
  15. children map[byte]*node
  16. ignore *node
  17. ids []uint64
  18. }
  19. func (n *node) isEmpty() bool {
  20. return len(n.children) == 0 && len(n.ids) == 0 && n.ignore == nil
  21. }
  22. func newNode() *node {
  23. return &node{
  24. children: make(map[byte]*node),
  25. ids: []uint64{},
  26. }
  27. }
  28. // Trie datastructure.
  29. type Trie struct {
  30. root *node
  31. }
  32. // NewTrie returns Trie.
  33. func NewTrie() *Trie {
  34. return &Trie{
  35. root: newNode(),
  36. }
  37. }
  38. // parseIgnoreBytes would parse the ignore string, and convert it into a list of bools, where
  39. // bool[idx] = true implies that key[idx] can be ignored during comparison.
  40. func parseIgnoreBytes(ig string) ([]bool, error) {
  41. var out []bool
  42. if ig == "" {
  43. return out, nil
  44. }
  45. for _, each := range strings.Split(strings.TrimSpace(ig), ",") {
  46. r := strings.Split(strings.TrimSpace(each), "-")
  47. if len(r) == 0 || len(r) > 2 {
  48. return out, fmt.Errorf("Invalid range: %s", each)
  49. }
  50. start, end := -1, -1 //nolint:ineffassign
  51. if len(r) == 2 {
  52. idx, err := strconv.Atoi(strings.TrimSpace(r[1]))
  53. if err != nil {
  54. return out, err
  55. }
  56. end = idx
  57. }
  58. {
  59. // Always consider r[0]
  60. idx, err := strconv.Atoi(strings.TrimSpace(r[0]))
  61. if err != nil {
  62. return out, err
  63. }
  64. start = idx
  65. }
  66. if start == -1 {
  67. return out, fmt.Errorf("Invalid range: %s", each)
  68. }
  69. for start >= len(out) {
  70. out = append(out, false)
  71. }
  72. for end >= len(out) { // end could be -1, so do have the start loop above.
  73. out = append(out, false)
  74. }
  75. if end == -1 {
  76. out[start] = true
  77. } else {
  78. for i := start; i <= end; i++ {
  79. out[i] = true
  80. }
  81. }
  82. }
  83. return out, nil
  84. }
  85. // Add adds the id in the trie for the given prefix path.
  86. func (t *Trie) Add(prefix []byte, id uint64) {
  87. m := pb.Match{
  88. Prefix: prefix,
  89. }
  90. y.Check(t.AddMatch(m, id))
  91. }
  92. // AddMatch allows you to send in a prefix match, with "holes" in the prefix. The holes are
  93. // specified via IgnoreBytes in a comma-separated list of indices starting from 0. A dash can be
  94. // used to denote a range. Valid example is "3, 5-8, 10, 12-15". Length of IgnoreBytes does not need
  95. // to match the length of the Prefix passed.
  96. //
  97. // Consider a prefix = "aaaa". If the IgnoreBytes is set to "0, 2", then along with key "aaaa...",
  98. // a key "baba..." would also match.
  99. func (t *Trie) AddMatch(m pb.Match, id uint64) error {
  100. return t.fix(m, id, set)
  101. }
  102. const (
  103. set = iota
  104. del
  105. )
  106. func (t *Trie) fix(m pb.Match, id uint64, op int) error {
  107. curNode := t.root
  108. ignore, err := parseIgnoreBytes(m.IgnoreBytes)
  109. if err != nil {
  110. return errors.Wrapf(err, "while parsing ignore bytes: %s", m.IgnoreBytes)
  111. }
  112. for len(ignore) < len(m.Prefix) {
  113. ignore = append(ignore, false)
  114. }
  115. for idx, byt := range m.Prefix {
  116. var child *node
  117. if ignore[idx] {
  118. child = curNode.ignore
  119. if child == nil {
  120. if op == del {
  121. // No valid node found for delete operation. Return immediately.
  122. return nil
  123. }
  124. child = newNode()
  125. curNode.ignore = child
  126. }
  127. } else {
  128. child = curNode.children[byt]
  129. if child == nil {
  130. if op == del {
  131. // No valid node found for delete operation. Return immediately.
  132. return nil
  133. }
  134. child = newNode()
  135. curNode.children[byt] = child
  136. }
  137. }
  138. curNode = child
  139. }
  140. // We only need to add the id to the last node of the given prefix.
  141. if op == set {
  142. curNode.ids = append(curNode.ids, id)
  143. } else if op == del {
  144. out := curNode.ids[:0]
  145. for _, cid := range curNode.ids {
  146. if id != cid {
  147. out = append(out, cid)
  148. }
  149. }
  150. curNode.ids = out
  151. } else {
  152. y.AssertTrue(false)
  153. }
  154. return nil
  155. }
  156. func (t *Trie) Get(key []byte) map[uint64]struct{} {
  157. return t.get(t.root, key)
  158. }
  159. // Get returns prefix matched ids for the given key.
  160. func (t *Trie) get(curNode *node, key []byte) map[uint64]struct{} {
  161. y.AssertTrue(curNode != nil)
  162. out := make(map[uint64]struct{})
  163. // If any node in the path of the key has ids, pick them up.
  164. // This would also match nil prefixes.
  165. for _, i := range curNode.ids {
  166. out[i] = struct{}{}
  167. }
  168. if len(key) == 0 {
  169. return out
  170. }
  171. // If we found an ignore node, traverse that path.
  172. if curNode.ignore != nil {
  173. res := t.get(curNode.ignore, key[1:])
  174. for id := range res {
  175. out[id] = struct{}{}
  176. }
  177. }
  178. if child := curNode.children[key[0]]; child != nil {
  179. res := t.get(child, key[1:])
  180. for id := range res {
  181. out[id] = struct{}{}
  182. }
  183. }
  184. return out
  185. }
  186. func removeEmpty(curNode *node) bool {
  187. // Go depth first.
  188. if curNode.ignore != nil {
  189. if empty := removeEmpty(curNode.ignore); empty {
  190. curNode.ignore = nil
  191. }
  192. }
  193. for byt, n := range curNode.children {
  194. if empty := removeEmpty(n); empty {
  195. delete(curNode.children, byt)
  196. }
  197. }
  198. return curNode.isEmpty()
  199. }
  200. // Delete will delete the id if the id exist in the given index path.
  201. func (t *Trie) Delete(prefix []byte, id uint64) error {
  202. return t.DeleteMatch(pb.Match{Prefix: prefix}, id)
  203. }
  204. func (t *Trie) DeleteMatch(m pb.Match, id uint64) error {
  205. if err := t.fix(m, id, del); err != nil {
  206. return err
  207. }
  208. // Would recursively delete empty nodes.
  209. // Do not remove the t.root even if its empty.
  210. removeEmpty(t.root)
  211. return nil
  212. }
  213. func numNodes(curNode *node) int {
  214. if curNode == nil {
  215. return 0
  216. }
  217. num := numNodes(curNode.ignore)
  218. for _, n := range curNode.children {
  219. num += numNodes(n)
  220. }
  221. return num + 1
  222. }