ipc.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """Cross platform abstractions for inter-process communication
  2. On Unix, this uses AF_UNIX sockets.
  3. On Windows, this uses NamedPipes.
  4. """
  5. from __future__ import annotations
  6. import base64
  7. import os
  8. import shutil
  9. import sys
  10. import tempfile
  11. from types import TracebackType
  12. from typing import Callable, Final
  13. if sys.platform == "win32":
  14. # This may be private, but it is needed for IPC on Windows, and is basically stable
  15. import ctypes
  16. import _winapi
  17. _IPCHandle = int
  18. kernel32 = ctypes.windll.kernel32
  19. DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
  20. FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
  21. else:
  22. import socket
  23. _IPCHandle = socket.socket
  24. class IPCException(Exception):
  25. """Exception for IPC issues."""
  26. class IPCBase:
  27. """Base class for communication between the dmypy client and server.
  28. This contains logic shared between the client and server, such as reading
  29. and writing.
  30. """
  31. connection: _IPCHandle
  32. def __init__(self, name: str, timeout: float | None) -> None:
  33. self.name = name
  34. self.timeout = timeout
  35. def read(self, size: int = 100000) -> bytes:
  36. """Read bytes from an IPC connection until its empty."""
  37. bdata = bytearray()
  38. if sys.platform == "win32":
  39. while True:
  40. ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
  41. try:
  42. if err == _winapi.ERROR_IO_PENDING:
  43. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  44. res = _winapi.WaitForSingleObject(ov.event, timeout)
  45. if res != _winapi.WAIT_OBJECT_0:
  46. raise IPCException(f"Bad result from I/O wait: {res}")
  47. except BaseException:
  48. ov.cancel()
  49. raise
  50. _, err = ov.GetOverlappedResult(True)
  51. more = ov.getbuffer()
  52. if more:
  53. bdata.extend(more)
  54. if err == 0:
  55. # we are done!
  56. break
  57. elif err == _winapi.ERROR_MORE_DATA:
  58. # read again
  59. continue
  60. elif err == _winapi.ERROR_OPERATION_ABORTED:
  61. raise IPCException("ReadFile operation aborted.")
  62. else:
  63. while True:
  64. more = self.connection.recv(size)
  65. if not more:
  66. break
  67. bdata.extend(more)
  68. return bytes(bdata)
  69. def write(self, data: bytes) -> None:
  70. """Write bytes to an IPC connection."""
  71. if sys.platform == "win32":
  72. try:
  73. ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
  74. try:
  75. if err == _winapi.ERROR_IO_PENDING:
  76. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  77. res = _winapi.WaitForSingleObject(ov.event, timeout)
  78. if res != _winapi.WAIT_OBJECT_0:
  79. raise IPCException(f"Bad result from I/O wait: {res}")
  80. elif err != 0:
  81. raise IPCException(f"Failed writing to pipe with error: {err}")
  82. except BaseException:
  83. ov.cancel()
  84. raise
  85. bytes_written, err = ov.GetOverlappedResult(True)
  86. assert err == 0, err
  87. assert bytes_written == len(data)
  88. except OSError as e:
  89. raise IPCException(f"Failed to write with error: {e.winerror}") from e
  90. else:
  91. self.connection.sendall(data)
  92. self.connection.shutdown(socket.SHUT_WR)
  93. def close(self) -> None:
  94. if sys.platform == "win32":
  95. if self.connection != _winapi.NULL:
  96. _winapi.CloseHandle(self.connection)
  97. else:
  98. self.connection.close()
  99. class IPCClient(IPCBase):
  100. """The client side of an IPC connection."""
  101. def __init__(self, name: str, timeout: float | None) -> None:
  102. super().__init__(name, timeout)
  103. if sys.platform == "win32":
  104. timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
  105. try:
  106. _winapi.WaitNamedPipe(self.name, timeout)
  107. except FileNotFoundError as e:
  108. raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
  109. except OSError as e:
  110. if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
  111. raise IPCException("Timed out waiting for connection.") from e
  112. else:
  113. raise
  114. try:
  115. self.connection = _winapi.CreateFile(
  116. self.name,
  117. _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
  118. 0,
  119. _winapi.NULL,
  120. _winapi.OPEN_EXISTING,
  121. _winapi.FILE_FLAG_OVERLAPPED,
  122. _winapi.NULL,
  123. )
  124. except OSError as e:
  125. if e.winerror == _winapi.ERROR_PIPE_BUSY:
  126. raise IPCException("The connection is busy.") from e
  127. else:
  128. raise
  129. _winapi.SetNamedPipeHandleState(
  130. self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None
  131. )
  132. else:
  133. self.connection = socket.socket(socket.AF_UNIX)
  134. self.connection.settimeout(timeout)
  135. self.connection.connect(name)
  136. def __enter__(self) -> IPCClient:
  137. return self
  138. def __exit__(
  139. self,
  140. exc_ty: type[BaseException] | None = None,
  141. exc_val: BaseException | None = None,
  142. exc_tb: TracebackType | None = None,
  143. ) -> None:
  144. self.close()
  145. class IPCServer(IPCBase):
  146. BUFFER_SIZE: Final = 2**16
  147. def __init__(self, name: str, timeout: float | None = None) -> None:
  148. if sys.platform == "win32":
  149. name = r"\\.\pipe\{}-{}.pipe".format(
  150. name, base64.urlsafe_b64encode(os.urandom(6)).decode()
  151. )
  152. else:
  153. name = f"{name}.sock"
  154. super().__init__(name, timeout)
  155. if sys.platform == "win32":
  156. self.connection = _winapi.CreateNamedPipe(
  157. self.name,
  158. _winapi.PIPE_ACCESS_DUPLEX
  159. | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
  160. | _winapi.FILE_FLAG_OVERLAPPED,
  161. _winapi.PIPE_READMODE_MESSAGE
  162. | _winapi.PIPE_TYPE_MESSAGE
  163. | _winapi.PIPE_WAIT
  164. | 0x8, # PIPE_REJECT_REMOTE_CLIENTS
  165. 1, # one instance
  166. self.BUFFER_SIZE,
  167. self.BUFFER_SIZE,
  168. _winapi.NMPWAIT_WAIT_FOREVER,
  169. 0, # Use default security descriptor
  170. )
  171. if self.connection == -1: # INVALID_HANDLE_VALUE
  172. err = _winapi.GetLastError()
  173. raise IPCException(f"Invalid handle to pipe: {err}")
  174. else:
  175. self.sock_directory = tempfile.mkdtemp()
  176. sockfile = os.path.join(self.sock_directory, self.name)
  177. self.sock = socket.socket(socket.AF_UNIX)
  178. self.sock.bind(sockfile)
  179. self.sock.listen(1)
  180. if timeout is not None:
  181. self.sock.settimeout(timeout)
  182. def __enter__(self) -> IPCServer:
  183. if sys.platform == "win32":
  184. # NOTE: It is theoretically possible that this will hang forever if the
  185. # client never connects, though this can be "solved" by killing the server
  186. try:
  187. ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
  188. except OSError as e:
  189. # Don't raise if the client already exists, or the client already connected
  190. if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
  191. raise
  192. else:
  193. try:
  194. timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
  195. res = _winapi.WaitForSingleObject(ov.event, timeout)
  196. assert res == _winapi.WAIT_OBJECT_0
  197. except BaseException:
  198. ov.cancel()
  199. _winapi.CloseHandle(self.connection)
  200. raise
  201. _, err = ov.GetOverlappedResult(True)
  202. assert err == 0
  203. else:
  204. try:
  205. self.connection, _ = self.sock.accept()
  206. except socket.timeout as e:
  207. raise IPCException("The socket timed out") from e
  208. return self
  209. def __exit__(
  210. self,
  211. exc_ty: type[BaseException] | None = None,
  212. exc_val: BaseException | None = None,
  213. exc_tb: TracebackType | None = None,
  214. ) -> None:
  215. if sys.platform == "win32":
  216. try:
  217. # Wait for the client to finish reading the last write before disconnecting
  218. if not FlushFileBuffers(self.connection):
  219. raise IPCException(
  220. "Failed to flush NamedPipe buffer, maybe the client hung up?"
  221. )
  222. finally:
  223. DisconnectNamedPipe(self.connection)
  224. else:
  225. self.close()
  226. def cleanup(self) -> None:
  227. if sys.platform == "win32":
  228. self.close()
  229. else:
  230. shutil.rmtree(self.sock_directory)
  231. @property
  232. def connection_name(self) -> str:
  233. if sys.platform == "win32":
  234. return self.name
  235. else:
  236. name = self.sock.getsockname()
  237. assert isinstance(name, str)
  238. return name