special_methods_checker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Special methods checker and helper function's module."""
  5. from __future__ import annotations
  6. from collections.abc import Callable
  7. import astroid
  8. from astroid import bases, nodes, util
  9. from astroid.context import InferenceContext
  10. from astroid.typing import InferenceResult
  11. from pylint.checkers import BaseChecker
  12. from pylint.checkers.utils import (
  13. PYMETHODS,
  14. SPECIAL_METHODS_PARAMS,
  15. decorated_with,
  16. is_function_body_ellipsis,
  17. only_required_for_messages,
  18. safe_infer,
  19. )
  20. from pylint.lint.pylinter import PyLinter
  21. NEXT_METHOD = "__next__"
  22. def _safe_infer_call_result(
  23. node: nodes.FunctionDef,
  24. caller: nodes.FunctionDef,
  25. context: InferenceContext | None = None,
  26. ) -> InferenceResult | None:
  27. """Safely infer the return value of a function.
  28. Returns None if inference failed or if there is some ambiguity (more than
  29. one node has been inferred). Otherwise, returns inferred value.
  30. """
  31. try:
  32. inferit = node.infer_call_result(caller, context=context)
  33. value = next(inferit)
  34. except astroid.InferenceError:
  35. return None # inference failed
  36. except StopIteration:
  37. return None # no values inferred
  38. try:
  39. next(inferit)
  40. return None # there is ambiguity on the inferred node
  41. except astroid.InferenceError:
  42. return None # there is some kind of ambiguity
  43. except StopIteration:
  44. return value
  45. class SpecialMethodsChecker(BaseChecker):
  46. """Checker which verifies that special methods
  47. are implemented correctly.
  48. """
  49. name = "classes"
  50. msgs = {
  51. "E0301": (
  52. "__iter__ returns non-iterator",
  53. "non-iterator-returned",
  54. "Used when an __iter__ method returns something which is not an "
  55. f"iterable (i.e. has no `{NEXT_METHOD}` method)",
  56. {
  57. "old_names": [
  58. ("W0234", "old-non-iterator-returned-1"),
  59. ("E0234", "old-non-iterator-returned-2"),
  60. ]
  61. },
  62. ),
  63. "E0302": (
  64. "The special method %r expects %s param(s), %d %s given",
  65. "unexpected-special-method-signature",
  66. "Emitted when a special method was defined with an "
  67. "invalid number of parameters. If it has too few or "
  68. "too many, it might not work at all.",
  69. {"old_names": [("E0235", "bad-context-manager")]},
  70. ),
  71. "E0303": (
  72. "__len__ does not return non-negative integer",
  73. "invalid-length-returned",
  74. "Used when a __len__ method returns something which is not a "
  75. "non-negative integer",
  76. ),
  77. "E0304": (
  78. "__bool__ does not return bool",
  79. "invalid-bool-returned",
  80. "Used when a __bool__ method returns something which is not a bool",
  81. ),
  82. "E0305": (
  83. "__index__ does not return int",
  84. "invalid-index-returned",
  85. "Used when an __index__ method returns something which is not "
  86. "an integer",
  87. ),
  88. "E0306": (
  89. "__repr__ does not return str",
  90. "invalid-repr-returned",
  91. "Used when a __repr__ method returns something which is not a string",
  92. ),
  93. "E0307": (
  94. "__str__ does not return str",
  95. "invalid-str-returned",
  96. "Used when a __str__ method returns something which is not a string",
  97. ),
  98. "E0308": (
  99. "__bytes__ does not return bytes",
  100. "invalid-bytes-returned",
  101. "Used when a __bytes__ method returns something which is not bytes",
  102. ),
  103. "E0309": (
  104. "__hash__ does not return int",
  105. "invalid-hash-returned",
  106. "Used when a __hash__ method returns something which is not an integer",
  107. ),
  108. "E0310": (
  109. "__length_hint__ does not return non-negative integer",
  110. "invalid-length-hint-returned",
  111. "Used when a __length_hint__ method returns something which is not a "
  112. "non-negative integer",
  113. ),
  114. "E0311": (
  115. "__format__ does not return str",
  116. "invalid-format-returned",
  117. "Used when a __format__ method returns something which is not a string",
  118. ),
  119. "E0312": (
  120. "__getnewargs__ does not return a tuple",
  121. "invalid-getnewargs-returned",
  122. "Used when a __getnewargs__ method returns something which is not "
  123. "a tuple",
  124. ),
  125. "E0313": (
  126. "__getnewargs_ex__ does not return a tuple containing (tuple, dict)",
  127. "invalid-getnewargs-ex-returned",
  128. "Used when a __getnewargs_ex__ method returns something which is not "
  129. "of the form tuple(tuple, dict)",
  130. ),
  131. }
  132. def __init__(self, linter: PyLinter) -> None:
  133. super().__init__(linter)
  134. self._protocol_map: dict[
  135. str, Callable[[nodes.FunctionDef, InferenceResult], None]
  136. ] = {
  137. "__iter__": self._check_iter,
  138. "__len__": self._check_len,
  139. "__bool__": self._check_bool,
  140. "__index__": self._check_index,
  141. "__repr__": self._check_repr,
  142. "__str__": self._check_str,
  143. "__bytes__": self._check_bytes,
  144. "__hash__": self._check_hash,
  145. "__length_hint__": self._check_length_hint,
  146. "__format__": self._check_format,
  147. "__getnewargs__": self._check_getnewargs,
  148. "__getnewargs_ex__": self._check_getnewargs_ex,
  149. }
  150. @only_required_for_messages(
  151. "unexpected-special-method-signature",
  152. "non-iterator-returned",
  153. "invalid-length-returned",
  154. "invalid-bool-returned",
  155. "invalid-index-returned",
  156. "invalid-repr-returned",
  157. "invalid-str-returned",
  158. "invalid-bytes-returned",
  159. "invalid-hash-returned",
  160. "invalid-length-hint-returned",
  161. "invalid-format-returned",
  162. "invalid-getnewargs-returned",
  163. "invalid-getnewargs-ex-returned",
  164. )
  165. def visit_functiondef(self, node: nodes.FunctionDef) -> None:
  166. if not node.is_method():
  167. return
  168. inferred = _safe_infer_call_result(node, node)
  169. # Only want to check types that we are able to infer
  170. if (
  171. inferred
  172. and node.name in self._protocol_map
  173. and not is_function_body_ellipsis(node)
  174. ):
  175. self._protocol_map[node.name](node, inferred)
  176. if node.name in PYMETHODS:
  177. self._check_unexpected_method_signature(node)
  178. visit_asyncfunctiondef = visit_functiondef
  179. def _check_unexpected_method_signature(self, node: nodes.FunctionDef) -> None:
  180. expected_params = SPECIAL_METHODS_PARAMS[node.name]
  181. if expected_params is None:
  182. # This can support a variable number of parameters.
  183. return
  184. if not node.args.args and not node.args.vararg:
  185. # Method has no parameter, will be caught
  186. # by no-method-argument.
  187. return
  188. if decorated_with(node, ["builtins.staticmethod"]):
  189. # We expect to not take in consideration self.
  190. all_args = node.args.args
  191. else:
  192. all_args = node.args.args[1:]
  193. mandatory = len(all_args) - len(node.args.defaults)
  194. optional = len(node.args.defaults)
  195. current_params = mandatory + optional
  196. emit = False # If we don't know we choose a false negative
  197. if isinstance(expected_params, tuple):
  198. # The expected number of parameters can be any value from this
  199. # tuple, although the user should implement the method
  200. # to take all of them in consideration.
  201. emit = mandatory not in expected_params
  202. # mypy thinks that expected_params has type tuple[int, int] | int | None
  203. # But at this point it must be 'tuple[int, int]' because of the type check
  204. expected_params = f"between {expected_params[0]} or {expected_params[1]}" # type: ignore[assignment]
  205. else:
  206. # If the number of mandatory parameters doesn't
  207. # suffice, the expected parameters for this
  208. # function will be deduced from the optional
  209. # parameters.
  210. rest = expected_params - mandatory
  211. if rest == 0:
  212. emit = False
  213. elif rest < 0:
  214. emit = True
  215. elif rest > 0:
  216. emit = not ((optional - rest) >= 0 or node.args.vararg)
  217. if emit:
  218. verb = "was" if current_params <= 1 else "were"
  219. self.add_message(
  220. "unexpected-special-method-signature",
  221. args=(node.name, expected_params, current_params, verb),
  222. node=node,
  223. )
  224. @staticmethod
  225. def _is_wrapped_type(node: InferenceResult, type_: str) -> bool:
  226. return (
  227. isinstance(node, bases.Instance)
  228. and node.name == type_
  229. and not isinstance(node, nodes.Const)
  230. )
  231. @staticmethod
  232. def _is_int(node: InferenceResult) -> bool:
  233. if SpecialMethodsChecker._is_wrapped_type(node, "int"):
  234. return True
  235. return isinstance(node, nodes.Const) and isinstance(node.value, int)
  236. @staticmethod
  237. def _is_str(node: InferenceResult) -> bool:
  238. if SpecialMethodsChecker._is_wrapped_type(node, "str"):
  239. return True
  240. return isinstance(node, nodes.Const) and isinstance(node.value, str)
  241. @staticmethod
  242. def _is_bool(node: InferenceResult) -> bool:
  243. if SpecialMethodsChecker._is_wrapped_type(node, "bool"):
  244. return True
  245. return isinstance(node, nodes.Const) and isinstance(node.value, bool)
  246. @staticmethod
  247. def _is_bytes(node: InferenceResult) -> bool:
  248. if SpecialMethodsChecker._is_wrapped_type(node, "bytes"):
  249. return True
  250. return isinstance(node, nodes.Const) and isinstance(node.value, bytes)
  251. @staticmethod
  252. def _is_tuple(node: InferenceResult) -> bool:
  253. if SpecialMethodsChecker._is_wrapped_type(node, "tuple"):
  254. return True
  255. return isinstance(node, nodes.Const) and isinstance(node.value, tuple)
  256. @staticmethod
  257. def _is_dict(node: InferenceResult) -> bool:
  258. if SpecialMethodsChecker._is_wrapped_type(node, "dict"):
  259. return True
  260. return isinstance(node, nodes.Const) and isinstance(node.value, dict)
  261. @staticmethod
  262. def _is_iterator(node: InferenceResult) -> bool:
  263. if isinstance(node, bases.Generator):
  264. # Generators can be iterated.
  265. return True
  266. if isinstance(node, nodes.ComprehensionScope):
  267. # Comprehensions can be iterated.
  268. return True
  269. if isinstance(node, bases.Instance):
  270. try:
  271. node.local_attr(NEXT_METHOD)
  272. return True
  273. except astroid.NotFoundError:
  274. pass
  275. elif isinstance(node, nodes.ClassDef):
  276. metaclass = node.metaclass()
  277. if metaclass and isinstance(metaclass, nodes.ClassDef):
  278. try:
  279. metaclass.local_attr(NEXT_METHOD)
  280. return True
  281. except astroid.NotFoundError:
  282. pass
  283. return False
  284. def _check_iter(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  285. if not self._is_iterator(inferred):
  286. self.add_message("non-iterator-returned", node=node)
  287. def _check_len(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  288. if not self._is_int(inferred):
  289. self.add_message("invalid-length-returned", node=node)
  290. elif isinstance(inferred, nodes.Const) and inferred.value < 0:
  291. self.add_message("invalid-length-returned", node=node)
  292. def _check_bool(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  293. if not self._is_bool(inferred):
  294. self.add_message("invalid-bool-returned", node=node)
  295. def _check_index(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  296. if not self._is_int(inferred):
  297. self.add_message("invalid-index-returned", node=node)
  298. def _check_repr(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  299. if not self._is_str(inferred):
  300. self.add_message("invalid-repr-returned", node=node)
  301. def _check_str(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  302. if not self._is_str(inferred):
  303. self.add_message("invalid-str-returned", node=node)
  304. def _check_bytes(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  305. if not self._is_bytes(inferred):
  306. self.add_message("invalid-bytes-returned", node=node)
  307. def _check_hash(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  308. if not self._is_int(inferred):
  309. self.add_message("invalid-hash-returned", node=node)
  310. def _check_length_hint(
  311. self, node: nodes.FunctionDef, inferred: InferenceResult
  312. ) -> None:
  313. if not self._is_int(inferred):
  314. self.add_message("invalid-length-hint-returned", node=node)
  315. elif isinstance(inferred, nodes.Const) and inferred.value < 0:
  316. self.add_message("invalid-length-hint-returned", node=node)
  317. def _check_format(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
  318. if not self._is_str(inferred):
  319. self.add_message("invalid-format-returned", node=node)
  320. def _check_getnewargs(
  321. self, node: nodes.FunctionDef, inferred: InferenceResult
  322. ) -> None:
  323. if not self._is_tuple(inferred):
  324. self.add_message("invalid-getnewargs-returned", node=node)
  325. def _check_getnewargs_ex(
  326. self, node: nodes.FunctionDef, inferred: InferenceResult
  327. ) -> None:
  328. if not self._is_tuple(inferred):
  329. self.add_message("invalid-getnewargs-ex-returned", node=node)
  330. return
  331. if not isinstance(inferred, nodes.Tuple):
  332. # If it's not an astroid.Tuple we can't analyze it further
  333. return
  334. found_error = False
  335. if len(inferred.elts) != 2:
  336. found_error = True
  337. else:
  338. for arg, check in (
  339. (inferred.elts[0], self._is_tuple),
  340. (inferred.elts[1], self._is_dict),
  341. ):
  342. if isinstance(arg, nodes.Call):
  343. arg = safe_infer(arg)
  344. if arg and not isinstance(arg, util.UninferableBase):
  345. if not check(arg):
  346. found_error = True
  347. break
  348. if found_error:
  349. self.add_message("invalid-getnewargs-ex-returned", node=node)