Skip to content

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
DonnieBLT committed Jan 1, 2025
1 parent 8659ce0 commit 4df4e41
Showing 1 changed file with 146 additions and 90 deletions.
236 changes: 146 additions & 90 deletions blt/middleware/ip_restrict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

MAX_COUNT = 2147483647

CACHE_TIMEOUT = 86400 # 1 day, adjust as needed


class IPRestrictMiddleware:
"""
Expand All @@ -17,148 +19,202 @@ class IPRestrictMiddleware:
def __init__(self, get_response):
self.get_response = get_response

def get_cached_data(self, cache_key, queryset, timeout=86400):
# ----------------------------------------------------------------------
# Caching and Data Retrieval
# ----------------------------------------------------------------------

def get_blocked_entries(self):
"""
Retrieve all blocked entries from cache or database. We store the
entire list of Blocked objects (or a subset of fields) in cache once.
"""
blocked_data = cache.get("blocked_entries")
if blocked_data is None:
# You can store only the fields needed if you prefer:
# e.g., list(Blocked.objects.values('pk', 'address', 'ip_network', 'user_agent_string', 'count'))
blocked_data = list(Blocked.objects.all())
cache.set("blocked_entries", blocked_data, CACHE_TIMEOUT)
return blocked_data

def get_blocked_ips(self):
"""
Retrieve data from cache or database.
Build a set of blocked IP addresses from the cached blocked entries.
"""
cached_data = cache.get(cache_key)
if cached_data is None:
cached_data = list(filter(None, queryset)) # Filter out None values
cache.set(cache_key, cached_data, timeout=timeout)
return cached_data
ips = cache.get("blocked_ips")
if ips is None:
entries = self.get_blocked_entries()
# Filter out empty or None addresses, convert to a set
ips = {entry.address for entry in entries if entry.address}
cache.set("blocked_ips", ips, CACHE_TIMEOUT)
return ips

def blocked_ips(self):
def get_blocked_networks(self):
"""
Retrieve blocked IP addresses from cache or database.
Build a list of blocked IP networks (ipaddress.ip_network objects)
from the cached blocked entries.
"""
blocked_addresses = Blocked.objects.values_list("address", flat=True)
return set(self.get_cached_data("blocked_ips", blocked_addresses))
networks = cache.get("blocked_networks")
if networks is None:
entries = self.get_blocked_entries()
networks = []
for entry in entries:
if entry.ip_network:
try:
net = ipaddress.ip_network(entry.ip_network, strict=False)
networks.append(net)
except ValueError:
# Skip invalid networks
pass
cache.set("blocked_networks", networks, CACHE_TIMEOUT)
return networks

def blocked_ip_network(self):
def get_blocked_agents(self):
"""
Retrieve blocked IP networks from cache or database.
Build a set of blocked user-agent substrings (lowercase) from the
cached blocked entries.
"""
blocked_network = Blocked.objects.values_list("ip_network", flat=True)
blocked_ip_network = []
agents = cache.get("blocked_agents")
if agents is None:
entries = self.get_blocked_entries()
# Filter out empty or None user_agent_string, convert to lowercase set
agents = {
entry.user_agent_string.lower() for entry in entries if entry.user_agent_string
}
cache.set("blocked_agents", agents, CACHE_TIMEOUT)
return agents

for range_str in self.get_cached_data("blocked_ip_network", blocked_network):
try:
network = ipaddress.ip_network(range_str, strict=False)
blocked_ip_network.append(network)
except ValueError:
# Log the error or handle it as needed, but skip invalid networks
continue
# ----------------------------------------------------------------------
# Core Utility Methods
# ----------------------------------------------------------------------

return blocked_ip_network
def get_client_ip(self, request):
"""
Extract the client IP address from the request.
"""
forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR", "")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.META.get("REMOTE_ADDR", "")

def blocked_agents(self):
def get_user_agent(self, request):
"""
Retrieve blocked user agents from cache or database.
Extract the user agent string from the request.
"""
blocked_user_agents = Blocked.objects.values_list("user_agent_string", flat=True)
return set(self.get_cached_data("blocked_agents", blocked_user_agents))
return request.META.get("HTTP_USER_AGENT", "").strip()

def ip_in_ips(self, ip, blocked_ips):
"""
Check if the IP address is in the list of blocked IPs.
Check if the IP address is in the set of blocked IPs.
"""
return ip in blocked_ips

def ip_in_range(self, ip, blocked_ip_network):
def ip_in_networks(self, ip, blocked_networks):
"""
Check if the IP address is within any of the blocked IP networks.
Returns the matching network if found, else None.
"""
if not ip:
return None
ip_obj = ipaddress.ip_address(ip)
return any(ip_obj in ip_range for ip_range in blocked_ip_network)
for net in blocked_networks:
if ip_obj in net:
return net
return None

def is_user_agent_blocked(self, user_agent, blocked_agents):
"""
Check if the user agent is in the list of blocked user agents by checking if the
full user agent string contains any of the blocked substrings.
Check if the user agent matches any of the blocked user agent
substrings (case-insensitive).
"""
user_agent_str = str(user_agent).strip().lower()
return any(blocked_agent.lower() in user_agent_str for blocked_agent in blocked_agents)
ua_lower = user_agent.lower()
return any(blocked_ua in ua_lower for blocked_ua in blocked_agents)

# ----------------------------------------------------------------------
# Incrementing the Block Count in the DB
# ----------------------------------------------------------------------

def increment_block_count(self, ip=None, network=None, user_agent=None):
"""
Increment the block count for a specific IP, network, or user agent in the Blocked model.
Increment the block count for the matching Blocked entry.
Since the actual update must happen in the DB, we do one minimal query:
we find the relevant entry by IP, or by network, or by user_agent substring.
"""
if not (ip or network or user_agent):
return # Nothing to increment

with transaction.atomic():
# We'll do a single lookup based on what's provided.
# Searching by IP / network is straightforward. Searching by user_agent
# (substring) requires an __icontains lookup or a manual filter.
qs = Blocked.objects.select_for_update()

if ip:
blocked_entry = Blocked.objects.select_for_update().filter(address=ip).first()
qs = qs.filter(address=ip)
elif network:
blocked_entry = (
Blocked.objects.select_for_update().filter(ip_network=network).first()
)
qs = qs.filter(ip_network=network)
elif user_agent:
# Correct lookup: find if any user_agent_string is a substring of the user_agent
blocked_entry = (
Blocked.objects.select_for_update()
.filter(
user_agent_string__in=[
agent
for agent in Blocked.objects.values_list("user_agent_string", flat=True)
if agent.lower() in user_agent.lower()
]
)
.first()
# If multiple user_agent_strings can match, you might want to refine logic here.
qs = qs.filter(user_agent_string__iexact=user_agent) | qs.filter(
user_agent_string__icontains=user_agent
)
else:
return # Nothing to increment

blocked_entry = qs.first()
if blocked_entry:
blocked_entry.count = models.F("count") + 1
blocked_entry.save(update_fields=["count"])

# ----------------------------------------------------------------------
# Recording General IP Usage in IP Model
# ----------------------------------------------------------------------

def record_ip_usage(self, ip, agent, path):
"""
Create or update the IP record for (ip, path) with an incremented count.
"""
with transaction.atomic():
ip_record_qs = IP.objects.select_for_update().filter(address=ip, path=path)
if ip_record_qs.exists():
ip_record = ip_record_qs.first()
ip_record.agent = agent
ip_record.count = min(ip_record.count + 1, MAX_COUNT)
ip_record.save(update_fields=["agent", "count"])

# Clean up any other records with the same (ip, path) but different PKs
ip_record_qs.exclude(pk=ip_record.pk).delete()
else:
IP.objects.create(address=ip, agent=agent, count=1, path=path)

# ----------------------------------------------------------------------
# Main Handler
# ----------------------------------------------------------------------

def __call__(self, request):
ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get(
"REMOTE_ADDR", ""
)
agent = request.META.get("HTTP_USER_AGENT", "").strip()
ip = self.get_client_ip(request)
agent = self.get_user_agent(request)

blocked_ips = self.blocked_ips()
blocked_ip_network = self.blocked_ip_network()
blocked_agents = self.blocked_agents()
# Grab cached block sets/lists (1 DB call if cache is empty)
blocked_ips = self.get_blocked_ips()
blocked_networks = self.get_blocked_networks()
blocked_agents = self.get_blocked_agents()

if self.ip_in_ips(ip, blocked_ips):
# 1) Check if IP is explicitly blocked
if ip and self.ip_in_ips(ip, blocked_ips):
self.increment_block_count(ip=ip)
return HttpResponseForbidden()

if self.ip_in_range(ip, blocked_ip_network):
# Find the specific network that caused the block and increment its count
for network in blocked_ip_network:
if ipaddress.ip_address(ip) in network:
self.increment_block_count(network=str(network))
break
# 2) Check if IP is in any blocked network
network_hit = self.ip_in_networks(ip, blocked_networks)
if network_hit:
self.increment_block_count(network=str(network_hit))
return HttpResponseForbidden()

if self.is_user_agent_blocked(agent, blocked_agents):
# 3) Check if user agent is blocked
if agent and self.is_user_agent_blocked(agent, blocked_agents):
self.increment_block_count(user_agent=agent)
return HttpResponseForbidden()

# 4) Record IP usage if present
if ip:
with transaction.atomic():
# create unique entry for every unique (ip,path) tuple
# if this tuple already exists, we just increment the count.
ip_records = IP.objects.select_for_update().filter(address=ip, path=request.path)
if ip_records.exists():
ip_record = ip_records.first()

# Calculate the new count and ensure it doesn't exceed the MAX_COUNT
new_count = ip_record.count + 1
if new_count > MAX_COUNT:
new_count = MAX_COUNT

ip_record.agent = agent
ip_record.count = new_count
if ip_record.pk:
ip_record.save(update_fields=["agent", "count"])

# Check if a transaction is already active before starting a new one
if not transaction.get_autocommit():
ip_records.exclude(pk=ip_record.pk).delete()
else:
# If no record exists, create a new one
IP.objects.create(address=ip, agent=agent, count=1, path=request.path)
self.record_ip_usage(ip, agent, request.path)

return self.get_response(request)

0 comments on commit 4df4e41

Please sign in to comment.