access.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package ipaccess
  2. import (
  3. "fmt"
  4. "net"
  5. "sort"
  6. )
  7. type Policy struct {
  8. defaultAllow bool
  9. rules []Rule
  10. }
  11. type Rule struct {
  12. ipNet *net.IPNet
  13. ports []int
  14. allow bool
  15. }
  16. func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) {
  17. for _, rule := range rules {
  18. if err := rule.Validate(); err != nil {
  19. return nil, err
  20. }
  21. }
  22. policy := Policy{
  23. defaultAllow: defaultAllow,
  24. rules: rules,
  25. }
  26. return &policy, nil
  27. }
  28. func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) {
  29. if prefix == nil || len(*prefix) == 0 {
  30. return Rule{}, fmt.Errorf("no prefix provided")
  31. }
  32. _, ipnet, err := net.ParseCIDR(*prefix)
  33. if err != nil {
  34. return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix)
  35. }
  36. return NewRule(ipnet, ports, allow)
  37. }
  38. func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) {
  39. rule := Rule{
  40. ipNet: ipnet,
  41. ports: ports,
  42. allow: allow,
  43. }
  44. return rule, rule.Validate()
  45. }
  46. func (r *Rule) Validate() error {
  47. if r.ipNet == nil {
  48. return fmt.Errorf("no ipnet set on the rule")
  49. }
  50. if len(r.ports) > 0 {
  51. sort.Ints(r.ports)
  52. for _, port := range r.ports {
  53. if port < 1 || port > 65535 {
  54. return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port)
  55. }
  56. }
  57. }
  58. return nil
  59. }
  60. func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) {
  61. if len(h.rules) == 0 {
  62. return h.defaultAllow, nil
  63. }
  64. for _, rule := range h.rules {
  65. if rule.ipNet.Contains(ip) {
  66. if len(rule.ports) == 0 {
  67. return rule.allow, &rule
  68. } else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port {
  69. return rule.allow, &rule
  70. }
  71. }
  72. }
  73. return h.defaultAllow, nil
  74. }
  75. func (ipr *Rule) String() string {
  76. return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow)
  77. }
  78. func (ipr *Rule) PortsString() string {
  79. if len(ipr.ports) > 0 {
  80. return fmt.Sprint(ipr.ports)
  81. }
  82. return "all"
  83. }