encoder.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. // Copyright 2019+ Klaus Post. All rights reserved.
  2. // License information can be found in the LICENSE file.
  3. // Based on work by Yann Collet, released under BSD License.
  4. package zstd
  5. import (
  6. "crypto/rand"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "math"
  11. rdebug "runtime/debug"
  12. "sync"
  13. "github.com/klauspost/compress/zstd/internal/xxhash"
  14. )
  15. // Encoder provides encoding to Zstandard.
  16. // An Encoder can be used for either compressing a stream via the
  17. // io.WriteCloser interface supported by the Encoder or as multiple independent
  18. // tasks via the EncodeAll function.
  19. // Smaller encodes are encouraged to use the EncodeAll function.
  20. // Use NewWriter to create a new instance.
  21. type Encoder struct {
  22. o encoderOptions
  23. encoders chan encoder
  24. state encoderState
  25. init sync.Once
  26. }
  27. type encoder interface {
  28. Encode(blk *blockEnc, src []byte)
  29. EncodeNoHist(blk *blockEnc, src []byte)
  30. Block() *blockEnc
  31. CRC() *xxhash.Digest
  32. AppendCRC([]byte) []byte
  33. WindowSize(size int64) int32
  34. UseBlock(*blockEnc)
  35. Reset(d *dict, singleBlock bool)
  36. }
  37. type encoderState struct {
  38. w io.Writer
  39. filling []byte
  40. current []byte
  41. previous []byte
  42. encoder encoder
  43. writing *blockEnc
  44. err error
  45. writeErr error
  46. nWritten int64
  47. nInput int64
  48. frameContentSize int64
  49. headerWritten bool
  50. eofWritten bool
  51. fullFrameWritten bool
  52. // This waitgroup indicates an encode is running.
  53. wg sync.WaitGroup
  54. // This waitgroup indicates we have a block encoding/writing.
  55. wWg sync.WaitGroup
  56. }
  57. // NewWriter will create a new Zstandard encoder.
  58. // If the encoder will be used for encoding blocks a nil writer can be used.
  59. func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
  60. initPredefined()
  61. var e Encoder
  62. e.o.setDefault()
  63. for _, o := range opts {
  64. err := o(&e.o)
  65. if err != nil {
  66. return nil, err
  67. }
  68. }
  69. if w != nil {
  70. e.Reset(w)
  71. }
  72. return &e, nil
  73. }
  74. func (e *Encoder) initialize() {
  75. if e.o.concurrent == 0 {
  76. e.o.setDefault()
  77. }
  78. e.encoders = make(chan encoder, e.o.concurrent)
  79. for i := 0; i < e.o.concurrent; i++ {
  80. enc := e.o.encoder()
  81. e.encoders <- enc
  82. }
  83. }
  84. // Reset will re-initialize the writer and new writes will encode to the supplied writer
  85. // as a new, independent stream.
  86. func (e *Encoder) Reset(w io.Writer) {
  87. s := &e.state
  88. s.wg.Wait()
  89. s.wWg.Wait()
  90. if cap(s.filling) == 0 {
  91. s.filling = make([]byte, 0, e.o.blockSize)
  92. }
  93. if e.o.concurrent > 1 {
  94. if cap(s.current) == 0 {
  95. s.current = make([]byte, 0, e.o.blockSize)
  96. }
  97. if cap(s.previous) == 0 {
  98. s.previous = make([]byte, 0, e.o.blockSize)
  99. }
  100. s.current = s.current[:0]
  101. s.previous = s.previous[:0]
  102. if s.writing == nil {
  103. s.writing = &blockEnc{lowMem: e.o.lowMem}
  104. s.writing.init()
  105. }
  106. s.writing.initNewEncode()
  107. }
  108. if s.encoder == nil {
  109. s.encoder = e.o.encoder()
  110. }
  111. s.filling = s.filling[:0]
  112. s.encoder.Reset(e.o.dict, false)
  113. s.headerWritten = false
  114. s.eofWritten = false
  115. s.fullFrameWritten = false
  116. s.w = w
  117. s.err = nil
  118. s.nWritten = 0
  119. s.nInput = 0
  120. s.writeErr = nil
  121. s.frameContentSize = 0
  122. }
  123. // ResetWithOptions will re-initialize the writer and apply the given options
  124. // as a new, independent stream.
  125. // Options are applied on top of the existing options.
  126. // Some options cannot be changed on reset and will return an error.
  127. func (e *Encoder) ResetWithOptions(w io.Writer, opts ...EOption) error {
  128. e.o.resetOpt = true
  129. defer func() { e.o.resetOpt = false }()
  130. for _, o := range opts {
  131. if err := o(&e.o); err != nil {
  132. return err
  133. }
  134. }
  135. e.Reset(w)
  136. return nil
  137. }
  138. // ResetContentSize will reset and set a content size for the next stream.
  139. // If the bytes written does not match the size given an error will be returned
  140. // when calling Close().
  141. // This is removed when Reset is called.
  142. // Sizes <= 0 results in no content size set.
  143. func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
  144. e.Reset(w)
  145. if size >= 0 {
  146. e.state.frameContentSize = size
  147. }
  148. }
  149. // Write data to the encoder.
  150. // Input data will be buffered and as the buffer fills up
  151. // content will be compressed and written to the output.
  152. // When done writing, use Close to flush the remaining output
  153. // and write CRC if requested.
  154. func (e *Encoder) Write(p []byte) (n int, err error) {
  155. s := &e.state
  156. if s.eofWritten {
  157. return 0, ErrEncoderClosed
  158. }
  159. for len(p) > 0 {
  160. if len(p)+len(s.filling) < e.o.blockSize {
  161. if e.o.crc {
  162. _, _ = s.encoder.CRC().Write(p)
  163. }
  164. s.filling = append(s.filling, p...)
  165. return n + len(p), nil
  166. }
  167. add := p
  168. if len(p)+len(s.filling) > e.o.blockSize {
  169. add = add[:e.o.blockSize-len(s.filling)]
  170. }
  171. if e.o.crc {
  172. _, _ = s.encoder.CRC().Write(add)
  173. }
  174. s.filling = append(s.filling, add...)
  175. p = p[len(add):]
  176. n += len(add)
  177. if len(s.filling) < e.o.blockSize {
  178. return n, nil
  179. }
  180. err := e.nextBlock(false)
  181. if err != nil {
  182. return n, err
  183. }
  184. if debugAsserts && len(s.filling) > 0 {
  185. panic(len(s.filling))
  186. }
  187. }
  188. return n, nil
  189. }
  190. // nextBlock will synchronize and start compressing input in e.state.filling.
  191. // If an error has occurred during encoding it will be returned.
  192. func (e *Encoder) nextBlock(final bool) error {
  193. s := &e.state
  194. // Wait for current block.
  195. s.wg.Wait()
  196. if s.err != nil {
  197. return s.err
  198. }
  199. if len(s.filling) > e.o.blockSize {
  200. return fmt.Errorf("block > maxStoreBlockSize")
  201. }
  202. if !s.headerWritten {
  203. // If we have a single block encode, do a sync compression.
  204. if final && len(s.filling) == 0 && !e.o.fullZero {
  205. s.headerWritten = true
  206. s.fullFrameWritten = true
  207. s.eofWritten = true
  208. return nil
  209. }
  210. if final && len(s.filling) > 0 {
  211. s.current = e.encodeAll(s.encoder, s.filling, s.current[:0])
  212. var n2 int
  213. n2, s.err = s.w.Write(s.current)
  214. if s.err != nil {
  215. return s.err
  216. }
  217. s.nWritten += int64(n2)
  218. s.nInput += int64(len(s.filling))
  219. s.current = s.current[:0]
  220. s.filling = s.filling[:0]
  221. s.headerWritten = true
  222. s.fullFrameWritten = true
  223. s.eofWritten = true
  224. return nil
  225. }
  226. var tmp [maxHeaderSize]byte
  227. fh := frameHeader{
  228. ContentSize: uint64(s.frameContentSize),
  229. WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
  230. SingleSegment: false,
  231. Checksum: e.o.crc,
  232. DictID: e.o.dict.ID(),
  233. }
  234. dst := fh.appendTo(tmp[:0])
  235. s.headerWritten = true
  236. s.wWg.Wait()
  237. var n2 int
  238. n2, s.err = s.w.Write(dst)
  239. if s.err != nil {
  240. return s.err
  241. }
  242. s.nWritten += int64(n2)
  243. }
  244. if s.eofWritten {
  245. // Ensure we only write it once.
  246. final = false
  247. }
  248. if len(s.filling) == 0 {
  249. // Final block, but no data.
  250. if final {
  251. enc := s.encoder
  252. blk := enc.Block()
  253. blk.reset(nil)
  254. blk.last = true
  255. blk.encodeRaw(nil)
  256. s.wWg.Wait()
  257. _, s.err = s.w.Write(blk.output)
  258. s.nWritten += int64(len(blk.output))
  259. s.eofWritten = true
  260. }
  261. return s.err
  262. }
  263. // SYNC:
  264. if e.o.concurrent == 1 {
  265. src := s.filling
  266. s.nInput += int64(len(s.filling))
  267. if debugEncoder {
  268. println("Adding sync block,", len(src), "bytes, final:", final)
  269. }
  270. enc := s.encoder
  271. blk := enc.Block()
  272. blk.reset(nil)
  273. enc.Encode(blk, src)
  274. blk.last = final
  275. if final {
  276. s.eofWritten = true
  277. }
  278. s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  279. if s.err != nil {
  280. return s.err
  281. }
  282. _, s.err = s.w.Write(blk.output)
  283. s.nWritten += int64(len(blk.output))
  284. s.filling = s.filling[:0]
  285. return s.err
  286. }
  287. // Move blocks forward.
  288. s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
  289. s.nInput += int64(len(s.current))
  290. s.wg.Add(1)
  291. if final {
  292. s.eofWritten = true
  293. }
  294. go func(src []byte) {
  295. if debugEncoder {
  296. println("Adding block,", len(src), "bytes, final:", final)
  297. }
  298. defer func() {
  299. if r := recover(); r != nil {
  300. s.err = fmt.Errorf("panic while encoding: %v", r)
  301. rdebug.PrintStack()
  302. }
  303. s.wg.Done()
  304. }()
  305. enc := s.encoder
  306. blk := enc.Block()
  307. enc.Encode(blk, src)
  308. blk.last = final
  309. // Wait for pending writes.
  310. s.wWg.Wait()
  311. if s.writeErr != nil {
  312. s.err = s.writeErr
  313. return
  314. }
  315. // Transfer encoders from previous write block.
  316. blk.swapEncoders(s.writing)
  317. // Transfer recent offsets to next.
  318. enc.UseBlock(s.writing)
  319. s.writing = blk
  320. s.wWg.Add(1)
  321. go func() {
  322. defer func() {
  323. if r := recover(); r != nil {
  324. s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
  325. rdebug.PrintStack()
  326. }
  327. s.wWg.Done()
  328. }()
  329. s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  330. if s.writeErr != nil {
  331. return
  332. }
  333. _, s.writeErr = s.w.Write(blk.output)
  334. s.nWritten += int64(len(blk.output))
  335. }()
  336. }(s.current)
  337. return nil
  338. }
  339. // ReadFrom reads data from r until EOF or error.
  340. // The return value n is the number of bytes read.
  341. // Any error except io.EOF encountered during the read is also returned.
  342. //
  343. // The Copy function uses ReaderFrom if available.
  344. func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
  345. if debugEncoder {
  346. println("Using ReadFrom")
  347. }
  348. // Flush any current writes.
  349. if len(e.state.filling) > 0 {
  350. if err := e.nextBlock(false); err != nil {
  351. return 0, err
  352. }
  353. }
  354. e.state.filling = e.state.filling[:e.o.blockSize]
  355. src := e.state.filling
  356. for {
  357. n2, err := r.Read(src)
  358. if e.o.crc {
  359. _, _ = e.state.encoder.CRC().Write(src[:n2])
  360. }
  361. // src is now the unfilled part...
  362. src = src[n2:]
  363. n += int64(n2)
  364. switch err {
  365. case io.EOF:
  366. e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
  367. if debugEncoder {
  368. println("ReadFrom: got EOF final block:", len(e.state.filling))
  369. }
  370. return n, nil
  371. case nil:
  372. default:
  373. if debugEncoder {
  374. println("ReadFrom: got error:", err)
  375. }
  376. e.state.err = err
  377. return n, err
  378. }
  379. if len(src) > 0 {
  380. if debugEncoder {
  381. println("ReadFrom: got space left in source:", len(src))
  382. }
  383. continue
  384. }
  385. err = e.nextBlock(false)
  386. if err != nil {
  387. return n, err
  388. }
  389. e.state.filling = e.state.filling[:e.o.blockSize]
  390. src = e.state.filling
  391. }
  392. }
  393. // Flush will send the currently written data to output
  394. // and block until everything has been written.
  395. // This should only be used on rare occasions where pushing the currently queued data is critical.
  396. func (e *Encoder) Flush() error {
  397. s := &e.state
  398. if len(s.filling) > 0 {
  399. err := e.nextBlock(false)
  400. if err != nil {
  401. // Ignore Flush after Close.
  402. if errors.Is(s.err, ErrEncoderClosed) {
  403. return nil
  404. }
  405. return err
  406. }
  407. }
  408. s.wg.Wait()
  409. s.wWg.Wait()
  410. if s.err != nil {
  411. // Ignore Flush after Close.
  412. if errors.Is(s.err, ErrEncoderClosed) {
  413. return nil
  414. }
  415. return s.err
  416. }
  417. return s.writeErr
  418. }
  419. // Close will flush the final output and close the stream.
  420. // The function will block until everything has been written.
  421. // The Encoder can still be re-used after calling this.
  422. func (e *Encoder) Close() error {
  423. s := &e.state
  424. if s.encoder == nil {
  425. return nil
  426. }
  427. err := e.nextBlock(true)
  428. if err != nil {
  429. if errors.Is(s.err, ErrEncoderClosed) {
  430. return nil
  431. }
  432. return err
  433. }
  434. if s.frameContentSize > 0 {
  435. if s.nInput != s.frameContentSize {
  436. return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
  437. }
  438. }
  439. if e.state.fullFrameWritten {
  440. return s.err
  441. }
  442. s.wg.Wait()
  443. s.wWg.Wait()
  444. if s.err != nil {
  445. return s.err
  446. }
  447. if s.writeErr != nil {
  448. return s.writeErr
  449. }
  450. // Write CRC
  451. if e.o.crc && s.err == nil {
  452. // heap alloc.
  453. var tmp [4]byte
  454. _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
  455. s.nWritten += 4
  456. }
  457. // Add padding with content from crypto/rand.Reader
  458. if s.err == nil && e.o.pad > 0 {
  459. add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
  460. frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
  461. if err != nil {
  462. return err
  463. }
  464. _, s.err = s.w.Write(frame)
  465. }
  466. if s.err == nil {
  467. s.err = ErrEncoderClosed
  468. return nil
  469. }
  470. return s.err
  471. }
  472. // EncodeAll will encode all input in src and append it to dst.
  473. // This function can be called concurrently, but each call will only run on a single goroutine.
  474. // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
  475. // Encoded blocks can be concatenated and the result will be the combined input stream.
  476. // Data compressed with EncodeAll can be decoded with the Decoder,
  477. // using either a stream or DecodeAll.
  478. func (e *Encoder) EncodeAll(src, dst []byte) []byte {
  479. e.init.Do(e.initialize)
  480. enc := <-e.encoders
  481. defer func() {
  482. e.encoders <- enc
  483. }()
  484. return e.encodeAll(enc, src, dst)
  485. }
  486. func (e *Encoder) encodeAll(enc encoder, src, dst []byte) []byte {
  487. if len(src) == 0 {
  488. if e.o.fullZero {
  489. // Add frame header.
  490. fh := frameHeader{
  491. ContentSize: 0,
  492. WindowSize: MinWindowSize,
  493. SingleSegment: true,
  494. // Adding a checksum would be a waste of space.
  495. Checksum: false,
  496. DictID: 0,
  497. }
  498. dst = fh.appendTo(dst)
  499. // Write raw block as last one only.
  500. var blk blockHeader
  501. blk.setSize(0)
  502. blk.setType(blockTypeRaw)
  503. blk.setLast(true)
  504. dst = blk.appendTo(dst)
  505. }
  506. return dst
  507. }
  508. // Use single segments when above minimum window and below window size.
  509. single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
  510. if e.o.single != nil {
  511. single = *e.o.single
  512. }
  513. fh := frameHeader{
  514. ContentSize: uint64(len(src)),
  515. WindowSize: uint32(enc.WindowSize(int64(len(src)))),
  516. SingleSegment: single,
  517. Checksum: e.o.crc,
  518. DictID: e.o.dict.ID(),
  519. }
  520. // If less than 1MB, allocate a buffer up front.
  521. if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
  522. dst = make([]byte, 0, len(src))
  523. }
  524. dst = fh.appendTo(dst)
  525. // If we can do everything in one block, prefer that.
  526. if len(src) <= e.o.blockSize {
  527. enc.Reset(e.o.dict, true)
  528. // Slightly faster with no history and everything in one block.
  529. if e.o.crc {
  530. _, _ = enc.CRC().Write(src)
  531. }
  532. blk := enc.Block()
  533. blk.last = true
  534. if e.o.dict == nil {
  535. enc.EncodeNoHist(blk, src)
  536. } else {
  537. enc.Encode(blk, src)
  538. }
  539. // If we got the exact same number of literals as input,
  540. // assume the literals cannot be compressed.
  541. oldout := blk.output
  542. // Output directly to dst
  543. blk.output = dst
  544. err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
  545. if err != nil {
  546. panic(err)
  547. }
  548. dst = blk.output
  549. blk.output = oldout
  550. } else {
  551. enc.Reset(e.o.dict, false)
  552. blk := enc.Block()
  553. for len(src) > 0 {
  554. todo := src
  555. if len(todo) > e.o.blockSize {
  556. todo = todo[:e.o.blockSize]
  557. }
  558. src = src[len(todo):]
  559. if e.o.crc {
  560. _, _ = enc.CRC().Write(todo)
  561. }
  562. blk.pushOffsets()
  563. enc.Encode(blk, todo)
  564. if len(src) == 0 {
  565. blk.last = true
  566. }
  567. err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
  568. if err != nil {
  569. panic(err)
  570. }
  571. dst = append(dst, blk.output...)
  572. blk.reset(nil)
  573. }
  574. }
  575. if e.o.crc {
  576. dst = enc.AppendCRC(dst)
  577. }
  578. // Add padding with content from crypto/rand.Reader
  579. if e.o.pad > 0 {
  580. add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
  581. var err error
  582. dst, err = skippableFrame(dst, add, rand.Reader)
  583. if err != nil {
  584. panic(err)
  585. }
  586. }
  587. return dst
  588. }
  589. // MaxEncodedSize returns the expected maximum
  590. // size of an encoded block or stream.
  591. func (e *Encoder) MaxEncodedSize(size int) int {
  592. frameHeader := 4 + 2 // magic + frame header & window descriptor
  593. if e.o.dict != nil {
  594. frameHeader += 4
  595. }
  596. // Frame content size:
  597. if size < 256 {
  598. frameHeader++
  599. } else if size < 65536+256 {
  600. frameHeader += 2
  601. } else if size < math.MaxInt32 {
  602. frameHeader += 4
  603. } else {
  604. frameHeader += 8
  605. }
  606. // Final crc
  607. if e.o.crc {
  608. frameHeader += 4
  609. }
  610. // Max overhead is 3 bytes/block.
  611. // There cannot be 0 blocks.
  612. blocks := (size + e.o.blockSize) / e.o.blockSize
  613. // Combine, add padding.
  614. maxSz := frameHeader + 3*blocks + size
  615. if e.o.pad > 1 {
  616. maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
  617. }
  618. return maxSz
  619. }