whitelist_test.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package whitelist
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "testing"
  11. "time"
  12. )
  13. type StringLookup struct{}
  14. func (lu StringLookup) Address(args ...interface{}) (net.IP, error) {
  15. if len(args) != 1 {
  16. return nil, errors.New("whitelist: lookup requires a string")
  17. }
  18. var s string
  19. switch arg := args[0].(type) {
  20. case string:
  21. s = arg
  22. default:
  23. return nil, errors.New("whitelist: lookup requires a string")
  24. }
  25. ip := net.ParseIP(s)
  26. if ip == nil {
  27. return nil, errors.New("whitelist: no address found")
  28. }
  29. return ip, nil
  30. }
  31. var slu StringLookup
  32. func checkIPString(wl ACL, addr string, t *testing.T) bool {
  33. ip, err := slu.Address(addr)
  34. if err != nil {
  35. t.Fatalf("%v", err)
  36. }
  37. return wl.Permitted(ip)
  38. }
  39. func addIPString(wl HostACL, addr string, t *testing.T) {
  40. ip, err := slu.Address(addr)
  41. if err != nil {
  42. t.Fatalf("%v", err)
  43. }
  44. wl.Add(ip)
  45. }
  46. func delIPString(wl HostACL, addr string, t *testing.T) {
  47. ip, err := slu.Address(addr)
  48. if err != nil {
  49. t.Fatalf("%v", err)
  50. }
  51. wl.Remove(ip)
  52. }
  53. func TestBasicWhitelist(t *testing.T) {
  54. wl := NewBasic()
  55. if checkIPString(wl, "127.0.0.1", t) {
  56. t.Fatal("whitelist should have denied address")
  57. }
  58. addIPString(wl, "127.0.0.1", t)
  59. if !checkIPString(wl, "127.0.0.1", t) {
  60. t.Fatal("whitelist should have permitted address")
  61. }
  62. delIPString(wl, "127.0.0.1", t)
  63. if checkIPString(wl, "127.0.0.1", t) {
  64. t.Fatal("whitelist should have denied address")
  65. }
  66. addIPString(wl, "::1", t)
  67. if checkIPString(wl, "127.0.0.1", t) {
  68. t.Fatal("whitelist should have denied address")
  69. }
  70. wl.Add(nil)
  71. wl.Remove(nil)
  72. wl.Permitted(nil)
  73. }
  74. func TestStubWhitelist(t *testing.T) {
  75. wl := NewHostStub()
  76. if !checkIPString(wl, "127.0.0.1", t) {
  77. t.Fatal("whitelist should have permitted address")
  78. }
  79. addIPString(wl, "127.0.0.1", t)
  80. if !checkIPString(wl, "127.0.0.1", t) {
  81. t.Fatal("whitelist should have permitted address")
  82. }
  83. delIPString(wl, "127.0.0.1", t)
  84. if !checkIPString(wl, "127.0.0.1", t) {
  85. t.Fatal("whitelist should have permitted address")
  86. }
  87. }
  88. func TestMarshalHost(t *testing.T) {
  89. tv := map[string]*Basic{
  90. "test-a": NewBasic(),
  91. "test-b": NewBasic(),
  92. }
  93. ip := net.ParseIP("192.168.3.1")
  94. tv["test-a"].Add(ip)
  95. ip = net.ParseIP("192.168.3.2")
  96. tv["test-a"].Add(ip)
  97. if len(tv["test-a"].whitelist) != 2 {
  98. t.Fatalf("Expected whitelist to have 2 addresses, but have %d", len(tv["test-a"].whitelist))
  99. }
  100. out, err := json.Marshal(tv)
  101. if err != nil {
  102. t.Fatalf("%v", err)
  103. }
  104. var tvPrime map[string]*Basic
  105. err = json.Unmarshal(out, &tvPrime)
  106. if err != nil {
  107. t.Fatalf("%v", err)
  108. }
  109. }
  110. func TestMarshalHostFail(t *testing.T) {
  111. wl := NewBasic()
  112. badInput := `192.168.3.1/24,127.0.0.1/32`
  113. if err := wl.UnmarshalJSON([]byte(badInput)); err == nil {
  114. t.Fatal("Expected failure unmarshaling bad JSON input.")
  115. }
  116. badInput = `"192.168.3.1/32,127.0.0.252/32"`
  117. if err := wl.UnmarshalJSON([]byte(badInput)); err == nil {
  118. t.Fatal("Expected failure unmarshaling bad JSON input.")
  119. }
  120. }
  121. var shutdown = make(chan struct{}, 1)
  122. var proceed = make(chan struct{}, 0)
  123. func setupTestServer(t *testing.T, wl ACL) {
  124. ln, err := net.Listen("tcp", "127.0.0.1:4141")
  125. if err != nil {
  126. log.Fatalf("%v", err)
  127. }
  128. proceed <- struct{}{}
  129. for {
  130. select {
  131. case _, ok := <-shutdown:
  132. if !ok {
  133. return
  134. }
  135. default:
  136. conn, err := ln.Accept()
  137. if err != nil {
  138. log.Fatalf("%v", err)
  139. }
  140. go handleTestConnection(conn, wl, t)
  141. }
  142. }
  143. }
  144. func handleTestConnection(conn net.Conn, wl ACL, t *testing.T) {
  145. defer conn.Close()
  146. ip, err := NetConnLookup(conn)
  147. if err != nil {
  148. log.Fatalf("%v", err)
  149. }
  150. if wl.Permitted(ip) {
  151. conn.Write([]byte("OK"))
  152. } else {
  153. conn.Write([]byte("NO"))
  154. }
  155. }
  156. func TestNetConn(t *testing.T) {
  157. wl := NewBasic()
  158. defer close(shutdown)
  159. go setupTestServer(t, wl)
  160. <-proceed
  161. conn, err := net.Dial("tcp", "127.0.0.1:4141")
  162. if err != nil {
  163. t.Fatalf("%v", err)
  164. }
  165. body, err := io.ReadAll(conn)
  166. if err != nil {
  167. t.Fatalf("%v", err)
  168. }
  169. if string(body) != "NO" {
  170. t.Fatalf("Expected NO, but received %s", body)
  171. }
  172. conn.Close()
  173. addIPString(wl, "127.0.0.1", t)
  174. conn, err = net.Dial("tcp", "127.0.0.1:4141")
  175. if err != nil {
  176. t.Fatalf("%v", err)
  177. }
  178. body, err = io.ReadAll(conn)
  179. if err != nil {
  180. t.Fatalf("%v", err)
  181. }
  182. if string(body) != "OK" {
  183. t.Fatalf("Expected OK, but received %s", body)
  184. }
  185. conn.Close()
  186. }
  187. func TestBasicDumpLoad(t *testing.T) {
  188. wl := NewBasic()
  189. addIPString(wl, "127.0.0.1", t)
  190. addIPString(wl, "10.0.1.15", t)
  191. addIPString(wl, "192.168.1.5", t)
  192. out := DumpBasic(wl)
  193. loaded, err := LoadBasic(out)
  194. if err != nil {
  195. t.Fatalf("%v", err)
  196. }
  197. dumped := DumpBasic(loaded)
  198. if !bytes.Equal(out, dumped) {
  199. t.Fatalf("dump -> load failed")
  200. }
  201. }
  202. func TestBasicFailedLoad(t *testing.T) {
  203. dump := []byte("192.168.1.5\n192.168.2.3\n192.168.2\n192.168.3.1")
  204. if _, err := LoadBasic(dump); err == nil {
  205. t.Fatalf("LoadBasic should fail on invalid IP address")
  206. }
  207. }
  208. func TestNetConnChecks(t *testing.T) {
  209. if _, err := NetConnLookup(nil); err == nil {
  210. t.Fatal("Address should fail with an invalid argument")
  211. }
  212. }
  213. func TestHTTPRequestLookup(t *testing.T) {
  214. if _, err := HTTPRequestLookup(nil); err == nil {
  215. t.Fatal("Address should fail with an invalid argument")
  216. }
  217. req := new(http.Request)
  218. if _, err := HTTPRequestLookup(req); err == nil {
  219. t.Fatal("Address should fail with an invalid argument")
  220. }
  221. }
  222. type stubConn struct {
  223. Fake bool
  224. Global bool
  225. }
  226. func (conn *stubConn) Read(b []byte) (n int, err error) {
  227. return 0, nil
  228. }
  229. func (conn *stubConn) Write(b []byte) (n int, err error) {
  230. return 0, nil
  231. }
  232. func (conn *stubConn) Close() error {
  233. return nil
  234. }
  235. func (conn *stubConn) LocalAddr() net.Addr {
  236. return nil
  237. }
  238. func (conn *stubConn) RemoteAddr() net.Addr {
  239. if !conn.Fake {
  240. return nil
  241. }
  242. return &net.IPAddr{
  243. IP: net.IP{},
  244. }
  245. }
  246. func (conn *stubConn) SetDeadline(t time.Time) error {
  247. return nil
  248. }
  249. func (conn *stubConn) SetReadDeadline(t time.Time) error {
  250. return nil
  251. }
  252. func (conn *stubConn) SetWriteDeadline(t time.Time) error {
  253. return nil
  254. }
  255. func TestStubConn(t *testing.T) {
  256. var conn = new(stubConn)
  257. _, err := NetConnLookup(conn)
  258. if err == nil {
  259. t.Fatal("Address should fail to return an address")
  260. }
  261. conn.Fake = true
  262. _, err = NetConnLookup(conn)
  263. if err == nil {
  264. t.Fatal("Address should fail to return an address")
  265. }
  266. }
  267. func TestValidIP(t *testing.T) {
  268. ip4 := []byte{127, 0, 0, 1}
  269. ip6 := make([]byte, 16)
  270. ip6[15] = 1
  271. if !validIP(ip4) || !validIP(ip6) {
  272. t.Fatal("Failed to validate an IPv4 or an IPv6 address")
  273. }
  274. }