diff --git a/api/app.py b/api/app.py index da496f0..7168573 100644 --- a/api/app.py +++ b/api/app.py @@ -10,7 +10,15 @@ log = logging.getLogger() -app = FastAPI() +app = FastAPI( + title="Tech With Tim", + docs_url="/api/docs", + redoc_url="/api/redoc", + openapi_url="/api/docs/openapi.json", + openapi_tags=[ + {"name": "roles", "description": "Manage roles"}, + ], +) app.router.prefix = "/api" app.router.default_response_class = JSONResponse diff --git a/api/dependencies.py b/api/dependencies.py new file mode 100644 index 0000000..9696a9e --- /dev/null +++ b/api/dependencies.py @@ -0,0 +1,74 @@ +import jwt +import utils +import config + +from api.models import User +from typing import List, Union +from fastapi import Depends, HTTPException, Request + +from api.models import Role +from api.models.permissions import BasePermission + + +def authorization(app_only: bool = False, user_only: bool = False): + if app_only and user_only: + raise ValueError("app_only and user_only are mutually exclusive") + + async def inner(request: Request): + """Attempts to locate and decode JWT token.""" + token = request.headers.get("authorization") + + if token is None: + raise HTTPException(status_code=401) + + try: + data = jwt.decode( + jwt=token, + algorithms=["HS256"], + key=config.secret_key(), + ) + except jwt.PyJWTError: + raise HTTPException(status_code=401, detail="Invalid token.") + + data["uid"] = int(data["uid"]) + + user = await User.fetch(data["uid"]) + if not user: + raise HTTPException(status_code=401, detail="Invalid token.") + + if app_only and not user.app: + raise HTTPException(status_code=403, detail="Users can't use this endpoint") + + if user_only and user.app: + raise HTTPException(status_code=403, detail="Bots can't use this endpoint") + + return user + + return Depends(inner) + + +def has_permissions(permissions: List[Union[int, BasePermission]]): + async def inner(user=authorization()): + query = """ + SELECT * + FROM roles r + WHERE r.id IN ( + SELECT ur.role_id + FROM userroles ur + WHERE ur.user_id = $1 + ) + """ + records = await Role.pool.fetch(query, user.id) + if not records: + raise HTTPException(403, "Missing Permissions") + + user_permissions = 0 + for record in records: + user_permissions |= record["permissions"] + + if not utils.has_permissions(user_permissions, permissions): + raise HTTPException(403, "Missing Permissions") + + return [Role(**record) for record in records] + + return Depends(inner) diff --git a/api/models b/api/models index 50531e1..f11c37d 160000 --- a/api/models +++ b/api/models @@ -1 +1 @@ -Subproject commit 50531e15ab5eac0db6c1d15469a396932f6f1b37 +Subproject commit f11c37d62259f36731d0981a8fb99b3d78186550 diff --git a/api/versions/v1/routers/roles/__init__.py b/api/versions/v1/routers/roles/__init__.py new file mode 100644 index 0000000..9bd4168 --- /dev/null +++ b/api/versions/v1/routers/roles/__init__.py @@ -0,0 +1,4 @@ +from .routes import router + + +__all__ = (router,) diff --git a/api/versions/v1/routers/roles/models.py b/api/versions/v1/routers/roles/models.py new file mode 100644 index 0000000..712cb89 --- /dev/null +++ b/api/versions/v1/routers/roles/models.py @@ -0,0 +1,27 @@ +from typing import List, Optional +from pydantic import BaseModel, Field + + +class RoleResponse(BaseModel): + id: str + name: str + position: int + permissions: int + color: Optional[int] + + +class DetailedRoleResponse(RoleResponse): + members: List[str] + + +class NewRoleBody(BaseModel): + name: str = Field(..., min_length=4, max_length=32) + color: Optional[int] = Field(None, le=0xFFFFFF, ge=0) + permissions: Optional[int] = Field(0, ge=0) + + +class UpdateRoleBody(BaseModel): + name: str = Field("", min_length=4, max_length=64) + color: Optional[int] = Field(None, le=0xFFFFFF, ge=0) + permissions: int = Field(0, ge=0) + position: int = Field(0, ge=0) diff --git a/api/versions/v1/routers/roles/routes.py b/api/versions/v1/routers/roles/routes.py new file mode 100644 index 0000000..4beca11 --- /dev/null +++ b/api/versions/v1/routers/roles/routes.py @@ -0,0 +1,283 @@ +import utils +import asyncpg + +from typing import List, Union +from fastapi import APIRouter, HTTPException, Response + +from api.models import Role, UserRole +from api.dependencies import has_permissions +from api.models.permissions import ManageRoles +from api.versions.v1.routers.roles.models import ( + NewRoleBody, + RoleResponse, + UpdateRoleBody, + DetailedRoleResponse, +) + + +router = APIRouter(prefix="/roles") + + +@router.get("", tags=["roles"], response_model=List[RoleResponse]) +async def fetch_all_roles(): + """Fetch all roles""" + + query = """ + SELECT *, + r.id::TEXT + FROM roles r + """ + records = await Role.pool.fetch(query) + + return [dict(record) for record in records] + + +@router.get( + "/{id}", + tags=["roles"], + response_model=DetailedRoleResponse, + responses={ + 404: {"description": "Role not found"}, + }, +) +async def fetch_role(id: int): + """Fetch a role by its id""" + + query = """ + SELECT *, + id::TEXT, + COALESCE( + ( + SELECT json_agg(ur.user_id::TEXT) + FROM userroles ur + WHERE ur.role_id = r.id + ), '[]' + ) members + FROM roles r + WHERE r.id = $1 + """ + record = await Role.pool.fetchrow(query, id) + + if not record: + raise HTTPException(404, "Role not found") + + return dict(record) + + +@router.post( + "", + tags=["roles"], + response_model=RoleResponse, + responses={ + 201: {"description": "Role Created Successfully"}, + 401: {"description": "Unauthorized"}, + 403: {"description": "Missing Permissions"}, + 409: {"description": "Role with that name already exists"}, + }, + status_code=201, +) +async def create_role(body: NewRoleBody, roles=has_permissions([ManageRoles()])): + # Check if the user has administrator permission or all the permissions provided in the role + user_permissions = 0 + for role in roles: + user_permissions |= role.permissions + + if not utils.has_permission(user_permissions, body.permissions): + raise HTTPException(403, "Missing Permissions") + + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), $1, $2, $3, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + + try: + record = await Role.pool.fetchrow( + query, body.name, body.color, body.permissions + ) + except asyncpg.exceptions.UniqueViolationError: + raise HTTPException(409, "Role with that name already exists") + + return utils.JSONResponse(status_code=201, content=dict(record)) + + +@router.patch( + "/{id}", + tags=["roles"], + responses={ + 204: {"description": "Role Updated Successfully"}, + 401: {"description": "Unauthorized"}, + 403: {"description": "Missing Permissions"}, + 404: {"description": "Role not found"}, + 409: {"description": "Role with that name already exists"}, + }, + status_code=204, +) +async def update_role( + id: int, + body: UpdateRoleBody, + roles=has_permissions([ManageRoles()]), +): + role = await Role.fetch(id) + if not role: + raise HTTPException(404, "Role Not Found") + + # Check if the user has administrator permission or all the permissions provided in the role + user_permissions = 0 + for r in roles: + user_permissions |= r.permissions + + top_role = min(roles, key=lambda role: role.position) + if top_role.position >= role.position: + raise HTTPException(403, "Missing Permissions") + + data = body.dict(exclude_unset=True) + if not utils.has_permission(user_permissions, body.permissions): + raise HTTPException(403, "Missing Permissions") + + if name := data.get("name", None): + record = await Role.pool.fetchrow("SELECT * FROM roles WHERE name = $1", name) + + if record: + raise HTTPException(409, "Role with that name already exists") + + if ( + position := data.pop("position", None) + ) is not None and position != role.position: + if position <= top_role.position: + raise HTTPException(403, "Missing Permissions") + + if position > role.position: + new_pos = position + 0.5 + else: + new_pos = position - 0.5 + + query = """ + UPDATE roles r SET position = $1 + WHERE r.id = $2; + """ + await Role.pool.execute(query, new_pos, id) + + query = """ + WITH todo AS ( + SELECT r.id, + ROW_NUMBER() OVER (ORDER BY position) AS position + FROM roles r + ) + UPDATE roles r SET + position = td.position + FROM todo td + WHERE r.id = td.id; + """ + await Role.pool.execute(query) + + if data: + query = "UPDATE ROLES SET " + query += ", ".join("%s = $%d" % (key, i) for i, key in enumerate(data, 2)) + query += " WHERE id = $1" + + await Role.pool.execute(query, id, *data.values()) + + return Response(status_code=204, content="") + + +@router.delete( + "/{id}", + tags=["roles"], + responses={ + 204: {"description": "Role Updated Successfully"}, + 401: {"description": "Unauthorized"}, + 403: {"description": "Missing Permissions"}, + 404: {"description": "Role not found"}, + }, + status_code=204, +) +async def delete_role(id: int, roles=has_permissions([ManageRoles()])): + role = await Role.fetch(id) + if not role: + raise HTTPException(404, "Role Not Found") + + top_role = min(roles, key=lambda role: role.position) + if top_role.position >= role.position: + raise HTTPException(403, "Missing Permissions") + + query = """ + WITH deleted AS ( + DELETE FROM roles r + WHERE r.id = $1 + RETURNING r.id + ), + to_update AS ( + SELECT r.id, + ROW_NUMBER() OVER (ORDER BY r.position) AS position + FROM roles r + WHERE r.id != (SELECT id FROM deleted) + ) + UPDATE roles r SET + position = tu.position + FROM to_update tu + WHERE r.id = tu.id + """ + await Role.pool.execute(query, id) + + return Response(status_code=204, content="") + + +@router.put( + "/{role_id}/members/{member_id}", + tags=["roles"], + responses={ + 204: {"description": "Role assigned to member"}, + 401: {"description": "Unauthorized"}, + 403: {"description": "Missing Permissions"}, + 404: {"description": "Role or member not found"}, + 409: {"description": "User already has the role"}, + }, + status_code=204, +) +async def add_member_to_role( + role_id: int, member_id: int, roles=has_permissions([ManageRoles()]) +) -> Union[Response, utils.JSONResponse]: + role = await Role.fetch(role_id) + if not role: + raise HTTPException(404, "Role Not Found") + + top_role = min(roles, key=lambda role: role.position) + if top_role.position >= role.position: + raise HTTPException(403, "Missing Permissions") + + try: + await UserRole.create(member_id, role_id) + except asyncpg.exceptions.UniqueViolationError: + raise HTTPException(409, "User already has the role") + except asyncpg.exceptions.ForeignKeyViolationError: + raise HTTPException(404, "Member not found") + + return Response(status_code=204, content="") + + +@router.delete( + "/{role_id}/members/{member_id}", + tags=["roles"], + responses={ + 204: {"description": "Role removed from member"}, + 401: {"description": "Unauthorized"}, + 403: {"description": "Missing Permissions"}, + 404: {"description": "Role not found"}, + }, + status_code=204, +) +async def remove_member_from_role( + role_id: int, member_id: int, roles=has_permissions([ManageRoles()]) +) -> Union[Response, utils.JSONResponse]: + role = await Role.fetch(role_id) + if not role: + raise HTTPException(404, "Role Not Found") + + top_role = min(roles, key=lambda role: role.position) + if top_role.position >= role.position: + raise HTTPException(403, "Missing Permissions") + + await UserRole.delete(member_id, role_id) + + return Response(status_code=204, content="") diff --git a/api/versions/v1/routers/router.py b/api/versions/v1/routers/router.py index f9810cc..5fe0698 100644 --- a/api/versions/v1/routers/router.py +++ b/api/versions/v1/routers/router.py @@ -1,6 +1,9 @@ from fastapi import APIRouter + from . import auth +from . import roles router = APIRouter(prefix="/v1") router.include_router(auth.router) +router.include_router(roles.router) diff --git a/tests/conftest.py b/tests/conftest.py index c5f93ab..8928c74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ -from launch import prepare_postgres, safe_create_tables, delete_tables +import jwt import config +import pytest +import asyncio -from httpx import AsyncClient from postDB import Model -import asyncio -import pytest +from httpx import AsyncClient + +from api.models import User +from launch import prepare_postgres, safe_create_tables, delete_tables @pytest.fixture(scope="session") @@ -34,6 +37,20 @@ async def db(event_loop) -> bool: await delete_tables() +@pytest.fixture(scope="function") +async def user(db): + yield await User.create(0, "Test", "0001") + await db.execute("""DELETE FROM users WHERE username = 'Test'""") + + +@pytest.fixture(scope="function") +async def token(user, db): + yield jwt.encode( + {"uid": user.id}, + key=config.secret_key(), + ) + + def pytest_addoption(parser): parser.addoption( "--no-db", diff --git a/tests/test_auth.py b/tests/test_auth.py index e1eee88..55e46b2 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -84,10 +84,10 @@ async def exchange_code(**kwargs): async def get_user(**kwargs): return { - "username": "M7MD", - "discriminator": "1701", - "id": 601173582516584602, - "avatar": "135fa48ba8f26417c4b9818ae2e37aa0", + "id": 1, + "username": "Test2", + "avatar": "avatar", + "discriminator": "0001", } mocker.patch("api.versions.v1.routers.auth.routes.get_user", new=get_user) @@ -99,3 +99,5 @@ async def get_user(**kwargs): ) assert res.status_code == 200 + + await db.execute("DELETE FROM users WHERE id = 1") diff --git a/tests/test_roles.py b/tests/test_roles.py new file mode 100644 index 0000000..dfc6dc0 --- /dev/null +++ b/tests/test_roles.py @@ -0,0 +1,393 @@ +import pytest + +from httpx import AsyncClient + +from api.models import Role, UserRole +from api.models.permissions import ManageRoles + + +@pytest.fixture +async def manage_roles_role(db): + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), $1, $2, $3, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + record = await Role.pool.fetchrow(query, "Roles Manager", 0x0, ManageRoles().value) + yield Role(**record) + await db.execute("DELETE FROM roles WHERE id = $1;", record["id"]) + + +@pytest.mark.db +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("data", "status"), + [ + ({}, 422), + ({"name": ""}, 422), + ({"permissions": -1}, 422), + ({"name": "test1", "color": 0xFFFFFFF}, 422), + ({"name": "test1", "color": -0x000001}, 422), + ({"name": "test2", "color": 0x000000, "permissions": 8}, 403), + ({"name": "test2", "color": 0x000000, "permissions": 0}, 201), + ({"name": "test2", "color": 0x000000, "permissions": 0}, 409), + ], +) +async def test_role_create( + app: AsyncClient, db, user, token, manage_roles_role, data, status +): + try: + await UserRole.create(user.id, manage_roles_role.id) + res = await app.post( + "/api/v1/roles", json=data, headers={"Authorization": token} + ) + assert res.status_code == status + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + if status == 409: + await db.execute("DELETE FROM roles WHERE name = $1", data["name"]) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_fetch_all_roles(app: AsyncClient): + res = await app.get("/api/v1/roles") + + assert res.status_code == 200 + assert type(res.json()) == list + + +@pytest.mark.db +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("request_data", "new_data", "status"), + [ + ({}, {"name": "test update", "permissions": 0, "color": 0}, 204), + ({"name": ""}, {"name": "test update", "permissions": 0, "color": 0}, 422), + ( + {"permissions": -1}, + {"name": "test update", "permissions": 0, "color": 0}, + 422, + ), + ( + {"color": 0xFFFFFFF}, + {"name": "test update", "permissions": 0, "color": 0}, + 422, + ), + ( + {"color": -0x000001}, + {"name": "test update", "permissions": 0, "color": 0}, + 422, + ), + ( + {"color": 0x5, "permissions": 8}, + {"name": "test update", "permissions": 0, "color": 0x0}, + 403, + ), + ( + {"color": 0x5, "permissions": ManageRoles().value}, + {"name": "test update", "permissions": ManageRoles().value, "color": 0x5}, + 204, + ), + ], +) +async def test_role_update( + app: AsyncClient, db, user, token, manage_roles_role, request_data, new_data, status +): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test update', 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.patch( + f"/api/v1/roles/{role.id}", + json=request_data, + headers={"Authorization": token}, + ) + + assert res.status_code == status + + role = await Role.fetch(role.id) + + data = role.as_dict() + data.pop("id") + data.pop("position") + + assert data == new_data + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_delete(app: AsyncClient, db, user, token, manage_roles_role): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test delete', 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.delete( + f"/api/v1/roles/{role.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 204 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_delete_high_position( + app: AsyncClient, db, user, token, manage_roles_role +): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test delete', 0, 0, 0) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.delete( + f"/api/v1/roles/{role.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 403 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_add(app: AsyncClient, db, user, token, manage_roles_role): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test add', 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.put( + f"/api/v1/roles/{role.id}/members/{user.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 204 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_add_high_position( + app: AsyncClient, db, user, token, manage_roles_role +): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test add', 0, 0, 0) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.put( + f"/api/v1/roles/{role.id}/members/{user.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 403 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_remove(app: AsyncClient, db, user, token, manage_roles_role): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test remove', 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.delete( + f"/api/v1/roles/{role.id}/members/{user.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 204 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_role_remove_high_position( + app: AsyncClient, db, user, token, manage_roles_role +): + try: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), 'test remove', 0, 0, 0) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query)) + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.delete( + f"/api/v1/roles/{role.id}/members/{user.id}", + headers={"Authorization": token}, + ) + + assert res.status_code == 403 + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_update_role_positions_up( + app: AsyncClient, db, user, token, manage_roles_role +): + try: + roles = [] + # manage roles -> 1 -> 3 -> 2 -> 4 + role_names = ["1", "3", "2", "4"] + for role_name in role_names: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), $1, 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query, role_name)) + roles.append(role) + + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.patch( + f"/api/v1/roles/{roles[2].id}", + json={"position": 3}, + headers={"Authorization": token}, + ) + assert res.status_code == 204 + + res = await app.get("/api/v1/roles") + new_roles = sorted(res.json(), key=lambda x: x["position"]) + + for i, role in enumerate(new_roles, 1): + assert ( + role["position"] == i + ) # make sure roles are ordered with no missing positions + + for i in range(1, 5): + assert new_roles[i]["name"] == str(i) + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + for role in roles: + await db.execute("DELETE FROM roles WHERE id = $1", role.id) + + +@pytest.mark.db +@pytest.mark.asyncio +async def test_update_role_positions_down( + app: AsyncClient, db, user, token, manage_roles_role +): + try: + roles = [] + # manage roles -> 1 -> 3 -> 2 -> 4 + role_names = ["1", "3", "2", "4"] + for role_name in role_names: + query = """ + INSERT INTO roles (id, name, color, permissions, position) + VALUES (create_snowflake(), $1, 0, 0, (SELECT COUNT(*) FROM roles) + 1) + RETURNING *; + """ + role = Role(**await Role.pool.fetchrow(query, role_name)) + roles.append(role) + + await UserRole.create(user.id, manage_roles_role.id) + + res = await app.patch( + f"/api/v1/roles/{roles[1].id}", + json={"position": 4}, + headers={"Authorization": token}, + ) + assert res.status_code == 204 + + res = await app.get("/api/v1/roles") + new_roles = sorted(res.json(), key=lambda x: x["position"]) + + for i, role in enumerate(new_roles, 1): + assert ( + role["position"] == i + ) # make sure roles are ordered with no missing positions + + for i in range(1, 5): + assert new_roles[i]["name"] == str(i) + finally: + await db.execute( + "DELETE FROM userroles WHERE role_id = $1 AND user_id = $2;", + manage_roles_role.id, + user.id, + ) + for role in roles: + await db.execute("DELETE FROM roles WHERE id = $1", role.id) diff --git a/utils/__init__.py b/utils/__init__.py index 7670ed7..7723b5a 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,7 +1,10 @@ from .time import snowflake_time from .response import JSONResponse +from .permissions import has_permission, has_permissions __all__ = ( - "JSONResponse", - "snowflake_time", + JSONResponse, + snowflake_time, + has_permission, + has_permissions, ) diff --git a/utils/permissions.py b/utils/permissions.py new file mode 100644 index 0000000..e824932 --- /dev/null +++ b/utils/permissions.py @@ -0,0 +1,30 @@ +from typing import Union, List +from api.models.permissions import BasePermission, Administrator + + +def has_permissions( + permissions: int, required: List[Union[int, BasePermission]] +) -> bool: + """Returns `True` if `permissions` has all required permissions""" + if permissions & Administrator().value: + return True + + all_perms = 0 + for perm in required: + if isinstance(perm, int): + all_perms |= perm + else: + all_perms |= perm.value + + return permissions & all_perms == all_perms + + +def has_permission(permissions: int, permission: Union[BasePermission, int]) -> bool: + """Returns `True` if `permissions` has required permission""" + if permissions & Administrator().value: + return True + + if isinstance(permission, int): + return permissions & permission == permission + + return permissions & permission.value == permission.value