pathfinder.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package matchfinder
  2. import (
  3. "encoding/binary"
  4. "math"
  5. "math/bits"
  6. "slices"
  7. )
  8. // Pathfinder is a MatchFinder that uses hash chains to find matches, and a
  9. // shortest-path optimizer to choose which matches to use.
  10. type Pathfinder struct {
  11. // MaxDistance is the maximum distance (in bytes) to look back for
  12. // a match. The default is 65535.
  13. MaxDistance int
  14. // MinLength is the length of the shortest match to return.
  15. // The default is 4.
  16. MinLength int
  17. // HashLen is the number of bytes to use to calculate the hashes.
  18. // The maximum is 8 and the default is 6.
  19. HashLen int
  20. // TableBits is the number of bits in the hash table indexes.
  21. // The default is 17 (128K entries).
  22. TableBits int
  23. // ChainLength is how many entries to search on the "match chain" of older
  24. // locations with the same hash as the current location.
  25. ChainLength int
  26. table []uint32
  27. chain []uint32
  28. history []byte
  29. // holding onto buffers to reduce allocations:
  30. arrivals []arrival
  31. foundMatches []absoluteMatch
  32. matches []Match
  33. }
  34. func (q *Pathfinder) Reset() {
  35. for i := range q.table {
  36. q.table[i] = 0
  37. }
  38. q.history = q.history[:0]
  39. q.chain = q.chain[:0]
  40. }
  41. // An arrival represents how we got to a certain byte position.
  42. // The cost is the total cost to get there from the beginning of the block.
  43. // If distance > 0, the arrival is with a match.
  44. // If distance == 0, the arrival is with a run of literals.
  45. type arrival struct {
  46. length uint32
  47. distance uint32
  48. cost float32
  49. }
  50. const (
  51. baseMatchCost float32 = 4
  52. )
  53. func (q *Pathfinder) FindMatches(dst []Match, src []byte) []Match {
  54. if q.MaxDistance == 0 {
  55. q.MaxDistance = 65535
  56. }
  57. if q.MinLength == 0 {
  58. q.MinLength = 4
  59. }
  60. if q.HashLen == 0 {
  61. q.HashLen = 6
  62. }
  63. if q.TableBits == 0 {
  64. q.TableBits = 17
  65. }
  66. if len(q.table) < 1<<q.TableBits {
  67. q.table = make([]uint32, 1<<q.TableBits)
  68. }
  69. var histogram [256]uint32
  70. for _, b := range src {
  71. histogram[b]++
  72. }
  73. var byteCost [256]float32
  74. for b, n := range histogram {
  75. cost := max(math.Log2(float64(len(src))/float64(n)), 1)
  76. byteCost[b] = float32(cost)
  77. }
  78. // Each element in arrivals corresponds to the position just after
  79. // the corresponding byte in src.
  80. arrivals := q.arrivals
  81. if len(arrivals) < len(src) {
  82. arrivals = make([]arrival, len(src))
  83. q.arrivals = arrivals
  84. } else {
  85. arrivals = arrivals[:len(src)]
  86. for i := range arrivals {
  87. arrivals[i] = arrival{}
  88. }
  89. }
  90. if len(q.history) > q.MaxDistance*2 {
  91. // Trim down the history buffer.
  92. delta := len(q.history) - q.MaxDistance
  93. copy(q.history, q.history[delta:])
  94. q.history = q.history[:q.MaxDistance]
  95. q.chain = q.chain[:q.MaxDistance]
  96. for i, v := range q.table {
  97. newV := max(int(v)-delta, 0)
  98. q.table[i] = uint32(newV)
  99. }
  100. }
  101. // Append src to the history buffer.
  102. historyLen := len(q.history)
  103. q.history = append(q.history, src...)
  104. q.chain = append(q.chain, make([]uint32, len(src))...)
  105. src = q.history
  106. // Calculate hashes and build the chain.
  107. for i := historyLen; i < len(src)-7; i++ {
  108. h := ((binary.LittleEndian.Uint64(src[i:]) & (1<<(8*q.HashLen) - 1)) * hashMul64) >> (64 - q.TableBits)
  109. candidate := int(q.table[h])
  110. q.table[h] = uint32(i)
  111. if candidate != 0 {
  112. delta := i - candidate
  113. q.chain[i] = uint32(delta)
  114. }
  115. }
  116. // Look for matches, and collect them in foundMatches. Later we'll figure out
  117. // which ones to actually use.
  118. foundMatches := q.foundMatches[:0]
  119. var prevMatch absoluteMatch
  120. i := historyLen
  121. for i < len(src)-7 {
  122. delta := q.chain[i]
  123. if delta == 0 {
  124. i++
  125. continue
  126. }
  127. candidate := i - int(delta)
  128. if candidate <= 0 || i-candidate > q.MaxDistance {
  129. i++
  130. continue
  131. }
  132. var currentMatch absoluteMatch
  133. if i >= prevMatch.End && prevMatch != (absoluteMatch{}) {
  134. // Look for a repeat match at i+1.
  135. prevDistance := prevMatch.Start - prevMatch.Match
  136. if binary.LittleEndian.Uint32(src[i+1:]) == binary.LittleEndian.Uint32(src[i+1-prevDistance:]) {
  137. m := extendMatch2(src, i+1, i+1-prevDistance, i+1)
  138. if m.End-m.Start > q.MinLength {
  139. currentMatch = m
  140. foundMatches = append(foundMatches, m)
  141. }
  142. }
  143. }
  144. if binary.LittleEndian.Uint32(src[candidate:]) == binary.LittleEndian.Uint32(src[i:]) {
  145. m := extendMatch2(src, i, candidate, max(historyLen, prevMatch.Start))
  146. if m.End-m.Start > q.MinLength {
  147. currentMatch = m
  148. foundMatches = append(foundMatches, m)
  149. }
  150. }
  151. for range q.ChainLength {
  152. delta := q.chain[candidate]
  153. if delta == 0 {
  154. break
  155. }
  156. candidate -= int(delta)
  157. if candidate <= 0 || i-candidate > q.MaxDistance {
  158. break
  159. }
  160. if binary.LittleEndian.Uint32(src[candidate:]) == binary.LittleEndian.Uint32(src[i:]) {
  161. m := extendMatch2(src, i, candidate, max(historyLen, prevMatch.Start))
  162. if m.End-m.Start > q.MinLength && m.End-m.Start > currentMatch.End-currentMatch.Start {
  163. currentMatch = m
  164. foundMatches = append(foundMatches, m)
  165. }
  166. }
  167. }
  168. if i < prevMatch.End && currentMatch.End-currentMatch.Start <= prevMatch.End-prevMatch.Start {
  169. // We were looking for an overlapping match, but we didn't find one longer
  170. // than the previous match. So we'll go back to sequential search,
  171. // starting right after the previous match.
  172. i = prevMatch.End
  173. continue
  174. }
  175. if currentMatch == (absoluteMatch{}) {
  176. // No match found. Continue with sequential search.
  177. i++
  178. continue
  179. }
  180. // We've found a match; now look for matches overlapping the end of it.
  181. prevMatch = currentMatch
  182. i = currentMatch.End + 2 - q.HashLen
  183. }
  184. q.foundMatches = foundMatches
  185. slices.SortFunc(foundMatches, func(a, b absoluteMatch) int { return a.Start - b.Start })
  186. matchIndex := 0
  187. var pending absoluteMatch
  188. for i := historyLen; i < len(src); i++ {
  189. var arrivedHere arrival
  190. if i > historyLen {
  191. arrivedHere = arrivals[i-historyLen-1]
  192. }
  193. unmatched := 0
  194. if arrivedHere.distance == 0 {
  195. unmatched = int(arrivedHere.length)
  196. }
  197. prevDistance := 0
  198. if i-unmatched > historyLen {
  199. prevDistance = int(arrivals[i-historyLen-1-unmatched].distance)
  200. }
  201. literalCost := byteCost[src[i]]
  202. nextArrival := &arrivals[i-historyLen]
  203. if nextArrival.cost == 0 || arrivedHere.cost+literalCost < nextArrival.cost {
  204. *nextArrival = arrival{
  205. cost: arrivedHere.cost + literalCost,
  206. length: uint32(unmatched + 1),
  207. }
  208. }
  209. for matchIndex < len(foundMatches) && foundMatches[matchIndex].Start == i {
  210. m := foundMatches[matchIndex]
  211. matchIndex++
  212. if m.End > pending.End {
  213. pending = m
  214. }
  215. matchCost := baseMatchCost + float32(bits.Len(uint(unmatched)))
  216. if m.Start-m.Match != prevDistance {
  217. matchCost += float32(bits.Len(uint(m.Start - m.Match)))
  218. }
  219. for j := m.Start + q.MinLength; j <= m.End; j++ {
  220. adjustedCost := matchCost
  221. if j-m.Start < 6 {
  222. // Matches shorter than 6 are comparatively rare, and therefore
  223. // have longer codes.
  224. adjustedCost += float32(6-(j-m.Start)) * 2
  225. }
  226. a := &arrivals[j-historyLen-1]
  227. if a.cost == 0 || arrivedHere.cost+adjustedCost < a.cost {
  228. *a = arrival{
  229. length: uint32(j - m.Start),
  230. distance: uint32(m.Start - m.Match),
  231. cost: arrivedHere.cost + adjustedCost,
  232. }
  233. }
  234. }
  235. }
  236. // If a match from an earlier position extends far enough past the current
  237. // position, try using the tail of it, starting from here.
  238. if unmatched == 0 && pending.Start != i && pending.End >= i+q.MinLength &&
  239. !(arrivedHere.length != 0 && arrivedHere.distance == uint32(pending.Start-pending.Match)) {
  240. matchCost := baseMatchCost + float32(bits.Len(uint(pending.Start-pending.Match)))
  241. for j := i + q.MinLength; j <= pending.End; j++ {
  242. adjustedCost := matchCost
  243. if j-i < 6 {
  244. // Matches shorter than 6 are comparatively rare, and therefore
  245. // have longer codes.
  246. adjustedCost += float32(6-(j-i)) * 2
  247. }
  248. a := &arrivals[j-historyLen-1]
  249. if a.cost == 0 || arrivedHere.cost+adjustedCost < a.cost {
  250. *a = arrival{
  251. length: uint32(j - i),
  252. distance: uint32(pending.Start - pending.Match),
  253. cost: arrivedHere.cost + adjustedCost,
  254. }
  255. }
  256. }
  257. }
  258. delta := q.chain[i]
  259. if delta == 0 {
  260. continue
  261. }
  262. candidate := i - int(delta)
  263. if candidate <= 0 || i-candidate > q.MaxDistance {
  264. continue
  265. }
  266. }
  267. // We've found the shortest path; now walk it backward and store the matches.
  268. matches := q.matches[:0]
  269. i = len(arrivals) - 1
  270. for i >= 0 {
  271. a := arrivals[i]
  272. if a.distance > 0 {
  273. matches = append(matches, Match{
  274. Length: int(a.length),
  275. Distance: int(a.distance),
  276. })
  277. i -= int(a.length)
  278. } else {
  279. if len(matches) == 0 {
  280. matches = append(matches, Match{})
  281. }
  282. matches[len(matches)-1].Unmatched = int(a.length)
  283. i -= int(a.length)
  284. }
  285. }
  286. q.matches = matches
  287. slices.Reverse(matches)
  288. return append(dst, matches...)
  289. }