diff --git a/connfilter.go b/connfilter.go index bc648db..4090755 100644 --- a/connfilter.go +++ b/connfilter.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "log" - "regexp" + "net" "slices" "strconv" "sync" @@ -161,38 +161,8 @@ where g.game_time < 60000 and g.time_started + $1::interval > now() and (i.pkey // stage 7 ip based mute if account == nil { - ipmutes := map[string]bool{} - for i := len(inst.cfgs) - 1; i >= 0; i-- { - o, ok := inst.cfgs[i].GetKeys("ipmute") - if !ok { - continue - } - for _, k := range o { - s, ok := inst.cfgs[i].GetBool("ipmute", k) - if !ok { - continue - } - if !s { - delete(ipmutes, k) - } else { - ipmutes[k] = s - } - } - } - for kip, v := range ipmutes { - if !v { - continue - } - reg, err := regexp.Compile(kip) - if err != nil { - inst.logger.Printf("Failed to compile regexp %q: %s", kip, err.Error()) - continue - } - if reg.Match([]byte(ip)) { - if jd.AllowChat { - jd.AllowChat = false - } - } + if checkIPmuted(inst, ip) { + jd.AllowChat = false } } @@ -206,6 +176,51 @@ where g.game_time < 60000 and g.time_started + $1::interval > now() and (i.pkey return jd, action, "" } +func checkIPmuted(inst *instance, ip string) bool { + clip := net.ParseIP(ip) + if clip == nil { + inst.logger.Printf("ipmute invalid ip %q", ip) + return false + } + ipmutes := map[string]bool{} + for i := len(inst.cfgs) - 1; i >= 0; i-- { + o, ok := inst.cfgs[i].GetKeys("ipmute") + if !ok { + continue + } + for _, k := range o { + s, ok := inst.cfgs[i].GetBool("ipmute", k) + if !ok { + continue + } + if !s { + delete(ipmutes, k) + } else { + ipmutes[k] = s + } + } + } + for kip, v := range ipmutes { + if !v { + continue + } + _, pnt, err := net.ParseCIDR(kip) + if err != nil { + inst.logger.Printf("ipmute ip %q is not in CIDR notation: %s", kip, err) + continue + } + if pnt == nil { + inst.logger.Printf("ipmute ip %q has no network after parsing", kip) + continue + } + if pnt.Contains(clip) { + inst.logger.Printf("ipmute applied to client %q with rule %q", ip, kip) + return true + } + } + return false +} + type joinCheckActionLevel int const (