8
8
9
9
MAX_COUNT = 2147483647
10
10
11
+ CACHE_TIMEOUT = 86400 # 1 day, adjust as needed
12
+
11
13
12
14
class IPRestrictMiddleware :
13
15
"""
@@ -17,148 +19,202 @@ class IPRestrictMiddleware:
17
19
def __init__ (self , get_response ):
18
20
self .get_response = get_response
19
21
20
- def get_cached_data (self , cache_key , queryset , timeout = 86400 ):
22
+ # ----------------------------------------------------------------------
23
+ # Caching and Data Retrieval
24
+ # ----------------------------------------------------------------------
25
+
26
+ def get_blocked_entries (self ):
27
+ """
28
+ Retrieve all blocked entries from cache or database. We store the
29
+ entire list of Blocked objects (or a subset of fields) in cache once.
30
+ """
31
+ blocked_data = cache .get ("blocked_entries" )
32
+ if blocked_data is None :
33
+ # You can store only the fields needed if you prefer:
34
+ # e.g., list(Blocked.objects.values('pk', 'address', 'ip_network', 'user_agent_string', 'count'))
35
+ blocked_data = list (Blocked .objects .all ())
36
+ cache .set ("blocked_entries" , blocked_data , CACHE_TIMEOUT )
37
+ return blocked_data
38
+
39
+ def get_blocked_ips (self ):
21
40
"""
22
- Retrieve data from cache or database .
41
+ Build a set of blocked IP addresses from the cached blocked entries .
23
42
"""
24
- cached_data = cache .get (cache_key )
25
- if cached_data is None :
26
- cached_data = list (filter (None , queryset )) # Filter out None values
27
- cache .set (cache_key , cached_data , timeout = timeout )
28
- return cached_data
43
+ ips = cache .get ("blocked_ips" )
44
+ if ips is None :
45
+ entries = self .get_blocked_entries ()
46
+ # Filter out empty or None addresses, convert to a set
47
+ ips = {entry .address for entry in entries if entry .address }
48
+ cache .set ("blocked_ips" , ips , CACHE_TIMEOUT )
49
+ return ips
29
50
30
- def blocked_ips (self ):
51
+ def get_blocked_networks (self ):
31
52
"""
32
- Retrieve blocked IP addresses from cache or database.
53
+ Build a list of blocked IP networks (ipaddress.ip_network objects)
54
+ from the cached blocked entries.
33
55
"""
34
- blocked_addresses = Blocked .objects .values_list ("address" , flat = True )
35
- return set (self .get_cached_data ("blocked_ips" , blocked_addresses ))
56
+ networks = cache .get ("blocked_networks" )
57
+ if networks is None :
58
+ entries = self .get_blocked_entries ()
59
+ networks = []
60
+ for entry in entries :
61
+ if entry .ip_network :
62
+ try :
63
+ net = ipaddress .ip_network (entry .ip_network , strict = False )
64
+ networks .append (net )
65
+ except ValueError :
66
+ # Skip invalid networks
67
+ pass
68
+ cache .set ("blocked_networks" , networks , CACHE_TIMEOUT )
69
+ return networks
36
70
37
- def blocked_ip_network (self ):
71
+ def get_blocked_agents (self ):
38
72
"""
39
- Retrieve blocked IP networks from cache or database.
73
+ Build a set of blocked user-agent substrings (lowercase) from the
74
+ cached blocked entries.
40
75
"""
41
- blocked_network = Blocked .objects .values_list ("ip_network" , flat = True )
42
- blocked_ip_network = []
76
+ agents = cache .get ("blocked_agents" )
77
+ if agents is None :
78
+ entries = self .get_blocked_entries ()
79
+ # Filter out empty or None user_agent_string, convert to lowercase set
80
+ agents = {
81
+ entry .user_agent_string .lower () for entry in entries if entry .user_agent_string
82
+ }
83
+ cache .set ("blocked_agents" , agents , CACHE_TIMEOUT )
84
+ return agents
43
85
44
- for range_str in self .get_cached_data ("blocked_ip_network" , blocked_network ):
45
- try :
46
- network = ipaddress .ip_network (range_str , strict = False )
47
- blocked_ip_network .append (network )
48
- except ValueError :
49
- # Log the error or handle it as needed, but skip invalid networks
50
- continue
86
+ # ----------------------------------------------------------------------
87
+ # Core Utility Methods
88
+ # ----------------------------------------------------------------------
51
89
52
- return blocked_ip_network
90
+ def get_client_ip (self , request ):
91
+ """
92
+ Extract the client IP address from the request.
93
+ """
94
+ forwarded_for = request .META .get ("HTTP_X_FORWARDED_FOR" , "" )
95
+ if forwarded_for :
96
+ return forwarded_for .split ("," )[0 ].strip ()
97
+ return request .META .get ("REMOTE_ADDR" , "" )
53
98
54
- def blocked_agents (self ):
99
+ def get_user_agent (self , request ):
55
100
"""
56
- Retrieve blocked user agents from cache or database .
101
+ Extract the user agent string from the request .
57
102
"""
58
- blocked_user_agents = Blocked .objects .values_list ("user_agent_string" , flat = True )
59
- return set (self .get_cached_data ("blocked_agents" , blocked_user_agents ))
103
+ return request .META .get ("HTTP_USER_AGENT" , "" ).strip ()
60
104
61
105
def ip_in_ips (self , ip , blocked_ips ):
62
106
"""
63
- Check if the IP address is in the list of blocked IPs.
107
+ Check if the IP address is in the set of blocked IPs.
64
108
"""
65
109
return ip in blocked_ips
66
110
67
- def ip_in_range (self , ip , blocked_ip_network ):
111
+ def ip_in_networks (self , ip , blocked_networks ):
68
112
"""
69
113
Check if the IP address is within any of the blocked IP networks.
114
+ Returns the matching network if found, else None.
70
115
"""
116
+ if not ip :
117
+ return None
71
118
ip_obj = ipaddress .ip_address (ip )
72
- return any (ip_obj in ip_range for ip_range in blocked_ip_network )
119
+ for net in blocked_networks :
120
+ if ip_obj in net :
121
+ return net
122
+ return None
73
123
74
124
def is_user_agent_blocked (self , user_agent , blocked_agents ):
75
125
"""
76
- Check if the user agent is in the list of blocked user agents by checking if the
77
- full user agent string contains any of the blocked substrings.
126
+ Check if the user agent matches any of the blocked user agent
127
+ substrings (case-insensitive) .
78
128
"""
79
- user_agent_str = str (user_agent ).strip ().lower ()
80
- return any (blocked_agent .lower () in user_agent_str for blocked_agent in blocked_agents )
129
+ ua_lower = user_agent .lower ()
130
+ return any (blocked_ua in ua_lower for blocked_ua in blocked_agents )
131
+
132
+ # ----------------------------------------------------------------------
133
+ # Incrementing the Block Count in the DB
134
+ # ----------------------------------------------------------------------
81
135
82
136
def increment_block_count (self , ip = None , network = None , user_agent = None ):
83
137
"""
84
- Increment the block count for a specific IP, network, or user agent in the Blocked model.
138
+ Increment the block count for the matching Blocked entry.
139
+ Since the actual update must happen in the DB, we do one minimal query:
140
+ we find the relevant entry by IP, or by network, or by user_agent substring.
85
141
"""
142
+ if not (ip or network or user_agent ):
143
+ return # Nothing to increment
144
+
86
145
with transaction .atomic ():
146
+ # We'll do a single lookup based on what's provided.
147
+ # Searching by IP / network is straightforward. Searching by user_agent
148
+ # (substring) requires an __icontains lookup or a manual filter.
149
+ qs = Blocked .objects .select_for_update ()
150
+
87
151
if ip :
88
- blocked_entry = Blocked . objects . select_for_update (). filter (address = ip ). first ( )
152
+ qs = qs . filter (address = ip )
89
153
elif network :
90
- blocked_entry = (
91
- Blocked .objects .select_for_update ().filter (ip_network = network ).first ()
92
- )
154
+ qs = qs .filter (ip_network = network )
93
155
elif user_agent :
94
- # Correct lookup: find if any user_agent_string is a substring of the user_agent
95
- blocked_entry = (
96
- Blocked .objects .select_for_update ()
97
- .filter (
98
- user_agent_string__in = [
99
- agent
100
- for agent in Blocked .objects .values_list ("user_agent_string" , flat = True )
101
- if agent .lower () in user_agent .lower ()
102
- ]
103
- )
104
- .first ()
156
+ # If multiple user_agent_strings can match, you might want to refine logic here.
157
+ qs = qs .filter (user_agent_string__iexact = user_agent ) | qs .filter (
158
+ user_agent_string__icontains = user_agent
105
159
)
106
- else :
107
- return # Nothing to increment
108
160
161
+ blocked_entry = qs .first ()
109
162
if blocked_entry :
110
163
blocked_entry .count = models .F ("count" ) + 1
111
164
blocked_entry .save (update_fields = ["count" ])
112
165
166
+ # ----------------------------------------------------------------------
167
+ # Recording General IP Usage in IP Model
168
+ # ----------------------------------------------------------------------
169
+
170
+ def record_ip_usage (self , ip , agent , path ):
171
+ """
172
+ Create or update the IP record for (ip, path) with an incremented count.
173
+ """
174
+ with transaction .atomic ():
175
+ ip_record_qs = IP .objects .select_for_update ().filter (address = ip , path = path )
176
+ if ip_record_qs .exists ():
177
+ ip_record = ip_record_qs .first ()
178
+ ip_record .agent = agent
179
+ ip_record .count = min (ip_record .count + 1 , MAX_COUNT )
180
+ ip_record .save (update_fields = ["agent" , "count" ])
181
+
182
+ # Clean up any other records with the same (ip, path) but different PKs
183
+ ip_record_qs .exclude (pk = ip_record .pk ).delete ()
184
+ else :
185
+ IP .objects .create (address = ip , agent = agent , count = 1 , path = path )
186
+
187
+ # ----------------------------------------------------------------------
188
+ # Main Handler
189
+ # ----------------------------------------------------------------------
190
+
113
191
def __call__ (self , request ):
114
- ip = request .META .get ("HTTP_X_FORWARDED_FOR" , "" ).split ("," )[0 ].strip () or request .META .get (
115
- "REMOTE_ADDR" , ""
116
- )
117
- agent = request .META .get ("HTTP_USER_AGENT" , "" ).strip ()
192
+ ip = self .get_client_ip (request )
193
+ agent = self .get_user_agent (request )
118
194
119
- blocked_ips = self .blocked_ips ()
120
- blocked_ip_network = self .blocked_ip_network ()
121
- blocked_agents = self .blocked_agents ()
195
+ # Grab cached block sets/lists (1 DB call if cache is empty)
196
+ blocked_ips = self .get_blocked_ips ()
197
+ blocked_networks = self .get_blocked_networks ()
198
+ blocked_agents = self .get_blocked_agents ()
122
199
123
- if self .ip_in_ips (ip , blocked_ips ):
200
+ # 1) Check if IP is explicitly blocked
201
+ if ip and self .ip_in_ips (ip , blocked_ips ):
124
202
self .increment_block_count (ip = ip )
125
203
return HttpResponseForbidden ()
126
204
127
- if self .ip_in_range (ip , blocked_ip_network ):
128
- # Find the specific network that caused the block and increment its count
129
- for network in blocked_ip_network :
130
- if ipaddress .ip_address (ip ) in network :
131
- self .increment_block_count (network = str (network ))
132
- break
205
+ # 2) Check if IP is in any blocked network
206
+ network_hit = self .ip_in_networks (ip , blocked_networks )
207
+ if network_hit :
208
+ self .increment_block_count (network = str (network_hit ))
133
209
return HttpResponseForbidden ()
134
210
135
- if self .is_user_agent_blocked (agent , blocked_agents ):
211
+ # 3) Check if user agent is blocked
212
+ if agent and self .is_user_agent_blocked (agent , blocked_agents ):
136
213
self .increment_block_count (user_agent = agent )
137
214
return HttpResponseForbidden ()
138
215
216
+ # 4) Record IP usage if present
139
217
if ip :
140
- with transaction .atomic ():
141
- # create unique entry for every unique (ip,path) tuple
142
- # if this tuple already exists, we just increment the count.
143
- ip_records = IP .objects .select_for_update ().filter (address = ip , path = request .path )
144
- if ip_records .exists ():
145
- ip_record = ip_records .first ()
146
-
147
- # Calculate the new count and ensure it doesn't exceed the MAX_COUNT
148
- new_count = ip_record .count + 1
149
- if new_count > MAX_COUNT :
150
- new_count = MAX_COUNT
151
-
152
- ip_record .agent = agent
153
- ip_record .count = new_count
154
- if ip_record .pk :
155
- ip_record .save (update_fields = ["agent" , "count" ])
156
-
157
- # Check if a transaction is already active before starting a new one
158
- if not transaction .get_autocommit ():
159
- ip_records .exclude (pk = ip_record .pk ).delete ()
160
- else :
161
- # If no record exists, create a new one
162
- IP .objects .create (address = ip , agent = agent , count = 1 , path = request .path )
218
+ self .record_ip_usage (ip , agent , request .path )
163
219
164
220
return self .get_response (request )
0 commit comments