123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- package whitelist
- import (
- "io"
- "log"
- "net/http"
- "net/http/httptest"
- "strings"
- "sync"
- "testing"
- )
- type testHandler struct {
- Message string
- }
- func newTestHandler(m string) http.Handler {
- return &testHandler{Message: m}
- }
- func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte(h.Message))
- }
- var testAllowHandler = newTestHandler("OK")
- var testDenyHandler = newTestHandler("NO")
- func testHTTPResponse(url string, t *testing.T) string {
- resp, err := http.Get(url)
- if err != nil {
- t.Fatalf("%v", err)
- }
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("%v", err)
- }
- resp.Body.Close()
- return string(body)
- }
- func testWorker(url string, t *testing.T, wg *sync.WaitGroup) {
- for i := 0; i < 100; i++ {
- response := testHTTPResponse(url, t)
- if response != "NO" {
- log.Fatalf("Expected NO, but got %s", response)
- }
- }
- wg.Done()
- }
- func TestHostStubHTTP(t *testing.T) {
- wl := NewHostStub()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- response := testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- addIPString(wl, "127.0.0.1", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- delIPString(wl, "127.0.0.1", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- }
- func TestNetStubHTTP(t *testing.T) {
- wl := NewNetStub()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- response := testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- testAddNet(wl, "127.0.0.1/32", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- testDelNet(wl, "127.0.0.1/32", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- }
- func TestBasicHTTP(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- response := testHTTPResponse(srv.URL, t)
- if response != "NO" {
- t.Fatalf("Expected NO, but got %s", response)
- }
- addIPString(wl, "127.0.0.1", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- delIPString(wl, "127.0.0.1", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "NO" {
- t.Fatalf("Expected NO, but got %s", response)
- }
- }
- func TestBasicHTTPDefaultDeny(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandler(testAllowHandler, nil, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- expected := "Unauthorized"
- response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- }
- func TestBasicHTTPWorkers(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- wg := new(sync.WaitGroup)
- defer srv.Close()
- for i := 0; i < 16; i++ {
- wg.Add(1)
- go testWorker(srv.URL, t, wg)
- }
- wg.Wait()
- }
- func TestFailHTTP(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- w := httptest.NewRecorder()
- req := new(http.Request)
- if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
- t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
- }
- }
- var testHandlerFunc *HandlerFunc
- func newTestHandlerFunc(m string) func(http.ResponseWriter, *http.Request) {
- return func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte(m))
- }
- }
- var testAllowHandlerFunc = newTestHandlerFunc("OK")
- var testDenyHandlerFunc = newTestHandlerFunc("NO")
- func TestSetupHandlerFuncFails(t *testing.T) {
- wl := NewBasic()
- _, err := NewHandlerFunc(nil, testDenyHandlerFunc, wl)
- if err == nil {
- t.Fatal("expected NewHandlerFunc to fail with nil allow handler")
- }
- _, err = NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, nil)
- if err == nil {
- t.Fatal("expected NewHandlerFunc to fail with nil whitelist")
- }
- _, err = NewHandlerFunc(testAllowHandlerFunc, nil, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- }
- func TestSetupHandlerFunc(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- expected := "NO"
- response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- h.deny = nil
- expected = "Unauthorized"
- response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- addIPString(wl, "127.0.0.1", t)
- expected = "OK"
- response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- }
- func TestFailHTTPFunc(t *testing.T) {
- wl := NewBasic()
- h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- w := httptest.NewRecorder()
- req := new(http.Request)
- if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
- t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
- }
- }
- func TestBasicNetHTTP(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- response := testHTTPResponse(srv.URL, t)
- if response != "NO" {
- t.Fatalf("Expected NO, but got %s", response)
- }
- testAddNet(wl, "127.0.0.1/32", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "OK" {
- t.Fatalf("Expected OK, but got %s", response)
- }
- testDelNet(wl, "127.0.0.1/32", t)
- response = testHTTPResponse(srv.URL, t)
- if response != "NO" {
- t.Fatalf("Expected NO, but got %s", response)
- }
- }
- func TestBasicNetHTTPDefaultDeny(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandler(testAllowHandler, nil, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- expected := "Unauthorized"
- response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- }
- func TestBasicNetHTTPWorkers(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- wg := new(sync.WaitGroup)
- defer srv.Close()
- for i := 0; i < 16; i++ {
- wg.Add(1)
- go testWorker(srv.URL, t, wg)
- }
- wg.Wait()
- }
- func TestNetFailHTTP(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandler(testAllowHandler, testDenyHandler, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- w := httptest.NewRecorder()
- req := new(http.Request)
- if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
- t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
- }
- }
- func TestSetupNetHandlerFuncFails(t *testing.T) {
- wl := NewBasicNet()
- _, err := NewHandlerFunc(nil, testDenyHandlerFunc, wl)
- if err == nil {
- t.Fatal("expected NewHandlerFunc to fail with nil allow handler")
- }
- _, err = NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, nil)
- if err == nil {
- t.Fatal("expected NewHandlerFunc to fail with nil whitelist")
- }
- _, err = NewHandlerFunc(testAllowHandlerFunc, nil, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- }
- func TestSetupNetHandlerFunc(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
- if err != nil {
- log.Fatalf("%v", err)
- }
- srv := httptest.NewServer(h)
- defer srv.Close()
- expected := "NO"
- response := strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- h.deny = nil
- expected = "Unauthorized"
- response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- testAddNet(wl, "127.0.0.1/32", t)
- expected = "OK"
- response = strings.TrimSpace(testHTTPResponse(srv.URL, t))
- if response != expected {
- t.Fatalf("Expected %s, but got %s", expected, response)
- }
- }
- func TestNetFailHTTPFunc(t *testing.T) {
- wl := NewBasicNet()
- h, err := NewHandlerFunc(testAllowHandlerFunc, testDenyHandlerFunc, wl)
- if err != nil {
- t.Fatalf("%v", err)
- }
- w := httptest.NewRecorder()
- req := new(http.Request)
- if h.ServeHTTP(w, req); w.Code != http.StatusInternalServerError {
- t.Fatalf("Expect HTTP 500, but got HTTP %d", w.Code)
- }
- }
- func TestHandlerFunc(t *testing.T) {
- var acl ACL
- _, err := NewHandler(testAllowHandler, testDenyHandler, acl)
- if err == nil || err.Error() != "whitelist: ACL cannot be nil" {
- t.Fatal("Expected error with nil allow handler.")
- }
- acl = NewBasic()
- _, err = NewHandler(nil, testDenyHandler, acl)
- if err == nil || err.Error() != "whitelist: allow cannot be nil" {
- t.Fatal("Expected error with nil ACL.")
- }
- }
|