wire.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package protowire parses and formats the raw wire encoding.
  5. // See https://protobuf.dev/programming-guides/encoding.
  6. //
  7. // For marshaling and unmarshaling entire protobuf messages,
  8. // use the [google.golang.org/protobuf/proto] package instead.
  9. package protowire
  10. import (
  11. "io"
  12. "math"
  13. "math/bits"
  14. "google.golang.org/protobuf/internal/errors"
  15. )
  16. // Number represents the field number.
  17. type Number int32
  18. const (
  19. MinValidNumber Number = 1
  20. FirstReservedNumber Number = 19000
  21. LastReservedNumber Number = 19999
  22. MaxValidNumber Number = 1<<29 - 1
  23. DefaultRecursionLimit = 10000
  24. )
  25. // IsValid reports whether the field number is semantically valid.
  26. func (n Number) IsValid() bool {
  27. return MinValidNumber <= n && n <= MaxValidNumber
  28. }
  29. // Type represents the wire type.
  30. type Type int8
  31. const (
  32. VarintType Type = 0
  33. Fixed32Type Type = 5
  34. Fixed64Type Type = 1
  35. BytesType Type = 2
  36. StartGroupType Type = 3
  37. EndGroupType Type = 4
  38. )
  39. const (
  40. _ = -iota
  41. errCodeTruncated
  42. errCodeFieldNumber
  43. errCodeOverflow
  44. errCodeReserved
  45. errCodeEndGroup
  46. errCodeRecursionDepth
  47. )
  48. var (
  49. errFieldNumber = errors.New("invalid field number")
  50. errOverflow = errors.New("variable length integer overflow")
  51. errReserved = errors.New("cannot parse reserved wire type")
  52. errEndGroup = errors.New("mismatching end group marker")
  53. errParse = errors.New("parse error")
  54. )
  55. // ParseError converts an error code into an error value.
  56. // This returns nil if n is a non-negative number.
  57. func ParseError(n int) error {
  58. if n >= 0 {
  59. return nil
  60. }
  61. switch n {
  62. case errCodeTruncated:
  63. return io.ErrUnexpectedEOF
  64. case errCodeFieldNumber:
  65. return errFieldNumber
  66. case errCodeOverflow:
  67. return errOverflow
  68. case errCodeReserved:
  69. return errReserved
  70. case errCodeEndGroup:
  71. return errEndGroup
  72. default:
  73. return errParse
  74. }
  75. }
  76. // ConsumeField parses an entire field record (both tag and value) and returns
  77. // the field number, the wire type, and the total length.
  78. // This returns a negative length upon an error (see [ParseError]).
  79. //
  80. // The total length includes the tag header and the end group marker (if the
  81. // field is a group).
  82. func ConsumeField(b []byte) (Number, Type, int) {
  83. num, typ, n := ConsumeTag(b)
  84. if n < 0 {
  85. return 0, 0, n // forward error code
  86. }
  87. m := ConsumeFieldValue(num, typ, b[n:])
  88. if m < 0 {
  89. return 0, 0, m // forward error code
  90. }
  91. return num, typ, n + m
  92. }
  93. // ConsumeFieldValue parses a field value and returns its length.
  94. // This assumes that the field [Number] and wire [Type] have already been parsed.
  95. // This returns a negative length upon an error (see [ParseError]).
  96. //
  97. // When parsing a group, the length includes the end group marker and
  98. // the end group is verified to match the starting field number.
  99. func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
  100. return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
  101. }
  102. func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
  103. switch typ {
  104. case VarintType:
  105. _, n = ConsumeVarint(b)
  106. return n
  107. case Fixed32Type:
  108. _, n = ConsumeFixed32(b)
  109. return n
  110. case Fixed64Type:
  111. _, n = ConsumeFixed64(b)
  112. return n
  113. case BytesType:
  114. _, n = ConsumeBytes(b)
  115. return n
  116. case StartGroupType:
  117. if depth < 0 {
  118. return errCodeRecursionDepth
  119. }
  120. n0 := len(b)
  121. for {
  122. num2, typ2, n := ConsumeTag(b)
  123. if n < 0 {
  124. return n // forward error code
  125. }
  126. b = b[n:]
  127. if typ2 == EndGroupType {
  128. if num != num2 {
  129. return errCodeEndGroup
  130. }
  131. return n0 - len(b)
  132. }
  133. n = consumeFieldValueD(num2, typ2, b, depth-1)
  134. if n < 0 {
  135. return n // forward error code
  136. }
  137. b = b[n:]
  138. }
  139. case EndGroupType:
  140. return errCodeEndGroup
  141. default:
  142. return errCodeReserved
  143. }
  144. }
  145. // AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
  146. func AppendTag(b []byte, num Number, typ Type) []byte {
  147. return AppendVarint(b, EncodeTag(num, typ))
  148. }
  149. // ConsumeTag parses b as a varint-encoded tag, reporting its length.
  150. // This returns a negative length upon an error (see [ParseError]).
  151. func ConsumeTag(b []byte) (Number, Type, int) {
  152. v, n := ConsumeVarint(b)
  153. if n < 0 {
  154. return 0, 0, n // forward error code
  155. }
  156. num, typ := DecodeTag(v)
  157. if num < MinValidNumber {
  158. return 0, 0, errCodeFieldNumber
  159. }
  160. return num, typ, n
  161. }
  162. func SizeTag(num Number) int {
  163. return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
  164. }
  165. // AppendVarint appends v to b as a varint-encoded uint64.
  166. func AppendVarint(b []byte, v uint64) []byte {
  167. switch {
  168. case v < 1<<7:
  169. b = append(b, byte(v))
  170. case v < 1<<14:
  171. b = append(b,
  172. byte((v>>0)&0x7f|0x80),
  173. byte(v>>7))
  174. case v < 1<<21:
  175. b = append(b,
  176. byte((v>>0)&0x7f|0x80),
  177. byte((v>>7)&0x7f|0x80),
  178. byte(v>>14))
  179. case v < 1<<28:
  180. b = append(b,
  181. byte((v>>0)&0x7f|0x80),
  182. byte((v>>7)&0x7f|0x80),
  183. byte((v>>14)&0x7f|0x80),
  184. byte(v>>21))
  185. case v < 1<<35:
  186. b = append(b,
  187. byte((v>>0)&0x7f|0x80),
  188. byte((v>>7)&0x7f|0x80),
  189. byte((v>>14)&0x7f|0x80),
  190. byte((v>>21)&0x7f|0x80),
  191. byte(v>>28))
  192. case v < 1<<42:
  193. b = append(b,
  194. byte((v>>0)&0x7f|0x80),
  195. byte((v>>7)&0x7f|0x80),
  196. byte((v>>14)&0x7f|0x80),
  197. byte((v>>21)&0x7f|0x80),
  198. byte((v>>28)&0x7f|0x80),
  199. byte(v>>35))
  200. case v < 1<<49:
  201. b = append(b,
  202. byte((v>>0)&0x7f|0x80),
  203. byte((v>>7)&0x7f|0x80),
  204. byte((v>>14)&0x7f|0x80),
  205. byte((v>>21)&0x7f|0x80),
  206. byte((v>>28)&0x7f|0x80),
  207. byte((v>>35)&0x7f|0x80),
  208. byte(v>>42))
  209. case v < 1<<56:
  210. b = append(b,
  211. byte((v>>0)&0x7f|0x80),
  212. byte((v>>7)&0x7f|0x80),
  213. byte((v>>14)&0x7f|0x80),
  214. byte((v>>21)&0x7f|0x80),
  215. byte((v>>28)&0x7f|0x80),
  216. byte((v>>35)&0x7f|0x80),
  217. byte((v>>42)&0x7f|0x80),
  218. byte(v>>49))
  219. case v < 1<<63:
  220. b = append(b,
  221. byte((v>>0)&0x7f|0x80),
  222. byte((v>>7)&0x7f|0x80),
  223. byte((v>>14)&0x7f|0x80),
  224. byte((v>>21)&0x7f|0x80),
  225. byte((v>>28)&0x7f|0x80),
  226. byte((v>>35)&0x7f|0x80),
  227. byte((v>>42)&0x7f|0x80),
  228. byte((v>>49)&0x7f|0x80),
  229. byte(v>>56))
  230. default:
  231. b = append(b,
  232. byte((v>>0)&0x7f|0x80),
  233. byte((v>>7)&0x7f|0x80),
  234. byte((v>>14)&0x7f|0x80),
  235. byte((v>>21)&0x7f|0x80),
  236. byte((v>>28)&0x7f|0x80),
  237. byte((v>>35)&0x7f|0x80),
  238. byte((v>>42)&0x7f|0x80),
  239. byte((v>>49)&0x7f|0x80),
  240. byte((v>>56)&0x7f|0x80),
  241. 1)
  242. }
  243. return b
  244. }
  245. // ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
  246. // This returns a negative length upon an error (see [ParseError]).
  247. func ConsumeVarint(b []byte) (v uint64, n int) {
  248. var y uint64
  249. if len(b) <= 0 {
  250. return 0, errCodeTruncated
  251. }
  252. v = uint64(b[0])
  253. if v < 0x80 {
  254. return v, 1
  255. }
  256. v -= 0x80
  257. if len(b) <= 1 {
  258. return 0, errCodeTruncated
  259. }
  260. y = uint64(b[1])
  261. v += y << 7
  262. if y < 0x80 {
  263. return v, 2
  264. }
  265. v -= 0x80 << 7
  266. if len(b) <= 2 {
  267. return 0, errCodeTruncated
  268. }
  269. y = uint64(b[2])
  270. v += y << 14
  271. if y < 0x80 {
  272. return v, 3
  273. }
  274. v -= 0x80 << 14
  275. if len(b) <= 3 {
  276. return 0, errCodeTruncated
  277. }
  278. y = uint64(b[3])
  279. v += y << 21
  280. if y < 0x80 {
  281. return v, 4
  282. }
  283. v -= 0x80 << 21
  284. if len(b) <= 4 {
  285. return 0, errCodeTruncated
  286. }
  287. y = uint64(b[4])
  288. v += y << 28
  289. if y < 0x80 {
  290. return v, 5
  291. }
  292. v -= 0x80 << 28
  293. if len(b) <= 5 {
  294. return 0, errCodeTruncated
  295. }
  296. y = uint64(b[5])
  297. v += y << 35
  298. if y < 0x80 {
  299. return v, 6
  300. }
  301. v -= 0x80 << 35
  302. if len(b) <= 6 {
  303. return 0, errCodeTruncated
  304. }
  305. y = uint64(b[6])
  306. v += y << 42
  307. if y < 0x80 {
  308. return v, 7
  309. }
  310. v -= 0x80 << 42
  311. if len(b) <= 7 {
  312. return 0, errCodeTruncated
  313. }
  314. y = uint64(b[7])
  315. v += y << 49
  316. if y < 0x80 {
  317. return v, 8
  318. }
  319. v -= 0x80 << 49
  320. if len(b) <= 8 {
  321. return 0, errCodeTruncated
  322. }
  323. y = uint64(b[8])
  324. v += y << 56
  325. if y < 0x80 {
  326. return v, 9
  327. }
  328. v -= 0x80 << 56
  329. if len(b) <= 9 {
  330. return 0, errCodeTruncated
  331. }
  332. y = uint64(b[9])
  333. v += y << 63
  334. if y < 2 {
  335. return v, 10
  336. }
  337. return 0, errCodeOverflow
  338. }
  339. // SizeVarint returns the encoded size of a varint.
  340. // The size is guaranteed to be within 1 and 10, inclusive.
  341. func SizeVarint(v uint64) int {
  342. // This computes 1 + (bits.Len64(v)-1)/7.
  343. // 9/64 is a good enough approximation of 1/7
  344. //
  345. // The Go compiler can translate the bits.LeadingZeros64 call into the LZCNT
  346. // instruction, which is very fast on CPUs from the last few years. The
  347. // specific way of expressing the calculation matches C++ Protobuf, see
  348. // https://godbolt.org/z/4P3h53oM4 for the C++ code and how gcc/clang
  349. // optimize that function for GOAMD64=v1 and GOAMD64=v3 (-march=haswell).
  350. // By OR'ing v with 1, we guarantee that v is never 0, without changing the
  351. // result of SizeVarint. LZCNT is not defined for 0, meaning the compiler
  352. // needs to add extra instructions to handle that case.
  353. //
  354. // The Go compiler currently (go1.24.4) does not make use of this knowledge.
  355. // This opportunity (removing the XOR instruction, which handles the 0 case)
  356. // results in a small (1%) performance win across CPU architectures.
  357. //
  358. // Independently of avoiding the 0 case, we need the v |= 1 line because
  359. // it allows the Go compiler to eliminate an extra XCHGL barrier.
  360. v |= 1
  361. // It would be clearer to write log2value := 63 - uint32(...), but
  362. // writing uint32(...) ^ 63 is much more efficient (-14% ARM, -20% Intel).
  363. // Proof of identity for our value range [0..63]:
  364. // https://go.dev/play/p/Pdn9hEWYakX
  365. log2value := uint32(bits.LeadingZeros64(v)) ^ 63
  366. return int((log2value*9 + (64 + 9)) / 64)
  367. }
  368. // AppendFixed32 appends v to b as a little-endian uint32.
  369. func AppendFixed32(b []byte, v uint32) []byte {
  370. return append(b,
  371. byte(v>>0),
  372. byte(v>>8),
  373. byte(v>>16),
  374. byte(v>>24))
  375. }
  376. // ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
  377. // This returns a negative length upon an error (see [ParseError]).
  378. func ConsumeFixed32(b []byte) (v uint32, n int) {
  379. if len(b) < 4 {
  380. return 0, errCodeTruncated
  381. }
  382. v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  383. return v, 4
  384. }
  385. // SizeFixed32 returns the encoded size of a fixed32; which is always 4.
  386. func SizeFixed32() int {
  387. return 4
  388. }
  389. // AppendFixed64 appends v to b as a little-endian uint64.
  390. func AppendFixed64(b []byte, v uint64) []byte {
  391. return append(b,
  392. byte(v>>0),
  393. byte(v>>8),
  394. byte(v>>16),
  395. byte(v>>24),
  396. byte(v>>32),
  397. byte(v>>40),
  398. byte(v>>48),
  399. byte(v>>56))
  400. }
  401. // ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
  402. // This returns a negative length upon an error (see [ParseError]).
  403. func ConsumeFixed64(b []byte) (v uint64, n int) {
  404. if len(b) < 8 {
  405. return 0, errCodeTruncated
  406. }
  407. v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
  408. return v, 8
  409. }
  410. // SizeFixed64 returns the encoded size of a fixed64; which is always 8.
  411. func SizeFixed64() int {
  412. return 8
  413. }
  414. // AppendBytes appends v to b as a length-prefixed bytes value.
  415. func AppendBytes(b []byte, v []byte) []byte {
  416. return append(AppendVarint(b, uint64(len(v))), v...)
  417. }
  418. // ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
  419. // This returns a negative length upon an error (see [ParseError]).
  420. func ConsumeBytes(b []byte) (v []byte, n int) {
  421. m, n := ConsumeVarint(b)
  422. if n < 0 {
  423. return nil, n // forward error code
  424. }
  425. if m > uint64(len(b[n:])) {
  426. return nil, errCodeTruncated
  427. }
  428. return b[n:][:m], n + int(m)
  429. }
  430. // SizeBytes returns the encoded size of a length-prefixed bytes value,
  431. // given only the length.
  432. func SizeBytes(n int) int {
  433. return SizeVarint(uint64(n)) + n
  434. }
  435. // AppendString appends v to b as a length-prefixed bytes value.
  436. func AppendString(b []byte, v string) []byte {
  437. return append(AppendVarint(b, uint64(len(v))), v...)
  438. }
  439. // ConsumeString parses b as a length-prefixed bytes value, reporting its length.
  440. // This returns a negative length upon an error (see [ParseError]).
  441. func ConsumeString(b []byte) (v string, n int) {
  442. bb, n := ConsumeBytes(b)
  443. return string(bb), n
  444. }
  445. // AppendGroup appends v to b as group value, with a trailing end group marker.
  446. // The value v must not contain the end marker.
  447. func AppendGroup(b []byte, num Number, v []byte) []byte {
  448. return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
  449. }
  450. // ConsumeGroup parses b as a group value until the trailing end group marker,
  451. // and verifies that the end marker matches the provided num. The value v
  452. // does not contain the end marker, while the length does contain the end marker.
  453. // This returns a negative length upon an error (see [ParseError]).
  454. func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
  455. n = ConsumeFieldValue(num, StartGroupType, b)
  456. if n < 0 {
  457. return nil, n // forward error code
  458. }
  459. b = b[:n]
  460. // Truncate off end group marker, but need to handle denormalized varints.
  461. // Assuming end marker is never 0 (which is always the case since
  462. // EndGroupType is non-zero), we can truncate all trailing bytes where the
  463. // lower 7 bits are all zero (implying that the varint is denormalized).
  464. for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
  465. b = b[:len(b)-1]
  466. }
  467. b = b[:len(b)-SizeTag(num)]
  468. return b, n
  469. }
  470. // SizeGroup returns the encoded size of a group, given only the length.
  471. func SizeGroup(num Number, n int) int {
  472. return n + SizeTag(num)
  473. }
  474. // DecodeTag decodes the field [Number] and wire [Type] from its unified form.
  475. // The [Number] is -1 if the decoded field number overflows int32.
  476. // Other than overflow, this does not check for field number validity.
  477. func DecodeTag(x uint64) (Number, Type) {
  478. // NOTE: MessageSet allows for larger field numbers than normal.
  479. if x>>3 > uint64(math.MaxInt32) {
  480. return -1, 0
  481. }
  482. return Number(x >> 3), Type(x & 7)
  483. }
  484. // EncodeTag encodes the field [Number] and wire [Type] into its unified form.
  485. func EncodeTag(num Number, typ Type) uint64 {
  486. return uint64(num)<<3 | uint64(typ&7)
  487. }
  488. // DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
  489. //
  490. // Input: {…, 5, 3, 1, 0, 2, 4, 6, …}
  491. // Output: {…, -3, -2, -1, 0, +1, +2, +3, …}
  492. func DecodeZigZag(x uint64) int64 {
  493. return int64(x>>1) ^ int64(x)<<63>>63
  494. }
  495. // EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
  496. //
  497. // Input: {…, -3, -2, -1, 0, +1, +2, +3, …}
  498. // Output: {…, 5, 3, 1, 0, 2, 4, 6, …}
  499. func EncodeZigZag(x int64) uint64 {
  500. return uint64(x<<1) ^ uint64(x>>63)
  501. }
  502. // DecodeBool decodes a uint64 as a bool.
  503. //
  504. // Input: { 0, 1, 2, …}
  505. // Output: {false, true, true, …}
  506. func DecodeBool(x uint64) bool {
  507. return x != 0
  508. }
  509. // EncodeBool encodes a bool as a uint64.
  510. //
  511. // Input: {false, true}
  512. // Output: { 0, 1}
  513. func EncodeBool(x bool) uint64 {
  514. if x {
  515. return 1
  516. }
  517. return 0
  518. }