hybi.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. // This file implements a protocol of hybi draft.
  6. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
  7. import (
  8. "bufio"
  9. "bytes"
  10. "crypto/rand"
  11. "crypto/sha1"
  12. "encoding/base64"
  13. "encoding/binary"
  14. "fmt"
  15. "io"
  16. "net/http"
  17. "net/url"
  18. "strings"
  19. )
  20. const (
  21. websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  22. closeStatusNormal = 1000
  23. closeStatusGoingAway = 1001
  24. closeStatusProtocolError = 1002
  25. closeStatusUnsupportedData = 1003
  26. closeStatusFrameTooLarge = 1004
  27. closeStatusNoStatusRcvd = 1005
  28. closeStatusAbnormalClosure = 1006
  29. closeStatusBadMessageData = 1007
  30. closeStatusPolicyViolation = 1008
  31. closeStatusTooBigData = 1009
  32. closeStatusExtensionMismatch = 1010
  33. maxControlFramePayloadLength = 125
  34. )
  35. var (
  36. ErrBadMaskingKey = &ProtocolError{"bad masking key"}
  37. ErrBadPongMessage = &ProtocolError{"bad pong message"}
  38. ErrBadClosingStatus = &ProtocolError{"bad closing status"}
  39. ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
  40. ErrNotImplemented = &ProtocolError{"not implemented"}
  41. handshakeHeader = map[string]bool{
  42. "Host": true,
  43. "Upgrade": true,
  44. "Connection": true,
  45. "Sec-Websocket-Key": true,
  46. "Sec-Websocket-Origin": true,
  47. "Sec-Websocket-Version": true,
  48. "Sec-Websocket-Protocol": true,
  49. "Sec-Websocket-Accept": true,
  50. }
  51. )
  52. // A hybiFrameHeader is a frame header as defined in hybi draft.
  53. type hybiFrameHeader struct {
  54. Fin bool
  55. Rsv [3]bool
  56. OpCode byte
  57. Length int64
  58. MaskingKey []byte
  59. data *bytes.Buffer
  60. }
  61. // A hybiFrameReader is a reader for hybi frame.
  62. type hybiFrameReader struct {
  63. reader io.Reader
  64. header hybiFrameHeader
  65. pos int64
  66. length int
  67. }
  68. func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
  69. n, err = frame.reader.Read(msg)
  70. if frame.header.MaskingKey != nil {
  71. for i := 0; i < n; i++ {
  72. msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
  73. frame.pos++
  74. }
  75. }
  76. return n, err
  77. }
  78. func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
  79. func (frame *hybiFrameReader) HeaderReader() io.Reader {
  80. if frame.header.data == nil {
  81. return nil
  82. }
  83. if frame.header.data.Len() == 0 {
  84. return nil
  85. }
  86. return frame.header.data
  87. }
  88. func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
  89. func (frame *hybiFrameReader) Len() (n int) { return frame.length }
  90. // A hybiFrameReaderFactory creates new frame reader based on its frame type.
  91. type hybiFrameReaderFactory struct {
  92. *bufio.Reader
  93. }
  94. // NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
  95. // See Section 5.2 Base Framing protocol for detail.
  96. // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
  97. func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
  98. hybiFrame := new(hybiFrameReader)
  99. frame = hybiFrame
  100. var header []byte
  101. var b byte
  102. // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
  103. b, err = buf.ReadByte()
  104. if err != nil {
  105. return
  106. }
  107. header = append(header, b)
  108. hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
  109. for i := 0; i < 3; i++ {
  110. j := uint(6 - i)
  111. hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
  112. }
  113. hybiFrame.header.OpCode = header[0] & 0x0f
  114. // Second byte. Mask/Payload len(7bits)
  115. b, err = buf.ReadByte()
  116. if err != nil {
  117. return
  118. }
  119. header = append(header, b)
  120. mask := (b & 0x80) != 0
  121. b &= 0x7f
  122. lengthFields := 0
  123. switch {
  124. case b <= 125: // Payload length 7bits.
  125. hybiFrame.header.Length = int64(b)
  126. case b == 126: // Payload length 7+16bits
  127. lengthFields = 2
  128. case b == 127: // Payload length 7+64bits
  129. lengthFields = 8
  130. }
  131. for i := 0; i < lengthFields; i++ {
  132. b, err = buf.ReadByte()
  133. if err != nil {
  134. return
  135. }
  136. if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
  137. b &= 0x7f
  138. }
  139. header = append(header, b)
  140. hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
  141. }
  142. if mask {
  143. // Masking key. 4 bytes.
  144. for i := 0; i < 4; i++ {
  145. b, err = buf.ReadByte()
  146. if err != nil {
  147. return
  148. }
  149. header = append(header, b)
  150. hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
  151. }
  152. }
  153. hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
  154. hybiFrame.header.data = bytes.NewBuffer(header)
  155. hybiFrame.length = len(header) + int(hybiFrame.header.Length)
  156. return
  157. }
  158. // A HybiFrameWriter is a writer for hybi frame.
  159. type hybiFrameWriter struct {
  160. writer *bufio.Writer
  161. header *hybiFrameHeader
  162. }
  163. func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
  164. var header []byte
  165. var b byte
  166. if frame.header.Fin {
  167. b |= 0x80
  168. }
  169. for i := 0; i < 3; i++ {
  170. if frame.header.Rsv[i] {
  171. j := uint(6 - i)
  172. b |= 1 << j
  173. }
  174. }
  175. b |= frame.header.OpCode
  176. header = append(header, b)
  177. if frame.header.MaskingKey != nil {
  178. b = 0x80
  179. } else {
  180. b = 0
  181. }
  182. lengthFields := 0
  183. length := len(msg)
  184. switch {
  185. case length <= 125:
  186. b |= byte(length)
  187. case length < 65536:
  188. b |= 126
  189. lengthFields = 2
  190. default:
  191. b |= 127
  192. lengthFields = 8
  193. }
  194. header = append(header, b)
  195. for i := 0; i < lengthFields; i++ {
  196. j := uint((lengthFields - i - 1) * 8)
  197. b = byte((length >> j) & 0xff)
  198. header = append(header, b)
  199. }
  200. if frame.header.MaskingKey != nil {
  201. if len(frame.header.MaskingKey) != 4 {
  202. return 0, ErrBadMaskingKey
  203. }
  204. header = append(header, frame.header.MaskingKey...)
  205. frame.writer.Write(header)
  206. data := make([]byte, length)
  207. for i := range data {
  208. data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
  209. }
  210. frame.writer.Write(data)
  211. err = frame.writer.Flush()
  212. return length, err
  213. }
  214. frame.writer.Write(header)
  215. frame.writer.Write(msg)
  216. err = frame.writer.Flush()
  217. return length, err
  218. }
  219. func (frame *hybiFrameWriter) Close() error { return nil }
  220. type hybiFrameWriterFactory struct {
  221. *bufio.Writer
  222. needMaskingKey bool
  223. }
  224. func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
  225. frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
  226. if buf.needMaskingKey {
  227. frameHeader.MaskingKey, err = generateMaskingKey()
  228. if err != nil {
  229. return nil, err
  230. }
  231. }
  232. return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
  233. }
  234. type hybiFrameHandler struct {
  235. conn *Conn
  236. payloadType byte
  237. }
  238. func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
  239. if handler.conn.IsServerConn() {
  240. // The client MUST mask all frames sent to the server.
  241. if frame.(*hybiFrameReader).header.MaskingKey == nil {
  242. handler.WriteClose(closeStatusProtocolError)
  243. return nil, io.EOF
  244. }
  245. } else {
  246. // The server MUST NOT mask all frames.
  247. if frame.(*hybiFrameReader).header.MaskingKey != nil {
  248. handler.WriteClose(closeStatusProtocolError)
  249. return nil, io.EOF
  250. }
  251. }
  252. if header := frame.HeaderReader(); header != nil {
  253. io.Copy(io.Discard, header)
  254. }
  255. switch frame.PayloadType() {
  256. case ContinuationFrame:
  257. frame.(*hybiFrameReader).header.OpCode = handler.payloadType
  258. case TextFrame, BinaryFrame:
  259. handler.payloadType = frame.PayloadType()
  260. case CloseFrame:
  261. return nil, io.EOF
  262. case PingFrame, PongFrame:
  263. b := make([]byte, maxControlFramePayloadLength)
  264. n, err := io.ReadFull(frame, b)
  265. if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
  266. return nil, err
  267. }
  268. io.Copy(io.Discard, frame)
  269. if frame.PayloadType() == PingFrame {
  270. if _, err := handler.WritePong(b[:n]); err != nil {
  271. return nil, err
  272. }
  273. }
  274. return nil, nil
  275. }
  276. return frame, nil
  277. }
  278. func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
  279. handler.conn.wio.Lock()
  280. defer handler.conn.wio.Unlock()
  281. w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
  282. if err != nil {
  283. return err
  284. }
  285. msg := make([]byte, 2)
  286. binary.BigEndian.PutUint16(msg, uint16(status))
  287. _, err = w.Write(msg)
  288. w.Close()
  289. return err
  290. }
  291. func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
  292. handler.conn.wio.Lock()
  293. defer handler.conn.wio.Unlock()
  294. w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
  295. if err != nil {
  296. return 0, err
  297. }
  298. n, err = w.Write(msg)
  299. w.Close()
  300. return n, err
  301. }
  302. // newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
  303. func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  304. if buf == nil {
  305. br := bufio.NewReader(rwc)
  306. bw := bufio.NewWriter(rwc)
  307. buf = bufio.NewReadWriter(br, bw)
  308. }
  309. ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
  310. frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
  311. frameWriterFactory: hybiFrameWriterFactory{
  312. buf.Writer, request == nil},
  313. PayloadType: TextFrame,
  314. defaultCloseStatus: closeStatusNormal}
  315. ws.frameHandler = &hybiFrameHandler{conn: ws}
  316. return ws
  317. }
  318. // generateMaskingKey generates a masking key for a frame.
  319. func generateMaskingKey() (maskingKey []byte, err error) {
  320. maskingKey = make([]byte, 4)
  321. if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
  322. return
  323. }
  324. return
  325. }
  326. // generateNonce generates a nonce consisting of a randomly selected 16-byte
  327. // value that has been base64-encoded.
  328. func generateNonce() (nonce []byte) {
  329. key := make([]byte, 16)
  330. if _, err := io.ReadFull(rand.Reader, key); err != nil {
  331. panic(err)
  332. }
  333. nonce = make([]byte, 24)
  334. base64.StdEncoding.Encode(nonce, key)
  335. return
  336. }
  337. // removeZone removes IPv6 zone identifier from host.
  338. // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
  339. func removeZone(host string) string {
  340. if !strings.HasPrefix(host, "[") {
  341. return host
  342. }
  343. i := strings.LastIndex(host, "]")
  344. if i < 0 {
  345. return host
  346. }
  347. j := strings.LastIndex(host[:i], "%")
  348. if j < 0 {
  349. return host
  350. }
  351. return host[:j] + host[i:]
  352. }
  353. // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
  354. // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
  355. func getNonceAccept(nonce []byte) (expected []byte, err error) {
  356. h := sha1.New()
  357. if _, err = h.Write(nonce); err != nil {
  358. return
  359. }
  360. if _, err = h.Write([]byte(websocketGUID)); err != nil {
  361. return
  362. }
  363. expected = make([]byte, 28)
  364. base64.StdEncoding.Encode(expected, h.Sum(nil))
  365. return
  366. }
  367. // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
  368. func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
  369. bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
  370. // According to RFC 6874, an HTTP client, proxy, or other
  371. // intermediary must remove any IPv6 zone identifier attached
  372. // to an outgoing URI.
  373. bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
  374. bw.WriteString("Upgrade: websocket\r\n")
  375. bw.WriteString("Connection: Upgrade\r\n")
  376. nonce := generateNonce()
  377. if config.handshakeData != nil {
  378. nonce = []byte(config.handshakeData["key"])
  379. }
  380. bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
  381. bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
  382. if config.Version != ProtocolVersionHybi13 {
  383. return ErrBadProtocolVersion
  384. }
  385. bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
  386. if len(config.Protocol) > 0 {
  387. bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
  388. }
  389. // TODO(ukai): send Sec-WebSocket-Extensions.
  390. err = config.Header.WriteSubset(bw, handshakeHeader)
  391. if err != nil {
  392. return err
  393. }
  394. bw.WriteString("\r\n")
  395. if err = bw.Flush(); err != nil {
  396. return err
  397. }
  398. resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
  399. if err != nil {
  400. return err
  401. }
  402. if resp.StatusCode != 101 {
  403. return ErrBadStatus
  404. }
  405. if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
  406. strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
  407. return ErrBadUpgrade
  408. }
  409. expectedAccept, err := getNonceAccept(nonce)
  410. if err != nil {
  411. return err
  412. }
  413. if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
  414. return ErrChallengeResponse
  415. }
  416. if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
  417. return ErrUnsupportedExtensions
  418. }
  419. offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
  420. if offeredProtocol != "" {
  421. protocolMatched := false
  422. for i := 0; i < len(config.Protocol); i++ {
  423. if config.Protocol[i] == offeredProtocol {
  424. protocolMatched = true
  425. break
  426. }
  427. }
  428. if !protocolMatched {
  429. return ErrBadWebSocketProtocol
  430. }
  431. config.Protocol = []string{offeredProtocol}
  432. }
  433. return nil
  434. }
  435. // newHybiClientConn creates a client WebSocket connection after handshake.
  436. func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
  437. return newHybiConn(config, buf, rwc, nil)
  438. }
  439. // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
  440. type hybiServerHandshaker struct {
  441. *Config
  442. accept []byte
  443. }
  444. func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
  445. c.Version = ProtocolVersionHybi13
  446. if req.Method != "GET" {
  447. return http.StatusMethodNotAllowed, ErrBadRequestMethod
  448. }
  449. // HTTP version can be safely ignored.
  450. if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
  451. !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
  452. return http.StatusBadRequest, ErrNotWebSocket
  453. }
  454. key := req.Header.Get("Sec-Websocket-Key")
  455. if key == "" {
  456. return http.StatusBadRequest, ErrChallengeResponse
  457. }
  458. version := req.Header.Get("Sec-Websocket-Version")
  459. switch version {
  460. case "13":
  461. c.Version = ProtocolVersionHybi13
  462. default:
  463. return http.StatusBadRequest, ErrBadWebSocketVersion
  464. }
  465. var scheme string
  466. if req.TLS != nil {
  467. scheme = "wss"
  468. } else {
  469. scheme = "ws"
  470. }
  471. c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
  472. if err != nil {
  473. return http.StatusBadRequest, err
  474. }
  475. protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
  476. if protocol != "" {
  477. protocols := strings.Split(protocol, ",")
  478. for i := 0; i < len(protocols); i++ {
  479. c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
  480. }
  481. }
  482. c.accept, err = getNonceAccept([]byte(key))
  483. if err != nil {
  484. return http.StatusInternalServerError, err
  485. }
  486. return http.StatusSwitchingProtocols, nil
  487. }
  488. // Origin parses the Origin header in req.
  489. // If the Origin header is not set, it returns nil and nil.
  490. func Origin(config *Config, req *http.Request) (*url.URL, error) {
  491. var origin string
  492. switch config.Version {
  493. case ProtocolVersionHybi13:
  494. origin = req.Header.Get("Origin")
  495. }
  496. if origin == "" {
  497. return nil, nil
  498. }
  499. return url.ParseRequestURI(origin)
  500. }
  501. func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
  502. if len(c.Protocol) > 0 {
  503. if len(c.Protocol) != 1 {
  504. // You need choose a Protocol in Handshake func in Server.
  505. return ErrBadWebSocketProtocol
  506. }
  507. }
  508. buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
  509. buf.WriteString("Upgrade: websocket\r\n")
  510. buf.WriteString("Connection: Upgrade\r\n")
  511. buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
  512. if len(c.Protocol) > 0 {
  513. buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
  514. }
  515. // TODO(ukai): send Sec-WebSocket-Extensions.
  516. if c.Header != nil {
  517. err := c.Header.WriteSubset(buf, handshakeHeader)
  518. if err != nil {
  519. return err
  520. }
  521. }
  522. buf.WriteString("\r\n")
  523. return buf.Flush()
  524. }
  525. func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  526. return newHybiServerConn(c.Config, buf, rwc, request)
  527. }
  528. // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
  529. func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
  530. return newHybiConn(config, buf, rwc, request)
  531. }