lookup.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package whitelist
  2. import (
  3. "errors"
  4. "log"
  5. "net"
  6. "net/http"
  7. )
  8. // NetConnLookup extracts an IP from the remote address in the
  9. // net.Conn. A single net.Conn should be passed to Address.
  10. func NetConnLookup(conn net.Conn) (net.IP, error) {
  11. if conn == nil {
  12. return nil, errors.New("whitelist: no connection")
  13. }
  14. netAddr := conn.RemoteAddr()
  15. if netAddr == nil {
  16. return nil, errors.New("whitelist: no address returned")
  17. }
  18. addr, _, err := net.SplitHostPort(netAddr.String())
  19. if err != nil {
  20. return nil, err
  21. }
  22. ip := net.ParseIP(addr)
  23. return ip, nil
  24. }
  25. // HTTPRequestLookup extracts an IP from the remote address in a
  26. // *http.Request. A single *http.Request should be passed to Address.
  27. func HTTPRequestLookup(req *http.Request) (net.IP, error) {
  28. if req == nil {
  29. return nil, errors.New("whitelist: no request")
  30. }
  31. addr, _, err := net.SplitHostPort(req.RemoteAddr)
  32. if err != nil {
  33. return nil, err
  34. }
  35. ip := net.ParseIP(addr)
  36. return ip, nil
  37. }
  38. // Handler wraps an HTTP handler with IP whitelisting.
  39. type Handler struct {
  40. allowHandler http.Handler
  41. denyHandler http.Handler
  42. whitelist ACL
  43. }
  44. // NewHandler returns a new whitelisting-wrapped HTTP handler. The
  45. // allow handler should contain a handler that will be called if the
  46. // request is whitelisted; the deny handler should contain a handler
  47. // that will be called in the request is not whitelisted.
  48. func NewHandler(allow, deny http.Handler, acl ACL) (http.Handler, error) {
  49. if allow == nil {
  50. return nil, errors.New("whitelist: allow cannot be nil")
  51. }
  52. if acl == nil {
  53. return nil, errors.New("whitelist: ACL cannot be nil")
  54. }
  55. return &Handler{
  56. allowHandler: allow,
  57. denyHandler: deny,
  58. whitelist: acl,
  59. }, nil
  60. }
  61. // ServeHTTP wraps the request in a whitelist check.
  62. func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  63. ip, err := HTTPRequestLookup(req)
  64. if err != nil {
  65. log.Printf("failed to lookup request address: %v", err)
  66. status := http.StatusInternalServerError
  67. http.Error(w, http.StatusText(status), status)
  68. return
  69. }
  70. if h.whitelist.Permitted(ip) {
  71. h.allowHandler.ServeHTTP(w, req)
  72. } else {
  73. if h.denyHandler == nil {
  74. status := http.StatusUnauthorized
  75. http.Error(w, http.StatusText(status), status)
  76. } else {
  77. h.denyHandler.ServeHTTP(w, req)
  78. }
  79. }
  80. }
  81. // A HandlerFunc contains a pair of http.HandleFunc-handler functions
  82. // that will be called depending on whether a request is allowed or
  83. // denied.
  84. type HandlerFunc struct {
  85. allow func(http.ResponseWriter, *http.Request)
  86. deny func(http.ResponseWriter, *http.Request)
  87. whitelist ACL
  88. }
  89. // NewHandlerFunc returns a new basic whitelisting handler.
  90. func NewHandlerFunc(allow, deny func(http.ResponseWriter, *http.Request), acl ACL) (*HandlerFunc, error) {
  91. if allow == nil {
  92. return nil, errors.New("whitelist: allow cannot be nil")
  93. }
  94. if acl == nil {
  95. return nil, errors.New("whitelist: ACL cannot be nil")
  96. }
  97. return &HandlerFunc{
  98. allow: allow,
  99. deny: deny,
  100. whitelist: acl,
  101. }, nil
  102. }
  103. // ServeHTTP checks the incoming request to see whether it is permitted,
  104. // and calls the appropriate handle function.
  105. func (h *HandlerFunc) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  106. ip, err := HTTPRequestLookup(req)
  107. if err != nil {
  108. log.Printf("failed to lookup request address: %v", err)
  109. status := http.StatusInternalServerError
  110. http.Error(w, http.StatusText(status), status)
  111. return
  112. }
  113. if h.whitelist.Permitted(ip) {
  114. h.allow(w, req)
  115. } else {
  116. if h.deny == nil {
  117. status := http.StatusUnauthorized
  118. http.Error(w, http.StatusText(status), status)
  119. } else {
  120. h.deny(w, req)
  121. }
  122. }
  123. }