streaming.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package fasthttp
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "sync"
  7. "github.com/valyala/bytebufferpool"
  8. )
  9. type requestStream struct {
  10. header *RequestHeader
  11. prefetchedBytes *bytes.Reader
  12. reader *bufio.Reader
  13. totalBytesRead int
  14. chunkLeft int
  15. }
  16. func (rs *requestStream) Read(p []byte) (int, error) {
  17. var (
  18. n int
  19. err error
  20. )
  21. if rs.header.contentLength == -1 {
  22. if rs.chunkLeft == 0 {
  23. chunkSize, err := parseChunkSize(rs.reader)
  24. if err != nil {
  25. return 0, err
  26. }
  27. if chunkSize == 0 {
  28. err = rs.header.ReadTrailer(rs.reader)
  29. if err != nil && err != io.EOF {
  30. return 0, err
  31. }
  32. return 0, io.EOF
  33. }
  34. rs.chunkLeft = chunkSize
  35. }
  36. bytesToRead := len(p)
  37. if rs.chunkLeft < len(p) {
  38. bytesToRead = rs.chunkLeft
  39. }
  40. n, err = rs.reader.Read(p[:bytesToRead])
  41. rs.totalBytesRead += n
  42. rs.chunkLeft -= n
  43. if err == io.EOF {
  44. err = io.ErrUnexpectedEOF
  45. }
  46. if err == nil && rs.chunkLeft == 0 {
  47. err = readCrLf(rs.reader)
  48. }
  49. return n, err
  50. }
  51. if rs.totalBytesRead == rs.header.contentLength {
  52. return 0, io.EOF
  53. }
  54. prefetchedSize := int(rs.prefetchedBytes.Size())
  55. if prefetchedSize > rs.totalBytesRead {
  56. left := prefetchedSize - rs.totalBytesRead
  57. if len(p) > left {
  58. p = p[:left]
  59. }
  60. n, err := rs.prefetchedBytes.Read(p)
  61. rs.totalBytesRead += n
  62. if n == rs.header.contentLength {
  63. return n, io.EOF
  64. }
  65. return n, err
  66. } else {
  67. left := rs.header.contentLength - rs.totalBytesRead
  68. if len(p) > left {
  69. p = p[:left]
  70. }
  71. n, err = rs.reader.Read(p)
  72. rs.totalBytesRead += n
  73. if err != nil {
  74. return n, err
  75. }
  76. }
  77. if rs.totalBytesRead == rs.header.contentLength {
  78. err = io.EOF
  79. }
  80. return n, err
  81. }
  82. func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream {
  83. rs := requestStreamPool.Get().(*requestStream)
  84. rs.prefetchedBytes = bytes.NewReader(b.B)
  85. rs.reader = r
  86. rs.header = h
  87. return rs
  88. }
  89. func releaseRequestStream(rs *requestStream) {
  90. rs.prefetchedBytes = nil
  91. rs.totalBytesRead = 0
  92. rs.chunkLeft = 0
  93. rs.reader = nil
  94. rs.header = nil
  95. requestStreamPool.Put(rs)
  96. }
  97. var requestStreamPool = sync.Pool{
  98. New: func() interface{} {
  99. return &requestStream{}
  100. },
  101. }