diff --git a/play-v29/src/main/scala/com/gu/googleauth/actions.scala b/play-v29/src/main/scala/com/gu/googleauth/actions.scala index 8989a5a..6196807 100644 --- a/play-v29/src/main/scala/com/gu/googleauth/actions.scala +++ b/play-v29/src/main/scala/com/gu/googleauth/actions.scala @@ -89,6 +89,8 @@ class AuthAction[A](val authConfig: GoogleAuthConfig, loginTarget: Call, bodyPar } trait LoginSupport extends Logging { + import Actions.GroupCheckConfig + implicit def wsClient: WSClient /** @@ -154,7 +156,7 @@ trait LoginSupport extends Logging { * Looks up user's Google Groups and ensures they belong to any that are required. Redirects to * `failureRedirectTarget` if the user is not a member of any required group. */ - def enforceGoogleGroups(userIdentity: UserIdentity, requiredGoogleGroups: Set[String], googleGroupChecker: GoogleGroupChecker, errorMessage: String = "Login failure. You do not belong to the required Google groups") + def enforceGoogleGroups(userIdentity: UserIdentity, groupCheckConfig: GroupCheckConfig, googleGroupChecker: GoogleGroupChecker, errorMessage: String = "Login failure. You do not belong to the required Google groups") (implicit request: RequestHeader, ec: ExecutionContext): EitherT[Future, Result, Unit] = { googleGroupChecker.retrieveGroupsFor(userIdentity.email).attemptT .leftMap { t => @@ -162,7 +164,7 @@ trait LoginSupport extends Logging { redirectWithError(failureRedirectTarget, "Login failure. Unable to look up Google Group membership") } .subflatMap { userGroups => - if (Actions.checkGoogleGroups(userGroups, requiredGoogleGroups)) { + if (Actions.checkGoogleGroups(userGroups, groupCheckConfig)) { Right(()) } else { logger.info("Login failure, user not in required Google groups") @@ -187,11 +189,22 @@ trait LoginSupport extends Logging { * * Also ensures the user belongs to the (provided) required Google Groups. */ + @deprecated("Prefer to pass in a GroupCheckConfig object instead.") def processOauth2Callback(requiredGoogleGroups: Set[String], groupChecker: GoogleGroupChecker) + (implicit request: RequestHeader, ec: ExecutionContext): Future[Result] = { + val groupCheckConfig = GroupCheckConfig(requiredGroups = Some(requiredGoogleGroups)) + processOauth2Callback(groupCheckConfig, groupChecker) + } + /** + * Handle the OAuth2 callback, which logs the user in and redirects them appropriately. + * + * Also ensures the user is in the correct Google Groups, as defined by the given GroupCheckConfig + */ + def processOauth2Callback(groupCheckConfig: GroupCheckConfig, groupChecker: GoogleGroupChecker) (implicit request: RequestHeader, ec: ExecutionContext): Future[Result] = { (for { identity <- checkIdentity() - _ <- enforceGoogleGroups(identity, requiredGoogleGroups, groupChecker) + _ <- enforceGoogleGroups(identity, groupCheckConfig, groupChecker) } yield { setupSessionWhenSuccessful(identity) }).merge @@ -216,9 +229,30 @@ trait LoginSupport extends Logging { } object Actions { - private[googleauth] def checkGoogleGroups(userGroups: Set[String], requiredGroups: Set[String]): Boolean = { + /** + * @param requiredGroups If defined, user must be a member of *all* groups in requiredGroups + * @param allowedGroups If defined, user must be a member of *at least one of* the groups in allowedGroups + */ + case class GroupCheckConfig( + requiredGroups: Option[Set[String]] = None, + allowedGroups: Option[Set[String]] = None + ) + + private[googleauth] def checkGoogleGroups(userGroups: Set[String], groupCheckConfig: GroupCheckConfig): Boolean = { + val requiredGroupCheck = groupCheckConfig.requiredGroups.map(required => Actions.checkRequiredGoogleGroups(userGroups, required)).getOrElse(true) + val allowedGroupCheck = groupCheckConfig.allowedGroups.map(allowed => Actions.checkAllowedGoogleGroups(userGroups, allowed)).getOrElse(true) + requiredGroupCheck && allowedGroupCheck + } + + // User must be a member of *all* groups in requiredGroups + private def checkRequiredGoogleGroups(userGroups: Set[String], requiredGroups: Set[String]): Boolean = { userGroups.intersect(requiredGroups) == requiredGroups } + + // User must be a member of *at least one* of the groups in allowedGroups + private def checkAllowedGoogleGroups(userGroups: Set[String], allowedGroups: Set[String]): Boolean = { + allowedGroups.intersect(userGroups).nonEmpty + } } trait Filters extends UserIdentifier with Logging { diff --git a/play-v29/src/sbt-test/example/webapp/app/controllers/Login.scala b/play-v29/src/sbt-test/example/webapp/app/controllers/Login.scala index 9cf297d..b6a72cd 100644 --- a/play-v29/src/sbt-test/example/webapp/app/controllers/Login.scala +++ b/play-v29/src/sbt-test/example/webapp/app/controllers/Login.scala @@ -1,6 +1,7 @@ package controllers import com.gu.googleauth.{GoogleAuthConfig, GoogleGroupChecker, LoginSupport} +import com.gu.googleauth.Actions.GroupCheckConfig import play.api.libs.ws.WSClient import play.api.mvc._ @@ -34,7 +35,7 @@ class Login(requiredGoogleGroups: Set[String], val authConfig: GoogleAuthConfig, */ def oauth2Callback = Action.async { implicit request => // processOauth2Callback() // without Google group membership checks - processOauth2Callback(requiredGoogleGroups, googleGroupChecker) // with optional Google group checks + processOauth2Callback(GroupCheckConfig(requiredGroups = Some(requiredGoogleGroups)), googleGroupChecker) // with optional Google group checks } def logout = Action { implicit request => diff --git a/play-v29/src/test/scala/com/gu/googleauth/GoogleAuthTest.scala b/play-v29/src/test/scala/com/gu/googleauth/GoogleAuthTest.scala index 73aa798..69c7685 100644 --- a/play-v29/src/test/scala/com/gu/googleauth/GoogleAuthTest.scala +++ b/play-v29/src/test/scala/com/gu/googleauth/GoogleAuthTest.scala @@ -2,6 +2,7 @@ package com.gu.googleauth import com.gu.play.secretrotation.DualSecretTransition.TransitioningSecret import com.gu.play.secretrotation.SnapshotProvider +import com.gu.googleauth.Actions.GroupCheckConfig import org.apache.commons.codec.binary.Base64 import org.scalatest.freespec.AsyncFreeSpec import org.scalatest.matchers.should.Matchers @@ -15,30 +16,65 @@ import scala.concurrent.ExecutionContext.global class GoogleAuthTest extends AsyncFreeSpec with Matchers { - "enforceUserGroups" - { + "requiredGroups check" - { val requiredGroups = Set("required-group-1", "required-group-2") "returns false if the user has no groups" in { val userGroups = Set.empty[String] - val result = Actions.checkGoogleGroups(userGroups, requiredGroups) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(requiredGroups = Some(requiredGroups))) result shouldEqual false } "returns false if the user is missing a group" in { val userGroups = Set(requiredGroups.head) - val result = Actions.checkGoogleGroups(userGroups, requiredGroups) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(requiredGroups = Some(requiredGroups))) result shouldEqual false } "returns false if the user has other groups but is missing a required group" in { val userGroups = Set(requiredGroups.head, "example-group", "another-group") - val result = Actions.checkGoogleGroups(userGroups, requiredGroups) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(requiredGroups = Some(requiredGroups))) result shouldEqual false } "returns true if the required groups are present" in { val userGroups = requiredGroups + "example-group" - val result = Actions.checkGoogleGroups(userGroups, requiredGroups) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(requiredGroups = Some(requiredGroups))) + result shouldEqual true + } + } + + "allowedGroups check" - { + val requiredGroups = Set("required-group-1", "required-group-2") + val allowedGroups = Set("allowed-group-1", "allowed-group-2") + + "returns false if the user has no groups" in { + val userGroups = Set.empty[String] + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(allowedGroups = Some(allowedGroups))) + result shouldEqual false + } + + "returns false if the user has other groups but is missing all allowed groups" in { + val userGroups = Set("example-group", "another-group") + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(allowedGroups = Some(allowedGroups))) + result shouldEqual false + } + + "returns true if the user has one group and is missing another group" in { + val userGroups = Set(allowedGroups.head) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(allowedGroups = Some(allowedGroups))) + result shouldEqual true + } + + "returns false if the user has an allowed group but is missing required groups" in { + val userGroups = Set(allowedGroups.head) + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(allowedGroups = Some(allowedGroups), requiredGroups = Some(requiredGroups))) + result shouldEqual false + } + + "returns true if the user has an allowed group and has required groups" in { + val userGroups = Set(allowedGroups.head) ++ requiredGroups + val result = Actions.checkGoogleGroups(userGroups, GroupCheckConfig(allowedGroups = Some(allowedGroups), requiredGroups = Some(requiredGroups))) result shouldEqual true } }