From a55b6161e33505ad19e26b9c3633e071161baa0b Mon Sep 17 00:00:00 2001 From: George Margaritis Date: Fri, 30 Aug 2024 18:43:04 +0300 Subject: [PATCH] Fix access_token expiration and refresh handling in GitHub backend Ensure the correct key is used for access_token expiration in the GitHub backend's extra_data, and save the refresh_token. Previously, the expiration of the access_token was not stored, causing the refresh_token functionality to be skipped. Signed-off-by: George Margaritis --- CHANGELOG.md | 1 + social_core/backends/github.py | 7 ++- social_core/tests/backends/test_github.py | 61 ++++++++++++++++++++++- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b570d970..c4b573162 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Handle case where user has not registered a `family-name` with ORCID +- Fix access token expiration and refresh token handling in GitHub backend ## [4.5.4](https://github.com/python-social-auth/social-core/releases/tag/4.5.4) - 2024-04-25 diff --git a/social_core/backends/github.py b/social_core/backends/github.py index e97322d22..14bb7931b 100644 --- a/social_core/backends/github.py +++ b/social_core/backends/github.py @@ -23,7 +23,12 @@ class GithubOAuth2(BaseOAuth2): REDIRECT_STATE = False STATE_PARAMETER = True SEND_USER_AGENT = True - EXTRA_DATA = [("id", "id"), ("expires", "expires"), ("login", "login")] + EXTRA_DATA = [ + ("id", "id"), + ("expires_in", "expires"), + ("login", "login"), + ("refresh_token", "refresh_token"), + ] def api_url(self): return self.API_URL diff --git a/social_core/tests/backends/test_github.py b/social_core/tests/backends/test_github.py index a4edc8d67..eac2d9cbc 100644 --- a/social_core/tests/backends/test_github.py +++ b/social_core/tests/backends/test_github.py @@ -10,7 +10,24 @@ class GithubOAuth2Test(OAuth2Test): backend_path = "social_core.backends.github.GithubOAuth2" user_data_url = "https://api.github.com/user" expected_username = "foobar" - access_token_body = json.dumps({"access_token": "foobar", "token_type": "bearer"}) + access_token_body = json.dumps( + { + "access_token": "foobar", + "token_type": "bearer", + "expires_in": 28800, + "refresh_token": "foobar-refresh-token", + } + ) + refresh_token_body = json.dumps( + { + "access_token": "foobar-new-token", + "token_type": "bearer", + "expires_in": 28800, + "refresh_token": "foobar-new-refresh-token", + "refresh_token_expires_in": 15897600, + "scope": "", + } + ) user_data_body = json.dumps( { "login": "foobar", @@ -46,12 +63,25 @@ class GithubOAuth2Test(OAuth2Test): } ) + def do_login(self): + user = super().do_login() + social = user.social[0] + + self.assertIsNotNone(social.extra_data["expires"]) + self.assertIsNotNone(social.extra_data["refresh_token"]) + + return user + def test_login(self): self.do_login() def test_partial_pipeline(self): self.do_partial_pipeline() + def test_refresh_token(self): + user, social = self.do_refresh_token() + self.assertEqual(social.extra_data["access_token"], "foobar-new-token") + class GithubOAuth2NoEmailTest(GithubOAuth2Test): user_data_body = json.dumps( @@ -122,6 +152,17 @@ def test_partial_pipeline(self): ) self.do_partial_pipeline() + def test_refresh_token(self): + url = "https://api.github.com/user/emails" + HTTPretty.register_uri( + HTTPretty.GET, + url, + status=200, + body=json.dumps([{"email": "foo@bar.com"}]), + content_type="application/json", + ) + self.do_refresh_token() + class GithubOrganizationOAuth2Test(GithubOAuth2Test): backend_path = "social_core.backends.github.GithubOrganizationOAuth2" @@ -139,6 +180,10 @@ def test_partial_pipeline(self): self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_ORG_NAME": "foobar"}) self.do_partial_pipeline() + def test_refresh_token(self): + self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_ORG_NAME": "foobar"}) + self.do_refresh_token() + class GithubOrganizationOAuth2FailTest(GithubOAuth2Test): backend_path = "social_core.backends.github.GithubOrganizationOAuth2" @@ -164,6 +209,11 @@ def test_partial_pipeline(self): with self.assertRaises(AuthFailed): self.do_partial_pipeline() + def test_refresh_token(self): + self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_ORG_NAME": "foobar"}) + with self.assertRaises(AuthFailed): + self.do_refresh_token() + class GithubTeamOAuth2Test(GithubOAuth2Test): backend_path = "social_core.backends.github.GithubTeamOAuth2" @@ -181,6 +231,10 @@ def test_partial_pipeline(self): self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_TEAM_ID": "123"}) self.do_partial_pipeline() + def test_refresh_token(self): + self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_TEAM_ID": "123"}) + self.do_refresh_token() + class GithubTeamOAuth2FailTest(GithubOAuth2Test): backend_path = "social_core.backends.github.GithubTeamOAuth2" @@ -205,3 +259,8 @@ def test_partial_pipeline(self): self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_TEAM_ID": "123"}) with self.assertRaises(AuthFailed): self.do_partial_pipeline() + + def test_refresh_token(self): + self.strategy.set_settings({"SOCIAL_AUTH_GITHUB_TEAM_ID": "123"}) + with self.assertRaises(AuthFailed): + self.do_refresh_token()