websocket.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. // 🚀 Fiber is an Express inspired web framework written in Go with 💖
  2. // 📌 API Documentation: https://fiber.wiki
  3. // 📝 Github Repository: https://github.com/gofiber/fiber
  4. package websocket
  5. import (
  6. "errors"
  7. "io"
  8. "sync"
  9. "time"
  10. "github.com/fasthttp/websocket"
  11. "github.com/gofiber/fiber/v2"
  12. "github.com/gofiber/fiber/v2/utils"
  13. "github.com/valyala/fasthttp"
  14. )
  15. // Config ...
  16. type Config struct {
  17. // Filter defines a function to skip middleware.
  18. // Optional. Default: nil
  19. Filter func(*fiber.Ctx) bool
  20. // HandshakeTimeout specifies the duration for the handshake to complete.
  21. HandshakeTimeout time.Duration
  22. // Subprotocols specifies the client's requested subprotocols.
  23. Subprotocols []string
  24. // Allowed Origin's based on the Origin header, this validate the request origin to
  25. // prevent cross-site request forgery. Everything is allowed if left empty.
  26. Origins []string
  27. // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
  28. // size is zero, then a useful default size is used. The I/O buffer sizes
  29. // do not limit the size of the messages that can be sent or received.
  30. ReadBufferSize, WriteBufferSize int
  31. // WriteBufferPool is a pool of buffers for write operations. If the value
  32. // is not set, then write buffers are allocated to the connection for the
  33. // lifetime of the connection.
  34. //
  35. // A pool is most useful when the application has a modest volume of writes
  36. // across a large number of connections.
  37. //
  38. // Applications should use a single pool for each unique value of
  39. // WriteBufferSize.
  40. WriteBufferPool websocket.BufferPool
  41. // EnableCompression specifies if the client should attempt to negotiate
  42. // per message compression (RFC 7692). Setting this value to true does not
  43. // guarantee that compression will be supported. Currently only "no context
  44. // takeover" modes are supported.
  45. EnableCompression bool
  46. }
  47. // New returns a new `handler func(*Conn)` that upgrades a client to the
  48. // websocket protocol, you can pass an optional config.
  49. func New(handler func(*Conn), config ...Config) fiber.Handler {
  50. // Init config
  51. var cfg Config
  52. if len(config) > 0 {
  53. cfg = config[0]
  54. }
  55. if len(cfg.Origins) == 0 {
  56. cfg.Origins = []string{"*"}
  57. }
  58. if cfg.ReadBufferSize == 0 {
  59. cfg.ReadBufferSize = 1024
  60. }
  61. if cfg.WriteBufferSize == 0 {
  62. cfg.WriteBufferSize = 1024
  63. }
  64. var upgrader = websocket.FastHTTPUpgrader{
  65. HandshakeTimeout: cfg.HandshakeTimeout,
  66. Subprotocols: cfg.Subprotocols,
  67. ReadBufferSize: cfg.ReadBufferSize,
  68. WriteBufferSize: cfg.WriteBufferSize,
  69. EnableCompression: cfg.EnableCompression,
  70. WriteBufferPool: cfg.WriteBufferPool,
  71. CheckOrigin: func(fctx *fasthttp.RequestCtx) bool {
  72. if cfg.Origins[0] == "*" {
  73. return true
  74. }
  75. origin := utils.UnsafeString(fctx.Request.Header.Peek("Origin"))
  76. for i := range cfg.Origins {
  77. if cfg.Origins[i] == origin {
  78. return true
  79. }
  80. }
  81. return false
  82. },
  83. }
  84. return func(c *fiber.Ctx) error {
  85. if cfg.Filter != nil && !cfg.Filter(c) {
  86. return c.Next()
  87. }
  88. conn := acquireConn()
  89. // locals
  90. c.Context().VisitUserValues(func(key []byte, value interface{}) {
  91. conn.locals[string(key)] = value
  92. })
  93. // params
  94. params := c.Route().Params
  95. for i := 0; i < len(params); i++ {
  96. conn.params[utils.CopyString(params[i])] = utils.ImmutableString(c.Params(params[i]))
  97. }
  98. // queries
  99. c.Context().QueryArgs().VisitAll(func(key, value []byte) {
  100. conn.queries[string(key)] = string(value)
  101. })
  102. // cookies
  103. c.Context().Request.Header.VisitAllCookie(func(key, value []byte) {
  104. conn.cookies[string(key)] = string(value)
  105. })
  106. if err := upgrader.Upgrade(c.Context(), func(fconn *websocket.Conn) {
  107. conn.Conn = fconn
  108. defer releaseConn(conn)
  109. handler(conn)
  110. }); err != nil { // Upgrading required
  111. return fiber.ErrUpgradeRequired
  112. }
  113. return nil
  114. }
  115. }
  116. // Conn https://godoc.org/github.com/gorilla/websocket#pkg-index
  117. type Conn struct {
  118. *websocket.Conn
  119. locals map[string]interface{}
  120. params map[string]string
  121. cookies map[string]string
  122. queries map[string]string
  123. }
  124. // Conn pool
  125. var poolConn = sync.Pool{
  126. New: func() interface{} {
  127. return new(Conn)
  128. },
  129. }
  130. // Acquire Conn from pool
  131. func acquireConn() *Conn {
  132. conn := poolConn.Get().(*Conn)
  133. conn.locals = make(map[string]interface{})
  134. conn.params = make(map[string]string)
  135. conn.queries = make(map[string]string)
  136. conn.cookies = make(map[string]string)
  137. return conn
  138. }
  139. // Return Conn to pool
  140. func releaseConn(conn *Conn) {
  141. conn.Conn = nil
  142. poolConn.Put(conn)
  143. }
  144. // Locals makes it possible to pass interface{} values under string keys scoped to the request
  145. // and therefore available to all following routes that match the request.
  146. func (conn *Conn) Locals(key string) interface{} {
  147. return conn.locals[key]
  148. }
  149. // Params is used to get the route parameters.
  150. // Defaults to empty string "" if the param doesn't exist.
  151. // If a default value is given, it will return that value if the param doesn't exist.
  152. func (conn *Conn) Params(key string, defaultValue ...string) string {
  153. v, ok := conn.params[key]
  154. if !ok && len(defaultValue) > 0 {
  155. return defaultValue[0]
  156. }
  157. return v
  158. }
  159. // Query returns the query string parameter in the url.
  160. // Defaults to empty string "" if the query doesn't exist.
  161. // If a default value is given, it will return that value if the query doesn't exist.
  162. func (conn *Conn) Query(key string, defaultValue ...string) string {
  163. v, ok := conn.queries[key]
  164. if !ok && len(defaultValue) > 0 {
  165. return defaultValue[0]
  166. }
  167. return v
  168. }
  169. // Cookies is used for getting a cookie value by key
  170. // Defaults to empty string "" if the cookie doesn't exist.
  171. // If a default value is given, it will return that value if the cookie doesn't exist.
  172. func (conn *Conn) Cookies(key string, defaultValue ...string) string {
  173. v, ok := conn.cookies[key]
  174. if !ok && len(defaultValue) > 0 {
  175. return defaultValue[0]
  176. }
  177. return v
  178. }
  179. // Constants are taken from https://github.com/fasthttp/websocket/blob/master/conn.go#L43
  180. // Close codes defined in RFC 6455, section 11.7.
  181. const (
  182. CloseNormalClosure = 1000
  183. CloseGoingAway = 1001
  184. CloseProtocolError = 1002
  185. CloseUnsupportedData = 1003
  186. CloseNoStatusReceived = 1005
  187. CloseAbnormalClosure = 1006
  188. CloseInvalidFramePayloadData = 1007
  189. ClosePolicyViolation = 1008
  190. CloseMessageTooBig = 1009
  191. CloseMandatoryExtension = 1010
  192. CloseInternalServerErr = 1011
  193. CloseServiceRestart = 1012
  194. CloseTryAgainLater = 1013
  195. CloseTLSHandshake = 1015
  196. )
  197. // The message types are defined in RFC 6455, section 11.8.
  198. const (
  199. // TextMessage denotes a text data message. The text message payload is
  200. // interpreted as UTF-8 encoded text data.
  201. TextMessage = 1
  202. // BinaryMessage denotes a binary data message.
  203. BinaryMessage = 2
  204. // CloseMessage denotes a close control message. The optional message
  205. // payload contains a numeric code and text. Use the FormatCloseMessage
  206. // function to format a close message payload.
  207. CloseMessage = 8
  208. // PingMessage denotes a ping control message. The optional message payload
  209. // is UTF-8 encoded text.
  210. PingMessage = 9
  211. // PongMessage denotes a pong control message. The optional message payload
  212. // is UTF-8 encoded text.
  213. PongMessage = 10
  214. )
  215. var (
  216. // ErrBadHandshake is returned when the server response to opening handshake is
  217. // invalid.
  218. ErrBadHandshake = errors.New("websocket: bad handshake")
  219. // ErrCloseSent is returned when the application writes a message to the
  220. // connection after sending a close message.
  221. ErrCloseSent = errors.New("websocket: close sent")
  222. // ErrReadLimit is returned when reading a message that is larger than the
  223. // read limit set for the connection.
  224. ErrReadLimit = errors.New("websocket: read limit exceeded")
  225. )
  226. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
  227. // An empty message is returned for code CloseNoStatusReceived.
  228. func FormatCloseMessage(closeCode int, text string) []byte {
  229. return websocket.FormatCloseMessage(closeCode, text)
  230. }
  231. // IsCloseError returns boolean indicating whether the error is a *CloseError
  232. // with one of the specified codes.
  233. func IsCloseError(err error, codes ...int) bool {
  234. return websocket.IsCloseError(err, codes...)
  235. }
  236. // IsUnexpectedCloseError returns boolean indicating whether the error is a
  237. // *CloseError with a code not in the list of expected codes.
  238. func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
  239. return websocket.IsUnexpectedCloseError(err, expectedCodes...)
  240. }
  241. // IsWebSocketUpgrade returns true if the client requested upgrade to the
  242. // WebSocket protocol.
  243. func IsWebSocketUpgrade(c *fiber.Ctx) bool {
  244. return websocket.FastHTTPIsWebSocketUpgrade(c.Context())
  245. }
  246. // JoinMessages concatenates received messages to create a single io.Reader.
  247. // The string term is appended to each message. The returned reader does not
  248. // support concurrent calls to the Read method.
  249. func JoinMessages(c *websocket.Conn, term string) io.Reader {
  250. return websocket.JoinMessages(c, term)
  251. }