checkpattern.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. """Pattern checker. This file is conceptually part of TypeChecker."""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from typing import NamedTuple
  5. from typing_extensions import Final
  6. import mypy.checker
  7. from mypy import message_registry
  8. from mypy.checkmember import analyze_member_access
  9. from mypy.expandtype import expand_type_by_instance
  10. from mypy.join import join_types
  11. from mypy.literals import literal_hash
  12. from mypy.maptype import map_instance_to_supertype
  13. from mypy.meet import narrow_declared_type
  14. from mypy.messages import MessageBuilder
  15. from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
  16. from mypy.options import Options
  17. from mypy.patterns import (
  18. AsPattern,
  19. ClassPattern,
  20. MappingPattern,
  21. OrPattern,
  22. Pattern,
  23. SequencePattern,
  24. SingletonPattern,
  25. StarredPattern,
  26. ValuePattern,
  27. )
  28. from mypy.plugin import Plugin
  29. from mypy.subtypes import is_subtype
  30. from mypy.typeops import (
  31. coerce_to_literal,
  32. make_simplified_union,
  33. try_getting_str_literals_from_type,
  34. tuple_fallback,
  35. )
  36. from mypy.types import (
  37. AnyType,
  38. Instance,
  39. LiteralType,
  40. NoneType,
  41. ProperType,
  42. TupleType,
  43. Type,
  44. TypedDictType,
  45. TypeOfAny,
  46. UninhabitedType,
  47. UnionType,
  48. get_proper_type,
  49. )
  50. from mypy.typevars import fill_typevars
  51. from mypy.visitor import PatternVisitor
  52. self_match_type_names: Final = [
  53. "builtins.bool",
  54. "builtins.bytearray",
  55. "builtins.bytes",
  56. "builtins.dict",
  57. "builtins.float",
  58. "builtins.frozenset",
  59. "builtins.int",
  60. "builtins.list",
  61. "builtins.set",
  62. "builtins.str",
  63. "builtins.tuple",
  64. ]
  65. non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"]
  66. # For every Pattern a PatternType can be calculated. This requires recursively calculating
  67. # the PatternTypes of the sub-patterns first.
  68. # Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
  69. class PatternType(NamedTuple):
  70. type: Type # The type the match subject can be narrowed to
  71. rest_type: Type # The remaining type if the pattern didn't match
  72. captures: dict[Expression, Type] # The variables captured by the pattern
  73. class PatternChecker(PatternVisitor[PatternType]):
  74. """Pattern checker.
  75. This class checks if a pattern can match a type, what the type can be narrowed to, and what
  76. type capture patterns should be inferred as.
  77. """
  78. # Some services are provided by a TypeChecker instance.
  79. chk: mypy.checker.TypeChecker
  80. # This is shared with TypeChecker, but stored also here for convenience.
  81. msg: MessageBuilder
  82. # Currently unused
  83. plugin: Plugin
  84. # The expression being matched against the pattern
  85. subject: Expression
  86. subject_type: Type
  87. # Type of the subject to check the (sub)pattern against
  88. type_context: list[Type]
  89. # Types that match against self instead of their __match_args__ if used as a class pattern
  90. # Filled in from self_match_type_names
  91. self_match_types: list[Type]
  92. # Types that are sequences, but don't match sequence patterns. Filled in from
  93. # non_sequence_match_type_names
  94. non_sequence_match_types: list[Type]
  95. options: Options
  96. def __init__(
  97. self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: Plugin, options: Options
  98. ) -> None:
  99. self.chk = chk
  100. self.msg = msg
  101. self.plugin = plugin
  102. self.type_context = []
  103. self.self_match_types = self.generate_types_from_names(self_match_type_names)
  104. self.non_sequence_match_types = self.generate_types_from_names(
  105. non_sequence_match_type_names
  106. )
  107. self.options = options
  108. def accept(self, o: Pattern, type_context: Type) -> PatternType:
  109. self.type_context.append(type_context)
  110. result = o.accept(self)
  111. self.type_context.pop()
  112. return result
  113. def visit_as_pattern(self, o: AsPattern) -> PatternType:
  114. current_type = self.type_context[-1]
  115. if o.pattern is not None:
  116. pattern_type = self.accept(o.pattern, current_type)
  117. typ, rest_type, type_map = pattern_type
  118. else:
  119. typ, rest_type, type_map = current_type, UninhabitedType(), {}
  120. if not is_uninhabited(typ) and o.name is not None:
  121. typ, _ = self.chk.conditional_types_with_intersection(
  122. current_type, [get_type_range(typ)], o, default=current_type
  123. )
  124. if not is_uninhabited(typ):
  125. type_map[o.name] = typ
  126. return PatternType(typ, rest_type, type_map)
  127. def visit_or_pattern(self, o: OrPattern) -> PatternType:
  128. current_type = self.type_context[-1]
  129. #
  130. # Check all the subpatterns
  131. #
  132. pattern_types = []
  133. for pattern in o.patterns:
  134. pattern_type = self.accept(pattern, current_type)
  135. pattern_types.append(pattern_type)
  136. current_type = pattern_type.rest_type
  137. #
  138. # Collect the final type
  139. #
  140. types = []
  141. for pattern_type in pattern_types:
  142. if not is_uninhabited(pattern_type.type):
  143. types.append(pattern_type.type)
  144. #
  145. # Check the capture types
  146. #
  147. capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
  148. # Collect captures from the first subpattern
  149. for expr, typ in pattern_types[0].captures.items():
  150. node = get_var(expr)
  151. capture_types[node].append((expr, typ))
  152. # Check if other subpatterns capture the same names
  153. for i, pattern_type in enumerate(pattern_types[1:]):
  154. vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
  155. if capture_types.keys() != vars:
  156. self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
  157. for expr, typ in pattern_type.captures.items():
  158. node = get_var(expr)
  159. capture_types[node].append((expr, typ))
  160. captures: dict[Expression, Type] = {}
  161. for var, capture_list in capture_types.items():
  162. typ = UninhabitedType()
  163. for _, other in capture_list:
  164. typ = join_types(typ, other)
  165. captures[capture_list[0][0]] = typ
  166. union_type = make_simplified_union(types)
  167. return PatternType(union_type, current_type, captures)
  168. def visit_value_pattern(self, o: ValuePattern) -> PatternType:
  169. current_type = self.type_context[-1]
  170. typ = self.chk.expr_checker.accept(o.expr)
  171. typ = coerce_to_literal(typ)
  172. narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
  173. current_type, [get_type_range(typ)], o, default=current_type
  174. )
  175. if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)):
  176. return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {})
  177. return PatternType(narrowed_type, rest_type, {})
  178. def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
  179. current_type = self.type_context[-1]
  180. value: bool | None = o.value
  181. if isinstance(value, bool):
  182. typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
  183. elif value is None:
  184. typ = NoneType()
  185. else:
  186. assert False
  187. narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
  188. current_type, [get_type_range(typ)], o, default=current_type
  189. )
  190. return PatternType(narrowed_type, rest_type, {})
  191. def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
  192. #
  193. # check for existence of a starred pattern
  194. #
  195. current_type = get_proper_type(self.type_context[-1])
  196. if not self.can_match_sequence(current_type):
  197. return self.early_non_match()
  198. star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
  199. star_position: int | None = None
  200. if len(star_positions) == 1:
  201. star_position = star_positions[0]
  202. elif len(star_positions) >= 2:
  203. assert False, "Parser should prevent multiple starred patterns"
  204. required_patterns = len(o.patterns)
  205. if star_position is not None:
  206. required_patterns -= 1
  207. #
  208. # get inner types of original type
  209. #
  210. if isinstance(current_type, TupleType):
  211. inner_types = current_type.items
  212. size_diff = len(inner_types) - required_patterns
  213. if size_diff < 0:
  214. return self.early_non_match()
  215. elif size_diff > 0 and star_position is None:
  216. return self.early_non_match()
  217. else:
  218. inner_type = self.get_sequence_type(current_type, o)
  219. if inner_type is None:
  220. inner_type = self.chk.named_type("builtins.object")
  221. inner_types = [inner_type] * len(o.patterns)
  222. #
  223. # match inner patterns
  224. #
  225. contracted_new_inner_types: list[Type] = []
  226. contracted_rest_inner_types: list[Type] = []
  227. captures: dict[Expression, Type] = {}
  228. contracted_inner_types = self.contract_starred_pattern_types(
  229. inner_types, star_position, required_patterns
  230. )
  231. for p, t in zip(o.patterns, contracted_inner_types):
  232. pattern_type = self.accept(p, t)
  233. typ, rest, type_map = pattern_type
  234. contracted_new_inner_types.append(typ)
  235. contracted_rest_inner_types.append(rest)
  236. self.update_type_map(captures, type_map)
  237. new_inner_types = self.expand_starred_pattern_types(
  238. contracted_new_inner_types, star_position, len(inner_types)
  239. )
  240. rest_inner_types = self.expand_starred_pattern_types(
  241. contracted_rest_inner_types, star_position, len(inner_types)
  242. )
  243. #
  244. # Calculate new type
  245. #
  246. new_type: Type
  247. rest_type: Type = current_type
  248. if isinstance(current_type, TupleType):
  249. narrowed_inner_types = []
  250. inner_rest_types = []
  251. for inner_type, new_inner_type in zip(inner_types, new_inner_types):
  252. (
  253. narrowed_inner_type,
  254. inner_rest_type,
  255. ) = self.chk.conditional_types_with_intersection(
  256. new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type
  257. )
  258. narrowed_inner_types.append(narrowed_inner_type)
  259. inner_rest_types.append(inner_rest_type)
  260. if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
  261. new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
  262. else:
  263. new_type = UninhabitedType()
  264. if all(is_uninhabited(typ) for typ in inner_rest_types):
  265. # All subpatterns always match, so we can apply negative narrowing
  266. rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
  267. else:
  268. new_inner_type = UninhabitedType()
  269. for typ in new_inner_types:
  270. new_inner_type = join_types(new_inner_type, typ)
  271. new_type = self.construct_sequence_child(current_type, new_inner_type)
  272. if is_subtype(new_type, current_type):
  273. new_type, _ = self.chk.conditional_types_with_intersection(
  274. current_type, [get_type_range(new_type)], o, default=current_type
  275. )
  276. else:
  277. new_type = current_type
  278. return PatternType(new_type, rest_type, captures)
  279. def get_sequence_type(self, t: Type, context: Context) -> Type | None:
  280. t = get_proper_type(t)
  281. if isinstance(t, AnyType):
  282. return AnyType(TypeOfAny.from_another_any, t)
  283. if isinstance(t, UnionType):
  284. items = [self.get_sequence_type(item, context) for item in t.items]
  285. not_none_items = [item for item in items if item is not None]
  286. if not_none_items:
  287. return make_simplified_union(not_none_items)
  288. else:
  289. return None
  290. if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
  291. if isinstance(t, TupleType):
  292. t = tuple_fallback(t)
  293. return self.chk.iterable_item_type(t, context)
  294. else:
  295. return None
  296. def contract_starred_pattern_types(
  297. self, types: list[Type], star_pos: int | None, num_patterns: int
  298. ) -> list[Type]:
  299. """
  300. Contracts a list of types in a sequence pattern depending on the position of a starred
  301. capture pattern.
  302. For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
  303. bytes] the contracted types are [bool, Union[int, str], bytes].
  304. If star_pos in None the types are returned unchanged.
  305. """
  306. if star_pos is None:
  307. return types
  308. new_types = types[:star_pos]
  309. star_length = len(types) - num_patterns
  310. new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
  311. new_types += types[star_pos + star_length :]
  312. return new_types
  313. def expand_starred_pattern_types(
  314. self, types: list[Type], star_pos: int | None, num_types: int
  315. ) -> list[Type]:
  316. """Undoes the contraction done by contract_starred_pattern_types.
  317. For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
  318. to length 4 the result is [bool, int, int, str].
  319. """
  320. if star_pos is None:
  321. return types
  322. new_types = types[:star_pos]
  323. star_length = num_types - len(types) + 1
  324. new_types += [types[star_pos]] * star_length
  325. new_types += types[star_pos + 1 :]
  326. return new_types
  327. def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
  328. captures: dict[Expression, Type] = {}
  329. if o.capture is not None:
  330. list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]])
  331. captures[o.capture] = list_type
  332. return PatternType(self.type_context[-1], UninhabitedType(), captures)
  333. def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
  334. current_type = get_proper_type(self.type_context[-1])
  335. can_match = True
  336. captures: dict[Expression, Type] = {}
  337. for key, value in zip(o.keys, o.values):
  338. inner_type = self.get_mapping_item_type(o, current_type, key)
  339. if inner_type is None:
  340. can_match = False
  341. inner_type = self.chk.named_type("builtins.object")
  342. pattern_type = self.accept(value, inner_type)
  343. if is_uninhabited(pattern_type.type):
  344. can_match = False
  345. else:
  346. self.update_type_map(captures, pattern_type.captures)
  347. if o.rest is not None:
  348. mapping = self.chk.named_type("typing.Mapping")
  349. if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
  350. mapping_inst = map_instance_to_supertype(current_type, mapping.type)
  351. dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
  352. rest_type = Instance(dict_typeinfo, mapping_inst.args)
  353. else:
  354. object_type = self.chk.named_type("builtins.object")
  355. rest_type = self.chk.named_generic_type(
  356. "builtins.dict", [object_type, object_type]
  357. )
  358. captures[o.rest] = rest_type
  359. if can_match:
  360. # We can't narrow the type here, as Mapping key is invariant.
  361. new_type = self.type_context[-1]
  362. else:
  363. new_type = UninhabitedType()
  364. return PatternType(new_type, current_type, captures)
  365. def get_mapping_item_type(
  366. self, pattern: MappingPattern, mapping_type: Type, key: Expression
  367. ) -> Type | None:
  368. mapping_type = get_proper_type(mapping_type)
  369. if isinstance(mapping_type, TypedDictType):
  370. with self.msg.filter_errors() as local_errors:
  371. result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
  372. mapping_type, key
  373. )
  374. has_local_errors = local_errors.has_new_errors()
  375. # If we can't determine the type statically fall back to treating it as a normal
  376. # mapping
  377. if has_local_errors:
  378. with self.msg.filter_errors() as local_errors:
  379. result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
  380. if local_errors.has_new_errors():
  381. result = None
  382. else:
  383. with self.msg.filter_errors():
  384. result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
  385. return result
  386. def get_simple_mapping_item_type(
  387. self, pattern: MappingPattern, mapping_type: Type, key: Expression
  388. ) -> Type:
  389. result, _ = self.chk.expr_checker.check_method_call_by_name(
  390. "__getitem__", mapping_type, [key], [ARG_POS], pattern
  391. )
  392. return result
  393. def visit_class_pattern(self, o: ClassPattern) -> PatternType:
  394. current_type = get_proper_type(self.type_context[-1])
  395. #
  396. # Check class type
  397. #
  398. type_info = o.class_ref.node
  399. if type_info is None:
  400. return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {})
  401. if isinstance(type_info, TypeAlias) and not type_info.no_args:
  402. self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
  403. return self.early_non_match()
  404. if isinstance(type_info, TypeInfo):
  405. any_type = AnyType(TypeOfAny.implementation_artifact)
  406. typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
  407. elif isinstance(type_info, TypeAlias):
  408. typ = type_info.target
  409. else:
  410. if isinstance(type_info, Var) and type_info.type is not None:
  411. name = type_info.type.str_with_options(self.options)
  412. else:
  413. name = type_info.name
  414. self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o.class_ref)
  415. return self.early_non_match()
  416. new_type, rest_type = self.chk.conditional_types_with_intersection(
  417. current_type, [get_type_range(typ)], o, default=current_type
  418. )
  419. if is_uninhabited(new_type):
  420. return self.early_non_match()
  421. # TODO: Do I need this?
  422. narrowed_type = narrow_declared_type(current_type, new_type)
  423. #
  424. # Convert positional to keyword patterns
  425. #
  426. keyword_pairs: list[tuple[str | None, Pattern]] = []
  427. match_arg_set: set[str] = set()
  428. captures: dict[Expression, Type] = {}
  429. if len(o.positionals) != 0:
  430. if self.should_self_match(typ):
  431. if len(o.positionals) > 1:
  432. self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
  433. pattern_type = self.accept(o.positionals[0], narrowed_type)
  434. if not is_uninhabited(pattern_type.type):
  435. return PatternType(
  436. pattern_type.type,
  437. join_types(rest_type, pattern_type.rest_type),
  438. pattern_type.captures,
  439. )
  440. captures = pattern_type.captures
  441. else:
  442. with self.msg.filter_errors() as local_errors:
  443. match_args_type = analyze_member_access(
  444. "__match_args__",
  445. typ,
  446. o,
  447. False,
  448. False,
  449. False,
  450. self.msg,
  451. original_type=typ,
  452. chk=self.chk,
  453. )
  454. has_local_errors = local_errors.has_new_errors()
  455. if has_local_errors:
  456. self.msg.fail(
  457. message_registry.MISSING_MATCH_ARGS.format(
  458. typ.str_with_options(self.options)
  459. ),
  460. o,
  461. )
  462. return self.early_non_match()
  463. proper_match_args_type = get_proper_type(match_args_type)
  464. if isinstance(proper_match_args_type, TupleType):
  465. match_arg_names = get_match_arg_names(proper_match_args_type)
  466. if len(o.positionals) > len(match_arg_names):
  467. self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
  468. return self.early_non_match()
  469. else:
  470. match_arg_names = [None] * len(o.positionals)
  471. for arg_name, pos in zip(match_arg_names, o.positionals):
  472. keyword_pairs.append((arg_name, pos))
  473. if arg_name is not None:
  474. match_arg_set.add(arg_name)
  475. #
  476. # Check for duplicate patterns
  477. #
  478. keyword_arg_set = set()
  479. has_duplicates = False
  480. for key, value in zip(o.keyword_keys, o.keyword_values):
  481. keyword_pairs.append((key, value))
  482. if key in match_arg_set:
  483. self.msg.fail(
  484. message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value
  485. )
  486. has_duplicates = True
  487. elif key in keyword_arg_set:
  488. self.msg.fail(
  489. message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value
  490. )
  491. has_duplicates = True
  492. keyword_arg_set.add(key)
  493. if has_duplicates:
  494. return self.early_non_match()
  495. #
  496. # Check keyword patterns
  497. #
  498. can_match = True
  499. for keyword, pattern in keyword_pairs:
  500. key_type: Type | None = None
  501. with self.msg.filter_errors() as local_errors:
  502. if keyword is not None:
  503. key_type = analyze_member_access(
  504. keyword,
  505. narrowed_type,
  506. pattern,
  507. False,
  508. False,
  509. False,
  510. self.msg,
  511. original_type=new_type,
  512. chk=self.chk,
  513. )
  514. else:
  515. key_type = AnyType(TypeOfAny.from_error)
  516. has_local_errors = local_errors.has_new_errors()
  517. if has_local_errors or key_type is None:
  518. key_type = AnyType(TypeOfAny.from_error)
  519. self.msg.fail(
  520. message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
  521. typ.str_with_options(self.options), keyword
  522. ),
  523. pattern,
  524. )
  525. inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
  526. if is_uninhabited(inner_type):
  527. can_match = False
  528. else:
  529. self.update_type_map(captures, inner_captures)
  530. if not is_uninhabited(inner_rest_type):
  531. rest_type = current_type
  532. if not can_match:
  533. new_type = UninhabitedType()
  534. return PatternType(new_type, rest_type, captures)
  535. def should_self_match(self, typ: Type) -> bool:
  536. typ = get_proper_type(typ)
  537. if isinstance(typ, Instance) and typ.type.is_named_tuple:
  538. return False
  539. for other in self.self_match_types:
  540. if is_subtype(typ, other):
  541. return True
  542. return False
  543. def can_match_sequence(self, typ: ProperType) -> bool:
  544. if isinstance(typ, UnionType):
  545. return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
  546. for other in self.non_sequence_match_types:
  547. # We have to ignore promotions, as memoryview should match, but bytes,
  548. # which it can be promoted to, shouldn't
  549. if is_subtype(typ, other, ignore_promotions=True):
  550. return False
  551. sequence = self.chk.named_type("typing.Sequence")
  552. # If the static type is more general than sequence the actual type could still match
  553. return is_subtype(typ, sequence) or is_subtype(sequence, typ)
  554. def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
  555. types: list[Type] = []
  556. for name in type_names:
  557. try:
  558. types.append(self.chk.named_type(name))
  559. except KeyError as e:
  560. # Some built in types are not defined in all test cases
  561. if not name.startswith("builtins."):
  562. raise e
  563. return types
  564. def update_type_map(
  565. self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
  566. ) -> None:
  567. # Calculating this would not be needed if TypeMap directly used literal hashes instead of
  568. # expressions, as suggested in the TODO above it's definition
  569. already_captured = {literal_hash(expr) for expr in original_type_map}
  570. for expr, typ in extra_type_map.items():
  571. if literal_hash(expr) in already_captured:
  572. node = get_var(expr)
  573. self.msg.fail(
  574. message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
  575. )
  576. else:
  577. original_type_map[expr] = typ
  578. def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
  579. """
  580. If outer_type is a child class of typing.Sequence returns a new instance of
  581. outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
  582. typing.Sequence just returns a Sequence of inner_type
  583. For example:
  584. construct_sequence_child(List[int], str) = List[str]
  585. TODO: this doesn't make sense. For example if one has class S(Sequence[int], Generic[T])
  586. or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
  587. """
  588. proper_type = get_proper_type(outer_type)
  589. if isinstance(proper_type, UnionType):
  590. types = [
  591. self.construct_sequence_child(item, inner_type)
  592. for item in proper_type.items
  593. if self.can_match_sequence(get_proper_type(item))
  594. ]
  595. return make_simplified_union(types)
  596. sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
  597. if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
  598. proper_type = get_proper_type(outer_type)
  599. if isinstance(proper_type, TupleType):
  600. proper_type = tuple_fallback(proper_type)
  601. assert isinstance(proper_type, Instance)
  602. empty_type = fill_typevars(proper_type.type)
  603. partial_type = expand_type_by_instance(empty_type, sequence)
  604. return expand_type_by_instance(partial_type, proper_type)
  605. else:
  606. return sequence
  607. def early_non_match(self) -> PatternType:
  608. return PatternType(UninhabitedType(), self.type_context[-1], {})
  609. def get_match_arg_names(typ: TupleType) -> list[str | None]:
  610. args: list[str | None] = []
  611. for item in typ.items:
  612. values = try_getting_str_literals_from_type(item)
  613. if values is None or len(values) != 1:
  614. args.append(None)
  615. else:
  616. args.append(values[0])
  617. return args
  618. def get_var(expr: Expression) -> Var:
  619. """
  620. Warning: this in only true for expressions captured by a match statement.
  621. Don't call it from anywhere else
  622. """
  623. assert isinstance(expr, NameExpr)
  624. node = expr.node
  625. assert isinstance(node, Var)
  626. return node
  627. def get_type_range(typ: Type) -> mypy.checker.TypeRange:
  628. typ = get_proper_type(typ)
  629. if (
  630. isinstance(typ, Instance)
  631. and typ.last_known_value
  632. and isinstance(typ.last_known_value.value, bool)
  633. ):
  634. typ = typ.last_known_value
  635. return mypy.checker.TypeRange(typ, is_upper_bound=False)
  636. def is_uninhabited(typ: Type) -> bool:
  637. return isinstance(get_proper_type(typ), UninhabitedType)