Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Add Instant to TimeService interface and use TimeService in UaaTokenStore #2315

Merged
merged 10 commits into from
Jul 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.cloudfoundry.identity.uaa.authentication.UaaAuthenticationDetails;
import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.slf4j.Logger;
Expand Down Expand Up @@ -79,18 +80,20 @@ public class UaaTokenStore implements AuthorizationCodeServices {
private static final String SQL_CLEAN_STATEMENT = "delete from oauth_code where created < ? and expiresat = 0";

private final DataSource dataSource;
private final TimeService timeService;
private final Duration expirationTime;
private final RandomValueStringGenerator generator = new RandomValueStringGenerator(32);
private final RowMapper rowMapper = new TokenCodeRowMapper();

private AtomicReference<Instant> lastClean = new AtomicReference<>(Instant.EPOCH);

public UaaTokenStore(DataSource dataSource) {
this(dataSource, DEFAULT_EXPIRATION_TIME);
public UaaTokenStore(DataSource dataSource, TimeService timeService) {
this(dataSource, timeService, DEFAULT_EXPIRATION_TIME);
}

public UaaTokenStore(DataSource dataSource, Duration expirationTime) {
public UaaTokenStore(DataSource dataSource, TimeService timeService, Duration expirationTime) {
this.dataSource = dataSource;
this.timeService = timeService;
this.expirationTime = expirationTime;
}

Expand All @@ -103,7 +106,7 @@ public String createAuthorizationCode(OAuth2Authentication authentication) {
while ((tries++)<=max_tries) {
try {
String code = generator.generate();
Instant expiresAt = Instant.now().plus(getExpirationTime());
Instant expiresAt = timeService.getCurrentInstant().plus(getExpirationTime());
String userId = authentication.getUserAuthentication()==null ? null : ((UaaPrincipal)authentication.getUserAuthentication().getPrincipal()).getId();
String clientId = authentication.getOAuth2Request().getClientId();
SqlLobValue data = new SqlLobValue(serializeOauth2Authentication(authentication));
Expand Down Expand Up @@ -213,7 +216,7 @@ else if (map.get(USER_AUTHENTICATION_UAA_PRINCIPAL)!=null) {
protected void performExpirationClean() {
Instant last = lastClean.get();
//check if we should expire again
Instant now = Instant.now();
Instant now = timeService.getCurrentInstant();
if (enoughTimeHasPassedSinceLastExpirationClean(last, now)) {
//avoid concurrent deletes from the same UAA - performance improvement
if (lastClean.compareAndSet(last, now)) {
Expand Down Expand Up @@ -314,7 +317,7 @@ public NewTokenCode(String code, String userId, Instant expiresAt, String client

@Override
boolean isExpired() {
return expiresAt.isBefore(Instant.now());
return expiresAt.isBefore(timeService.getCurrentInstant());
}

@Override
Expand Down Expand Up @@ -343,7 +346,7 @@ public LegacyTokenCode(String code, String userId, Instant created, String clien

@Override
boolean isExpired() {
return Instant.now().minus(getExpirationTime()).isAfter(created);
return timeService.getCurrentInstant().minus(getExpirationTime()).isAfter(created);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
package org.cloudfoundry.identity.uaa.util;


import java.time.Instant;
import java.util.Date;

public interface TimeService {
default long getCurrentTimeMillis() {
return System.currentTimeMillis();
return getCurrentInstant().toEpochMilli();
}

default Date getCurrentDate() { return new Date(getCurrentTimeMillis()); }
default Date getCurrentDate() { return Date.from(getCurrentInstant()); }

default Instant getCurrentInstant() { return Instant.now(); }
strehle marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.client.BaseClientDetails;

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Date;
import java.util.List;

Expand Down Expand Up @@ -50,6 +52,7 @@ public void ensureRequiredApprovals_happyCase() {
approval.setStatus(Approval.ApprovalStatus.APPROVED);
approval.setExpiresAt(new Date(approvalExpiry));
when(timeService.getCurrentTimeMillis()).thenReturn(approvalExpiry - 1L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(approvalExpiry - 1L, ChronoUnit.MILLIS));
when(timeService.getCurrentDate()).thenCallRealMethod();

List<Approval> approvals = Lists.newArrayList(approval);
Expand All @@ -68,7 +71,7 @@ public void ensureRequiredApprovals_throwsWhenApprovalsExpired() {
approval.setScope("foo.read");
approval.setStatus(Approval.ApprovalStatus.APPROVED);
approval.setExpiresAt(new Date(approvalExpiry));
when(timeService.getCurrentTimeMillis()).thenReturn(approvalExpiry + 1L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(approvalExpiry + 1L, ChronoUnit.MILLIS));
strehle marked this conversation as resolved.
Show resolved Hide resolved
when(timeService.getCurrentDate()).thenCallRealMethod();

List<Approval> approvals = Lists.newArrayList(approval);
Expand Down Expand Up @@ -112,7 +115,7 @@ public void ensureRequiredApprovals_iteratesThroughAllApprovalsAndScopes() {
approval3.setStatus(Approval.ApprovalStatus.APPROVED);
approval3.setExpiresAt(new Date(approvalExpiry));

when(timeService.getCurrentTimeMillis()).thenReturn(approvalExpiry - 1L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(approvalExpiry - 1L, ChronoUnit.MILLIS));
when(timeService.getCurrentDate()).thenCallRealMethod();

List<Approval> approvals = Lists.newArrayList(approval1, approval2, approval3);
Expand Down Expand Up @@ -140,7 +143,7 @@ public void ensureRequiredApprovals_throwsIfAnyRequestedScopesAreNotApproved() {
approval3.setStatus(Approval.ApprovalStatus.APPROVED);
approval3.setExpiresAt(new Date(approvalExpiry));

when(timeService.getCurrentTimeMillis()).thenReturn(approvalExpiry - 1L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(approvalExpiry - 1L, ChronoUnit.MILLIS));
when(timeService.getCurrentDate()).thenCallRealMethod();

List<Approval> approvals = Lists.newArrayList(approval1, approval2, approval3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.HttpRequestMethodNotSupportedException;

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -223,7 +225,7 @@ public void setUp(boolean opaque) throws Exception {

nowMillis = 10000L;
timeService = mock(TimeService.class);
when(timeService.getCurrentTimeMillis()).thenReturn(nowMillis);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(nowMillis, ChronoUnit.MILLIS));
when(timeService.getCurrentDate()).thenCallRealMethod();
userAuthorities = new ArrayList<>();
userAuthorities.add(new SimpleGrantedAuthority("read"));
Expand Down Expand Up @@ -942,7 +944,7 @@ public void testExpiredToken() throws Exception {
tokenServices.setClientDetailsService(clientDetailsService);
OAuth2AccessToken accessToken = tokenServices.createAccessToken(authentication);

when(timeService.getCurrentTimeMillis()).thenReturn(nowMillis + validitySeconds.longValue() * 1000 + 1L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(nowMillis + validitySeconds.longValue() * 1000 + 1L, ChronoUnit.MILLIS));
endpoint.checkToken(accessToken.getValue(), Collections.emptyList(), request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.*;

import static java.util.Collections.*;
Expand Down Expand Up @@ -128,6 +130,7 @@ public void setUp() throws Exception {
tokenServices = tokenSupport.getUaaTokenServices();
tokenProvisioning = tokenSupport.getTokenProvisioning();
when(tokenSupport.timeService.getCurrentTimeMillis()).thenReturn(1000L);
when(tokenSupport.timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(1000L, ChronoUnit.MILLIS));
}

@After
Expand Down Expand Up @@ -249,6 +252,7 @@ public void refreshAccessToken_buildsIdToken_withRolesAndAttributesAndACR() thro

TimeService timeService = mock(TimeService.class);
when(timeService.getCurrentTimeMillis()).thenReturn(1000L);
when(timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(1000L, ChronoUnit.MILLIS));
when(timeService.getCurrentDate()).thenCallRealMethod();
RefreshTokenCreator refreshTokenCreator = mock(RefreshTokenCreator.class);
ApprovalService approvalService = mock(ApprovalService.class);
Expand Down Expand Up @@ -1823,6 +1827,7 @@ public void testLoadAuthenticationWithAnExpiredToken() {
assertThat(accessToken, validFor(is(1)));

when(tokenSupport.timeService.getCurrentTimeMillis()).thenReturn(2001L);
when(tokenSupport.timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(2001L, ChronoUnit.MILLIS));
try {
tokenServices.loadAuthentication(accessToken.getValue());
fail("Expected Exception was not thrown");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.springframework.security.oauth2.provider.TokenRequest;
import org.springframework.security.oauth2.provider.client.BaseClientDetails;

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
Expand Down Expand Up @@ -55,7 +57,7 @@ void setUp() throws Exception {
persistToken.setExpiration(expiration);

tokenServices = tokenSupport.getUaaTokenServices();
when(tokenSupport.timeService.getCurrentTimeMillis()).thenReturn(1000L);
when(tokenSupport.timeService.getCurrentInstant()).thenReturn(Instant.EPOCH.plus(1000L, ChronoUnit.MILLIS));
new IdentityZoneManagerImpl().getCurrentIdentityZone().getConfig().getTokenPolicy().setRefreshTokenFormat(TokenConstants.TokenFormat.OPAQUE.getStringValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ public TokenTestSupport(UaaTokenEnhancer tokenEnhancer, KeyInfoService keyInfo)
requestFactory = new DefaultOAuth2RequestFactory(clientDetailsService);
timeService = mock(TimeService.class);
approvalService = new ApprovalService(timeService, approvalStore);
when(timeService.getCurrentInstant()).thenCallRealMethod();
when(timeService.getCurrentDate()).thenCallRealMethod();
TokenEndpointBuilder tokenEndpointBuilder = new TokenEndpointBuilder(DEFAULT_ISSUER);
keyInfoService = keyInfo != null ? keyInfo : new KeyInfoService(DEFAULT_ISSUER);
Expand Down
Loading