dnshostprovider.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package zk
  2. import (
  3. "fmt"
  4. "net"
  5. "sync"
  6. )
  7. // DNSHostProvider is the default HostProvider. It currently matches
  8. // the Java StaticHostProvider, resolving hosts from DNS once during
  9. // the call to Init. It could be easily extended to re-query DNS
  10. // periodically or if there is trouble connecting.
  11. type DNSHostProvider struct {
  12. mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
  13. servers []string
  14. curr int
  15. last int
  16. lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
  17. }
  18. // Init is called first, with the servers specified in the connection
  19. // string. It uses DNS to look up addresses for each server, then
  20. // shuffles them all together.
  21. func (hp *DNSHostProvider) Init(servers []string) error {
  22. hp.mu.Lock()
  23. defer hp.mu.Unlock()
  24. lookupHost := hp.lookupHost
  25. if lookupHost == nil {
  26. lookupHost = net.LookupHost
  27. }
  28. found := []string{}
  29. for _, server := range servers {
  30. host, port, err := net.SplitHostPort(server)
  31. if err != nil {
  32. return err
  33. }
  34. addrs, err := lookupHost(host)
  35. if err != nil {
  36. return err
  37. }
  38. for _, addr := range addrs {
  39. found = append(found, net.JoinHostPort(addr, port))
  40. }
  41. }
  42. if len(found) == 0 {
  43. return fmt.Errorf("No hosts found for addresses %q", servers)
  44. }
  45. // Randomize the order of the servers to avoid creating hotspots
  46. stringShuffle(found)
  47. hp.servers = found
  48. hp.curr = -1
  49. hp.last = -1
  50. return nil
  51. }
  52. // Len returns the number of servers available
  53. func (hp *DNSHostProvider) Len() int {
  54. hp.mu.Lock()
  55. defer hp.mu.Unlock()
  56. return len(hp.servers)
  57. }
  58. // Next returns the next server to connect to. retryStart will be true
  59. // if we've looped through all known servers without Connected() being
  60. // called.
  61. func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
  62. hp.mu.Lock()
  63. defer hp.mu.Unlock()
  64. hp.curr = (hp.curr + 1) % len(hp.servers)
  65. retryStart = hp.curr == hp.last
  66. if hp.last == -1 {
  67. hp.last = 0
  68. }
  69. return hp.servers[hp.curr], retryStart
  70. }
  71. // Connected notifies the HostProvider of a successful connection.
  72. func (hp *DNSHostProvider) Connected() {
  73. hp.mu.Lock()
  74. defer hp.mu.Unlock()
  75. hp.last = hp.curr
  76. }