zstd.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package fasthttp
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "github.com/klauspost/compress/zstd"
  8. "github.com/valyala/bytebufferpool"
  9. "github.com/valyala/fasthttp/stackless"
  10. )
  11. const (
  12. CompressZstdSpeedNotSet = iota
  13. CompressZstdBestSpeed
  14. CompressZstdDefault
  15. CompressZstdSpeedBetter
  16. CompressZstdBestCompression
  17. )
  18. var (
  19. zstdDecoderPool sync.Pool
  20. zstdEncoderPool sync.Pool
  21. realZstdWriterPoolMap = newCompressWriterPoolMap()
  22. stacklessZstdWriterPoolMap = newCompressWriterPoolMap()
  23. )
  24. func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) {
  25. v := zstdDecoderPool.Get()
  26. if v == nil {
  27. return zstd.NewReader(r)
  28. }
  29. zr := v.(*zstd.Decoder)
  30. if err := zr.Reset(r); err != nil {
  31. return nil, err
  32. }
  33. return zr, nil
  34. }
  35. func releaseZstdReader(zr *zstd.Decoder) {
  36. zstdDecoderPool.Put(zr)
  37. }
  38. func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) {
  39. v := zstdEncoderPool.Get()
  40. if v == nil {
  41. return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
  42. }
  43. zw := v.(*zstd.Encoder)
  44. zw.Reset(w)
  45. return zw, nil
  46. }
  47. func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused
  48. zw.Close()
  49. zstdEncoderPool.Put(zw)
  50. }
  51. func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer {
  52. nLevel := normalizeZstdCompressLevel(compressLevel)
  53. p := stacklessZstdWriterPoolMap[nLevel]
  54. v := p.Get()
  55. if v == nil {
  56. return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
  57. return acquireRealZstdWriter(w, compressLevel)
  58. })
  59. }
  60. sw := v.(stackless.Writer)
  61. sw.Reset(w)
  62. return sw
  63. }
  64. func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) {
  65. zf.Close()
  66. nLevel := normalizeZstdCompressLevel(zstdDefault)
  67. p := stacklessZstdWriterPoolMap[nLevel]
  68. p.Put(zf)
  69. }
  70. func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder {
  71. nLevel := normalizeZstdCompressLevel(level)
  72. p := realZstdWriterPoolMap[nLevel]
  73. v := p.Get()
  74. if v == nil {
  75. zw, err := acquireZstdWriter(w, level)
  76. if err != nil {
  77. panic(err)
  78. }
  79. return zw
  80. }
  81. zw := v.(*zstd.Encoder)
  82. zw.Reset(w)
  83. return zw
  84. }
  85. func releaseRealZstdWrter(zw *zstd.Encoder, level int) {
  86. zw.Close()
  87. nLevel := normalizeZstdCompressLevel(level)
  88. p := realZstdWriterPoolMap[nLevel]
  89. p.Put(zw)
  90. }
  91. func AppendZstdBytesLevel(dst, src []byte, level int) []byte {
  92. w := &byteSliceWriter{b: dst}
  93. WriteZstdLevel(w, src, level) //nolint:errcheck
  94. return w.b
  95. }
  96. func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) {
  97. level = normalizeZstdCompressLevel(level)
  98. switch w.(type) {
  99. case *byteSliceWriter,
  100. *bytes.Buffer,
  101. *bytebufferpool.ByteBuffer:
  102. ctx := &compressCtx{
  103. w: w,
  104. p: p,
  105. level: level,
  106. }
  107. stacklessWriteZstd(ctx)
  108. return len(p), nil
  109. default:
  110. zw := acquireStacklessZstdWriter(w, level)
  111. n, err := zw.Write(p)
  112. releaseStacklessZstdWriter(zw, level)
  113. return n, err
  114. }
  115. }
  116. var (
  117. stacklessWriteZstdOnce sync.Once
  118. stacklessWriteZstdFunc func(ctx any) bool
  119. )
  120. func stacklessWriteZstd(ctx any) {
  121. stacklessWriteZstdOnce.Do(func() {
  122. stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd)
  123. })
  124. stacklessWriteZstdFunc(ctx)
  125. }
  126. func nonblockingWriteZstd(ctxv any) {
  127. ctx := ctxv.(*compressCtx)
  128. zw := acquireRealZstdWriter(ctx.w, ctx.level)
  129. zw.Write(ctx.p) //nolint:errcheck
  130. releaseRealZstdWrter(zw, ctx.level)
  131. }
  132. // AppendZstdBytes appends zstd src to dst and returns the resulting dst.
  133. func AppendZstdBytes(dst, src []byte) []byte {
  134. return AppendZstdBytesLevel(dst, src, CompressZstdDefault)
  135. }
  136. // WriteUnzstd writes unzstd p to w and returns the number of uncompressed
  137. // bytes written to w.
  138. func WriteUnzstd(w io.Writer, p []byte) (int, error) {
  139. r := &byteSliceReader{b: p}
  140. zr, err := acquireZstdReader(r)
  141. if err != nil {
  142. return 0, err
  143. }
  144. n, err := copyZeroAlloc(w, zr)
  145. releaseZstdReader(zr)
  146. nn := int(n)
  147. if int64(nn) != n {
  148. return 0, fmt.Errorf("too much data unzstd: %d", n)
  149. }
  150. return nn, err
  151. }
  152. // AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst.
  153. func AppendUnzstdBytes(dst, src []byte) ([]byte, error) {
  154. w := &byteSliceWriter{b: dst}
  155. _, err := WriteUnzstd(w, src)
  156. return w.b, err
  157. }
  158. // normalizes compression level into [0..7], so it could be used as an index
  159. // in *PoolMap.
  160. func normalizeZstdCompressLevel(level int) int {
  161. if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression {
  162. level = CompressZstdDefault
  163. }
  164. return level
  165. }