http_test.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. package whitelist
  2. import (
  3. "io"
  4. "log"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "sync"
  9. "testing"
  10. )
  11. type testHandler struct {
  12. Message string
  13. }
  14. func newTestHandler(m string) http.Handler {
  15. return &testHandler{Message: m}
  16. }
  17. func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  18. w.Write([]byte(h.Message))
  19. }
  20. var testAllowHandler = newTestHandler("OK")
  21. var testDenyHandler = newTestHandler("NO")
  22. func testHTTPResponse(url string, t *testing.T) string {
  23. resp, err := http.Get(url)
  24. if err != nil {
  25. t.Fatalf("%v", err)
  26. }
  27. body, err := io.ReadAll(resp.Body)
  28. if err != nil {
  29. t.Fatalf("%v", err)
  30. }
  31. resp.Body.Close()
  32. return string(body)
  33. }
  34. func testWorker(url string, t *testing.T, wg *sync.WaitGroup) {
  35. for i := 0; i < 100; i++ {
  36. response := testHTTPResponse(url, t)
  37. if response != "NO" {
  38. log.Fatalf("Expected NO, but got %s", response)
  39. }
  40. }
  41. wg.Done()
  42. }
  43. func TestHostStubHTTP(t *testing.T) {
  44. wl := NewHostStub()
  45. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  46. if err != nil {
  47. t.Fatalf("%v", err)
  48. }
  49. srv := httptest.NewServer(h)
  50. defer srv.Close()
  51. response := testHTTPResponse(srv.URL, t)
  52. if response != "OK" {
  53. t.Fatalf("Expected OK, but got %s", response)
  54. }
  55. addIPString(wl, "127.0.0.1", t)
  56. response = testHTTPResponse(srv.URL, t)
  57. if response != "OK" {
  58. t.Fatalf("Expected OK, but got %s", response)
  59. }
  60. delIPString(wl, "127.0.0.1", t)
  61. response = testHTTPResponse(srv.URL, t)
  62. if response != "OK" {
  63. t.Fatalf("Expected OK, but got %s", response)
  64. }
  65. }
  66. func TestNetStubHTTP(t *testing.T) {
  67. wl := NewNetStub()
  68. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  69. if err != nil {
  70. t.Fatalf("%v", err)
  71. }
  72. srv := httptest.NewServer(h)
  73. defer srv.Close()
  74. response := testHTTPResponse(srv.URL, t)
  75. if response != "OK" {
  76. t.Fatalf("Expected OK, but got %s", response)
  77. }
  78. testAddNet(wl, "127.0.0.1/32", t)
  79. response = testHTTPResponse(srv.URL, t)
  80. if response != "OK" {
  81. t.Fatalf("Expected OK, but got %s", response)
  82. }
  83. testDelNet(wl, "127.0.0.1/32", t)
  84. response = testHTTPResponse(srv.URL, t)
  85. if response != "OK" {
  86. t.Fatalf("Expected OK, but got %s", response)
  87. }
  88. }
  89. func TestBasicHTTP(t *testing.T) {
  90. wl := NewBasic()
  91. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  92. if err != nil {
  93. t.Fatalf("%v", err)
  94. }
  95. srv := httptest.NewServer(h)
  96. defer srv.Close()
  97. response := testHTTPResponse(srv.URL, t)
  98. if response != "NO" {
  99. t.Fatalf("Expected NO, but got %s", response)
  100. }
  101. addIPString(wl, "127.0.0.1", t)
  102. response = testHTTPResponse(srv.URL, t)
  103. if response != "OK" {
  104. t.Fatalf("Expected OK, but got %s", response)
  105. }
  106. delIPString(wl, "127.0.0.1", t)
  107. response = testHTTPResponse(srv.URL, t)
  108. if response != "NO" {
  109. t.Fatalf("Expected NO, but got %s", response)
  110. }
  111. }
  112. func TestBasicHTTPDefaultDeny(t *testing.T) {
  113. wl := NewBasic()
  114. h, err := NewHandler(testAllowHandler, nil, wl)
  115. if err != nil {
  116. t.Fatalf("%v", err)
  117. }
  118. srv := httptest.NewServer(h)
  119. defer srv.Close()
  120. expected := "Unauthorized"
  121. response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
  122. if response != expected {
  123. t.Fatalf("Expected %s, but got %s", expected, response)
  124. }
  125. }
  126. func TestBasicHTTPWorkers(t *testing.T) {
  127. wl := NewBasic()
  128. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  129. if err != nil {
  130. t.Fatalf("%v", err)
  131. }
  132. srv := httptest.NewServer(h)
  133. wg := new(sync.WaitGroup)
  134. defer srv.Close()
  135. for i := 0; i < 16; i++ {
  136. wg.Add(1)
  137. go testWorker(srv.URL, t, wg)
  138. }
  139. wg.Wait()
  140. }
  141. func TestFailHTTP(t *testing.T) {
  142. wl := NewBasic()
  143. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  144. if err != nil {
  145. t.Fatalf("%v", err)
  146. }
  147. w := httptest.NewRecorder()
  148. req := new(http.Request)
  149. if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
  150. t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
  151. }
  152. }
  153. var testHandlerFunc *HandlerFunc
  154. func newTestHandlerFunc(m string) func(http.ResponseWriter, *http.Request) {
  155. return func(w http.ResponseWriter, r *http.Request) {
  156. w.Write([]byte(m))
  157. }
  158. }
  159. var testAllowHandlerFunc = newTestHandlerFunc("OK")
  160. var testDenyHandlerFunc = newTestHandlerFunc("NO")
  161. func TestSetupHandlerFuncFails(t *testing.T) {
  162. wl := NewBasic()
  163. _, err := NewHandlerFunc(nil, testDenyHandlerFunc, wl)
  164. if err == nil {
  165. t.Fatal("expected NewHandlerFunc to fail with nil allow handler")
  166. }
  167. _, err = NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, nil)
  168. if err == nil {
  169. t.Fatal("expected NewHandlerFunc to fail with nil whitelist")
  170. }
  171. _, err = NewHandlerFunc(testAllowHandlerFunc, nil, wl)
  172. if err != nil {
  173. t.Fatalf("%v", err)
  174. }
  175. }
  176. func TestSetupHandlerFunc(t *testing.T) {
  177. wl := NewBasic()
  178. h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
  179. if err != nil {
  180. t.Fatalf("%v", err)
  181. }
  182. srv := httptest.NewServer(h)
  183. defer srv.Close()
  184. expected := "NO"
  185. response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
  186. if response != expected {
  187. t.Fatalf("Expected %s, but got %s", expected, response)
  188. }
  189. h.deny = nil
  190. expected = "Unauthorized"
  191. response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
  192. if response != expected {
  193. t.Fatalf("Expected %s, but got %s", expected, response)
  194. }
  195. addIPString(wl, "127.0.0.1", t)
  196. expected = "OK"
  197. response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
  198. if response != expected {
  199. t.Fatalf("Expected %s, but got %s", expected, response)
  200. }
  201. }
  202. func TestFailHTTPFunc(t *testing.T) {
  203. wl := NewBasic()
  204. h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
  205. if err != nil {
  206. t.Fatalf("%v", err)
  207. }
  208. w := httptest.NewRecorder()
  209. req := new(http.Request)
  210. if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
  211. t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
  212. }
  213. }
  214. func TestBasicNetHTTP(t *testing.T) {
  215. wl := NewBasicNet()
  216. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  217. if err != nil {
  218. t.Fatalf("%v", err)
  219. }
  220. srv := httptest.NewServer(h)
  221. defer srv.Close()
  222. response := testHTTPResponse(srv.URL, t)
  223. if response != "NO" {
  224. t.Fatalf("Expected NO, but got %s", response)
  225. }
  226. testAddNet(wl, "127.0.0.1/32", t)
  227. response = testHTTPResponse(srv.URL, t)
  228. if response != "OK" {
  229. t.Fatalf("Expected OK, but got %s", response)
  230. }
  231. testDelNet(wl, "127.0.0.1/32", t)
  232. response = testHTTPResponse(srv.URL, t)
  233. if response != "NO" {
  234. t.Fatalf("Expected NO, but got %s", response)
  235. }
  236. }
  237. func TestBasicNetHTTPDefaultDeny(t *testing.T) {
  238. wl := NewBasicNet()
  239. h, err := NewHandler(testAllowHandler, nil, wl)
  240. if err != nil {
  241. t.Fatalf("%v", err)
  242. }
  243. srv := httptest.NewServer(h)
  244. defer srv.Close()
  245. expected := "Unauthorized"
  246. response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
  247. if response != expected {
  248. t.Fatalf("Expected %s, but got %s", expected, response)
  249. }
  250. }
  251. func TestBasicNetHTTPWorkers(t *testing.T) {
  252. wl := NewBasicNet()
  253. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  254. if err != nil {
  255. t.Fatalf("%v", err)
  256. }
  257. srv := httptest.NewServer(h)
  258. wg := new(sync.WaitGroup)
  259. defer srv.Close()
  260. for i := 0; i < 16; i++ {
  261. wg.Add(1)
  262. go testWorker(srv.URL, t, wg)
  263. }
  264. wg.Wait()
  265. }
  266. func TestNetFailHTTP(t *testing.T) {
  267. wl := NewBasicNet()
  268. h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
  269. if err != nil {
  270. t.Fatalf("%v", err)
  271. }
  272. w := httptest.NewRecorder()
  273. req := new(http.Request)
  274. if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
  275. t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
  276. }
  277. }
  278. func TestSetupNetHandlerFuncFails(t *testing.T) {
  279. wl := NewBasicNet()
  280. _, err := NewHandlerFunc(nil, testDenyHandlerFunc, wl)
  281. if err == nil {
  282. t.Fatal("expected NewHandlerFunc to fail with nil allow handler")
  283. }
  284. _, err = NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, nil)
  285. if err == nil {
  286. t.Fatal("expected NewHandlerFunc to fail with nil whitelist")
  287. }
  288. _, err = NewHandlerFunc(testAllowHandlerFunc, nil, wl)
  289. if err != nil {
  290. t.Fatalf("%v", err)
  291. }
  292. }
  293. func TestSetupNetHandlerFunc(t *testing.T) {
  294. wl := NewBasicNet()
  295. h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
  296. if err != nil {
  297. log.Fatalf("%v", err)
  298. }
  299. srv := httptest.NewServer(h)
  300. defer srv.Close()
  301. expected := "NO"
  302. response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
  303. if response != expected {
  304. t.Fatalf("Expected %s, but got %s", expected, response)
  305. }
  306. h.deny = nil
  307. expected = "Unauthorized"
  308. response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
  309. if response != expected {
  310. t.Fatalf("Expected %s, but got %s", expected, response)
  311. }
  312. testAddNet(wl, "127.0.0.1/32", t)
  313. expected = "OK"
  314. response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
  315. if response != expected {
  316. t.Fatalf("Expected %s, but got %s", expected, response)
  317. }
  318. }
  319. func TestNetFailHTTPFunc(t *testing.T) {
  320. wl := NewBasicNet()
  321. h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
  322. if err != nil {
  323. t.Fatalf("%v", err)
  324. }
  325. w := httptest.NewRecorder()
  326. req := new(http.Request)
  327. if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
  328. t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
  329. }
  330. }
  331. func TestHandlerFunc(t *testing.T) {
  332. var acl ACL
  333. _, err := NewHandler(testAllowHandler, testDenyHandler, acl)
  334. if err == nil || err.Error() != "whitelist: ACL cannot be nil" {
  335. t.Fatal("Expected error with nil allow handler.")
  336. }
  337. acl = NewBasic()
  338. _, err = NewHandler(nil, testDenyHandler, acl)
  339. if err == nil || err.Error() != "whitelist: allow cannot be nil" {
  340. t.Fatal("Expected error with nil ACL.")
  341. }
  342. }