Skip to content

Commit

Permalink
Merge pull request #885 from thunderstore-io/getcurrentuser-teams
Browse files Browse the repository at this point in the history
Augment the returned information regarding current user's teams
  • Loading branch information
anttimaki committed Oct 5, 2023
2 parents 91b3338 + abfa424 commit 216a20f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 15 deletions.
53 changes: 44 additions & 9 deletions django/thunderstore/social/api/experimental/views/current_user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from typing import List, Optional, Set, TypedDict

from django.db.models import Q
from django.utils import timezone
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers
Expand All @@ -9,6 +10,7 @@

from thunderstore.account.models.user_flag import UserFlag
from thunderstore.core.types import UserType
from thunderstore.repository.models import TeamMember


class CurrentUserExperimentalApiView(APIView):
Expand Down Expand Up @@ -46,6 +48,18 @@ class SubscriptionStatusSerializer(serializers.Serializer):
expires = serializers.DateField()


class UserTeam(TypedDict):
name: str
role: str
member_count: int


class UserTeamSerializer(serializers.Serializer):
name: serializers.CharField()
role: serializers.CharField()
member_count: serializers.IntegerField(min_value=0)


class UserProfile(TypedDict):
username: Optional[str]
capabilities: Set[str]
Expand All @@ -61,7 +75,7 @@ class UserProfileSerializer(serializers.Serializer):
connections = serializers.ListSerializer(child=SocialAuthConnectionSerializer())
subscription = SubscriptionStatusSerializer()
rated_packages = serializers.ListField()
teams = serializers.ListField()
teams = serializers.ListSerializer(child=UserTeamSerializer())


def get_empty_profile() -> UserProfile:
Expand All @@ -86,20 +100,13 @@ def get_user_profile(user: UserType) -> UserProfile:
),
)

teams = list(
user.teams.filter(team__is_active=True).values_list(
"team__name",
flat=True,
),
)

return {
"username": username,
"capabilities": capabilities,
"connections": get_social_auth_connections(user),
"subscription": get_subscription_status(user),
"rated_packages": rated_packages,
"teams": teams,
"teams": get_teams(user),
}


Expand Down Expand Up @@ -148,3 +155,31 @@ def get_social_auth_connections(user: UserType) -> List[SocialAuthConnection]:
}
for sa in user.social_auth.all()
]


def get_teams(user: UserType) -> List[UserTeam]:
"""
Return information regarding the teams the user belongs to.
"""
memberships = (
TeamMember.objects.prefetch_related(
"team__members__user__service_account",
)
.exclude(team__is_active=False)
.exclude(~Q(user=user))
)

return [
{
"name": membership.team.name,
"role": membership.role,
"member_count": len(
[
m
for m in membership.team.members.all()
if not hasattr(m.user, "service_account")
],
),
}
for membership in memberships.all()
]
54 changes: 48 additions & 6 deletions django/thunderstore/social/tests/test_current_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from rest_framework.test import APIClient
from social_django.models import UserSocialAuth # type: ignore

from thunderstore.account.factories import ServiceAccountFactory
from thunderstore.account.models.user_flag import UserFlag, UserFlagMembership
from thunderstore.core.types import UserType
from thunderstore.repository.factories import TeamMemberFactory
from thunderstore.repository.models.team import TeamMemberRole


def request_user_info(api_client: APIClient) -> Response:
Expand Down Expand Up @@ -50,7 +53,6 @@ def test_current_user_info__for_authenticated_user__has_basic_values(
assert user_info["username"] == "Test"
assert type(user_info["capabilities"]) == list
assert type(user_info["rated_packages"]) == list
assert type(user_info["teams"]) == list


@pytest.mark.django_db
Expand Down Expand Up @@ -91,11 +93,10 @@ def test_current_user_info__for_subscriber__has_subscription_expiration(
assert type(user_info["subscription"]) == dict
assert "expires" in user_info["subscription"]
expiry_datetime = datetime.datetime.fromisoformat(
user_info["subscription"]["expires"].replace("Z", "+00:00")
user_info["subscription"]["expires"].replace("Z", "+00:00"),
)
assert expiry_datetime > now + datetime.timedelta(
days=27
) and expiry_datetime < now + datetime.timedelta(days=29)
assert now + datetime.timedelta(days=27) < expiry_datetime
assert expiry_datetime < now + datetime.timedelta(days=29)


@pytest.mark.django_db
Expand Down Expand Up @@ -133,7 +134,7 @@ def test_current_user_info__for_oauth_user__has_connections(
uid="ow123",
extra_data={"username": "ow_user", "avatar": "ow_url"},
),
]
],
)

response = request_user_info(api_client)
Expand All @@ -156,3 +157,44 @@ def test_current_user_info__for_oauth_user__has_connections(
overwolf = next(c for c in user_info["connections"] if c["provider"] == "overwolf")
assert overwolf["username"] == "ow_user"
assert overwolf["avatar"] == "ow_url"


@pytest.mark.django_db
def test_current_user_info__for_team_member__has_teams(
api_client: APIClient,
user: UserType,
) -> None:
api_client.force_authenticate(user=user)
response = request_user_info(api_client)

assert response.status_code == 200

user_info = response.json()

assert type(user_info["teams"]) == list
assert len(user_info["teams"]) == 0

# First team contains only the user, second team has another member
# and a service account.
member1 = TeamMemberFactory.create(user=user, role=TeamMemberRole.owner)
member2 = TeamMemberFactory.create(user=user, role=TeamMemberRole.member)
TeamMemberFactory.create(team=member2.team)
sa = ServiceAccountFactory(owner=member2.team)
TeamMemberFactory(user=sa.user, team=member2.team)

response = request_user_info(api_client)

assert response.status_code == 200

user_info = response.json()

assert type(user_info["teams"]) == list
assert len(user_info["teams"]) == 2

team1 = next(t for t in user_info["teams"] if t["name"] == member1.team.name)
assert team1["role"] == TeamMemberRole.owner
assert team1["member_count"] == 1

team2 = next(t for t in user_info["teams"] if t["name"] == member2.team.name)
assert team2["role"] == TeamMemberRole.member
assert team2["member_count"] == 2 # ServiceAccounts do not count.

0 comments on commit 216a20f

Please sign in to comment.