decode.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. // Copyright 2011 The Snappy-Go Authors. All rights reserved.
  2. // Copyright (c) 2019 Klaus Post. All rights reserved.
  3. // Use of this source code is governed by a BSD-style
  4. // license that can be found in the LICENSE file.
  5. package s2
  6. import (
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "strconv"
  11. "github.com/klauspost/compress/internal/race"
  12. )
  13. var (
  14. // ErrCorrupt reports that the input is invalid.
  15. ErrCorrupt = errors.New("s2: corrupt input")
  16. // ErrCRC reports that the input failed CRC validation (streams only)
  17. ErrCRC = errors.New("s2: corrupt input, crc mismatch")
  18. // ErrTooLarge reports that the uncompressed length is too large.
  19. ErrTooLarge = errors.New("s2: decoded block is too large")
  20. // ErrUnsupported reports that the input isn't supported.
  21. ErrUnsupported = errors.New("s2: unsupported input")
  22. )
  23. // DecodedLen returns the length of the decoded block.
  24. func DecodedLen(src []byte) (int, error) {
  25. v, _, err := decodedLen(src)
  26. return v, err
  27. }
  28. // decodedLen returns the length of the decoded block and the number of bytes
  29. // that the length header occupied.
  30. func decodedLen(src []byte) (blockLen, headerLen int, err error) {
  31. v, n := binary.Uvarint(src)
  32. if n <= 0 || v > 0xffffffff {
  33. return 0, 0, ErrCorrupt
  34. }
  35. const wordSize = 32 << (^uint(0) >> 32 & 1)
  36. if wordSize == 32 && v > 0x7fffffff {
  37. return 0, 0, ErrTooLarge
  38. }
  39. return int(v), n, nil
  40. }
  41. const (
  42. decodeErrCodeCorrupt = 1
  43. )
  44. // Decode returns the decoded form of src. The returned slice may be a sub-
  45. // slice of dst if dst was large enough to hold the entire decoded block.
  46. // Otherwise, a newly allocated slice will be returned.
  47. //
  48. // The dst and src must not overlap. It is valid to pass a nil dst.
  49. func Decode(dst, src []byte) ([]byte, error) {
  50. dLen, s, err := decodedLen(src)
  51. if err != nil {
  52. return nil, err
  53. }
  54. if dLen <= cap(dst) {
  55. dst = dst[:dLen]
  56. } else {
  57. dst = make([]byte, dLen)
  58. }
  59. race.WriteSlice(dst)
  60. race.ReadSlice(src[s:])
  61. if s2Decode(dst, src[s:]) != 0 {
  62. return nil, ErrCorrupt
  63. }
  64. return dst, nil
  65. }
  66. // s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded
  67. // length of the decompressed bytes has already been read, and that len(dst)
  68. // equals that length.
  69. //
  70. // It returns 0 on success or a decodeErrCodeXxx error code on failure.
  71. func s2DecodeDict(dst, src []byte, dict *Dict) int {
  72. if dict == nil {
  73. return s2Decode(dst, src)
  74. }
  75. const debug = false
  76. const debugErrs = debug
  77. if debug {
  78. fmt.Println("Starting decode, dst len:", len(dst))
  79. }
  80. var d, s, length int
  81. offset := len(dict.dict) - dict.repeat
  82. // As long as we can read at least 5 bytes...
  83. for s < len(src)-5 {
  84. // Removing bounds checks is SLOWER, when if doing
  85. // in := src[s:s+5]
  86. // Checked on Go 1.18
  87. switch src[s] & 0x03 {
  88. case tagLiteral:
  89. x := uint32(src[s] >> 2)
  90. switch {
  91. case x < 60:
  92. s++
  93. case x == 60:
  94. s += 2
  95. x = uint32(src[s-1])
  96. case x == 61:
  97. in := src[s : s+3]
  98. x = uint32(in[1]) | uint32(in[2])<<8
  99. s += 3
  100. case x == 62:
  101. in := src[s : s+4]
  102. // Load as 32 bit and shift down.
  103. x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
  104. x >>= 8
  105. s += 4
  106. case x == 63:
  107. in := src[s : s+5]
  108. x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
  109. s += 5
  110. }
  111. length = int(x) + 1
  112. if debug {
  113. fmt.Println("literals, length:", length, "d-after:", d+length)
  114. }
  115. if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
  116. if debugErrs {
  117. fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
  118. }
  119. return decodeErrCodeCorrupt
  120. }
  121. copy(dst[d:], src[s:s+length])
  122. d += length
  123. s += length
  124. continue
  125. case tagCopy1:
  126. s += 2
  127. toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
  128. length = int(src[s-2]) >> 2 & 0x7
  129. if toffset == 0 {
  130. if debug {
  131. fmt.Print("(repeat) ")
  132. }
  133. // keep last offset
  134. switch length {
  135. case 5:
  136. length = int(src[s]) + 4
  137. s += 1
  138. case 6:
  139. in := src[s : s+2]
  140. length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
  141. s += 2
  142. case 7:
  143. in := src[s : s+3]
  144. length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
  145. s += 3
  146. default: // 0-> 4
  147. }
  148. } else {
  149. offset = toffset
  150. }
  151. length += 4
  152. case tagCopy2:
  153. in := src[s : s+3]
  154. offset = int(uint32(in[1]) | uint32(in[2])<<8)
  155. length = 1 + int(in[0])>>2
  156. s += 3
  157. case tagCopy4:
  158. in := src[s : s+5]
  159. offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
  160. length = 1 + int(in[0])>>2
  161. s += 5
  162. }
  163. if offset <= 0 || length > len(dst)-d {
  164. if debugErrs {
  165. fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
  166. }
  167. return decodeErrCodeCorrupt
  168. }
  169. // copy from dict
  170. if d < offset {
  171. if d > MaxDictSrcOffset {
  172. if debugErrs {
  173. fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
  174. }
  175. return decodeErrCodeCorrupt
  176. }
  177. startOff := len(dict.dict) - offset + d
  178. if startOff < 0 || startOff+length > len(dict.dict) {
  179. if debugErrs {
  180. fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict))
  181. }
  182. return decodeErrCodeCorrupt
  183. }
  184. if debug {
  185. fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff)
  186. }
  187. copy(dst[d:d+length], dict.dict[startOff:])
  188. d += length
  189. continue
  190. }
  191. if debug {
  192. fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
  193. }
  194. // Copy from an earlier sub-slice of dst to a later sub-slice.
  195. // If no overlap, use the built-in copy:
  196. if offset > length {
  197. copy(dst[d:d+length], dst[d-offset:])
  198. d += length
  199. continue
  200. }
  201. // Unlike the built-in copy function, this byte-by-byte copy always runs
  202. // forwards, even if the slices overlap. Conceptually, this is:
  203. //
  204. // d += forwardCopy(dst[d:d+length], dst[d-offset:])
  205. //
  206. // We align the slices into a and b and show the compiler they are the same size.
  207. // This allows the loop to run without bounds checks.
  208. a := dst[d : d+length]
  209. b := dst[d-offset:]
  210. b = b[:len(a)]
  211. for i := range a {
  212. a[i] = b[i]
  213. }
  214. d += length
  215. }
  216. // Remaining with extra checks...
  217. for s < len(src) {
  218. switch src[s] & 0x03 {
  219. case tagLiteral:
  220. x := uint32(src[s] >> 2)
  221. switch {
  222. case x < 60:
  223. s++
  224. case x == 60:
  225. s += 2
  226. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  227. if debugErrs {
  228. fmt.Println("src went oob")
  229. }
  230. return decodeErrCodeCorrupt
  231. }
  232. x = uint32(src[s-1])
  233. case x == 61:
  234. s += 3
  235. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  236. if debugErrs {
  237. fmt.Println("src went oob")
  238. }
  239. return decodeErrCodeCorrupt
  240. }
  241. x = uint32(src[s-2]) | uint32(src[s-1])<<8
  242. case x == 62:
  243. s += 4
  244. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  245. if debugErrs {
  246. fmt.Println("src went oob")
  247. }
  248. return decodeErrCodeCorrupt
  249. }
  250. x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
  251. case x == 63:
  252. s += 5
  253. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  254. if debugErrs {
  255. fmt.Println("src went oob")
  256. }
  257. return decodeErrCodeCorrupt
  258. }
  259. x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
  260. }
  261. length = int(x) + 1
  262. if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
  263. if debugErrs {
  264. fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
  265. }
  266. return decodeErrCodeCorrupt
  267. }
  268. if debug {
  269. fmt.Println("literals, length:", length, "d-after:", d+length)
  270. }
  271. copy(dst[d:], src[s:s+length])
  272. d += length
  273. s += length
  274. continue
  275. case tagCopy1:
  276. s += 2
  277. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  278. if debugErrs {
  279. fmt.Println("src went oob")
  280. }
  281. return decodeErrCodeCorrupt
  282. }
  283. length = int(src[s-2]) >> 2 & 0x7
  284. toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
  285. if toffset == 0 {
  286. if debug {
  287. fmt.Print("(repeat) ")
  288. }
  289. // keep last offset
  290. switch length {
  291. case 5:
  292. s += 1
  293. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  294. if debugErrs {
  295. fmt.Println("src went oob")
  296. }
  297. return decodeErrCodeCorrupt
  298. }
  299. length = int(uint32(src[s-1])) + 4
  300. case 6:
  301. s += 2
  302. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  303. if debugErrs {
  304. fmt.Println("src went oob")
  305. }
  306. return decodeErrCodeCorrupt
  307. }
  308. length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
  309. case 7:
  310. s += 3
  311. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  312. if debugErrs {
  313. fmt.Println("src went oob")
  314. }
  315. return decodeErrCodeCorrupt
  316. }
  317. length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
  318. default: // 0-> 4
  319. }
  320. } else {
  321. offset = toffset
  322. }
  323. length += 4
  324. case tagCopy2:
  325. s += 3
  326. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  327. if debugErrs {
  328. fmt.Println("src went oob")
  329. }
  330. return decodeErrCodeCorrupt
  331. }
  332. length = 1 + int(src[s-3])>>2
  333. offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
  334. case tagCopy4:
  335. s += 5
  336. if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
  337. if debugErrs {
  338. fmt.Println("src went oob")
  339. }
  340. return decodeErrCodeCorrupt
  341. }
  342. length = 1 + int(src[s-5])>>2
  343. offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
  344. }
  345. if offset <= 0 || length > len(dst)-d {
  346. if debugErrs {
  347. fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
  348. }
  349. return decodeErrCodeCorrupt
  350. }
  351. // copy from dict
  352. if d < offset {
  353. if d > MaxDictSrcOffset {
  354. if debugErrs {
  355. fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
  356. }
  357. return decodeErrCodeCorrupt
  358. }
  359. rOff := len(dict.dict) - (offset - d)
  360. if debug {
  361. fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff)
  362. }
  363. if rOff+length > len(dict.dict) {
  364. if debugErrs {
  365. fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length)
  366. }
  367. return decodeErrCodeCorrupt
  368. }
  369. if rOff < 0 {
  370. if debugErrs {
  371. fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length)
  372. }
  373. return decodeErrCodeCorrupt
  374. }
  375. copy(dst[d:d+length], dict.dict[rOff:])
  376. d += length
  377. continue
  378. }
  379. if debug {
  380. fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
  381. }
  382. // Copy from an earlier sub-slice of dst to a later sub-slice.
  383. // If no overlap, use the built-in copy:
  384. if offset > length {
  385. copy(dst[d:d+length], dst[d-offset:])
  386. d += length
  387. continue
  388. }
  389. // Unlike the built-in copy function, this byte-by-byte copy always runs
  390. // forwards, even if the slices overlap. Conceptually, this is:
  391. //
  392. // d += forwardCopy(dst[d:d+length], dst[d-offset:])
  393. //
  394. // We align the slices into a and b and show the compiler they are the same size.
  395. // This allows the loop to run without bounds checks.
  396. a := dst[d : d+length]
  397. b := dst[d-offset:]
  398. b = b[:len(a)]
  399. for i := range a {
  400. a[i] = b[i]
  401. }
  402. d += length
  403. }
  404. if d != len(dst) {
  405. if debugErrs {
  406. fmt.Println("wanted length", len(dst), "got", d)
  407. }
  408. return decodeErrCodeCorrupt
  409. }
  410. return 0
  411. }