test_emitwrapper.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. import unittest
  3. from mypy.test.helpers import assert_string_arrays_equal
  4. from mypyc.codegen.emit import Emitter, EmitterContext, ReturnHandler
  5. from mypyc.codegen.emitwrapper import generate_arg_check
  6. from mypyc.ir.rtypes import int_rprimitive, list_rprimitive
  7. from mypyc.namegen import NameGenerator
  8. class TestArgCheck(unittest.TestCase):
  9. def setUp(self) -> None:
  10. self.context = EmitterContext(NameGenerator([["mod"]]))
  11. def test_check_list(self) -> None:
  12. emitter = Emitter(self.context)
  13. generate_arg_check("x", list_rprimitive, emitter, ReturnHandler("NULL"))
  14. lines = emitter.fragments
  15. self.assert_lines(
  16. [
  17. "PyObject *arg_x;",
  18. "if (likely(PyList_Check(obj_x)))",
  19. " arg_x = obj_x;",
  20. "else {",
  21. ' CPy_TypeError("list", obj_x);',
  22. " return NULL;",
  23. "}",
  24. ],
  25. lines,
  26. )
  27. def test_check_int(self) -> None:
  28. emitter = Emitter(self.context)
  29. generate_arg_check("x", int_rprimitive, emitter, ReturnHandler("NULL"))
  30. generate_arg_check("y", int_rprimitive, emitter, ReturnHandler("NULL"), optional=True)
  31. lines = emitter.fragments
  32. self.assert_lines(
  33. [
  34. "CPyTagged arg_x;",
  35. "if (likely(PyLong_Check(obj_x)))",
  36. " arg_x = CPyTagged_BorrowFromObject(obj_x);",
  37. "else {",
  38. ' CPy_TypeError("int", obj_x); return NULL;',
  39. "}",
  40. "CPyTagged arg_y;",
  41. "if (obj_y == NULL) {",
  42. " arg_y = CPY_INT_TAG;",
  43. "} else if (likely(PyLong_Check(obj_y)))",
  44. " arg_y = CPyTagged_BorrowFromObject(obj_y);",
  45. "else {",
  46. ' CPy_TypeError("int", obj_y); return NULL;',
  47. "}",
  48. ],
  49. lines,
  50. )
  51. def assert_lines(self, expected: list[str], actual: list[str]) -> None:
  52. actual = [line.rstrip("\n") for line in actual]
  53. assert_string_arrays_equal(expected, actual, "Invalid output")