run-singledispatch.test 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. # Test cases related to the functools.singledispatch decorator
  2. # Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet
  3. # (These tests will be re-enabled when mypyc supports singledispatch)
  4. [case testSpecializedImplementationUsed]
  5. from functools import singledispatch
  6. @singledispatch
  7. def fun(arg) -> bool:
  8. return False
  9. @fun.register
  10. def fun_specialized(arg: str) -> bool:
  11. return True
  12. def test_specialize() -> None:
  13. assert fun('a')
  14. assert not fun(3)
  15. [case testSubclassesOfExpectedTypeUseSpecialized]
  16. from functools import singledispatch
  17. class A: pass
  18. class B(A): pass
  19. @singledispatch
  20. def fun(arg) -> bool:
  21. return False
  22. @fun.register
  23. def fun_specialized(arg: A) -> bool:
  24. return True
  25. def test_specialize() -> None:
  26. assert fun(B())
  27. assert fun(A())
  28. [case testSuperclassImplementationNotUsedWhenSubclassHasImplementation]
  29. from functools import singledispatch
  30. class A: pass
  31. class B(A): pass
  32. @singledispatch
  33. def fun(arg) -> bool:
  34. # shouldn't be using this
  35. assert False
  36. @fun.register
  37. def fun_specialized(arg: A) -> bool:
  38. return False
  39. @fun.register
  40. def fun_specialized2(arg: B) -> bool:
  41. return True
  42. def test_specialize() -> None:
  43. assert fun(B())
  44. assert not fun(A())
  45. [case testMultipleUnderscoreFunctionsIsntError]
  46. from functools import singledispatch
  47. @singledispatch
  48. def fun(arg) -> str:
  49. return 'default'
  50. @fun.register
  51. def _(arg: str) -> str:
  52. return 'str'
  53. @fun.register
  54. def _(arg: int) -> str:
  55. return 'int'
  56. # extra function to make sure all 3 underscore functions aren't treated as one OverloadedFuncDef
  57. def a(b): pass
  58. @fun.register
  59. def _(arg: list) -> str:
  60. return 'list'
  61. def test_singledispatch() -> None:
  62. assert fun(0) == 'int'
  63. assert fun('a') == 'str'
  64. assert fun([1, 2]) == 'list'
  65. assert fun({'a': 'b'}) == 'default'
  66. [case testCanRegisterCompiledClasses]
  67. from functools import singledispatch
  68. class A: pass
  69. @singledispatch
  70. def fun(arg) -> bool:
  71. return False
  72. @fun.register
  73. def fun_specialized(arg: A) -> bool:
  74. return True
  75. def test_singledispatch() -> None:
  76. assert fun(A())
  77. assert not fun(1)
  78. [case testTypeUsedAsArgumentToRegister]
  79. from functools import singledispatch
  80. @singledispatch
  81. def fun(arg) -> bool:
  82. return False
  83. @fun.register(int)
  84. def fun_specialized(arg) -> bool:
  85. return True
  86. def test_singledispatch() -> None:
  87. assert fun(1)
  88. assert not fun('a')
  89. [case testUseRegisterAsAFunction]
  90. from functools import singledispatch
  91. @singledispatch
  92. def fun(arg) -> bool:
  93. return False
  94. def fun_specialized_impl(arg) -> bool:
  95. return True
  96. fun.register(int, fun_specialized_impl)
  97. def test_singledispatch() -> None:
  98. assert fun(0)
  99. assert not fun('a')
  100. [case testRegisterDoesntChangeFunction]
  101. from functools import singledispatch
  102. @singledispatch
  103. def fun(arg) -> bool:
  104. return False
  105. @fun.register(int)
  106. def fun_specialized(arg) -> bool:
  107. return True
  108. def test_singledispatch() -> None:
  109. assert fun_specialized('a')
  110. # TODO: turn this into a mypy error
  111. [case testNoneIsntATypeWhenUsedAsArgumentToRegister]
  112. from functools import singledispatch
  113. @singledispatch
  114. def fun(arg) -> bool:
  115. return False
  116. try:
  117. @fun.register
  118. def fun_specialized(arg: None) -> bool:
  119. return True
  120. except TypeError:
  121. pass
  122. [case testRegisteringTheSameFunctionSeveralTimes]
  123. from functools import singledispatch
  124. @singledispatch
  125. def fun(arg) -> bool:
  126. return False
  127. @fun.register(int)
  128. @fun.register(str)
  129. def fun_specialized(arg) -> bool:
  130. return True
  131. def test_singledispatch() -> None:
  132. assert fun(0)
  133. assert fun('a')
  134. assert not fun([1, 2])
  135. [case testTypeIsAnABC]
  136. from functools import singledispatch
  137. from collections.abc import Mapping
  138. @singledispatch
  139. def fun(arg) -> bool:
  140. return False
  141. @fun.register
  142. def fun_specialized(arg: Mapping) -> bool:
  143. return True
  144. def test_singledispatch() -> None:
  145. assert not fun(1)
  146. assert fun({'a': 'b'})
  147. [case testSingleDispatchMethod-xfail]
  148. from functools import singledispatchmethod
  149. class A:
  150. @singledispatchmethod
  151. def fun(self, arg) -> str:
  152. return 'default'
  153. @fun.register
  154. def fun_int(self, arg: int) -> str:
  155. return 'int'
  156. @fun.register
  157. def fun_str(self, arg: str) -> str:
  158. return 'str'
  159. def test_singledispatchmethod() -> None:
  160. x = A()
  161. assert x.fun(5) == 'int'
  162. assert x.fun('a') == 'str'
  163. assert x.fun([1, 2]) == 'default'
  164. [case testSingleDispatchMethodWithOtherDecorator-xfail]
  165. from functools import singledispatchmethod
  166. class A:
  167. @singledispatchmethod
  168. @staticmethod
  169. def fun(arg) -> str:
  170. return 'default'
  171. @fun.register
  172. @staticmethod
  173. def fun_int(arg: int) -> str:
  174. return 'int'
  175. @fun.register
  176. @staticmethod
  177. def fun_str(arg: str) -> str:
  178. return 'str'
  179. def test_singledispatchmethod() -> None:
  180. x = A()
  181. assert x.fun(5) == 'int'
  182. assert x.fun('a') == 'str'
  183. assert x.fun([1, 2]) == 'default'
  184. [case testSingledispatchTreeSumAndEqual]
  185. from functools import singledispatch
  186. class Tree:
  187. pass
  188. class Leaf(Tree):
  189. pass
  190. class Node(Tree):
  191. def __init__(self, value: int, left: Tree, right: Tree) -> None:
  192. self.value = value
  193. self.left = left
  194. self.right = right
  195. @singledispatch
  196. def calc_sum(x: Tree) -> int:
  197. raise TypeError('invalid type for x')
  198. @calc_sum.register
  199. def _(x: Leaf) -> int:
  200. return 0
  201. @calc_sum.register
  202. def _(x: Node) -> int:
  203. return x.value + calc_sum(x.left) + calc_sum(x.right)
  204. @singledispatch
  205. def equal(to_compare: Tree, known: Tree) -> bool:
  206. raise TypeError('invalid type for x')
  207. @equal.register
  208. def _(to_compare: Leaf, known: Tree) -> bool:
  209. return isinstance(known, Leaf)
  210. @equal.register
  211. def _(to_compare: Node, known: Tree) -> bool:
  212. if isinstance(known, Node):
  213. if to_compare.value != known.value:
  214. return False
  215. else:
  216. return equal(to_compare.left, known.left) and equal(to_compare.right, known.right)
  217. return False
  218. def build(n: int) -> Tree:
  219. if n == 0:
  220. return Leaf()
  221. return Node(n, build(n - 1), build(n - 1))
  222. def test_sum_and_equal():
  223. tree = build(5)
  224. tree2 = build(5)
  225. tree2.right.right.right.value = 10
  226. assert calc_sum(tree) == 57
  227. assert calc_sum(tree2) == 65
  228. assert equal(tree, tree)
  229. assert not equal(tree, tree2)
  230. tree3 = build(4)
  231. assert not equal(tree, tree3)
  232. [case testSimulateMypySingledispatch]
  233. from functools import singledispatch
  234. from mypy_extensions import trait
  235. from typing import Iterator, Union, TypeVar, Any, List, Type
  236. # based on use of singledispatch in stubtest.py
  237. class Error:
  238. def __init__(self, msg: str) -> None:
  239. self.msg = msg
  240. @trait
  241. class Node: pass
  242. class MypyFile(Node): pass
  243. class TypeInfo(Node): pass
  244. @trait
  245. class SymbolNode(Node): pass
  246. @trait
  247. class Expression(Node): pass
  248. class TypeVarLikeExpr(SymbolNode, Expression): pass
  249. class TypeVarExpr(TypeVarLikeExpr): pass
  250. class TypeAlias(SymbolNode): pass
  251. class Missing: pass
  252. MISSING = Missing()
  253. T = TypeVar("T")
  254. MaybeMissing = Union[T, Missing]
  255. @singledispatch
  256. def verify(stub: Node, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]:
  257. yield Error('unknown node type')
  258. @verify.register(MypyFile)
  259. def verify_mypyfile(stub: MypyFile, a: MaybeMissing[int], b: List[str]) -> Iterator[Error]:
  260. if isinstance(a, Missing):
  261. yield Error("shouldn't be missing")
  262. return
  263. if not isinstance(a, int):
  264. # this check should be unnecessary because of the type signature and the previous check,
  265. # but stubtest.py has this check
  266. yield Error("should be an int")
  267. return
  268. yield from verify(TypeInfo(), str, ['abc', 'def'])
  269. @verify.register(TypeInfo)
  270. def verify_typeinfo(stub: TypeInfo, a: MaybeMissing[Type[Any]], b: List[str]) -> Iterator[Error]:
  271. yield Error('in TypeInfo')
  272. yield Error('hello')
  273. @verify.register(TypeVarExpr)
  274. def verify_typevarexpr(stub: TypeVarExpr, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]:
  275. if False:
  276. yield None
  277. def verify_list(stub, a, b) -> List[str]:
  278. """Helper function that converts iterator of errors to list of messages"""
  279. return list(err.msg for err in verify(stub, a, b))
  280. def test_verify() -> None:
  281. assert verify_list(TypeAlias(), 'a', ['a', 'b']) == ['unknown node type']
  282. assert verify_list(MypyFile(), MISSING, ['a', 'b']) == ["shouldn't be missing"]
  283. assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello']
  284. assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello']
  285. assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == []
  286. [case testArgsInRegisteredImplNamedDifferentlyFromMainFunction]
  287. from functools import singledispatch
  288. @singledispatch
  289. def f(a) -> bool:
  290. return False
  291. @f.register
  292. def g(b: int) -> bool:
  293. return True
  294. def test_singledispatch():
  295. assert f(5)
  296. assert not f('a')
  297. [case testKeywordArguments]
  298. from functools import singledispatch
  299. @singledispatch
  300. def f(arg, *, kwarg: int = 0) -> int:
  301. return kwarg + 10
  302. @f.register
  303. def g(arg: int, *, kwarg: int = 5) -> int:
  304. return kwarg - 10
  305. def test_keywords():
  306. assert f('a') == 10
  307. assert f('a', kwarg=3) == 13
  308. assert f('a', kwarg=7) == 17
  309. assert f(1) == -5
  310. assert f(1, kwarg=4) == -6
  311. assert f(1, kwarg=6) == -4
  312. [case testGeneratorAndMultipleTypesOfIterable]
  313. from functools import singledispatch
  314. from typing import *
  315. @singledispatch
  316. def f(arg: Any) -> Iterable[int]:
  317. yield 1
  318. @f.register
  319. def g(arg: str) -> Iterable[int]:
  320. return [0]
  321. def test_iterables():
  322. assert f(1) != [1]
  323. assert list(f(1)) == [1]
  324. assert f('a') == [0]
  325. [case testRegisterUsedAtSameTimeAsOtherDecorators]
  326. from functools import singledispatch
  327. from typing import TypeVar
  328. class A: pass
  329. class B: pass
  330. T = TypeVar('T')
  331. def decorator(f: T) -> T:
  332. return f
  333. @singledispatch
  334. def f(arg) -> int:
  335. return 0
  336. @f.register
  337. @decorator
  338. def h(arg: str) -> int:
  339. return 2
  340. def test_singledispatch():
  341. assert f(1) == 0
  342. assert f('a') == 2
  343. [case testDecoratorModifiesFunction]
  344. from functools import singledispatch
  345. from typing import Callable, Any
  346. class A: pass
  347. def decorator(f: Callable[[Any], int]) -> Callable[[Any], int]:
  348. def wrapper(x) -> int:
  349. return f(x) * 7
  350. return wrapper
  351. @singledispatch
  352. def f(arg) -> int:
  353. return 10
  354. @f.register
  355. @decorator
  356. def h(arg: str) -> int:
  357. return 5
  358. def test_singledispatch():
  359. assert f('a') == 35
  360. assert f(A()) == 10
  361. [case testMoreSpecificTypeBeforeLessSpecificType]
  362. from functools import singledispatch
  363. class A: pass
  364. class B(A): pass
  365. @singledispatch
  366. def f(arg) -> str:
  367. return 'default'
  368. @f.register
  369. def g(arg: B) -> str:
  370. return 'b'
  371. @f.register
  372. def h(arg: A) -> str:
  373. return 'a'
  374. def test_singledispatch():
  375. assert f(B()) == 'b'
  376. assert f(A()) == 'a'
  377. assert f(5) == 'default'
  378. [case testMultipleRelatedClassesBeingRegistered]
  379. from functools import singledispatch
  380. class A: pass
  381. class B(A): pass
  382. class C(B): pass
  383. @singledispatch
  384. def f(arg) -> str: return 'default'
  385. @f.register
  386. def _(arg: A) -> str: return 'a'
  387. @f.register
  388. def _(arg: C) -> str: return 'c'
  389. @f.register
  390. def _(arg: B) -> str: return 'b'
  391. def test_singledispatch():
  392. assert f(A()) == 'a'
  393. assert f(B()) == 'b'
  394. assert f(C()) == 'c'
  395. assert f(1) == 'default'
  396. [case testRegisteredImplementationsInDifferentFiles]
  397. from other_a import f, A, B, C
  398. @f.register
  399. def a(arg: A) -> int:
  400. return 2
  401. @f.register
  402. def _(arg: C) -> int:
  403. return 3
  404. def test_singledispatch():
  405. assert f(B()) == 1
  406. assert f(A()) == 2
  407. assert f(C()) == 3
  408. assert f(1) == 0
  409. [file other_a.py]
  410. from functools import singledispatch
  411. class A: pass
  412. class B(A): pass
  413. class C(B): pass
  414. @singledispatch
  415. def f(arg) -> int:
  416. return 0
  417. @f.register
  418. def g(arg: B) -> int:
  419. return 1
  420. [case testOrderCanOnlyBeDeterminedFromMRONotIsinstanceChecks]
  421. from mypy_extensions import trait
  422. from functools import singledispatch
  423. @trait
  424. class A: pass
  425. @trait
  426. class B: pass
  427. class AB(A, B): pass
  428. class BA(B, A): pass
  429. @singledispatch
  430. def f(arg) -> str:
  431. return "default"
  432. pass
  433. @f.register
  434. def fa(arg: A) -> str:
  435. return "a"
  436. @f.register
  437. def fb(arg: B) -> str:
  438. return "b"
  439. def test_singledispatch():
  440. assert f(AB()) == "a"
  441. assert f(BA()) == "b"
  442. [case testCallingFunctionBeforeAllImplementationsRegistered]
  443. from functools import singledispatch
  444. class A: pass
  445. class B(A): pass
  446. @singledispatch
  447. def f(arg) -> str:
  448. return 'default'
  449. assert f(A()) == 'default'
  450. assert f(B()) == 'default'
  451. assert f(1) == 'default'
  452. @f.register
  453. def g(arg: A) -> str:
  454. return 'a'
  455. assert f(A()) == 'a'
  456. assert f(B()) == 'a'
  457. assert f(1) == 'default'
  458. @f.register
  459. def _(arg: B) -> str:
  460. return 'b'
  461. assert f(A()) == 'a'
  462. assert f(B()) == 'b'
  463. assert f(1) == 'default'
  464. [case testDynamicallyRegisteringFunctionFromInterpretedCode]
  465. from functools import singledispatch
  466. class A: pass
  467. class B(A): pass
  468. class C(B): pass
  469. class D(C): pass
  470. @singledispatch
  471. def f(arg) -> str:
  472. return "default"
  473. @f.register
  474. def _(arg: B) -> str:
  475. return 'b'
  476. [file register_impl.py]
  477. from native import f, A, B, C
  478. @f.register(A)
  479. def a(arg) -> str:
  480. return 'a'
  481. @f.register
  482. def c(arg: C) -> str:
  483. return 'c'
  484. [file driver.py]
  485. from native import f, A, B, C
  486. from register_impl import a, c
  487. # We need a custom driver here because register_impl has to be run before we test this (so that the
  488. # additional implementations are registered)
  489. assert f(C()) == 'c'
  490. assert f(A()) == 'a'
  491. assert f(B()) == 'b'
  492. assert a(C()) == 'a'
  493. assert c(A()) == 'c'
  494. [case testMalformedDynamicRegisterCall]
  495. from functools import singledispatch
  496. @singledispatch
  497. def f(arg) -> None:
  498. pass
  499. [file register.py]
  500. from native import f
  501. from testutil import assertRaises
  502. with assertRaises(TypeError, 'Invalid first argument to `register()`'):
  503. @f.register
  504. def _():
  505. pass
  506. [file driver.py]
  507. import register
  508. [case testCacheClearedWhenNewFunctionRegistered]
  509. from functools import singledispatch
  510. @singledispatch
  511. def f(arg) -> str:
  512. return 'default'
  513. [file register.py]
  514. from native import f
  515. class A: pass
  516. class B: pass
  517. class C: pass
  518. # annotated function
  519. assert f(A()) == 'default'
  520. @f.register
  521. def _(arg: A) -> str:
  522. return 'a'
  523. assert f(A()) == 'a'
  524. # type passed as argument
  525. assert f(B()) == 'default'
  526. @f.register(B)
  527. def _(arg: B) -> str:
  528. return 'b'
  529. assert f(B()) == 'b'
  530. # 2 argument form
  531. assert f(C()) == 'default'
  532. def c(arg) -> str:
  533. return 'c'
  534. f.register(C, c)
  535. assert f(C()) == 'c'
  536. [file driver.py]
  537. import register