test_refcount.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """Test runner for reference count opcode insertion transform test cases.
  2. The transform inserts needed reference count increment/decrement
  3. operations to IR.
  4. """
  5. from __future__ import annotations
  6. import os.path
  7. from mypy.errors import CompileError
  8. from mypy.test.config import test_temp_dir
  9. from mypy.test.data import DataDrivenTestCase
  10. from mypyc.common import TOP_LEVEL_NAME
  11. from mypyc.ir.pprint import format_func
  12. from mypyc.test.testutil import (
  13. ICODE_GEN_BUILTINS,
  14. MypycDataSuite,
  15. assert_test_output,
  16. build_ir_for_single_file,
  17. infer_ir_build_options_from_test_name,
  18. remove_comment_lines,
  19. replace_word_size,
  20. use_custom_builtins,
  21. )
  22. from mypyc.transform.refcount import insert_ref_count_opcodes
  23. from mypyc.transform.uninit import insert_uninit_checks
  24. files = ["refcount.test"]
  25. class TestRefCountTransform(MypycDataSuite):
  26. files = files
  27. base_path = test_temp_dir
  28. optional_out = True
  29. def run_case(self, testcase: DataDrivenTestCase) -> None:
  30. """Perform a runtime checking transformation test case."""
  31. options = infer_ir_build_options_from_test_name(testcase.name)
  32. if options is None:
  33. # Skipped test case
  34. return
  35. with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase):
  36. expected_output = remove_comment_lines(testcase.output)
  37. expected_output = replace_word_size(expected_output)
  38. try:
  39. ir = build_ir_for_single_file(testcase.input, options)
  40. except CompileError as e:
  41. actual = e.messages
  42. else:
  43. actual = []
  44. for fn in ir:
  45. if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"):
  46. continue
  47. insert_uninit_checks(fn)
  48. insert_ref_count_opcodes(fn)
  49. actual.extend(format_func(fn))
  50. assert_test_output(testcase, actual, "Invalid source code output", expected_output)