Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions todo/middlewares/team_access_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from django.http import JsonResponse
from django.urls import resolve
from rest_framework import status

from todo.constants.messages import ApiErrors
from todo.constants.role import RoleScope
from todo.services.user_role_service import UserRoleService

logger = logging.getLogger(__name__)


class TeamAccessMiddleware:
"""
Middleware to handle team access control for specific routes.
Only applies to routes that contain 'teams/<team_id>' pattern.
"""

def __init__(self, get_response):
self.get_response = get_response
self.protected_routes = [
"team_detail",
"team_activity_timeline",
]

def __call__(self, request):
resolved_url = resolve(request.path_info)
route_name = resolved_url.url_name

if route_name in self.protected_routes:
try:
team_id = resolved_url.kwargs.get("team_id")

if not team_id:
return JsonResponse({"detail": "Team ID is required."}, status=status.HTTP_400_BAD_REQUEST)

user_id = getattr(request, "user_id", None)

user_team_roles = UserRoleService.get_user_roles(
user_id=user_id, scope=RoleScope.TEAM.value, team_id=team_id
)

if not user_team_roles:
return JsonResponse({"detail": ApiErrors.UNAUTHORIZED_TITLE}, status=status.HTTP_403_FORBIDDEN)

except Exception as e:
logger.error(f"Error in TeamAccessMiddleware: {str(e)}")
return JsonResponse(
{"detail": ApiErrors.INTERNAL_SERVER_ERROR}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

response = self.get_response(request)
return response
79 changes: 79 additions & 0 deletions todo/tests/unit/middlewares/test_team_access_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from unittest import TestCase
from unittest.mock import Mock, patch
from django.http import HttpRequest, JsonResponse
from rest_framework import status
import json

from todo.middlewares.team_access_middleware import TeamAccessMiddleware
from todo.constants.messages import ApiErrors


class TeamAccessMiddlewareTests(TestCase):
def setUp(self):
self.get_response = Mock(return_value=JsonResponse({"data": "success"}))
self.middleware = TeamAccessMiddleware(self.get_response)
self.request = Mock(spec=HttpRequest)
self.request.user_id = "user123"
self.request.path_info = "/v1/teams/team123"

@patch("todo.middlewares.team_access_middleware.resolve")
def test_protected_route_with_valid_access(self, mock_resolve):
mock_resolve.return_value.url_name = "team_detail"
mock_resolve.return_value.kwargs = {"team_id": "team123"}

with patch("todo.middlewares.team_access_middleware.UserRoleService.get_user_roles") as mock_get_roles:
mock_get_roles.return_value = [{"role": "admin"}]

response = self.middleware(self.request)

self.assertEqual(response.status_code, 200)
self.get_response.assert_called_once_with(self.request)

@patch("todo.middlewares.team_access_middleware.resolve")
def test_protected_route_with_no_access(self, mock_resolve):
mock_resolve.return_value.url_name = "team_detail"
mock_resolve.return_value.kwargs = {"team_id": "team123"}

with patch("todo.middlewares.team_access_middleware.UserRoleService.get_user_roles") as mock_get_roles:
mock_get_roles.return_value = []

response = self.middleware(self.request)

self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
response_data = json.loads(response.content)
self.assertEqual(response_data["detail"], ApiErrors.UNAUTHORIZED_TITLE)

@patch("todo.middlewares.team_access_middleware.resolve")
def test_unprotected_route_bypasses_middleware(self, mock_resolve):
mock_resolve.return_value.url_name = "task_list"
mock_resolve.return_value.kwargs = {}

response = self.middleware(self.request)

self.assertEqual(response.status_code, 200)
self.get_response.assert_called_once_with(self.request)

@patch("todo.middlewares.team_access_middleware.resolve")
def test_middleware_handles_exception_with_500(self, mock_resolve):
mock_resolve.return_value.url_name = "team_detail"
mock_resolve.return_value.kwargs = {"team_id": "team123"}

with patch("todo.middlewares.team_access_middleware.UserRoleService.get_user_roles") as mock_get_roles:
mock_get_roles.side_effect = Exception("Database connection failed")

response = self.middleware(self.request)

self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
response_data = json.loads(response.content)
self.assertEqual(response_data["detail"], ApiErrors.INTERNAL_SERVER_ERROR)

@patch("todo.middlewares.team_access_middleware.resolve")
def test_missing_team_id_returns_400(self, mock_resolve):
mock_resolve.return_value.url_name = "team_detail"
mock_resolve.return_value.kwargs = {}

response = self.middleware(self.request)

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
response_data = json.loads(response.content)
self.assertEqual(response_data["detail"], "Team ID is required.")
5 changes: 5 additions & 0 deletions todo/views/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class TaskListView(APIView):
responses={
200: OpenApiResponse(response=GetTasksResponse, description="Successful response"),
400: OpenApiResponse(description="Bad request"),
403: OpenApiResponse(description="Forbidden"),
500: OpenApiResponse(description="Internal server error"),
},
)
Expand Down Expand Up @@ -108,6 +109,10 @@ def get(self, request: Request):
team_id=team_id,
status_filter=status_filter,
)

if response.error and response.error.get("code") == "FORBIDDEN":
return Response(data=response.model_dump(mode="json"), status=status.HTTP_403_FORBIDDEN)

return Response(data=response.model_dump(mode="json"), status=status.HTTP_200_OK)

@extend_schema(
Expand Down
4 changes: 4 additions & 0 deletions todo/views/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def get(self, request: Request, team_id: str):
200: OpenApiResponse(response=TeamDTO, description="Team updated successfully"),
400: OpenApiResponse(description="Bad request - validation error or invalid member IDs"),
404: OpenApiResponse(description="Team not found"),
403: OpenApiResponse(description="Forbidden"),
500: OpenApiResponse(description="Internal server error"),
},
)
Expand Down Expand Up @@ -288,6 +289,7 @@ class AddTeamMembersView(APIView):
200: OpenApiResponse(response=TeamDTO, description="Team members added successfully"),
400: OpenApiResponse(description="Bad request - validation error or user not a team member"),
404: OpenApiResponse(description="Team not found"),
403: OpenApiResponse(description="Forbidden"),
500: OpenApiResponse(description="Internal server error"),
},
)
Expand Down Expand Up @@ -404,6 +406,7 @@ class TeamActivityTimelineView(APIView):
},
description="Team activity timeline returned successfully",
),
403: OpenApiResponse(description="Forbidden"),
404: OpenApiResponse(description="Team not found"),
},
)
Expand Down Expand Up @@ -470,6 +473,7 @@ class RemoveTeamMemberView(APIView):
],
responses={
204: OpenApiResponse(description="User removed from team successfully."),
403: OpenApiResponse(description="Forbidden"),
404: OpenApiResponse(description="Team or user not found."),
400: OpenApiResponse(description="Bad request or other error."),
},
Expand Down
1 change: 1 addition & 0 deletions todo_project/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.common.CommonMiddleware",
"todo.middlewares.jwt_auth.JWTAuthenticationMiddleware",
"todo.middlewares.team_access_middleware.TeamAccessMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]

Expand Down