| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- from __future__ import annotations
- import unittest
- from mypy.test.helpers import assert_string_arrays_equal
- from mypyc.codegen.emit import Emitter, EmitterContext, ReturnHandler
- from mypyc.codegen.emitwrapper import generate_arg_check
- from mypyc.ir.rtypes import int_rprimitive, list_rprimitive
- from mypyc.namegen import NameGenerator
- class TestArgCheck(unittest.TestCase):
- def setUp(self) -> None:
- self.context = EmitterContext(NameGenerator([["mod"]]))
- def test_check_list(self) -> None:
- emitter = Emitter(self.context)
- generate_arg_check("x", list_rprimitive, emitter, ReturnHandler("NULL"))
- lines = emitter.fragments
- self.assert_lines(
- [
- "PyObject *arg_x;",
- "if (likely(PyList_Check(obj_x)))",
- " arg_x = obj_x;",
- "else {",
- ' CPy_TypeError("list", obj_x);',
- " return NULL;",
- "}",
- ],
- lines,
- )
- def test_check_int(self) -> None:
- emitter = Emitter(self.context)
- generate_arg_check("x", int_rprimitive, emitter, ReturnHandler("NULL"))
- generate_arg_check("y", int_rprimitive, emitter, ReturnHandler("NULL"), optional=True)
- lines = emitter.fragments
- self.assert_lines(
- [
- "CPyTagged arg_x;",
- "if (likely(PyLong_Check(obj_x)))",
- " arg_x = CPyTagged_BorrowFromObject(obj_x);",
- "else {",
- ' CPy_TypeError("int", obj_x); return NULL;',
- "}",
- "CPyTagged arg_y;",
- "if (obj_y == NULL) {",
- " arg_y = CPY_INT_TAG;",
- "} else if (likely(PyLong_Check(obj_y)))",
- " arg_y = CPyTagged_BorrowFromObject(obj_y);",
- "else {",
- ' CPy_TypeError("int", obj_y); return NULL;',
- "}",
- ],
- lines,
- )
- def assert_lines(self, expected: list[str], actual: list[str]) -> None:
- actual = [line.rstrip("\n") for line in actual]
- assert_string_arrays_equal(expected, actual, "Invalid output")
|