8
8
9
9
MAX_COUNT = 2147483647
10
10
11
- CACHE_TIMEOUT = 86400 # 1 day, adjust as needed
12
-
13
11
14
12
class IPRestrictMiddleware :
15
13
"""
@@ -19,202 +17,148 @@ class IPRestrictMiddleware:
19
17
def __init__ (self , get_response ):
20
18
self .get_response = get_response
21
19
22
- # ----------------------------------------------------------------------
23
- # Caching and Data Retrieval
24
- # ----------------------------------------------------------------------
25
-
26
- def get_blocked_entries (self ):
27
- """
28
- Retrieve all blocked entries from cache or database. We store the
29
- entire list of Blocked objects (or a subset of fields) in cache once.
30
- """
31
- blocked_data = cache .get ("blocked_entries" )
32
- if blocked_data is None :
33
- # You can store only the fields needed if you prefer:
34
- # e.g., list(Blocked.objects.values('pk', 'address', 'ip_network', 'user_agent_string', 'count'))
35
- blocked_data = list (Blocked .objects .all ())
36
- cache .set ("blocked_entries" , blocked_data , CACHE_TIMEOUT )
37
- return blocked_data
38
-
39
- def get_blocked_ips (self ):
20
+ def get_cached_data (self , cache_key , queryset , timeout = 86400 ):
40
21
"""
41
- Build a set of blocked IP addresses from the cached blocked entries .
22
+ Retrieve data from cache or database .
42
23
"""
43
- ips = cache .get ("blocked_ips" )
44
- if ips is None :
45
- entries = self .get_blocked_entries ()
46
- # Filter out empty or None addresses, convert to a set
47
- ips = {entry .address for entry in entries if entry .address }
48
- cache .set ("blocked_ips" , ips , CACHE_TIMEOUT )
49
- return ips
24
+ cached_data = cache .get (cache_key )
25
+ if cached_data is None :
26
+ cached_data = list (filter (None , queryset )) # Filter out None values
27
+ cache .set (cache_key , cached_data , timeout = timeout )
28
+ return cached_data
50
29
51
- def get_blocked_networks (self ):
30
+ def blocked_ips (self ):
52
31
"""
53
- Build a list of blocked IP networks (ipaddress.ip_network objects)
54
- from the cached blocked entries.
32
+ Retrieve blocked IP addresses from cache or database.
55
33
"""
56
- networks = cache .get ("blocked_networks" )
57
- if networks is None :
58
- entries = self .get_blocked_entries ()
59
- networks = []
60
- for entry in entries :
61
- if entry .ip_network :
62
- try :
63
- net = ipaddress .ip_network (entry .ip_network , strict = False )
64
- networks .append (net )
65
- except ValueError :
66
- # Skip invalid networks
67
- pass
68
- cache .set ("blocked_networks" , networks , CACHE_TIMEOUT )
69
- return networks
34
+ blocked_addresses = Blocked .objects .values_list ("address" , flat = True )
35
+ return set (self .get_cached_data ("blocked_ips" , blocked_addresses ))
70
36
71
- def get_blocked_agents (self ):
37
+ def blocked_ip_network (self ):
72
38
"""
73
- Build a set of blocked user-agent substrings (lowercase) from the
74
- cached blocked entries.
39
+ Retrieve blocked IP networks from cache or database.
75
40
"""
76
- agents = cache .get ("blocked_agents" )
77
- if agents is None :
78
- entries = self .get_blocked_entries ()
79
- # Filter out empty or None user_agent_string, convert to lowercase set
80
- agents = {
81
- entry .user_agent_string .lower () for entry in entries if entry .user_agent_string
82
- }
83
- cache .set ("blocked_agents" , agents , CACHE_TIMEOUT )
84
- return agents
41
+ blocked_network = Blocked .objects .values_list ("ip_network" , flat = True )
42
+ blocked_ip_network = []
85
43
86
- # ----------------------------------------------------------------------
87
- # Core Utility Methods
88
- # ----------------------------------------------------------------------
44
+ for range_str in self .get_cached_data ("blocked_ip_network" , blocked_network ):
45
+ try :
46
+ network = ipaddress .ip_network (range_str , strict = False )
47
+ blocked_ip_network .append (network )
48
+ except ValueError :
49
+ # Log the error or handle it as needed, but skip invalid networks
50
+ continue
89
51
90
- def get_client_ip (self , request ):
91
- """
92
- Extract the client IP address from the request.
93
- """
94
- forwarded_for = request .META .get ("HTTP_X_FORWARDED_FOR" , "" )
95
- if forwarded_for :
96
- return forwarded_for .split ("," )[0 ].strip ()
97
- return request .META .get ("REMOTE_ADDR" , "" )
52
+ return blocked_ip_network
98
53
99
- def get_user_agent (self , request ):
54
+ def blocked_agents (self ):
100
55
"""
101
- Extract the user agent string from the request .
56
+ Retrieve blocked user agents from cache or database .
102
57
"""
103
- return request .META .get ("HTTP_USER_AGENT" , "" ).strip ()
58
+ blocked_user_agents = Blocked .objects .values_list ("user_agent_string" , flat = True )
59
+ return set (self .get_cached_data ("blocked_agents" , blocked_user_agents ))
104
60
105
61
def ip_in_ips (self , ip , blocked_ips ):
106
62
"""
107
- Check if the IP address is in the set of blocked IPs.
63
+ Check if the IP address is in the list of blocked IPs.
108
64
"""
109
65
return ip in blocked_ips
110
66
111
- def ip_in_networks (self , ip , blocked_networks ):
67
+ def ip_in_range (self , ip , blocked_ip_network ):
112
68
"""
113
69
Check if the IP address is within any of the blocked IP networks.
114
- Returns the matching network if found, else None.
115
70
"""
116
- if not ip :
117
- return None
118
71
ip_obj = ipaddress .ip_address (ip )
119
- for net in blocked_networks :
120
- if ip_obj in net :
121
- return net
122
- return None
72
+ return any (ip_obj in ip_range for ip_range in blocked_ip_network )
123
73
124
74
def is_user_agent_blocked (self , user_agent , blocked_agents ):
125
75
"""
126
- Check if the user agent matches any of the blocked user agent
127
- substrings (case-insensitive) .
76
+ Check if the user agent is in the list of blocked user agents by checking if the
77
+ full user agent string contains any of the blocked substrings .
128
78
"""
129
- ua_lower = user_agent .lower ()
130
- return any (blocked_ua in ua_lower for blocked_ua in blocked_agents )
131
-
132
- # ----------------------------------------------------------------------
133
- # Incrementing the Block Count in the DB
134
- # ----------------------------------------------------------------------
79
+ user_agent_str = str (user_agent ).strip ().lower ()
80
+ return any (blocked_agent .lower () in user_agent_str for blocked_agent in blocked_agents )
135
81
136
82
def increment_block_count (self , ip = None , network = None , user_agent = None ):
137
83
"""
138
- Increment the block count for the matching Blocked entry.
139
- Since the actual update must happen in the DB, we do one minimal query:
140
- we find the relevant entry by IP, or by network, or by user_agent substring.
84
+ Increment the block count for a specific IP, network, or user agent in the Blocked model.
141
85
"""
142
- if not (ip or network or user_agent ):
143
- return # Nothing to increment
144
-
145
86
with transaction .atomic ():
146
- # We'll do a single lookup based on what's provided.
147
- # Searching by IP / network is straightforward. Searching by user_agent
148
- # (substring) requires an __icontains lookup or a manual filter.
149
- qs = Blocked .objects .select_for_update ()
150
-
151
87
if ip :
152
- qs = qs . filter (address = ip )
88
+ blocked_entry = Blocked . objects . select_for_update (). filter (address = ip ). first ( )
153
89
elif network :
154
- qs = qs .filter (ip_network = network )
90
+ blocked_entry = (
91
+ Blocked .objects .select_for_update ().filter (ip_network = network ).first ()
92
+ )
155
93
elif user_agent :
156
- # If multiple user_agent_strings can match, you might want to refine logic here.
157
- qs = qs .filter (user_agent_string__iexact = user_agent ) | qs .filter (
158
- user_agent_string__icontains = user_agent
94
+ # Correct lookup: find if any user_agent_string is a substring of the user_agent
95
+ blocked_entry = (
96
+ Blocked .objects .select_for_update ()
97
+ .filter (
98
+ user_agent_string__in = [
99
+ agent
100
+ for agent in Blocked .objects .values_list ("user_agent_string" , flat = True )
101
+ if agent .lower () in user_agent .lower ()
102
+ ]
103
+ )
104
+ .first ()
159
105
)
106
+ else :
107
+ return # Nothing to increment
160
108
161
- blocked_entry = qs .first ()
162
109
if blocked_entry :
163
110
blocked_entry .count = models .F ("count" ) + 1
164
111
blocked_entry .save (update_fields = ["count" ])
165
112
166
- # ----------------------------------------------------------------------
167
- # Recording General IP Usage in IP Model
168
- # ----------------------------------------------------------------------
169
-
170
- def record_ip_usage (self , ip , agent , path ):
171
- """
172
- Create or update the IP record for (ip, path) with an incremented count.
173
- """
174
- with transaction .atomic ():
175
- ip_record_qs = IP .objects .select_for_update ().filter (address = ip , path = path )
176
- if ip_record_qs .exists ():
177
- ip_record = ip_record_qs .first ()
178
- ip_record .agent = agent
179
- ip_record .count = min (ip_record .count + 1 , MAX_COUNT )
180
- ip_record .save (update_fields = ["agent" , "count" ])
181
-
182
- # Clean up any other records with the same (ip, path) but different PKs
183
- ip_record_qs .exclude (pk = ip_record .pk ).delete ()
184
- else :
185
- IP .objects .create (address = ip , agent = agent , count = 1 , path = path )
186
-
187
- # ----------------------------------------------------------------------
188
- # Main Handler
189
- # ----------------------------------------------------------------------
190
-
191
113
def __call__ (self , request ):
192
- ip = self .get_client_ip (request )
193
- agent = self .get_user_agent (request )
114
+ ip = request .META .get ("HTTP_X_FORWARDED_FOR" , "" ).split ("," )[0 ].strip () or request .META .get (
115
+ "REMOTE_ADDR" , ""
116
+ )
117
+ agent = request .META .get ("HTTP_USER_AGENT" , "" ).strip ()
194
118
195
- # Grab cached block sets/lists (1 DB call if cache is empty)
196
- blocked_ips = self .get_blocked_ips ()
197
- blocked_networks = self .get_blocked_networks ()
198
- blocked_agents = self .get_blocked_agents ()
119
+ blocked_ips = self .blocked_ips ()
120
+ blocked_ip_network = self .blocked_ip_network ()
121
+ blocked_agents = self .blocked_agents ()
199
122
200
- # 1) Check if IP is explicitly blocked
201
- if ip and self .ip_in_ips (ip , blocked_ips ):
123
+ if self .ip_in_ips (ip , blocked_ips ):
202
124
self .increment_block_count (ip = ip )
203
125
return HttpResponseForbidden ()
204
126
205
- # 2) Check if IP is in any blocked network
206
- network_hit = self .ip_in_networks (ip , blocked_networks )
207
- if network_hit :
208
- self .increment_block_count (network = str (network_hit ))
127
+ if self .ip_in_range (ip , blocked_ip_network ):
128
+ # Find the specific network that caused the block and increment its count
129
+ for network in blocked_ip_network :
130
+ if ipaddress .ip_address (ip ) in network :
131
+ self .increment_block_count (network = str (network ))
132
+ break
209
133
return HttpResponseForbidden ()
210
134
211
- # 3) Check if user agent is blocked
212
- if agent and self .is_user_agent_blocked (agent , blocked_agents ):
135
+ if self .is_user_agent_blocked (agent , blocked_agents ):
213
136
self .increment_block_count (user_agent = agent )
214
137
return HttpResponseForbidden ()
215
138
216
- # 4) Record IP usage if present
217
139
if ip :
218
- self .record_ip_usage (ip , agent , request .path )
140
+ with transaction .atomic ():
141
+ # create unique entry for every unique (ip,path) tuple
142
+ # if this tuple already exists, we just increment the count.
143
+ ip_records = IP .objects .select_for_update ().filter (address = ip , path = request .path )
144
+ if ip_records .exists ():
145
+ ip_record = ip_records .first ()
146
+
147
+ # Calculate the new count and ensure it doesn't exceed the MAX_COUNT
148
+ new_count = ip_record .count + 1
149
+ if new_count > MAX_COUNT :
150
+ new_count = MAX_COUNT
151
+
152
+ ip_record .agent = agent
153
+ ip_record .count = new_count
154
+ if ip_record .pk :
155
+ ip_record .save (update_fields = ["agent" , "count" ])
156
+
157
+ # Check if a transaction is already active before starting a new one
158
+ if not transaction .get_autocommit ():
159
+ ip_records .exclude (pk = ip_record .pk ).delete ()
160
+ else :
161
+ # If no record exists, create a new one
162
+ IP .objects .create (address = ip , agent = agent , count = 1 , path = request .path )
219
163
220
164
return self .get_response (request )
0 commit comments