| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
- # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
- # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
- """Special methods checker and helper function's module."""
- from __future__ import annotations
- from collections.abc import Callable
- import astroid
- from astroid import bases, nodes, util
- from astroid.context import InferenceContext
- from astroid.typing import InferenceResult
- from pylint.checkers import BaseChecker
- from pylint.checkers.utils import (
- PYMETHODS,
- SPECIAL_METHODS_PARAMS,
- decorated_with,
- is_function_body_ellipsis,
- only_required_for_messages,
- safe_infer,
- )
- from pylint.lint.pylinter import PyLinter
- NEXT_METHOD = "__next__"
- def _safe_infer_call_result(
- node: nodes.FunctionDef,
- caller: nodes.FunctionDef,
- context: InferenceContext | None = None,
- ) -> InferenceResult | None:
- """Safely infer the return value of a function.
- Returns None if inference failed or if there is some ambiguity (more than
- one node has been inferred). Otherwise, returns inferred value.
- """
- try:
- inferit = node.infer_call_result(caller, context=context)
- value = next(inferit)
- except astroid.InferenceError:
- return None # inference failed
- except StopIteration:
- return None # no values inferred
- try:
- next(inferit)
- return None # there is ambiguity on the inferred node
- except astroid.InferenceError:
- return None # there is some kind of ambiguity
- except StopIteration:
- return value
- class SpecialMethodsChecker(BaseChecker):
- """Checker which verifies that special methods
- are implemented correctly.
- """
- name = "classes"
- msgs = {
- "E0301": (
- "__iter__ returns non-iterator",
- "non-iterator-returned",
- "Used when an __iter__ method returns something which is not an "
- f"iterable (i.e. has no `{NEXT_METHOD}` method)",
- {
- "old_names": [
- ("W0234", "old-non-iterator-returned-1"),
- ("E0234", "old-non-iterator-returned-2"),
- ]
- },
- ),
- "E0302": (
- "The special method %r expects %s param(s), %d %s given",
- "unexpected-special-method-signature",
- "Emitted when a special method was defined with an "
- "invalid number of parameters. If it has too few or "
- "too many, it might not work at all.",
- {"old_names": [("E0235", "bad-context-manager")]},
- ),
- "E0303": (
- "__len__ does not return non-negative integer",
- "invalid-length-returned",
- "Used when a __len__ method returns something which is not a "
- "non-negative integer",
- ),
- "E0304": (
- "__bool__ does not return bool",
- "invalid-bool-returned",
- "Used when a __bool__ method returns something which is not a bool",
- ),
- "E0305": (
- "__index__ does not return int",
- "invalid-index-returned",
- "Used when an __index__ method returns something which is not "
- "an integer",
- ),
- "E0306": (
- "__repr__ does not return str",
- "invalid-repr-returned",
- "Used when a __repr__ method returns something which is not a string",
- ),
- "E0307": (
- "__str__ does not return str",
- "invalid-str-returned",
- "Used when a __str__ method returns something which is not a string",
- ),
- "E0308": (
- "__bytes__ does not return bytes",
- "invalid-bytes-returned",
- "Used when a __bytes__ method returns something which is not bytes",
- ),
- "E0309": (
- "__hash__ does not return int",
- "invalid-hash-returned",
- "Used when a __hash__ method returns something which is not an integer",
- ),
- "E0310": (
- "__length_hint__ does not return non-negative integer",
- "invalid-length-hint-returned",
- "Used when a __length_hint__ method returns something which is not a "
- "non-negative integer",
- ),
- "E0311": (
- "__format__ does not return str",
- "invalid-format-returned",
- "Used when a __format__ method returns something which is not a string",
- ),
- "E0312": (
- "__getnewargs__ does not return a tuple",
- "invalid-getnewargs-returned",
- "Used when a __getnewargs__ method returns something which is not "
- "a tuple",
- ),
- "E0313": (
- "__getnewargs_ex__ does not return a tuple containing (tuple, dict)",
- "invalid-getnewargs-ex-returned",
- "Used when a __getnewargs_ex__ method returns something which is not "
- "of the form tuple(tuple, dict)",
- ),
- }
- def __init__(self, linter: PyLinter) -> None:
- super().__init__(linter)
- self._protocol_map: dict[
- str, Callable[[nodes.FunctionDef, InferenceResult], None]
- ] = {
- "__iter__": self._check_iter,
- "__len__": self._check_len,
- "__bool__": self._check_bool,
- "__index__": self._check_index,
- "__repr__": self._check_repr,
- "__str__": self._check_str,
- "__bytes__": self._check_bytes,
- "__hash__": self._check_hash,
- "__length_hint__": self._check_length_hint,
- "__format__": self._check_format,
- "__getnewargs__": self._check_getnewargs,
- "__getnewargs_ex__": self._check_getnewargs_ex,
- }
- @only_required_for_messages(
- "unexpected-special-method-signature",
- "non-iterator-returned",
- "invalid-length-returned",
- "invalid-bool-returned",
- "invalid-index-returned",
- "invalid-repr-returned",
- "invalid-str-returned",
- "invalid-bytes-returned",
- "invalid-hash-returned",
- "invalid-length-hint-returned",
- "invalid-format-returned",
- "invalid-getnewargs-returned",
- "invalid-getnewargs-ex-returned",
- )
- def visit_functiondef(self, node: nodes.FunctionDef) -> None:
- if not node.is_method():
- return
- inferred = _safe_infer_call_result(node, node)
- # Only want to check types that we are able to infer
- if (
- inferred
- and node.name in self._protocol_map
- and not is_function_body_ellipsis(node)
- ):
- self._protocol_map[node.name](node, inferred)
- if node.name in PYMETHODS:
- self._check_unexpected_method_signature(node)
- visit_asyncfunctiondef = visit_functiondef
- def _check_unexpected_method_signature(self, node: nodes.FunctionDef) -> None:
- expected_params = SPECIAL_METHODS_PARAMS[node.name]
- if expected_params is None:
- # This can support a variable number of parameters.
- return
- if not node.args.args and not node.args.vararg:
- # Method has no parameter, will be caught
- # by no-method-argument.
- return
- if decorated_with(node, ["builtins.staticmethod"]):
- # We expect to not take in consideration self.
- all_args = node.args.args
- else:
- all_args = node.args.args[1:]
- mandatory = len(all_args) - len(node.args.defaults)
- optional = len(node.args.defaults)
- current_params = mandatory + optional
- emit = False # If we don't know we choose a false negative
- if isinstance(expected_params, tuple):
- # The expected number of parameters can be any value from this
- # tuple, although the user should implement the method
- # to take all of them in consideration.
- emit = mandatory not in expected_params
- # mypy thinks that expected_params has type tuple[int, int] | int | None
- # But at this point it must be 'tuple[int, int]' because of the type check
- expected_params = f"between {expected_params[0]} or {expected_params[1]}" # type: ignore[assignment]
- else:
- # If the number of mandatory parameters doesn't
- # suffice, the expected parameters for this
- # function will be deduced from the optional
- # parameters.
- rest = expected_params - mandatory
- if rest == 0:
- emit = False
- elif rest < 0:
- emit = True
- elif rest > 0:
- emit = not ((optional - rest) >= 0 or node.args.vararg)
- if emit:
- verb = "was" if current_params <= 1 else "were"
- self.add_message(
- "unexpected-special-method-signature",
- args=(node.name, expected_params, current_params, verb),
- node=node,
- )
- @staticmethod
- def _is_wrapped_type(node: InferenceResult, type_: str) -> bool:
- return (
- isinstance(node, bases.Instance)
- and node.name == type_
- and not isinstance(node, nodes.Const)
- )
- @staticmethod
- def _is_int(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "int"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, int)
- @staticmethod
- def _is_str(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "str"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, str)
- @staticmethod
- def _is_bool(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "bool"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, bool)
- @staticmethod
- def _is_bytes(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "bytes"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, bytes)
- @staticmethod
- def _is_tuple(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "tuple"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, tuple)
- @staticmethod
- def _is_dict(node: InferenceResult) -> bool:
- if SpecialMethodsChecker._is_wrapped_type(node, "dict"):
- return True
- return isinstance(node, nodes.Const) and isinstance(node.value, dict)
- @staticmethod
- def _is_iterator(node: InferenceResult) -> bool:
- if isinstance(node, bases.Generator):
- # Generators can be iterated.
- return True
- if isinstance(node, nodes.ComprehensionScope):
- # Comprehensions can be iterated.
- return True
- if isinstance(node, bases.Instance):
- try:
- node.local_attr(NEXT_METHOD)
- return True
- except astroid.NotFoundError:
- pass
- elif isinstance(node, nodes.ClassDef):
- metaclass = node.metaclass()
- if metaclass and isinstance(metaclass, nodes.ClassDef):
- try:
- metaclass.local_attr(NEXT_METHOD)
- return True
- except astroid.NotFoundError:
- pass
- return False
- def _check_iter(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_iterator(inferred):
- self.add_message("non-iterator-returned", node=node)
- def _check_len(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_int(inferred):
- self.add_message("invalid-length-returned", node=node)
- elif isinstance(inferred, nodes.Const) and inferred.value < 0:
- self.add_message("invalid-length-returned", node=node)
- def _check_bool(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_bool(inferred):
- self.add_message("invalid-bool-returned", node=node)
- def _check_index(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_int(inferred):
- self.add_message("invalid-index-returned", node=node)
- def _check_repr(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_str(inferred):
- self.add_message("invalid-repr-returned", node=node)
- def _check_str(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_str(inferred):
- self.add_message("invalid-str-returned", node=node)
- def _check_bytes(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_bytes(inferred):
- self.add_message("invalid-bytes-returned", node=node)
- def _check_hash(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_int(inferred):
- self.add_message("invalid-hash-returned", node=node)
- def _check_length_hint(
- self, node: nodes.FunctionDef, inferred: InferenceResult
- ) -> None:
- if not self._is_int(inferred):
- self.add_message("invalid-length-hint-returned", node=node)
- elif isinstance(inferred, nodes.Const) and inferred.value < 0:
- self.add_message("invalid-length-hint-returned", node=node)
- def _check_format(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None:
- if not self._is_str(inferred):
- self.add_message("invalid-format-returned", node=node)
- def _check_getnewargs(
- self, node: nodes.FunctionDef, inferred: InferenceResult
- ) -> None:
- if not self._is_tuple(inferred):
- self.add_message("invalid-getnewargs-returned", node=node)
- def _check_getnewargs_ex(
- self, node: nodes.FunctionDef, inferred: InferenceResult
- ) -> None:
- if not self._is_tuple(inferred):
- self.add_message("invalid-getnewargs-ex-returned", node=node)
- return
- if not isinstance(inferred, nodes.Tuple):
- # If it's not an astroid.Tuple we can't analyze it further
- return
- found_error = False
- if len(inferred.elts) != 2:
- found_error = True
- else:
- for arg, check in (
- (inferred.elts[0], self._is_tuple),
- (inferred.elts[1], self._is_dict),
- ):
- if isinstance(arg, nodes.Call):
- arg = safe_infer(arg)
- if arg and not isinstance(arg, util.UninferableBase):
- if not check(arg):
- found_error = True
- break
- if found_error:
- self.add_message("invalid-getnewargs-ex-returned", node=node)
|