| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- """Constant folding of expressions.
- For example, 3 + 5 can be constant folded into 8.
- """
- from __future__ import annotations
- from typing import Union
- from typing_extensions import Final
- from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
- # All possible result types of constant folding
- ConstantValue = Union[int, bool, float, str]
- CONST_TYPES: Final = (int, bool, float, str)
- def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
- """Return the constant value of an expression for supported operations.
- Among other things, support int arithmetic and string
- concatenation. For example, the expression 3 + 5 has the constant
- value 8.
- Also bind simple references to final constants defined in the
- current module (cur_mod_id). Binding to references is best effort
- -- we don't bind references to other modules. Mypyc trusts these
- to be correct in compiled modules, so that it can replace a
- constant expression (or a reference to one) with the statically
- computed value. We don't want to infer constant values based on
- stubs, in particular, as these might not match the implementation
- (due to version skew, for example).
- Return None if unsuccessful.
- """
- if isinstance(expr, IntExpr):
- return expr.value
- if isinstance(expr, StrExpr):
- return expr.value
- if isinstance(expr, FloatExpr):
- return expr.value
- elif isinstance(expr, NameExpr):
- if expr.name == "True":
- return True
- elif expr.name == "False":
- return False
- node = expr.node
- if (
- isinstance(node, Var)
- and node.is_final
- and node.fullname.rsplit(".", 1)[0] == cur_mod_id
- ):
- value = node.final_value
- if isinstance(value, (CONST_TYPES)):
- return value
- elif isinstance(expr, OpExpr):
- left = constant_fold_expr(expr.left, cur_mod_id)
- right = constant_fold_expr(expr.right, cur_mod_id)
- if isinstance(left, int) and isinstance(right, int):
- return constant_fold_binary_int_op(expr.op, left, right)
- elif isinstance(left, str) and isinstance(right, str):
- return constant_fold_binary_str_op(expr.op, left, right)
- elif isinstance(expr, UnaryExpr):
- value = constant_fold_expr(expr.expr, cur_mod_id)
- if isinstance(value, int):
- return constant_fold_unary_int_op(expr.op, value)
- if isinstance(value, float):
- return constant_fold_unary_float_op(expr.op, value)
- return None
- def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
- if op == "+":
- return left + right
- if op == "-":
- return left - right
- elif op == "*":
- return left * right
- elif op == "//":
- if right != 0:
- return left // right
- elif op == "%":
- if right != 0:
- return left % right
- elif op == "&":
- return left & right
- elif op == "|":
- return left | right
- elif op == "^":
- return left ^ right
- elif op == "<<":
- if right >= 0:
- return left << right
- elif op == ">>":
- if right >= 0:
- return left >> right
- elif op == "**":
- if right >= 0:
- ret = left**right
- assert isinstance(ret, int)
- return ret
- return None
- def constant_fold_unary_int_op(op: str, value: int) -> int | None:
- if op == "-":
- return -value
- elif op == "~":
- return ~value
- elif op == "+":
- return value
- return None
- def constant_fold_unary_float_op(op: str, value: float) -> float | None:
- if op == "-":
- return -value
- elif op == "+":
- return value
- return None
- def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
- if op == "+":
- return left + right
- return None
|