websocket.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. // 🚀 Fiber is an Express inspired web framework written in Go with 💖
  2. // 📌 API Documentation: https://fiber.wiki
  3. // 📝 Github Repository: https://github.com/gofiber/fiber
  4. package websocket
  5. import (
  6. "errors"
  7. "io"
  8. "sync"
  9. "time"
  10. "github.com/fasthttp/websocket"
  11. "github.com/gofiber/fiber/v2"
  12. "github.com/gofiber/fiber/v2/utils"
  13. "github.com/valyala/fasthttp"
  14. )
  15. // Config ...
  16. type Config struct {
  17. // Filter defines a function to skip middleware.
  18. // Optional. Default: nil
  19. Filter func(*fiber.Ctx) bool
  20. // HandshakeTimeout specifies the duration for the handshake to complete.
  21. HandshakeTimeout time.Duration
  22. // Subprotocols specifies the client's requested subprotocols.
  23. Subprotocols []string
  24. // Allowed Origin's based on the Origin header, this validate the request origin to
  25. // prevent cross-site request forgery. Everything is allowed if left empty.
  26. Origins []string
  27. // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
  28. // size is zero, then a useful default size is used. The I/O buffer sizes
  29. // do not limit the size of the messages that can be sent or received.
  30. ReadBufferSize, WriteBufferSize int
  31. // WriteBufferPool is a pool of buffers for write operations. If the value
  32. // is not set, then write buffers are allocated to the connection for the
  33. // lifetime of the connection.
  34. //
  35. // A pool is most useful when the application has a modest volume of writes
  36. // across a large number of connections.
  37. //
  38. // Applications should use a single pool for each unique value of
  39. // WriteBufferSize.
  40. WriteBufferPool websocket.BufferPool
  41. // EnableCompression specifies if the client should attempt to negotiate
  42. // per message compression (RFC 7692). Setting this value to true does not
  43. // guarantee that compression will be supported. Currently only "no context
  44. // takeover" modes are supported.
  45. EnableCompression bool
  46. }
  47. // New returns a new `handler func(*Conn)` that upgrades a client to the
  48. // websocket protocol, you can pass an optional config.
  49. func New(handler func(*Conn), config ...Config) fiber.Handler {
  50. // Init config
  51. var cfg Config
  52. if len(config) > 0 {
  53. cfg = config[0]
  54. }
  55. if len(cfg.Origins) == 0 {
  56. cfg.Origins = []string{"*"}
  57. }
  58. if cfg.ReadBufferSize == 0 {
  59. cfg.ReadBufferSize = 1024
  60. }
  61. if cfg.WriteBufferSize == 0 {
  62. cfg.WriteBufferSize = 1024
  63. }
  64. var upgrader = websocket.FastHTTPUpgrader{
  65. HandshakeTimeout: cfg.HandshakeTimeout,
  66. Subprotocols: cfg.Subprotocols,
  67. ReadBufferSize: cfg.ReadBufferSize,
  68. WriteBufferSize: cfg.WriteBufferSize,
  69. EnableCompression: cfg.EnableCompression,
  70. WriteBufferPool: cfg.WriteBufferPool,
  71. CheckOrigin: func(fctx *fasthttp.RequestCtx) bool {
  72. if cfg.Origins[0] == "*" {
  73. return true
  74. }
  75. origin := utils.UnsafeString(fctx.Request.Header.Peek("Origin"))
  76. for i := range cfg.Origins {
  77. if cfg.Origins[i] == origin {
  78. return true
  79. }
  80. }
  81. return false
  82. },
  83. }
  84. return func(c *fiber.Ctx) error {
  85. conn := acquireConn()
  86. // locals
  87. c.Context().VisitUserValues(func(key []byte, value interface{}) {
  88. conn.locals[string(key)] = value
  89. })
  90. // params
  91. params := c.Route().Params
  92. for i := 0; i < len(params); i++ {
  93. conn.params[utils.CopyString(params[i])] = utils.ImmutableString(c.Params(params[i]))
  94. }
  95. // queries
  96. c.Context().QueryArgs().VisitAll(func(key, value []byte) {
  97. conn.queries[string(key)] = string(value)
  98. })
  99. // cookies
  100. c.Context().Request.Header.VisitAllCookie(func(key, value []byte) {
  101. conn.cookies[string(key)] = string(value)
  102. })
  103. if err := upgrader.Upgrade(c.Context(), func(fconn *websocket.Conn) {
  104. conn.Conn = fconn
  105. defer releaseConn(conn)
  106. handler(conn)
  107. }); err != nil { // Upgrading required
  108. return fiber.ErrUpgradeRequired
  109. }
  110. return nil
  111. }
  112. }
  113. // Conn https://godoc.org/github.com/gorilla/websocket#pkg-index
  114. type Conn struct {
  115. *websocket.Conn
  116. locals map[string]interface{}
  117. params map[string]string
  118. cookies map[string]string
  119. queries map[string]string
  120. }
  121. // Conn pool
  122. var poolConn = sync.Pool{
  123. New: func() interface{} {
  124. return new(Conn)
  125. },
  126. }
  127. // Acquire Conn from pool
  128. func acquireConn() *Conn {
  129. conn := poolConn.Get().(*Conn)
  130. conn.locals = make(map[string]interface{})
  131. conn.params = make(map[string]string)
  132. conn.queries = make(map[string]string)
  133. conn.cookies = make(map[string]string)
  134. return conn
  135. }
  136. // Return Conn to pool
  137. func releaseConn(conn *Conn) {
  138. conn.Conn = nil
  139. poolConn.Put(conn)
  140. }
  141. // Locals makes it possible to pass interface{} values under string keys scoped to the request
  142. // and therefore available to all following routes that match the request.
  143. func (conn *Conn) Locals(key string) interface{} {
  144. return conn.locals[key]
  145. }
  146. // Params is used to get the route parameters.
  147. // Defaults to empty string "" if the param doesn't exist.
  148. // If a default value is given, it will return that value if the param doesn't exist.
  149. func (conn *Conn) Params(key string, defaultValue ...string) string {
  150. v, ok := conn.params[key]
  151. if !ok && len(defaultValue) > 0 {
  152. return defaultValue[0]
  153. }
  154. return v
  155. }
  156. // Query returns the query string parameter in the url.
  157. // Defaults to empty string "" if the query doesn't exist.
  158. // If a default value is given, it will return that value if the query doesn't exist.
  159. func (conn *Conn) Query(key string, defaultValue ...string) string {
  160. v, ok := conn.queries[key]
  161. if !ok && len(defaultValue) > 0 {
  162. return defaultValue[0]
  163. }
  164. return v
  165. }
  166. // Cookies is used for getting a cookie value by key
  167. // Defaults to empty string "" if the cookie doesn't exist.
  168. // If a default value is given, it will return that value if the cookie doesn't exist.
  169. func (conn *Conn) Cookies(key string, defaultValue ...string) string {
  170. v, ok := conn.cookies[key]
  171. if !ok && len(defaultValue) > 0 {
  172. return defaultValue[0]
  173. }
  174. return v
  175. }
  176. // Constants are taken from https://github.com/fasthttp/websocket/blob/master/conn.go#L43
  177. // Close codes defined in RFC 6455, section 11.7.
  178. const (
  179. CloseNormalClosure = 1000
  180. CloseGoingAway = 1001
  181. CloseProtocolError = 1002
  182. CloseUnsupportedData = 1003
  183. CloseNoStatusReceived = 1005
  184. CloseAbnormalClosure = 1006
  185. CloseInvalidFramePayloadData = 1007
  186. ClosePolicyViolation = 1008
  187. CloseMessageTooBig = 1009
  188. CloseMandatoryExtension = 1010
  189. CloseInternalServerErr = 1011
  190. CloseServiceRestart = 1012
  191. CloseTryAgainLater = 1013
  192. CloseTLSHandshake = 1015
  193. )
  194. // The message types are defined in RFC 6455, section 11.8.
  195. const (
  196. // TextMessage denotes a text data message. The text message payload is
  197. // interpreted as UTF-8 encoded text data.
  198. TextMessage = 1
  199. // BinaryMessage denotes a binary data message.
  200. BinaryMessage = 2
  201. // CloseMessage denotes a close control message. The optional message
  202. // payload contains a numeric code and text. Use the FormatCloseMessage
  203. // function to format a close message payload.
  204. CloseMessage = 8
  205. // PingMessage denotes a ping control message. The optional message payload
  206. // is UTF-8 encoded text.
  207. PingMessage = 9
  208. // PongMessage denotes a pong control message. The optional message payload
  209. // is UTF-8 encoded text.
  210. PongMessage = 10
  211. )
  212. var (
  213. // ErrBadHandshake is returned when the server response to opening handshake is
  214. // invalid.
  215. ErrBadHandshake = errors.New("websocket: bad handshake")
  216. // ErrCloseSent is returned when the application writes a message to the
  217. // connection after sending a close message.
  218. ErrCloseSent = errors.New("websocket: close sent")
  219. // ErrReadLimit is returned when reading a message that is larger than the
  220. // read limit set for the connection.
  221. ErrReadLimit = errors.New("websocket: read limit exceeded")
  222. )
  223. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
  224. // An empty message is returned for code CloseNoStatusReceived.
  225. func FormatCloseMessage(closeCode int, text string) []byte {
  226. return websocket.FormatCloseMessage(closeCode, text)
  227. }
  228. // IsCloseError returns boolean indicating whether the error is a *CloseError
  229. // with one of the specified codes.
  230. func IsCloseError(err error, codes ...int) bool {
  231. return websocket.IsCloseError(err, codes...)
  232. }
  233. // IsUnexpectedCloseError returns boolean indicating whether the error is a
  234. // *CloseError with a code not in the list of expected codes.
  235. func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
  236. return websocket.IsUnexpectedCloseError(err, expectedCodes...)
  237. }
  238. // IsWebSocketUpgrade returns true if the client requested upgrade to the
  239. // WebSocket protocol.
  240. func IsWebSocketUpgrade(c *fiber.Ctx) bool {
  241. return websocket.FastHTTPIsWebSocketUpgrade(c.Context())
  242. }
  243. // JoinMessages concatenates received messages to create a single io.Reader.
  244. // The string term is appended to each message. The returned reader does not
  245. // support concurrent calls to the Read method.
  246. func JoinMessages(c *websocket.Conn, term string) io.Reader {
  247. return websocket.JoinMessages(c, term)
  248. }