constant_fold.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """Constant folding of IR values.
  2. For example, 3 + 5 can be constant folded into 8.
  3. This is mostly like mypy.constant_fold, but we can bind some additional
  4. NameExpr and MemberExpr references here, since we have more knowledge
  5. about which definitions can be trusted -- we constant fold only references
  6. to other compiled modules in the same compilation unit.
  7. """
  8. from __future__ import annotations
  9. from typing import Final, Union
  10. from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
  11. from mypy.nodes import (
  12. BytesExpr,
  13. ComplexExpr,
  14. Expression,
  15. FloatExpr,
  16. IntExpr,
  17. MemberExpr,
  18. NameExpr,
  19. OpExpr,
  20. StrExpr,
  21. UnaryExpr,
  22. Var,
  23. )
  24. from mypyc.irbuild.builder import IRBuilder
  25. from mypyc.irbuild.util import bytes_from_str
  26. # All possible result types of constant folding
  27. ConstantValue = Union[int, float, complex, str, bytes]
  28. CONST_TYPES: Final = (int, float, complex, str, bytes)
  29. def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
  30. """Return the constant value of an expression for supported operations.
  31. Return None otherwise.
  32. """
  33. if isinstance(expr, IntExpr):
  34. return expr.value
  35. if isinstance(expr, FloatExpr):
  36. return expr.value
  37. if isinstance(expr, StrExpr):
  38. return expr.value
  39. if isinstance(expr, BytesExpr):
  40. return bytes_from_str(expr.value)
  41. if isinstance(expr, ComplexExpr):
  42. return expr.value
  43. elif isinstance(expr, NameExpr):
  44. node = expr.node
  45. if isinstance(node, Var) and node.is_final:
  46. final_value = node.final_value
  47. if isinstance(final_value, (CONST_TYPES)):
  48. return final_value
  49. elif isinstance(expr, MemberExpr):
  50. final = builder.get_final_ref(expr)
  51. if final is not None:
  52. fn, final_var, native = final
  53. if final_var.is_final:
  54. final_value = final_var.final_value
  55. if isinstance(final_value, (CONST_TYPES)):
  56. return final_value
  57. elif isinstance(expr, OpExpr):
  58. left = constant_fold_expr(builder, expr.left)
  59. right = constant_fold_expr(builder, expr.right)
  60. if left is not None and right is not None:
  61. return constant_fold_binary_op_extended(expr.op, left, right)
  62. elif isinstance(expr, UnaryExpr):
  63. value = constant_fold_expr(builder, expr.expr)
  64. if value is not None and not isinstance(value, bytes):
  65. return constant_fold_unary_op(expr.op, value)
  66. return None
  67. def constant_fold_binary_op_extended(
  68. op: str, left: ConstantValue, right: ConstantValue
  69. ) -> ConstantValue | None:
  70. """Like mypy's constant_fold_binary_op(), but includes bytes support.
  71. mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
  72. """
  73. if not isinstance(left, bytes) and not isinstance(right, bytes):
  74. return constant_fold_binary_op(op, left, right)
  75. if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
  76. return left + right
  77. elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
  78. return left * right
  79. elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
  80. return left * right
  81. return None