| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355 |
- from contextlib import contextmanager
- from typing import Generator, List, Optional, Tuple
- from mypy.nodes import MatchStmt, NameExpr, TypeInfo
- from mypy.patterns import (
- AsPattern,
- ClassPattern,
- MappingPattern,
- OrPattern,
- Pattern,
- SequencePattern,
- SingletonPattern,
- StarredPattern,
- ValuePattern,
- )
- from mypy.traverser import TraverserVisitor
- from mypy.types import Instance, TupleType, get_proper_type
- from mypyc.ir.ops import BasicBlock, Value
- from mypyc.ir.rtypes import object_rprimitive
- from mypyc.irbuild.builder import IRBuilder
- from mypyc.primitives.dict_ops import (
- dict_copy,
- dict_del_item,
- mapping_has_key,
- supports_mapping_protocol,
- )
- from mypyc.primitives.generic_ops import generic_ssize_t_len_op
- from mypyc.primitives.list_ops import (
- sequence_get_item,
- sequence_get_slice,
- supports_sequence_protocol,
- )
- from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op
- # From: https://peps.python.org/pep-0634/#class-patterns
- MATCHABLE_BUILTINS = {
- "builtins.bool",
- "builtins.bytearray",
- "builtins.bytes",
- "builtins.dict",
- "builtins.float",
- "builtins.frozenset",
- "builtins.int",
- "builtins.list",
- "builtins.set",
- "builtins.str",
- "builtins.tuple",
- }
- class MatchVisitor(TraverserVisitor):
- builder: IRBuilder
- code_block: BasicBlock
- next_block: BasicBlock
- final_block: BasicBlock
- subject: Value
- match: MatchStmt
- as_pattern: Optional[AsPattern] = None
- def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
- self.builder = builder
- self.code_block = BasicBlock()
- self.next_block = BasicBlock()
- self.final_block = BasicBlock()
- self.match = match_node
- self.subject = builder.accept(match_node.subject)
- def build_match_body(self, index: int) -> None:
- self.builder.activate_block(self.code_block)
- guard = self.match.guards[index]
- if guard:
- self.code_block = BasicBlock()
- cond = self.builder.accept(guard)
- self.builder.add_bool_branch(cond, self.code_block, self.next_block)
- self.builder.activate_block(self.code_block)
- self.builder.accept(self.match.bodies[index])
- self.builder.goto(self.final_block)
- def visit_match_stmt(self, m: MatchStmt) -> None:
- for i, pattern in enumerate(m.patterns):
- self.code_block = BasicBlock()
- self.next_block = BasicBlock()
- pattern.accept(self)
- self.build_match_body(i)
- self.builder.activate_block(self.next_block)
- self.builder.goto_and_activate(self.final_block)
- def visit_value_pattern(self, pattern: ValuePattern) -> None:
- value = self.builder.accept(pattern.expr)
- cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)
- self.bind_as_pattern(value)
- self.builder.add_bool_branch(cond, self.code_block, self.next_block)
- def visit_or_pattern(self, pattern: OrPattern) -> None:
- backup_block = self.next_block
- self.next_block = BasicBlock()
- for p in pattern.patterns:
- # Hack to ensure the as pattern is bound to each pattern in the
- # "or" pattern, but not every subpattern
- backup = self.as_pattern
- p.accept(self)
- self.as_pattern = backup
- self.builder.activate_block(self.next_block)
- self.next_block = BasicBlock()
- self.next_block = backup_block
- self.builder.goto(self.next_block)
- def visit_class_pattern(self, pattern: ClassPattern) -> None:
- # TODO: use faster instance check for native classes (while still
- # making sure to account for inheritence)
- isinstance_op = (
- fast_isinstance_op
- if self.builder.is_builtin_ref_expr(pattern.class_ref)
- else slow_isinstance_op
- )
- cond = self.builder.call_c(
- isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
- )
- self.builder.add_bool_branch(cond, self.code_block, self.next_block)
- self.bind_as_pattern(self.subject, new_block=True)
- if pattern.positionals:
- if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- pattern.positionals[0].accept(self)
- return
- node = pattern.class_ref.node
- assert isinstance(node, TypeInfo)
- ty = node.names.get("__match_args__")
- assert ty
- match_args_type = get_proper_type(ty.type)
- assert isinstance(match_args_type, TupleType)
- match_args: List[str] = []
- for item in match_args_type.items:
- proper_item = get_proper_type(item)
- assert isinstance(proper_item, Instance) and proper_item.last_known_value
- match_arg = proper_item.last_known_value.value
- assert isinstance(match_arg, str)
- match_args.append(match_arg)
- for i, expr in enumerate(pattern.positionals):
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- # TODO: use faster "get_attr" method instead when calling on native or
- # builtin objects
- positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)
- with self.enter_subpattern(positional):
- expr.accept(self)
- for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- # TODO: same as above "get_attr" comment
- attr = self.builder.py_get_attr(self.subject, key, value.line)
- with self.enter_subpattern(attr):
- value.accept(self)
- def visit_as_pattern(self, pattern: AsPattern) -> None:
- if pattern.pattern:
- old_pattern = self.as_pattern
- self.as_pattern = pattern
- pattern.pattern.accept(self)
- self.as_pattern = old_pattern
- elif pattern.name:
- target = self.builder.get_assignment_target(pattern.name)
- self.builder.assign(target, self.subject, pattern.line)
- self.builder.goto(self.code_block)
- def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
- if pattern.value is None:
- obj = self.builder.none_object()
- elif pattern.value is True:
- obj = self.builder.true()
- else:
- obj = self.builder.false()
- cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)
- self.builder.add_bool_branch(cond, self.code_block, self.next_block)
- def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
- is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)
- self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)
- keys: List[Value] = []
- for key, value in zip(pattern.keys, pattern.values):
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- key_value = self.builder.accept(key)
- keys.append(key_value)
- exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)
- self.builder.add_bool_branch(exists, self.code_block, self.next_block)
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- item = self.builder.gen_method_call(
- self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
- )
- with self.enter_subpattern(item):
- value.accept(self)
- if pattern.rest:
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)
- target = self.builder.get_assignment_target(pattern.rest)
- self.builder.assign(target, rest, pattern.rest.line)
- for i, key_name in enumerate(keys):
- self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)
- self.builder.goto(self.code_block)
- def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
- star_index, capture, patterns = prep_sequence_pattern(seq_pattern)
- is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)
- self.builder.add_bool_branch(is_list, self.code_block, self.next_block)
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
- min_len = len(patterns)
- is_long_enough = self.builder.binary_op(
- actual_len,
- self.builder.load_int(min_len),
- "==" if star_index is None else ">=",
- seq_pattern.line,
- )
- self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)
- for i, pattern in enumerate(patterns):
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- if star_index is not None and i >= star_index:
- current = self.builder.binary_op(
- actual_len, self.builder.load_int(min_len - i), "-", pattern.line
- )
- else:
- current = self.builder.load_int(i)
- item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)
- with self.enter_subpattern(item):
- pattern.accept(self)
- if capture and star_index is not None:
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- capture_end = self.builder.binary_op(
- actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
- )
- rest = self.builder.call_c(
- sequence_get_slice,
- [self.subject, self.builder.load_int(star_index), capture_end],
- capture.line,
- )
- target = self.builder.get_assignment_target(capture)
- self.builder.assign(target, rest, capture.line)
- self.builder.goto(self.code_block)
- def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
- if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
- if new_block:
- self.builder.activate_block(self.code_block)
- self.code_block = BasicBlock()
- target = self.builder.get_assignment_target(self.as_pattern.name)
- self.builder.assign(target, value, self.as_pattern.pattern.line)
- self.as_pattern = None
- if new_block:
- self.builder.goto(self.code_block)
- @contextmanager
- def enter_subpattern(self, subject: Value) -> Generator[None, None, None]:
- old_subject = self.subject
- self.subject = subject
- yield
- self.subject = old_subject
- def prep_sequence_pattern(
- seq_pattern: SequencePattern,
- ) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]:
- star_index: Optional[int] = None
- capture: Optional[NameExpr] = None
- patterns: List[Pattern] = []
- for i, pattern in enumerate(seq_pattern.patterns):
- if isinstance(pattern, StarredPattern):
- star_index = i
- capture = pattern.capture
- else:
- patterns.append(pattern)
- return star_index, capture, patterns
|