| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- from __future__ import annotations
- from typing import Final, FrozenSet, Tuple, Union
- from typing_extensions import TypeGuard
- # Supported Python literal types. All tuple / frozenset items must have supported
- # literal types as well, but we can't represent the type precisely.
- LiteralValue = Union[
- str, bytes, int, bool, float, complex, Tuple[object, ...], FrozenSet[object], None
- ]
- def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]:
- return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None)))
- # Some literals are singletons and handled specially (None, False and True)
- NUM_SINGLETONS: Final = 3
- class Literals:
- """Collection of literal values used in a compilation group and related helpers."""
- def __init__(self) -> None:
- # Each dict maps value to literal index (0, 1, ...)
- self.str_literals: dict[str, int] = {}
- self.bytes_literals: dict[bytes, int] = {}
- self.int_literals: dict[int, int] = {}
- self.float_literals: dict[float, int] = {}
- self.complex_literals: dict[complex, int] = {}
- self.tuple_literals: dict[tuple[object, ...], int] = {}
- self.frozenset_literals: dict[frozenset[object], int] = {}
- def record_literal(self, value: LiteralValue) -> None:
- """Ensure that the literal value is available in generated code."""
- if value is None or value is True or value is False:
- # These are special cased and always present
- return
- if isinstance(value, str):
- str_literals = self.str_literals
- if value not in str_literals:
- str_literals[value] = len(str_literals)
- elif isinstance(value, bytes):
- bytes_literals = self.bytes_literals
- if value not in bytes_literals:
- bytes_literals[value] = len(bytes_literals)
- elif isinstance(value, int):
- int_literals = self.int_literals
- if value not in int_literals:
- int_literals[value] = len(int_literals)
- elif isinstance(value, float):
- float_literals = self.float_literals
- if value not in float_literals:
- float_literals[value] = len(float_literals)
- elif isinstance(value, complex):
- complex_literals = self.complex_literals
- if value not in complex_literals:
- complex_literals[value] = len(complex_literals)
- elif isinstance(value, tuple):
- tuple_literals = self.tuple_literals
- if value not in tuple_literals:
- for item in value:
- assert _is_literal_value(item)
- self.record_literal(item)
- tuple_literals[value] = len(tuple_literals)
- elif isinstance(value, frozenset):
- frozenset_literals = self.frozenset_literals
- if value not in frozenset_literals:
- for item in value:
- assert _is_literal_value(item)
- self.record_literal(item)
- frozenset_literals[value] = len(frozenset_literals)
- else:
- assert False, "invalid literal: %r" % value
- def literal_index(self, value: LiteralValue) -> int:
- """Return the index to the literals array for given value."""
- # The array contains first None and booleans, followed by all str values,
- # followed by bytes values, etc.
- if value is None:
- return 0
- elif value is False:
- return 1
- elif value is True:
- return 2
- n = NUM_SINGLETONS
- if isinstance(value, str):
- return n + self.str_literals[value]
- n += len(self.str_literals)
- if isinstance(value, bytes):
- return n + self.bytes_literals[value]
- n += len(self.bytes_literals)
- if isinstance(value, int):
- return n + self.int_literals[value]
- n += len(self.int_literals)
- if isinstance(value, float):
- return n + self.float_literals[value]
- n += len(self.float_literals)
- if isinstance(value, complex):
- return n + self.complex_literals[value]
- n += len(self.complex_literals)
- if isinstance(value, tuple):
- return n + self.tuple_literals[value]
- n += len(self.tuple_literals)
- if isinstance(value, frozenset):
- return n + self.frozenset_literals[value]
- assert False, "invalid literal: %r" % value
- def num_literals(self) -> int:
- # The first three are for None, True and False
- return (
- NUM_SINGLETONS
- + len(self.str_literals)
- + len(self.bytes_literals)
- + len(self.int_literals)
- + len(self.float_literals)
- + len(self.complex_literals)
- + len(self.tuple_literals)
- + len(self.frozenset_literals)
- )
- # The following methods return the C encodings of literal values
- # of different types
- def encoded_str_values(self) -> list[bytes]:
- return _encode_str_values(self.str_literals)
- def encoded_int_values(self) -> list[bytes]:
- return _encode_int_values(self.int_literals)
- def encoded_bytes_values(self) -> list[bytes]:
- return _encode_bytes_values(self.bytes_literals)
- def encoded_float_values(self) -> list[str]:
- return _encode_float_values(self.float_literals)
- def encoded_complex_values(self) -> list[str]:
- return _encode_complex_values(self.complex_literals)
- def encoded_tuple_values(self) -> list[str]:
- return self._encode_collection_values(self.tuple_literals)
- def encoded_frozenset_values(self) -> list[str]:
- return self._encode_collection_values(self.frozenset_literals)
- def _encode_collection_values(
- self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
- ) -> list[str]:
- """Encode tuple/frozenset values into a C array.
- The format of the result is like this:
- <number of collections>
- <length of the first collection>
- <literal index of first item>
- ...
- <literal index of last item>
- <length of the second collection>
- ...
- """
- value_by_index = {index: value for value, index in values.items()}
- result = []
- count = len(values)
- result.append(str(count))
- for i in range(count):
- value = value_by_index[i]
- result.append(str(len(value)))
- for item in value:
- assert _is_literal_value(item)
- index = self.literal_index(item)
- result.append(str(index))
- return result
- def _encode_str_values(values: dict[str, int]) -> list[bytes]:
- value_by_index = {index: value for value, index in values.items()}
- result = []
- line: list[bytes] = []
- line_len = 0
- for i in range(len(values)):
- value = value_by_index[i]
- c_literal = format_str_literal(value)
- c_len = len(c_literal)
- if line_len > 0 and line_len + c_len > 70:
- result.append(format_int(len(line)) + b"".join(line))
- line = []
- line_len = 0
- line.append(c_literal)
- line_len += c_len
- if line:
- result.append(format_int(len(line)) + b"".join(line))
- result.append(b"")
- return result
- def _encode_bytes_values(values: dict[bytes, int]) -> list[bytes]:
- value_by_index = {index: value for value, index in values.items()}
- result = []
- line: list[bytes] = []
- line_len = 0
- for i in range(len(values)):
- value = value_by_index[i]
- c_init = format_int(len(value))
- c_len = len(c_init) + len(value)
- if line_len > 0 and line_len + c_len > 70:
- result.append(format_int(len(line)) + b"".join(line))
- line = []
- line_len = 0
- line.append(c_init + value)
- line_len += c_len
- if line:
- result.append(format_int(len(line)) + b"".join(line))
- result.append(b"")
- return result
- def format_int(n: int) -> bytes:
- """Format an integer using a variable-length binary encoding."""
- if n < 128:
- a = [n]
- else:
- a = []
- while n > 0:
- a.insert(0, n & 0x7F)
- n >>= 7
- for i in range(len(a) - 1):
- # If the highest bit is set, more 7-bit digits follow
- a[i] |= 0x80
- return bytes(a)
- def format_str_literal(s: str) -> bytes:
- utf8 = s.encode("utf-8")
- return format_int(len(utf8)) + utf8
- def _encode_int_values(values: dict[int, int]) -> list[bytes]:
- """Encode int values into C strings.
- Values are stored in base 10 and separated by 0 bytes.
- """
- value_by_index = {index: value for value, index in values.items()}
- result = []
- line: list[bytes] = []
- line_len = 0
- for i in range(len(values)):
- value = value_by_index[i]
- encoded = b"%d" % value
- if line_len > 0 and line_len + len(encoded) > 70:
- result.append(format_int(len(line)) + b"\0".join(line))
- line = []
- line_len = 0
- line.append(encoded)
- line_len += len(encoded)
- if line:
- result.append(format_int(len(line)) + b"\0".join(line))
- result.append(b"")
- return result
- def float_to_c(x: float) -> str:
- """Return C literal representation of a float value."""
- s = str(x)
- if s == "inf":
- return "INFINITY"
- elif s == "-inf":
- return "-INFINITY"
- elif s == "nan":
- return "NAN"
- return s
- def _encode_float_values(values: dict[float, int]) -> list[str]:
- """Encode float values into a C array values.
- The result contains the number of values followed by individual values.
- """
- value_by_index = {index: value for value, index in values.items()}
- result = []
- num = len(values)
- result.append(str(num))
- for i in range(num):
- value = value_by_index[i]
- result.append(float_to_c(value))
- return result
- def _encode_complex_values(values: dict[complex, int]) -> list[str]:
- """Encode float values into a C array values.
- The result contains the number of values followed by pairs of doubles
- representing complex numbers.
- """
- value_by_index = {index: value for value, index in values.items()}
- result = []
- num = len(values)
- result.append(str(num))
- for i in range(num):
- value = value_by_index[i]
- result.append(float_to_c(value.real))
- result.append(float_to_c(value.imag))
- return result
|