| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- """Cross platform abstractions for inter-process communication
- On Unix, this uses AF_UNIX sockets.
- On Windows, this uses NamedPipes.
- """
- from __future__ import annotations
- import base64
- import os
- import shutil
- import sys
- import tempfile
- from types import TracebackType
- from typing import Callable, Final
- if sys.platform == "win32":
- # This may be private, but it is needed for IPC on Windows, and is basically stable
- import ctypes
- import _winapi
- _IPCHandle = int
- kernel32 = ctypes.windll.kernel32
- DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
- FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
- else:
- import socket
- _IPCHandle = socket.socket
- class IPCException(Exception):
- """Exception for IPC issues."""
- class IPCBase:
- """Base class for communication between the dmypy client and server.
- This contains logic shared between the client and server, such as reading
- and writing.
- """
- connection: _IPCHandle
- def __init__(self, name: str, timeout: float | None) -> None:
- self.name = name
- self.timeout = timeout
- def read(self, size: int = 100000) -> bytes:
- """Read bytes from an IPC connection until its empty."""
- bdata = bytearray()
- if sys.platform == "win32":
- while True:
- ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
- try:
- if err == _winapi.ERROR_IO_PENDING:
- timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
- res = _winapi.WaitForSingleObject(ov.event, timeout)
- if res != _winapi.WAIT_OBJECT_0:
- raise IPCException(f"Bad result from I/O wait: {res}")
- except BaseException:
- ov.cancel()
- raise
- _, err = ov.GetOverlappedResult(True)
- more = ov.getbuffer()
- if more:
- bdata.extend(more)
- if err == 0:
- # we are done!
- break
- elif err == _winapi.ERROR_MORE_DATA:
- # read again
- continue
- elif err == _winapi.ERROR_OPERATION_ABORTED:
- raise IPCException("ReadFile operation aborted.")
- else:
- while True:
- more = self.connection.recv(size)
- if not more:
- break
- bdata.extend(more)
- return bytes(bdata)
- def write(self, data: bytes) -> None:
- """Write bytes to an IPC connection."""
- if sys.platform == "win32":
- try:
- ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
- try:
- if err == _winapi.ERROR_IO_PENDING:
- timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
- res = _winapi.WaitForSingleObject(ov.event, timeout)
- if res != _winapi.WAIT_OBJECT_0:
- raise IPCException(f"Bad result from I/O wait: {res}")
- elif err != 0:
- raise IPCException(f"Failed writing to pipe with error: {err}")
- except BaseException:
- ov.cancel()
- raise
- bytes_written, err = ov.GetOverlappedResult(True)
- assert err == 0, err
- assert bytes_written == len(data)
- except OSError as e:
- raise IPCException(f"Failed to write with error: {e.winerror}") from e
- else:
- self.connection.sendall(data)
- self.connection.shutdown(socket.SHUT_WR)
- def close(self) -> None:
- if sys.platform == "win32":
- if self.connection != _winapi.NULL:
- _winapi.CloseHandle(self.connection)
- else:
- self.connection.close()
- class IPCClient(IPCBase):
- """The client side of an IPC connection."""
- def __init__(self, name: str, timeout: float | None) -> None:
- super().__init__(name, timeout)
- if sys.platform == "win32":
- timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
- try:
- _winapi.WaitNamedPipe(self.name, timeout)
- except FileNotFoundError as e:
- raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
- except OSError as e:
- if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
- raise IPCException("Timed out waiting for connection.") from e
- else:
- raise
- try:
- self.connection = _winapi.CreateFile(
- self.name,
- _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
- 0,
- _winapi.NULL,
- _winapi.OPEN_EXISTING,
- _winapi.FILE_FLAG_OVERLAPPED,
- _winapi.NULL,
- )
- except OSError as e:
- if e.winerror == _winapi.ERROR_PIPE_BUSY:
- raise IPCException("The connection is busy.") from e
- else:
- raise
- _winapi.SetNamedPipeHandleState(
- self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None
- )
- else:
- self.connection = socket.socket(socket.AF_UNIX)
- self.connection.settimeout(timeout)
- self.connection.connect(name)
- def __enter__(self) -> IPCClient:
- return self
- def __exit__(
- self,
- exc_ty: type[BaseException] | None = None,
- exc_val: BaseException | None = None,
- exc_tb: TracebackType | None = None,
- ) -> None:
- self.close()
- class IPCServer(IPCBase):
- BUFFER_SIZE: Final = 2**16
- def __init__(self, name: str, timeout: float | None = None) -> None:
- if sys.platform == "win32":
- name = r"\\.\pipe\{}-{}.pipe".format(
- name, base64.urlsafe_b64encode(os.urandom(6)).decode()
- )
- else:
- name = f"{name}.sock"
- super().__init__(name, timeout)
- if sys.platform == "win32":
- self.connection = _winapi.CreateNamedPipe(
- self.name,
- _winapi.PIPE_ACCESS_DUPLEX
- | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
- | _winapi.FILE_FLAG_OVERLAPPED,
- _winapi.PIPE_READMODE_MESSAGE
- | _winapi.PIPE_TYPE_MESSAGE
- | _winapi.PIPE_WAIT
- | 0x8, # PIPE_REJECT_REMOTE_CLIENTS
- 1, # one instance
- self.BUFFER_SIZE,
- self.BUFFER_SIZE,
- _winapi.NMPWAIT_WAIT_FOREVER,
- 0, # Use default security descriptor
- )
- if self.connection == -1: # INVALID_HANDLE_VALUE
- err = _winapi.GetLastError()
- raise IPCException(f"Invalid handle to pipe: {err}")
- else:
- self.sock_directory = tempfile.mkdtemp()
- sockfile = os.path.join(self.sock_directory, self.name)
- self.sock = socket.socket(socket.AF_UNIX)
- self.sock.bind(sockfile)
- self.sock.listen(1)
- if timeout is not None:
- self.sock.settimeout(timeout)
- def __enter__(self) -> IPCServer:
- if sys.platform == "win32":
- # NOTE: It is theoretically possible that this will hang forever if the
- # client never connects, though this can be "solved" by killing the server
- try:
- ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
- except OSError as e:
- # Don't raise if the client already exists, or the client already connected
- if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
- raise
- else:
- try:
- timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
- res = _winapi.WaitForSingleObject(ov.event, timeout)
- assert res == _winapi.WAIT_OBJECT_0
- except BaseException:
- ov.cancel()
- _winapi.CloseHandle(self.connection)
- raise
- _, err = ov.GetOverlappedResult(True)
- assert err == 0
- else:
- try:
- self.connection, _ = self.sock.accept()
- except socket.timeout as e:
- raise IPCException("The socket timed out") from e
- return self
- def __exit__(
- self,
- exc_ty: type[BaseException] | None = None,
- exc_val: BaseException | None = None,
- exc_tb: TracebackType | None = None,
- ) -> None:
- if sys.platform == "win32":
- try:
- # Wait for the client to finish reading the last write before disconnecting
- if not FlushFileBuffers(self.connection):
- raise IPCException(
- "Failed to flush NamedPipe buffer, maybe the client hung up?"
- )
- finally:
- DisconnectNamedPipe(self.connection)
- else:
- self.close()
- def cleanup(self) -> None:
- if sys.platform == "win32":
- self.close()
- else:
- shutil.rmtree(self.sock_directory)
- @property
- def connection_name(self) -> str:
- if sys.platform == "win32":
- return self.name
- else:
- name = self.sock.getsockname()
- assert isinstance(name, str)
- return name
|