Skip to content

Commit

Permalink
Merge pull request #884 from thunderstore-io/getcurrentuser-connections
Browse files Browse the repository at this point in the history
Return OAuth connection info along current user's information
  • Loading branch information
anttimaki authored Oct 5, 2023
2 parents e03715e + 0e8c268 commit 91b3338
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 2 deletions.
45 changes: 45 additions & 0 deletions django/thunderstore/social/api/experimental/views/current_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def get(self, request, format=None):
return Response(profile)


class SocialAuthConnection(TypedDict):
provider: str
username: str
avatar: Optional[str]


class SocialAuthConnectionSerializer(serializers.Serializer):
provider: serializers.CharField()
username: serializers.CharField()
avatar: serializers.CharField()


class SubscriptionStatus(TypedDict):
expires: Optional[datetime.datetime]

Expand All @@ -37,6 +49,7 @@ class SubscriptionStatusSerializer(serializers.Serializer):
class UserProfile(TypedDict):
username: Optional[str]
capabilities: Set[str]
connections: List[SocialAuthConnection]
subscription: SubscriptionStatus
rated_packages: List[str]
teams: List[str]
Expand All @@ -45,6 +58,7 @@ class UserProfile(TypedDict):
class UserProfileSerializer(serializers.Serializer):
username = serializers.CharField()
capabilities = serializers.ListField()
connections = serializers.ListSerializer(child=SocialAuthConnectionSerializer())
subscription = SubscriptionStatusSerializer()
rated_packages = serializers.ListField()
teams = serializers.ListField()
Expand All @@ -54,6 +68,7 @@ def get_empty_profile() -> UserProfile:
return {
"username": None,
"capabilities": set(),
"connections": [],
"subscription": get_subscription_status(user=None),
"rated_packages": [],
"teams": [],
Expand Down Expand Up @@ -81,6 +96,7 @@ def get_user_profile(user: UserType) -> UserProfile:
return {
"username": username,
"capabilities": capabilities,
"connections": get_social_auth_connections(user),
"subscription": get_subscription_status(user),
"rated_packages": rated_packages,
"teams": teams,
Expand All @@ -103,3 +119,32 @@ def get_subscription_status(user: Optional[UserType]) -> SubscriptionStatus:
return {"expires": (now + datetime.timedelta(weeks=4))}

return {"expires": None}


OAUTH_USERNAME_FIELDS = {
"discord": "username",
"github": "login",
"overwolf": "username",
}


OAUTH_AVATAR_FIELDS = {
"discord": "!NOT_SUPPORTED!",
"github": "avatar_url",
"overwolf": "avatar",
}


def get_social_auth_connections(user: UserType) -> List[SocialAuthConnection]:
"""
Return information regarding user's registered OAuth logins.
"""

return [
{
"provider": sa.provider,
"username": sa.extra_data[OAUTH_USERNAME_FIELDS[sa.provider]],
"avatar": sa.extra_data.get(OAUTH_AVATAR_FIELDS[sa.provider]),
}
for sa in user.social_auth.all()
]
82 changes: 80 additions & 2 deletions django/thunderstore/social/tests/test_current_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.utils import timezone
from rest_framework.response import Response
from rest_framework.test import APIClient
from social_django.models import UserSocialAuth # type: ignore

from thunderstore.account.models.user_flag import UserFlag, UserFlagMembership
from thunderstore.core.types import UserType
Expand All @@ -29,12 +30,13 @@ def test_current_user_info__for_unauthenticated_user__is_empty_structure(
assert user_info["username"] is None
assert user_info["subscription"]["expires"] is None
assert len(user_info["capabilities"]) == 0
assert len(user_info["connections"]) == 0
assert len(user_info["rated_packages"]) == 0
assert len(user_info["teams"]) == 0


@pytest.mark.django_db
def test_current_user_info__for_authenticated_user__has_proper_values(
def test_current_user_info__for_authenticated_user__has_basic_values(
api_client: APIClient,
user: UserType,
) -> None:
Expand All @@ -46,6 +48,23 @@ def test_current_user_info__for_authenticated_user__has_proper_values(
user_info = response.json()

assert user_info["username"] == "Test"
assert type(user_info["capabilities"]) == list
assert type(user_info["rated_packages"]) == list
assert type(user_info["teams"]) == list


@pytest.mark.django_db
def test_current_user_info__for_subscriber__has_subscription_expiration(
api_client: APIClient,
user: UserType,
) -> None:
api_client.force_authenticate(user=user)
response = request_user_info(api_client)

assert response.status_code == 200

user_info = response.json()

assert type(user_info["subscription"]) == dict
assert "expires" in user_info["subscription"]
assert user_info["subscription"]["expires"] is None
Expand All @@ -69,7 +88,6 @@ def test_current_user_info__for_authenticated_user__has_proper_values(

user_info = response.json()

assert user_info["username"] == "Test"
assert type(user_info["subscription"]) == dict
assert "expires" in user_info["subscription"]
expiry_datetime = datetime.datetime.fromisoformat(
Expand All @@ -78,3 +96,63 @@ def test_current_user_info__for_authenticated_user__has_proper_values(
assert expiry_datetime > now + datetime.timedelta(
days=27
) and expiry_datetime < now + datetime.timedelta(days=29)


@pytest.mark.django_db
def test_current_user_info__for_oauth_user__has_connections(
api_client: APIClient,
user: UserType,
) -> None:
api_client.force_authenticate(user=user)
response = request_user_info(api_client)

assert response.status_code == 200

user_info = response.json()

assert type(user_info["connections"]) == list
assert len(user_info["connections"]) == 0

UserSocialAuth.objects.bulk_create(
[
UserSocialAuth(
user=user,
provider="discord",
uid="d123",
extra_data={"username": "discord_user"},
),
UserSocialAuth(
user=user,
provider="github",
uid="gh123",
extra_data={"login": "gh_user", "avatar_url": "gh_url"},
),
UserSocialAuth(
user=user,
provider="overwolf",
uid="ow123",
extra_data={"username": "ow_user", "avatar": "ow_url"},
),
]
)

response = request_user_info(api_client)

assert response.status_code == 200

user_info = response.json()

assert type(user_info["connections"]) == list
assert len(user_info["connections"]) == 3

discord = next(c for c in user_info["connections"] if c["provider"] == "discord")
assert discord["username"] == "discord_user"
assert discord["avatar"] is None

github = next(c for c in user_info["connections"] if c["provider"] == "github")
assert github["username"] == "gh_user"
assert github["avatar"] == "gh_url"

overwolf = next(c for c in user_info["connections"] if c["provider"] == "overwolf")
assert overwolf["username"] == "ow_user"
assert overwolf["avatar"] == "ow_url"

0 comments on commit 91b3338

Please sign in to comment.