ipv6.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package fasthttp
  2. import (
  3. "bytes"
  4. "errors"
  5. )
  6. var (
  7. errInvalidIPv6Host = errors.New("invalid IPv6 host")
  8. errInvalidIPv6Zone = errors.New("invalid IPv6 zone")
  9. errInvalidIPv6Address = errors.New("invalid IPv6 address")
  10. )
  11. func validateIPv6Literal(host []byte) error {
  12. if len(host) == 0 || host[0] != '[' {
  13. return nil
  14. }
  15. end := bytes.IndexByte(host, ']')
  16. if end < 0 || end == 1 {
  17. return errInvalidIPv6Host
  18. }
  19. addr := host[1:end]
  20. // Optional zone.
  21. if zi := bytes.IndexByte(addr, '%'); zi >= 0 {
  22. if zi == len(addr)-1 {
  23. return errInvalidIPv6Zone
  24. }
  25. addr = addr[:zi]
  26. }
  27. // Must have a colon to be IPv6.
  28. if bytes.IndexByte(addr, ':') < 0 {
  29. return errInvalidIPv6Address
  30. }
  31. // IPv4-embedded?
  32. if bytes.IndexByte(addr, '.') >= 0 {
  33. lastColon := bytes.LastIndexByte(addr, ':')
  34. if lastColon < 0 || lastColon == len(addr)-1 {
  35. return errInvalidIPv6Address
  36. }
  37. ipv4 := addr[lastColon+1:]
  38. if !validIPv4(ipv4) {
  39. return errInvalidIPv6Address
  40. }
  41. head := addr[:lastColon]
  42. seenDoubleAtSplit := lastColon > 0 && addr[lastColon-1] == ':'
  43. if seenDoubleAtSplit {
  44. head = addr[:lastColon-1]
  45. }
  46. hextets, seenDoubleHead, ok := parseIPv6Hextets(head, false)
  47. if !ok {
  48. return errInvalidIPv6Address
  49. }
  50. if seenDoubleHead && seenDoubleAtSplit {
  51. return errInvalidIPv6Address
  52. }
  53. hextets += 2 // IPv4 tail = 2 hextets
  54. seenDouble := seenDoubleHead || seenDoubleAtSplit
  55. // '::' must compress at least one hextet.
  56. if (!seenDouble && hextets != 8) || (seenDouble && hextets >= 8) {
  57. return errInvalidIPv6Address
  58. }
  59. return nil
  60. }
  61. // Pure IPv6
  62. hextets, seenDouble, ok := parseIPv6Hextets(addr, false)
  63. if !ok {
  64. return errInvalidIPv6Address
  65. }
  66. if (!seenDouble && hextets != 8) || (seenDouble && hextets >= 8) {
  67. return errInvalidIPv6Address
  68. }
  69. return nil
  70. }
  71. func parseIPv6Hextets(s []byte, allowTrailingColon bool) (groups int, seenDouble, ok bool) {
  72. n := len(s)
  73. if n == 0 {
  74. return 0, false, true
  75. }
  76. i := 0
  77. justSawDouble := false
  78. for i < n {
  79. if s[i] == ':' {
  80. if i+1 < n && s[i+1] == ':' {
  81. if seenDouble || justSawDouble {
  82. return 0, false, false
  83. }
  84. seenDouble = true
  85. justSawDouble = true
  86. i += 2
  87. if i == n {
  88. break
  89. }
  90. continue
  91. }
  92. if i == 0 {
  93. return 0, false, false
  94. }
  95. if justSawDouble {
  96. return 0, false, false
  97. }
  98. if i == n-1 {
  99. if allowTrailingColon {
  100. break
  101. }
  102. return 0, false, false
  103. }
  104. if !ishex(s[i+1]) {
  105. return 0, false, false
  106. }
  107. i++
  108. continue
  109. }
  110. justSawDouble = false
  111. cnt := 0
  112. for cnt < 4 && i < n && ishex(s[i]) {
  113. i++
  114. cnt++
  115. }
  116. if cnt == 0 {
  117. return 0, false, false
  118. }
  119. groups++
  120. if i < n && s[i] != ':' {
  121. return 0, false, false
  122. }
  123. }
  124. return groups, seenDouble, true
  125. }
  126. // validIPv4 validates a dotted-quad (exactly 4 parts, 0..255) with no leading zeros
  127. // unless the octet is exactly "0".
  128. func validIPv4(s []byte) bool {
  129. parts := 0
  130. i := 0
  131. n := len(s)
  132. for parts < 4 {
  133. if i >= n {
  134. return false
  135. }
  136. start := i
  137. val := 0
  138. digits := 0
  139. for i < n {
  140. c := s[i]
  141. if c < '0' || c > '9' {
  142. break
  143. }
  144. val = val*10 + int(c-'0')
  145. if val > 255 {
  146. return false
  147. }
  148. i++
  149. digits++
  150. if digits > 3 {
  151. return false
  152. }
  153. }
  154. if digits == 0 {
  155. return false
  156. }
  157. // Disallow leading zeros like "00", "01", "001".
  158. // Allowed: exactly "0" or any number that doesn't start with '0'.
  159. if digits > 1 && s[start] == '0' {
  160. return false
  161. }
  162. parts++
  163. if parts == 4 {
  164. return i == n // must consume all input
  165. }
  166. if i >= n || s[i] != '.' {
  167. return false
  168. }
  169. i++ // skip dot
  170. }
  171. return false
  172. }