Skip to content

Commit

Permalink
Parse response with Play Json. Fix error handling. Code factoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
waisingyiu committed Apr 10, 2024
1 parent 58f6762 commit e29c79c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ class AndroidSender(val config: FcmWorkerConfiguration, val firebaseAppName: Opt
override def deliverChunkedTokens(chunkedTokenStream: Stream[IO, (ChunkedTokens, Long, Instant, Int)]): Stream[IO, Unit] = {
chunkedTokenStream.map {
case (chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize) =>
logger.info(Map("notificationId" -> chunkedTokens.notification.id), s"Sending notification ${chunkedTokens.notification.id} in batches")
deliverBatchNotificationStream(Stream.emits(chunkedTokens.toBatchNotificationToSends).covary[IO])
.broadcastTo(
reportBatchSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
reportBatchLatency(chunkedTokens, chunkedTokens.metadata),
cleanupBatchFailures(chunkedTokens.notification.id),
trackBatchProgress(chunkedTokens.notification.id))
if (false) {
logger.info(Map("notificationId" -> chunkedTokens.notification.id), s"Sending notification ${chunkedTokens.notification.id} in batches")
deliverBatchNotificationStream(Stream.emits(chunkedTokens.toBatchNotificationToSends).covary[IO])
.broadcastTo(
reportBatchSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
reportBatchLatency(chunkedTokens, chunkedTokens.metadata),
cleanupBatchFailures(chunkedTokens.notification.id),
trackBatchProgress(chunkedTokens.notification.id))
}
else
deliverIndividualNotificationStream(Stream.emits(chunkedTokens.toNotificationToSends).covary[IO])
.broadcastTo(
reportSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
cleanupFailures,
trackProgress(chunkedTokens.notification.id))
}.parJoin(maxConcurrency)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.gu.notifications.worker.delivery.fcm
import _root_.models.Notification
import com.google.api.core.{ApiFuture, ApiFutureCallback, ApiFutures}
import com.google.api.client.json.JsonFactory
import com.google.auth.oauth2.GoogleCredentials
import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials}
import com.google.firebase.messaging._
import com.google.firebase.{ErrorCode, FirebaseApp, FirebaseOptions}
import com.gu.notifications.worker.delivery.DeliveryException.{BatchCallFailedRequest, FailedRequest, InvalidToken, UnknownReasonFailedRequest}
Expand Down Expand Up @@ -42,15 +42,11 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
ErrorCode.PERMISSION_DENIED
)

// private val requestFactory = ApiClientUtils.newAuthorizedRequestFactory(app)
// private val childRequestFactory = ApiClientUtils.newUnauthorizedRequestFactory(app)
// private val jsonFactory = app.getOptions().getJsonFactory()

def close(): Unit = firebaseApp.delete()

private final val FCM_URL: String = s"https://fcm.googleapis.com/v1/projects/${projectId}/messages:send";

private val fcmClient: FcmTransportMultiplexedHttp2Impl = new FcmTransportMultiplexedHttp2Impl(credential, FCM_URL, jsonFactory)
private val fcmTransport: FcmTransportOkhttpImpl = new FcmTransportOkhttpImpl(credential, FCM_URL, jsonFactory)

def payloadBuilder: Notification => Option[FcmPayload] = n => FcmPayloadBuilder(n, config.debug)

Expand All @@ -76,7 +72,7 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
} else {
import FirebaseHelpers._
val start = Instant.now
fcmClient.sendAsync(token, payload, dryRun)
fcmTransport.sendAsync(token, payload, dryRun)
// firebaseMessaging
// .sendAsync(message)
// .asScala
Expand Down Expand Up @@ -124,7 +120,7 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
}
}

