trie.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /*
  2. * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
  3. * SPDX-License-Identifier: Apache-2.0
  4. */
  5. package trie
  6. import (
  7. "fmt"
  8. "strconv"
  9. "strings"
  10. "github.com/dgraph-io/badger/v4/pb"
  11. "github.com/dgraph-io/badger/v4/y"
  12. )
  13. type node struct {
  14. children map[byte]*node
  15. ignore *node
  16. ids []uint64
  17. }
  18. func (n *node) isEmpty() bool {
  19. return len(n.children) == 0 && len(n.ids) == 0 && n.ignore == nil
  20. }
  21. func newNode() *node {
  22. return &node{
  23. children: make(map[byte]*node),
  24. ids: []uint64{},
  25. }
  26. }
  27. // Trie datastructure.
  28. type Trie struct {
  29. root *node
  30. }
  31. // NewTrie returns Trie.
  32. func NewTrie() *Trie {
  33. return &Trie{
  34. root: newNode(),
  35. }
  36. }
  37. // parseIgnoreBytes would parse the ignore string, and convert it into a list of bools, where
  38. // bool[idx] = true implies that key[idx] can be ignored during comparison.
  39. func parseIgnoreBytes(ig string) ([]bool, error) {
  40. var out []bool
  41. if ig == "" {
  42. return out, nil
  43. }
  44. for _, each := range strings.Split(strings.TrimSpace(ig), ",") {
  45. r := strings.Split(strings.TrimSpace(each), "-")
  46. if len(r) == 0 || len(r) > 2 {
  47. return out, fmt.Errorf("Invalid range: %s", each)
  48. }
  49. start, end := -1, -1 //nolint:ineffassign
  50. if len(r) == 2 {
  51. idx, err := strconv.Atoi(strings.TrimSpace(r[1]))
  52. if err != nil {
  53. return out, err
  54. }
  55. end = idx
  56. }
  57. {
  58. // Always consider r[0]
  59. idx, err := strconv.Atoi(strings.TrimSpace(r[0]))
  60. if err != nil {
  61. return out, err
  62. }
  63. start = idx
  64. }
  65. if start == -1 {
  66. return out, fmt.Errorf("Invalid range: %s", each)
  67. }
  68. for start >= len(out) {
  69. out = append(out, false)
  70. }
  71. for end >= len(out) { // end could be -1, so do have the start loop above.
  72. out = append(out, false)
  73. }
  74. if end == -1 {
  75. out[start] = true
  76. } else {
  77. for i := start; i <= end; i++ {
  78. out[i] = true
  79. }
  80. }
  81. }
  82. return out, nil
  83. }
  84. // Add adds the id in the trie for the given prefix path.
  85. func (t *Trie) Add(prefix []byte, id uint64) {
  86. m := pb.Match{
  87. Prefix: prefix,
  88. }
  89. y.Check(t.AddMatch(m, id))
  90. }
  91. // AddMatch allows you to send in a prefix match, with "holes" in the prefix. The holes are
  92. // specified via IgnoreBytes in a comma-separated list of indices starting from 0. A dash can be
  93. // used to denote a range. Valid example is "3, 5-8, 10, 12-15". Length of IgnoreBytes does not need
  94. // to match the length of the Prefix passed.
  95. //
  96. // Consider a prefix = "aaaa". If the IgnoreBytes is set to "0, 2", then along with key "aaaa...",
  97. // a key "baba..." would also match.
  98. func (t *Trie) AddMatch(m pb.Match, id uint64) error {
  99. return t.fix(m, id, set)
  100. }
  101. const (
  102. set = iota
  103. del
  104. )
  105. func (t *Trie) fix(m pb.Match, id uint64, op int) error {
  106. curNode := t.root
  107. ignore, err := parseIgnoreBytes(m.IgnoreBytes)
  108. if err != nil {
  109. return fmt.Errorf( "while parsing ignore bytes: %s: %w", m.IgnoreBytes,err)
  110. }
  111. for len(ignore) < len(m.Prefix) {
  112. ignore = append(ignore, false)
  113. }
  114. for idx, byt := range m.Prefix {
  115. var child *node
  116. if ignore[idx] {
  117. child = curNode.ignore
  118. if child == nil {
  119. if op == del {
  120. // No valid node found for delete operation. Return immediately.
  121. return nil
  122. }
  123. child = newNode()
  124. curNode.ignore = child
  125. }
  126. } else {
  127. child = curNode.children[byt]
  128. if child == nil {
  129. if op == del {
  130. // No valid node found for delete operation. Return immediately.
  131. return nil
  132. }
  133. child = newNode()
  134. curNode.children[byt] = child
  135. }
  136. }
  137. curNode = child
  138. }
  139. // We only need to add the id to the last node of the given prefix.
  140. if op == set {
  141. curNode.ids = append(curNode.ids, id)
  142. } else if op == del {
  143. out := curNode.ids[:0]
  144. for _, cid := range curNode.ids {
  145. if id != cid {
  146. out = append(out, cid)
  147. }
  148. }
  149. curNode.ids = out
  150. } else {
  151. y.AssertTrue(false)
  152. }
  153. return nil
  154. }
  155. func (t *Trie) Get(key []byte) map[uint64]struct{} {
  156. return t.get(t.root, key)
  157. }
  158. // Get returns prefix matched ids for the given key.
  159. func (t *Trie) get(curNode *node, key []byte) map[uint64]struct{} {
  160. y.AssertTrue(curNode != nil)
  161. out := make(map[uint64]struct{})
  162. // If any node in the path of the key has ids, pick them up.
  163. // This would also match nil prefixes.
  164. for _, i := range curNode.ids {
  165. out[i] = struct{}{}
  166. }
  167. if len(key) == 0 {
  168. return out
  169. }
  170. // If we found an ignore node, traverse that path.
  171. if curNode.ignore != nil {
  172. res := t.get(curNode.ignore, key[1:])
  173. for id := range res {
  174. out[id] = struct{}{}
  175. }
  176. }
  177. if child := curNode.children[key[0]]; child != nil {
  178. res := t.get(child, key[1:])
  179. for id := range res {
  180. out[id] = struct{}{}
  181. }
  182. }
  183. return out
  184. }
  185. func removeEmpty(curNode *node) bool {
  186. // Go depth first.
  187. if curNode.ignore != nil {
  188. if empty := removeEmpty(curNode.ignore); empty {
  189. curNode.ignore = nil
  190. }
  191. }
  192. for byt, n := range curNode.children {
  193. if empty := removeEmpty(n); empty {
  194. delete(curNode.children, byt)
  195. }
  196. }
  197. return curNode.isEmpty()
  198. }
  199. // Delete will delete the id if the id exist in the given index path.
  200. func (t *Trie) Delete(prefix []byte, id uint64) error {
  201. return t.DeleteMatch(pb.Match{Prefix: prefix}, id)
  202. }
  203. func (t *Trie) DeleteMatch(m pb.Match, id uint64) error {
  204. if err := t.fix(m, id, del); err != nil {
  205. return err
  206. }
  207. // Would recursively delete empty nodes.
  208. // Do not remove the t.root even if its empty.
  209. removeEmpty(t.root)
  210. return nil
  211. }
  212. func numNodes(curNode *node) int {
  213. if curNode == nil {
  214. return 0
  215. }
  216. num := numNodes(curNode.ignore)
  217. for _, n := range curNode.children {
  218. num += numNodes(n)
  219. }
  220. return num + 1
  221. }