Skip to content

Commit

Permalink
Refactor: Add Instant to TimeService interface and use TimeService in…
Browse files Browse the repository at this point in the history
… UaaTokenStore (#2315)

* Add mock for TimeService

* adopt test and use always class timeService

* refactor all usages of Instant -> now to timeService

* fix timService mock

* simplify test

* refactor and minimal changes
  • Loading branch information
strehle authored Jul 14, 2023
1 parent bd39493 commit eb008a5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

package org.cloudfoundry.identity.uaa.oauth;


import com.fasterxml.jackson.core.type.TypeReference;
import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication;
import org.cloudfoundry.identity.uaa.authentication.UaaAuthenticationDetails;
import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.slf4j.Logger;
Expand Down Expand Up @@ -54,7 +54,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicReference;

public class UaaTokenStore implements AuthorizationCodeServices {
public static final Duration DEFAULT_EXPIRATION_TIME = Duration.ofMinutes(5);
Expand All @@ -80,19 +79,21 @@ public class UaaTokenStore implements AuthorizationCodeServices {
private static final String SQL_CLEAN_STATEMENT = "delete from oauth_code where created < ? and expiresat = 0";

private final DataSource dataSource;
private final TimeService timeService;
private final Duration expirationTime;
private final RandomValueStringGenerator generator = new RandomValueStringGenerator(32);
private final RowMapper rowMapper = new TokenCodeRowMapper();

private Instant lastClean = Instant.EPOCH;
private Semaphore cleanMutex = new Semaphore(1);

public UaaTokenStore(DataSource dataSource) {
this(dataSource, DEFAULT_EXPIRATION_TIME);
public UaaTokenStore(DataSource dataSource, TimeService timeService) {
this(dataSource, timeService, DEFAULT_EXPIRATION_TIME);
}

public UaaTokenStore(DataSource dataSource, Duration expirationTime) {
public UaaTokenStore(DataSource dataSource, TimeService timeService, Duration expirationTime) {
this.dataSource = dataSource;
this.timeService = timeService;
this.expirationTime = expirationTime;
}

Expand All @@ -106,7 +107,7 @@ public String createAuthorizationCode(OAuth2Authentication authentication) {
attempt++;
try {
String code = generator.generate();
Instant expiresAt = Instant.now().plus(getExpirationTime());
Instant expiresAt = timeService.getCurrentInstant().plus(getExpirationTime());
String userId = authentication.getUserAuthentication()==null ? null : ((UaaPrincipal)authentication.getUserAuthentication().getPrincipal()).getId();
String clientId = authentication.getOAuth2Request().getClientId();
SqlLobValue data = new SqlLobValue(serializeOauth2Authentication(authentication));
Expand Down Expand Up @@ -216,7 +217,7 @@ protected void performExpirationCleanIfEnoughTimeHasElapsed() {
if (cleanMutex.tryAcquire()) {
//check if we should expire again
try {
Instant now = Instant.now();
Instant now = timeService.getCurrentInstant();
if (enoughTimeHasPassedSinceLastExpirationClean(lastClean, now)) {
//avoid concurrent deletes from the same UAA - performance improvement
lastClean = now;
Expand Down Expand Up @@ -324,7 +325,7 @@ public NewTokenCode(String code, String userId, Instant expiresAt, String client

@Override
boolean isExpired() {
return expiresAt.isBefore(Instant.now());
return expiresAt.isBefore(timeService.getCurrentInstant());
}

@Override
Expand Down Expand Up @@ -353,7 +354,7 @@ public LegacyTokenCode(String code, String userId, Instant created, String clien

@Override
boolean isExpired() {
return Instant.now().minus(getExpirationTime()).isAfter(created);
return timeService.getCurrentInstant().minus(getExpirationTime()).isAfter(created);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.cloudfoundry.identity.uaa.util;


import java.time.Instant;
import java.util.Date;

public interface TimeService {
Expand All @@ -24,4 +25,6 @@ default long getCurrentTimeMillis() {
}

default Date getCurrentDate() { return new Date(getCurrentTimeMillis()); }

default Instant getCurrentInstant() { return Instant.now(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import org.cloudfoundry.identity.uaa.authentication.UaaAuthenticationDetails;
import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.TimeServiceImpl;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -34,10 +36,8 @@
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.Temporal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -56,7 +56,6 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
Expand All @@ -70,6 +69,7 @@ class UaaTokenStoreTests {
private OAuth2Authentication clientAuthentication;
private OAuth2Authentication usernamePasswordAuthentication;
private OAuth2Authentication uaaAuthentication;
private TimeService timeService;

private UaaPrincipal principal = new UaaPrincipal("userid", "username", "username@test.org", OriginKeys.UAA, null, IdentityZone.getUaaZoneId());

Expand All @@ -86,7 +86,8 @@ void setUp() {
List<GrantedAuthority> userAuthorities = Collections.singletonList(new SimpleGrantedAuthority(
"openid"));

store = new UaaTokenStore(dataSource);
timeService = givenMockedTime();
store = new UaaTokenStore(dataSource, timeService);
legacyCodeServices = new JdbcAuthorizationCodeServices(dataSource);
BaseClientDetails client = new BaseClientDetails("clientid", null, "openid", "client_credentials,password", "oauth.login", null);
Map<String, String> parameters = new HashMap<>();
Expand Down Expand Up @@ -200,7 +201,7 @@ void retrieveToken() {
void retrieveExpiredToken() {
String code = store.createAuthorizationCode(clientAuthentication);
assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code WHERE code = ?", new Object[]{code}, Integer.class), is(1));
jdbcTemplate.update("update oauth_code set expiresat = 1");
doReturn(Instant.now().plus(UaaTokenStore.DEFAULT_EXPIRATION_TIME)).when(timeService).getCurrentInstant();
assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode(code));
}

Expand All @@ -214,14 +215,13 @@ void retrieveNonExistentToken() {
@Test
void cleanUpExpiredTokensBasedOnExpiresField() {
int count = 10;
store = new UaaTokenStore(dataSource, givenMockedExpiration());
String lastCode = null;
for (int i = 0; i < count; i++) {
lastCode = store.createAuthorizationCode(clientAuthentication);
}
assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count));

jdbcTemplate.update("UPDATE oauth_code SET expiresat = ?", System.currentTimeMillis() - 60000);
doReturn(Instant.now().plus(UaaTokenStore.LEGACY_CODE_EXPIRATION_TIME)).when(timeService).getCurrentInstant();

final String finalLastCode = lastCode;
assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode(finalLastCode));
Expand All @@ -232,15 +232,14 @@ void cleanUpExpiredTokensBasedOnExpiresField() {
void cleanUpLegacyCodesCodesWithoutExpiresAtAfter3Days() {
int count = 10;
long oneday = 1000 * 60 * 60 * 24;
store = new UaaTokenStore(dataSource, givenMockedExpiration());
for (int i = 0; i < count; i++) {
legacyCodeServices.createAuthorizationCode(clientAuthentication);
}
assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count));
jdbcTemplate.update("UPDATE oauth_code SET created = ?", new Timestamp(System.currentTimeMillis() - (2 * oneday)));
doReturn(Instant.now().plus(Duration.ofDays(2))).when(timeService).getCurrentInstant();
assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode("non-existent"));
assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count));
jdbcTemplate.update("UPDATE oauth_code SET created = ?", new Timestamp(System.currentTimeMillis() - (4 * oneday)));
doReturn(Instant.now().plus(Duration.ofDays(4))).when(timeService).getCurrentInstant();
assertThrows(InvalidGrantException.class, () -> store.consumeAuthorizationCode("non-existent"));
assertThat(jdbcTemplate.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(0));
}
Expand Down Expand Up @@ -289,7 +288,6 @@ void cleanUpUnusedOldTokensMySQLInAnotherTimezone(
throw new RuntimeException("Unknown DB profile:" + db);
}

store = new UaaTokenStore(sameConnectionDataSource);
legacyCodeServices = new JdbcAuthorizationCodeServices(sameConnectionDataSource);
int count = 10;
String lastCode = null;
Expand All @@ -304,7 +302,6 @@ void cleanUpUnusedOldTokensMySQLInAnotherTimezone(
}
assertThat(template.queryForObject("SELECT count(*) FROM oauth_code", Integer.class), is(count - 1));
} finally {
store = new UaaTokenStore(dataSource);
legacyCodeServices = new JdbcAuthorizationCodeServices(dataSource);
}
}
Expand All @@ -318,7 +315,7 @@ void cleanUpExpiredTokensDeadlockLoser() throws Exception {

SameConnectionDataSource sameConnectionDataSource = new SameConnectionDataSource(expirationLoser);

store = new UaaTokenStore(sameConnectionDataSource, Duration.ofMillis(1));
store = new UaaTokenStore(sameConnectionDataSource, timeService, Duration.ofMillis(1));
int count = 10;
for (int i = 0; i < count; i++) {
String code = store.createAuthorizationCode(clientAuthentication);
Expand All @@ -327,8 +324,6 @@ void cleanUpExpiredTokensDeadlockLoser() throws Exception {
} catch (InvalidGrantException ignored) {
}
}
} finally {
store = new UaaTokenStore(dataSource);
}
}

Expand All @@ -346,7 +341,7 @@ void testCountingTheExecutedSqlDeleteStatements() throws SQLException {
// Given, mocked data source to count how often it is used, call performExpirationClean 10 times.
DataSource mockedDataSource = mock(DataSource.class);
Instant before = Instant.now();
store = new UaaTokenStore(mockedDataSource);
store = new UaaTokenStore(mockedDataSource, timeService);
// When
for (int i = 0; i < 10; i++) {
try {
Expand All @@ -362,6 +357,20 @@ void testCountingTheExecutedSqlDeleteStatements() throws SQLException {
assertTrue(after.compareTo(before) < Duration.ofMinutes(5).toNanos());
// Expect us to call the DB only once within 5 minutes. Check this when using the data source object
verify(mockedDataSource, atMost(1)).getConnection();
// When moving time to one hour later from now
doReturn(Instant.now().plus(Duration.ofHours(1))).when(timeService).getCurrentInstant();
// Then
performExpirationClean(store);
// Expect a 2nd DB call
verify(mockedDataSource, atMost(2)).getConnection();
}

private static void performExpirationClean(UaaTokenStore store) {
try {
store.performExpirationCleanIfEnoughTimeHasElapsed();
} catch (Exception sqlException) {
// ignore
}
}

public static class SameConnectionDataSource implements DataSource {
Expand Down Expand Up @@ -482,10 +491,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
}
}

private static Duration givenMockedExpiration() {
Duration durationMock = mock(Duration.class);
doReturn(Instant.now().plus(UaaTokenStore.DEFAULT_EXPIRATION_TIME)).when(durationMock).addTo(any(Temporal.class));
return durationMock;
private static TimeService givenMockedTime() {
TimeServiceImpl timeService = mock(TimeServiceImpl.class);
doReturn(Instant.now()).when(timeService).getCurrentInstant();
return timeService;
}

private static final byte[] UAA_AUTHENTICATION_DATA_OLD_STYLE = new byte[]{123, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 115, 112, 111, 110, 115, 101, 84, 121, 112, 101, 115, 34, 58, 91, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 115, 111, 117, 114, 99, 101, 73, 100, 115, 34, 58, 91, 93, 44, 34, 117, 115, 101, 114, 65, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 46, 117, 97, 97, 80, 114, 105, 110, 99, 105, 112, 97, 108, 34, 58, 34, 123, 92, 34, 105, 100, 92, 34, 58, 92, 34, 117, 115, 101, 114, 105, 100, 92, 34, 44, 92, 34, 110, 97, 109, 101, 92, 34, 58, 92, 34, 117, 115, 101, 114, 110, 97, 109, 101, 92, 34, 44, 92, 34, 101, 109, 97, 105, 108, 92, 34, 58, 92, 34, 117, 115, 101, 114, 110, 97, 109, 101, 64, 116, 101, 115, 116, 46, 111, 114, 103, 92, 34, 44, 92, 34, 111, 114, 105, 103, 105, 110, 92, 34, 58, 92, 34, 117, 97, 97, 92, 34, 44, 92, 34, 101, 120, 116, 101, 114, 110, 97, 108, 73, 100, 92, 34, 58, 110, 117, 108, 108, 44, 92, 34, 122, 111, 110, 101, 73, 100, 92, 34, 58, 92, 34, 117, 97, 97, 92, 34, 125, 34, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 113, 117, 101, 115, 116, 80, 97, 114, 97, 109, 101, 116, 101, 114, 115, 34, 58, 123, 34, 103, 114, 97, 110, 116, 95, 116, 121, 112, 101, 34, 58, 34, 112, 97, 115, 115, 119, 111, 114, 100, 34, 44, 34, 99, 108, 105, 101, 110, 116, 95, 105, 100, 34, 58, 34, 99, 108, 105, 101, 110, 116, 105, 100, 34, 44, 34, 115, 99, 111, 112, 101, 34, 58, 34, 111, 112, 101, 110, 105, 100, 34, 125, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 114, 101, 100, 105, 114, 101, 99, 116, 85, 114, 105, 34, 58, 110, 117, 108, 108, 44, 34, 117, 115, 101, 114, 65, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 46, 97, 117, 116, 104, 111, 114, 105, 116, 105, 101, 115, 34, 58, 91, 34, 111, 112, 101, 110, 105, 100, 34, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 97, 117, 116, 104, 111, 114, 105, 116, 105, 101, 115, 34, 58, 91, 34, 111, 97, 117, 116, 104, 46, 108, 111, 103, 105, 110, 34, 93, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 99, 108, 105, 101, 110, 116, 73, 100, 34, 58, 34, 99, 108, 105, 101, 110, 116, 105, 100, 34, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 97, 112, 112, 114, 111, 118, 101, 100, 34, 58, 116, 114, 117, 101, 44, 34, 111, 97, 117, 116, 104, 50, 82, 101, 113, 117, 101, 115, 116, 46, 115, 99, 111, 112, 101, 34, 58, 91, 34, 111, 112, 101, 110, 105, 100, 34, 93, 125};
Expand Down
3 changes: 2 additions & 1 deletion uaa/src/main/webapp/WEB-INF/spring/oauth-endpoints.xml
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@
<!-- End of PKCE -->

<bean id="authorizationCodeServices" class="org.cloudfoundry.identity.uaa.oauth.UaaTokenStore">
<constructor-arg ref="dataSource"/>
<constructor-arg name="dataSource" ref="dataSource"/>
<constructor-arg name="timeService" ref="timeService"/>
</bean>

<bean id="userApprovalHandler" class="org.cloudfoundry.identity.uaa.user.UaaUserApprovalHandler"/>
Expand Down

0 comments on commit eb008a5

Please sign in to comment.