reachability.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. """Utilities related to determining the reachability of code (in semantic analysis)."""
  2. from __future__ import annotations
  3. from typing import Final, Tuple, TypeVar
  4. from mypy.literals import literal
  5. from mypy.nodes import (
  6. LITERAL_YES,
  7. AssertStmt,
  8. Block,
  9. CallExpr,
  10. ComparisonExpr,
  11. Expression,
  12. FuncDef,
  13. IfStmt,
  14. Import,
  15. ImportAll,
  16. ImportFrom,
  17. IndexExpr,
  18. IntExpr,
  19. MatchStmt,
  20. MemberExpr,
  21. NameExpr,
  22. OpExpr,
  23. SliceExpr,
  24. StrExpr,
  25. TupleExpr,
  26. UnaryExpr,
  27. )
  28. from mypy.options import Options
  29. from mypy.patterns import AsPattern, OrPattern, Pattern
  30. from mypy.traverser import TraverserVisitor
  31. # Inferred truth value of an expression.
  32. ALWAYS_TRUE: Final = 1
  33. MYPY_TRUE: Final = 2 # True in mypy, False at runtime
  34. ALWAYS_FALSE: Final = 3
  35. MYPY_FALSE: Final = 4 # False in mypy, True at runtime
  36. TRUTH_VALUE_UNKNOWN: Final = 5
  37. inverted_truth_mapping: Final = {
  38. ALWAYS_TRUE: ALWAYS_FALSE,
  39. ALWAYS_FALSE: ALWAYS_TRUE,
  40. TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN,
  41. MYPY_TRUE: MYPY_FALSE,
  42. MYPY_FALSE: MYPY_TRUE,
  43. }
  44. reverse_op: Final = {"==": "==", "!=": "!=", "<": ">", ">": "<", "<=": ">=", ">=": "<="}
  45. def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None:
  46. for i in range(len(s.expr)):
  47. result = infer_condition_value(s.expr[i], options)
  48. if result in (ALWAYS_FALSE, MYPY_FALSE):
  49. # The condition is considered always false, so we skip the if/elif body.
  50. mark_block_unreachable(s.body[i])
  51. elif result in (ALWAYS_TRUE, MYPY_TRUE):
  52. # This condition is considered always true, so all of the remaining
  53. # elif/else bodies should not be checked.
  54. if result == MYPY_TRUE:
  55. # This condition is false at runtime; this will affect
  56. # import priorities.
  57. mark_block_mypy_only(s.body[i])
  58. for body in s.body[i + 1 :]:
  59. mark_block_unreachable(body)
  60. # Make sure else body always exists and is marked as
  61. # unreachable so the type checker always knows that
  62. # all control flow paths will flow through the if
  63. # statement body.
  64. if not s.else_body:
  65. s.else_body = Block([])
  66. mark_block_unreachable(s.else_body)
  67. break
  68. def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None:
  69. for i, guard in enumerate(s.guards):
  70. pattern_value = infer_pattern_value(s.patterns[i])
  71. if guard is not None:
  72. guard_value = infer_condition_value(guard, options)
  73. else:
  74. guard_value = ALWAYS_TRUE
  75. if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) or guard_value in (
  76. ALWAYS_FALSE,
  77. MYPY_FALSE,
  78. ):
  79. # The case is considered always false, so we skip the case body.
  80. mark_block_unreachable(s.bodies[i])
  81. elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) and guard_value in (
  82. ALWAYS_TRUE,
  83. MYPY_TRUE,
  84. ):
  85. for body in s.bodies[i + 1 :]:
  86. mark_block_unreachable(body)
  87. if guard_value == MYPY_TRUE:
  88. # This condition is false at runtime; this will affect
  89. # import priorities.
  90. mark_block_mypy_only(s.bodies[i])
  91. def assert_will_always_fail(s: AssertStmt, options: Options) -> bool:
  92. return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE)
  93. def infer_condition_value(expr: Expression, options: Options) -> int:
  94. """Infer whether the given condition is always true/false.
  95. Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false,
  96. MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if
  97. false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN.
  98. """
  99. pyversion = options.python_version
  100. name = ""
  101. negated = False
  102. alias = expr
  103. if isinstance(alias, UnaryExpr):
  104. if alias.op == "not":
  105. expr = alias.expr
  106. negated = True
  107. result = TRUTH_VALUE_UNKNOWN
  108. if isinstance(expr, NameExpr):
  109. name = expr.name
  110. elif isinstance(expr, MemberExpr):
  111. name = expr.name
  112. elif isinstance(expr, OpExpr) and expr.op in ("and", "or"):
  113. left = infer_condition_value(expr.left, options)
  114. if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or (
  115. left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or"
  116. ):
  117. # Either `True and <other>` or `False or <other>`: the result will
  118. # always be the right-hand-side.
  119. return infer_condition_value(expr.right, options)
  120. else:
  121. # The result will always be the left-hand-side (e.g. ALWAYS_* or
  122. # TRUTH_VALUE_UNKNOWN).
  123. return left
  124. else:
  125. result = consider_sys_version_info(expr, pyversion)
  126. if result == TRUTH_VALUE_UNKNOWN:
  127. result = consider_sys_platform(expr, options.platform)
  128. if result == TRUTH_VALUE_UNKNOWN:
  129. if name == "PY2":
  130. result = ALWAYS_FALSE
  131. elif name == "PY3":
  132. result = ALWAYS_TRUE
  133. elif name == "MYPY" or name == "TYPE_CHECKING":
  134. result = MYPY_TRUE
  135. elif name in options.always_true:
  136. result = ALWAYS_TRUE
  137. elif name in options.always_false:
  138. result = ALWAYS_FALSE
  139. if negated:
  140. result = inverted_truth_mapping[result]
  141. return result
  142. def infer_pattern_value(pattern: Pattern) -> int:
  143. if isinstance(pattern, AsPattern) and pattern.pattern is None:
  144. return ALWAYS_TRUE
  145. elif isinstance(pattern, OrPattern) and any(
  146. infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns
  147. ):
  148. return ALWAYS_TRUE
  149. else:
  150. return TRUTH_VALUE_UNKNOWN
  151. def consider_sys_version_info(expr: Expression, pyversion: tuple[int, ...]) -> int:
  152. """Consider whether expr is a comparison involving sys.version_info.
  153. Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
  154. """
  155. # Cases supported:
  156. # - sys.version_info[<int>] <compare_op> <int>
  157. # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints>
  158. # - sys.version_info <compare_op> <tuple_of_1_or_2_ints>
  159. # (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=)
  160. if not isinstance(expr, ComparisonExpr):
  161. return TRUTH_VALUE_UNKNOWN
  162. # Let's not yet support chained comparisons.
  163. if len(expr.operators) > 1:
  164. return TRUTH_VALUE_UNKNOWN
  165. op = expr.operators[0]
  166. if op not in ("==", "!=", "<=", ">=", "<", ">"):
  167. return TRUTH_VALUE_UNKNOWN
  168. index = contains_sys_version_info(expr.operands[0])
  169. thing = contains_int_or_tuple_of_ints(expr.operands[1])
  170. if index is None or thing is None:
  171. index = contains_sys_version_info(expr.operands[1])
  172. thing = contains_int_or_tuple_of_ints(expr.operands[0])
  173. op = reverse_op[op]
  174. if isinstance(index, int) and isinstance(thing, int):
  175. # sys.version_info[i] <compare_op> k
  176. if 0 <= index <= 1:
  177. return fixed_comparison(pyversion[index], op, thing)
  178. else:
  179. return TRUTH_VALUE_UNKNOWN
  180. elif isinstance(index, tuple) and isinstance(thing, tuple):
  181. lo, hi = index
  182. if lo is None:
  183. lo = 0
  184. if hi is None:
  185. hi = 2
  186. if 0 <= lo < hi <= 2:
  187. val = pyversion[lo:hi]
  188. if len(val) == len(thing) or len(val) > len(thing) and op not in ("==", "!="):
  189. return fixed_comparison(val, op, thing)
  190. return TRUTH_VALUE_UNKNOWN
  191. def consider_sys_platform(expr: Expression, platform: str) -> int:
  192. """Consider whether expr is a comparison involving sys.platform.
  193. Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN.
  194. """
  195. # Cases supported:
  196. # - sys.platform == 'posix'
  197. # - sys.platform != 'win32'
  198. # - sys.platform.startswith('win')
  199. if isinstance(expr, ComparisonExpr):
  200. # Let's not yet support chained comparisons.
  201. if len(expr.operators) > 1:
  202. return TRUTH_VALUE_UNKNOWN
  203. op = expr.operators[0]
  204. if op not in ("==", "!="):
  205. return TRUTH_VALUE_UNKNOWN
  206. if not is_sys_attr(expr.operands[0], "platform"):
  207. return TRUTH_VALUE_UNKNOWN
  208. right = expr.operands[1]
  209. if not isinstance(right, StrExpr):
  210. return TRUTH_VALUE_UNKNOWN
  211. return fixed_comparison(platform, op, right.value)
  212. elif isinstance(expr, CallExpr):
  213. if not isinstance(expr.callee, MemberExpr):
  214. return TRUTH_VALUE_UNKNOWN
  215. if len(expr.args) != 1 or not isinstance(expr.args[0], StrExpr):
  216. return TRUTH_VALUE_UNKNOWN
  217. if not is_sys_attr(expr.callee.expr, "platform"):
  218. return TRUTH_VALUE_UNKNOWN
  219. if expr.callee.name != "startswith":
  220. return TRUTH_VALUE_UNKNOWN
  221. if platform.startswith(expr.args[0].value):
  222. return ALWAYS_TRUE
  223. else:
  224. return ALWAYS_FALSE
  225. else:
  226. return TRUTH_VALUE_UNKNOWN
  227. Targ = TypeVar("Targ", int, str, Tuple[int, ...])
  228. def fixed_comparison(left: Targ, op: str, right: Targ) -> int:
  229. rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE}
  230. if op == "==":
  231. return rmap[left == right]
  232. if op == "!=":
  233. return rmap[left != right]
  234. if op == "<=":
  235. return rmap[left <= right]
  236. if op == ">=":
  237. return rmap[left >= right]
  238. if op == "<":
  239. return rmap[left < right]
  240. if op == ">":
  241. return rmap[left > right]
  242. return TRUTH_VALUE_UNKNOWN
  243. def contains_int_or_tuple_of_ints(expr: Expression) -> None | int | tuple[int, ...]:
  244. if isinstance(expr, IntExpr):
  245. return expr.value
  246. if isinstance(expr, TupleExpr):
  247. if literal(expr) == LITERAL_YES:
  248. thing = []
  249. for x in expr.items:
  250. if not isinstance(x, IntExpr):
  251. return None
  252. thing.append(x.value)
  253. return tuple(thing)
  254. return None
  255. def contains_sys_version_info(expr: Expression) -> None | int | tuple[int | None, int | None]:
  256. if is_sys_attr(expr, "version_info"):
  257. return (None, None) # Same as sys.version_info[:]
  258. if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, "version_info"):
  259. index = expr.index
  260. if isinstance(index, IntExpr):
  261. return index.value
  262. if isinstance(index, SliceExpr):
  263. if index.stride is not None:
  264. if not isinstance(index.stride, IntExpr) or index.stride.value != 1:
  265. return None
  266. begin = end = None
  267. if index.begin_index is not None:
  268. if not isinstance(index.begin_index, IntExpr):
  269. return None
  270. begin = index.begin_index.value
  271. if index.end_index is not None:
  272. if not isinstance(index.end_index, IntExpr):
  273. return None
  274. end = index.end_index.value
  275. return (begin, end)
  276. return None
  277. def is_sys_attr(expr: Expression, name: str) -> bool:
  278. # TODO: This currently doesn't work with code like this:
  279. # - import sys as _sys
  280. # - from sys import version_info
  281. if isinstance(expr, MemberExpr) and expr.name == name:
  282. if isinstance(expr.expr, NameExpr) and expr.expr.name == "sys":
  283. # TODO: Guard against a local named sys, etc.
  284. # (Though later passes will still do most checking.)
  285. return True
  286. return False
  287. def mark_block_unreachable(block: Block) -> None:
  288. block.is_unreachable = True
  289. block.accept(MarkImportsUnreachableVisitor())
  290. class MarkImportsUnreachableVisitor(TraverserVisitor):
  291. """Visitor that flags all imports nested within a node as unreachable."""
  292. def visit_import(self, node: Import) -> None:
  293. node.is_unreachable = True
  294. def visit_import_from(self, node: ImportFrom) -> None:
  295. node.is_unreachable = True
  296. def visit_import_all(self, node: ImportAll) -> None:
  297. node.is_unreachable = True
  298. def mark_block_mypy_only(block: Block) -> None:
  299. block.accept(MarkImportsMypyOnlyVisitor())
  300. class MarkImportsMypyOnlyVisitor(TraverserVisitor):
  301. """Visitor that sets is_mypy_only (which affects priority)."""
  302. def visit_import(self, node: Import) -> None:
  303. node.is_mypy_only = True
  304. def visit_import_from(self, node: ImportFrom) -> None:
  305. node.is_mypy_only = True
  306. def visit_import_all(self, node: ImportAll) -> None:
  307. node.is_mypy_only = True
  308. def visit_func_def(self, node: FuncDef) -> None:
  309. node.is_mypy_only = True