123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- package ipaccess
- import (
- "fmt"
- "net"
- "sort"
- )
- type Policy struct {
- defaultAllow bool
- rules []Rule
- }
- type Rule struct {
- ipNet *net.IPNet
- ports []int
- allow bool
- }
- func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) {
- for _, rule := range rules {
- if err := rule.Validate(); err != nil {
- return nil, err
- }
- }
- policy := Policy{
- defaultAllow: defaultAllow,
- rules: rules,
- }
- return &policy, nil
- }
- func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) {
- if prefix == nil || len(*prefix) == 0 {
- return Rule{}, fmt.Errorf("no prefix provided")
- }
- _, ipnet, err := net.ParseCIDR(*prefix)
- if err != nil {
- return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix)
- }
- return NewRule(ipnet, ports, allow)
- }
- func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) {
- rule := Rule{
- ipNet: ipnet,
- ports: ports,
- allow: allow,
- }
- return rule, rule.Validate()
- }
- func (r *Rule) Validate() error {
- if r.ipNet == nil {
- return fmt.Errorf("no ipnet set on the rule")
- }
- if len(r.ports) > 0 {
- sort.Ints(r.ports)
- for _, port := range r.ports {
- if port < 1 || port > 65535 {
- return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port)
- }
- }
- }
- return nil
- }
- func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) {
- if len(h.rules) == 0 {
- return h.defaultAllow, nil
- }
- for _, rule := range h.rules {
- if rule.ipNet.Contains(ip) {
- if len(rule.ports) == 0 {
- return rule.allow, &rule
- } else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port {
- return rule.allow, &rule
- }
- }
- }
- return h.defaultAllow, nil
- }
- func (ipr *Rule) String() string {
- return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow)
- }
- func (ipr *Rule) PortsString() string {
- if len(ipr.ports) > 0 {
- return fmt.Sprint(ipr.ports)
- }
- return "all"
- }
|