tcpdialer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. package fasthttp
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "strconv"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. // Dial dials the given TCP addr using tcp4.
  13. //
  14. // This function has the following additional features comparing to net.Dial:
  15. //
  16. // - It reduces load on DNS resolver by caching resolved TCP addressed
  17. // for DNSCacheDuration.
  18. // - It dials all the resolved TCP addresses in round-robin manner until
  19. // connection is established. This may be useful if certain addresses
  20. // are temporarily unreachable.
  21. // - It returns ErrDialTimeout if connection cannot be established during
  22. // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
  23. //
  24. // This dialer is intended for custom code wrapping before passing
  25. // to Client.Dial or HostClient.Dial.
  26. //
  27. // For instance, per-host counters and/or limits may be implemented
  28. // by such wrappers.
  29. //
  30. // The addr passed to the function must contain port. Example addr values:
  31. //
  32. // - foobar.baz:443
  33. // - foo.bar:80
  34. // - aaa.com:8080
  35. func Dial(addr string) (net.Conn, error) {
  36. return defaultDialer.Dial(addr)
  37. }
  38. // DialTimeout dials the given TCP addr using tcp4 using the given timeout.
  39. //
  40. // This function has the following additional features comparing to net.Dial:
  41. //
  42. // - It reduces load on DNS resolver by caching resolved TCP addressed
  43. // for DNSCacheDuration.
  44. // - It dials all the resolved TCP addresses in round-robin manner until
  45. // connection is established. This may be useful if certain addresses
  46. // are temporarily unreachable.
  47. //
  48. // This dialer is intended for custom code wrapping before passing
  49. // to Client.DialTimeout or HostClient.DialTimeout.
  50. //
  51. // For instance, per-host counters and/or limits may be implemented
  52. // by such wrappers.
  53. //
  54. // The addr passed to the function must contain port. Example addr values:
  55. //
  56. // - foobar.baz:443
  57. // - foo.bar:80
  58. // - aaa.com:8080
  59. func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  60. return defaultDialer.DialTimeout(addr, timeout)
  61. }
  62. // DialDualStack dials the given TCP addr using both tcp4 and tcp6.
  63. //
  64. // This function has the following additional features comparing to net.Dial:
  65. //
  66. // - It reduces load on DNS resolver by caching resolved TCP addressed
  67. // for DNSCacheDuration.
  68. // - It dials all the resolved TCP addresses in round-robin manner until
  69. // connection is established. This may be useful if certain addresses
  70. // are temporarily unreachable.
  71. // - It returns ErrDialTimeout if connection cannot be established during
  72. // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
  73. // timeout.
  74. //
  75. // This dialer is intended for custom code wrapping before passing
  76. // to Client.Dial or HostClient.Dial.
  77. //
  78. // For instance, per-host counters and/or limits may be implemented
  79. // by such wrappers.
  80. //
  81. // The addr passed to the function must contain port. Example addr values:
  82. //
  83. // - foobar.baz:443
  84. // - foo.bar:80
  85. // - aaa.com:8080
  86. func DialDualStack(addr string) (net.Conn, error) {
  87. return defaultDialer.DialDualStack(addr)
  88. }
  89. // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
  90. // using the given timeout.
  91. //
  92. // This function has the following additional features comparing to net.Dial:
  93. //
  94. // - It reduces load on DNS resolver by caching resolved TCP addressed
  95. // for DNSCacheDuration.
  96. // - It dials all the resolved TCP addresses in round-robin manner until
  97. // connection is established. This may be useful if certain addresses
  98. // are temporarily unreachable.
  99. //
  100. // This dialer is intended for custom code wrapping before passing
  101. // to Client.DialTimeout or HostClient.DialTimeout.
  102. //
  103. // For instance, per-host counters and/or limits may be implemented
  104. // by such wrappers.
  105. //
  106. // The addr passed to the function must contain port. Example addr values:
  107. //
  108. // - foobar.baz:443
  109. // - foo.bar:80
  110. // - aaa.com:8080
  111. func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  112. return defaultDialer.DialDualStackTimeout(addr, timeout)
  113. }
  114. var defaultDialer = &TCPDialer{Concurrency: 1000}
  115. // Resolver represents interface of the tcp resolver.
  116. type Resolver interface {
  117. LookupIPAddr(context.Context, string) (names []net.IPAddr, err error)
  118. }
  119. // TCPDialer contains options to control a group of Dial calls.
  120. type TCPDialer struct {
  121. // This may be used to override DNS resolving policy, like this:
  122. // var dialer = &fasthttp.TCPDialer{
  123. // Resolver: &net.Resolver{
  124. // PreferGo: true,
  125. // StrictErrors: false,
  126. // Dial: func (ctx context.Context, network, address string) (net.Conn, error) {
  127. // d := net.Dialer{}
  128. // return d.DialContext(ctx, "udp", "8.8.8.8:53")
  129. // },
  130. // },
  131. // }
  132. Resolver Resolver
  133. // LocalAddr is the local address to use when dialing an
  134. // address.
  135. // If nil, a local address is automatically chosen.
  136. LocalAddr *net.TCPAddr
  137. concurrencyCh chan struct{}
  138. tcpAddrsMap sync.Map
  139. // Concurrency controls the maximum number of concurrent Dials
  140. // that can be performed using this object.
  141. // Setting this to 0 means unlimited.
  142. //
  143. // WARNING: This can only be changed before the first Dial.
  144. // Changes made after the first Dial will not affect anything.
  145. Concurrency int
  146. // DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration)
  147. DNSCacheDuration time.Duration
  148. once sync.Once
  149. // DisableDNSResolution may be used to disable DNS resolution
  150. DisableDNSResolution bool
  151. }
  152. // Dial dials the given TCP addr using tcp4.
  153. //
  154. // This function has the following additional features comparing to net.Dial:
  155. //
  156. // - It reduces load on DNS resolver by caching resolved TCP addressed
  157. // for DNSCacheDuration.
  158. // - It dials all the resolved TCP addresses in round-robin manner until
  159. // connection is established. This may be useful if certain addresses
  160. // are temporarily unreachable.
  161. // - It returns ErrDialTimeout if connection cannot be established during
  162. // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
  163. //
  164. // This dialer is intended for custom code wrapping before passing
  165. // to Client.Dial or HostClient.Dial.
  166. //
  167. // For instance, per-host counters and/or limits may be implemented
  168. // by such wrappers.
  169. //
  170. // The addr passed to the function must contain port. Example addr values:
  171. //
  172. // - foobar.baz:443
  173. // - foo.bar:80
  174. // - aaa.com:8080
  175. func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
  176. return d.dial(addr, false, DefaultDialTimeout)
  177. }
  178. // DialTimeout dials the given TCP addr using tcp4 using the given timeout.
  179. //
  180. // This function has the following additional features comparing to net.Dial:
  181. //
  182. // - It reduces load on DNS resolver by caching resolved TCP addressed
  183. // for DNSCacheDuration.
  184. // - It dials all the resolved TCP addresses in round-robin manner until
  185. // connection is established. This may be useful if certain addresses
  186. // are temporarily unreachable.
  187. //
  188. // This dialer is intended for custom code wrapping before passing
  189. // to Client.DialTimeout or HostClient.DialTimeout.
  190. //
  191. // For instance, per-host counters and/or limits may be implemented
  192. // by such wrappers.
  193. //
  194. // The addr passed to the function must contain port. Example addr values:
  195. //
  196. // - foobar.baz:443
  197. // - foo.bar:80
  198. // - aaa.com:8080
  199. func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  200. return d.dial(addr, false, timeout)
  201. }
  202. // DialDualStack dials the given TCP addr using both tcp4 and tcp6.
  203. //
  204. // This function has the following additional features comparing to net.Dial:
  205. //
  206. // - It reduces load on DNS resolver by caching resolved TCP addressed
  207. // for DNSCacheDuration.
  208. // - It dials all the resolved TCP addresses in round-robin manner until
  209. // connection is established. This may be useful if certain addresses
  210. // are temporarily unreachable.
  211. // - It returns ErrDialTimeout if connection cannot be established during
  212. // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
  213. // timeout.
  214. //
  215. // This dialer is intended for custom code wrapping before passing
  216. // to Client.Dial or HostClient.Dial.
  217. //
  218. // For instance, per-host counters and/or limits may be implemented
  219. // by such wrappers.
  220. //
  221. // The addr passed to the function must contain port. Example addr values:
  222. //
  223. // - foobar.baz:443
  224. // - foo.bar:80
  225. // - aaa.com:8080
  226. func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
  227. return d.dial(addr, true, DefaultDialTimeout)
  228. }
  229. // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
  230. // using the given timeout.
  231. //
  232. // This function has the following additional features comparing to net.Dial:
  233. //
  234. // - It reduces load on DNS resolver by caching resolved TCP addressed
  235. // for DNSCacheDuration.
  236. // - It dials all the resolved TCP addresses in round-robin manner until
  237. // connection is established. This may be useful if certain addresses
  238. // are temporarily unreachable.
  239. //
  240. // This dialer is intended for custom code wrapping before passing
  241. // to Client.DialTimeout or HostClient.DialTimeout.
  242. //
  243. // For instance, per-host counters and/or limits may be implemented
  244. // by such wrappers.
  245. //
  246. // The addr passed to the function must contain port. Example addr values:
  247. //
  248. // - foobar.baz:443
  249. // - foo.bar:80
  250. // - aaa.com:8080
  251. func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  252. return d.dial(addr, true, timeout)
  253. }
  254. func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) {
  255. d.once.Do(func() {
  256. if d.Concurrency > 0 {
  257. d.concurrencyCh = make(chan struct{}, d.Concurrency)
  258. }
  259. if d.DNSCacheDuration == 0 {
  260. d.DNSCacheDuration = DefaultDNSCacheDuration
  261. }
  262. if !d.DisableDNSResolution {
  263. go d.tcpAddrsClean()
  264. }
  265. })
  266. deadline := time.Now().Add(timeout)
  267. network := "tcp4"
  268. if dualStack {
  269. network = "tcp"
  270. }
  271. if d.DisableDNSResolution {
  272. return d.tryDial(network, addr, deadline, d.concurrencyCh)
  273. }
  274. addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline)
  275. if err != nil {
  276. return nil, err
  277. }
  278. var conn net.Conn
  279. n := uint32(len(addrs)) // #nosec G115
  280. for n > 0 {
  281. conn, err = d.tryDial(network, addrs[idx%n].String(), deadline, d.concurrencyCh)
  282. if err == nil {
  283. return conn, nil
  284. }
  285. if errors.Is(err, ErrDialTimeout) {
  286. return nil, err
  287. }
  288. idx++
  289. n--
  290. }
  291. return nil, err
  292. }
  293. func (d *TCPDialer) tryDial(
  294. network string, addr string, deadline time.Time, concurrencyCh chan struct{},
  295. ) (net.Conn, error) {
  296. timeout := time.Until(deadline)
  297. if timeout <= 0 {
  298. return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
  299. }
  300. if concurrencyCh != nil {
  301. select {
  302. case concurrencyCh <- struct{}{}:
  303. default:
  304. tc := AcquireTimer(timeout)
  305. isTimeout := false
  306. select {
  307. case concurrencyCh <- struct{}{}:
  308. case <-tc.C:
  309. isTimeout = true
  310. }
  311. ReleaseTimer(tc)
  312. if isTimeout {
  313. return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
  314. }
  315. }
  316. defer func() { <-concurrencyCh }()
  317. }
  318. dialer := net.Dialer{}
  319. if d.LocalAddr != nil {
  320. dialer.LocalAddr = d.LocalAddr
  321. }
  322. ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)
  323. defer cancelCtx()
  324. conn, err := dialer.DialContext(ctx, network, addr)
  325. if err != nil {
  326. if ctx.Err() == context.DeadlineExceeded {
  327. return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
  328. }
  329. return nil, wrapDialWithUpstream(err, addr)
  330. }
  331. return conn, nil
  332. }
  333. // ErrDialTimeout is returned when TCP dialing is timed out.
  334. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out")
  335. // ErrDialWithUpstream wraps dial error with upstream info.
  336. //
  337. // Should use errors.As to get upstream information from error:
  338. //
  339. // hc := fasthttp.HostClient{Addr: "foo.com,bar.com"}
  340. // err := hc.Do(req, res)
  341. //
  342. // var dialErr *fasthttp.ErrDialWithUpstream
  343. // if errors.As(err, &dialErr) {
  344. // upstream = dialErr.Upstream // 34.206.39.153:80
  345. // }
  346. type ErrDialWithUpstream struct {
  347. wrapErr error
  348. Upstream string
  349. }
  350. func (e *ErrDialWithUpstream) Error() string {
  351. return fmt.Sprintf("error when dialing %s: %s", e.Upstream, e.wrapErr.Error())
  352. }
  353. func (e *ErrDialWithUpstream) Unwrap() error {
  354. return e.wrapErr
  355. }
  356. func wrapDialWithUpstream(err error, upstream string) error {
  357. return &ErrDialWithUpstream{
  358. Upstream: upstream,
  359. wrapErr: err,
  360. }
  361. }
  362. // DefaultDialTimeout is timeout used by Dial and DialDualStack
  363. // for establishing TCP connections.
  364. const DefaultDialTimeout = 3 * time.Second
  365. type tcpAddrEntry struct {
  366. resolveTime time.Time
  367. addrs []net.TCPAddr
  368. addrsIdx uint32
  369. pending int32
  370. }
  371. // DefaultDNSCacheDuration is the duration for caching resolved TCP addresses
  372. // by Dial* functions.
  373. const DefaultDNSCacheDuration = time.Minute
  374. func (d *TCPDialer) tcpAddrsClean() {
  375. expireDuration := 2 * d.DNSCacheDuration
  376. for {
  377. time.Sleep(time.Second)
  378. t := time.Now()
  379. d.tcpAddrsMap.Range(func(k, v any) bool {
  380. if e, ok := v.(*tcpAddrEntry); ok && t.Sub(e.resolveTime) > expireDuration {
  381. d.tcpAddrsMap.Delete(k)
  382. }
  383. return true
  384. })
  385. }
  386. }
  387. func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool, deadline time.Time) ([]net.TCPAddr, uint32, error) {
  388. item, exist := d.tcpAddrsMap.Load(addr)
  389. e, ok := item.(*tcpAddrEntry)
  390. if exist && ok && e != nil && time.Since(e.resolveTime) > d.DNSCacheDuration {
  391. // Only let one goroutine re-resolve at a time.
  392. if atomic.SwapInt32(&e.pending, 1) == 0 {
  393. e = nil
  394. }
  395. }
  396. if e == nil {
  397. addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver, deadline)
  398. if err != nil {
  399. item, exist := d.tcpAddrsMap.Load(addr)
  400. e, ok = item.(*tcpAddrEntry)
  401. if exist && ok && e != nil {
  402. // Set pending to 0 so another goroutine can retry.
  403. atomic.StoreInt32(&e.pending, 0)
  404. }
  405. return nil, 0, err
  406. }
  407. e = &tcpAddrEntry{
  408. addrs: addrs,
  409. resolveTime: time.Now(),
  410. }
  411. d.tcpAddrsMap.Store(addr, e)
  412. }
  413. idx := atomic.AddUint32(&e.addrsIdx, 1)
  414. return e.addrs, idx, nil
  415. }
  416. func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver, deadline time.Time) ([]net.TCPAddr, error) {
  417. host, portS, err := net.SplitHostPort(addr)
  418. if err != nil {
  419. return nil, err
  420. }
  421. port, err := strconv.Atoi(portS)
  422. if err != nil {
  423. return nil, err
  424. }
  425. if resolver == nil {
  426. resolver = net.DefaultResolver
  427. }
  428. ctx, cancel := context.WithDeadline(context.Background(), deadline)
  429. defer cancel()
  430. ipaddrs, err := resolver.LookupIPAddr(ctx, host)
  431. if err != nil {
  432. return nil, err
  433. }
  434. n := len(ipaddrs)
  435. addrs := make([]net.TCPAddr, 0, n)
  436. for i := 0; i < n; i++ {
  437. ip := ipaddrs[i]
  438. if !dualStack && ip.IP.To4() == nil {
  439. continue
  440. }
  441. addrs = append(addrs, net.TCPAddr{
  442. IP: ip.IP,
  443. Port: port,
  444. Zone: ip.Zone,
  445. })
  446. }
  447. if len(addrs) == 0 {
  448. return nil, errNoDNSEntries
  449. }
  450. return addrs, nil
  451. }
  452. var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack")