Skip to content

Commit

Permalink
Add initial implementation of breach limits
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Aug 23, 2023
1 parent f72854b commit 6686ba5
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
1 change: 1 addition & 0 deletions flask_limiter/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ConfigVars:
HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
IN_MEMORY_FALLBACK_ENABLED = "RATELIMIT_IN_MEMORY_FALLBACK_ENABLED"
BREACH_LIMITS = "RATELIMIT_BREACH"


class HeaderNames(enum.Enum):
Expand Down
40 changes: 40 additions & 0 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
on_breach: Optional[
Callable[[RequestLimit], Optional[flask.wrappers.Response]]
] = None,
breach_limits: Optional[List[Union[str, Callable[[], str]]]] = None,
in_memory_fallback: Optional[List[str]] = None,
in_memory_fallback_enabled: Optional[bool] = None,
retry_after: Optional[str] = None,
Expand Down Expand Up @@ -232,6 +233,20 @@ def __init__(
else []
)

self._breach_limits = (
[
LimitGroup(
limit_provider=limit,
key_function=self._key_func,
scope="meta",
shared=True,
)
for limit in breach_limits
]
if breach_limits
else []
)

if in_memory_fallback:
for limit in in_memory_fallback:
self._in_memory_fallback.append(
Expand Down Expand Up @@ -420,6 +435,17 @@ def init_app(self, app: flask.Flask) -> None:
group.deduct_when = self._default_limits_deduct_when
group.cost = self._default_limits_cost
self.limit_manager.set_default_limits(default_limit_groups)

breach_limits = config.get(ConfigVars.BREACH_LIMITS, None)
if not self._breach_limits and breach_limits:
self._breach_limits = [
LimitGroup(
limit_provider=app_limits,
key_function=self._key_func,
scope="meta",
shared=True,
)
]
self.__configure_fallbacks(app, self._strategy)

if self not in app.extensions.setdefault("limiter", set()):
Expand Down Expand Up @@ -984,6 +1010,15 @@ def __evaluate_limits(self, endpoint: str, limits: List[Limit]) -> None:
failed_limits: List[Tuple[Limit, List[str]]] = []
limit_for_header: Optional[RequestLimit] = None
view_limits: List[RequestLimit] = []
for lim in itertools.chain(*self._breach_limits):
limit_key, scope = lim.key_func(), lim.scope_for(endpoint, None)
args = [limit_key, scope]
if not self.limiter.test(lim.limit, *args):
breach_limit = RequestLimit(self, lim.limit, args, True, lim.shared)
self.context.view_rate_limit = breach_limit
self.context.view_rate_limits = [breach_limit]
raise RateLimitExceeded(lim)

for lim in sorted(limits, key=lambda x: x.limit):
if lim.is_exempt or lim.method_exempt:
continue
Expand Down Expand Up @@ -1055,6 +1090,11 @@ def __evaluate_limits(self, endpoint: str, limits: List[Limit]) -> None:
else:
raise err
if failed_limits:
for lim in itertools.chain(*self._breach_limits):
limit_scope = lim.scope_for(endpoint, flask.request.method)
limit_key = lim.key_func()
args = [limit_key, limit_scope]
self.limiter.hit(lim.limit, *args)
raise RateLimitExceeded(
sorted(failed_limits, key=lambda x: x[0].limit)[0][0],
response=on_breach_response,
Expand Down
41 changes: 41 additions & 0 deletions tests/test_flask_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,44 @@ def app_test1():
assert cli.get("/test1").status_code == 429
timeline.forward(50)
assert cli.get("/test1").status_code == 200


def test_breach_limits(extension_factory):
app, limiter = extension_factory(
default_limits=["2/second"], breach_limits=["2/minute", "3/hour", "4/day"]
)

@app.route("/")
def root():
return "root"

with hiro.Timeline().freeze() as timeline:
start = time.time()
print(start)
with app.test_client() as cli:
for _ in range(2):
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 429
timeline.forward(1)

# blocked because of max 2 breaches/minute
assert cli.get("/").status_code == 429
timeline.forward(59)
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 429
timeline.forward(59)
# blocked because of max 3 breaches/hour
assert cli.get("/").status_code == 429
# forward to 1 hour since start
timeline.forward(60 * 58)
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 429
# forward another hour and it should now be blocked for the day
timeline.forward(60 * 60)
assert cli.get("/").status_code == 429
# forward 22 hours
timeline.forward(60 * 60 * 22)
assert cli.get("/").status_code == 200

0 comments on commit 6686ba5

Please sign in to comment.