modified_iterating_checker.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt
  4. from __future__ import annotations
  5. from typing import TYPE_CHECKING
  6. from astroid import nodes
  7. from pylint import checkers, interfaces
  8. from pylint.checkers import utils
  9. if TYPE_CHECKING:
  10. from pylint.lint import PyLinter
  11. _LIST_MODIFIER_METHODS = {"append", "remove"}
  12. _SET_MODIFIER_METHODS = {"add", "remove"}
  13. class ModifiedIterationChecker(checkers.BaseChecker):
  14. """Checks for modified iterators in for loops iterations.
  15. Currently supports `for` loops for Sets, Dictionaries and Lists.
  16. """
  17. name = "modified_iteration"
  18. msgs = {
  19. "W4701": (
  20. "Iterated list '%s' is being modified inside for loop body, consider iterating through a copy of it "
  21. "instead.",
  22. "modified-iterating-list",
  23. "Emitted when items are added or removed to a list being iterated through. "
  24. "Doing so can result in unexpected behaviour, that is why it is preferred to use a copy of the list.",
  25. ),
  26. "E4702": (
  27. "Iterated dict '%s' is being modified inside for loop body, iterate through a copy of it instead.",
  28. "modified-iterating-dict",
  29. "Emitted when items are added or removed to a dict being iterated through. "
  30. "Doing so raises a RuntimeError.",
  31. ),
  32. "E4703": (
  33. "Iterated set '%s' is being modified inside for loop body, iterate through a copy of it instead.",
  34. "modified-iterating-set",
  35. "Emitted when items are added or removed to a set being iterated through. "
  36. "Doing so raises a RuntimeError.",
  37. ),
  38. }
  39. options = ()
  40. @utils.only_required_for_messages(
  41. "modified-iterating-list", "modified-iterating-dict", "modified-iterating-set"
  42. )
  43. def visit_for(self, node: nodes.For) -> None:
  44. iter_obj = node.iter
  45. for body_node in node.body:
  46. self._modified_iterating_check_on_node_and_children(body_node, iter_obj)
  47. def _modified_iterating_check_on_node_and_children(
  48. self, body_node: nodes.NodeNG, iter_obj: nodes.NodeNG
  49. ) -> None:
  50. """See if node or any of its children raises modified iterating messages."""
  51. self._modified_iterating_check(body_node, iter_obj)
  52. for child in body_node.get_children():
  53. self._modified_iterating_check_on_node_and_children(child, iter_obj)
  54. def _modified_iterating_check(
  55. self, node: nodes.NodeNG, iter_obj: nodes.NodeNG
  56. ) -> None:
  57. msg_id = None
  58. if isinstance(node, nodes.Delete) and any(
  59. self._deleted_iteration_target_cond(t, iter_obj) for t in node.targets
  60. ):
  61. inferred = utils.safe_infer(iter_obj)
  62. if isinstance(inferred, nodes.List):
  63. msg_id = "modified-iterating-list"
  64. elif isinstance(inferred, nodes.Dict):
  65. msg_id = "modified-iterating-dict"
  66. elif isinstance(inferred, nodes.Set):
  67. msg_id = "modified-iterating-set"
  68. elif not isinstance(iter_obj, (nodes.Name, nodes.Attribute)):
  69. pass
  70. elif self._modified_iterating_list_cond(node, iter_obj):
  71. msg_id = "modified-iterating-list"
  72. elif self._modified_iterating_dict_cond(node, iter_obj):
  73. msg_id = "modified-iterating-dict"
  74. elif self._modified_iterating_set_cond(node, iter_obj):
  75. msg_id = "modified-iterating-set"
  76. if msg_id:
  77. if isinstance(iter_obj, nodes.Attribute):
  78. obj_name = iter_obj.attrname
  79. else:
  80. obj_name = iter_obj.name
  81. self.add_message(
  82. msg_id,
  83. node=node,
  84. args=(obj_name,),
  85. confidence=interfaces.INFERENCE,
  86. )
  87. @staticmethod
  88. def _is_node_expr_that_calls_attribute_name(node: nodes.NodeNG) -> bool:
  89. return (
  90. isinstance(node, nodes.Expr)
  91. and isinstance(node.value, nodes.Call)
  92. and isinstance(node.value.func, nodes.Attribute)
  93. and isinstance(node.value.func.expr, nodes.Name)
  94. )
  95. @staticmethod
  96. def _common_cond_list_set(
  97. node: nodes.Expr,
  98. iter_obj: nodes.Name | nodes.Attribute,
  99. infer_val: nodes.List | nodes.Set,
  100. ) -> bool:
  101. iter_obj_name = (
  102. iter_obj.attrname
  103. if isinstance(iter_obj, nodes.Attribute)
  104. else iter_obj.name
  105. )
  106. return (infer_val == utils.safe_infer(iter_obj)) and ( # type: ignore[no-any-return]
  107. node.value.func.expr.name == iter_obj_name
  108. )
  109. @staticmethod
  110. def _is_node_assigns_subscript_name(node: nodes.NodeNG) -> bool:
  111. return isinstance(node, nodes.Assign) and (
  112. isinstance(node.targets[0], nodes.Subscript)
  113. and (isinstance(node.targets[0].value, nodes.Name))
  114. )
  115. def _modified_iterating_list_cond(
  116. self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
  117. ) -> bool:
  118. if not self._is_node_expr_that_calls_attribute_name(node):
  119. return False
  120. infer_val = utils.safe_infer(node.value.func.expr)
  121. if not isinstance(infer_val, nodes.List):
  122. return False
  123. return (
  124. self._common_cond_list_set(node, iter_obj, infer_val)
  125. and node.value.func.attrname in _LIST_MODIFIER_METHODS
  126. )
  127. def _modified_iterating_dict_cond(
  128. self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
  129. ) -> bool:
  130. if not self._is_node_assigns_subscript_name(node):
  131. return False
  132. # Do not emit when merely updating the same key being iterated
  133. if (
  134. isinstance(iter_obj, nodes.Name)
  135. and iter_obj.name == node.targets[0].value.name
  136. and isinstance(iter_obj.parent.target, nodes.AssignName)
  137. and isinstance(node.targets[0].slice, nodes.Name)
  138. and iter_obj.parent.target.name == node.targets[0].slice.name
  139. ):
  140. return False
  141. infer_val = utils.safe_infer(node.targets[0].value)
  142. if not isinstance(infer_val, nodes.Dict):
  143. return False
  144. if infer_val != utils.safe_infer(iter_obj):
  145. return False
  146. if isinstance(iter_obj, nodes.Attribute):
  147. iter_obj_name = iter_obj.attrname
  148. else:
  149. iter_obj_name = iter_obj.name
  150. return node.targets[0].value.name == iter_obj_name # type: ignore[no-any-return]
  151. def _modified_iterating_set_cond(
  152. self, node: nodes.NodeNG, iter_obj: nodes.Name | nodes.Attribute
  153. ) -> bool:
  154. if not self._is_node_expr_that_calls_attribute_name(node):
  155. return False
  156. infer_val = utils.safe_infer(node.value.func.expr)
  157. if not isinstance(infer_val, nodes.Set):
  158. return False
  159. return (
  160. self._common_cond_list_set(node, iter_obj, infer_val)
  161. and node.value.func.attrname in _SET_MODIFIER_METHODS
  162. )
  163. def _deleted_iteration_target_cond(
  164. self, node: nodes.DelName, iter_obj: nodes.NodeNG
  165. ) -> bool:
  166. if not isinstance(node, nodes.DelName):
  167. return False
  168. if not isinstance(iter_obj.parent, nodes.For):
  169. return False
  170. if not isinstance(
  171. iter_obj.parent.target, (nodes.AssignName, nodes.BaseContainer)
  172. ):
  173. return False
  174. return any(
  175. t == node.name
  176. for t in utils.find_assigned_names_recursive(iter_obj.parent.target)
  177. )
  178. def register(linter: PyLinter) -> None:
  179. linter.register_checker(ModifiedIterationChecker(linter))