mocks_for_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package connection
  2. import (
  3. "fmt"
  4. "math"
  5. "math/rand"
  6. "net"
  7. "reflect"
  8. "testing/quick"
  9. )
  10. type mockAddrs struct {
  11. // a set of synthetic SRV records
  12. addrMap map[net.SRV][]*net.TCPAddr
  13. // the total number of addresses, aggregated across addrMap.
  14. // For the convenience of test code that would otherwise have to compute
  15. // this by hand every time.
  16. numAddrs int
  17. }
  18. func newMockAddrs(port uint16, numRegions uint8, numAddrsPerRegion uint8) mockAddrs {
  19. addrMap := make(map[net.SRV][]*net.TCPAddr)
  20. numAddrs := 0
  21. for r := uint8(0); r < numRegions; r++ {
  22. var (
  23. srv = net.SRV{Target: fmt.Sprintf("test-region-%v.example.com", r), Port: port}
  24. addrs []*net.TCPAddr
  25. )
  26. for a := uint8(0); a < numAddrsPerRegion; a++ {
  27. addrs = append(addrs, &net.TCPAddr{
  28. IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)),
  29. Port: int(port),
  30. })
  31. }
  32. addrMap[srv] = addrs
  33. numAddrs += len(addrs)
  34. }
  35. return mockAddrs{addrMap: addrMap, numAddrs: numAddrs}
  36. }
  37. var _ quick.Generator = mockAddrs{}
  38. func (mockAddrs) Generate(rand *rand.Rand, size int) reflect.Value {
  39. port := uint16(rand.Intn(math.MaxUint16))
  40. numRegions := uint8(1 + rand.Intn(10))
  41. numAddrsPerRegion := uint8(1 + rand.Intn(32))
  42. result := newMockAddrs(port, numRegions, numAddrsPerRegion)
  43. return reflect.ValueOf(result)
  44. }
  45. // Returns a function compatible with net.LookupSRV that will return the SRV
  46. // records from mockAddrs.
  47. func mockNetLookupSRV(
  48. m mockAddrs,
  49. ) func(service, proto, name string) (cname string, addrs []*net.SRV, err error) {
  50. var addrs []*net.SRV
  51. for k := range m.addrMap {
  52. addr := k
  53. addrs = append(addrs, &addr)
  54. // We can't just do
  55. // addrs = append(addrs, &k)
  56. // `k` will be reused by subsequent loop iterations,
  57. // so all the copies of `&k` would point to the same location.
  58. }
  59. return func(_, _, _ string) (string, []*net.SRV, error) {
  60. return "", addrs, nil
  61. }
  62. }
  63. // Returns a function compatible with net.LookupIP that translates the SRV records
  64. // from mockAddrs into IP addresses, based on the TCP addresses in mockAddrs.
  65. func mockNetLookupIP(
  66. m mockAddrs,
  67. ) func(host string) ([]net.IP, error) {
  68. return func(host string) ([]net.IP, error) {
  69. for srv, tcpAddrs := range m.addrMap {
  70. if srv.Target != host {
  71. continue
  72. }
  73. result := make([]net.IP, len(tcpAddrs))
  74. for i, tcpAddr := range tcpAddrs {
  75. result[i] = tcpAddr.IP
  76. }
  77. return result, nil
  78. }
  79. return nil, fmt.Errorf("No IPs for %v", host)
  80. }
  81. }
  82. type mockEdgeServiceDiscoverer struct {
  83. }
  84. func (mr *mockEdgeServiceDiscoverer) Addr() (*net.TCPAddr, error) {
  85. return &net.TCPAddr{
  86. IP: net.ParseIP("127.0.0.1"),
  87. Port: 63102,
  88. }, nil
  89. }
  90. func (mr *mockEdgeServiceDiscoverer) AnyAddr() (*net.TCPAddr, error) {
  91. return &net.TCPAddr{
  92. IP: net.ParseIP("127.0.0.1"),
  93. Port: 63102,
  94. }, nil
  95. }
  96. func (mr *mockEdgeServiceDiscoverer) ReplaceAddr(addr *net.TCPAddr) {}
  97. func (mr *mockEdgeServiceDiscoverer) MarkAddrBad(addr *net.TCPAddr) {}
  98. func (mr *mockEdgeServiceDiscoverer) AvailableAddrs() int {
  99. return 1
  100. }
  101. func (mr *mockEdgeServiceDiscoverer) Refresh() error {
  102. return nil
  103. }