checkpattern.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. """Pattern checker. This file is conceptually part of TypeChecker."""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from typing import Final, NamedTuple
  5. import mypy.checker
  6. from mypy import message_registry
  7. from mypy.checkmember import analyze_member_access
  8. from mypy.expandtype import expand_type_by_instance
  9. from mypy.join import join_types
  10. from mypy.literals import literal_hash
  11. from mypy.maptype import map_instance_to_supertype
  12. from mypy.meet import narrow_declared_type
  13. from mypy.messages import MessageBuilder
  14. from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
  15. from mypy.options import Options
  16. from mypy.patterns import (
  17. AsPattern,
  18. ClassPattern,
  19. MappingPattern,
  20. OrPattern,
  21. Pattern,
  22. SequencePattern,
  23. SingletonPattern,
  24. StarredPattern,
  25. ValuePattern,
  26. )
  27. from mypy.plugin import Plugin
  28. from mypy.subtypes import is_subtype
  29. from mypy.typeops import (
  30. coerce_to_literal,
  31. make_simplified_union,
  32. try_getting_str_literals_from_type,
  33. tuple_fallback,
  34. )
  35. from mypy.types import (
  36. AnyType,
  37. Instance,
  38. LiteralType,
  39. NoneType,
  40. ProperType,
  41. TupleType,
  42. Type,
  43. TypedDictType,
  44. TypeOfAny,
  45. UninhabitedType,
  46. UnionType,
  47. get_proper_type,
  48. )
  49. from mypy.typevars import fill_typevars
  50. from mypy.visitor import PatternVisitor
  51. self_match_type_names: Final = [
  52. "builtins.bool",
  53. "builtins.bytearray",
  54. "builtins.bytes",
  55. "builtins.dict",
  56. "builtins.float",
  57. "builtins.frozenset",
  58. "builtins.int",
  59. "builtins.list",
  60. "builtins.set",
  61. "builtins.str",
  62. "builtins.tuple",
  63. ]
  64. non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"]
  65. # For every Pattern a PatternType can be calculated. This requires recursively calculating
  66. # the PatternTypes of the sub-patterns first.
  67. # Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
  68. class PatternType(NamedTuple):
  69. type: Type # The type the match subject can be narrowed to
  70. rest_type: Type # The remaining type if the pattern didn't match
  71. captures: dict[Expression, Type] # The variables captured by the pattern
  72. class PatternChecker(PatternVisitor[PatternType]):
  73. """Pattern checker.
  74. This class checks if a pattern can match a type, what the type can be narrowed to, and what
  75. type capture patterns should be inferred as.
  76. """
  77. # Some services are provided by a TypeChecker instance.
  78. chk: mypy.checker.TypeChecker
  79. # This is shared with TypeChecker, but stored also here for convenience.
  80. msg: MessageBuilder
  81. # Currently unused
  82. plugin: Plugin
  83. # The expression being matched against the pattern
  84. subject: Expression
  85. subject_type: Type
  86. # Type of the subject to check the (sub)pattern against
  87. type_context: list[Type]
  88. # Types that match against self instead of their __match_args__ if used as a class pattern
  89. # Filled in from self_match_type_names
  90. self_match_types: list[Type]
  91. # Types that are sequences, but don't match sequence patterns. Filled in from
  92. # non_sequence_match_type_names
  93. non_sequence_match_types: list[Type]
  94. options: Options
  95. def __init__(
  96. self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: Plugin, options: Options
  97. ) -> None:
  98. self.chk = chk
  99. self.msg = msg
  100. self.plugin = plugin
  101. self.type_context = []
  102. self.self_match_types = self.generate_types_from_names(self_match_type_names)
  103. self.non_sequence_match_types = self.generate_types_from_names(
  104. non_sequence_match_type_names
  105. )
  106. self.options = options
  107. def accept(self, o: Pattern, type_context: Type) -> PatternType:
  108. self.type_context.append(type_context)
  109. result = o.accept(self)
  110. self.type_context.pop()
  111. return result
  112. def visit_as_pattern(self, o: AsPattern) -> PatternType:
  113. current_type = self.type_context[-1]
  114. if o.pattern is not None:
  115. pattern_type = self.accept(o.pattern, current_type)
  116. typ, rest_type, type_map = pattern_type
  117. else:
  118. typ, rest_type, type_map = current_type, UninhabitedType(), {}
  119. if not is_uninhabited(typ) and o.name is not None:
  120. typ, _ = self.chk.conditional_types_with_intersection(
  121. current_type, [get_type_range(typ)], o, default=current_type
  122. )
  123. if not is_uninhabited(typ):
  124. type_map[o.name] = typ
  125. return PatternType(typ, rest_type, type_map)
  126. def visit_or_pattern(self, o: OrPattern) -> PatternType:
  127. current_type = self.type_context[-1]
  128. #
  129. # Check all the subpatterns
  130. #
  131. pattern_types = []
  132. for pattern in o.patterns:
  133. pattern_type = self.accept(pattern, current_type)
  134. pattern_types.append(pattern_type)
  135. current_type = pattern_type.rest_type
  136. #
  137. # Collect the final type
  138. #
  139. types = []
  140. for pattern_type in pattern_types:
  141. if not is_uninhabited(pattern_type.type):
  142. types.append(pattern_type.type)
  143. #
  144. # Check the capture types
  145. #
  146. capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
  147. # Collect captures from the first subpattern
  148. for expr, typ in pattern_types[0].captures.items():
  149. node = get_var(expr)
  150. capture_types[node].append((expr, typ))
  151. # Check if other subpatterns capture the same names
  152. for i, pattern_type in enumerate(pattern_types[1:]):
  153. vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
  154. if capture_types.keys() != vars:
  155. self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
  156. for expr, typ in pattern_type.captures.items():
  157. node = get_var(expr)
  158. capture_types[node].append((expr, typ))
  159. captures: dict[Expression, Type] = {}
  160. for var, capture_list in capture_types.items():
  161. typ = UninhabitedType()
  162. for _, other in capture_list:
  163. typ = join_types(typ, other)
  164. captures[capture_list[0][0]] = typ
  165. union_type = make_simplified_union(types)
  166. return PatternType(union_type, current_type, captures)
  167. def visit_value_pattern(self, o: ValuePattern) -> PatternType:
  168. current_type = self.type_context[-1]
  169. typ = self.chk.expr_checker.accept(o.expr)
  170. typ = coerce_to_literal(typ)
  171. narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
  172. current_type, [get_type_range(typ)], o, default=current_type
  173. )
  174. if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)):
  175. return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {})
  176. return PatternType(narrowed_type, rest_type, {})
  177. def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
  178. current_type = self.type_context[-1]
  179. value: bool | None = o.value
  180. if isinstance(value, bool):
  181. typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
  182. elif value is None:
  183. typ = NoneType()
  184. else:
  185. assert False
  186. narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
  187. current_type, [get_type_range(typ)], o, default=current_type
  188. )
  189. return PatternType(narrowed_type, rest_type, {})
  190. def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
  191. #
  192. # check for existence of a starred pattern
  193. #
  194. current_type = get_proper_type(self.type_context[-1])
  195. if not self.can_match_sequence(current_type):
  196. return self.early_non_match()
  197. star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
  198. star_position: int | None = None
  199. if len(star_positions) == 1:
  200. star_position = star_positions[0]
  201. elif len(star_positions) >= 2:
  202. assert False, "Parser should prevent multiple starred patterns"
  203. required_patterns = len(o.patterns)
  204. if star_position is not None:
  205. required_patterns -= 1
  206. #
  207. # get inner types of original type
  208. #
  209. if isinstance(current_type, TupleType):
  210. inner_types = current_type.items
  211. size_diff = len(inner_types) - required_patterns
  212. if size_diff < 0:
  213. return self.early_non_match()
  214. elif size_diff > 0 and star_position is None:
  215. return self.early_non_match()
  216. else:
  217. inner_type = self.get_sequence_type(current_type, o)
  218. if inner_type is None:
  219. inner_type = self.chk.named_type("builtins.object")
  220. inner_types = [inner_type] * len(o.patterns)
  221. #
  222. # match inner patterns
  223. #
  224. contracted_new_inner_types: list[Type] = []
  225. contracted_rest_inner_types: list[Type] = []
  226. captures: dict[Expression, Type] = {}
  227. contracted_inner_types = self.contract_starred_pattern_types(
  228. inner_types, star_position, required_patterns
  229. )
  230. for p, t in zip(o.patterns, contracted_inner_types):
  231. pattern_type = self.accept(p, t)
  232. typ, rest, type_map = pattern_type
  233. contracted_new_inner_types.append(typ)
  234. contracted_rest_inner_types.append(rest)
  235. self.update_type_map(captures, type_map)
  236. new_inner_types = self.expand_starred_pattern_types(
  237. contracted_new_inner_types, star_position, len(inner_types)
  238. )
  239. rest_inner_types = self.expand_starred_pattern_types(
  240. contracted_rest_inner_types, star_position, len(inner_types)
  241. )
  242. #
  243. # Calculate new type
  244. #
  245. new_type: Type
  246. rest_type: Type = current_type
  247. if isinstance(current_type, TupleType):
  248. narrowed_inner_types = []
  249. inner_rest_types = []
  250. for inner_type, new_inner_type in zip(inner_types, new_inner_types):
  251. (
  252. narrowed_inner_type,
  253. inner_rest_type,
  254. ) = self.chk.conditional_types_with_intersection(
  255. new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type
  256. )
  257. narrowed_inner_types.append(narrowed_inner_type)
  258. inner_rest_types.append(inner_rest_type)
  259. if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
  260. new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
  261. else:
  262. new_type = UninhabitedType()
  263. if all(is_uninhabited(typ) for typ in inner_rest_types):
  264. # All subpatterns always match, so we can apply negative narrowing
  265. rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
  266. else:
  267. new_inner_type = UninhabitedType()
  268. for typ in new_inner_types:
  269. new_inner_type = join_types(new_inner_type, typ)
  270. new_type = self.construct_sequence_child(current_type, new_inner_type)
  271. if is_subtype(new_type, current_type):
  272. new_type, _ = self.chk.conditional_types_with_intersection(
  273. current_type, [get_type_range(new_type)], o, default=current_type
  274. )
  275. else:
  276. new_type = current_type
  277. return PatternType(new_type, rest_type, captures)
  278. def get_sequence_type(self, t: Type, context: Context) -> Type | None:
  279. t = get_proper_type(t)
  280. if isinstance(t, AnyType):
  281. return AnyType(TypeOfAny.from_another_any, t)
  282. if isinstance(t, UnionType):
  283. items = [self.get_sequence_type(item, context) for item in t.items]
  284. not_none_items = [item for item in items if item is not None]
  285. if not_none_items:
  286. return make_simplified_union(not_none_items)
  287. else:
  288. return None
  289. if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
  290. if isinstance(t, TupleType):
  291. t = tuple_fallback(t)
  292. return self.chk.iterable_item_type(t, context)
  293. else:
  294. return None
  295. def contract_starred_pattern_types(
  296. self, types: list[Type], star_pos: int | None, num_patterns: int
  297. ) -> list[Type]:
  298. """
  299. Contracts a list of types in a sequence pattern depending on the position of a starred
  300. capture pattern.
  301. For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
  302. bytes] the contracted types are [bool, Union[int, str], bytes].
  303. If star_pos in None the types are returned unchanged.
  304. """
  305. if star_pos is None:
  306. return types
  307. new_types = types[:star_pos]
  308. star_length = len(types) - num_patterns
  309. new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
  310. new_types += types[star_pos + star_length :]
  311. return new_types
  312. def expand_starred_pattern_types(
  313. self, types: list[Type], star_pos: int | None, num_types: int
  314. ) -> list[Type]:
  315. """Undoes the contraction done by contract_starred_pattern_types.
  316. For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
  317. to length 4 the result is [bool, int, int, str].
  318. """
  319. if star_pos is None:
  320. return types
  321. new_types = types[:star_pos]
  322. star_length = num_types - len(types) + 1
  323. new_types += [types[star_pos]] * star_length
  324. new_types += types[star_pos + 1 :]
  325. return new_types
  326. def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
  327. captures: dict[Expression, Type] = {}
  328. if o.capture is not None:
  329. list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]])
  330. captures[o.capture] = list_type
  331. return PatternType(self.type_context[-1], UninhabitedType(), captures)
  332. def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
  333. current_type = get_proper_type(self.type_context[-1])
  334. can_match = True
  335. captures: dict[Expression, Type] = {}
  336. for key, value in zip(o.keys, o.values):
  337. inner_type = self.get_mapping_item_type(o, current_type, key)
  338. if inner_type is None:
  339. can_match = False
  340. inner_type = self.chk.named_type("builtins.object")
  341. pattern_type = self.accept(value, inner_type)
  342. if is_uninhabited(pattern_type.type):
  343. can_match = False
  344. else:
  345. self.update_type_map(captures, pattern_type.captures)
  346. if o.rest is not None:
  347. mapping = self.chk.named_type("typing.Mapping")
  348. if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
  349. mapping_inst = map_instance_to_supertype(current_type, mapping.type)
  350. dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
  351. rest_type = Instance(dict_typeinfo, mapping_inst.args)
  352. else:
  353. object_type = self.chk.named_type("builtins.object")
  354. rest_type = self.chk.named_generic_type(
  355. "builtins.dict", [object_type, object_type]
  356. )
  357. captures[o.rest] = rest_type
  358. if can_match:
  359. # We can't narrow the type here, as Mapping key is invariant.
  360. new_type = self.type_context[-1]
  361. else:
  362. new_type = UninhabitedType()
  363. return PatternType(new_type, current_type, captures)
  364. def get_mapping_item_type(
  365. self, pattern: MappingPattern, mapping_type: Type, key: Expression
  366. ) -> Type | None:
  367. mapping_type = get_proper_type(mapping_type)
  368. if isinstance(mapping_type, TypedDictType):
  369. with self.msg.filter_errors() as local_errors:
  370. result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
  371. mapping_type, key
  372. )
  373. has_local_errors = local_errors.has_new_errors()
  374. # If we can't determine the type statically fall back to treating it as a normal
  375. # mapping
  376. if has_local_errors:
  377. with self.msg.filter_errors() as local_errors:
  378. result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
  379. if local_errors.has_new_errors():
  380. result = None
  381. else:
  382. with self.msg.filter_errors():
  383. result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
  384. return result
  385. def get_simple_mapping_item_type(
  386. self, pattern: MappingPattern, mapping_type: Type, key: Expression
  387. ) -> Type:
  388. result, _ = self.chk.expr_checker.check_method_call_by_name(
  389. "__getitem__", mapping_type, [key], [ARG_POS], pattern
  390. )
  391. return result
  392. def visit_class_pattern(self, o: ClassPattern) -> PatternType:
  393. current_type = get_proper_type(self.type_context[-1])
  394. #
  395. # Check class type
  396. #
  397. type_info = o.class_ref.node
  398. if type_info is None:
  399. return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {})
  400. if isinstance(type_info, TypeAlias) and not type_info.no_args:
  401. self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
  402. return self.early_non_match()
  403. if isinstance(type_info, TypeInfo):
  404. any_type = AnyType(TypeOfAny.implementation_artifact)
  405. typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
  406. elif isinstance(type_info, TypeAlias):
  407. typ = type_info.target
  408. else:
  409. if isinstance(type_info, Var) and type_info.type is not None:
  410. name = type_info.type.str_with_options(self.options)
  411. else:
  412. name = type_info.name
  413. self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o)
  414. return self.early_non_match()
  415. new_type, rest_type = self.chk.conditional_types_with_intersection(
  416. current_type, [get_type_range(typ)], o, default=current_type
  417. )
  418. if is_uninhabited(new_type):
  419. return self.early_non_match()
  420. # TODO: Do I need this?
  421. narrowed_type = narrow_declared_type(current_type, new_type)
  422. #
  423. # Convert positional to keyword patterns
  424. #
  425. keyword_pairs: list[tuple[str | None, Pattern]] = []
  426. match_arg_set: set[str] = set()
  427. captures: dict[Expression, Type] = {}
  428. if len(o.positionals) != 0:
  429. if self.should_self_match(typ):
  430. if len(o.positionals) > 1:
  431. self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
  432. pattern_type = self.accept(o.positionals[0], narrowed_type)
  433. if not is_uninhabited(pattern_type.type):
  434. return PatternType(
  435. pattern_type.type,
  436. join_types(rest_type, pattern_type.rest_type),
  437. pattern_type.captures,
  438. )
  439. captures = pattern_type.captures
  440. else:
  441. with self.msg.filter_errors() as local_errors:
  442. match_args_type = analyze_member_access(
  443. "__match_args__",
  444. typ,
  445. o,
  446. False,
  447. False,
  448. False,
  449. self.msg,
  450. original_type=typ,
  451. chk=self.chk,
  452. )
  453. has_local_errors = local_errors.has_new_errors()
  454. if has_local_errors:
  455. self.msg.fail(
  456. message_registry.MISSING_MATCH_ARGS.format(
  457. typ.str_with_options(self.options)
  458. ),
  459. o,
  460. )
  461. return self.early_non_match()
  462. proper_match_args_type = get_proper_type(match_args_type)
  463. if isinstance(proper_match_args_type, TupleType):
  464. match_arg_names = get_match_arg_names(proper_match_args_type)
  465. if len(o.positionals) > len(match_arg_names):
  466. self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
  467. return self.early_non_match()
  468. else:
  469. match_arg_names = [None] * len(o.positionals)
  470. for arg_name, pos in zip(match_arg_names, o.positionals):
  471. keyword_pairs.append((arg_name, pos))
  472. if arg_name is not None:
  473. match_arg_set.add(arg_name)
  474. #
  475. # Check for duplicate patterns
  476. #
  477. keyword_arg_set = set()
  478. has_duplicates = False
  479. for key, value in zip(o.keyword_keys, o.keyword_values):
  480. keyword_pairs.append((key, value))
  481. if key in match_arg_set:
  482. self.msg.fail(
  483. message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value
  484. )
  485. has_duplicates = True
  486. elif key in keyword_arg_set:
  487. self.msg.fail(
  488. message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value
  489. )
  490. has_duplicates = True
  491. keyword_arg_set.add(key)
  492. if has_duplicates:
  493. return self.early_non_match()
  494. #
  495. # Check keyword patterns
  496. #
  497. can_match = True
  498. for keyword, pattern in keyword_pairs:
  499. key_type: Type | None = None
  500. with self.msg.filter_errors() as local_errors:
  501. if keyword is not None:
  502. key_type = analyze_member_access(
  503. keyword,
  504. narrowed_type,
  505. pattern,
  506. False,
  507. False,
  508. False,
  509. self.msg,
  510. original_type=new_type,
  511. chk=self.chk,
  512. )
  513. else:
  514. key_type = AnyType(TypeOfAny.from_error)
  515. has_local_errors = local_errors.has_new_errors()
  516. if has_local_errors or key_type is None:
  517. key_type = AnyType(TypeOfAny.from_error)
  518. self.msg.fail(
  519. message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
  520. typ.str_with_options(self.options), keyword
  521. ),
  522. pattern,
  523. )
  524. inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
  525. if is_uninhabited(inner_type):
  526. can_match = False
  527. else:
  528. self.update_type_map(captures, inner_captures)
  529. if not is_uninhabited(inner_rest_type):
  530. rest_type = current_type
  531. if not can_match:
  532. new_type = UninhabitedType()
  533. return PatternType(new_type, rest_type, captures)
  534. def should_self_match(self, typ: Type) -> bool:
  535. typ = get_proper_type(typ)
  536. if isinstance(typ, Instance) and typ.type.is_named_tuple:
  537. return False
  538. for other in self.self_match_types:
  539. if is_subtype(typ, other):
  540. return True
  541. return False
  542. def can_match_sequence(self, typ: ProperType) -> bool:
  543. if isinstance(typ, UnionType):
  544. return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
  545. for other in self.non_sequence_match_types:
  546. # We have to ignore promotions, as memoryview should match, but bytes,
  547. # which it can be promoted to, shouldn't
  548. if is_subtype(typ, other, ignore_promotions=True):
  549. return False
  550. sequence = self.chk.named_type("typing.Sequence")
  551. # If the static type is more general than sequence the actual type could still match
  552. return is_subtype(typ, sequence) or is_subtype(sequence, typ)
  553. def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
  554. types: list[Type] = []
  555. for name in type_names:
  556. try:
  557. types.append(self.chk.named_type(name))
  558. except KeyError as e:
  559. # Some built in types are not defined in all test cases
  560. if not name.startswith("builtins."):
  561. raise e
  562. return types
  563. def update_type_map(
  564. self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
  565. ) -> None:
  566. # Calculating this would not be needed if TypeMap directly used literal hashes instead of
  567. # expressions, as suggested in the TODO above it's definition
  568. already_captured = {literal_hash(expr) for expr in original_type_map}
  569. for expr, typ in extra_type_map.items():
  570. if literal_hash(expr) in already_captured:
  571. node = get_var(expr)
  572. self.msg.fail(
  573. message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
  574. )
  575. else:
  576. original_type_map[expr] = typ
  577. def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
  578. """
  579. If outer_type is a child class of typing.Sequence returns a new instance of
  580. outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
  581. typing.Sequence just returns a Sequence of inner_type
  582. For example:
  583. construct_sequence_child(List[int], str) = List[str]
  584. TODO: this doesn't make sense. For example if one has class S(Sequence[int], Generic[T])
  585. or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
  586. """
  587. proper_type = get_proper_type(outer_type)
  588. if isinstance(proper_type, UnionType):
  589. types = [
  590. self.construct_sequence_child(item, inner_type)
  591. for item in proper_type.items
  592. if self.can_match_sequence(get_proper_type(item))
  593. ]
  594. return make_simplified_union(types)
  595. sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
  596. if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
  597. proper_type = get_proper_type(outer_type)
  598. if isinstance(proper_type, TupleType):
  599. proper_type = tuple_fallback(proper_type)
  600. assert isinstance(proper_type, Instance)
  601. empty_type = fill_typevars(proper_type.type)
  602. partial_type = expand_type_by_instance(empty_type, sequence)
  603. return expand_type_by_instance(partial_type, proper_type)
  604. else:
  605. return sequence
  606. def early_non_match(self) -> PatternType:
  607. return PatternType(UninhabitedType(), self.type_context[-1], {})
  608. def get_match_arg_names(typ: TupleType) -> list[str | None]:
  609. args: list[str | None] = []
  610. for item in typ.items:
  611. values = try_getting_str_literals_from_type(item)
  612. if values is None or len(values) != 1:
  613. args.append(None)
  614. else:
  615. args.append(values[0])
  616. return args
  617. def get_var(expr: Expression) -> Var:
  618. """
  619. Warning: this in only true for expressions captured by a match statement.
  620. Don't call it from anywhere else
  621. """
  622. assert isinstance(expr, NameExpr)
  623. node = expr.node
  624. assert isinstance(node, Var)
  625. return node
  626. def get_type_range(typ: Type) -> mypy.checker.TypeRange:
  627. typ = get_proper_type(typ)
  628. if (
  629. isinstance(typ, Instance)
  630. and typ.last_known_value
  631. and isinstance(typ.last_known_value.value, bool)
  632. ):
  633. typ = typ.last_known_value
  634. return mypy.checker.TypeRange(typ, is_upper_bound=False)
  635. def is_uninhabited(typ: Type) -> bool:
  636. return isinstance(get_proper_type(typ), UninhabitedType)