ipc.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """Cross platform abstractions for inter-process communication
  2. On Unix, this uses AF_UNIX sockets.
  3. On Windows, this uses NamedPipes.
  4. """
  5. from __future__ import annotations
  6. import base64
  7. import os
  8. import shutil
  9. import sys
  10. import tempfile
  11. from types import TracebackType
  12. from typing import Callable
  13. from typing_extensions import Final
  14. if sys.platform == "win32":
  15. # This may be private, but it is needed for IPC on Windows, and is basically stable
  16. import ctypes
  17. import _winapi
  18. _IPCHandle = int
  19. kernel32 = ctypes.windll.kernel32
  20. DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
  21. FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
  22. else:
  23. import socket
  24. _IPCHandle = socket.socket
  25. class IPCException(Exception):
  26. """Exception for IPC issues."""
  27. class IPCBase:
  28. """Base class for communication between the dmypy client and server.
  29. This contains logic shared between the client and server, such as reading
  30. and writing.
  31. """
  32. connection: _IPCHandle
  33. def __init__(self, name: str, timeout: float | None) -> None:
  34. self.name = name
  35. self.timeout = timeout
  36. def read(self, size: int = 100000) -> bytes:
  37. """Read bytes from an IPC connection until its empty."""
  38. bdata = bytearray()
  39. if sys.platform == "win32":
  40. while True:
  41. ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
  42. try:
  43. if err == _winapi.ERROR_IO_PENDING:
  44. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  45. res = _winapi.WaitForSingleObject(ov.event, timeout)
  46. if res != _winapi.WAIT_OBJECT_0:
  47. raise IPCException(f"Bad result from I/O wait: {res}")
  48. except BaseException:
  49. ov.cancel()
  50. raise
  51. _, err = ov.GetOverlappedResult(True)
  52. more = ov.getbuffer()
  53. if more:
  54. bdata.extend(more)
  55. if err == 0:
  56. # we are done!
  57. break
  58. elif err == _winapi.ERROR_MORE_DATA:
  59. # read again
  60. continue
  61. elif err == _winapi.ERROR_OPERATION_ABORTED:
  62. raise IPCException("ReadFile operation aborted.")
  63. else:
  64. while True:
  65. more = self.connection.recv(size)
  66. if not more:
  67. break
  68. bdata.extend(more)
  69. return bytes(bdata)
  70. def write(self, data: bytes) -> None:
  71. """Write bytes to an IPC connection."""
  72. if sys.platform == "win32":
  73. try:
  74. ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
  75. try:
  76. if err == _winapi.ERROR_IO_PENDING:
  77. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  78. res = _winapi.WaitForSingleObject(ov.event, timeout)
  79. if res != _winapi.WAIT_OBJECT_0:
  80. raise IPCException(f"Bad result from I/O wait: {res}")
  81. elif err != 0:
  82. raise IPCException(f"Failed writing to pipe with error: {err}")
  83. except BaseException:
  84. ov.cancel()
  85. raise
  86. bytes_written, err = ov.GetOverlappedResult(True)
  87. assert err == 0, err
  88. assert bytes_written == len(data)
  89. except OSError as e:
  90. raise IPCException(f"Failed to write with error: {e.winerror}") from e
  91. else:
  92. self.connection.sendall(data)
  93. self.connection.shutdown(socket.SHUT_WR)
  94. def close(self) -> None:
  95. if sys.platform == "win32":
  96. if self.connection != _winapi.NULL:
  97. _winapi.CloseHandle(self.connection)
  98. else:
  99. self.connection.close()
  100. class IPCClient(IPCBase):
  101. """The client side of an IPC connection."""
  102. def __init__(self, name: str, timeout: float | None) -> None:
  103. super().__init__(name, timeout)
  104. if sys.platform == "win32":
  105. timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
  106. try:
  107. _winapi.WaitNamedPipe(self.name, timeout)
  108. except FileNotFoundError as e:
  109. raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
  110. except OSError as e:
  111. if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
  112. raise IPCException("Timed out waiting for connection.") from e
  113. else:
  114. raise
  115. try:
  116. self.connection = _winapi.CreateFile(
  117. self.name,
  118. _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
  119. 0,
  120. _winapi.NULL,
  121. _winapi.OPEN_EXISTING,
  122. _winapi.FILE_FLAG_OVERLAPPED,
  123. _winapi.NULL,
  124. )
  125. except OSError as e:
  126. if e.winerror == _winapi.ERROR_PIPE_BUSY:
  127. raise IPCException("The connection is busy.") from e
  128. else:
  129. raise
  130. _winapi.SetNamedPipeHandleState(
  131. self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None
  132. )
  133. else:
  134. self.connection = socket.socket(socket.AF_UNIX)
  135. self.connection.settimeout(timeout)
  136. self.connection.connect(name)
  137. def __enter__(self) -> IPCClient:
  138. return self
  139. def __exit__(
  140. self,
  141. exc_ty: type[BaseException] | None = None,
  142. exc_val: BaseException | None = None,
  143. exc_tb: TracebackType | None = None,
  144. ) -> None:
  145. self.close()
  146. class IPCServer(IPCBase):
  147. BUFFER_SIZE: Final = 2**16
  148. def __init__(self, name: str, timeout: float | None = None) -> None:
  149. if sys.platform == "win32":
  150. name = r"\\.\pipe\{}-{}.pipe".format(
  151. name, base64.urlsafe_b64encode(os.urandom(6)).decode()
  152. )
  153. else:
  154. name = f"{name}.sock"
  155. super().__init__(name, timeout)
  156. if sys.platform == "win32":
  157. self.connection = _winapi.CreateNamedPipe(
  158. self.name,
  159. _winapi.PIPE_ACCESS_DUPLEX
  160. | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
  161. | _winapi.FILE_FLAG_OVERLAPPED,
  162. _winapi.PIPE_READMODE_MESSAGE
  163. | _winapi.PIPE_TYPE_MESSAGE
  164. | _winapi.PIPE_WAIT
  165. | 0x8, # PIPE_REJECT_REMOTE_CLIENTS
  166. 1, # one instance
  167. self.BUFFER_SIZE,
  168. self.BUFFER_SIZE,
  169. _winapi.NMPWAIT_WAIT_FOREVER,
  170. 0, # Use default security descriptor
  171. )
  172. if self.connection == -1: # INVALID_HANDLE_VALUE
  173. err = _winapi.GetLastError()
  174. raise IPCException(f"Invalid handle to pipe: {err}")
  175. else:
  176. self.sock_directory = tempfile.mkdtemp()
  177. sockfile = os.path.join(self.sock_directory, self.name)
  178. self.sock = socket.socket(socket.AF_UNIX)
  179. self.sock.bind(sockfile)
  180. self.sock.listen(1)
  181. if timeout is not None:
  182. self.sock.settimeout(timeout)
  183. def __enter__(self) -> IPCServer:
  184. if sys.platform == "win32":
  185. # NOTE: It is theoretically possible that this will hang forever if the
  186. # client never connects, though this can be "solved" by killing the server
  187. try:
  188. ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
  189. except OSError as e:
  190. # Don't raise if the client already exists, or the client already connected
  191. if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
  192. raise
  193. else:
  194. try:
  195. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  196. res = _winapi.WaitForSingleObject(ov.event, timeout)
  197. assert res == _winapi.WAIT_OBJECT_0
  198. except BaseException:
  199. ov.cancel()
  200. _winapi.CloseHandle(self.connection)
  201. raise
  202. _, err = ov.GetOverlappedResult(True)
  203. assert err == 0
  204. else:
  205. try:
  206. self.connection, _ = self.sock.accept()
  207. except socket.timeout as e:
  208. raise IPCException("The socket timed out") from e
  209. return self
  210. def __exit__(
  211. self,
  212. exc_ty: type[BaseException] | None = None,
  213. exc_val: BaseException | None = None,
  214. exc_tb: TracebackType | None = None,
  215. ) -> None:
  216. if sys.platform == "win32":
  217. try:
  218. # Wait for the client to finish reading the last write before disconnecting
  219. if not FlushFileBuffers(self.connection):
  220. raise IPCException(
  221. "Failed to flush NamedPipe buffer, maybe the client hung up?"
  222. )
  223. finally:
  224. DisconnectNamedPipe(self.connection)
  225. else:
  226. self.close()
  227. def cleanup(self) -> None:
  228. if sys.platform == "win32":
  229. self.close()
  230. else:
  231. shutil.rmtree(self.sock_directory)
  232. @property
  233. def connection_name(self) -> str:
  234. if sys.platform == "win32":
  235. return self.name
  236. else:
  237. name = self.sock.getsockname()
  238. assert isinstance(name, str)
  239. return name