literals.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. from __future__ import annotations
  2. from typing import Final, FrozenSet, Tuple, Union
  3. from typing_extensions import TypeGuard
  4. # Supported Python literal types. All tuple / frozenset items must have supported
  5. # literal types as well, but we can't represent the type precisely.
  6. LiteralValue = Union[
  7. str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None
  8. ]
  9. def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]:
  10. return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None)))
  11. # Some literals are singletons and handled specially (None, False and True)
  12. NUM_SINGLETONS: Final = 3
  13. class Literals:
  14. """Collection of literal values used in a compilation group and related helpers."""
  15. def __init__(self) -> None:
  16. # Each dict maps value to literal index (0, 1, ...)
  17. self.str_literals: dict[str, int] = {}
  18. self.bytes_literals: dict[bytes, int] = {}
  19. self.int_literals: dict[int, int] = {}
  20. self.float_literals: dict[float, int] = {}
  21. self.complex_literals: dict[complex, int] = {}
  22. self.tuple_literals: dict[tuple[object, ...], int] = {}
  23. self.frozenset_literals: dict[frozenset[object], int] = {}
  24. def record_literal(self, value: LiteralValue) -> None:
  25. """Ensure that the literal value is available in generated code."""
  26. if value is None or value is True or value is False:
  27. # These are special cased and always present
  28. return
  29. if isinstance(value, str):
  30. str_literals = self.str_literals
  31. if value not in str_literals:
  32. str_literals[value] = len(str_literals)
  33. elif isinstance(value, bytes):
  34. bytes_literals = self.bytes_literals
  35. if value not in bytes_literals:
  36. bytes_literals[value] = len(bytes_literals)
  37. elif isinstance(value, int):
  38. int_literals = self.int_literals
  39. if value not in int_literals:
  40. int_literals[value] = len(int_literals)
  41. elif isinstance(value, float):
  42. float_literals = self.float_literals
  43. if value not in float_literals:
  44. float_literals[value] = len(float_literals)
  45. elif isinstance(value, complex):
  46. complex_literals = self.complex_literals
  47. if value not in complex_literals:
  48. complex_literals[value] = len(complex_literals)
  49. elif isinstance(value, tuple):
  50. tuple_literals = self.tuple_literals
  51. if value not in tuple_literals:
  52. for item in value:
  53. assert _is_literal_value(item)
  54. self.record_literal(item)
  55. tuple_literals[value] = len(tuple_literals)
  56. elif isinstance(value, frozenset):
  57. frozenset_literals = self.frozenset_literals
  58. if value not in frozenset_literals:
  59. for item in value:
  60. assert _is_literal_value(item)
  61. self.record_literal(item)
  62. frozenset_literals[value] = len(frozenset_literals)
  63. else:
  64. assert False, "invalid literal: %r" % value
  65. def literal_index(self, value: LiteralValue) -> int:
  66. """Return the index to the literals array for given value."""
  67. # The array contains first None and booleans, followed by all str values,
  68. # followed by bytes values, etc.
  69. if value is None:
  70. return 0
  71. elif value is False:
  72. return 1
  73. elif value is True:
  74. return 2
  75. n = NUM_SINGLETONS
  76. if isinstance(value, str):
  77. return n + self.str_literals[value]
  78. n += len(self.str_literals)
  79. if isinstance(value, bytes):
  80. return n + self.bytes_literals[value]
  81. n += len(self.bytes_literals)
  82. if isinstance(value, int):
  83. return n + self.int_literals[value]
  84. n += len(self.int_literals)
  85. if isinstance(value, float):
  86. return n + self.float_literals[value]
  87. n += len(self.float_literals)
  88. if isinstance(value, complex):
  89. return n + self.complex_literals[value]
  90. n += len(self.complex_literals)
  91. if isinstance(value, tuple):
  92. return n + self.tuple_literals[value]
  93. n += len(self.tuple_literals)
  94. if isinstance(value, frozenset):
  95. return n + self.frozenset_literals[value]
  96. assert False, "invalid literal: %r" % value
  97. def num_literals(self) -> int:
  98. # The first three are for None, True and False
  99. return (
  100. NUM_SINGLETONS
  101. + len(self.str_literals)
  102. + len(self.bytes_literals)
  103. + len(self.int_literals)
  104. + len(self.float_literals)
  105. + len(self.complex_literals)
  106. + len(self.tuple_literals)
  107. + len(self.frozenset_literals)
  108. )
  109. # The following methods return the C encodings of literal values
  110. # of different types
  111. def encoded_str_values(self) -> list[bytes]:
  112. return _encode_str_values(self.str_literals)
  113. def encoded_int_values(self) -> list[bytes]:
  114. return _encode_int_values(self.int_literals)
  115. def encoded_bytes_values(self) -> list[bytes]:
  116. return _encode_bytes_values(self.bytes_literals)
  117. def encoded_float_values(self) -> list[str]:
  118. return _encode_float_values(self.float_literals)
  119. def encoded_complex_values(self) -> list[str]:
  120. return _encode_complex_values(self.complex_literals)
  121. def encoded_tuple_values(self) -> list[str]:
  122. return self._encode_collection_values(self.tuple_literals)
  123. def encoded_frozenset_values(self) -> list[str]:
  124. return self._encode_collection_values(self.frozenset_literals)
  125. def _encode_collection_values(
  126. self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
  127. ) -> list[str]:
  128. """Encode tuple/frozenset values into a C array.
  129. The format of the result is like this:
  130. <number of collections>
  131. <length of the first collection>
  132. <literal index of first item>
  133. ...
  134. <literal index of last item>
  135. <length of the second collection>
  136. ...
  137. """
  138. value_by_index = {index: value for value, index in values.items()}
  139. result = []
  140. count = len(values)
  141. result.append(str(count))
  142. for i in range(count):
  143. value = value_by_index[i]
  144. result.append(str(len(value)))
  145. for item in value:
  146. assert _is_literal_value(item)
  147. index = self.literal_index(item)
  148. result.append(str(index))
  149. return result
  150. def _encode_str_values(values: dict[str, int]) -> list[bytes]:
  151. value_by_index = {index: value for value, index in values.items()}
  152. result = []
  153. line: list[bytes] = []
  154. line_len = 0
  155. for i in range(len(values)):
  156. value = value_by_index[i]
  157. c_literal = format_str_literal(value)
  158. c_len = len(c_literal)
  159. if line_len > 0 and line_len + c_len > 70:
  160. result.append(format_int(len(line)) + b"".join(line))
  161. line = []
  162. line_len = 0
  163. line.append(c_literal)
  164. line_len += c_len
  165. if line:
  166. result.append(format_int(len(line)) + b"".join(line))
  167. result.append(b"")
  168. return result
  169. def _encode_bytes_values(values: dict[bytes, int]) -> list[bytes]:
  170. value_by_index = {index: value for value, index in values.items()}
  171. result = []
  172. line: list[bytes] = []
  173. line_len = 0
  174. for i in range(len(values)):
  175. value = value_by_index[i]
  176. c_init = format_int(len(value))
  177. c_len = len(c_init) + len(value)
  178. if line_len > 0 and line_len + c_len > 70:
  179. result.append(format_int(len(line)) + b"".join(line))
  180. line = []
  181. line_len = 0
  182. line.append(c_init + value)
  183. line_len += c_len
  184. if line:
  185. result.append(format_int(len(line)) + b"".join(line))
  186. result.append(b"")
  187. return result
  188. def format_int(n: int) -> bytes:
  189. """Format an integer using a variable-length binary encoding."""
  190. if n < 128:
  191. a = [n]
  192. else:
  193. a = []
  194. while n > 0:
  195. a.insert(0, n & 0x7F)
  196. n >>= 7
  197. for i in range(len(a) - 1):
  198. # If the highest bit is set, more 7-bit digits follow
  199. a[i] |= 0x80
  200. return bytes(a)
  201. def format_str_literal(s: str) -> bytes:
  202. utf8 = s.encode("utf-8")
  203. return format_int(len(utf8)) + utf8
  204. def _encode_int_values(values: dict[int, int]) -> list[bytes]:
  205. """Encode int values into C strings.
  206. Values are stored in base 10 and separated by 0 bytes.
  207. """
  208. value_by_index = {index: value for value, index in values.items()}
  209. result = []
  210. line: list[bytes] = []
  211. line_len = 0
  212. for i in range(len(values)):
  213. value = value_by_index[i]
  214. encoded = b"%d" % value
  215. if line_len > 0 and line_len + len(encoded) > 70:
  216. result.append(format_int(len(line)) + b"\0".join(line))
  217. line = []
  218. line_len = 0
  219. line.append(encoded)
  220. line_len += len(encoded)
  221. if line:
  222. result.append(format_int(len(line)) + b"\0".join(line))
  223. result.append(b"")
  224. return result
  225. def float_to_c(x: float) -> str:
  226. """Return C literal representation of a float value."""
  227. s = str(x)
  228. if s == "inf":
  229. return "INFINITY"
  230. elif s == "-inf":
  231. return "-INFINITY"
  232. elif s == "nan":
  233. return "NAN"
  234. return s
  235. def _encode_float_values(values: dict[float, int]) -> list[str]:
  236. """Encode float values into a C array values.
  237. The result contains the number of values followed by individual values.
  238. """
  239. value_by_index = {index: value for value, index in values.items()}
  240. result = []
  241. num = len(values)
  242. result.append(str(num))
  243. for i in range(num):
  244. value = value_by_index[i]
  245. result.append(float_to_c(value))
  246. return result
  247. def _encode_complex_values(values: dict[complex, int]) -> list[str]:
  248. """Encode float values into a C array values.
  249. The result contains the number of values followed by pairs of doubles
  250. representing complex numbers.
  251. """
  252. value_by_index = {index: value for value, index in values.items()}
  253. result = []
  254. num = len(values)
  255. result.append(str(num))
  256. for i in range(num):
  257. value = value_by_index[i]
  258. result.append(float_to_c(value.real))
  259. result.append(float_to_c(value.imag))
  260. return result