diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 0c8cc2377..e23f56879 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -653,7 +653,7 @@ func (ts *TokenTestSuite) TestPasswordVerificationHook() { } -func (ts *TokenTestSuite) TestCustomAccessToken() { +func (ts *TokenTestSuite) TestCustomAccessTokenHook() { type customAccessTokenTestcase struct { desc string uri string @@ -762,6 +762,42 @@ end; $$ language plpgsql;`, } } +func (ts *TokenTestSuite) TestAddCustomClaims() { + // Use an autoconfirmed signup and a hook to add claims which are returned in response + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = "pg-functions://postgres/auth/custom_access_token_add_claim" + ts.Config.Mailer.Autoconfirm = true + require.NoError(ts.T(), ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + hookFunctionSQL := ` create or replace function custom_access_token_add_claim(input jsonb) returns jsonb as $$ declare result jsonb; begin if jsonb_typeof(jsonb_object_field(input, 'claims')) is null then result := jsonb_build_object('error', jsonb_build_object('http_code', 400, 'message', 'Input does not contain claims field')); return result; end if; + input := jsonb_set(input, '{claims,app_metadata,newclaim}', '"newcustomclaim"', true); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;` + err := ts.API.db.RawQuery(hookFunctionSQL).Exec() + require.NoError(ts.T(), err) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test0@example.com", + "password": "test1213", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + sessionResponse := AccessTokenResponse{} + + require.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&sessionResponse)) + require.Equal(ts.T(), sessionResponse.User.AppMetaData["newclaim"], "newcustomclaim") + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false +} + func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() { companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil)