semanal_enum.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """Semantic analysis of call-based Enum definitions.
  2. This is conceptually part of mypy.semanal (semantic analyzer pass 2).
  3. """
  4. from __future__ import annotations
  5. from typing import cast
  6. from typing_extensions import Final
  7. from mypy.nodes import (
  8. ARG_NAMED,
  9. ARG_POS,
  10. MDEF,
  11. AssignmentStmt,
  12. CallExpr,
  13. Context,
  14. DictExpr,
  15. EnumCallExpr,
  16. Expression,
  17. ListExpr,
  18. MemberExpr,
  19. NameExpr,
  20. RefExpr,
  21. StrExpr,
  22. SymbolTableNode,
  23. TupleExpr,
  24. TypeInfo,
  25. Var,
  26. is_StrExpr_list,
  27. )
  28. from mypy.options import Options
  29. from mypy.semanal_shared import SemanticAnalyzerInterface
  30. from mypy.types import ENUM_REMOVED_PROPS, LiteralType, get_proper_type
  31. # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
  32. # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
  33. ENUM_BASES: Final = frozenset(
  34. ("enum.Enum", "enum.IntEnum", "enum.Flag", "enum.IntFlag", "enum.StrEnum")
  35. )
  36. ENUM_SPECIAL_PROPS: Final = frozenset(
  37. (
  38. "name",
  39. "value",
  40. "_name_",
  41. "_value_",
  42. *ENUM_REMOVED_PROPS,
  43. # Also attributes from `object`:
  44. "__module__",
  45. "__annotations__",
  46. "__doc__",
  47. "__slots__",
  48. "__dict__",
  49. )
  50. )
  51. class EnumCallAnalyzer:
  52. def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None:
  53. self.options = options
  54. self.api = api
  55. def process_enum_call(self, s: AssignmentStmt, is_func_scope: bool) -> bool:
  56. """Check if s defines an Enum; if yes, store the definition in symbol table.
  57. Return True if this looks like an Enum definition (but maybe with errors),
  58. otherwise return False.
  59. """
  60. if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
  61. return False
  62. lvalue = s.lvalues[0]
  63. name = lvalue.name
  64. enum_call = self.check_enum_call(s.rvalue, name, is_func_scope)
  65. if enum_call is None:
  66. return False
  67. if isinstance(lvalue, MemberExpr):
  68. self.fail("Enum type as attribute is not supported", lvalue)
  69. return False
  70. # Yes, it's a valid Enum definition. Add it to the symbol table.
  71. self.api.add_symbol(name, enum_call, s)
  72. return True
  73. def check_enum_call(
  74. self, node: Expression, var_name: str, is_func_scope: bool
  75. ) -> TypeInfo | None:
  76. """Check if a call defines an Enum.
  77. Example:
  78. A = enum.Enum('A', 'foo bar')
  79. is equivalent to:
  80. class A(enum.Enum):
  81. foo = 1
  82. bar = 2
  83. """
  84. if not isinstance(node, CallExpr):
  85. return None
  86. call = node
  87. callee = call.callee
  88. if not isinstance(callee, RefExpr):
  89. return None
  90. fullname = callee.fullname
  91. if fullname not in ENUM_BASES:
  92. return None
  93. items, values, ok = self.parse_enum_call_args(call, fullname.split(".")[-1])
  94. if not ok:
  95. # Error. Construct dummy return value.
  96. info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line)
  97. else:
  98. name = cast(StrExpr, call.args[0]).value
  99. if name != var_name or is_func_scope:
  100. # Give it a unique name derived from the line number.
  101. name += "@" + str(call.line)
  102. info = self.build_enum_call_typeinfo(name, items, fullname, call.line)
  103. # Store generated TypeInfo under both names, see semanal_namedtuple for more details.
  104. if name != var_name or is_func_scope:
  105. self.api.add_symbol_skip_local(name, info)
  106. call.analyzed = EnumCallExpr(info, items, values)
  107. call.analyzed.set_line(call)
  108. info.line = node.line
  109. return info
  110. def build_enum_call_typeinfo(
  111. self, name: str, items: list[str], fullname: str, line: int
  112. ) -> TypeInfo:
  113. base = self.api.named_type_or_none(fullname)
  114. assert base is not None
  115. info = self.api.basic_new_typeinfo(name, base, line)
  116. info.metaclass_type = info.calculate_metaclass_type()
  117. info.is_enum = True
  118. for item in items:
  119. var = Var(item)
  120. var.info = info
  121. var.is_property = True
  122. var._fullname = f"{info.fullname}.{item}"
  123. info.names[item] = SymbolTableNode(MDEF, var)
  124. return info
  125. def parse_enum_call_args(
  126. self, call: CallExpr, class_name: str
  127. ) -> tuple[list[str], list[Expression | None], bool]:
  128. """Parse arguments of an Enum call.
  129. Return a tuple of fields, values, was there an error.
  130. """
  131. args = call.args
  132. if not all([arg_kind in [ARG_POS, ARG_NAMED] for arg_kind in call.arg_kinds]):
  133. return self.fail_enum_call_arg(f"Unexpected arguments to {class_name}()", call)
  134. if len(args) < 2:
  135. return self.fail_enum_call_arg(f"Too few arguments for {class_name}()", call)
  136. if len(args) > 6:
  137. return self.fail_enum_call_arg(f"Too many arguments for {class_name}()", call)
  138. valid_name = [None, "value", "names", "module", "qualname", "type", "start"]
  139. for arg_name in call.arg_names:
  140. if arg_name not in valid_name:
  141. self.fail_enum_call_arg(f'Unexpected keyword argument "{arg_name}"', call)
  142. value, names = None, None
  143. for arg_name, arg in zip(call.arg_names, args):
  144. if arg_name == "value":
  145. value = arg
  146. if arg_name == "names":
  147. names = arg
  148. if value is None:
  149. value = args[0]
  150. if names is None:
  151. names = args[1]
  152. if not isinstance(value, StrExpr):
  153. return self.fail_enum_call_arg(
  154. f"{class_name}() expects a string literal as the first argument", call
  155. )
  156. items = []
  157. values: list[Expression | None] = []
  158. if isinstance(names, StrExpr):
  159. fields = names.value
  160. for field in fields.replace(",", " ").split():
  161. items.append(field)
  162. elif isinstance(names, (TupleExpr, ListExpr)):
  163. seq_items = names.items
  164. if is_StrExpr_list(seq_items):
  165. items = [seq_item.value for seq_item in seq_items]
  166. elif all(
  167. isinstance(seq_item, (TupleExpr, ListExpr))
  168. and len(seq_item.items) == 2
  169. and isinstance(seq_item.items[0], StrExpr)
  170. for seq_item in seq_items
  171. ):
  172. for seq_item in seq_items:
  173. assert isinstance(seq_item, (TupleExpr, ListExpr))
  174. name, value = seq_item.items
  175. assert isinstance(name, StrExpr)
  176. items.append(name.value)
  177. values.append(value)
  178. else:
  179. return self.fail_enum_call_arg(
  180. "%s() with tuple or list expects strings or (name, value) pairs" % class_name,
  181. call,
  182. )
  183. elif isinstance(names, DictExpr):
  184. for key, value in names.items:
  185. if not isinstance(key, StrExpr):
  186. return self.fail_enum_call_arg(
  187. f"{class_name}() with dict literal requires string literals", call
  188. )
  189. items.append(key.value)
  190. values.append(value)
  191. elif isinstance(args[1], RefExpr) and isinstance(args[1].node, Var):
  192. proper_type = get_proper_type(args[1].node.type)
  193. if (
  194. proper_type is not None
  195. and isinstance(proper_type, LiteralType)
  196. and isinstance(proper_type.value, str)
  197. ):
  198. fields = proper_type.value
  199. for field in fields.replace(",", " ").split():
  200. items.append(field)
  201. elif args[1].node.is_final and isinstance(args[1].node.final_value, str):
  202. fields = args[1].node.final_value
  203. for field in fields.replace(",", " ").split():
  204. items.append(field)
  205. else:
  206. return self.fail_enum_call_arg(
  207. "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
  208. % class_name,
  209. call,
  210. )
  211. else:
  212. # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
  213. return self.fail_enum_call_arg(
  214. "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
  215. % class_name,
  216. call,
  217. )
  218. if not items:
  219. return self.fail_enum_call_arg(f"{class_name}() needs at least one item", call)
  220. if not values:
  221. values = [None] * len(items)
  222. assert len(items) == len(values)
  223. return items, values, True
  224. def fail_enum_call_arg(
  225. self, message: str, context: Context
  226. ) -> tuple[list[str], list[Expression | None], bool]:
  227. self.fail(message, context)
  228. return [], [], False
  229. # Helpers
  230. def fail(self, msg: str, ctx: Context) -> None:
  231. self.api.fail(msg, ctx)