source.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from __future__ import annotations
  2. from copy import copy
  3. from typing import Any
  4. from tomlkit.exceptions import ParseError
  5. from tomlkit.exceptions import UnexpectedCharError
  6. from tomlkit.toml_char import TOMLChar
  7. class _State:
  8. def __init__(
  9. self,
  10. source: Source,
  11. save_marker: bool | None = False,
  12. restore: bool | None = False,
  13. ) -> None:
  14. self._source = source
  15. self._save_marker = save_marker
  16. self.restore = restore
  17. def __enter__(self) -> _State:
  18. # Entering this context manager - save the state
  19. self._chars = copy(self._source._chars)
  20. self._idx = self._source._idx
  21. self._current = self._source._current
  22. self._marker = self._source._marker
  23. return self
  24. def __exit__(self, exception_type, exception_val, trace):
  25. # Exiting this context manager - restore the prior state
  26. if self.restore or exception_type:
  27. self._source._chars = self._chars
  28. self._source._idx = self._idx
  29. self._source._current = self._current
  30. if self._save_marker:
  31. self._source._marker = self._marker
  32. class _StateHandler:
  33. """
  34. State preserver for the Parser.
  35. """
  36. def __init__(self, source: Source) -> None:
  37. self._source = source
  38. self._states = []
  39. def __call__(self, *args, **kwargs):
  40. return _State(self._source, *args, **kwargs)
  41. def __enter__(self) -> _State:
  42. state = self()
  43. self._states.append(state)
  44. return state.__enter__()
  45. def __exit__(self, exception_type, exception_val, trace):
  46. state = self._states.pop()
  47. return state.__exit__(exception_type, exception_val, trace)
  48. class Source(str):
  49. EOF = TOMLChar("\0")
  50. def __init__(self, _: str) -> None:
  51. super().__init__()
  52. # Collection of TOMLChars
  53. self._chars = iter([(i, TOMLChar(c)) for i, c in enumerate(self)])
  54. self._idx = 0
  55. self._marker = 0
  56. self._current = TOMLChar("")
  57. self._state = _StateHandler(self)
  58. self.inc()
  59. def reset(self):
  60. # initialize both idx and current
  61. self.inc()
  62. # reset marker
  63. self.mark()
  64. @property
  65. def state(self) -> _StateHandler:
  66. return self._state
  67. @property
  68. def idx(self) -> int:
  69. return self._idx
  70. @property
  71. def current(self) -> TOMLChar:
  72. return self._current
  73. @property
  74. def marker(self) -> int:
  75. return self._marker
  76. def extract(self) -> str:
  77. """
  78. Extracts the value between marker and index
  79. """
  80. return self[self._marker : self._idx]
  81. def inc(self, exception: type[ParseError] | None = None) -> bool:
  82. """
  83. Increments the parser if the end of the input has not been reached.
  84. Returns whether or not it was able to advance.
  85. """
  86. try:
  87. self._idx, self._current = next(self._chars)
  88. return True
  89. except StopIteration:
  90. self._idx = len(self)
  91. self._current = self.EOF
  92. if exception:
  93. raise self.parse_error(exception)
  94. return False
  95. def inc_n(self, n: int, exception: type[ParseError] | None = None) -> bool:
  96. """
  97. Increments the parser by n characters
  98. if the end of the input has not been reached.
  99. """
  100. return all(self.inc(exception=exception) for _ in range(n))
  101. def consume(self, chars, min=0, max=-1):
  102. """
  103. Consume chars until min/max is satisfied is valid.
  104. """
  105. while self.current in chars and max != 0:
  106. min -= 1
  107. max -= 1
  108. if not self.inc():
  109. break
  110. # failed to consume minimum number of characters
  111. if min > 0:
  112. raise self.parse_error(UnexpectedCharError, self.current)
  113. def end(self) -> bool:
  114. """
  115. Returns True if the parser has reached the end of the input.
  116. """
  117. return self._current is self.EOF
  118. def mark(self) -> None:
  119. """
  120. Sets the marker to the index's current position
  121. """
  122. self._marker = self._idx
  123. def parse_error(
  124. self,
  125. exception: type[ParseError] = ParseError,
  126. *args: Any,
  127. **kwargs: Any,
  128. ) -> ParseError:
  129. """
  130. Creates a generic "parse error" at the current position.
  131. """
  132. line, col = self._to_linecol()
  133. return exception(line, col, *args, **kwargs)
  134. def _to_linecol(self) -> tuple[int, int]:
  135. cur = 0
  136. for i, line in enumerate(self.splitlines()):
  137. if cur + len(line) + 1 > self.idx:
  138. return (i + 1, self.idx - cur)
  139. cur += len(line) + 1
  140. return len(self.splitlines()), 0