semanal_enum.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """Semantic analysis of call-based Enum definitions.
  2. This is conceptually part of mypy.semanal (semantic analyzer pass 2).
  3. """
  4. from __future__ import annotations
  5. from typing import Final, cast
  6. from mypy.nodes import (
  7. ARG_NAMED,
  8. ARG_POS,
  9. MDEF,
  10. AssignmentStmt,
  11. CallExpr,
  12. Context,
  13. DictExpr,
  14. EnumCallExpr,
  15. Expression,
  16. ListExpr,
  17. MemberExpr,
  18. NameExpr,
  19. RefExpr,
  20. StrExpr,
  21. SymbolTableNode,
  22. TupleExpr,
  23. TypeInfo,
  24. Var,
  25. is_StrExpr_list,
  26. )
  27. from mypy.options import Options
  28. from mypy.semanal_shared import SemanticAnalyzerInterface
  29. from mypy.types import ENUM_REMOVED_PROPS, LiteralType, get_proper_type
  30. # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
  31. # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
  32. ENUM_BASES: Final = frozenset(
  33. ("enum.Enum", "enum.IntEnum", "enum.Flag", "enum.IntFlag", "enum.StrEnum")
  34. )
  35. ENUM_SPECIAL_PROPS: Final = frozenset(
  36. (
  37. "name",
  38. "value",
  39. "_name_",
  40. "_value_",
  41. *ENUM_REMOVED_PROPS,
  42. # Also attributes from `object`:
  43. "__module__",
  44. "__annotations__",
  45. "__doc__",
  46. "__slots__",
  47. "__dict__",
  48. )
  49. )
  50. class EnumCallAnalyzer:
  51. def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None:
  52. self.options = options
  53. self.api = api
  54. def process_enum_call(self, s: AssignmentStmt, is_func_scope: bool) -> bool:
  55. """Check if s defines an Enum; if yes, store the definition in symbol table.
  56. Return True if this looks like an Enum definition (but maybe with errors),
  57. otherwise return False.
  58. """
  59. if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
  60. return False
  61. lvalue = s.lvalues[0]
  62. name = lvalue.name
  63. enum_call = self.check_enum_call(s.rvalue, name, is_func_scope)
  64. if enum_call is None:
  65. return False
  66. if isinstance(lvalue, MemberExpr):
  67. self.fail("Enum type as attribute is not supported", lvalue)
  68. return False
  69. # Yes, it's a valid Enum definition. Add it to the symbol table.
  70. self.api.add_symbol(name, enum_call, s)
  71. return True
  72. def check_enum_call(
  73. self, node: Expression, var_name: str, is_func_scope: bool
  74. ) -> TypeInfo | None:
  75. """Check if a call defines an Enum.
  76. Example:
  77. A = enum.Enum('A', 'foo bar')
  78. is equivalent to:
  79. class A(enum.Enum):
  80. foo = 1
  81. bar = 2
  82. """
  83. if not isinstance(node, CallExpr):
  84. return None
  85. call = node
  86. callee = call.callee
  87. if not isinstance(callee, RefExpr):
  88. return None
  89. fullname = callee.fullname
  90. if fullname not in ENUM_BASES:
  91. return None
  92. items, values, ok = self.parse_enum_call_args(call, fullname.split(".")[-1])
  93. if not ok:
  94. # Error. Construct dummy return value.
  95. info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line)
  96. else:
  97. name = cast(StrExpr, call.args[0]).value
  98. if name != var_name or is_func_scope:
  99. # Give it a unique name derived from the line number.
  100. name += "@" + str(call.line)
  101. info = self.build_enum_call_typeinfo(name, items, fullname, call.line)
  102. # Store generated TypeInfo under both names, see semanal_namedtuple for more details.
  103. if name != var_name or is_func_scope:
  104. self.api.add_symbol_skip_local(name, info)
  105. call.analyzed = EnumCallExpr(info, items, values)
  106. call.analyzed.set_line(call)
  107. info.line = node.line
  108. return info
  109. def build_enum_call_typeinfo(
  110. self, name: str, items: list[str], fullname: str, line: int
  111. ) -> TypeInfo:
  112. base = self.api.named_type_or_none(fullname)
  113. assert base is not None
  114. info = self.api.basic_new_typeinfo(name, base, line)
  115. info.metaclass_type = info.calculate_metaclass_type()
  116. info.is_enum = True
  117. for item in items:
  118. var = Var(item)
  119. var.info = info
  120. var.is_property = True
  121. var._fullname = f"{info.fullname}.{item}"
  122. info.names[item] = SymbolTableNode(MDEF, var)
  123. return info
  124. def parse_enum_call_args(
  125. self, call: CallExpr, class_name: str
  126. ) -> tuple[list[str], list[Expression | None], bool]:
  127. """Parse arguments of an Enum call.
  128. Return a tuple of fields, values, was there an error.
  129. """
  130. args = call.args
  131. if not all([arg_kind in [ARG_POS, ARG_NAMED] for arg_kind in call.arg_kinds]):
  132. return self.fail_enum_call_arg(f"Unexpected arguments to {class_name}()", call)
  133. if len(args) < 2:
  134. return self.fail_enum_call_arg(f"Too few arguments for {class_name}()", call)
  135. if len(args) > 6:
  136. return self.fail_enum_call_arg(f"Too many arguments for {class_name}()", call)
  137. valid_name = [None, "value", "names", "module", "qualname", "type", "start"]
  138. for arg_name in call.arg_names:
  139. if arg_name not in valid_name:
  140. self.fail_enum_call_arg(f'Unexpected keyword argument "{arg_name}"', call)
  141. value, names = None, None
  142. for arg_name, arg in zip(call.arg_names, args):
  143. if arg_name == "value":
  144. value = arg
  145. if arg_name == "names":
  146. names = arg
  147. if value is None:
  148. value = args[0]
  149. if names is None:
  150. names = args[1]
  151. if not isinstance(value, StrExpr):
  152. return self.fail_enum_call_arg(
  153. f"{class_name}() expects a string literal as the first argument", call
  154. )
  155. items = []
  156. values: list[Expression | None] = []
  157. if isinstance(names, StrExpr):
  158. fields = names.value
  159. for field in fields.replace(",", " ").split():
  160. items.append(field)
  161. elif isinstance(names, (TupleExpr, ListExpr)):
  162. seq_items = names.items
  163. if is_StrExpr_list(seq_items):
  164. items = [seq_item.value for seq_item in seq_items]
  165. elif all(
  166. isinstance(seq_item, (TupleExpr, ListExpr))
  167. and len(seq_item.items) == 2
  168. and isinstance(seq_item.items[0], StrExpr)
  169. for seq_item in seq_items
  170. ):
  171. for seq_item in seq_items:
  172. assert isinstance(seq_item, (TupleExpr, ListExpr))
  173. name, value = seq_item.items
  174. assert isinstance(name, StrExpr)
  175. items.append(name.value)
  176. values.append(value)
  177. else:
  178. return self.fail_enum_call_arg(
  179. "%s() with tuple or list expects strings or (name, value) pairs" % class_name,
  180. call,
  181. )
  182. elif isinstance(names, DictExpr):
  183. for key, value in names.items:
  184. if not isinstance(key, StrExpr):
  185. return self.fail_enum_call_arg(
  186. f"{class_name}() with dict literal requires string literals", call
  187. )
  188. items.append(key.value)
  189. values.append(value)
  190. elif isinstance(args[1], RefExpr) and isinstance(args[1].node, Var):
  191. proper_type = get_proper_type(args[1].node.type)
  192. if (
  193. proper_type is not None
  194. and isinstance(proper_type, LiteralType)
  195. and isinstance(proper_type.value, str)
  196. ):
  197. fields = proper_type.value
  198. for field in fields.replace(",", " ").split():
  199. items.append(field)
  200. elif args[1].node.is_final and isinstance(args[1].node.final_value, str):
  201. fields = args[1].node.final_value
  202. for field in fields.replace(",", " ").split():
  203. items.append(field)
  204. else:
  205. return self.fail_enum_call_arg(
  206. "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
  207. % class_name,
  208. call,
  209. )
  210. else:
  211. # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
  212. return self.fail_enum_call_arg(
  213. "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
  214. % class_name,
  215. call,
  216. )
  217. if not items:
  218. return self.fail_enum_call_arg(f"{class_name}() needs at least one item", call)
  219. if not values:
  220. values = [None] * len(items)
  221. assert len(items) == len(values)
  222. return items, values, True
  223. def fail_enum_call_arg(
  224. self, message: str, context: Context
  225. ) -> tuple[list[str], list[Expression | None], bool]:
  226. self.fail(message, context)
  227. return [], [], False
  228. # Helpers
  229. def fail(self, msg: str, ctx: Context) -> None:
  230. self.api.fail(msg, ctx)