| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- """Semantic analysis of call-based Enum definitions.
- This is conceptually part of mypy.semanal (semantic analyzer pass 2).
- """
- from __future__ import annotations
- from typing import cast
- from typing_extensions import Final
- from mypy.nodes import (
- ARG_NAMED,
- ARG_POS,
- MDEF,
- AssignmentStmt,
- CallExpr,
- Context,
- DictExpr,
- EnumCallExpr,
- Expression,
- ListExpr,
- MemberExpr,
- NameExpr,
- RefExpr,
- StrExpr,
- SymbolTableNode,
- TupleExpr,
- TypeInfo,
- Var,
- is_StrExpr_list,
- )
- from mypy.options import Options
- from mypy.semanal_shared import SemanticAnalyzerInterface
- from mypy.types import ENUM_REMOVED_PROPS, LiteralType, get_proper_type
- # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
- # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
- ENUM_BASES: Final = frozenset(
- ("enum.Enum", "enum.IntEnum", "enum.Flag", "enum.IntFlag", "enum.StrEnum")
- )
- ENUM_SPECIAL_PROPS: Final = frozenset(
- (
- "name",
- "value",
- "_name_",
- "_value_",
- *ENUM_REMOVED_PROPS,
- # Also attributes from `object`:
- "__module__",
- "__annotations__",
- "__doc__",
- "__slots__",
- "__dict__",
- )
- )
- class EnumCallAnalyzer:
- def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None:
- self.options = options
- self.api = api
- def process_enum_call(self, s: AssignmentStmt, is_func_scope: bool) -> bool:
- """Check if s defines an Enum; if yes, store the definition in symbol table.
- Return True if this looks like an Enum definition (but maybe with errors),
- otherwise return False.
- """
- if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
- return False
- lvalue = s.lvalues[0]
- name = lvalue.name
- enum_call = self.check_enum_call(s.rvalue, name, is_func_scope)
- if enum_call is None:
- return False
- if isinstance(lvalue, MemberExpr):
- self.fail("Enum type as attribute is not supported", lvalue)
- return False
- # Yes, it's a valid Enum definition. Add it to the symbol table.
- self.api.add_symbol(name, enum_call, s)
- return True
- def check_enum_call(
- self, node: Expression, var_name: str, is_func_scope: bool
- ) -> TypeInfo | None:
- """Check if a call defines an Enum.
- Example:
- A = enum.Enum('A', 'foo bar')
- is equivalent to:
- class A(enum.Enum):
- foo = 1
- bar = 2
- """
- if not isinstance(node, CallExpr):
- return None
- call = node
- callee = call.callee
- if not isinstance(callee, RefExpr):
- return None
- fullname = callee.fullname
- if fullname not in ENUM_BASES:
- return None
- items, values, ok = self.parse_enum_call_args(call, fullname.split(".")[-1])
- if not ok:
- # Error. Construct dummy return value.
- info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line)
- else:
- name = cast(StrExpr, call.args[0]).value
- if name != var_name or is_func_scope:
- # Give it a unique name derived from the line number.
- name += "@" + str(call.line)
- info = self.build_enum_call_typeinfo(name, items, fullname, call.line)
- # Store generated TypeInfo under both names, see semanal_namedtuple for more details.
- if name != var_name or is_func_scope:
- self.api.add_symbol_skip_local(name, info)
- call.analyzed = EnumCallExpr(info, items, values)
- call.analyzed.set_line(call)
- info.line = node.line
- return info
- def build_enum_call_typeinfo(
- self, name: str, items: list[str], fullname: str, line: int
- ) -> TypeInfo:
- base = self.api.named_type_or_none(fullname)
- assert base is not None
- info = self.api.basic_new_typeinfo(name, base, line)
- info.metaclass_type = info.calculate_metaclass_type()
- info.is_enum = True
- for item in items:
- var = Var(item)
- var.info = info
- var.is_property = True
- var._fullname = f"{info.fullname}.{item}"
- info.names[item] = SymbolTableNode(MDEF, var)
- return info
- def parse_enum_call_args(
- self, call: CallExpr, class_name: str
- ) -> tuple[list[str], list[Expression | None], bool]:
- """Parse arguments of an Enum call.
- Return a tuple of fields, values, was there an error.
- """
- args = call.args
- if not all([arg_kind in [ARG_POS, ARG_NAMED] for arg_kind in call.arg_kinds]):
- return self.fail_enum_call_arg(f"Unexpected arguments to {class_name}()", call)
- if len(args) < 2:
- return self.fail_enum_call_arg(f"Too few arguments for {class_name}()", call)
- if len(args) > 6:
- return self.fail_enum_call_arg(f"Too many arguments for {class_name}()", call)
- valid_name = [None, "value", "names", "module", "qualname", "type", "start"]
- for arg_name in call.arg_names:
- if arg_name not in valid_name:
- self.fail_enum_call_arg(f'Unexpected keyword argument "{arg_name}"', call)
- value, names = None, None
- for arg_name, arg in zip(call.arg_names, args):
- if arg_name == "value":
- value = arg
- if arg_name == "names":
- names = arg
- if value is None:
- value = args[0]
- if names is None:
- names = args[1]
- if not isinstance(value, StrExpr):
- return self.fail_enum_call_arg(
- f"{class_name}() expects a string literal as the first argument", call
- )
- items = []
- values: list[Expression | None] = []
- if isinstance(names, StrExpr):
- fields = names.value
- for field in fields.replace(",", " ").split():
- items.append(field)
- elif isinstance(names, (TupleExpr, ListExpr)):
- seq_items = names.items
- if is_StrExpr_list(seq_items):
- items = [seq_item.value for seq_item in seq_items]
- elif all(
- isinstance(seq_item, (TupleExpr, ListExpr))
- and len(seq_item.items) == 2
- and isinstance(seq_item.items[0], StrExpr)
- for seq_item in seq_items
- ):
- for seq_item in seq_items:
- assert isinstance(seq_item, (TupleExpr, ListExpr))
- name, value = seq_item.items
- assert isinstance(name, StrExpr)
- items.append(name.value)
- values.append(value)
- else:
- return self.fail_enum_call_arg(
- "%s() with tuple or list expects strings or (name, value) pairs" % class_name,
- call,
- )
- elif isinstance(names, DictExpr):
- for key, value in names.items:
- if not isinstance(key, StrExpr):
- return self.fail_enum_call_arg(
- f"{class_name}() with dict literal requires string literals", call
- )
- items.append(key.value)
- values.append(value)
- elif isinstance(args[1], RefExpr) and isinstance(args[1].node, Var):
- proper_type = get_proper_type(args[1].node.type)
- if (
- proper_type is not None
- and isinstance(proper_type, LiteralType)
- and isinstance(proper_type.value, str)
- ):
- fields = proper_type.value
- for field in fields.replace(",", " ").split():
- items.append(field)
- elif args[1].node.is_final and isinstance(args[1].node.final_value, str):
- fields = args[1].node.final_value
- for field in fields.replace(",", " ").split():
- items.append(field)
- else:
- return self.fail_enum_call_arg(
- "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
- % class_name,
- call,
- )
- else:
- # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
- return self.fail_enum_call_arg(
- "Second argument of %s() must be string, tuple, list or dict literal for mypy to determine Enum members"
- % class_name,
- call,
- )
- if not items:
- return self.fail_enum_call_arg(f"{class_name}() needs at least one item", call)
- if not values:
- values = [None] * len(items)
- assert len(items) == len(values)
- return items, values, True
- def fail_enum_call_arg(
- self, message: str, context: Context
- ) -> tuple[list[str], list[Expression | None], bool]:
- self.fail(message, context)
- return [], [], False
- # Helpers
- def fail(self, msg: str, ctx: Context) -> None:
- self.api.fail(msg, ctx)
|