cancelreader_linux.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. //go:build linux
  2. // +build linux
  3. package cancelreader
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "os"
  9. "strings"
  10. "golang.org/x/sys/unix"
  11. )
  12. // NewReader returns a reader and a cancel function. If the input reader is a
  13. // File, the cancel function can be used to interrupt a blocking read call.
  14. // In this case, the cancel function returns true if the call was canceled
  15. // successfully. If the input reader is not a File, the cancel function
  16. // does nothing and always returns false. The Linux implementation is based on
  17. // the epoll mechanism.
  18. func NewReader(reader io.Reader) (CancelReader, error) {
  19. file, ok := reader.(File)
  20. if !ok {
  21. return newFallbackCancelReader(reader)
  22. }
  23. epoll, err := unix.EpollCreate1(0)
  24. if err != nil {
  25. return nil, fmt.Errorf("create epoll: %w", err)
  26. }
  27. r := &epollCancelReader{
  28. file: file,
  29. epoll: epoll,
  30. }
  31. r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe()
  32. if err != nil {
  33. _ = unix.Close(epoll)
  34. return nil, err
  35. }
  36. err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(file.Fd()), &unix.EpollEvent{
  37. Events: unix.EPOLLIN,
  38. Fd: int32(file.Fd()),
  39. })
  40. if err != nil {
  41. _ = unix.Close(epoll)
  42. return nil, fmt.Errorf("add reader to epoll interest list")
  43. }
  44. err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(r.cancelSignalReader.Fd()), &unix.EpollEvent{
  45. Events: unix.EPOLLIN,
  46. Fd: int32(r.cancelSignalReader.Fd()),
  47. })
  48. if err != nil {
  49. _ = unix.Close(epoll)
  50. return nil, fmt.Errorf("add reader to epoll interest list")
  51. }
  52. return r, nil
  53. }
  54. type epollCancelReader struct {
  55. file File
  56. cancelSignalReader File
  57. cancelSignalWriter File
  58. cancelMixin
  59. epoll int
  60. }
  61. func (r *epollCancelReader) Read(data []byte) (int, error) {
  62. if r.isCanceled() {
  63. return 0, ErrCanceled
  64. }
  65. err := r.wait()
  66. if err != nil {
  67. if errors.Is(err, ErrCanceled) {
  68. // remove signal from pipe
  69. var b [1]byte
  70. _, readErr := r.cancelSignalReader.Read(b[:])
  71. if readErr != nil {
  72. return 0, fmt.Errorf("reading cancel signal: %w", readErr)
  73. }
  74. }
  75. return 0, err
  76. }
  77. return r.file.Read(data)
  78. }
  79. func (r *epollCancelReader) Cancel() bool {
  80. r.setCanceled()
  81. // send cancel signal
  82. _, err := r.cancelSignalWriter.Write([]byte{'c'})
  83. return err == nil
  84. }
  85. func (r *epollCancelReader) Close() error {
  86. var errMsgs []string
  87. // close kqueue
  88. err := unix.Close(r.epoll)
  89. if err != nil {
  90. errMsgs = append(errMsgs, fmt.Sprintf("closing epoll: %v", err))
  91. }
  92. // close pipe
  93. err = r.cancelSignalWriter.Close()
  94. if err != nil {
  95. errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err))
  96. }
  97. err = r.cancelSignalReader.Close()
  98. if err != nil {
  99. errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err))
  100. }
  101. if len(errMsgs) > 0 {
  102. return fmt.Errorf(strings.Join(errMsgs, ", "))
  103. }
  104. return nil
  105. }
  106. func (r *epollCancelReader) wait() error {
  107. events := make([]unix.EpollEvent, 1)
  108. for {
  109. _, err := unix.EpollWait(r.epoll, events, -1)
  110. if errors.Is(err, unix.EINTR) {
  111. continue // try again if the syscall was interrupted
  112. }
  113. if err != nil {
  114. return fmt.Errorf("kevent: %w", err)
  115. }
  116. break
  117. }
  118. switch events[0].Fd {
  119. case int32(r.file.Fd()):
  120. return nil
  121. case int32(r.cancelSignalReader.Fd()):
  122. return ErrCanceled
  123. }
  124. return fmt.Errorf("unknown error")
  125. }