diff --git a/compose.yml b/compose.yml index 0141ccad9..16baf038e 100644 --- a/compose.yml +++ b/compose.yml @@ -31,6 +31,7 @@ services: "--canvas-token", "changeme", "--use-canvas", "true", # "--disable-compilation", # Enable me, if desired! + "--client-id", "changeme", ] networks: - autograder diff --git a/docs/getting-started/getting-started.md b/docs/getting-started/getting-started.md index 05f4b081f..9f254d0be 100644 --- a/docs/getting-started/getting-started.md +++ b/docs/getting-started/getting-started.md @@ -109,6 +109,12 @@ Do the following actions: # Use one of the following, but not both --canvas-token --use-canvas false + +#Follow the steps at https://developer.byu.edu/data/api-usage/create-an-oauth-client +#Please choose the sandbox environment +#You are going to want to choose the Auth Code + PKCE option +#For the redirect url, you should use the cas-callback-url +--client-id ``` ### 6. Run the Autograder Locally diff --git a/src/main/java/Main.java b/src/main/java/Main.java index 1e59be2ba..137c5fb93 100644 --- a/src/main/java/Main.java +++ b/src/main/java/Main.java @@ -90,6 +90,9 @@ private static void setupProperties(String[] args) { if (cmd.hasOption("disable-compilation")) { properties.setProperty("run-compilation", "false"); } + if(cmd.hasOption("client-id")){ + properties.setProperty("client-id", cmd.getOptionValue("client-id")); + } } catch (ParseException e) { throw new RuntimeException("Error parsing command line arguments", e); } @@ -109,6 +112,7 @@ private static Options getOptions() { options.addOption(null, "canvas-token", true, "Canvas Token"); options.addOption(null, "use-canvas", true, "Using Canvas"); options.addOption(null, "disable-compilation", false, "Turn off student code compilation"); + options.addOption(null, "client-id", true, "Client ID for BYU OAuth"); return options; } diff --git a/src/main/java/edu/byu/cs/controller/CasController.java b/src/main/java/edu/byu/cs/controller/RedirectController.java similarity index 61% rename from src/main/java/edu/byu/cs/controller/CasController.java rename to src/main/java/edu/byu/cs/controller/RedirectController.java index 8ec60c15e..0eafbef73 100644 --- a/src/main/java/edu/byu/cs/controller/CasController.java +++ b/src/main/java/edu/byu/cs/controller/RedirectController.java @@ -1,53 +1,71 @@ package edu.byu.cs.controller; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; + import edu.byu.cs.canvas.CanvasException; +import edu.byu.cs.controller.exception.UnauthorizedException; import edu.byu.cs.dataAccess.DataAccessException; import edu.byu.cs.model.User; import edu.byu.cs.properties.ApplicationProperties; -import edu.byu.cs.service.CasService; +import edu.byu.cs.service.AuthenticationService; import edu.byu.cs.service.ConfigService; +import static edu.byu.cs.util.JwtUtils.generateToken; import io.javalin.http.Context; +import io.javalin.http.Cookie; import io.javalin.http.Handler; import io.javalin.http.HttpStatus; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; - -import static edu.byu.cs.util.JwtUtils.generateToken; - /** - * Handles CAS-related HTTP endpoints. CAS, standing for Central Authentication Service, - * is BYU's centralized authentication provider for all BYU users + * Handles Redirect related endpoints, including the class chat link (i. e. Slack or Discord) + * and authentication redirects. */ -public class CasController { +public class RedirectController { + public static final Handler callbackGet = ctx -> { - String ticket = ctx.queryParam("ticket"); + String code = ctx.queryParam("code"); + if (code == null){ + throw new UnauthorizedException(); + } + AuthenticationService.TokenResponse response = AuthenticationService.exchangeCodeForTokens(code); User user; try { - user = CasService.callback(ticket); + user = AuthenticationService.callback(response.idToken()); } catch (CanvasException e) { String errorUrlParam = URLEncoder.encode(e.getMessage(), StandardCharsets.UTF_8); ctx.redirect(ApplicationProperties.frontendUrl() + "/login?error=" + errorUrlParam, HttpStatus.FOUND); return; } - // FIXME: secure cookie with httpOnly - ctx.cookie("token", generateToken(user.netId()), 14400); + ctx.cookie (new Cookie( + "token", + generateToken(user.netId()), + "/", + 14400, + AuthenticationService.isSecure(), + 0, + true + )); redirect(ctx); }; + + public static final Handler loginGet = ctx -> { // check if already logged in if (ctx.cookie("token") != null) { redirect(ctx); return; } - ctx.redirect(CasService.BYU_CAS_URL + "/login" + "?service=" + ApplicationProperties.casCallbackUrl()); + ctx.redirect(AuthenticationService.getAuthorizationUrl()); }; - + /** + * Redirects students to the class chat invite. At the time we used Slack, and therefore all references use + * that name + */ private static void redirect(Context ctx) throws DataAccessException { String redirectTo; if(ctx.sessionAttribute("slack") != null) { @@ -64,7 +82,7 @@ private static void redirect(Context ctx) throws DataAccessException { return; } - // TODO: call cas logout endpoint with ticket + // TODO: call logout endpoint with token ctx.removeCookie("token", "/"); ctx.redirect(ApplicationProperties.frontendUrl(), HttpStatus.OK); }; diff --git a/src/main/java/edu/byu/cs/properties/ApplicationProperties.java b/src/main/java/edu/byu/cs/properties/ApplicationProperties.java index 9f972a890..8b9f4016d 100644 --- a/src/main/java/edu/byu/cs/properties/ApplicationProperties.java +++ b/src/main/java/edu/byu/cs/properties/ApplicationProperties.java @@ -58,6 +58,7 @@ public static String frontendUrl() { return mustGet("frontend-url"); } + public static String clientId() {return mustGet("client-id");} public static String casCallbackUrl() { return mustGet("cas-callback-url"); diff --git a/src/main/java/edu/byu/cs/server/endpointprovider/EndpointProviderImpl.java b/src/main/java/edu/byu/cs/server/endpointprovider/EndpointProviderImpl.java index 00d00d68e..6e3aa8779 100644 --- a/src/main/java/edu/byu/cs/server/endpointprovider/EndpointProviderImpl.java +++ b/src/main/java/edu/byu/cs/server/endpointprovider/EndpointProviderImpl.java @@ -93,17 +93,17 @@ public Handler meGet() { @Override public Handler callbackGet() { - return CasController.callbackGet; + return RedirectController.callbackGet; } @Override public Handler loginGet() { - return CasController.loginGet; + return RedirectController.loginGet; } @Override public Handler logoutPost() { - return CasController.logoutPost; + return RedirectController.logoutPost; } // ConfigController diff --git a/src/main/java/edu/byu/cs/service/AuthenticationService.java b/src/main/java/edu/byu/cs/service/AuthenticationService.java new file mode 100644 index 000000000..1f8a2613e --- /dev/null +++ b/src/main/java/edu/byu/cs/service/AuthenticationService.java @@ -0,0 +1,278 @@ +package edu.byu.cs.service; + +import java.io.IOException; +import java.net.URI; +import java.net.URLEncoder; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collection; + +import edu.byu.cs.controller.RedirectController; +import edu.byu.cs.util.JwtUtils; +import edu.byu.cs.util.NetworkUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.gson.Gson; +import com.google.gson.annotations.SerializedName; + +import edu.byu.cs.canvas.CanvasException; +import edu.byu.cs.canvas.CanvasService; +import edu.byu.cs.controller.exception.BadRequestException; +import edu.byu.cs.controller.exception.InternalServerException; +import edu.byu.cs.dataAccess.DaoService; +import edu.byu.cs.dataAccess.DataAccessException; +import edu.byu.cs.dataAccess.daoInterface.UserDao; +import edu.byu.cs.model.User; +import edu.byu.cs.properties.ApplicationProperties; + +/** + * Contains service logic for the {@link RedirectController}.
View the + * BYU API documentation + * to understand how OAuth works, if needed. Other sites on this page are a great resource as well, + * particularly verifying JWT tokens. You should also check out the {@link JwtUtils} class as well. + *

