| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- """Constant folding of IR values.
- For example, 3 + 5 can be constant folded into 8.
- This is mostly like mypy.constant_fold, but we can bind some additional
- NameExpr and MemberExpr references here, since we have more knowledge
- about which definitions can be trusted -- we constant fold only references
- to other compiled modules in the same compilation unit.
- """
- from __future__ import annotations
- from typing import Final, Union
- from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
- from mypy.nodes import (
- BytesExpr,
- ComplexExpr,
- Expression,
- FloatExpr,
- IntExpr,
- MemberExpr,
- NameExpr,
- OpExpr,
- StrExpr,
- UnaryExpr,
- Var,
- )
- from mypyc.irbuild.builder import IRBuilder
- from mypyc.irbuild.util import bytes_from_str
- # All possible result types of constant folding
- ConstantValue = Union[int, float, complex, str, bytes]
- CONST_TYPES: Final = (int, float, complex, str, bytes)
- def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
- """Return the constant value of an expression for supported operations.
- Return None otherwise.
- """
- if isinstance(expr, IntExpr):
- return expr.value
- if isinstance(expr, FloatExpr):
- return expr.value
- if isinstance(expr, StrExpr):
- return expr.value
- if isinstance(expr, BytesExpr):
- return bytes_from_str(expr.value)
- if isinstance(expr, ComplexExpr):
- return expr.value
- elif isinstance(expr, NameExpr):
- node = expr.node
- if isinstance(node, Var) and node.is_final:
- final_value = node.final_value
- if isinstance(final_value, (CONST_TYPES)):
- return final_value
- elif isinstance(expr, MemberExpr):
- final = builder.get_final_ref(expr)
- if final is not None:
- fn, final_var, native = final
- if final_var.is_final:
- final_value = final_var.final_value
- if isinstance(final_value, (CONST_TYPES)):
- return final_value
- elif isinstance(expr, OpExpr):
- left = constant_fold_expr(builder, expr.left)
- right = constant_fold_expr(builder, expr.right)
- if left is not None and right is not None:
- return constant_fold_binary_op_extended(expr.op, left, right)
- elif isinstance(expr, UnaryExpr):
- value = constant_fold_expr(builder, expr.expr)
- if value is not None and not isinstance(value, bytes):
- return constant_fold_unary_op(expr.op, value)
- return None
- def constant_fold_binary_op_extended(
- op: str, left: ConstantValue, right: ConstantValue
- ) -> ConstantValue | None:
- """Like mypy's constant_fold_binary_op(), but includes bytes support.
- mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
- """
- if not isinstance(left, bytes) and not isinstance(right, bytes):
- return constant_fold_binary_op(op, left, right)
- if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
- return left + right
- elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
- return left * right
- elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
- return left * right
- return None
|