From eb008a55d06ed979985003d634ee4e9a9099478d Mon Sep 17 00:00:00 2001 From: Markus Strehle <11627201+strehle@users.noreply.github.com> Date: Fri, 14 Jul 2023 18:22:04 +0200 Subject: [PATCH] Refactor: Add Instant to TimeService interface and use TimeService in UaaTokenStore (#2315) * Add mock for TimeService * adopt test and use always class timeService * refactor all usages of Instant -> now to timeService * fix timService mock * simplify test * refactor and minimal changes --- .../identity/uaa/oauth/UaaTokenStore.java | 19 +++---- .../identity/uaa/util/TimeService.java | 3 ++ .../uaa/oauth/UaaTokenStoreTests.java | 49 +++++++++++-------- .../webapp/WEB-INF/spring/oauth-endpoints.xml | 3 +- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStore.java b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStore.java index ad089046a24..dba48d178ba 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStore.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStore.java @@ -14,12 +14,12 @@ package org.cloudfoundry.identity.uaa.oauth; - import com.fasterxml.jackson.core.type.TypeReference; import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication; import org.cloudfoundry.identity.uaa.authentication.UaaAuthenticationDetails; import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal; import org.cloudfoundry.identity.uaa.util.JsonUtils; +import org.cloudfoundry.identity.uaa.util.TimeService; import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.slf4j.Logger; @@ -54,7 +54,6 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicReference; public class UaaTokenStore implements AuthorizationCodeServices { public static final Duration DEFAULT_EXPIRATION_TIME = Duration.ofMinutes(5); @@ -80,6 +79,7 @@ public class UaaTokenStore implements AuthorizationCodeServices { private static final String SQL_CLEAN_STATEMENT = "delete from oauth_code where created < ? and expiresat = 0"; private final DataSource dataSource; + private final TimeService timeService; private final Duration expirationTime; private final RandomValueStringGenerator generator = new RandomValueStringGenerator(32); private final RowMapper rowMapper = new TokenCodeRowMapper(); @@ -87,12 +87,13 @@ public class UaaTokenStore implements AuthorizationCodeServices { private Instant lastClean = Instant.EPOCH; private Semaphore cleanMutex = new Semaphore(1); - public UaaTokenStore(DataSource dataSource) { - this(dataSource, DEFAULT_EXPIRATION_TIME); + public UaaTokenStore(DataSource dataSource, TimeService timeService) { + this(dataSource, timeService, DEFAULT_EXPIRATION_TIME); } - public UaaTokenStore(DataSource dataSource, Duration expirationTime) { + public UaaTokenStore(DataSource dataSource, TimeService timeService, Duration expirationTime) { this.dataSource = dataSource; + this.timeService = timeService; this.expirationTime = expirationTime; } @@ -106,7 +107,7 @@ public String createAuthorizationCode(OAuth2Authentication authentication) { attempt++; try { String code = generator.generate(); - Instant expiresAt = Instant.now().plus(getExpirationTime()); + Instant expiresAt = timeService.getCurrentInstant().plus(getExpirationTime()); String userId = authentication.getUserAuthentication()==null ? null : ((UaaPrincipal)authentication.getUserAuthentication().getPrincipal()).getId(); String clientId = authentication.getOAuth2Request().getClientId(); SqlLobValue data = new SqlLobValue(serializeOauth2Authentication(authentication)); @@ -216,7 +217,7 @@ protected void performExpirationCleanIfEnoughTimeHasElapsed() { if (cleanMutex.tryAcquire()) { //check if we should expire again try { - Instant now = Instant.now(); + Instant now = timeService.getCurrentInstant(); if (enoughTimeHasPassedSinceLastExpirationClean(lastClean, now)) { //avoid concurrent deletes from the same UAA - performance improvement lastClean = now; @@ -324,7 +325,7 @@ public NewTokenCode(String code, String userId, Instant expiresAt, String client @Override boolean isExpired() { - return expiresAt.isBefore(Instant.now()); + return expiresAt.isBefore(timeService.getCurrentInstant()); } @Override @@ -353,7 +354,7 @@ public LegacyTokenCode(String code, String userId, Instant created, String clien @Override boolean isExpired() { - return Instant.now().minus(getExpirationTime()).isAfter(created); + return timeService.getCurrentInstant().minus(getExpirationTime()).isAfter(created); } @Override diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/util/TimeService.java b/server/src/main/java/org/cloudfoundry/identity/uaa/util/TimeService.java index 32c337e2a99..42ea07351dd 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/util/TimeService.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/util/TimeService.java @@ -16,6 +16,7 @@ package org.cloudfoundry.identity.uaa.util; +import java.time.Instant; import java.util.Date; public interface TimeService { @@ -24,4 +25,6 @@ default long getCurrentTimeMillis() { } default Date getCurrentDate() { return new Date(getCurrentTimeMillis()); } + + default Instant getCurrentInstant() { return Instant.now(); } } diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStoreTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStoreTests.java index 770d1909e6e..2127e5a1554 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStoreTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenStoreTests.java @@ -5,6 +5,8 @@ import org.cloudfoundry.identity.uaa.authentication.UaaAuthenticationDetails; import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal; import org.cloudfoundry.identity.uaa.constants.OriginKeys; +import org.cloudfoundry.identity.uaa.util.TimeService; +import org.cloudfoundry.identity.uaa.util.TimeServiceImpl; import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.junit.jupiter.api.BeforeEach; @@ -34,10 +36,8 @@ import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.sql.Timestamp; import java.time.Duration; import java.time.Instant; -import java.time.temporal.Temporal; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -56,7 +56,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; @@ -70,6 +69,7 @@ class UaaTokenStoreTests { private OAuth2Authentication clientAuthentication; private OAuth2Authentication usernamePasswordAuthentication; private OAuth2Authentication uaaAuthentication; + private TimeService timeService; private UaaPrincipal principal = new UaaPrincipal("userid", "username", "username@test.org", OriginKeys.UAA, null, IdentityZone.getUaaZoneId()); @@ -86,7 +86,8 @@ void setUp() { List userAuthorities = Collections.singletonList(new SimpleGrantedAuthority( "openid")); - store = new UaaTokenStore(dataSource); + timeService = givenMockedTime(); + store = new UaaTokenStore(dataSource, timeService); legacyCodeServices = new JdbcAuthorizationCodeServices(dataSource); BaseClientDetails client = new BaseClientDetails("clientid", null, "openid", "client_credentials,password", "oauth.login", null); Map parameters = new HashMap<>(); @@ -200,7 +201,7 @@ void retrieveToken() { void retrieveExpiredToken() { String code = store.createAuthorizationCode(clientAuthentication); assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code WHERE code = ?", new Object[]{code}, Integer.class), is(1)); - jdbcTemplate.update("update oauth_code set expiresat = 1"); + doReturn(Instant.now().plus(UaaTokenStore.DEFAULT_EXPIRATION_TIME)).when(timeService).getCurrentInstant(); assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode(code)); } @@ -214,14 +215,13 @@ void retrieveNonExistentToken() { @Test void cleanUpExpiredTokensBasedOnExpiresField() { int count = 10; - store = new UaaTokenStore(dataSource, givenMockedExpiration()); String lastCode = null; for (int i = 0; i < count; i++) { lastCode = store.createAuthorizationCode(clientAuthentication); } assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count)); - jdbcTemplate.update("UPDATE oauth_code SET expiresat = ?", System.currentTimeMillis() - 60000); + doReturn(Instant.now().plus(UaaTokenStore.LEGACY_CODE_EXPIRATION_TIME)).when(timeService).getCurrentInstant(); final String finalLastCode = lastCode; assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode(finalLastCode)); @@ -232,15 +232,14 @@ void cleanUpExpiredTokensBasedOnExpiresField() { void cleanUpLegacyCodesCodesWithoutExpiresAtAfter3Days() { int count = 10; long oneday = 1000 * 60 * 60 * 24; - store = new UaaTokenStore(dataSource, givenMockedExpiration()); for (int i = 0; i < count; i++) { legacyCodeServices.createAuthorizationCode(clientAuthentication); } assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count)); - jdbcTemplate.update("UPDATE oauth_code SET created = ?", new Timestamp(System.currentTimeMillis() - (2 * oneday))); + doReturn(Instant.now().plus(Duration.ofDays(2))).when(timeService).getCurrentInstant(); assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode("non-existent")); assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count)); - jdbcTemplate.update("UPDATE oauth_code SET created = ?", new Timestamp(System.currentTimeMillis() - (4 * oneday))); + doReturn(Instant.now().plus(Duration.ofDays(4))).when(timeService).getCurrentInstant(); assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode("non-existent")); assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(0)); } @@ -289,7 +288,6 @@ void cleanUpUnusedOldTokensMySQLInAnotherTimezone( throw new RuntimeException("Unknown DB profile:" + db); } - store = new UaaTokenStore(sameConnectionDataSource); legacyCodeServices = new JdbcAuthorizationCodeServices(sameConnectionDataSource); int count = 10; String lastCode = null; @@ -304,7 +302,6 @@ void cleanUpUnusedOldTokensMySQLInAnotherTimezone( } assertThat(template.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count - 1)); } finally { - store = new UaaTokenStore(dataSource); legacyCodeServices = new JdbcAuthorizationCodeServices(dataSource); } } @@ -318,7 +315,7 @@ void cleanUpExpiredTokensDeadlockLoser() throws Exception { SameConnectionDataSource sameConnectionDataSource = new SameConnectionDataSource(expirationLoser); - store = new UaaTokenStore(sameConnectionDataSource, Duration.ofMillis(1)); + store = new UaaTokenStore(sameConnectionDataSource, timeService, Duration.ofMillis(1)); int count = 10; for (int i = 0; i < count; i++) { String code = store.createAuthorizationCode(clientAuthentication); @@ -327,8 +324,6 @@ void cleanUpExpiredTokensDeadlockLoser() throws Exception { } catch (InvalidGrantException ignored) { } } - } finally { - store = new UaaTokenStore(dataSource); } } @@ -346,7 +341,7 @@ void testCountingTheExecutedSqlDeleteStatements() throws SQLException { // Given, mocked data source to count how often it is used, call performExpirationClean 10 times. DataSource mockedDataSource = mock(DataSource.class); Instant before = Instant.now(); - store = new UaaTokenStore(mockedDataSource); + store = new UaaTokenStore(mockedDataSource, timeService); // When for (int i = 0; i < 10; i++) { try { @@ -362,6 +357,20 @@ void testCountingTheExecutedSqlDeleteStatements() throws SQLException { assertTrue(after.compareTo(before) < Duration.ofMinutes(5).toNanos()); // Expect us to call the DB only once within 5 minutes. Check this when using the data source object verify(mockedDataSource, atMost(1)).getConnection(); + // When moving time to one hour later from now + doReturn(Instant.now().plus(Duration.ofHours(1))).when(timeService).getCurrentInstant(); + // Then + performExpirationClean(store); + // Expect a 2nd DB call + verify(mockedDataSource, atMost(2)).getConnection(); + } + + private static void performExpirationClean(UaaTokenStore store) { + try { + store.performExpirationCleanIfEnoughTimeHasElapsed(); + } catch (Exception sqlException) { + // ignore + } } public static class SameConnectionDataSource implements DataSource { @@ -482,10 +491,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl } } - private static Duration givenMockedExpiration() { - Duration durationMock = mock(Duration.class); - doReturn(Instant.now().plus(UaaTokenStore.DEFAULT_EXPIRATION_TIME)).when(durationMock).addTo(any(Temporal.class)); - return durationMock; + private static TimeService givenMockedTime() { + TimeServiceImpl timeService = mock(TimeServiceImpl.class); + doReturn(Instant.now()).when(timeService).getCurrentInstant(); + return timeService; } private static final byte[] UAA_AUTHENTICATION_DATA_OLD_STYLE = new byte[]{123, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 115, 112, 111, 110, 115, 101, 84, 121, 112, 101, 115, 34, 58, 91, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 115, 111, 117, 114, 99, 101, 73, 100, 115, 34, 58, 91, 93, 44, 34, 117, 115, 101, 114, 65, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 46, 117, 97, 97, 80, 114, 105, 110, 99, 105, 112, 97, 108, 34, 58, 34, 123, 92, 34, 105, 100, 92, 34, 58, 92, 34, 117, 115, 101, 114, 105, 100, 92, 34, 44, 92, 34, 110, 97, 109, 101, 92, 34, 58, 92, 34, 117, 115, 101, 114, 110, 97, 109, 101, 92, 34, 44, 92, 34, 101, 109, 97, 105, 108, 92, 34, 58, 92, 34, 117, 115, 101, 114, 110, 97, 109, 101, 64, 116, 101, 115, 116, 46, 111, 114, 103, 92, 34, 44, 92, 34, 111, 114, 105, 103, 105, 110, 92, 34, 58, 92, 34, 117, 97, 97, 92, 34, 44, 92, 34, 101, 120, 116, 101, 114, 110, 97, 108, 73, 100, 92, 34, 58, 110, 117, 108, 108, 44, 92, 34, 122, 111, 110, 101, 73, 100, 92, 34, 58, 92, 34, 117, 97, 97, 92, 34, 125, 34, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 113, 117, 101, 115, 116, 80, 97, 114, 97, 109, 101, 116, 101, 114, 115, 34, 58, 123, 34, 103, 114, 97, 110, 116, 95, 116, 121, 112, 101, 34, 58, 34, 112, 97, 115, 115, 119, 111, 114, 100, 34, 44, 34, 99, 108, 105, 101, 110, 116, 95, 105, 100, 34, 58, 34, 99, 108, 105, 101, 110, 116, 105, 100, 34, 44, 34, 115, 99, 111, 112, 101, 34, 58, 34, 111, 112, 101, 110, 105, 100, 34, 125, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 100, 105, 114, 101, 99, 116, 85, 114, 105, 34, 58, 110, 117, 108, 108, 44, 34, 117, 115, 101, 114, 65, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 46, 97, 117, 116, 104, 111, 114, 105, 116, 105, 101, 115, 34, 58, 91, 34, 111, 112, 101, 110, 105, 100, 34, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 97, 117, 116, 104, 111, 114, 105, 116, 105, 101, 115, 34, 58, 91, 34, 111, 97, 117, 116, 104, 46, 108, 111, 103, 105, 110, 34, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 99, 108, 105, 101, 110, 116, 73, 100, 34, 58, 34, 99, 108, 105, 101, 110, 116, 105, 100, 34, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 97, 112, 112, 114, 111, 118, 101, 100, 34, 58, 116, 114, 117, 101, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 115, 99, 111, 112, 101, 34, 58, 91, 34, 111, 112, 101, 110, 105, 100, 34, 93, 125}; diff --git a/uaa/src/main/webapp/WEB-INF/spring/oauth-endpoints.xml b/uaa/src/main/webapp/WEB-INF/spring/oauth-endpoints.xml index 0d2169ca528..5b262fac55c 100755 --- a/uaa/src/main/webapp/WEB-INF/spring/oauth-endpoints.xml +++ b/uaa/src/main/webapp/WEB-INF/spring/oauth-endpoints.xml @@ -476,7 +476,8 @@ - + +