+ * The {@code AuthenticationService} ensures the user authenticates before they access + * and use the AutoGrader. + */ +public class AuthenticationService { + public static final String BYU_API_URL = "https://api-sandbox.byu.edu"; + private static final Logger LOGGER = LoggerFactory.getLogger(AuthenticationService.class); + + //initializing these 30 seconds behind the current time ensures the config is cached on start + private static Instant configExpiration = Instant.now().minusSeconds(30); + private static Instant keyExpiration = Instant.now().minusSeconds(30); + + public static OpenIDConfig config; + + + /** + * Validates an identity token and retrieves the associated user. + *
+ * If the user exists in the database, they are returned directly. Otherwise, the user + * is retrieved from Canvas and stored in the database before being returned + * + * @param ticket the identity token in the form of a JWT + * @return the user, either stored in the database or from Canvas if not + * @throws InternalServerException if an error arose during ticket validation or user retrieval + * @throws BadRequestException if JWT validation failed + * @throws DataAccessException if there was an issue storing the user in the database + * @throws CanvasException if there was an issue getting the user from Canvas + */ + public static User callback(String ticket) throws InternalServerException, BadRequestException, DataAccessException, CanvasException { + String netId; + try { + netId = AuthenticationService.validateToken(ticket); + } catch (IOException | InterruptedException e) { + LOGGER.error("Error validating ticket", e); + throw new InternalServerException("Error validating ticket", e); + } + + if (netId == null) { + throw new BadRequestException("Ticket validation failed"); + } + + UserDao userDao = DaoService.getUserDao(); + + User user; + // Check if student is already in the database + try { + user = userDao.getUser(netId); + } catch (DataAccessException e) { + LOGGER.error("Couldn't get user from database", e); + throw new InternalServerException("Couldn't get user from database", e); + } + + // If there isn't a student in the database with this netId + if (user == null) { + try { + user = CanvasService.getCanvasIntegration().getUser(netId); + } catch (CanvasException e) { + LOGGER.error("Error getting user from canvas", e); + throw e; + } + + userDao.insertUser(user); + LOGGER.info("Registered {}", user); + } + return user; + } + + /** + * + * @param token the JWT token to validate + * @return the JWT subject (currently the netid) + * @throws InternalServerException when unable to grab keys or OpenID config + */ + public static String validateToken(String token) throws InternalServerException, IOException, InterruptedException { + if (isExpired(keyExpiration)){ + refreshConfig(); + cacheJWK(); + + } + return JwtUtils.validateTokenAgainstKeys(token); + } + + public static TokenResponse exchangeCodeForTokens(String code) throws IOException, InterruptedException, InternalServerException { + refreshConfig(); + + String formData = "grant_type=authorization_code" + + "&client_id=" + URLEncoder.encode(ApplicationProperties.clientId(), StandardCharsets.UTF_8) + + "&code=" + URLEncoder.encode(code, StandardCharsets.UTF_8) + + "&redirect_uri=" + URLEncoder.encode(ApplicationProperties.casCallbackUrl(), StandardCharsets.UTF_8); + + + HttpResponse response = NetworkUtils.makeParameterizedPostRequest(config.tokenEndpoint, formData); + + return new Gson().fromJson(response.body(), TokenResponse.class); + + } + + private static void refreshConfig() throws InternalServerException, IOException, InterruptedException { + if (isExpired(configExpiration)){ + cacheBYUOpenIDConfig(); + } + } + + + public record TokenResponse( + @SerializedName("access_token") String accessToken, + @SerializedName("id_token") String idToken, + @SerializedName("refresh_token") String refreshToken, + @SerializedName("expires_in") int expiresIn, + @SerializedName("token_type") String tokenType + + ) {} + + /** + * Caches the info from the api needed to complete the OAuth transaction. + * @throws InternalServerException when there is something suspicious about the OpenID config + */ + private static void cacheBYUOpenIDConfig() throws InternalServerException, IOException, InterruptedException { + HttpResponse response = NetworkUtils.makeJsonGetRequest(BYU_API_URL + + "/.well-known/openid-configuration"); + + configExpiration = NetworkUtils.getCacheTime(response); + + OpenIDConfig config = new Gson().fromJson(response.body(), OpenIDConfig.class); + if (!isValidConfig(config)){ + throw new InternalServerException("Unable to verify OpenID config", null); + } + + AuthenticationService.config = config; + + } + + /** + * Grabs a set of JWKs from the endpoint specified in the config. These are public keys used to verify that + * JWTs received are in fact from BYU. The sandbox api should usually only have one at a time, but they + * can rotate the keys whenever, so we must be able to account for multiple. + */ + private static void cacheJWK () throws IOException, InterruptedException, InternalServerException { + + HttpResponse response = NetworkUtils.makeJsonGetRequest(config.keyUri); + + keyExpiration = NetworkUtils.getCacheTime(response); + + JwtUtils.readJWKs(response.body()); + + } + + /** + * Some of the fields delivered for the OpenID config. Call the endpoint yourself to see the full config. + * @param issuer - should be a byu api, and the specific API called + * @param authorizationEndpoint - where the browser should redirect the user on login + * @param tokenEndpoint - where the browser should confirm the redirect worked + * @param keyUri - where public keys to verify JWT tokens are received from + * @param scopes - only scope currently is openid + * @param encryptions - types of encryptions supported when signing the JWT tokens + */ + public record OpenIDConfig( + String issuer, + @SerializedName("authorization_endpoint") String authorizationEndpoint, + @SerializedName("token_endpoint") String tokenEndpoint, + @SerializedName("jwks_uri") String keyUri, + @SerializedName("scopes_supported") Collection scopes, + @SerializedName("id_token_signing_alg_values_supported")Collection encryptions + ){} + + /** + * Ensures the config came from the issuer, BYU API, and that any redirect links also are from the BYU API. + *

+ * Also logs any changes to the OpenID config that may need to be looked at. + * @param config an OpenID config + * @return true if valid, false if there's a glaring problem + */ + private static boolean isValidConfig(OpenIDConfig config){ + if (!config.issuer().equals(BYU_API_URL)){ + return false; + } + if (!config.equals(AuthenticationService.config) && AuthenticationService.config != null){ + LOGGER.info("OpenID config has changed: {}", config); + } + if (config.scopes().size()!= 1){ + LOGGER.warn("Config has multiple scopes: {}", config); + } + if (config.encryptions().size()!=1){ + LOGGER.warn("Config has multiple encryption types: {}", config); + } + return isValidUrl(config.authorizationEndpoint) && isValidUrl(config.tokenEndpoint()) && + isValidUrl(config.keyUri()); + } + + private static boolean isExpired(Instant time){ + return time.isBefore(Instant.now()); + } + + /** + * Validates that a URL uses HTTPS and has the same host as the BYU API URL. + * @param urlString the URL to validate + * @return true if the URL is valid, false otherwise + */ + private static boolean isValidUrl(String urlString) { + try { + URI uri = new URI(urlString); + URI baseUri = new URI(BYU_API_URL); + + // Verify HTTPS is used + if (!"https".equals(uri.getScheme())) { + return false; + } + + // Verify the host matches the base API URL's host + String host = uri.getHost(); + String expectedHost = baseUri.getHost(); + return host != null && host.equals(expectedHost); + + } catch (Exception e) { + return false; + } + } + + /** + * @return authorization url with parameters filled in + * @throws InternalServerException when unable to reload OpenID config + */ + public static String getAuthorizationUrl() throws InternalServerException{ + try { + if (isExpired(configExpiration)) { + cacheBYUOpenIDConfig(); + } + } catch (IOException | InterruptedException e){ + LOGGER.error("Unable to cache OpenID Config", e); + throw new InternalServerException("Unable to verify identity", e); + } + return AuthenticationService.config.authorizationEndpoint() + + "?response_type=code&client_id=" + URLEncoder.encode(ApplicationProperties.clientId(), StandardCharsets.UTF_8) + + "&redirect_uri=" + URLEncoder.encode(ApplicationProperties.casCallbackUrl(), StandardCharsets.UTF_8) + + "&scope=" + URLEncoder.encode("openid", StandardCharsets.UTF_8); + } + + /** + * Evaluates the frontend url and if it is secure returns true. + */ + public static boolean isSecure() { + return ApplicationProperties.frontendUrl().startsWith("https"); + } +} diff --git a/src/main/java/edu/byu/cs/service/CasService.java b/src/main/java/edu/byu/cs/service/CasService.java deleted file mode 100644 index 2a91fa598..000000000 --- a/src/main/java/edu/byu/cs/service/CasService.java +++ /dev/null @@ -1,114 +0,0 @@ -package edu.byu.cs.service; - -import com.fasterxml.jackson.dataformat.xml.XmlMapper; -import edu.byu.cs.canvas.CanvasException; -import edu.byu.cs.canvas.CanvasService; -import edu.byu.cs.controller.exception.BadRequestException; -import edu.byu.cs.controller.exception.InternalServerException; -import edu.byu.cs.dataAccess.DaoService; -import edu.byu.cs.dataAccess.DataAccessException; -import edu.byu.cs.dataAccess.daoInterface.UserDao; -import edu.byu.cs.model.User; -import edu.byu.cs.properties.ApplicationProperties; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.net.ssl.HttpsURLConnection; -import java.io.IOException; -import java.net.URI; -import java.util.Map; - -/** - * Contains service logic for the {@link edu.byu.cs.controller.CasController}.
View the - * Berkeley CAS docs - * to understand how CAS, or Central Authentication Service, works, if needed. - *

- * The {@code CasService} ensures user authentication using BYU's CAS before they access - * and use the AutoGrader. - */ -public class CasService { - private static final Logger LOGGER = LoggerFactory.getLogger(CasService.class); - public static final String BYU_CAS_URL = "https://cas.byu.edu/cas"; - - /** - * Validates a CAS ticket and retrieves the associated user. - *
- * If the user exists in the database, they are returned directly. Otherwise, the user - * is retrieved from Canvas and stored in the database before being returned - * - * @param ticket the CAS ticket to validate - * @return the user, either stored in the database or from Canvas if not - * @throws InternalServerException if an error arose during ticket validation or user retrieval - * @throws BadRequestException if ticket validation failed - * @throws DataAccessException if there was an issue storing the user in the database - * @throws CanvasException if there was an issue getting the user from Canvas - */ - public static User callback(String ticket) throws InternalServerException, BadRequestException, DataAccessException, CanvasException { - String netId; - try { - netId = CasService.validateCasTicket(ticket); - } catch (IOException e) { - LOGGER.error("Error validating ticket", e); - throw new InternalServerException("Error validating ticket", e); - } - - if (netId == null) { - throw new BadRequestException("Ticket validation failed"); - } - - UserDao userDao = DaoService.getUserDao(); - - User user; - // Check if student is already in the database - try { - user = userDao.getUser(netId); - } catch (DataAccessException e) { - LOGGER.error("Couldn't get user from database", e); - throw new InternalServerException("Couldn't get user from database", e); - } - - // If there isn't a student in the database with this netId - if (user == null) { - try { - user = CanvasService.getCanvasIntegration().getUser(netId); - } catch (CanvasException e) { - LOGGER.error("Error getting user from canvas", e); - throw e; - } - - userDao.insertUser(user); - LOGGER.info("Registered {}", user); - } - return user; - } - - /** - * Validates a CAS ticket and returns the netId of the user if valid
- * Berkeley CAS docs - * - * @param ticket the ticket to validate - * @return the netId of the user if valid, null otherwise - * @throws IOException if there is an error with the CAS server response - */ - public static String validateCasTicket(String ticket) throws IOException { - String validationUrl = BYU_CAS_URL + "/serviceValidate" + "?ticket=" + ticket + "&service=" + ApplicationProperties.casCallbackUrl(); - - - URI uri = URI.create(validationUrl); - HttpsURLConnection connection = (HttpsURLConnection) uri.toURL().openConnection(); - - try { - String body = new String(connection.getInputStream().readAllBytes()); - - Map casServiceResponse = XmlMapper.builder().build().readValue(body, Map.class); - return (String) ((Map) casServiceResponse.get("authenticationSuccess")).get("user"); - - } catch (Exception e) { - LOGGER.error("Error with response from CAS server:", e); - throw e; - } finally { - connection.disconnect(); - } - } - -} diff --git a/src/main/java/edu/byu/cs/util/JwtUtils.java b/src/main/java/edu/byu/cs/util/JwtUtils.java index 55afae8c2..33e01504a 100644 --- a/src/main/java/edu/byu/cs/util/JwtUtils.java +++ b/src/main/java/edu/byu/cs/util/JwtUtils.java @@ -1,11 +1,15 @@ package edu.byu.cs.util; -import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.*; +import io.jsonwebtoken.security.Jwk; +import io.jsonwebtoken.security.JwkSet; +import io.jsonwebtoken.security.Jwks; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.crypto.KeyGenerator; import javax.crypto.SecretKey; +import java.security.Key; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.time.Instant; @@ -20,6 +24,7 @@ */ public class JwtUtils { private static final SecretKey key = generateSecretKey(); + private static JwkSet byuPublicKeys; private static final Logger LOGGER = LoggerFactory.getLogger(JwtUtils.class); @@ -69,4 +74,31 @@ private static SecretKey generateSecretKey() { keyGenerator.init(512, new SecureRandom()); return keyGenerator.generateKey(); } + + public static String validateTokenAgainstKeys(String token){ + Locator locator = header -> { + for (Jwk key : byuPublicKeys) { + if (header instanceof ProtectedHeader protectedHeader) { + if (protectedHeader.getKeyId().equals(key.getId())) { + return key.toKey(); + } + } + } + return null; + }; + return Jwts.parser() + .keyLocator(locator) + .build() + .parseSignedClaims(token) + .getPayload() + .getSubject(); + } + + public static void readJWKs(String json){ + byuPublicKeys = Jwks.setParser() + .build() + .parse(json); + } + + } diff --git a/src/main/java/edu/byu/cs/util/NetworkUtils.java b/src/main/java/edu/byu/cs/util/NetworkUtils.java index d1d0e3cdf..2ed06b3fd 100644 --- a/src/main/java/edu/byu/cs/util/NetworkUtils.java +++ b/src/main/java/edu/byu/cs/util/NetworkUtils.java @@ -1,5 +1,6 @@ package edu.byu.cs.util; +import edu.byu.cs.controller.exception.InternalServerException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -8,6 +9,9 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.time.Instant; +import java.util.NoSuchElementException; +import java.util.Optional; /** * A utility class that provides methods for making HTTP Requests @@ -17,19 +21,42 @@ public class NetworkUtils { private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUtils.class); /** - * Leverages the built-in {@link java.net.http.HttpClient} library to make a basic HTTP Get request. + * Leverages the built-in {@link java.net.http.HttpClient} library to make an HTTP Get request. + * Requires a json response. * * @param url The URL to request * @return The {@link HttpResponse} response, or the errors generated in the process. */ - public static HttpResponse makeGetRequest(String url) throws IOException, InterruptedException { + public static HttpResponse makeJsonGetRequest(String url) throws IOException, InterruptedException { try (HttpClient httpClient = HttpClient.newHttpClient()) { HttpRequest request = HttpRequest.newBuilder() .uri(URI.create(url)) + .header("Accept", "application/json") .GET() // HTTP GET method .build(); - return httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + if (isFailure(response.statusCode())){ + LOGGER.warn("Error making GET request to '{}': {} status returned", url, response.statusCode()); + } + return response; + } + } + + public static HttpResponse makeParameterizedPostRequest(String url, String formData) throws IOException, InterruptedException { + try(HttpClient httpClient = HttpClient.newHttpClient()){ + HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(formData)) + .build(); + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + if (isFailure(response.statusCode())){ + LOGGER.warn("Error making POST request to '{}': {} status returned with form data: {}", + url, response.statusCode(), formData); + } + return response; } } @@ -41,7 +68,8 @@ public static HttpResponse makeGetRequest(String url) throws IOException */ public static String readGetRequestBody(String url) { try { - HttpResponse response = makeGetRequest(url); + HttpResponse response = makeJsonGetRequest(url); + return response.body(); } catch (IOException | InterruptedException e) { System.err.print("Error making GET request to '" + url + "': " + e.getMessage()); @@ -49,4 +77,29 @@ public static String readGetRequestBody(String url) { return null; } } + + private static boolean isFailure(int status){ + return status / 100 != 2; + } + + /** + * Grabs the Instant that the Cache-Control indicates expiration + * + * @param response http response with Cache-Control header + * @return Expire time of given response + * @throws InternalServerException if unable to find the Cache-Control header + */ + public static Instant getCacheTime(HttpResponse response) throws InternalServerException { + Optional cache = response.headers().firstValue("Cache-Control"); + try{ + String seconds = cache.get().replace("max-age=", ""); + if (Long.parseLong(seconds) > 0){ + return Instant.now().plusSeconds(Long.parseLong(seconds)); + } + else throw new InternalServerException("Invalid cache time", new IllegalArgumentException()); + } + catch (NoSuchElementException e) { + throw new InternalServerException("Unable to determine cache time", e); + } + } } diff --git a/src/test/java/edu/byu/cs/service/AuthenticationServiceTests.java b/src/test/java/edu/byu/cs/service/AuthenticationServiceTests.java new file mode 100644 index 000000000..f0a59761d --- /dev/null +++ b/src/test/java/edu/byu/cs/service/AuthenticationServiceTests.java @@ -0,0 +1,217 @@ +package edu.byu.cs.service; + +import com.google.gson.Gson; +import edu.byu.cs.controller.exception.BadRequestException; +import edu.byu.cs.controller.exception.InternalServerException; +import edu.byu.cs.dataAccess.DaoService; +import edu.byu.cs.dataAccess.daoInterface.UserDao; +import edu.byu.cs.model.User; +import edu.byu.cs.properties.ApplicationProperties; +import edu.byu.cs.util.JwtUtils; +import edu.byu.cs.util.NetworkUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; + +import java.io.IOException; +import java.net.URLEncoder; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +class AuthenticationServiceTests { + + private UserDao mockUserDao; + private User testUser; + + + @BeforeEach + void setUp() { + mockUserDao = mock(UserDao.class); + testUser = new User("test_netid", + 0, + "FirstName", + "LastName", + null, + User.Role.ADMIN); + } + + // ==================== callback() Tests ==================== + + //tests valid url and valid config + @Test + @DisplayName("callback should return user from database if user exists") + void callback_returnUserFromDatabase_whenUserExists() throws Exception { + try (MockedStatic daoServiceMock = mockStatic(DaoService.class); + MockedStatic jwtUtilsMock = mockStatic(JwtUtils.class); + MockedStatic networkUtilsMock = mockStatic(NetworkUtils.class)) { + + daoServiceMock.when(DaoService::getUserDao).thenReturn(mockUserDao); + when(mockUserDao.getUser("test_netid")).thenReturn(testUser); + jwtUtilsMock.when(() -> JwtUtils.validateTokenAgainstKeys("valid_token")).thenReturn("test_netid"); + + setupMockedValidOpenIDConfig(networkUtilsMock); + + User result = AuthenticationService.callback("valid_token"); + + assertEquals(testUser, result); + verify(mockUserDao).getUser("test_netid"); + verify(mockUserDao, never()).insertUser(any()); + } + } + + @Test + @DisplayName("callback should fetch user from Canvas and store in database if user doesn't exist") + void callback_fetchFromCanvasAndStore_whenUserNotInDatabase() throws Exception { + try (MockedStatic daoServiceMock = mockStatic(DaoService.class); + MockedStatic jwtUtilsMock = mockStatic(JwtUtils.class); + MockedStatic networkUtilsMock = mockStatic(NetworkUtils.class); + MockedStatic appPropsMock = mockStatic(ApplicationProperties.class)) { + + daoServiceMock.when(DaoService::getUserDao).thenReturn(mockUserDao); + when(mockUserDao.getUser("test_netid")).thenReturn(null); + jwtUtilsMock.when(() -> JwtUtils.validateTokenAgainstKeys("valid_token")).thenReturn("test_netid"); + + appPropsMock.when(ApplicationProperties::useCanvas).thenReturn(false); + + setupMockedValidOpenIDConfig(networkUtilsMock); + + User result = AuthenticationService.callback("valid_token"); + + assertEquals(testUser, result); + verify(mockUserDao).insertUser(testUser); + } + } + + @Test + @DisplayName("callback should throw BadRequestException when token validation returns null") + void callback_throwBadRequestException_whenTokenValidationFails() throws Exception { + try (MockedStatic jwtUtilsMock = mockStatic(JwtUtils.class); + MockedStatic networkUtilsMock = mockStatic(NetworkUtils.class)) { + + jwtUtilsMock.when(() -> JwtUtils.validateTokenAgainstKeys("invalid_token")).thenReturn(null); + setupMockedValidOpenIDConfig(networkUtilsMock); + + assertThrows(BadRequestException.class, () -> AuthenticationService.callback("invalid_token")); + } + } + + @Test + @DisplayName("callback should throw InternalErrorException when OpenID returned is suspicious") + void callback_throwInternalErrorException_whenOpenIDCacheLooksSuspicious() throws Exception { + try (MockedStatic daoServiceMock = mockStatic(DaoService.class); + MockedStatic jwtUtilsMock = mockStatic(JwtUtils.class); + MockedStatic networkUtilsMock = mockStatic(NetworkUtils.class); + MockedStatic appPropsMock = mockStatic(ApplicationProperties.class)) { + + daoServiceMock.when(DaoService::getUserDao).thenReturn(mockUserDao); + when(mockUserDao.getUser("test_netid")).thenReturn(null); + jwtUtilsMock.when(() -> JwtUtils.validateTokenAgainstKeys("valid_token")).thenReturn("test_netid"); + + appPropsMock.when(ApplicationProperties::useCanvas).thenReturn(false); + + setupSuspiciousValidOpenIDConfig(networkUtilsMock); + + Assertions.assertThrows(InternalServerException.class, + ()-> AuthenticationService.callback("valid_token")); + + verify(mockUserDao, times(0)).insertUser(testUser); + } + } + + // ==================== isSecure() Tests ==================== + + @Test + @DisplayName("isSecure should return true when frontend URL starts with https") + void isSecure_returnTrue_whenFrontendUrlIsHttps() { + try (MockedStatic appPropsMock = mockStatic(ApplicationProperties.class)) { + appPropsMock.when(ApplicationProperties::frontendUrl).thenReturn("https://example.com"); + + assertTrue(AuthenticationService.isSecure()); + } + } + + @Test + @DisplayName("isSecure should return false when frontend URL does not start with https") + void isSecure_returnFalse_whenFrontendUrlIsNotHttps() { + try (MockedStatic appPropsMock = mockStatic(ApplicationProperties.class)) { + appPropsMock.when(ApplicationProperties::frontendUrl).thenReturn("http://example.com"); + + assertFalse(AuthenticationService.isSecure()); + } + } + + // ==================== Authorization Url Test ============ + + @Test + @DisplayName("getAuthorizationUrl returns a filled out authorization request") + void getAuthUrl() throws Exception{ + try (MockedStatic networkUtilsMock = mockStatic(NetworkUtils.class); + MockedStatic appPropsMock = mockStatic(ApplicationProperties.class)) { + + + appPropsMock.when(ApplicationProperties::casCallbackUrl).thenReturn("https://cs240.click/auth/callback"); + appPropsMock.when(ApplicationProperties::clientId).thenReturn("cs240"); + setupMockedValidOpenIDConfig(networkUtilsMock); + + Assertions.assertEquals("https://api-sandbox.byu.edu/auth?response_type=code" + + "&client_id=cs240" + + "&redirect_uri=" + + URLEncoder.encode("https://cs240.click/auth/callback", StandardCharsets.UTF_8) + + "&scope=openid", + AuthenticationService.getAuthorizationUrl()); + + } + } + + // ==================== Helper Methods ==================== + + private void setupMockedValidOpenIDConfig(MockedStatic networkUtilsMock) throws IOException, InterruptedException { + HttpResponse mockResponse = mock(HttpResponse.class); + when(mockResponse.body()).thenReturn(createValidOpenIDConfigJson()); + networkUtilsMock.when(() -> NetworkUtils.makeJsonGetRequest(anyString())) + .thenReturn(mockResponse); + networkUtilsMock.when(() -> NetworkUtils.getCacheTime(any())) + .thenReturn(Instant.now().plusSeconds(3600)); + } + + private void setupSuspiciousValidOpenIDConfig(MockedStatic networkUtilsMock){ + HttpResponse mockResponse = mock(HttpResponse.class); + when(mockResponse.body()).thenReturn(createBadIssuerOpenIDConfigJson()); + networkUtilsMock.when(() -> NetworkUtils.makeJsonGetRequest(anyString())) + .thenReturn(mockResponse); + networkUtilsMock.when(() -> NetworkUtils.getCacheTime(any())) + .thenReturn(Instant.now().plusSeconds(0)); + } + + private String createValidOpenIDConfigJson() { + AuthenticationService.OpenIDConfig config = new AuthenticationService.OpenIDConfig( + AuthenticationService.BYU_API_URL, + "https://api-sandbox.byu.edu/auth", + "https://api-sandbox.byu.edu/token", + "https://api-sandbox.byu.edu/jwks", + List.of("openid"), + List.of("RS256") + ); + return new Gson().toJson(config); + } + + private String createBadIssuerOpenIDConfigJson(){ + AuthenticationService.OpenIDConfig config = new AuthenticationService.OpenIDConfig( + "https://badactor.com", + "https://api-sandbox.byu.edu/auth", + "https://api-sandbox.byu.edu/token", + "https://api-sandbox.byu.edu/jwks", + List.of("openid"), + List.of("RS256") + ); + return new Gson().toJson(config); + } +} diff --git a/src/test/java/edu/byu/cs/util/JwtUtilsTest.java b/src/test/java/edu/byu/cs/util/JwtUtilsTest.java index dd406cc42..082d08905 100644 --- a/src/test/java/edu/byu/cs/util/JwtUtilsTest.java +++ b/src/test/java/edu/byu/cs/util/JwtUtilsTest.java @@ -1,11 +1,21 @@ package edu.byu.cs.util; import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.jackson.io.JacksonSerializer; +import io.jsonwebtoken.security.*; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import java.nio.charset.StandardCharsets; +import java.security.*; +import java.security.KeyPair; +import io.jsonwebtoken.security.SignatureException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; import static org.junit.jupiter.api.Assertions.*; @@ -41,6 +51,60 @@ void validateToken__expired() { assertNull(JwtUtils.validateToken(token)); } + @ParameterizedTest(name = "validateTokenAgainst{0}Keys") + @ValueSource(ints = {1, 2, 3}) + void validateTokenAgainstKeys(int size) throws Exception{ + HashMap map = generateKeyPairs(size); + JwkSet set = generateJwks(map); + + String token = generateToken(map.get(size-1).getPrivate(), size-1); + byte[] bytes = new JacksonSerializer().serialize(set); + String serialized = new String(bytes, StandardCharsets.UTF_8); + JwtUtils.readJWKs(serialized); + String netId = JwtUtils.validateTokenAgainstKeys(token); + assertEquals("testNetId", netId); + } + + @Test + void invalidTokenNotVerifiedByAnyKey() throws Exception{ + HashMap map = generateKeyPairs(3); + JwkSet set = generateJwks(map); + + //sign with a fake key + KeyPair fake = generateKeyPairs(1).get(0); + + String token = generateToken(fake.getPrivate(), 2); + byte[] bytes = new JacksonSerializer().serialize(set); + String serialized = new String(bytes, StandardCharsets.UTF_8); + JwtUtils.readJWKs(serialized); + assertThrows(SignatureException.class, ()-> JwtUtils.validateTokenAgainstKeys(token)); + } + + private HashMap generateKeyPairs(int size) throws NoSuchAlgorithmException { + KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA"); + generator.initialize(2048); + HashMap pairs = new HashMap<>(); + for (int i = 0; i < size; i++){ + KeyPair pair = generator.generateKeyPair(); + pairs.put(i, pair); + } + return pairs; + } + + private JwkSet generateJwks(HashMap pairs) { + HashSet> set = new HashSet<>(); + for (int i = 0; i < pairs.size(); i++){ + KeyPair pair = pairs.get(i); + Jwk jwk = Jwks.builder() + .key(pair.getPublic()) + .id(Integer.toString(i)) + .build(); + set.add(jwk); + } + JwkSet jwks = Jwks.set().add(set).build(); + return jwks; + } + private String generateToken(boolean expired) { Instant expiration = expired ? Instant.now().minus(1, ChronoUnit.HOURS) @@ -50,4 +114,15 @@ private String generateToken(boolean expired) { .expiration(Date.from(expiration)) .compact(); } + + private String generateToken(PrivateKey key, int id){ + return Jwts.builder() + .header() + .keyId(Integer.toString(id)) + .and() + .subject("testNetId") + .signWith(key) + .compact(); + } + }