constant_fold.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """Constant folding of IR values.
  2. For example, 3 + 5 can be constant folded into 8.
  3. This is mostly like mypy.constant_fold, but we can bind some additional
  4. NameExpr and MemberExpr references here, since we have more knowledge
  5. about which definitions can be trusted -- we constant fold only references
  6. to other compiled modules in the same compilation unit.
  7. """
  8. from __future__ import annotations
  9. from typing import Union
  10. from typing_extensions import Final
  11. from mypy.constant_fold import (
  12. constant_fold_binary_int_op,
  13. constant_fold_binary_str_op,
  14. constant_fold_unary_float_op,
  15. constant_fold_unary_int_op,
  16. )
  17. from mypy.nodes import (
  18. Expression,
  19. FloatExpr,
  20. IntExpr,
  21. MemberExpr,
  22. NameExpr,
  23. OpExpr,
  24. StrExpr,
  25. UnaryExpr,
  26. Var,
  27. )
  28. from mypyc.irbuild.builder import IRBuilder
  29. # All possible result types of constant folding
  30. ConstantValue = Union[int, str, float]
  31. CONST_TYPES: Final = (int, str, float)
  32. def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
  33. """Return the constant value of an expression for supported operations.
  34. Return None otherwise.
  35. """
  36. if isinstance(expr, IntExpr):
  37. return expr.value
  38. if isinstance(expr, StrExpr):
  39. return expr.value
  40. if isinstance(expr, FloatExpr):
  41. return expr.value
  42. elif isinstance(expr, NameExpr):
  43. node = expr.node
  44. if isinstance(node, Var) and node.is_final:
  45. value = node.final_value
  46. if isinstance(value, (CONST_TYPES)):
  47. return value
  48. elif isinstance(expr, MemberExpr):
  49. final = builder.get_final_ref(expr)
  50. if final is not None:
  51. fn, final_var, native = final
  52. if final_var.is_final:
  53. value = final_var.final_value
  54. if isinstance(value, (CONST_TYPES)):
  55. return value
  56. elif isinstance(expr, OpExpr):
  57. left = constant_fold_expr(builder, expr.left)
  58. right = constant_fold_expr(builder, expr.right)
  59. if isinstance(left, int) and isinstance(right, int):
  60. return constant_fold_binary_int_op(expr.op, left, right)
  61. elif isinstance(left, str) and isinstance(right, str):
  62. return constant_fold_binary_str_op(expr.op, left, right)
  63. elif isinstance(expr, UnaryExpr):
  64. value = constant_fold_expr(builder, expr.expr)
  65. if isinstance(value, int):
  66. return constant_fold_unary_int_op(expr.op, value)
  67. if isinstance(value, float):
  68. return constant_fold_unary_float_op(expr.op, value)
  69. return None