zstd.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package fasthttp
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/klauspost/compress/zstd"
  8. "github.com/valyala/bytebufferpool"
  9. "github.com/valyala/fasthttp/stackless"
  10. )
  11. const (
  12. CompressZstdSpeedNotSet = iota
  13. CompressZstdBestSpeed
  14. CompressZstdDefault
  15. CompressZstdSpeedBetter
  16. CompressZstdBestCompression
  17. )
  18. var (
  19. zstdDecoderPool sync.Pool
  20. realZstdWriterPoolMap = newCompressWriterPoolMap()
  21. stacklessZstdWriterPoolMap = newCompressWriterPoolMap()
  22. )
  23. func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) {
  24. v := zstdDecoderPool.Get()
  25. if v == nil {
  26. return zstd.NewReader(r)
  27. }
  28. zr := v.(*zstd.Decoder)
  29. if err := zr.Reset(r); err != nil {
  30. return nil, err
  31. }
  32. return zr, nil
  33. }
  34. func releaseZstdReader(zr *zstd.Decoder) {
  35. zstdDecoderPool.Put(zr)
  36. }
  37. func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer {
  38. nLevel := normalizeZstdCompressLevel(compressLevel)
  39. p := stacklessZstdWriterPoolMap[nLevel]
  40. v := p.Get()
  41. if v == nil {
  42. return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
  43. return acquireRealZstdWriter(w, compressLevel)
  44. })
  45. }
  46. sw := v.(stackless.Writer)
  47. sw.Reset(w)
  48. return sw
  49. }
  50. func releaseStacklessZstdWriter(zf stackless.Writer, level int) {
  51. zf.Close()
  52. nLevel := normalizeZstdCompressLevel(level)
  53. p := stacklessZstdWriterPoolMap[nLevel]
  54. p.Put(zf)
  55. }
  56. func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder {
  57. nLevel := normalizeZstdCompressLevel(level)
  58. p := realZstdWriterPoolMap[nLevel]
  59. v := p.Get()
  60. if v == nil {
  61. zw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
  62. if err != nil {
  63. panic(err)
  64. }
  65. return zw
  66. }
  67. zw := v.(*zstd.Encoder)
  68. zw.Reset(w)
  69. return zw
  70. }
  71. func releaseRealZstdWriter(zw *zstd.Encoder, level int) {
  72. zw.Close()
  73. nLevel := normalizeZstdCompressLevel(level)
  74. p := realZstdWriterPoolMap[nLevel]
  75. p.Put(zw)
  76. }
  77. func AppendZstdBytesLevel(dst, src []byte, level int) []byte {
  78. w := &byteSliceWriter{b: dst}
  79. WriteZstdLevel(w, src, level) //nolint:errcheck
  80. return w.b
  81. }
  82. func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) {
  83. level = normalizeZstdCompressLevel(level)
  84. switch w.(type) {
  85. case *byteSliceWriter,
  86. *bytes.Buffer,
  87. *bytebufferpool.ByteBuffer:
  88. ctx := &compressCtx{
  89. w: w,
  90. p: p,
  91. level: level,
  92. }
  93. stacklessWriteZstd(ctx)
  94. return len(p), nil
  95. default:
  96. zw := acquireStacklessZstdWriter(w, level)
  97. n, err := zw.Write(p)
  98. releaseStacklessZstdWriter(zw, level)
  99. return n, err
  100. }
  101. }
  102. var (
  103. stacklessWriteZstdOnce sync.Once
  104. stacklessWriteZstdFunc func(ctx any) bool
  105. )
  106. func stacklessWriteZstd(ctx any) {
  107. stacklessWriteZstdOnce.Do(func() {
  108. stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd)
  109. })
  110. stacklessWriteZstdFunc(ctx)
  111. }
  112. func nonblockingWriteZstd(ctxv any) {
  113. ctx := ctxv.(*compressCtx)
  114. zw := acquireRealZstdWriter(ctx.w, ctx.level)
  115. zw.Write(ctx.p) //nolint:errcheck
  116. releaseRealZstdWriter(zw, ctx.level)
  117. }
  118. // AppendZstdBytes appends zstd src to dst and returns the resulting dst.
  119. func AppendZstdBytes(dst, src []byte) []byte {
  120. return AppendZstdBytesLevel(dst, src, CompressZstdDefault)
  121. }
  122. // WriteUnzstd writes unzstd p to w and returns the number of uncompressed
  123. // bytes written to w.
  124. func WriteUnzstd(w io.Writer, p []byte) (int, error) {
  125. r := &byteSliceReader{b: p}
  126. zr, err := acquireZstdReader(r)
  127. if err != nil {
  128. return 0, err
  129. }
  130. n, err := copyZeroAlloc(w, zr)
  131. releaseZstdReader(zr)
  132. nn := int(n)
  133. if int64(nn) != n {
  134. return 0, fmt.Errorf("too much data unzstd: %d", n)
  135. }
  136. return nn, err
  137. }
  138. // AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst.
  139. func AppendUnzstdBytes(dst, src []byte) ([]byte, error) {
  140. w := &byteSliceWriter{b: dst}
  141. _, err := WriteUnzstd(w, src)
  142. return w.b, err
  143. }
  144. // normalizes compression level into [0..7], so it could be used as an index
  145. // in *PoolMap.
  146. func normalizeZstdCompressLevel(level int) int {
  147. if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression {
  148. level = CompressZstdDefault
  149. }
  150. return level
  151. }