| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- from __future__ import annotations
- import contextlib
- import os
- import pathlib
- import re
- import shutil
- import sys
- import time
- from typing import Any, Callable, Iterable, Iterator, Pattern
- # Exporting Suite as alias to TestCase for backwards compatibility
- # TODO: avoid aliasing - import and subclass TestCase directly
- from unittest import TestCase
- Suite = TestCase # re-exporting
- import pytest
- import mypy.api as api
- import mypy.version
- from mypy import defaults
- from mypy.main import process_options
- from mypy.options import Options
- from mypy.test.config import test_data_prefix, test_temp_dir
- from mypy.test.data import DataDrivenTestCase, DeleteFile, UpdateFile, fix_cobertura_filename
- skip = pytest.mark.skip
- # AssertStringArraysEqual displays special line alignment helper messages if
- # the first different line has at least this many characters,
- MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
- def run_mypy(args: list[str]) -> None:
- __tracebackhide__ = True
- # We must enable site packages even though they could cause problems,
- # since stubs for typing_extensions live there.
- outval, errval, status = api.run(args + ["--show-traceback", "--no-silence-site-packages"])
- if status != 0:
- sys.stdout.write(outval)
- sys.stderr.write(errval)
- pytest.fail(msg="Sample check failed", pytrace=False)
- def assert_string_arrays_equal(expected: list[str], actual: list[str], msg: str) -> None:
- """Assert that two string arrays are equal.
- We consider "can't" and "cannot" equivalent, by replacing the
- former with the latter before comparing.
- Display any differences in a human-readable form.
- """
- actual = clean_up(actual)
- actual = [line.replace("can't", "cannot") for line in actual]
- expected = [line.replace("can't", "cannot") for line in expected]
- if actual != expected:
- num_skip_start = num_skipped_prefix_lines(expected, actual)
- num_skip_end = num_skipped_suffix_lines(expected, actual)
- sys.stderr.write("Expected:\n")
- # If omit some lines at the beginning, indicate it by displaying a line
- # with '...'.
- if num_skip_start > 0:
- sys.stderr.write(" ...\n")
- # Keep track of the first different line.
- first_diff = -1
- # Display only this many first characters of identical lines.
- width = 75
- for i in range(num_skip_start, len(expected) - num_skip_end):
- if i >= len(actual) or expected[i] != actual[i]:
- if first_diff < 0:
- first_diff = i
- sys.stderr.write(f" {expected[i]:<45} (diff)")
- else:
- e = expected[i]
- sys.stderr.write(" " + e[:width])
- if len(e) > width:
- sys.stderr.write("...")
- sys.stderr.write("\n")
- if num_skip_end > 0:
- sys.stderr.write(" ...\n")
- sys.stderr.write("Actual:\n")
- if num_skip_start > 0:
- sys.stderr.write(" ...\n")
- for j in range(num_skip_start, len(actual) - num_skip_end):
- if j >= len(expected) or expected[j] != actual[j]:
- sys.stderr.write(f" {actual[j]:<45} (diff)")
- else:
- a = actual[j]
- sys.stderr.write(" " + a[:width])
- if len(a) > width:
- sys.stderr.write("...")
- sys.stderr.write("\n")
- if not actual:
- sys.stderr.write(" (empty)\n")
- if num_skip_end > 0:
- sys.stderr.write(" ...\n")
- sys.stderr.write("\n")
- if 0 <= first_diff < len(actual) and (
- len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
- or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
- ):
- # Display message that helps visualize the differences between two
- # long lines.
- show_align_message(expected[first_diff], actual[first_diff])
- pytest.fail(msg, pytrace=False)
- def assert_module_equivalence(name: str, expected: Iterable[str], actual: Iterable[str]) -> None:
- expected_normalized = sorted(expected)
- actual_normalized = sorted(set(actual).difference({"__main__"}))
- assert_string_arrays_equal(
- expected_normalized,
- actual_normalized,
- ("Actual modules ({}) do not match expected modules ({}) " 'for "[{} ...]"').format(
- ", ".join(actual_normalized), ", ".join(expected_normalized), name
- ),
- )
- def assert_target_equivalence(name: str, expected: list[str], actual: list[str]) -> None:
- """Compare actual and expected targets (order sensitive)."""
- assert_string_arrays_equal(
- expected,
- actual,
- ("Actual targets ({}) do not match expected targets ({}) " 'for "[{} ...]"').format(
- ", ".join(actual), ", ".join(expected), name
- ),
- )
- def show_align_message(s1: str, s2: str) -> None:
- """Align s1 and s2 so that the their first difference is highlighted.
- For example, if s1 is 'foobar' and s2 is 'fobar', display the
- following lines:
- E: foobar
- A: fobar
- ^
- If s1 and s2 are long, only display a fragment of the strings around the
- first difference. If s1 is very short, do nothing.
- """
- # Seeing what went wrong is trivial even without alignment if the expected
- # string is very short. In this case do nothing to simplify output.
- if len(s1) < 4:
- return
- maxw = 72 # Maximum number of characters shown
- sys.stderr.write("Alignment of first line difference:\n")
- trunc = False
- while s1[:30] == s2[:30]:
- s1 = s1[10:]
- s2 = s2[10:]
- trunc = True
- if trunc:
- s1 = "..." + s1
- s2 = "..." + s2
- max_len = max(len(s1), len(s2))
- extra = ""
- if max_len > maxw:
- extra = "..."
- # Write a chunk of both lines, aligned.
- sys.stderr.write(f" E: {s1[:maxw]}{extra}\n")
- sys.stderr.write(f" A: {s2[:maxw]}{extra}\n")
- # Write an indicator character under the different columns.
- sys.stderr.write(" ")
- for j in range(min(maxw, max(len(s1), len(s2)))):
- if s1[j : j + 1] != s2[j : j + 1]:
- sys.stderr.write("^") # Difference
- break
- else:
- sys.stderr.write(" ") # Equal
- sys.stderr.write("\n")
- def clean_up(a: list[str]) -> list[str]:
- """Remove common directory prefix from all strings in a.
- This uses a naive string replace; it seems to work well enough. Also
- remove trailing carriage returns.
- """
- res = []
- pwd = os.getcwd()
- driver = pwd + "/driver.py"
- for s in a:
- prefix = os.sep
- ss = s
- for p in prefix, prefix.replace(os.sep, "/"):
- if p != "/" and p != "//" and p != "\\" and p != "\\\\":
- ss = ss.replace(p, "")
- # Ignore spaces at end of line.
- ss = re.sub(" +$", "", ss)
- # Remove pwd from driver.py's path
- ss = ss.replace(driver, "driver.py")
- res.append(re.sub("\\r$", "", ss))
- return res
- @contextlib.contextmanager
- def local_sys_path_set() -> Iterator[None]:
- """Temporary insert current directory into sys.path.
- This can be used by test cases that do runtime imports, for example
- by the stubgen tests.
- """
- old_sys_path = sys.path.copy()
- if not ("" in sys.path or "." in sys.path):
- sys.path.insert(0, "")
- try:
- yield
- finally:
- sys.path = old_sys_path
- def num_skipped_prefix_lines(a1: list[str], a2: list[str]) -> int:
- num_eq = 0
- while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
- num_eq += 1
- return max(0, num_eq - 4)
- def num_skipped_suffix_lines(a1: list[str], a2: list[str]) -> int:
- num_eq = 0
- while num_eq < min(len(a1), len(a2)) and a1[-num_eq - 1] == a2[-num_eq - 1]:
- num_eq += 1
- return max(0, num_eq - 4)
- def testfile_pyversion(path: str) -> tuple[int, int]:
- if path.endswith("python311.test"):
- return 3, 11
- elif path.endswith("python310.test"):
- return 3, 10
- elif path.endswith("python39.test"):
- return 3, 9
- elif path.endswith("python38.test"):
- return 3, 8
- else:
- return defaults.PYTHON3_VERSION
- def normalize_error_messages(messages: list[str]) -> list[str]:
- """Translate an array of error messages to use / as path separator."""
- a = []
- for m in messages:
- a.append(m.replace(os.sep, "/"))
- return a
- def retry_on_error(func: Callable[[], Any], max_wait: float = 1.0) -> None:
- """Retry callback with exponential backoff when it raises OSError.
- If the function still generates an error after max_wait seconds, propagate
- the exception.
- This can be effective against random file system operation failures on
- Windows.
- """
- t0 = time.time()
- wait_time = 0.01
- while True:
- try:
- func()
- return
- except OSError:
- wait_time = min(wait_time * 2, t0 + max_wait - time.time())
- if wait_time <= 0.01:
- # Done enough waiting, the error seems persistent.
- raise
- time.sleep(wait_time)
- def good_repr(obj: object) -> str:
- if isinstance(obj, str):
- if obj.count("\n") > 1:
- bits = ["'''\\"]
- for line in obj.split("\n"):
- # force repr to use ' not ", then cut it off
- bits.append(repr('"' + line)[2:-1])
- bits[-1] += "'''"
- return "\n".join(bits)
- return repr(obj)
- def assert_equal(a: object, b: object, fmt: str = "{} != {}") -> None:
- __tracebackhide__ = True
- if a != b:
- raise AssertionError(fmt.format(good_repr(a), good_repr(b)))
- def typename(t: type) -> str:
- if "." in str(t):
- return str(t).split(".")[-1].rstrip("'>")
- else:
- return str(t)[8:-2]
- def assert_type(typ: type, value: object) -> None:
- __tracebackhide__ = True
- if type(value) != typ:
- raise AssertionError(f"Invalid type {typename(type(value))}, expected {typename(typ)}")
- def parse_options(
- program_text: str, testcase: DataDrivenTestCase, incremental_step: int
- ) -> Options:
- """Parse comments like '# flags: --foo' in a test case."""
- options = Options()
- flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
- if incremental_step > 1:
- flags2 = re.search(f"# flags{incremental_step}: (.*)$", program_text, flags=re.MULTILINE)
- if flags2:
- flags = flags2
- if flags:
- flag_list = flags.group(1).split()
- flag_list.append("--no-site-packages") # the tests shouldn't need an installed Python
- targets, options = process_options(flag_list, require_targets=False)
- if targets:
- # TODO: support specifying targets via the flags pragma
- raise RuntimeError("Specifying targets via the flags pragma is not supported.")
- if "--show-error-codes" not in flag_list:
- options.hide_error_codes = True
- else:
- flag_list = []
- options = Options()
- # TODO: Enable strict optional in test cases by default (requires *many* test case changes)
- options.strict_optional = False
- options.error_summary = False
- options.hide_error_codes = True
- options.force_uppercase_builtins = True
- options.force_union_syntax = True
- # Allow custom python version to override testfile_pyversion.
- if all(flag.split("=")[0] not in ["--python-version", "-2", "--py2"] for flag in flag_list):
- options.python_version = testfile_pyversion(testcase.file)
- if testcase.config.getoption("--mypy-verbose"):
- options.verbosity = testcase.config.getoption("--mypy-verbose")
- return options
- def split_lines(*streams: bytes) -> list[str]:
- """Returns a single list of string lines from the byte streams in args."""
- return [s for stream in streams for s in stream.decode("utf8").splitlines()]
- def write_and_fudge_mtime(content: str, target_path: str) -> None:
- # In some systems, mtime has a resolution of 1 second which can
- # cause annoying-to-debug issues when a file has the same size
- # after a change. We manually set the mtime to circumvent this.
- # Note that we increment the old file's mtime, which guarantees a
- # different value, rather than incrementing the mtime after the
- # copy, which could leave the mtime unchanged if the old file had
- # a similarly fudged mtime.
- new_time = None
- if os.path.isfile(target_path):
- new_time = os.stat(target_path).st_mtime + 1
- dir = os.path.dirname(target_path)
- os.makedirs(dir, exist_ok=True)
- with open(target_path, "w", encoding="utf-8") as target:
- target.write(content)
- if new_time:
- os.utime(target_path, times=(new_time, new_time))
- def perform_file_operations(operations: list[UpdateFile | DeleteFile]) -> None:
- for op in operations:
- if isinstance(op, UpdateFile):
- # Modify/create file
- write_and_fudge_mtime(op.content, op.target_path)
- else:
- # Delete file/directory
- if os.path.isdir(op.path):
- # Sanity check to avoid unexpected deletions
- assert op.path.startswith("tmp")
- shutil.rmtree(op.path)
- else:
- # Use retries to work around potential flakiness on Windows (AppVeyor).
- path = op.path
- retry_on_error(lambda: os.remove(path))
- def check_test_output_files(
- testcase: DataDrivenTestCase, step: int, strip_prefix: str = ""
- ) -> None:
- for path, expected_content in testcase.output_files:
- if path.startswith(strip_prefix):
- path = path[len(strip_prefix) :]
- if not os.path.exists(path):
- raise AssertionError(
- "Expected file {} was not produced by test case{}".format(
- path, " on step %d" % step if testcase.output2 else ""
- )
- )
- with open(path, encoding="utf8") as output_file:
- actual_output_content = output_file.read()
- if isinstance(expected_content, Pattern):
- if expected_content.fullmatch(actual_output_content) is not None:
- continue
- raise AssertionError(
- "Output file {} did not match its expected output pattern\n---\n{}\n---".format(
- path, actual_output_content
- )
- )
- normalized_output = normalize_file_output(
- actual_output_content.splitlines(), os.path.abspath(test_temp_dir)
- )
- # We always normalize things like timestamp, but only handle operating-system
- # specific things if requested.
- if testcase.normalize_output:
- if testcase.suite.native_sep and os.path.sep == "\\":
- normalized_output = [fix_cobertura_filename(line) for line in normalized_output]
- normalized_output = normalize_error_messages(normalized_output)
- assert_string_arrays_equal(
- expected_content.splitlines(),
- normalized_output,
- "Output file {} did not match its expected output{}".format(
- path, " on step %d" % step if testcase.output2 else ""
- ),
- )
- def normalize_file_output(content: list[str], current_abs_path: str) -> list[str]:
- """Normalize file output for comparison."""
- timestamp_regex = re.compile(r"\d{10}")
- result = [x.replace(current_abs_path, "$PWD") for x in content]
- version = mypy.version.__version__
- result = [re.sub(r"\b" + re.escape(version) + r"\b", "$VERSION", x) for x in result]
- # We generate a new mypy.version when building mypy wheels that
- # lacks base_version, so handle that case.
- base_version = getattr(mypy.version, "base_version", version)
- result = [re.sub(r"\b" + re.escape(base_version) + r"\b", "$VERSION", x) for x in result]
- result = [timestamp_regex.sub("$TIMESTAMP", x) for x in result]
- return result
- def find_test_files(pattern: str, exclude: list[str] | None = None) -> list[str]:
- return [
- path.name
- for path in (pathlib.Path(test_data_prefix).rglob(pattern))
- if path.name not in (exclude or [])
- ]
|