cancelreader_windows.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. //go:build windows
  2. // +build windows
  3. package input
  4. import (
  5. "fmt"
  6. "io"
  7. "os"
  8. "sync"
  9. "time"
  10. "github.com/erikgeiser/coninput"
  11. "github.com/muesli/cancelreader"
  12. "golang.org/x/sys/windows"
  13. )
  14. type conInputReader struct {
  15. cancelMixin
  16. conin windows.Handle
  17. cancelEvent windows.Handle
  18. originalMode uint32
  19. // blockingReadSignal is used to signal that a blocking read is in progress.
  20. blockingReadSignal chan struct{}
  21. }
  22. var _ cancelreader.CancelReader = &conInputReader{}
  23. func newCancelreader(r io.Reader) (cancelreader.CancelReader, error) {
  24. fallback := func(io.Reader) (cancelreader.CancelReader, error) {
  25. return cancelreader.NewReader(r)
  26. }
  27. var dummy uint32
  28. if f, ok := r.(cancelreader.File); !ok || f.Fd() != os.Stdin.Fd() ||
  29. // If data was piped to the standard input, it does not emit events
  30. // anymore. We can detect this if the console mode cannot be set anymore,
  31. // in this case, we fallback to the default cancelreader implementation.
  32. windows.GetConsoleMode(windows.Handle(f.Fd()), &dummy) != nil {
  33. return fallback(r)
  34. }
  35. conin, err := coninput.NewStdinHandle()
  36. if err != nil {
  37. return fallback(r)
  38. }
  39. originalMode, err := prepareConsole(conin,
  40. windows.ENABLE_MOUSE_INPUT,
  41. windows.ENABLE_WINDOW_INPUT,
  42. windows.ENABLE_EXTENDED_FLAGS,
  43. )
  44. if err != nil {
  45. return nil, fmt.Errorf("failed to prepare console input: %w", err)
  46. }
  47. cancelEvent, err := windows.CreateEvent(nil, 0, 0, nil)
  48. if err != nil {
  49. return nil, fmt.Errorf("create stop event: %w", err)
  50. }
  51. return &conInputReader{
  52. conin: conin,
  53. cancelEvent: cancelEvent,
  54. originalMode: originalMode,
  55. blockingReadSignal: make(chan struct{}, 1),
  56. }, nil
  57. }
  58. // Cancel implements cancelreader.CancelReader.
  59. func (r *conInputReader) Cancel() bool {
  60. r.setCanceled()
  61. select {
  62. case r.blockingReadSignal <- struct{}{}:
  63. err := windows.SetEvent(r.cancelEvent)
  64. if err != nil {
  65. return false
  66. }
  67. <-r.blockingReadSignal
  68. case <-time.After(100 * time.Millisecond):
  69. // Read() hangs in a GetOverlappedResult which is likely due to
  70. // WaitForMultipleObjects returning without input being available
  71. // so we cannot cancel this ongoing read.
  72. return false
  73. }
  74. return true
  75. }
  76. // Close implements cancelreader.CancelReader.
  77. func (r *conInputReader) Close() error {
  78. err := windows.CloseHandle(r.cancelEvent)
  79. if err != nil {
  80. return fmt.Errorf("closing cancel event handle: %w", err)
  81. }
  82. if r.originalMode != 0 {
  83. err := windows.SetConsoleMode(r.conin, r.originalMode)
  84. if err != nil {
  85. return fmt.Errorf("reset console mode: %w", err)
  86. }
  87. }
  88. return nil
  89. }
  90. // Read implements cancelreader.CancelReader.
  91. func (r *conInputReader) Read(data []byte) (n int, err error) {
  92. if r.isCanceled() {
  93. return 0, cancelreader.ErrCanceled
  94. }
  95. err = waitForInput(r.conin, r.cancelEvent)
  96. if err != nil {
  97. return 0, err
  98. }
  99. if r.isCanceled() {
  100. return 0, cancelreader.ErrCanceled
  101. }
  102. r.blockingReadSignal <- struct{}{}
  103. n, err = overlappedReader(r.conin).Read(data)
  104. <-r.blockingReadSignal
  105. return
  106. }
  107. func prepareConsole(input windows.Handle, modes ...uint32) (originalMode uint32, err error) {
  108. err = windows.GetConsoleMode(input, &originalMode)
  109. if err != nil {
  110. return 0, fmt.Errorf("get console mode: %w", err)
  111. }
  112. newMode := coninput.AddInputModes(0, modes...)
  113. err = windows.SetConsoleMode(input, newMode)
  114. if err != nil {
  115. return 0, fmt.Errorf("set console mode: %w", err)
  116. }
  117. return originalMode, nil
  118. }
  119. func waitForInput(conin, cancel windows.Handle) error {
  120. event, err := windows.WaitForMultipleObjects([]windows.Handle{conin, cancel}, false, windows.INFINITE)
  121. switch {
  122. case windows.WAIT_OBJECT_0 <= event && event < windows.WAIT_OBJECT_0+2:
  123. if event == windows.WAIT_OBJECT_0+1 {
  124. return cancelreader.ErrCanceled
  125. }
  126. if event == windows.WAIT_OBJECT_0 {
  127. return nil
  128. }
  129. return fmt.Errorf("unexpected wait object is ready: %d", event-windows.WAIT_OBJECT_0)
  130. case windows.WAIT_ABANDONED <= event && event < windows.WAIT_ABANDONED+2:
  131. return fmt.Errorf("abandoned")
  132. case event == uint32(windows.WAIT_TIMEOUT):
  133. return fmt.Errorf("timeout")
  134. case event == windows.WAIT_FAILED:
  135. return fmt.Errorf("failed")
  136. default:
  137. return fmt.Errorf("unexpected error: %w", err)
  138. }
  139. }
  140. // cancelMixin represents a goroutine-safe cancelation status.
  141. type cancelMixin struct {
  142. unsafeCanceled bool
  143. lock sync.Mutex
  144. }
  145. func (c *cancelMixin) setCanceled() {
  146. c.lock.Lock()
  147. defer c.lock.Unlock()
  148. c.unsafeCanceled = true
  149. }
  150. func (c *cancelMixin) isCanceled() bool {
  151. c.lock.Lock()
  152. defer c.lock.Unlock()
  153. return c.unsafeCanceled
  154. }
  155. type overlappedReader windows.Handle
  156. // Read performs an overlapping read fom a windows.Handle.
  157. func (r overlappedReader) Read(data []byte) (int, error) {
  158. hevent, err := windows.CreateEvent(nil, 0, 0, nil)
  159. if err != nil {
  160. return 0, fmt.Errorf("create event: %w", err)
  161. }
  162. overlapped := windows.Overlapped{HEvent: hevent}
  163. var n uint32
  164. err = windows.ReadFile(windows.Handle(r), data, &n, &overlapped)
  165. if err != nil && err != windows.ERROR_IO_PENDING {
  166. return int(n), err
  167. }
  168. err = windows.GetOverlappedResult(windows.Handle(r), &overlapped, &n, true)
  169. if err != nil {
  170. return int(n), nil
  171. }
  172. return int(n), nil
  173. }