testutil.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """Helpers for writing tests"""
  2. from __future__ import annotations
  3. import contextlib
  4. import os
  5. import os.path
  6. import re
  7. import shutil
  8. from typing import Callable, Iterator
  9. from mypy import build
  10. from mypy.errors import CompileError
  11. from mypy.options import Options
  12. from mypy.test.config import test_temp_dir
  13. from mypy.test.data import DataDrivenTestCase, DataSuite
  14. from mypy.test.helpers import assert_string_arrays_equal
  15. from mypyc.analysis.ircheck import assert_func_ir_valid
  16. from mypyc.common import IS_32_BIT_PLATFORM, PLATFORM_SIZE
  17. from mypyc.errors import Errors
  18. from mypyc.ir.func_ir import FuncIR
  19. from mypyc.ir.module_ir import ModuleIR
  20. from mypyc.irbuild.main import build_ir
  21. from mypyc.irbuild.mapper import Mapper
  22. from mypyc.options import CompilerOptions
  23. from mypyc.test.config import test_data_prefix
  24. # The builtins stub used during icode generation test cases.
  25. ICODE_GEN_BUILTINS = os.path.join(test_data_prefix, "fixtures/ir.py")
  26. # The testutil support library
  27. TESTUTIL_PATH = os.path.join(test_data_prefix, "fixtures/testutil.py")
  28. class MypycDataSuite(DataSuite):
  29. # Need to list no files, since this will be picked up as a suite of tests
  30. files: list[str] = []
  31. data_prefix = test_data_prefix
  32. def builtins_wrapper(
  33. func: Callable[[DataDrivenTestCase], None], path: str
  34. ) -> Callable[[DataDrivenTestCase], None]:
  35. """Decorate a function that implements a data-driven test case to copy an
  36. alternative builtins module implementation in place before performing the
  37. test case. Clean up after executing the test case.
  38. """
  39. return lambda testcase: perform_test(func, path, testcase)
  40. @contextlib.contextmanager
  41. def use_custom_builtins(builtins_path: str, testcase: DataDrivenTestCase) -> Iterator[None]:
  42. for path, _ in testcase.files:
  43. if os.path.basename(path) == "builtins.pyi":
  44. default_builtins = False
  45. break
  46. else:
  47. # Use default builtins.
  48. builtins = os.path.abspath(os.path.join(test_temp_dir, "builtins.pyi"))
  49. shutil.copyfile(builtins_path, builtins)
  50. default_builtins = True
  51. # Actually perform the test case.
  52. try:
  53. yield None
  54. finally:
  55. if default_builtins:
  56. # Clean up.
  57. os.remove(builtins)
  58. def perform_test(
  59. func: Callable[[DataDrivenTestCase], None], builtins_path: str, testcase: DataDrivenTestCase
  60. ) -> None:
  61. for path, _ in testcase.files:
  62. if os.path.basename(path) == "builtins.py":
  63. default_builtins = False
  64. break
  65. else:
  66. # Use default builtins.
  67. builtins = os.path.join(test_temp_dir, "builtins.py")
  68. shutil.copyfile(builtins_path, builtins)
  69. default_builtins = True
  70. # Actually perform the test case.
  71. func(testcase)
  72. if default_builtins:
  73. # Clean up.
  74. os.remove(builtins)
  75. def build_ir_for_single_file(
  76. input_lines: list[str], compiler_options: CompilerOptions | None = None
  77. ) -> list[FuncIR]:
  78. return build_ir_for_single_file2(input_lines, compiler_options).functions
  79. def build_ir_for_single_file2(
  80. input_lines: list[str], compiler_options: CompilerOptions | None = None
  81. ) -> ModuleIR:
  82. program_text = "\n".join(input_lines)
  83. # By default generate IR compatible with the earliest supported Python C API.
  84. # If a test needs more recent API features, this should be overridden.
  85. compiler_options = compiler_options or CompilerOptions(capi_version=(3, 5))
  86. options = Options()
  87. options.show_traceback = True
  88. options.hide_error_codes = True
  89. options.use_builtins_fixtures = True
  90. options.strict_optional = True
  91. options.python_version = compiler_options.python_version or (3, 6)
  92. options.export_types = True
  93. options.preserve_asts = True
  94. options.allow_empty_bodies = True
  95. options.per_module_options["__main__"] = {"mypyc": True}
  96. source = build.BuildSource("main", "__main__", program_text)
  97. # Construct input as a single single.
  98. # Parse and type check the input program.
  99. result = build.build(sources=[source], options=options, alt_lib_path=test_temp_dir)
  100. if result.errors:
  101. raise CompileError(result.errors)
  102. errors = Errors(options)
  103. modules = build_ir(
  104. [result.files["__main__"]],
  105. result.graph,
  106. result.types,
  107. Mapper({"__main__": None}),
  108. compiler_options,
  109. errors,
  110. )
  111. if errors.num_errors:
  112. raise CompileError(errors.new_messages())
  113. module = list(modules.values())[0]
  114. for fn in module.functions:
  115. assert_func_ir_valid(fn)
  116. return module
  117. def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
  118. # TODO: backport this to mypy
  119. assert testcase.old_cwd is not None, "test was not properly set up"
  120. testcase_path = os.path.join(testcase.old_cwd, testcase.file)
  121. with open(testcase_path) as f:
  122. data_lines = f.read().splitlines()
  123. # We can't rely on the test line numbers to *find* the test, since
  124. # we might fix multiple tests in a run. So find it by the case
  125. # header. Give up if there are multiple tests with the same name.
  126. test_slug = f"[case {testcase.name}]"
  127. if data_lines.count(test_slug) != 1:
  128. return
  129. start_idx = data_lines.index(test_slug)
  130. stop_idx = start_idx + 11
  131. while stop_idx < len(data_lines) and not data_lines[stop_idx].startswith("[case "):
  132. stop_idx += 1
  133. test = data_lines[start_idx:stop_idx]
  134. out_start = test.index("[out]")
  135. test[out_start + 1 :] = output
  136. data_lines[start_idx:stop_idx] = test + [""]
  137. data = "\n".join(data_lines)
  138. with open(testcase_path, "w") as f:
  139. print(data, file=f)
  140. def assert_test_output(
  141. testcase: DataDrivenTestCase,
  142. actual: list[str],
  143. message: str,
  144. expected: list[str] | None = None,
  145. formatted: list[str] | None = None,
  146. ) -> None:
  147. __tracebackhide__ = True
  148. expected_output = expected if expected is not None else testcase.output
  149. if expected_output != actual and testcase.config.getoption("--update-data", False):
  150. update_testcase_output(testcase, actual)
  151. assert_string_arrays_equal(
  152. expected_output, actual, f"{message} ({testcase.file}, line {testcase.line})"
  153. )
  154. def get_func_names(expected: list[str]) -> list[str]:
  155. res = []
  156. for s in expected:
  157. m = re.match(r"def ([_a-zA-Z0-9.*$]+)\(", s)
  158. if m:
  159. res.append(m.group(1))
  160. return res
  161. def remove_comment_lines(a: list[str]) -> list[str]:
  162. """Return a copy of array with comments removed.
  163. Lines starting with '--' (but not with '---') are removed.
  164. """
  165. r = []
  166. for s in a:
  167. if s.strip().startswith("--") and not s.strip().startswith("---"):
  168. pass
  169. else:
  170. r.append(s)
  171. return r
  172. def print_with_line_numbers(s: str) -> None:
  173. lines = s.splitlines()
  174. for i, line in enumerate(lines):
  175. print("%-4d %s" % (i + 1, line))
  176. def heading(text: str) -> None:
  177. print("=" * 20 + " " + text + " " + "=" * 20)
  178. def show_c(cfiles: list[list[tuple[str, str]]]) -> None:
  179. heading("Generated C")
  180. for group in cfiles:
  181. for cfile, ctext in group:
  182. print(f"== {cfile} ==")
  183. print_with_line_numbers(ctext)
  184. heading("End C")
  185. def fudge_dir_mtimes(dir: str, delta: int) -> None:
  186. for dirpath, _, filenames in os.walk(dir):
  187. for name in filenames:
  188. path = os.path.join(dirpath, name)
  189. new_mtime = os.stat(path).st_mtime + delta
  190. os.utime(path, times=(new_mtime, new_mtime))
  191. def replace_word_size(text: list[str]) -> list[str]:
  192. """Replace WORDSIZE with platform specific word sizes"""
  193. result = []
  194. for line in text:
  195. index = line.find("WORD_SIZE")
  196. if index != -1:
  197. # get 'WORDSIZE*n' token
  198. word_size_token = line[index:].split()[0]
  199. n = int(word_size_token[10:])
  200. replace_str = str(PLATFORM_SIZE * n)
  201. result.append(line.replace(word_size_token, replace_str))
  202. else:
  203. result.append(line)
  204. return result
  205. def infer_ir_build_options_from_test_name(name: str) -> CompilerOptions | None:
  206. """Look for magic substrings in test case name to set compiler options.
  207. Return None if the test case should be skipped (always pass).
  208. Supported naming conventions:
  209. *_64bit*:
  210. Run test case only on 64-bit platforms
  211. *_32bit*:
  212. Run test caseonly on 32-bit platforms
  213. *_python3_8* (or for any Python version):
  214. Use Python 3.8+ C API features (default: lowest supported version)
  215. *StripAssert*:
  216. Don't generate code for assert statements
  217. """
  218. # If this is specific to some bit width, always pass if platform doesn't match.
  219. if "_64bit" in name and IS_32_BIT_PLATFORM:
  220. return None
  221. if "_32bit" in name and not IS_32_BIT_PLATFORM:
  222. return None
  223. options = CompilerOptions(strip_asserts="StripAssert" in name, capi_version=(3, 5))
  224. # A suffix like _python3.8 is used to set the target C API version.
  225. m = re.search(r"_python([3-9]+)_([0-9]+)(_|\b)", name)
  226. if m:
  227. options.capi_version = (int(m.group(1)), int(m.group(2)))
  228. options.python_version = options.capi_version
  229. elif "_py" in name or "_Python" in name:
  230. assert False, f"Invalid _py* suffix (should be _pythonX_Y): {name}"
  231. return options