nested_min_max.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/PyCQA/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/PyCQA/pylint/blob/main/CONTRIBUTORS.txt
  4. """Check for use of nested min/max functions."""
  5. from __future__ import annotations
  6. import copy
  7. from typing import TYPE_CHECKING
  8. from astroid import nodes, objects
  9. from pylint.checkers import BaseChecker
  10. from pylint.checkers.utils import only_required_for_messages, safe_infer
  11. from pylint.interfaces import INFERENCE
  12. if TYPE_CHECKING:
  13. from pylint.lint import PyLinter
  14. DICT_TYPES = (
  15. objects.DictValues,
  16. objects.DictKeys,
  17. objects.DictItems,
  18. nodes.node_classes.Dict,
  19. )
  20. class NestedMinMaxChecker(BaseChecker):
  21. """Multiple nested min/max calls on the same line will raise multiple messages.
  22. This behaviour is intended as it would slow down the checker to check
  23. for nested call with minimal benefits.
  24. """
  25. FUNC_NAMES = ("builtins.min", "builtins.max")
  26. name = "nested_min_max"
  27. msgs = {
  28. "W3301": (
  29. "Do not use nested call of '%s'; it's possible to do '%s' instead",
  30. "nested-min-max",
  31. "Nested calls ``min(1, min(2, 3))`` can be rewritten as ``min(1, 2, 3)``.",
  32. )
  33. }
  34. @classmethod
  35. def is_min_max_call(cls, node: nodes.NodeNG) -> bool:
  36. if not isinstance(node, nodes.Call):
  37. return False
  38. inferred = safe_infer(node.func)
  39. return (
  40. isinstance(inferred, nodes.FunctionDef)
  41. and inferred.qname() in cls.FUNC_NAMES
  42. )
  43. @classmethod
  44. def get_redundant_calls(cls, node: nodes.Call) -> list[nodes.Call]:
  45. return [
  46. arg
  47. for arg in node.args
  48. if cls.is_min_max_call(arg) and arg.func.name == node.func.name
  49. ]
  50. @only_required_for_messages("nested-min-max")
  51. def visit_call(self, node: nodes.Call) -> None:
  52. if not self.is_min_max_call(node):
  53. return
  54. redundant_calls = self.get_redundant_calls(node)
  55. if not redundant_calls:
  56. return
  57. fixed_node = copy.copy(node)
  58. while len(redundant_calls) > 0:
  59. for i, arg in enumerate(fixed_node.args):
  60. # Exclude any calls with generator expressions as there is no
  61. # clear better suggestion for them.
  62. if isinstance(arg, nodes.Call) and any(
  63. isinstance(a, nodes.GeneratorExp) for a in arg.args
  64. ):
  65. return
  66. if arg in redundant_calls:
  67. fixed_node.args = (
  68. fixed_node.args[:i] + arg.args + fixed_node.args[i + 1 :]
  69. )
  70. break
  71. redundant_calls = self.get_redundant_calls(fixed_node)
  72. for idx, arg in enumerate(fixed_node.args):
  73. if not isinstance(arg, nodes.Const):
  74. inferred = safe_infer(arg)
  75. if isinstance(
  76. inferred, (nodes.List, nodes.Tuple, nodes.Set, *DICT_TYPES)
  77. ):
  78. splat_node = nodes.Starred(lineno=inferred.lineno)
  79. splat_node.value = arg
  80. fixed_node.args = (
  81. fixed_node.args[:idx]
  82. + [splat_node]
  83. + fixed_node.args[idx + 1 : idx]
  84. )
  85. self.add_message(
  86. "nested-min-max",
  87. node=node,
  88. args=(node.func.name, fixed_node.as_string()),
  89. confidence=INFERENCE,
  90. )
  91. def register(linter: PyLinter) -> None:
  92. linter.register_checker(NestedMinMaxChecker(linter))