match.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from contextlib import contextmanager
  2. from typing import Generator, List, Optional, Tuple
  3. from mypy.nodes import MatchStmt, NameExpr, TypeInfo
  4. from mypy.patterns import (
  5. AsPattern,
  6. ClassPattern,
  7. MappingPattern,
  8. OrPattern,
  9. Pattern,
  10. SequencePattern,
  11. SingletonPattern,
  12. StarredPattern,
  13. ValuePattern,
  14. )
  15. from mypy.traverser import TraverserVisitor
  16. from mypy.types import Instance, TupleType, get_proper_type
  17. from mypyc.ir.ops import BasicBlock, Value
  18. from mypyc.ir.rtypes import object_rprimitive
  19. from mypyc.irbuild.builder import IRBuilder
  20. from mypyc.primitives.dict_ops import (
  21. dict_copy,
  22. dict_del_item,
  23. mapping_has_key,
  24. supports_mapping_protocol,
  25. )
  26. from mypyc.primitives.generic_ops import generic_ssize_t_len_op
  27. from mypyc.primitives.list_ops import (
  28. sequence_get_item,
  29. sequence_get_slice,
  30. supports_sequence_protocol,
  31. )
  32. from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op
  33. # From: https://peps.python.org/pep-0634/#class-patterns
  34. MATCHABLE_BUILTINS = {
  35. "builtins.bool",
  36. "builtins.bytearray",
  37. "builtins.bytes",
  38. "builtins.dict",
  39. "builtins.float",
  40. "builtins.frozenset",
  41. "builtins.int",
  42. "builtins.list",
  43. "builtins.set",
  44. "builtins.str",
  45. "builtins.tuple",
  46. }
  47. class MatchVisitor(TraverserVisitor):
  48. builder: IRBuilder
  49. code_block: BasicBlock
  50. next_block: BasicBlock
  51. final_block: BasicBlock
  52. subject: Value
  53. match: MatchStmt
  54. as_pattern: Optional[AsPattern] = None
  55. def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
  56. self.builder = builder
  57. self.code_block = BasicBlock()
  58. self.next_block = BasicBlock()
  59. self.final_block = BasicBlock()
  60. self.match = match_node
  61. self.subject = builder.accept(match_node.subject)
  62. def build_match_body(self, index: int) -> None:
  63. self.builder.activate_block(self.code_block)
  64. guard = self.match.guards[index]
  65. if guard:
  66. self.code_block = BasicBlock()
  67. cond = self.builder.accept(guard)
  68. self.builder.add_bool_branch(cond, self.code_block, self.next_block)
  69. self.builder.activate_block(self.code_block)
  70. self.builder.accept(self.match.bodies[index])
  71. self.builder.goto(self.final_block)
  72. def visit_match_stmt(self, m: MatchStmt) -> None:
  73. for i, pattern in enumerate(m.patterns):
  74. self.code_block = BasicBlock()
  75. self.next_block = BasicBlock()
  76. pattern.accept(self)
  77. self.build_match_body(i)
  78. self.builder.activate_block(self.next_block)
  79. self.builder.goto_and_activate(self.final_block)
  80. def visit_value_pattern(self, pattern: ValuePattern) -> None:
  81. value = self.builder.accept(pattern.expr)
  82. cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)
  83. self.bind_as_pattern(value)
  84. self.builder.add_bool_branch(cond, self.code_block, self.next_block)
  85. def visit_or_pattern(self, pattern: OrPattern) -> None:
  86. backup_block = self.next_block
  87. self.next_block = BasicBlock()
  88. for p in pattern.patterns:
  89. # Hack to ensure the as pattern is bound to each pattern in the
  90. # "or" pattern, but not every subpattern
  91. backup = self.as_pattern
  92. p.accept(self)
  93. self.as_pattern = backup
  94. self.builder.activate_block(self.next_block)
  95. self.next_block = BasicBlock()
  96. self.next_block = backup_block
  97. self.builder.goto(self.next_block)
  98. def visit_class_pattern(self, pattern: ClassPattern) -> None:
  99. # TODO: use faster instance check for native classes (while still
  100. # making sure to account for inheritence)
  101. isinstance_op = (
  102. fast_isinstance_op
  103. if self.builder.is_builtin_ref_expr(pattern.class_ref)
  104. else slow_isinstance_op
  105. )
  106. cond = self.builder.call_c(
  107. isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
  108. )
  109. self.builder.add_bool_branch(cond, self.code_block, self.next_block)
  110. self.bind_as_pattern(self.subject, new_block=True)
  111. if pattern.positionals:
  112. if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
  113. self.builder.activate_block(self.code_block)
  114. self.code_block = BasicBlock()
  115. pattern.positionals[0].accept(self)
  116. return
  117. node = pattern.class_ref.node
  118. assert isinstance(node, TypeInfo)
  119. ty = node.names.get("__match_args__")
  120. assert ty
  121. match_args_type = get_proper_type(ty.type)
  122. assert isinstance(match_args_type, TupleType)
  123. match_args: List[str] = []
  124. for item in match_args_type.items:
  125. proper_item = get_proper_type(item)
  126. assert isinstance(proper_item, Instance) and proper_item.last_known_value
  127. match_arg = proper_item.last_known_value.value
  128. assert isinstance(match_arg, str)
  129. match_args.append(match_arg)
  130. for i, expr in enumerate(pattern.positionals):
  131. self.builder.activate_block(self.code_block)
  132. self.code_block = BasicBlock()
  133. # TODO: use faster "get_attr" method instead when calling on native or
  134. # builtin objects
  135. positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)
  136. with self.enter_subpattern(positional):
  137. expr.accept(self)
  138. for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
  139. self.builder.activate_block(self.code_block)
  140. self.code_block = BasicBlock()
  141. # TODO: same as above "get_attr" comment
  142. attr = self.builder.py_get_attr(self.subject, key, value.line)
  143. with self.enter_subpattern(attr):
  144. value.accept(self)
  145. def visit_as_pattern(self, pattern: AsPattern) -> None:
  146. if pattern.pattern:
  147. old_pattern = self.as_pattern
  148. self.as_pattern = pattern
  149. pattern.pattern.accept(self)
  150. self.as_pattern = old_pattern
  151. elif pattern.name:
  152. target = self.builder.get_assignment_target(pattern.name)
  153. self.builder.assign(target, self.subject, pattern.line)
  154. self.builder.goto(self.code_block)
  155. def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
  156. if pattern.value is None:
  157. obj = self.builder.none_object()
  158. elif pattern.value is True:
  159. obj = self.builder.true()
  160. else:
  161. obj = self.builder.false()
  162. cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)
  163. self.builder.add_bool_branch(cond, self.code_block, self.next_block)
  164. def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
  165. is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)
  166. self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)
  167. keys: List[Value] = []
  168. for key, value in zip(pattern.keys, pattern.values):
  169. self.builder.activate_block(self.code_block)
  170. self.code_block = BasicBlock()
  171. key_value = self.builder.accept(key)
  172. keys.append(key_value)
  173. exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)
  174. self.builder.add_bool_branch(exists, self.code_block, self.next_block)
  175. self.builder.activate_block(self.code_block)
  176. self.code_block = BasicBlock()
  177. item = self.builder.gen_method_call(
  178. self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
  179. )
  180. with self.enter_subpattern(item):
  181. value.accept(self)
  182. if pattern.rest:
  183. self.builder.activate_block(self.code_block)
  184. self.code_block = BasicBlock()
  185. rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)
  186. target = self.builder.get_assignment_target(pattern.rest)
  187. self.builder.assign(target, rest, pattern.rest.line)
  188. for i, key_name in enumerate(keys):
  189. self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)
  190. self.builder.goto(self.code_block)
  191. def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
  192. star_index, capture, patterns = prep_sequence_pattern(seq_pattern)
  193. is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)
  194. self.builder.add_bool_branch(is_list, self.code_block, self.next_block)
  195. self.builder.activate_block(self.code_block)
  196. self.code_block = BasicBlock()
  197. actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
  198. min_len = len(patterns)
  199. is_long_enough = self.builder.binary_op(
  200. actual_len,
  201. self.builder.load_int(min_len),
  202. "==" if star_index is None else ">=",
  203. seq_pattern.line,
  204. )
  205. self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)
  206. for i, pattern in enumerate(patterns):
  207. self.builder.activate_block(self.code_block)
  208. self.code_block = BasicBlock()
  209. if star_index is not None and i >= star_index:
  210. current = self.builder.binary_op(
  211. actual_len, self.builder.load_int(min_len - i), "-", pattern.line
  212. )
  213. else:
  214. current = self.builder.load_int(i)
  215. item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)
  216. with self.enter_subpattern(item):
  217. pattern.accept(self)
  218. if capture and star_index is not None:
  219. self.builder.activate_block(self.code_block)
  220. self.code_block = BasicBlock()
  221. capture_end = self.builder.binary_op(
  222. actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
  223. )
  224. rest = self.builder.call_c(
  225. sequence_get_slice,
  226. [self.subject, self.builder.load_int(star_index), capture_end],
  227. capture.line,
  228. )
  229. target = self.builder.get_assignment_target(capture)
  230. self.builder.assign(target, rest, capture.line)
  231. self.builder.goto(self.code_block)
  232. def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
  233. if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
  234. if new_block:
  235. self.builder.activate_block(self.code_block)
  236. self.code_block = BasicBlock()
  237. target = self.builder.get_assignment_target(self.as_pattern.name)
  238. self.builder.assign(target, value, self.as_pattern.pattern.line)
  239. self.as_pattern = None
  240. if new_block:
  241. self.builder.goto(self.code_block)
  242. @contextmanager
  243. def enter_subpattern(self, subject: Value) -> Generator[None, None, None]:
  244. old_subject = self.subject
  245. self.subject = subject
  246. yield
  247. self.subject = old_subject
  248. def prep_sequence_pattern(
  249. seq_pattern: SequencePattern,
  250. ) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]:
  251. star_index: Optional[int] = None
  252. capture: Optional[NameExpr] = None
  253. patterns: List[Pattern] = []
  254. for i, pattern in enumerate(seq_pattern.patterns):
  255. if isinstance(pattern, StarredPattern):
  256. star_index = i
  257. capture = pattern.capture
  258. else:
  259. patterns.append(pattern)
  260. return star_index, capture, patterns