trie.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. /*
  2. * Copyright 2019 Dgraph Labs, Inc. and Contributors
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package trie
  17. import (
  18. "fmt"
  19. "strconv"
  20. "strings"
  21. "github.com/pkg/errors"
  22. "github.com/dgraph-io/badger/v4/pb"
  23. "github.com/dgraph-io/badger/v4/y"
  24. )
  25. type node struct {
  26. children map[byte]*node
  27. ignore *node
  28. ids []uint64
  29. }
  30. func (n *node) isEmpty() bool {
  31. return len(n.children) == 0 && len(n.ids) == 0 && n.ignore == nil
  32. }
  33. func newNode() *node {
  34. return &node{
  35. children: make(map[byte]*node),
  36. ids: []uint64{},
  37. }
  38. }
  39. // Trie datastructure.
  40. type Trie struct {
  41. root *node
  42. }
  43. // NewTrie returns Trie.
  44. func NewTrie() *Trie {
  45. return &Trie{
  46. root: newNode(),
  47. }
  48. }
  49. // parseIgnoreBytes would parse the ignore string, and convert it into a list of bools, where
  50. // bool[idx] = true implies that key[idx] can be ignored during comparison.
  51. func parseIgnoreBytes(ig string) ([]bool, error) {
  52. var out []bool
  53. if ig == "" {
  54. return out, nil
  55. }
  56. for _, each := range strings.Split(strings.TrimSpace(ig), ",") {
  57. r := strings.Split(strings.TrimSpace(each), "-")
  58. if len(r) == 0 || len(r) > 2 {
  59. return out, fmt.Errorf("Invalid range: %s", each)
  60. }
  61. start, end := -1, -1 //nolint:ineffassign
  62. if len(r) == 2 {
  63. idx, err := strconv.Atoi(strings.TrimSpace(r[1]))
  64. if err != nil {
  65. return out, err
  66. }
  67. end = idx
  68. }
  69. {
  70. // Always consider r[0]
  71. idx, err := strconv.Atoi(strings.TrimSpace(r[0]))
  72. if err != nil {
  73. return out, err
  74. }
  75. start = idx
  76. }
  77. if start == -1 {
  78. return out, fmt.Errorf("Invalid range: %s", each)
  79. }
  80. for start >= len(out) {
  81. out = append(out, false)
  82. }
  83. for end >= len(out) { // end could be -1, so do have the start loop above.
  84. out = append(out, false)
  85. }
  86. if end == -1 {
  87. out[start] = true
  88. } else {
  89. for i := start; i <= end; i++ {
  90. out[i] = true
  91. }
  92. }
  93. }
  94. return out, nil
  95. }
  96. // Add adds the id in the trie for the given prefix path.
  97. func (t *Trie) Add(prefix []byte, id uint64) {
  98. m := pb.Match{
  99. Prefix: prefix,
  100. }
  101. y.Check(t.AddMatch(m, id))
  102. }
  103. // AddMatch allows you to send in a prefix match, with "holes" in the prefix. The holes are
  104. // specified via IgnoreBytes in a comma-separated list of indices starting from 0. A dash can be
  105. // used to denote a range. Valid example is "3, 5-8, 10, 12-15". Length of IgnoreBytes does not need
  106. // to match the length of the Prefix passed.
  107. //
  108. // Consider a prefix = "aaaa". If the IgnoreBytes is set to "0, 2", then along with key "aaaa...",
  109. // a key "baba..." would also match.
  110. func (t *Trie) AddMatch(m pb.Match, id uint64) error {
  111. return t.fix(m, id, set)
  112. }
  113. const (
  114. set = iota
  115. del
  116. )
  117. func (t *Trie) fix(m pb.Match, id uint64, op int) error {
  118. curNode := t.root
  119. ignore, err := parseIgnoreBytes(m.IgnoreBytes)
  120. if err != nil {
  121. return errors.Wrapf(err, "while parsing ignore bytes: %s", m.IgnoreBytes)
  122. }
  123. for len(ignore) < len(m.Prefix) {
  124. ignore = append(ignore, false)
  125. }
  126. for idx, byt := range m.Prefix {
  127. var child *node
  128. if ignore[idx] {
  129. child = curNode.ignore
  130. if child == nil {
  131. if op == del {
  132. // No valid node found for delete operation. Return immediately.
  133. return nil
  134. }
  135. child = newNode()
  136. curNode.ignore = child
  137. }
  138. } else {
  139. child = curNode.children[byt]
  140. if child == nil {
  141. if op == del {
  142. // No valid node found for delete operation. Return immediately.
  143. return nil
  144. }
  145. child = newNode()
  146. curNode.children[byt] = child
  147. }
  148. }
  149. curNode = child
  150. }
  151. // We only need to add the id to the last node of the given prefix.
  152. if op == set {
  153. curNode.ids = append(curNode.ids, id)
  154. } else if op == del {
  155. out := curNode.ids[:0]
  156. for _, cid := range curNode.ids {
  157. if id != cid {
  158. out = append(out, cid)
  159. }
  160. }
  161. curNode.ids = out
  162. } else {
  163. y.AssertTrue(false)
  164. }
  165. return nil
  166. }
  167. func (t *Trie) Get(key []byte) map[uint64]struct{} {
  168. return t.get(t.root, key)
  169. }
  170. // Get returns prefix matched ids for the given key.
  171. func (t *Trie) get(curNode *node, key []byte) map[uint64]struct{} {
  172. y.AssertTrue(curNode != nil)
  173. out := make(map[uint64]struct{})
  174. // If any node in the path of the key has ids, pick them up.
  175. // This would also match nil prefixes.
  176. for _, i := range curNode.ids {
  177. out[i] = struct{}{}
  178. }
  179. if len(key) == 0 {
  180. return out
  181. }
  182. // If we found an ignore node, traverse that path.
  183. if curNode.ignore != nil {
  184. res := t.get(curNode.ignore, key[1:])
  185. for id := range res {
  186. out[id] = struct{}{}
  187. }
  188. }
  189. if child := curNode.children[key[0]]; child != nil {
  190. res := t.get(child, key[1:])
  191. for id := range res {
  192. out[id] = struct{}{}
  193. }
  194. }
  195. return out
  196. }
  197. func removeEmpty(curNode *node) bool {
  198. // Go depth first.
  199. if curNode.ignore != nil {
  200. if empty := removeEmpty(curNode.ignore); empty {
  201. curNode.ignore = nil
  202. }
  203. }
  204. for byt, n := range curNode.children {
  205. if empty := removeEmpty(n); empty {
  206. delete(curNode.children, byt)
  207. }
  208. }
  209. return curNode.isEmpty()
  210. }
  211. // Delete will delete the id if the id exist in the given index path.
  212. func (t *Trie) Delete(prefix []byte, id uint64) error {
  213. return t.DeleteMatch(pb.Match{Prefix: prefix}, id)
  214. }
  215. func (t *Trie) DeleteMatch(m pb.Match, id uint64) error {
  216. if err := t.fix(m, id, del); err != nil {
  217. return err
  218. }
  219. // Would recursively delete empty nodes.
  220. // Do not remove the t.root even if its empty.
  221. removeEmpty(t.root)
  222. return nil
  223. }
  224. func numNodes(curNode *node) int {
  225. if curNode == nil {
  226. return 0
  227. }
  228. num := numNodes(curNode.ignore)
  229. for _, n := range curNode.children {
  230. num += numNodes(n)
  231. }
  232. return num + 1
  233. }