Skip to content

Commit cbb5058

Browse files
committed
Refactor code structure and enhance readability in core.py
1 parent 4df4e41 commit cbb5058

File tree

1 file changed

+90
-146
lines changed

1 file changed

+90
-146
lines changed

blt/middleware/ip_restrict.py

Lines changed: 90 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
MAX_COUNT = 2147483647
1010

11-
CACHE_TIMEOUT = 86400 # 1 day, adjust as needed
12-
1311

1412
class IPRestrictMiddleware:
1513
"""
@@ -19,202 +17,148 @@ class IPRestrictMiddleware:
1917
def __init__(self, get_response):
2018
self.get_response = get_response
2119

22-
# ----------------------------------------------------------------------
23-
# Caching and Data Retrieval
24-
# ----------------------------------------------------------------------
25-
26-
def get_blocked_entries(self):
27-
"""
28-
Retrieve all blocked entries from cache or database. We store the
29-
entire list of Blocked objects (or a subset of fields) in cache once.
30-
"""
31-
blocked_data = cache.get("blocked_entries")
32-
if blocked_data is None:
33-
# You can store only the fields needed if you prefer:
34-
# e.g., list(Blocked.objects.values('pk', 'address', 'ip_network', 'user_agent_string', 'count'))
35-
blocked_data = list(Blocked.objects.all())
36-
cache.set("blocked_entries", blocked_data, CACHE_TIMEOUT)
37-
return blocked_data
38-
39-
def get_blocked_ips(self):
20+
def get_cached_data(self, cache_key, queryset, timeout=86400):
4021
"""
41-
Build a set of blocked IP addresses from the cached blocked entries.
22+
Retrieve data from cache or database.
4223
"""
43-
ips = cache.get("blocked_ips")
44-
if ips is None:
45-
entries = self.get_blocked_entries()
46-
# Filter out empty or None addresses, convert to a set
47-
ips = {entry.address for entry in entries if entry.address}
48-
cache.set("blocked_ips", ips, CACHE_TIMEOUT)
49-
return ips
24+
cached_data = cache.get(cache_key)
25+
if cached_data is None:
26+
cached_data = list(filter(None, queryset)) # Filter out None values
27+
cache.set(cache_key, cached_data, timeout=timeout)
28+
return cached_data
5029

51-
def get_blocked_networks(self):
30+
def blocked_ips(self):
5231
"""
53-
Build a list of blocked IP networks (ipaddress.ip_network objects)
54-
from the cached blocked entries.
32+
Retrieve blocked IP addresses from cache or database.
5533
"""
56-
networks = cache.get("blocked_networks")
57-
if networks is None:
58-
entries = self.get_blocked_entries()
59-
networks = []
60-
for entry in entries:
61-
if entry.ip_network:
62-
try:
63-
net = ipaddress.ip_network(entry.ip_network, strict=False)
64-
networks.append(net)
65-
except ValueError:
66-
# Skip invalid networks
67-
pass
68-
cache.set("blocked_networks", networks, CACHE_TIMEOUT)
69-
return networks
34+
blocked_addresses = Blocked.objects.values_list("address", flat=True)
35+
return set(self.get_cached_data("blocked_ips", blocked_addresses))
7036

71-
def get_blocked_agents(self):
37+
def blocked_ip_network(self):
7238
"""
73-
Build a set of blocked user-agent substrings (lowercase) from the
74-
cached blocked entries.
39+
Retrieve blocked IP networks from cache or database.
7540
"""
76-
agents = cache.get("blocked_agents")
77-
if agents is None:
78-
entries = self.get_blocked_entries()
79-
# Filter out empty or None user_agent_string, convert to lowercase set
80-
agents = {
81-
entry.user_agent_string.lower() for entry in entries if entry.user_agent_string
82-
}
83-
cache.set("blocked_agents", agents, CACHE_TIMEOUT)
84-
return agents
41+
blocked_network = Blocked.objects.values_list("ip_network", flat=True)
42+
blocked_ip_network = []
8543

86-
# ----------------------------------------------------------------------
87-
# Core Utility Methods
88-
# ----------------------------------------------------------------------
44+
for range_str in self.get_cached_data("blocked_ip_network", blocked_network):
45+
try:
46+
network = ipaddress.ip_network(range_str, strict=False)
47+
blocked_ip_network.append(network)
48+
except ValueError:
49+
# Log the error or handle it as needed, but skip invalid networks
50+
continue
8951

90-
def get_client_ip(self, request):
91-
"""
92-
Extract the client IP address from the request.
93-
"""
94-
forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR", "")
95-
if forwarded_for:
96-
return forwarded_for.split(",")[0].strip()
97-
return request.META.get("REMOTE_ADDR", "")
52+
return blocked_ip_network
9853

99-
def get_user_agent(self, request):
54+
def blocked_agents(self):
10055
"""
101-
Extract the user agent string from the request.
56+
Retrieve blocked user agents from cache or database.
10257
"""
103-
return request.META.get("HTTP_USER_AGENT", "").strip()
58+
blocked_user_agents = Blocked.objects.values_list("user_agent_string", flat=True)
59+
return set(self.get_cached_data("blocked_agents", blocked_user_agents))
10460

10561
def ip_in_ips(self, ip, blocked_ips):
10662
"""
107-
Check if the IP address is in the set of blocked IPs.
63+
Check if the IP address is in the list of blocked IPs.
10864
"""
10965
return ip in blocked_ips
11066

111-
def ip_in_networks(self, ip, blocked_networks):
67+
def ip_in_range(self, ip, blocked_ip_network):
11268
"""
11369
Check if the IP address is within any of the blocked IP networks.
114-
Returns the matching network if found, else None.
11570
"""
116-
if not ip:
117-
return None
11871
ip_obj = ipaddress.ip_address(ip)
119-
for net in blocked_networks:
120-
if ip_obj in net:
121-
return net
122-
return None
72+
return any(ip_obj in ip_range for ip_range in blocked_ip_network)
12373

12474
def is_user_agent_blocked(self, user_agent, blocked_agents):
12575
"""
126-
Check if the user agent matches any of the blocked user agent
127-
substrings (case-insensitive).
76+
Check if the user agent is in the list of blocked user agents by checking if the
77+
full user agent string contains any of the blocked substrings.
12878
"""
129-
ua_lower = user_agent.lower()
130-
return any(blocked_ua in ua_lower for blocked_ua in blocked_agents)
131-
132-
# ----------------------------------------------------------------------
133-
# Incrementing the Block Count in the DB
134-
# ----------------------------------------------------------------------
79+
user_agent_str = str(user_agent).strip().lower()
80+
return any(blocked_agent.lower() in user_agent_str for blocked_agent in blocked_agents)
13581

13682
def increment_block_count(self, ip=None, network=None, user_agent=None):
13783
"""
138-
Increment the block count for the matching Blocked entry.
139-
Since the actual update must happen in the DB, we do one minimal query:
140-
we find the relevant entry by IP, or by network, or by user_agent substring.
84+
Increment the block count for a specific IP, network, or user agent in the Blocked model.
14185
"""
142-
if not (ip or network or user_agent):
143-
return # Nothing to increment
144-
14586
with transaction.atomic():
146-
# We'll do a single lookup based on what's provided.
147-
# Searching by IP / network is straightforward. Searching by user_agent
148-
# (substring) requires an __icontains lookup or a manual filter.
149-
qs = Blocked.objects.select_for_update()
150-
15187
if ip:
152-
qs = qs.filter(address=ip)
88+
blocked_entry = Blocked.objects.select_for_update().filter(address=ip).first()
15389
elif network:
154-
qs = qs.filter(ip_network=network)
90+
blocked_entry = (
91+
Blocked.objects.select_for_update().filter(ip_network=network).first()
92+
)
15593
elif user_agent:
156-
# If multiple user_agent_strings can match, you might want to refine logic here.
157-
qs = qs.filter(user_agent_string__iexact=user_agent) | qs.filter(
158-
user_agent_string__icontains=user_agent
94+
# Correct lookup: find if any user_agent_string is a substring of the user_agent
95+
blocked_entry = (
96+
Blocked.objects.select_for_update()
97+
.filter(
98+
user_agent_string__in=[
99+
agent
100+
for agent in Blocked.objects.values_list("user_agent_string", flat=True)
101+
if agent.lower() in user_agent.lower()
102+
]
103+
)
104+
.first()
159105
)
106+
else:
107+
return # Nothing to increment
160108

161-
blocked_entry = qs.first()
162109
if blocked_entry:
163110
blocked_entry.count = models.F("count") + 1
164111
blocked_entry.save(update_fields=["count"])
165112

166-
# ----------------------------------------------------------------------
167-
# Recording General IP Usage in IP Model
168-
# ----------------------------------------------------------------------
169-
170-
def record_ip_usage(self, ip, agent, path):
171-
"""
172-
Create or update the IP record for (ip, path) with an incremented count.
173-
"""
174-
with transaction.atomic():
175-
ip_record_qs = IP.objects.select_for_update().filter(address=ip, path=path)
176-
if ip_record_qs.exists():
177-
ip_record = ip_record_qs.first()
178-
ip_record.agent = agent
179-
ip_record.count = min(ip_record.count + 1, MAX_COUNT)
180-
ip_record.save(update_fields=["agent", "count"])
181-
182-
# Clean up any other records with the same (ip, path) but different PKs
183-
ip_record_qs.exclude(pk=ip_record.pk).delete()
184-
else:
185-
IP.objects.create(address=ip, agent=agent, count=1, path=path)
186-
187-
# ----------------------------------------------------------------------
188-
# Main Handler
189-
# ----------------------------------------------------------------------
190-
191113
def __call__(self, request):
192-
ip = self.get_client_ip(request)
193-
agent = self.get_user_agent(request)
114+
ip = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0].strip() or request.META.get(
115+
"REMOTE_ADDR", ""
116+
)
117+
agent = request.META.get("HTTP_USER_AGENT", "").strip()
194118

195-
# Grab cached block sets/lists (1 DB call if cache is empty)
196-
blocked_ips = self.get_blocked_ips()
197-
blocked_networks = self.get_blocked_networks()
198-
blocked_agents = self.get_blocked_agents()
119+
blocked_ips = self.blocked_ips()
120+
blocked_ip_network = self.blocked_ip_network()
121+
blocked_agents = self.blocked_agents()
199122

200-
# 1) Check if IP is explicitly blocked
201-
if ip and self.ip_in_ips(ip, blocked_ips):
123+
if self.ip_in_ips(ip, blocked_ips):
202124
self.increment_block_count(ip=ip)
203125
return HttpResponseForbidden()
204126

205-
# 2) Check if IP is in any blocked network
206-
network_hit = self.ip_in_networks(ip, blocked_networks)
207-
if network_hit:
208-
self.increment_block_count(network=str(network_hit))
127+
if self.ip_in_range(ip, blocked_ip_network):
128+
# Find the specific network that caused the block and increment its count
129+
for network in blocked_ip_network:
130+
if ipaddress.ip_address(ip) in network:
131+
self.increment_block_count(network=str(network))
132+
break
209133
return HttpResponseForbidden()
210134

211-
# 3) Check if user agent is blocked
212-
if agent and self.is_user_agent_blocked(agent, blocked_agents):
135+
if self.is_user_agent_blocked(agent, blocked_agents):
213136
self.increment_block_count(user_agent=agent)
214137
return HttpResponseForbidden()
215138

216-
# 4) Record IP usage if present
217139
if ip:
218-
self.record_ip_usage(ip, agent, request.path)
140+
with transaction.atomic():
141+
# create unique entry for every unique (ip,path) tuple
142+
# if this tuple already exists, we just increment the count.
143+
ip_records = IP.objects.select_for_update().filter(address=ip, path=request.path)
144+
if ip_records.exists():
145+
ip_record = ip_records.first()
146+
147+
# Calculate the new count and ensure it doesn't exceed the MAX_COUNT
148+
new_count = ip_record.count + 1
149+
if new_count > MAX_COUNT:
150+
new_count = MAX_COUNT
151+
152+
ip_record.agent = agent
153+
ip_record.count = new_count
154+
if ip_record.pk:
155+
ip_record.save(update_fields=["agent", "count"])
156+
157+
# Check if a transaction is already active before starting a new one
158+
if not transaction.get_autocommit():
159+
ip_records.exclude(pk=ip_record.pk).delete()
160+
else:
161+
# If no record exists, create a new one
162+
IP.objects.create(address=ip, agent=agent, count=1, path=request.path)
219163

220164
return self.get_response(request)

0 commit comments

Comments
 (0)