Skip to content

Commit 4df4e41

Browse files
committed
1 parent 8659ce0 commit 4df4e41

File tree

1 file changed

+146
-90
lines changed

1 file changed

+146
-90
lines changed

blt/middleware/ip_restrict.py

Lines changed: 146 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
MAX_COUNT = 2147483647
1010

11+
CACHE_TIMEOUT = 86400 # 1 day, adjust as needed
12+
1113

1214
class IPRestrictMiddleware:
1315
"""
@@ -17,148 +19,202 @@ class IPRestrictMiddleware:
1719
def __init__(self, get_response):
1820
self.get_response = get_response
1921

20-
def get_cached_data(self, cache_key, queryset, timeout=86400):
22+
# ----------------------------------------------------------------------
23+
# Caching and Data Retrieval
24+
# ----------------------------------------------------------------------
25+
26+
def get_blocked_entries(self):
27+
"""
28+
Retrieve all blocked entries from cache or database. We store the
29+
entire list of Blocked objects (or a subset of fields) in cache once.
30+
"""
31+
blocked_data = cache.get("blocked_entries")
32+
if blocked_data is None:
33+
# You can store only the fields needed if you prefer:
34+
# e.g., list(Blocked.objects.values('pk', 'address', 'ip_network', 'user_agent_string', 'count'))
35+
blocked_data = list(Blocked.objects.all())
36+
cache.set("blocked_entries", blocked_data, CACHE_TIMEOUT)
37+
return blocked_data
38+
39+
def get_blocked_ips(self):
2140
"""
22-
Retrieve data from cache or database.
41+
Build a set of blocked IP addresses from the cached blocked entries.
2342
"""
24-
cached_data = cache.get(cache_key)
25-
if cached_data is None:
26-
cached_data = list(filter(None, queryset)) # Filter out None values
27-
cache.set(cache_key, cached_data, timeout=timeout)
28-
return cached_data
43+
ips = cache.get("blocked_ips")
44+
if ips is None:
45+
entries = self.get_blocked_entries()
46+
# Filter out empty or None addresses, convert to a set
47+
ips = {entry.address for entry in entries if entry.address}
48+
cache.set("blocked_ips", ips, CACHE_TIMEOUT)
49+
return ips
2950

30-
def blocked_ips(self):
51+
def get_blocked_networks(self):
3152
"""
32-
Retrieve blocked IP addresses from cache or database.
53+
Build a list of blocked IP networks (ipaddress.ip_network objects)
54+
from the cached blocked entries.
3355
"""
34-
blocked_addresses = Blocked.objects.values_list("address", flat=True)
35-
return set(self.get_cached_data("blocked_ips", blocked_addresses))
56+
networks = cache.get("blocked_networks")
57+
if networks is None:
58+
entries = self.get_blocked_entries()
59+
networks = []
60+
for entry in entries:
61+
if entry.ip_network:
62+
try:
63+
net = ipaddress.ip_network(entry.ip_network, strict=False)
64+
networks.append(net)
65+
except ValueError:
66+
# Skip invalid networks
67+
pass
68+
cache.set("blocked_networks", networks, CACHE_TIMEOUT)
69+
return networks
3670

37-
def blocked_ip_network(self):
71+
def get_blocked_agents(self):
3872
"""
39-
Retrieve blocked IP networks from cache or database.
73+
Build a set of blocked user-agent substrings (lowercase) from the
74+
cached blocked entries.
4075
"""
41-
blocked_network = Blocked.objects.values_list("ip_network", flat=True)
42-
blocked_ip_network = []
76+
agents = cache.get("blocked_agents")
77+
if agents is None:
78+
entries = self.get_blocked_entries()
79+
# Filter out empty or None user_agent_string, convert to lowercase set
80+
agents = {
81+
entry.user_agent_string.lower() for entry in entries if entry.user_agent_string
82+
}
83+
cache.set("blocked_agents", agents, CACHE_TIMEOUT)
84+
return agents
4385

44-
for range_str in self.get_cached_data("blocked_ip_network", blocked_network):
45-
try:
46-
network = ipaddress.ip_network(range_str, strict=False)
47-
blocked_ip_network.append(network)
48-
except ValueError:
49-
# Log the error or handle it as needed, but skip invalid networks
50-
continue
86+
# ----------------------------------------------------------------------
87+
# Core Utility Methods
88+
# ----------------------------------------------------------------------
5189

52-
return blocked_ip_network
90+
def get_client_ip(self, request):
91+
"""
92+
Extract the client IP address from the request.
93+
"""
94+
forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR", "")
95+
if forwarded_for:
96+
return forwarded_for.split(",")[0].strip()
97+
return request.META.get("REMOTE_ADDR", "")
5398

54-
def blocked_agents(self):
99+
def get_user_agent(self, request):
55100
"""
56-
Retrieve blocked user agents from cache or database.
101+
Extract the user agent string from the request.
57102
"""
58-
blocked_user_agents = Blocked.objects.values_list("user_agent_string", flat=True)
59-
return set(self.get_cached_data("blocked_agents", blocked_user_agents))
103+
return request.META.get("HTTP_USER_AGENT", "").strip()
60104

61105
def ip_in_ips(self, ip, blocked_ips):
62106
"""
63-
Check if the IP address is in the list of blocked IPs.
107+
Check if the IP address is in the set of blocked IPs.
64108
"""
65109
return ip in blocked_ips
66110

67-
def ip_in_range(self, ip, blocked_ip_network):
111+
def ip_in_networks(self, ip, blocked_networks):
68112
"""
69113
Check if the IP address is within any of the blocked IP networks.
114+
Returns the matching network if found, else None.
70115
"""
116+
if not ip:
117+
return None
71118
ip_obj = ipaddress.ip_address(ip)
72-
return any(ip_obj in ip_range for ip_range in blocked_ip_network)
119+
for net in blocked_networks:
120+
if ip_obj in net:
121+
return net
122+
return None
73123

74124
def is_user_agent_blocked(self, user_agent, blocked_agents):
75125
"""
76-
Check if the user agent is in the list of blocked user agents by checking if the
77-
full user agent string contains any of the blocked substrings.
126+
Check if the user agent matches any of the blocked user agent
127+
substrings (case-insensitive).
78128
"""
79-
user_agent_str = str(user_agent).strip().lower()
80-
return any(blocked_agent.lower() in user_agent_str for blocked_agent in blocked_agents)
129+
ua_lower = user_agent.lower()
130+
return any(blocked_ua in ua_lower for blocked_ua in blocked_agents)
131+
132+
# ----------------------------------------------------------------------
133+
# Incrementing the Block Count in the DB
134+
# ----------------------------------------------------------------------
81135

82136
def increment_block_count(self, ip=None, network=None, user_agent=None):
83137
"""
84-
Increment the block count for a specific IP, network, or user agent in the Blocked model.
138+
Increment the block count for the matching Blocked entry.
139+
Since the actual update must happen in the DB, we do one minimal query:
140+
we find the relevant entry by IP, or by network, or by user_agent substring.
85141
"""
142+
if not (ip or network or user_agent):
143+
return # Nothing to increment
144+
86145
with transaction.atomic():
146+
# We'll do a single lookup based on what's provided.
147+
# Searching by IP / network is straightforward. Searching by user_agent
148+
# (substring) requires an __icontains lookup or a manual filter.
149+
qs = Blocked.objects.select_for_update()
150+
87151
if ip:
88-
blocked_entry = Blocked.objects.select_for_update().filter(address=ip).first()
152+
qs = qs.filter(address=ip)
89153
elif network:
90-
blocked_entry = (
91-
Blocked.objects.select_for_update().filter(ip_network=network).first()
92-
)
154+
qs = qs.filter(ip_network=network)
93155
elif user_agent:
94-
# Correct lookup: find if any user_agent_string is a substring of the user_agent
95-
blocked_entry = (
96-
Blocked.objects.select_for_update()
97-
.filter(
98-
user_agent_string__in=[
99-
agent
100-
for agent in Blocked.objects.values_list("user_agent_string", flat=True)
101-
if agent.lower() in user_agent.lower()
102-
]
103-
)
104-
.first()
156+
# If multiple user_agent_strings can match, you might want to refine logic here.
157+
qs = qs.filter(user_agent_string__iexact=user_agent) | qs.filter(
158+
user_agent_string__icontains=user_agent
105159
)
106-
else:
107-
return # Nothing to increment
108160

161+
blocked_entry = qs.first()
109162
if blocked_entry:
110163
blocked_entry.count = models.F("count") + 1
111164
blocked_entry.save(update_fields=["count"])
112165

166+
# ----------------------------------------------------------------------
167+
# Recording General IP Usage in IP Model
168+
# ----------------------------------------------------------------------
169+
170+
def record_ip_usage(self, ip, agent, path):
171+
"""
172+
Create or update the IP record for (ip, path) with an incremented count.
173+
"""
174+
with transaction.atomic():
175+
ip_record_qs = IP.objects.select_for_update().filter(address=ip, path=path)
176+
if ip_record_qs.exists():
177+
ip_record = ip_record_qs.first()
178+
ip_record.agent = agent
179+
ip_record.count = min(ip_record.count + 1, MAX_COUNT)
180+
ip_record.save(update_fields=["agent", "count"])
181+
182+
# Clean up any other records with the same (ip, path) but different PKs
183+
ip_record_qs.exclude(pk=ip_record.pk).delete()
184+
else:
185+
IP.objects.create(address=ip, agent=agent, count=1, path=path)
186+
187+
# ----------------------------------------------------------------------
188+
# Main Handler
189+
# ----------------------------------------------------------------------
190+
113191
def __call__(self, request):
114-
ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get(
115-
"REMOTE_ADDR", ""
116-
)
117-
agent = request.META.get("HTTP_USER_AGENT", "").strip()
192+
ip = self.get_client_ip(request)
193+
agent = self.get_user_agent(request)
118194

119-
blocked_ips = self.blocked_ips()
120-
blocked_ip_network = self.blocked_ip_network()
121-
blocked_agents = self.blocked_agents()
195+
# Grab cached block sets/lists (1 DB call if cache is empty)
196+
blocked_ips = self.get_blocked_ips()
197+
blocked_networks = self.get_blocked_networks()
198+
blocked_agents = self.get_blocked_agents()
122199

123-
if self.ip_in_ips(ip, blocked_ips):
200+
# 1) Check if IP is explicitly blocked
201+
if ip and self.ip_in_ips(ip, blocked_ips):
124202
self.increment_block_count(ip=ip)
125203
return HttpResponseForbidden()
126204

127-
if self.ip_in_range(ip, blocked_ip_network):
128-
# Find the specific network that caused the block and increment its count
129-
for network in blocked_ip_network:
130-
if ipaddress.ip_address(ip) in network:
131-
self.increment_block_count(network=str(network))
132-
break
205+
# 2) Check if IP is in any blocked network
206+
network_hit = self.ip_in_networks(ip, blocked_networks)
207+
if network_hit:
208+
self.increment_block_count(network=str(network_hit))
133209
return HttpResponseForbidden()
134210

135-
if self.is_user_agent_blocked(agent, blocked_agents):
211+
# 3) Check if user agent is blocked
212+
if agent and self.is_user_agent_blocked(agent, blocked_agents):
136213
self.increment_block_count(user_agent=agent)
137214
return HttpResponseForbidden()
138215

216+
# 4) Record IP usage if present
139217
if ip:
140-
with transaction.atomic():
141-
# create unique entry for every unique (ip,path) tuple
142-
# if this tuple already exists, we just increment the count.
143-
ip_records = IP.objects.select_for_update().filter(address=ip, path=request.path)
144-
if ip_records.exists():
145-
ip_record = ip_records.first()
146-
147-
# Calculate the new count and ensure it doesn't exceed the MAX_COUNT
148-
new_count = ip_record.count + 1
149-
if new_count > MAX_COUNT:
150-
new_count = MAX_COUNT
151-
152-
ip_record.agent = agent
153-
ip_record.count = new_count
154-
if ip_record.pk:
155-
ip_record.save(update_fields=["agent", "count"])
156-
157-
# Check if a transaction is already active before starting a new one
158-
if not transaction.get_autocommit():
159-
ip_records.exclude(pk=ip_record.pk).delete()
160-
else:
161-
# If no record exists, create a new one
162-
IP.objects.create(address=ip, agent=agent, count=1, path=request.path)
218+
self.record_ip_usage(ip, agent, request.path)
163219

164220
return self.get_response(request)

0 commit comments

Comments
 (0)