Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sliding window rate limit #56

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ async def websocket_endpoint(websocket: WebSocket):

The lua script used.

### fixed window
```lua
local key = KEYS[1]
local limit = tonumber(ARGV[1])
Expand All @@ -162,6 +163,27 @@ else
end
```

### sliding window
```lua
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = tonumber(ARGV[2])
local current_time = redis.call('TIME')[1]
local start_time = current_time - expire_time / 1000

redis.call('ZREMRANGEBYSCORE', key, 0, start_time)

local current = redis.call('ZCARD', key)

if current >= limit then
return redis.call("PTTL",key)
else
redis.call("ZADD", key, current_time, current_time)
redis.call('PEXPIRE', key, expire_time)
return 0
end
```

## License

This project is licensed under the
Expand Down
11 changes: 10 additions & 1 deletion examples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter, WebSocketRateLimiter

from fastapi_limiter.constants import RateLimitType

@asynccontextmanager
async def lifespan(_: FastAPI):
Expand Down Expand Up @@ -52,6 +52,15 @@ async def websocket_endpoint(websocket: WebSocket):
except HTTPException:
await websocket.send_text("Hello again")

@app.get(
"/test_sliding_window",
dependencies=[
Depends(RateLimiter(times=2, seconds=5, rate_limit_type=RateLimitType.SLIDING_WINDOW))
],
)
async def test_sliding_window():
return {"msg": "Hello World"}


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
22 changes: 5 additions & 17 deletions fastapi_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starlette.responses import Response
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
from starlette.websockets import WebSocket
from fastapi_limiter.constants import LuaScript


async def default_identifier(request: Union[Request, WebSocket]):
Expand Down Expand Up @@ -51,22 +52,8 @@ class FastAPILimiter:
identifier: Optional[Callable] = None
http_callback: Optional[Callable] = None
ws_callback: Optional[Callable] = None
lua_script = """local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]

local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL",key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1,"px",expire_time)
return 0
end"""
lua_sha_fix_window: Optional[str] = None
lua_sha_sliding_window: Optional[str] = None

@classmethod
async def init(
Expand All @@ -82,7 +69,8 @@ async def init(
cls.identifier = identifier
cls.http_callback = http_callback
cls.ws_callback = ws_callback
cls.lua_sha = await redis.script_load(cls.lua_script)
cls.lua_sha_fix_window = await redis.script_load(LuaScript.FIXED_WINDOW_LIMIT_SCRIPT.value)
cls.lua_sha_sliding_window = await redis.script_load(LuaScript.SLIDING_WINDOW_LIMIT_SCRIPT.value)

@classmethod
async def close(cls) -> None:
Expand Down
47 changes: 47 additions & 0 deletions fastapi_limiter/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from enum import Enum


class RateLimitType(Enum):
FIXED_WINDOW = "fixed_window"
SLIDING_WINDOW = "sliding_window"


class LuaScript(Enum):
FIXED_WINDOW_LIMIT_SCRIPT = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]

local current = tonumber(redis.call('get', key) or "0")

if current > 0 then
if current + 1 > limit then
return redis.call("PTTL",key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1, "px", expire_time)
return 0
end
"""
SLIDING_WINDOW_LIMIT_SCRIPT = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = tonumber(ARGV[2])
local current_time = redis.call('TIME')[1]
local start_time = current_time - expire_time / 1000

redis.call('ZREMRANGEBYSCORE', key, 0, start_time)

local current = redis.call('ZCARD', key)

if current >= limit then
return redis.call("PTTL",key)
else
redis.call("ZADD", key, current_time, current_time)
redis.call('PEXPIRE', key, expire_time)
return 0
end
"""
24 changes: 18 additions & 6 deletions fastapi_limiter/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.websockets import WebSocket

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.constants import RateLimitType


class RateLimiter:
Expand All @@ -19,16 +20,30 @@ def __init__(
hours: Annotated[int, Field(ge=-1)] = 0,
identifier: Optional[Callable] = None,
callback: Optional[Callable] = None,
rate_limit_type: RateLimitType = RateLimitType.FIXED_WINDOW
):
self.times = times
self.milliseconds = milliseconds + 1000 * seconds + 60000 * minutes + 3600000 * hours
self.identifier = identifier
self.callback = callback
self.rate_limit_type = rate_limit_type

async def _check(self, key):
def _get_lua_sha(self, specific_lua_sha=None):
if specific_lua_sha:
return specific_lua_sha
elif self.rate_limit_type is RateLimitType.SLIDING_WINDOW:
return FastAPILimiter.lua_sha_sliding_window
return FastAPILimiter.lua_sha_fix_window


async def _check(self, key, specific_lua_sha=None):
redis = FastAPILimiter.redis
pexpire = await redis.evalsha(
FastAPILimiter.lua_sha, 1, key, str(self.times), str(self.milliseconds)
self._get_lua_sha(specific_lua_sha),
1,
key,
str(self.times),
str(self.milliseconds)
)
return pexpire

Expand All @@ -53,10 +68,7 @@ async def __call__(self, request: Request, response: Response):
try:
pexpire = await self._check(key)
except pyredis.exceptions.NoScriptError:
FastAPILimiter.lua_sha = await FastAPILimiter.redis.script_load(
FastAPILimiter.lua_script
)
pexpire = await self._check(key)
pexpire = await self._check(key, specific_lua_sha=FastAPILimiter.lua_sha_fix_window)
if pexpire != 0:
return await callback(request, response, pexpire)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages = [
]
readme = "README.md"
repository = "https://github.com/long2ice/fastapi-limiter.git"
version = "0.1.6"
version = "0.1.7"

[tool.poetry.dependencies]
redis = ">=4.2.0rc1"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def test_limiter_websockets():
data = ws.receive_text()
assert data == "Hello, world"
ws.close()


def test_limiter_sliding_window():
with TestClient(app) as client:
def req(sleep_times, assert_code):
nonlocal client
response = client.get("/test_sliding_window")
assert response.status_code == assert_code
sleep(sleep_times)

req(4, 200) # 0s
req(1, 200) # 4s
req(1, 200) # 5s
req(1, 429) # 6s
req(1, 429) # 7s
req(1, 429) # 8s
req(1, 200) # 9s