util.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "crypto/rand"
  8. "crypto/sha1"
  9. "encoding/base64"
  10. "io"
  11. "net/http"
  12. "strings"
  13. "unicode/utf8"
  14. )
  15. var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
  16. func computeAcceptKey(challengeKey string) string {
  17. h := sha1.New()
  18. h.Write([]byte(challengeKey))
  19. h.Write(keyGUID)
  20. return base64.StdEncoding.EncodeToString(h.Sum(nil))
  21. }
  22. func computeAcceptKeyBytes(challengeKey []byte) string {
  23. h := sha1.New()
  24. h.Write(challengeKey)
  25. h.Write(keyGUID)
  26. return base64.StdEncoding.EncodeToString(h.Sum(nil))
  27. }
  28. func generateChallengeKey() (string, error) {
  29. p := make([]byte, 16)
  30. if _, err := io.ReadFull(rand.Reader, p); err != nil {
  31. return "", err
  32. }
  33. return base64.StdEncoding.EncodeToString(p), nil
  34. }
  35. // Token octets per RFC 2616.
  36. var isTokenOctet = [256]bool{
  37. '!': true,
  38. '#': true,
  39. '$': true,
  40. '%': true,
  41. '&': true,
  42. '\'': true,
  43. '*': true,
  44. '+': true,
  45. '-': true,
  46. '.': true,
  47. '0': true,
  48. '1': true,
  49. '2': true,
  50. '3': true,
  51. '4': true,
  52. '5': true,
  53. '6': true,
  54. '7': true,
  55. '8': true,
  56. '9': true,
  57. 'A': true,
  58. 'B': true,
  59. 'C': true,
  60. 'D': true,
  61. 'E': true,
  62. 'F': true,
  63. 'G': true,
  64. 'H': true,
  65. 'I': true,
  66. 'J': true,
  67. 'K': true,
  68. 'L': true,
  69. 'M': true,
  70. 'N': true,
  71. 'O': true,
  72. 'P': true,
  73. 'Q': true,
  74. 'R': true,
  75. 'S': true,
  76. 'T': true,
  77. 'U': true,
  78. 'W': true,
  79. 'V': true,
  80. 'X': true,
  81. 'Y': true,
  82. 'Z': true,
  83. '^': true,
  84. '_': true,
  85. '`': true,
  86. 'a': true,
  87. 'b': true,
  88. 'c': true,
  89. 'd': true,
  90. 'e': true,
  91. 'f': true,
  92. 'g': true,
  93. 'h': true,
  94. 'i': true,
  95. 'j': true,
  96. 'k': true,
  97. 'l': true,
  98. 'm': true,
  99. 'n': true,
  100. 'o': true,
  101. 'p': true,
  102. 'q': true,
  103. 'r': true,
  104. 's': true,
  105. 't': true,
  106. 'u': true,
  107. 'v': true,
  108. 'w': true,
  109. 'x': true,
  110. 'y': true,
  111. 'z': true,
  112. '|': true,
  113. '~': true,
  114. }
  115. // skipSpace returns a slice of the string s with all leading RFC 2616 linear
  116. // whitespace removed.
  117. func skipSpace(s string) (rest string) {
  118. i := 0
  119. for ; i < len(s); i++ {
  120. if b := s[i]; b != ' ' && b != '\t' {
  121. break
  122. }
  123. }
  124. return s[i:]
  125. }
  126. // nextToken returns the leading RFC 2616 token of s and the string following
  127. // the token.
  128. func nextToken(s string) (token, rest string) {
  129. i := 0
  130. for ; i < len(s); i++ {
  131. if !isTokenOctet[s[i]] {
  132. break
  133. }
  134. }
  135. return s[:i], s[i:]
  136. }
  137. // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
  138. // and the string following the token or quoted string.
  139. func nextTokenOrQuoted(s string) (value string, rest string) {
  140. if !strings.HasPrefix(s, "\"") {
  141. token, rest := nextToken(s)
  142. return token, rest
  143. }
  144. s = s[1:]
  145. for i := 0; i < len(s); i++ {
  146. switch s[i] {
  147. case '"':
  148. return s[:i], s[i+1:]
  149. case '\\':
  150. p := make([]byte, len(s)-1)
  151. j := copy(p, s[:i])
  152. escape := true
  153. for i = i + 1; i < len(s); i++ {
  154. b := s[i]
  155. switch {
  156. case escape:
  157. escape = false
  158. p[j] = b
  159. j++
  160. case b == '\\':
  161. escape = true
  162. case b == '"':
  163. return string(p[:j]), s[i+1:]
  164. default:
  165. p[j] = b
  166. j++
  167. }
  168. }
  169. return "", ""
  170. }
  171. }
  172. return "", ""
  173. }
  174. // equalASCIIFold returns true if s is equal to t with ASCII case folding as
  175. // defined in RFC 4790.
  176. func equalASCIIFold(s, t string) bool {
  177. for s != "" && t != "" {
  178. sr, size := utf8.DecodeRuneInString(s)
  179. s = s[size:]
  180. tr, size := utf8.DecodeRuneInString(t)
  181. t = t[size:]
  182. if sr == tr {
  183. continue
  184. }
  185. if 'A' <= sr && sr <= 'Z' {
  186. sr = sr + 'a' - 'A'
  187. }
  188. if 'A' <= tr && tr <= 'Z' {
  189. tr = tr + 'a' - 'A'
  190. }
  191. if sr != tr {
  192. return false
  193. }
  194. }
  195. return s == t
  196. }
  197. // tokenListContainsValue returns true if the 1#token header with the given
  198. // name contains a token equal to value with ASCII case folding.
  199. func tokenContainsValue(s string, value string) bool {
  200. for {
  201. var t string
  202. t, s = nextToken(skipSpace(s))
  203. if t == "" {
  204. return false
  205. }
  206. s = skipSpace(s)
  207. if s != "" && s[0] != ',' {
  208. return false
  209. }
  210. if equalASCIIFold(t, value) {
  211. return true
  212. }
  213. if s == "" {
  214. return false
  215. }
  216. s = s[1:]
  217. }
  218. }
  219. // tokenListContainsValue returns true if the 1#token header with the given
  220. // name contains token.
  221. func tokenListContainsValue(header http.Header, name string, value string) bool {
  222. for _, s := range header[name] {
  223. if tokenContainsValue(s, value) {
  224. return true
  225. }
  226. }
  227. return false
  228. }
  229. // parseExtensions parses WebSocket extensions from a header.
  230. func parseExtensions(header http.Header) []map[string]string {
  231. // From RFC 6455:
  232. //
  233. // Sec-WebSocket-Extensions = extension-list
  234. // extension-list = 1#extension
  235. // extension = extension-token *( ";" extension-param )
  236. // extension-token = registered-token
  237. // registered-token = token
  238. // extension-param = token [ "=" (token | quoted-string) ]
  239. // ;When using the quoted-string syntax variant, the value
  240. // ;after quoted-string unescaping MUST conform to the
  241. // ;'token' ABNF.
  242. var result []map[string]string
  243. headers:
  244. for _, s := range header["Sec-Websocket-Extensions"] {
  245. for {
  246. var t string
  247. t, s = nextToken(skipSpace(s))
  248. if t == "" {
  249. continue headers
  250. }
  251. ext := map[string]string{"": t}
  252. for {
  253. s = skipSpace(s)
  254. if !strings.HasPrefix(s, ";") {
  255. break
  256. }
  257. var k string
  258. k, s = nextToken(skipSpace(s[1:]))
  259. if k == "" {
  260. continue headers
  261. }
  262. s = skipSpace(s)
  263. var v string
  264. if strings.HasPrefix(s, "=") {
  265. v, s = nextTokenOrQuoted(skipSpace(s[1:]))
  266. s = skipSpace(s)
  267. }
  268. if s != "" && s[0] != ',' && s[0] != ';' {
  269. continue headers
  270. }
  271. ext[k] = v
  272. }
  273. if s != "" && s[0] != ',' {
  274. continue headers
  275. }
  276. result = append(result, ext)
  277. if s == "" {
  278. continue headers
  279. }
  280. s = s[1:]
  281. }
  282. }
  283. return result
  284. }
  285. // isValidChallengeKey checks if the argument meets RFC6455 specification.
  286. func isValidChallengeKey(s string) bool {
  287. // From RFC6455:
  288. //
  289. // A |Sec-WebSocket-Key| header field with a base64-encoded (see
  290. // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
  291. // length.
  292. if s == "" {
  293. return false
  294. }
  295. decoded, err := base64.StdEncoding.DecodeString(s)
  296. return err == nil && len(decoded) == 16
  297. }
  298. // parseDataHeader returns a list with values if header value is comma-separated
  299. func parseDataHeader(headerValue []byte) [][]byte {
  300. h := bytes.TrimSpace(headerValue)
  301. if bytes.Equal(h, []byte("")) {
  302. return nil
  303. }
  304. values := bytes.Split(h, []byte(","))
  305. for i := range values {
  306. values[i] = bytes.TrimSpace(values[i])
  307. }
  308. return values
  309. }