testinfer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. """Test cases for type inference helper functions."""
  2. from __future__ import annotations
  3. from mypy.argmap import map_actuals_to_formals
  4. from mypy.checker import DisjointDict, group_comparison_operands
  5. from mypy.literals import Key
  6. from mypy.nodes import ARG_NAMED, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind, NameExpr
  7. from mypy.test.helpers import Suite, assert_equal
  8. from mypy.test.typefixture import TypeFixture
  9. from mypy.types import AnyType, TupleType, Type, TypeOfAny
  10. class MapActualsToFormalsSuite(Suite):
  11. """Test cases for argmap.map_actuals_to_formals."""
  12. def test_basic(self) -> None:
  13. self.assert_map([], [], [])
  14. def test_positional_only(self) -> None:
  15. self.assert_map([ARG_POS], [ARG_POS], [[0]])
  16. self.assert_map([ARG_POS, ARG_POS], [ARG_POS, ARG_POS], [[0], [1]])
  17. def test_optional(self) -> None:
  18. self.assert_map([], [ARG_OPT], [[]])
  19. self.assert_map([ARG_POS], [ARG_OPT], [[0]])
  20. self.assert_map([ARG_POS], [ARG_OPT, ARG_OPT], [[0], []])
  21. def test_callee_star(self) -> None:
  22. self.assert_map([], [ARG_STAR], [[]])
  23. self.assert_map([ARG_POS], [ARG_STAR], [[0]])
  24. self.assert_map([ARG_POS, ARG_POS], [ARG_STAR], [[0, 1]])
  25. def test_caller_star(self) -> None:
  26. self.assert_map([ARG_STAR], [ARG_STAR], [[0]])
  27. self.assert_map([ARG_POS, ARG_STAR], [ARG_STAR], [[0, 1]])
  28. self.assert_map([ARG_STAR], [ARG_POS, ARG_STAR], [[0], [0]])
  29. self.assert_map([ARG_STAR], [ARG_OPT, ARG_STAR], [[0], [0]])
  30. def test_too_many_caller_args(self) -> None:
  31. self.assert_map([ARG_POS], [], [])
  32. self.assert_map([ARG_STAR], [], [])
  33. self.assert_map([ARG_STAR], [ARG_POS], [[0]])
  34. def test_tuple_star(self) -> None:
  35. any_type = AnyType(TypeOfAny.special_form)
  36. self.assert_vararg_map([ARG_STAR], [ARG_POS], [[0]], self.make_tuple(any_type))
  37. self.assert_vararg_map(
  38. [ARG_STAR], [ARG_POS, ARG_POS], [[0], [0]], self.make_tuple(any_type, any_type)
  39. )
  40. self.assert_vararg_map(
  41. [ARG_STAR],
  42. [ARG_POS, ARG_OPT, ARG_OPT],
  43. [[0], [0], []],
  44. self.make_tuple(any_type, any_type),
  45. )
  46. def make_tuple(self, *args: Type) -> TupleType:
  47. return TupleType(list(args), TypeFixture().std_tuple)
  48. def test_named_args(self) -> None:
  49. self.assert_map(["x"], [(ARG_POS, "x")], [[0]])
  50. self.assert_map(["y", "x"], [(ARG_POS, "x"), (ARG_POS, "y")], [[1], [0]])
  51. def test_some_named_args(self) -> None:
  52. self.assert_map(["y"], [(ARG_OPT, "x"), (ARG_OPT, "y"), (ARG_OPT, "z")], [[], [0], []])
  53. def test_missing_named_arg(self) -> None:
  54. self.assert_map(["y"], [(ARG_OPT, "x")], [[]])
  55. def test_duplicate_named_arg(self) -> None:
  56. self.assert_map(["x", "x"], [(ARG_OPT, "x")], [[0, 1]])
  57. def test_varargs_and_bare_asterisk(self) -> None:
  58. self.assert_map([ARG_STAR], [ARG_STAR, (ARG_NAMED, "x")], [[0], []])
  59. self.assert_map([ARG_STAR, "x"], [ARG_STAR, (ARG_NAMED, "x")], [[0], [1]])
  60. def test_keyword_varargs(self) -> None:
  61. self.assert_map(["x"], [ARG_STAR2], [[0]])
  62. self.assert_map(["x", ARG_STAR2], [ARG_STAR2], [[0, 1]])
  63. self.assert_map(["x", ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [1]])
  64. self.assert_map([ARG_POS, ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [1]])
  65. def test_both_kinds_of_varargs(self) -> None:
  66. self.assert_map([ARG_STAR, ARG_STAR2], [(ARG_POS, "x"), (ARG_POS, "y")], [[0, 1], [0, 1]])
  67. def test_special_cases(self) -> None:
  68. self.assert_map([ARG_STAR], [ARG_STAR, ARG_STAR2], [[0], []])
  69. self.assert_map([ARG_STAR, ARG_STAR2], [ARG_STAR, ARG_STAR2], [[0], [1]])
  70. self.assert_map([ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [0]])
  71. self.assert_map([ARG_STAR2], [ARG_STAR2], [[0]])
  72. def assert_map(
  73. self,
  74. caller_kinds_: list[ArgKind | str],
  75. callee_kinds_: list[ArgKind | tuple[ArgKind, str]],
  76. expected: list[list[int]],
  77. ) -> None:
  78. caller_kinds, caller_names = expand_caller_kinds(caller_kinds_)
  79. callee_kinds, callee_names = expand_callee_kinds(callee_kinds_)
  80. result = map_actuals_to_formals(
  81. caller_kinds,
  82. caller_names,
  83. callee_kinds,
  84. callee_names,
  85. lambda i: AnyType(TypeOfAny.special_form),
  86. )
  87. assert_equal(result, expected)
  88. def assert_vararg_map(
  89. self,
  90. caller_kinds: list[ArgKind],
  91. callee_kinds: list[ArgKind],
  92. expected: list[list[int]],
  93. vararg_type: Type,
  94. ) -> None:
  95. result = map_actuals_to_formals(caller_kinds, [], callee_kinds, [], lambda i: vararg_type)
  96. assert_equal(result, expected)
  97. def expand_caller_kinds(
  98. kinds_or_names: list[ArgKind | str],
  99. ) -> tuple[list[ArgKind], list[str | None]]:
  100. kinds = []
  101. names: list[str | None] = []
  102. for k in kinds_or_names:
  103. if isinstance(k, str):
  104. kinds.append(ARG_NAMED)
  105. names.append(k)
  106. else:
  107. kinds.append(k)
  108. names.append(None)
  109. return kinds, names
  110. def expand_callee_kinds(
  111. kinds_and_names: list[ArgKind | tuple[ArgKind, str]]
  112. ) -> tuple[list[ArgKind], list[str | None]]:
  113. kinds = []
  114. names: list[str | None] = []
  115. for v in kinds_and_names:
  116. if isinstance(v, tuple):
  117. kinds.append(v[0])
  118. names.append(v[1])
  119. else:
  120. kinds.append(v)
  121. names.append(None)
  122. return kinds, names
  123. class OperandDisjointDictSuite(Suite):
  124. """Test cases for checker.DisjointDict, which is used for type inference with operands."""
  125. def new(self) -> DisjointDict[int, str]:
  126. return DisjointDict()
  127. def test_independent_maps(self) -> None:
  128. d = self.new()
  129. d.add_mapping({0, 1}, {"group1"})
  130. d.add_mapping({2, 3, 4}, {"group2"})
  131. d.add_mapping({5, 6, 7}, {"group3"})
  132. self.assertEqual(
  133. d.items(), [({0, 1}, {"group1"}), ({2, 3, 4}, {"group2"}), ({5, 6, 7}, {"group3"})]
  134. )
  135. def test_partial_merging(self) -> None:
  136. d = self.new()
  137. d.add_mapping({0, 1}, {"group1"})
  138. d.add_mapping({1, 2}, {"group2"})
  139. d.add_mapping({3, 4}, {"group3"})
  140. d.add_mapping({5, 0}, {"group4"})
  141. d.add_mapping({5, 6}, {"group5"})
  142. d.add_mapping({4, 7}, {"group6"})
  143. self.assertEqual(
  144. d.items(),
  145. [
  146. ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}),
  147. ({3, 4, 7}, {"group3", "group6"}),
  148. ],
  149. )
  150. def test_full_merging(self) -> None:
  151. d = self.new()
  152. d.add_mapping({0, 1, 2}, {"a"})
  153. d.add_mapping({3, 4, 2}, {"b"})
  154. d.add_mapping({10, 11, 12}, {"c"})
  155. d.add_mapping({13, 14, 15}, {"d"})
  156. d.add_mapping({14, 10, 16}, {"e"})
  157. d.add_mapping({0, 10}, {"f"})
  158. self.assertEqual(
  159. d.items(),
  160. [({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"})],
  161. )
  162. def test_merge_with_multiple_overlaps(self) -> None:
  163. d = self.new()
  164. d.add_mapping({0, 1, 2}, {"a"})
  165. d.add_mapping({3, 4, 5}, {"b"})
  166. d.add_mapping({1, 2, 4, 5}, {"c"})
  167. d.add_mapping({6, 1, 2, 4, 5}, {"d"})
  168. d.add_mapping({6, 1, 2, 4, 5}, {"e"})
  169. self.assertEqual(d.items(), [({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"})])
  170. class OperandComparisonGroupingSuite(Suite):
  171. """Test cases for checker.group_comparison_operands."""
  172. def literal_keymap(self, assignable_operands: dict[int, NameExpr]) -> dict[int, Key]:
  173. output: dict[int, Key] = {}
  174. for index, expr in assignable_operands.items():
  175. output[index] = ("FakeExpr", expr.name)
  176. return output
  177. def test_basic_cases(self) -> None:
  178. # Note: the grouping function doesn't actually inspect the input exprs, so we
  179. # just default to using NameExprs for simplicity.
  180. x0 = NameExpr("x0")
  181. x1 = NameExpr("x1")
  182. x2 = NameExpr("x2")
  183. x3 = NameExpr("x3")
  184. x4 = NameExpr("x4")
  185. basic_input = [("==", x0, x1), ("==", x1, x2), ("<", x2, x3), ("==", x3, x4)]
  186. none_assignable = self.literal_keymap({})
  187. all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4})
  188. for assignable in [none_assignable, all_assignable]:
  189. self.assertEqual(
  190. group_comparison_operands(basic_input, assignable, set()),
  191. [("==", [0, 1]), ("==", [1, 2]), ("<", [2, 3]), ("==", [3, 4])],
  192. )
  193. self.assertEqual(
  194. group_comparison_operands(basic_input, assignable, {"=="}),
  195. [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4])],
  196. )
  197. self.assertEqual(
  198. group_comparison_operands(basic_input, assignable, {"<"}),
  199. [("==", [0, 1]), ("==", [1, 2]), ("<", [2, 3]), ("==", [3, 4])],
  200. )
  201. self.assertEqual(
  202. group_comparison_operands(basic_input, assignable, {"==", "<"}),
  203. [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4])],
  204. )
  205. def test_multiple_groups(self) -> None:
  206. x0 = NameExpr("x0")
  207. x1 = NameExpr("x1")
  208. x2 = NameExpr("x2")
  209. x3 = NameExpr("x3")
  210. x4 = NameExpr("x4")
  211. x5 = NameExpr("x5")
  212. self.assertEqual(
  213. group_comparison_operands(
  214. [("==", x0, x1), ("==", x1, x2), ("is", x2, x3), ("is", x3, x4)],
  215. self.literal_keymap({}),
  216. {"==", "is"},
  217. ),
  218. [("==", [0, 1, 2]), ("is", [2, 3, 4])],
  219. )
  220. self.assertEqual(
  221. group_comparison_operands(
  222. [("==", x0, x1), ("==", x1, x2), ("==", x2, x3), ("==", x3, x4)],
  223. self.literal_keymap({}),
  224. {"==", "is"},
  225. ),
  226. [("==", [0, 1, 2, 3, 4])],
  227. )
  228. self.assertEqual(
  229. group_comparison_operands(
  230. [("is", x0, x1), ("==", x1, x2), ("==", x2, x3), ("==", x3, x4)],
  231. self.literal_keymap({}),
  232. {"==", "is"},
  233. ),
  234. [("is", [0, 1]), ("==", [1, 2, 3, 4])],
  235. )
  236. self.assertEqual(
  237. group_comparison_operands(
  238. [("is", x0, x1), ("is", x1, x2), ("<", x2, x3), ("==", x3, x4), ("==", x4, x5)],
  239. self.literal_keymap({}),
  240. {"==", "is"},
  241. ),
  242. [("is", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4, 5])],
  243. )
  244. def test_multiple_groups_coalescing(self) -> None:
  245. x0 = NameExpr("x0")
  246. x1 = NameExpr("x1")
  247. x2 = NameExpr("x2")
  248. x3 = NameExpr("x3")
  249. x4 = NameExpr("x4")
  250. nothing_combined = [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4, 5])]
  251. everything_combined = [("==", [0, 1, 2, 3, 4, 5]), ("<", [2, 3])]
  252. # Note: We do 'x4 == x0' at the very end!
  253. two_groups = [
  254. ("==", x0, x1),
  255. ("==", x1, x2),
  256. ("<", x2, x3),
  257. ("==", x3, x4),
  258. ("==", x4, x0),
  259. ]
  260. self.assertEqual(
  261. group_comparison_operands(
  262. two_groups, self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), {"=="}
  263. ),
  264. everything_combined,
  265. "All vars are assignable, everything is combined",
  266. )
  267. self.assertEqual(
  268. group_comparison_operands(
  269. two_groups, self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), {"=="}
  270. ),
  271. nothing_combined,
  272. "x0 is unassignable, so no combining",
  273. )
  274. self.assertEqual(
  275. group_comparison_operands(
  276. two_groups, self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), {"=="}
  277. ),
  278. everything_combined,
  279. "Some vars are unassignable but x0 is, so we combine",
  280. )
  281. self.assertEqual(
  282. group_comparison_operands(two_groups, self.literal_keymap({0: x0, 5: x0}), {"=="}),
  283. everything_combined,
  284. "All vars are unassignable but x0 is, so we combine",
  285. )
  286. def test_multiple_groups_different_operators(self) -> None:
  287. x0 = NameExpr("x0")
  288. x1 = NameExpr("x1")
  289. x2 = NameExpr("x2")
  290. x3 = NameExpr("x3")
  291. groups = [("==", x0, x1), ("==", x1, x2), ("is", x2, x3), ("is", x3, x0)]
  292. keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0})
  293. self.assertEqual(
  294. group_comparison_operands(groups, keymap, {"==", "is"}),
  295. [("==", [0, 1, 2]), ("is", [2, 3, 4])],
  296. "Different operators can never be combined",
  297. )
  298. def test_single_pair(self) -> None:
  299. x0 = NameExpr("x0")
  300. x1 = NameExpr("x1")
  301. single_comparison = [("==", x0, x1)]
  302. expected_output = [("==", [0, 1])]
  303. assignable_combinations: list[dict[int, NameExpr]] = [{}, {0: x0}, {1: x1}, {0: x0, 1: x1}]
  304. to_group_by: list[set[str]] = [set(), {"=="}, {"is"}]
  305. for combo in assignable_combinations:
  306. for operators in to_group_by:
  307. keymap = self.literal_keymap(combo)
  308. self.assertEqual(
  309. group_comparison_operands(single_comparison, keymap, operators),
  310. expected_output,
  311. )
  312. def test_empty_pair_list(self) -> None:
  313. # This case should never occur in practice -- ComparisionExprs
  314. # always contain at least one comparison. But in case it does...
  315. self.assertEqual(group_comparison_operands([], {}, set()), [])
  316. self.assertEqual(group_comparison_operands([], {}, {"=="}), [])