cancelreader_windows.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. //go:build windows
  2. // +build windows
  3. package cancelreader
  4. import (
  5. "fmt"
  6. "io"
  7. "os"
  8. "syscall"
  9. "time"
  10. "unicode/utf16"
  11. "golang.org/x/sys/windows"
  12. )
  13. var fileShareValidFlags uint32 = 0x00000007
  14. // NewReader returns a reader and a cancel function. If the input reader is a
  15. // File with the same file descriptor as os.Stdin, the cancel function can
  16. // be used to interrupt a blocking read call. In this case, the cancel function
  17. // returns true if the call was canceled successfully. If the input reader is
  18. // not a File with the same file descriptor as os.Stdin, the cancel
  19. // function does nothing and always returns false. The Windows implementation
  20. // is based on WaitForMultipleObject with overlapping reads from CONIN$.
  21. func NewReader(reader io.Reader) (CancelReader, error) {
  22. if f, ok := reader.(File); !ok || f.Fd() != os.Stdin.Fd() {
  23. return newFallbackCancelReader(reader)
  24. }
  25. // it is necessary to open CONIN$ (NOT windows.STD_INPUT_HANDLE) in
  26. // overlapped mode to be able to use it with WaitForMultipleObjects.
  27. conin, err := windows.CreateFile(
  28. &(utf16.Encode([]rune("CONIN$\x00"))[0]), windows.GENERIC_READ|windows.GENERIC_WRITE,
  29. fileShareValidFlags, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED, 0)
  30. if err != nil {
  31. return nil, fmt.Errorf("open CONIN$ in overlapping mode: %w", err)
  32. }
  33. resetConsole, err := prepareConsole(conin)
  34. if err != nil {
  35. return nil, fmt.Errorf("prepare console: %w", err)
  36. }
  37. // flush input, otherwise it can contain events which trigger
  38. // WaitForMultipleObjects but which ReadFile cannot read, resulting in an
  39. // un-cancelable read
  40. err = flushConsoleInputBuffer(conin)
  41. if err != nil {
  42. return nil, fmt.Errorf("flush console input buffer: %w", err)
  43. }
  44. cancelEvent, err := windows.CreateEvent(nil, 0, 0, nil)
  45. if err != nil {
  46. return nil, fmt.Errorf("create stop event: %w", err)
  47. }
  48. return &winCancelReader{
  49. conin: conin,
  50. cancelEvent: cancelEvent,
  51. resetConsole: resetConsole,
  52. blockingReadSignal: make(chan struct{}, 1),
  53. }, nil
  54. }
  55. type winCancelReader struct {
  56. conin windows.Handle
  57. cancelEvent windows.Handle
  58. cancelMixin
  59. resetConsole func() error
  60. blockingReadSignal chan struct{}
  61. }
  62. func (r *winCancelReader) Read(data []byte) (int, error) {
  63. if r.isCanceled() {
  64. return 0, ErrCanceled
  65. }
  66. err := r.wait()
  67. if err != nil {
  68. return 0, err
  69. }
  70. if r.isCanceled() {
  71. return 0, ErrCanceled
  72. }
  73. // windows.Read does not work on overlapping windows.Handles
  74. return r.readAsync(data)
  75. }
  76. // Cancel cancels ongoing and future Read() calls and returns true if the
  77. // cancelation of the ongoing Read() was successful. On Windows Terminal,
  78. // WaitForMultipleObjects sometimes immediately returns without input being
  79. // available. In this case, graceful cancelation is not possible and Cancel()
  80. // returns false.
  81. func (r *winCancelReader) Cancel() bool {
  82. r.setCanceled()
  83. select {
  84. case r.blockingReadSignal <- struct{}{}:
  85. err := windows.SetEvent(r.cancelEvent)
  86. if err != nil {
  87. return false
  88. }
  89. <-r.blockingReadSignal
  90. case <-time.After(100 * time.Millisecond):
  91. // Read() hangs in a GetOverlappedResult which is likely due to
  92. // WaitForMultipleObjects returning without input being available
  93. // so we cannot cancel this ongoing read.
  94. return false
  95. }
  96. return true
  97. }
  98. func (r *winCancelReader) Close() error {
  99. err := windows.CloseHandle(r.cancelEvent)
  100. if err != nil {
  101. return fmt.Errorf("closing cancel event handle: %w", err)
  102. }
  103. err = r.resetConsole()
  104. if err != nil {
  105. return err
  106. }
  107. err = windows.Close(r.conin)
  108. if err != nil {
  109. return fmt.Errorf("closing CONIN$")
  110. }
  111. return nil
  112. }
  113. func (r *winCancelReader) wait() error {
  114. event, err := windows.WaitForMultipleObjects([]windows.Handle{r.conin, r.cancelEvent}, false, windows.INFINITE)
  115. switch {
  116. case windows.WAIT_OBJECT_0 <= event && event < windows.WAIT_OBJECT_0+2:
  117. if event == windows.WAIT_OBJECT_0+1 {
  118. return ErrCanceled
  119. }
  120. if event == windows.WAIT_OBJECT_0 {
  121. return nil
  122. }
  123. return fmt.Errorf("unexpected wait object is ready: %d", event-windows.WAIT_OBJECT_0)
  124. case windows.WAIT_ABANDONED <= event && event < windows.WAIT_ABANDONED+2:
  125. return fmt.Errorf("abandoned")
  126. case event == uint32(windows.WAIT_TIMEOUT):
  127. return fmt.Errorf("timeout")
  128. case event == windows.WAIT_FAILED:
  129. return fmt.Errorf("failed")
  130. default:
  131. return fmt.Errorf("unexpected error: %w", error(err))
  132. }
  133. }
  134. // readAsync is necessary to read from a windows.Handle in overlapping mode.
  135. func (r *winCancelReader) readAsync(data []byte) (int, error) {
  136. hevent, err := windows.CreateEvent(nil, 0, 0, nil)
  137. if err != nil {
  138. return 0, fmt.Errorf("create event: %w", err)
  139. }
  140. overlapped := windows.Overlapped{
  141. HEvent: hevent,
  142. }
  143. var n uint32
  144. err = windows.ReadFile(r.conin, data, &n, &overlapped)
  145. if err != nil && err != windows.ERROR_IO_PENDING {
  146. return int(n), err
  147. }
  148. r.blockingReadSignal <- struct{}{}
  149. err = windows.GetOverlappedResult(r.conin, &overlapped, &n, true)
  150. if err != nil {
  151. return int(n), nil
  152. }
  153. <-r.blockingReadSignal
  154. return int(n), nil
  155. }
  156. func prepareConsole(input windows.Handle) (reset func() error, err error) {
  157. var originalMode uint32
  158. err = windows.GetConsoleMode(input, &originalMode)
  159. if err != nil {
  160. return nil, fmt.Errorf("get console mode: %w", err)
  161. }
  162. var newMode uint32
  163. newMode &^= windows.ENABLE_ECHO_INPUT
  164. newMode &^= windows.ENABLE_LINE_INPUT
  165. newMode &^= windows.ENABLE_MOUSE_INPUT
  166. newMode &^= windows.ENABLE_WINDOW_INPUT
  167. newMode &^= windows.ENABLE_PROCESSED_INPUT
  168. newMode |= windows.ENABLE_EXTENDED_FLAGS
  169. newMode |= windows.ENABLE_INSERT_MODE
  170. newMode |= windows.ENABLE_QUICK_EDIT_MODE
  171. // Enabling virtual terminal input is necessary for processing certain
  172. // types of input like X10 mouse events and arrows keys with the current
  173. // bytes-based input reader. It does, however, prevent cancelReader from
  174. // being able to cancel input. The planned solution for this is to read
  175. // Windows events in a more native fashion, rather than the current simple
  176. // bytes-based input reader which works well on unix systems.
  177. newMode |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT
  178. err = windows.SetConsoleMode(input, newMode)
  179. if err != nil {
  180. return nil, fmt.Errorf("set console mode: %w", err)
  181. }
  182. return func() error {
  183. err := windows.SetConsoleMode(input, originalMode)
  184. if err != nil {
  185. return fmt.Errorf("reset console mode: %w", err)
  186. }
  187. return nil
  188. }, nil
  189. }
  190. var (
  191. modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
  192. procFlushConsoleInputBuffer = modkernel32.NewProc("FlushConsoleInputBuffer")
  193. )
  194. func flushConsoleInputBuffer(consoleInput windows.Handle) error {
  195. r, _, e := syscall.Syscall(procFlushConsoleInputBuffer.Addr(), 1,
  196. uintptr(consoleInput), 0, 0)
  197. if r == 0 {
  198. return error(e)
  199. }
  200. return nil
  201. }