reachability.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """Utilities related to determining the reachability of code (in semantic analysis)."""
  2. from __future__ import annotations
  3. from typing import Tuple, TypeVar
  4. from typing_extensions import Final
  5. from mypy.literals import literal
  6. from mypy.nodes import (
  7. LITERAL_YES,
  8. AssertStmt,
  9. Block,
  10. CallExpr,
  11. ComparisonExpr,
  12. Expression,
  13. FuncDef,
  14. IfStmt,
  15. Import,
  16. ImportAll,
  17. ImportFrom,
  18. IndexExpr,
  19. IntExpr,
  20. MatchStmt,
  21. MemberExpr,
  22. NameExpr,
  23. OpExpr,
  24. SliceExpr,
  25. StrExpr,
  26. TupleExpr,
  27. UnaryExpr,
  28. )
  29. from mypy.options import Options
  30. from mypy.patterns import AsPattern, OrPattern, Pattern
  31. from mypy.traverser import TraverserVisitor
  32. # Inferred truth value of an expression.
  33. ALWAYS_TRUE: Final = 1
  34. MYPY_TRUE: Final = 2 # True in mypy, False at runtime
  35. ALWAYS_FALSE: Final = 3
  36. MYPY_FALSE: Final = 4 # False in mypy, True at runtime
  37. TRUTH_VALUE_UNKNOWN: Final = 5
  38. inverted_truth_mapping: Final = {
  39. ALWAYS_TRUE: ALWAYS_FALSE,
  40. ALWAYS_FALSE: ALWAYS_TRUE,
  41. TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
  42. MYPY_TRUE: MYPY_FALSE,
  43. MYPY_FALSE: MYPY_TRUE,
  44. }
  45. reverse_op: Final = {"==": "==", "!=": "!=", "<": ">", ">": "<", "<=": ">=", ">=": "<="}
  46. def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
  47. for i in range(len(s.expr)):
  48. result = infer_condition_value(s.expr[i], options)
  49. if result in (ALWAYS_FALSE, MYPY_FALSE):
  50. # The condition is considered always false, so we skip the if/elif body.
  51. mark_block_unreachable(s.body[i])
  52. elif result in (ALWAYS_TRUE, MYPY_TRUE):
  53. # This condition is considered always true, so all of the remaining
  54. # elif/else bodies should not be checked.
  55. if result == MYPY_TRUE:
  56. # This condition is false at runtime; this will affect
  57. # import priorities.
  58. mark_block_mypy_only(s.body[i])
  59. for body in s.body[i + 1 :]:
  60. mark_block_unreachable(body)
  61. # Make sure else body always exists and is marked as
  62. # unreachable so the type checker always knows that
  63. # all control flow paths will flow through the if
  64. # statement body.
  65. if not s.else_body:
  66. s.else_body = Block([])
  67. mark_block_unreachable(s.else_body)
  68. break
  69. def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None:
  70. for i, guard in enumerate(s.guards):
  71. pattern_value = infer_pattern_value(s.patterns[i])
  72. if guard is not None:
  73. guard_value = infer_condition_value(guard, options)
  74. else:
  75. guard_value = ALWAYS_TRUE
  76. if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) or guard_value in (
  77. ALWAYS_FALSE,
  78. MYPY_FALSE,
  79. ):
  80. # The case is considered always false, so we skip the case body.
  81. mark_block_unreachable(s.bodies[i])
  82. elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) and guard_value in (
  83. ALWAYS_TRUE,
  84. MYPY_TRUE,
  85. ):
  86. for body in s.bodies[i + 1 :]:
  87. mark_block_unreachable(body)
  88. if guard_value == MYPY_TRUE:
  89. # This condition is false at runtime; this will affect
  90. # import priorities.
  91. mark_block_mypy_only(s.bodies[i])
  92. def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
  93. return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)
  94. def infer_condition_value(expr: Expression, options: Options) -> int:
  95. """Infer whether the given condition is always true/false.
  96. Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
  97. MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
  98. false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
  99. """
  100. pyversion = options.python_version
  101. name = ""
  102. negated = False
  103. alias = expr
  104. if isinstance(alias, UnaryExpr):
  105. if alias.op == "not":
  106. expr = alias.expr
  107. negated = True
  108. result = TRUTH_VALUE_UNKNOWN
  109. if isinstance(expr, NameExpr):
  110. name = expr.name
  111. elif isinstance(expr, MemberExpr):
  112. name = expr.name
  113. elif isinstance(expr, OpExpr) and expr.op in ("and", "or"):
  114. left = infer_condition_value(expr.left, options)
  115. if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or (
  116. left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or"
  117. ):
  118. # Either `True and <other>` or `False or <other>`: the result will
  119. # always be the right-hand-side.
  120. return infer_condition_value(expr.right, options)
  121. else:
  122. # The result will always be the left-hand-side (e.g. ALWAYS_* or
  123. # TRUTH_VALUE_UNKNOWN).
  124. return left
  125. else:
  126. result = consider_sys_version_info(expr, pyversion)
  127. if result == TRUTH_VALUE_UNKNOWN:
  128. result = consider_sys_platform(expr, options.platform)
  129. if result == TRUTH_VALUE_UNKNOWN:
  130. if name == "PY2":
  131. result = ALWAYS_FALSE
  132. elif name == "PY3":
  133. result = ALWAYS_TRUE
  134. elif name == "MYPY" or name == "TYPE_CHECKING":
  135. result = MYPY_TRUE
  136. elif name in options.always_true:
  137. result = ALWAYS_TRUE
  138. elif name in options.always_false:
  139. result = ALWAYS_FALSE
  140. if negated:
  141. result = inverted_truth_mapping[result]
  142. return result
  143. def infer_pattern_value(pattern: Pattern) -> int:
  144. if isinstance(pattern, AsPattern) and pattern.pattern is None:
  145. return ALWAYS_TRUE
  146. elif isinstance(pattern, OrPattern) and any(
  147. infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns
  148. ):
  149. return ALWAYS_TRUE
  150. else:
  151. return TRUTH_VALUE_UNKNOWN
  152. def consider_sys_version_info(expr: Expression, pyversion: tuple[int, ...]) -> int:
  153. """Consider whether expr is a comparison involving sys.version_info.
  154. Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
  155. """
  156. # Cases supported:
  157. # - sys.version_info[<int>] <compare_op> <int>
  158. # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
  159. # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
  160. # (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
  161. if not isinstance(expr, ComparisonExpr):
  162. return TRUTH_VALUE_UNKNOWN
  163. # Let's not yet support chained comparisons.
  164. if len(expr.operators) > 1:
  165. return TRUTH_VALUE_UNKNOWN
  166. op = expr.operators[0]
  167. if op not in ("==", "!=", "<=", ">=", "<", ">"):
  168. return TRUTH_VALUE_UNKNOWN
  169. index = contains_sys_version_info(expr.operands[0])
  170. thing = contains_int_or_tuple_of_ints(expr.operands[1])
  171. if index is None or thing is None:
  172. index = contains_sys_version_info(expr.operands[1])
  173. thing = contains_int_or_tuple_of_ints(expr.operands[0])
  174. op = reverse_op[op]
  175. if isinstance(index, int) and isinstance(thing, int):
  176. # sys.version_info[i] <compare_op> k
  177. if 0 <= index <= 1:
  178. return fixed_comparison(pyversion[index], op, thing)
  179. else:
  180. return TRUTH_VALUE_UNKNOWN
  181. elif isinstance(index, tuple) and isinstance(thing, tuple):
  182. lo, hi = index
  183. if lo is None:
  184. lo = 0
  185. if hi is None:
  186. hi = 2
  187. if 0 <= lo < hi <= 2:
  188. val = pyversion[lo:hi]
  189. if len(val) == len(thing) or len(val) > len(thing) and op not in ("==", "!="):
  190. return fixed_comparison(val, op, thing)
  191. return TRUTH_VALUE_UNKNOWN
  192. def consider_sys_platform(expr: Expression, platform: str) -> int:
  193. """Consider whether expr is a comparison involving sys.platform.
  194. Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
  195. """
  196. # Cases supported:
  197. # - sys.platform == 'posix'
  198. # - sys.platform != 'win32'
  199. # - sys.platform.startswith('win')
  200. if isinstance(expr, ComparisonExpr):
  201. # Let's not yet support chained comparisons.
  202. if len(expr.operators) > 1:
  203. return TRUTH_VALUE_UNKNOWN
  204. op = expr.operators[0]
  205. if op not in ("==", "!="):
  206. return TRUTH_VALUE_UNKNOWN
  207. if not is_sys_attr(expr.operands[0], "platform"):
  208. return TRUTH_VALUE_UNKNOWN
  209. right = expr.operands[1]
  210. if not isinstance(right, StrExpr):
  211. return TRUTH_VALUE_UNKNOWN
  212. return fixed_comparison(platform, op, right.value)
  213. elif isinstance(expr, CallExpr):
  214. if not isinstance(expr.callee, MemberExpr):
  215. return TRUTH_VALUE_UNKNOWN
  216. if len(expr.args) != 1 or not isinstance(expr.args[0], StrExpr):
  217. return TRUTH_VALUE_UNKNOWN
  218. if not is_sys_attr(expr.callee.expr, "platform"):
  219. return TRUTH_VALUE_UNKNOWN
  220. if expr.callee.name != "startswith":
  221. return TRUTH_VALUE_UNKNOWN
  222. if platform.startswith(expr.args[0].value):
  223. return ALWAYS_TRUE
  224. else:
  225. return ALWAYS_FALSE
  226. else:
  227. return TRUTH_VALUE_UNKNOWN
  228. Targ = TypeVar("Targ", int, str, Tuple[int, ...])
  229. def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
  230. rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
  231. if op == "==":
  232. return rmap[left == right]
  233. if op == "!=":
  234. return rmap[left != right]
  235. if op == "<=":
  236. return rmap[left <= right]
  237. if op == ">=":
  238. return rmap[left >= right]
  239. if op == "<":
  240. return rmap[left < right]
  241. if op == ">":
  242. return rmap[left > right]
  243. return TRUTH_VALUE_UNKNOWN
  244. def contains_int_or_tuple_of_ints(expr: Expression) -> None | int | tuple[int, ...]:
  245. if isinstance(expr, IntExpr):
  246. return expr.value
  247. if isinstance(expr, TupleExpr):
  248. if literal(expr) == LITERAL_YES:
  249. thing = []
  250. for x in expr.items:
  251. if not isinstance(x, IntExpr):
  252. return None
  253. thing.append(x.value)
  254. return tuple(thing)
  255. return None
  256. def contains_sys_version_info(expr: Expression) -> None | int | tuple[int | None, int | None]:
  257. if is_sys_attr(expr, "version_info"):
  258. return (None, None) # Same as sys.version_info[:]
  259. if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, "version_info"):
  260. index = expr.index
  261. if isinstance(index, IntExpr):
  262. return index.value
  263. if isinstance(index, SliceExpr):
  264. if index.stride is not None:
  265. if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
  266. return None
  267. begin = end = None
  268. if index.begin_index is not None:
  269. if not isinstance(index.begin_index, IntExpr):
  270. return None
  271. begin = index.begin_index.value
  272. if index.end_index is not None:
  273. if not isinstance(index.end_index, IntExpr):
  274. return None
  275. end = index.end_index.value
  276. return (begin, end)
  277. return None
  278. def is_sys_attr(expr: Expression, name: str) -> bool:
  279. # TODO: This currently doesn't work with code like this:
  280. # - import sys as _sys
  281. # - from sys import version_info
  282. if isinstance(expr, MemberExpr) and expr.name == name:
  283. if isinstance(expr.expr, NameExpr) and expr.expr.name == "sys":
  284. # TODO: Guard against a local named sys, etc.
  285. # (Though later passes will still do most checking.)
  286. return True
  287. return False
  288. def mark_block_unreachable(block: Block) -> None:
  289. block.is_unreachable = True
  290. block.accept(MarkImportsUnreachableVisitor())
  291. class MarkImportsUnreachableVisitor(TraverserVisitor):
  292. """Visitor that flags all imports nested within a node as unreachable."""
  293. def visit_import(self, node: Import) -> None:
  294. node.is_unreachable = True
  295. def visit_import_from(self, node: ImportFrom) -> None:
  296. node.is_unreachable = True
  297. def visit_import_all(self, node: ImportAll) -> None:
  298. node.is_unreachable = True
  299. def mark_block_mypy_only(block: Block) -> None:
  300. block.accept(MarkImportsMypyOnlyVisitor())
  301. class MarkImportsMypyOnlyVisitor(TraverserVisitor):
  302. """Visitor that sets is_mypy_only (which affects priority)."""
  303. def visit_import(self, node: Import) -> None:
  304. node.is_mypy_only = True
  305. def visit_import_from(self, node: ImportFrom) -> None:
  306. node.is_mypy_only = True
  307. def visit_import_all(self, node: ImportAll) -> None:
  308. node.is_mypy_only = True
  309. def visit_func_def(self, node: FuncDef) -> None:
  310. node.is_mypy_only = True