Skip to content

Commit ed40a24

Browse files
committedAug 8, 2025·
Drop django-ratelimt, use a built-in
1 parent 62405d0 commit ed40a24

File tree

6 files changed

+687
-5
lines changed

6 files changed

+687
-5
lines changed
 

‎judge/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Tests package for judge app

‎judge/tests/test_ratelimit.py

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
"""
2+
Tests for the built-in rate limiting system
3+
"""
4+
5+
import time
6+
from unittest.mock import patch, MagicMock
7+
8+
from django.test import TestCase, RequestFactory
9+
from django.contrib.auth.models import User
10+
from django.core.cache import cache
11+
from django.http import HttpResponse
12+
13+
from judge.utils.ratelimit import (
14+
parse_rate,
15+
get_cache_key,
16+
get_client_ip,
17+
is_rate_limited,
18+
create_rate_limit_response,
19+
ratelimit,
20+
RateLimitExceeded,
21+
)
22+
23+
24+
class ParseRateTestCase(TestCase):
25+
"""Test rate parsing functionality"""
26+
27+
def test_parse_valid_rates(self):
28+
"""Test parsing of valid rate strings"""
29+
test_cases = [
30+
("30/h", (30, 3600)),
31+
("200/h", (200, 3600)),
32+
("10/m", (10, 60)),
33+
("5/s", (5, 1)),
34+
("1/d", (1, 86400)),
35+
("100/H", (100, 3600)), # Case insensitive
36+
]
37+
38+
for rate_str, expected in test_cases:
39+
with self.subTest(rate=rate_str):
40+
result = parse_rate(rate_str)
41+
self.assertEqual(result, expected)
42+
43+
def test_parse_invalid_rates(self):
44+
"""Test parsing of invalid rate strings"""
45+
invalid_rates = [
46+
"",
47+
"30",
48+
"30/",
49+
"/h",
50+
"30/x",
51+
"abc/h",
52+
"30/hour",
53+
"30-h",
54+
]
55+
56+
for rate_str in invalid_rates:
57+
with self.subTest(rate=rate_str):
58+
with self.assertRaises(ValueError):
59+
parse_rate(rate_str)
60+
61+
62+
class GetClientIpTestCase(TestCase):
63+
"""Test client IP detection"""
64+
65+
def setUp(self):
66+
self.factory = RequestFactory()
67+
68+
def test_x_forwarded_for(self):
69+
"""Test IP detection from X-Forwarded-For header"""
70+
request = self.factory.get("/")
71+
request.META["HTTP_X_FORWARDED_FOR"] = "192.168.1.1, 10.0.0.1"
72+
73+
ip = get_client_ip(request)
74+
self.assertEqual(ip, "192.168.1.1")
75+
76+
def test_x_real_ip(self):
77+
"""Test IP detection from X-Real-IP header"""
78+
request = self.factory.get("/")
79+
request.META["HTTP_X_REAL_IP"] = "192.168.1.2"
80+
81+
ip = get_client_ip(request)
82+
self.assertEqual(ip, "192.168.1.2")
83+
84+
def test_remote_addr(self):
85+
"""Test IP detection from REMOTE_ADDR"""
86+
request = self.factory.get("/")
87+
request.META["REMOTE_ADDR"] = "192.168.1.3"
88+
89+
ip = get_client_ip(request)
90+
self.assertEqual(ip, "192.168.1.3")
91+
92+
def test_no_ip_fallback(self):
93+
"""Test fallback when no IP is available"""
94+
request = self.factory.get("/")
95+
# Clear all IP-related headers
96+
if "REMOTE_ADDR" in request.META:
97+
del request.META["REMOTE_ADDR"]
98+
99+
ip = get_client_ip(request)
100+
self.assertEqual(ip, "unknown")
101+
102+
103+
class GetCacheKeyTestCase(TestCase):
104+
"""Test cache key generation"""
105+
106+
def setUp(self):
107+
self.factory = RequestFactory()
108+
self.user = User.objects.create_user(
109+
username="testuser", email="test@example.com", password="testpass"
110+
)
111+
112+
def test_user_key_authenticated(self):
113+
"""Test cache key for authenticated user"""
114+
request = self.factory.get("/")
115+
request.user = self.user
116+
117+
key = get_cache_key(request, "user", "test_view")
118+
expected = f"ratelimit:user:{self.user.id}:test_view"
119+
self.assertEqual(key, expected)
120+
121+
def test_user_key_anonymous(self):
122+
"""Test cache key for anonymous user falls back to IP"""
123+
request = self.factory.get("/")
124+
request.user = MagicMock()
125+
request.user.is_authenticated = False
126+
request.META["REMOTE_ADDR"] = "192.168.1.1"
127+
128+
key = get_cache_key(request, "user", "test_view")
129+
expected = "ratelimit:user:192.168.1.1:test_view"
130+
self.assertEqual(key, expected)
131+
132+
def test_ip_key(self):
133+
"""Test cache key for IP-based rate limiting"""
134+
request = self.factory.get("/")
135+
request.META["REMOTE_ADDR"] = "192.168.1.1"
136+
137+
key = get_cache_key(request, "ip", "test_view")
138+
expected = "ratelimit:ip:192.168.1.1:test_view"
139+
self.assertEqual(key, expected)
140+
141+
def test_header_key(self):
142+
"""Test cache key for header-based rate limiting"""
143+
request = self.factory.get("/")
144+
request.META["HTTP_X_API_KEY"] = "test-api-key"
145+
146+
key = get_cache_key(request, "header:x-api-key", "test_view")
147+
expected = "ratelimit:header:x-api-key:test-api-key:test_view"
148+
self.assertEqual(key, expected)
149+
150+
151+
class IsRateLimitedTestCase(TestCase):
152+
"""Test rate limiting logic"""
153+
154+
def setUp(self):
155+
cache.clear()
156+
157+
def tearDown(self):
158+
cache.clear()
159+
160+
def test_first_request_not_limited(self):
161+
"""Test that first request is not rate limited"""
162+
is_limited, count, reset_time = is_rate_limited("test_key", 5, 3600)
163+
164+
self.assertFalse(is_limited)
165+
self.assertEqual(count, 1)
166+
self.assertGreater(reset_time, time.time())
167+
168+
def test_within_limit_not_limited(self):
169+
"""Test requests within limit are not blocked"""
170+
cache_key = "test_key_within"
171+
172+
# Make 3 requests (limit is 5)
173+
for i in range(3):
174+
is_limited, count, reset_time = is_rate_limited(cache_key, 5, 3600)
175+
self.assertFalse(is_limited)
176+
self.assertEqual(count, i + 1)
177+
178+
def test_exceed_limit_blocked(self):
179+
"""Test requests exceeding limit are blocked"""
180+
cache_key = "test_key_exceed"
181+
182+
# Make requests up to the limit
183+
for i in range(5):
184+
is_limited, count, reset_time = is_rate_limited(cache_key, 5, 3600)
185+
self.assertFalse(is_limited)
186+
187+
# Next request should be blocked
188+
is_limited, count, reset_time = is_rate_limited(cache_key, 5, 3600)
189+
self.assertTrue(is_limited)
190+
self.assertEqual(count, 5)
191+
192+
@patch("time.time")
193+
def test_sliding_window_cleanup(self, mock_time):
194+
"""Test that old timestamps are cleaned up"""
195+
cache_key = "test_key_cleanup"
196+
197+
# Set initial time
198+
mock_time.return_value = 1000
199+
200+
# Make 3 requests
201+
for i in range(3):
202+
is_rate_limited(cache_key, 5, 60) # 5 requests per minute
203+
204+
# Move time forward by 61 seconds (past the window)
205+
mock_time.return_value = 1061
206+
207+
# Next request should not be limited (old timestamps cleaned up)
208+
is_limited, count, reset_time = is_rate_limited(cache_key, 5, 60)
209+
self.assertFalse(is_limited)
210+
self.assertEqual(count, 1) # Only current request
211+
212+
@patch("judge.utils.ratelimit.cache")
213+
def test_cache_failure_allows_request(self, mock_cache):
214+
"""Test that cache failures allow requests (fail-open)"""
215+
mock_cache.get.side_effect = Exception("Cache error")
216+
217+
is_limited, count, reset_time = is_rate_limited("test_key", 5, 3600)
218+
219+
self.assertFalse(is_limited)
220+
self.assertEqual(count, 0)
221+
222+
223+
class CreateRateLimitResponseTestCase(TestCase):
224+
"""Test rate limit response creation"""
225+
226+
def setUp(self):
227+
self.factory = RequestFactory()
228+
229+
def test_response_status_and_headers(self):
230+
"""Test that response has correct status and headers"""
231+
request = self.factory.get("/")
232+
reset_time = int(time.time()) + 3600
233+
234+
response = create_rate_limit_response(request, 30, 35, reset_time)
235+
236+
self.assertEqual(response.status_code, 429)
237+
self.assertEqual(response["X-RateLimit-Limit"], "30")
238+
self.assertEqual(response["X-RateLimit-Remaining"], "0")
239+
self.assertEqual(response["X-RateLimit-Reset"], str(reset_time))
240+
self.assertIn("Retry-After", response)
241+
242+
243+
class RateLimitDecoratorTestCase(TestCase):
244+
"""Test the ratelimit decorator"""
245+
246+
def setUp(self):
247+
self.factory = RequestFactory()
248+
self.user = User.objects.create_user(
249+
username="testuser", email="test@example.com", password="testpass"
250+
)
251+
cache.clear()
252+
253+
def tearDown(self):
254+
cache.clear()
255+
256+
def test_decorator_allows_within_limit(self):
257+
"""Test decorator allows requests within limit"""
258+
259+
@ratelimit(key="ip", rate="5/h")
260+
def test_view(request):
261+
return HttpResponse("OK")
262+
263+
request = self.factory.get("/")
264+
request.META["REMOTE_ADDR"] = "192.168.1.1"
265+
266+
# Make 3 requests (within limit of 5)
267+
for i in range(3):
268+
response = test_view(request)
269+
self.assertEqual(response.status_code, 200)
270+
self.assertEqual(response.content.decode(), "OK")
271+
272+
def test_decorator_blocks_over_limit(self):
273+
"""Test decorator blocks requests over limit"""
274+
275+
@ratelimit(key="ip", rate="2/h")
276+
def test_view(request):
277+
return HttpResponse("OK")
278+
279+
request = self.factory.get("/")
280+
request.META["REMOTE_ADDR"] = "192.168.1.2"
281+
282+
# Make 2 requests (at limit)
283+
for i in range(2):
284+
response = test_view(request)
285+
self.assertEqual(response.status_code, 200)
286+
287+
# Third request should be blocked
288+
response = test_view(request)
289+
self.assertEqual(response.status_code, 429)
290+
291+
def test_method_filtering(self):
292+
"""Test that method filtering works"""
293+
294+
@ratelimit(key="ip", rate="1/h", method=["POST"])
295+
def test_view(request):
296+
return HttpResponse("OK")
297+
298+
request_get = self.factory.get("/")
299+
request_post = self.factory.post("/")
300+
request_get.META["REMOTE_ADDR"] = "192.168.1.3"
301+
request_post.META["REMOTE_ADDR"] = "192.168.1.3"
302+
303+
# GET requests should not be rate limited
304+
response = test_view(request_get)
305+
self.assertEqual(response.status_code, 200)
306+
307+
response = test_view(request_get)
308+
self.assertEqual(response.status_code, 200)
309+
310+
# POST request should be rate limited
311+
response = test_view(request_post)
312+
self.assertEqual(response.status_code, 200)
313+
314+
# Second POST should be blocked
315+
response = test_view(request_post)
316+
self.assertEqual(response.status_code, 429)
317+
318+
def test_user_key_with_authenticated_user(self):
319+
"""Test user-based rate limiting with authenticated user"""
320+
321+
@ratelimit(key="user", rate="2/h")
322+
def test_view(request):
323+
return HttpResponse("OK")
324+
325+
request = self.factory.get("/")
326+
request.user = self.user
327+
328+
# Make 2 requests (at limit)
329+
for i in range(2):
330+
response = test_view(request)
331+
self.assertEqual(response.status_code, 200)
332+
333+
# Third request should be blocked
334+
response = test_view(request)
335+
self.assertEqual(response.status_code, 429)
336+
337+
def test_custom_key_function(self):
338+
"""Test custom key function"""
339+
340+
def custom_key(request):
341+
return f"custom_{request.META.get('HTTP_X_CUSTOM_ID', 'default')}"
342+
343+
@ratelimit(key=custom_key, rate="1/h")
344+
def test_view(request):
345+
return HttpResponse("OK")
346+
347+
request = self.factory.get("/")
348+
request.META["HTTP_X_CUSTOM_ID"] = "test123"
349+
350+
# First request should work
351+
response = test_view(request)
352+
self.assertEqual(response.status_code, 200)
353+
354+
# Second request should be blocked
355+
response = test_view(request)
356+
self.assertEqual(response.status_code, 429)
357+
358+
def test_block_false_allows_over_limit(self):
359+
"""Test that block=False allows requests over limit"""
360+
361+
@ratelimit(key="ip", rate="1/h", block=False)
362+
def test_view(request):
363+
# Check if rate limit status is available
364+
if hasattr(request, "rate_limit_status"):
365+
if request.rate_limit_status["limited"]:
366+
return HttpResponse("Limited but allowed", status=200)
367+
return HttpResponse("OK")
368+
369+
request = self.factory.get("/")
370+
request.META["REMOTE_ADDR"] = "192.168.1.4"
371+
372+
# First request
373+
response = test_view(request)
374+
self.assertEqual(response.status_code, 200)
375+
self.assertEqual(response.content.decode(), "OK")
376+
377+
# Second request should be allowed but marked as limited
378+
response = test_view(request)
379+
self.assertEqual(response.status_code, 200)
380+
self.assertEqual(response.content.decode(), "Limited but allowed")
381+
382+
def test_invalid_rate_raises_error(self):
383+
"""Test that invalid rate format raises ValueError"""
384+
with self.assertRaises(ValueError):
385+
386+
@ratelimit(key="ip", rate="invalid")
387+
def test_view(request):
388+
return HttpResponse("OK")
389+
390+
def test_missing_rate_raises_error(self):
391+
"""Test that missing rate raises ValueError"""
392+
with self.assertRaises(ValueError):
393+
394+
@ratelimit(key="ip")
395+
def test_view(request):
396+
return HttpResponse("OK")

0 commit comments

Comments
 (0)