def parseSendResponse(
def parseSendResponseOld(
notificationId: UUID, token: String, response: Try[String], requestCompletionTime: Instant
): Either[DeliveryException, Success] = response match {
case Success(messageId) =>
Expand All @@ -139,6 +135,27 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
Left(UnknownReasonFailedRequest(notificationId, token))
}

def parseSendResponse(
notificationId: UUID, token: String, response: Try[String], requestCompletionTime: Instant
): Either[DeliveryException, Success] = response match {
case Success(messageId) =>
Right(FcmDeliverySuccess(token, messageId, requestCompletionTime))
case Failure(e: InvalidTokenException) =>
Left(InvalidToken(notificationId, token, e.getMessage()))
case Failure(e: FcmServerException) =>
Left(FailedRequest(notificationId, token, e, Option(e.details.status)))
case Failure(e: UnknownException) =>
Left(FailedRequest(notificationId, token, e, Option(e.details.status)))
case Failure(e: InvalidResponseException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(e: QuotaExceededException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(e: FcmServerTransportException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(_) =>
Left(UnknownReasonFailedRequest(notificationId, token))
}

def parseBatchSendResponse(
notificationId: UUID, tokens: List[String], triedResponse: Try[BatchResponse], requestCompletionTime: Instant
)(cb: Either[DeliveryException, BatchSuccess] => Unit): Unit = triedResponse match {
Expand Down Expand Up @@ -175,7 +192,11 @@ object FcmClient {
case None => FirebaseApp.initializeApp(firebaseOptions)
case Some(name) => FirebaseApp.initializeApp(firebaseOptions, name)
}
new FcmClient(FirebaseMessaging.getInstance(firebaseApp), firebaseApp, config, firebaseOptions.getProjectId(), credential, firebaseOptions.getJsonFactory())
val projectId = credential match {
case s: ServiceAccountCredentials => s.getProjectId()
case _ => ""
}
new FcmClient(FirebaseMessaging.getInstance(firebaseApp), firebaseApp, config, projectId, credential, firebaseOptions.getJsonFactory())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,53 @@ import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit
import scala.jdk.CollectionConverters._
import scala.concurrent.{Future, Promise}
import scala.util.{Try, Failure}
import scala.util.{Try, Success, Failure}
import scala.util.Random
import okhttp3.{Call, Callback, ConnectionPool, Dns, Headers, MediaType, OkHttpClient, Request, RequestBody, Response, ResponseBody}
import play.api.libs.json.{Format, Json, JsError, JsValue, JsSuccess}
import com.google.auth.oauth2.GoogleCredentials
import com.google.api.client.json.{Json, JsonFactory, JsonParser, JsonGenerator}
import com.google.api.client.json.{JsonFactory, JsonGenerator}
import com.google.firebase.messaging._
import com.google.api.client.util.Key
import com.gu.notifications.worker.delivery.FcmPayload
import com.gu.notifications.worker.delivery.fcm.models.payload.{FcmResponse, FcmError, FcmErrorPayload}

case class FcmResponse(@Key("name") messageId: String)

case class FcmErrorPayload(@Key("code") code: Int, @Key("message") message: String, @Key("status") status: String)

case class FcmError(@Key("error") payload: FcmErrorPayload)

case class InvalidResponseException(message: String, responseBody: String, err: Throwable) extends Exception(message, err)
case class InvalidResponseException(responseBody: String) extends Exception("Invalid success response") {
override def getMessage(): String = {
s"${super.getMessage()}. Response body: [${responseBody.take(200)}]"
}
}

case class QuotaExceededException(message: String, details: FcmErrorPayload) extends Exception(message)
case class QuotaExceededException(details: FcmErrorPayload) extends Exception("Request quota exceeded") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class InvalidTokenException(message: String, details: FcmErrorPayload) extends Exception(message)
case class InvalidTokenException(details: FcmErrorPayload) extends Exception("Invalid device token") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class FcmServerException(message: String, details: FcmErrorPayload) extends Exception(message)
case class FcmServerException(details: FcmErrorPayload) extends Exception("FCM server error") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class UnknownException(message: String, details: FcmErrorPayload) extends Exception(message)
case class UnknownException(details: FcmErrorPayload) extends Exception("Unexpected exception") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class FcmServerTransportException(message: String, ex: Throwable) extends Exception(message, ex)
case class FcmServerTransportException(ex: Throwable) extends Exception("Failed to send HTTP request", ex) {
override def getMessage(): String = {
s"${super.getMessage()}. Reason: ${ex.getMessage()}]"
}
}

class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: String, jsonFactory: JsonFactory) {
class FcmTransportOkhttpImpl(credential: GoogleCredentials, url: String, jsonFactory: JsonFactory) {

private val okHttpClient: OkHttpClient = new OkHttpClient.Builder()
.dns((hostname: String) => Random.shuffle(Dns.SYSTEM.lookup(hostname).asScala).asJava)
Expand All @@ -45,7 +64,11 @@ class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: Strin

private val charSet = StandardCharsets.UTF_8

private val mediaType = MediaType.parse(Json.MEDIA_TYPE)
private val mediaType = MediaType.parse("application/json; charset=UTF-8")

private val authFcmScope = "https://www.googleapis.com/auth/firebase.messaging"

private val scopedCredential = credential.createScoped(authFcmScope)

def shutdown: Unit = {
okHttpClient.dispatcher().executorService().shutdown()
Expand All @@ -54,8 +77,8 @@ class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: Strin
}

private def getAccessToken(): String = {
credential.refreshIfExpired()
credential.getAccessToken().getTokenValue()
scopedCredential.refreshIfExpired()
scopedCredential.getAccessToken().getTokenValue()
}

private def createBody(message: Message, dryRun: Boolean): Array[Byte] = {
Expand All @@ -71,18 +94,20 @@ class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: Strin
sink.toByteArray()
}

private def parseBody(responseBody: ResponseBody): Try[FcmResponse] = Try {
val parser: JsonParser = jsonFactory.createJsonParser(responseBody.byteStream())
parser.parseAndClose(FcmResponse.getClass()).asInstanceOf[FcmResponse]
}.recoverWith {
case e => Failure(InvalidResponseException("Invalid response for success", responseBody.string(), e))
private def parseBody(responseBody: ResponseBody): Try[FcmResponse] = {
val json: JsValue = Json.parse(responseBody.string())
json.validate[FcmResponse] match {
case JsSuccess(message, _) => Success(message)
case JsError(errors) => Failure(InvalidResponseException(responseBody.string()))
}
}

private def parseError(responseBody: ResponseBody): Try[FcmError] = Try {
val parser: JsonParser = jsonFactory.createJsonParser(responseBody.byteStream())
parser.parseAndClose(FcmError.getClass()).asInstanceOf[FcmError]
}.recoverWith {
case e => Failure(InvalidResponseException("Invalid response for failed", responseBody.string(), e))
private def parseError(responseBody: ResponseBody): Try[FcmError] = {
val json: JsValue = Json.parse(responseBody.string())
json.validate[FcmError] match {
case JsSuccess(message, _) => Success(message)
case JsError(errors) => Failure(InvalidResponseException(responseBody.string()))
}
}

val invalidTokenErrorCodes = Set(
Expand All @@ -102,11 +127,11 @@ class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: Strin
if (response.code == 200)
parseBody(response.body())
else
parseError(response.body()).flatMap(fcmError => fcmError.payload.status match {
case code if invalidTokenErrorCodes.contains(code) => Failure(InvalidTokenException("Invalid device token", fcmError.payload))
case code if internalServerErrorCodes.contains(code) => Failure(FcmServerException("FCM server error", fcmError.payload))
case code if quotaExceededErrorCodes.contains(code) => Failure(QuotaExceededException("Request quota exceeded", fcmError.payload))
case _ => Failure(UnknownException("Unexpected exception", fcmError.payload))
parseError(response.body()).flatMap(fcmError => fcmError.error.status match {
case code if invalidTokenErrorCodes.contains(code) => Failure(InvalidTokenException(fcmError.error))
case code if internalServerErrorCodes.contains(code) => Failure(FcmServerException(fcmError.error))
case code if quotaExceededErrorCodes.contains(code) => Failure(QuotaExceededException(fcmError.error))
case _ => Failure(UnknownException(fcmError.error))
})
}

Expand All @@ -124,15 +149,16 @@ class FcmTransportMultiplexedHttp2Impl(credential: GoogleCredentials, url: Strin
.build();
val p = Promise[String]()
okHttpClient.newCall(request).enqueue(new Callback() {
def onFailure(call: Call, ex: IOException): Unit = (
p.failure(FcmServerTransportException(s"Failed to send HTTP request - ${ex.getMessage}", ex))
)
def onResponse(call: Call, response: Response): Unit = (
def onFailure(call: Call, ex: IOException): Unit = {
ex.printStackTrace()
p.failure(FcmServerTransportException(ex))
}
def onResponse(call: Call, response: Response): Unit = {
handleResponse(response).fold(
ex => p.failure(ex),
response => p.success(response.messageId)
response => p.success(response.name)
)
)
}
})
p.future
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.gu.notifications.worker.delivery.fcm.models.payload

import play.api.libs.json.{Format, Json, JsError, JsValue, JsSuccess}

case class FcmResponse(name: String)

object FcmResponse {
implicit val fcmResponseJf: Format[FcmResponse] = Json.format[FcmResponse]
}

case class FcmErrorPayload(code: Int, message: String, status: String) {
override def toString() = s"Code [$code] Status [$status] - $message"
}

object FcmErrorPayload {
implicit val fcmErrorPayloadJf: Format[FcmErrorPayload] = Json.format[FcmErrorPayload]
}

case class FcmError(error: FcmErrorPayload)

object FcmError {
implicit val fcmErrorJf: Format[FcmError] = Json.format[FcmError]
}

0 comments on commit e29c79c

Please sign in to comment.