Skip to content

Commit

Permalink
Merge pull request #1216 from guardian/LIVE-6383-implement-batch-send
Browse files Browse the repository at this point in the history
LIVE-6383 send requests to Firebase API directly
  • Loading branch information
waisingyiu authored Apr 15, 2024
2 parents 0659bd6 + e937b6b commit 8d86dd7
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.gu.notifications.worker

import _root_.models.NotificationMetadata
import _root_.models.Topic
import cats.effect.{ContextShift, IO, Timer}
import com.gu.notifications.worker.cleaning.CleaningClientImpl
import com.gu.notifications.worker.delivery.DeliveryException.InvalidToken
Expand Down Expand Up @@ -30,32 +31,43 @@ class AndroidSender(val config: FcmWorkerConfiguration, val firebaseAppName: Opt
override implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(config.threadPoolSize))

logger.info(s"Using thread pool size: ${config.threadPoolSize}")
logger.info(s"Topics for individual send: ${config.allowedTopicsForIndividualSend.mkString(",")}")
logger.info(s"Concurrency for individual send: ${config.concurrencyForIndividualSend}")

override implicit val ioContextShift: ContextShift[IO] = IO.contextShift(ec)
override implicit val timer: Timer[IO] = IO.timer(ec)

override val deliveryService: IO[Fcm[IO]] =
FcmClient(config.fcmConfig, firebaseAppName).fold(e => IO.raiseError(e), c => IO.delay(new Fcm(c)))
override val maxConcurrency = 100

override val maxConcurrency = config.concurrencyForIndividualSend
override val batchConcurrency = 100

//override the deliverChunkedTokens method to validate the success of sending batch notifications to the FCM client. This implementation could be refactored in the future to make it more streamlined with APNs
override def deliverChunkedTokens(chunkedTokenStream: Stream[IO, (ChunkedTokens, Long, Instant, Int)]): Stream[IO, Unit] = {
chunkedTokenStream.map {
case (chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize) =>
logger.info(Map("notificationId" -> chunkedTokens.notification.id), s"Sending notification ${chunkedTokens.notification.id} in batches")
deliverBatchNotificationStream(Stream.emits(chunkedTokens.toBatchNotificationToSends).covary[IO])
.broadcastTo(
reportBatchSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
reportBatchLatency(chunkedTokens, chunkedTokens.metadata),
cleanupBatchFailures(chunkedTokens.notification.id),
trackBatchProgress(chunkedTokens.notification.id))
}.parJoin(maxConcurrency)
if (config.isIndividualSend(chunkedTokens.notification.topic.map(_.toString())))
deliverIndividualNotificationStream(Stream.emits(chunkedTokens.toNotificationToSends).covary[IO])
.broadcastTo(
reportSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
cleanupFailures,
trackProgress(chunkedTokens.notification.id))
else {
logger.info(Map("notificationId" -> chunkedTokens.notification.id), s"Sending notification ${chunkedTokens.notification.id} in batches")
deliverBatchNotificationStream(Stream.emits(chunkedTokens.toBatchNotificationToSends).covary[IO])
.broadcastTo(
reportBatchSuccesses(chunkedTokens, sentTime, functionStartTime, sqsMessageBatchSize),
reportBatchLatency(chunkedTokens, chunkedTokens.metadata),
cleanupBatchFailures(chunkedTokens.notification.id),
trackBatchProgress(chunkedTokens.notification.id))
}
}.parJoin(batchConcurrency)
}

def deliverBatchNotificationStream[C <: FcmClient](batchNotificationStream: Stream[IO, BatchNotification]): Stream[IO, Either[DeliveryException, C#BatchSuccess]] = for {
deliveryService <- Stream.eval(deliveryService)
resp <- batchNotificationStream.map(batchNotification => deliveryService.sendBatch(batchNotification.notification, batchNotification.token))
.parJoin(maxConcurrency)
.parJoin(batchConcurrency)
.evalTap(Reporting.logBatch(s"Sending failure: "))
} yield resp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ case class FcmWorkerConfiguration(
cleaningSqsUrl: String,
fcmConfig: FcmConfig,
threadPoolSize: Int,
allowedTopicsForBatchSend: List[String],
) extends WorkerConfiguration
allowedTopicsForIndividualSend: List[String],
concurrencyForIndividualSend: Int
) extends WorkerConfiguration {
def isIndividualSend(topics: List[String]): Boolean =
topics.forall(topic => allowedTopicsForIndividualSend.exists(topic.startsWith(_)))
}

case class CleanerConfiguration(jdbcConfig: JdbcConfig)

Expand Down Expand Up @@ -112,8 +116,17 @@ object Configuration {
def fetchFirebase(): FcmWorkerConfiguration = {
val config = fetchConfiguration(confPrefixFromPlatform)

def getStringList(path: String): List[String] =
config.getString(path).split(",").toList
def getStringList(path: String): List[String] =
if (config.hasPath(path))
config.getString(path).split(",").toList
else
List()

def getOptionalInt(path: String, defVal: Int): Int =
if (config.hasPath(path))
config.getInt(path)
else
defVal

FcmWorkerConfiguration(
config.getString("cleaningSqsUrl"),
Expand All @@ -123,7 +136,8 @@ object Configuration {
dryRun = config.getBoolean("dryrun")
),
config.getInt("fcm.threadPoolSize"),
getStringList("fcm.allowedTopicsForBatchSend")
getStringList("fcm.allowedTopicsForIndividualSend"),
getOptionalInt("fcm.concurrencyForIndividualSend", 100)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ class IOSSender(val config: ApnsWorkerConfiguration, val metricNs: String) exten
override val deliveryService: IO[Apns[IO]] =
ApnsClient(config.apnsConfig).fold(e => IO.raiseError(e), c => IO.delay(new Apns(c)))
override val maxConcurrency = config.apnsConfig.maxConcurrency

override val batchConcurrency = config.apnsConfig.maxConcurrency
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ object NotificationWorkerLocalRun extends App {
),
imageUrl = None,
importance = Major,
topic = List(Topic(Breaking, "uk"), Topic(Breaking, "us"), Topic(Breaking, "au"), Topic(Breaking, "international"), Topic(Breaking, "europe")),
topic = List(Topic(Breaking, "international")),
// topic = List(Topic(Breaking, "uk"), Topic(Breaking, "us"), Topic(Breaking, "au"), Topic(Breaking, "international"), Topic(Breaking, "europe")),
dryRun = None
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ trait SenderRequestHandler[C <: DeliveryClient] extends Logging {
val cleaningClient: CleaningClient
val cloudwatch: Cloudwatch
val maxConcurrency: Int
val batchConcurrency: Int

def env = Env()

Expand Down Expand Up @@ -88,7 +89,7 @@ trait SenderRequestHandler[C <: DeliveryClient] extends Logging {
reportLatency(chunkedTokens, chunkedTokens.metadata),
cleanupFailures,
trackProgress(chunkedTokens.notification.id))
}.parJoin(maxConcurrency)
}.parJoin(batchConcurrency)
}

def handleChunkTokens(event: SQSEvent, context: Context): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package com.gu.notifications.worker.delivery.fcm

import _root_.models.Notification
import com.google.api.core.{ApiFuture, ApiFutureCallback, ApiFutures}
import com.google.auth.oauth2.GoogleCredentials
import com.google.api.client.json.JsonFactory
import com.google.auth.oauth2.{GoogleCredentials, ServiceAccountCredentials}
import com.google.firebase.messaging._
import com.google.firebase.{ErrorCode, FirebaseApp, FirebaseOptions}
import com.gu.notifications.worker.delivery.DeliveryException.{BatchCallFailedRequest, FailedRequest, InvalidToken, UnknownReasonFailedRequest}
Expand All @@ -22,7 +23,9 @@ import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp, config: FcmConfig)
import okhttp3.{Headers, MediaType, OkHttpClient, Request, RequestBody, Response, ResponseBody}

class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp, config: FcmConfig, projectId: String, credential: GoogleCredentials, jsonFactory: JsonFactory)
extends DeliveryClient with Logging {

type Success = FcmDeliverySuccess
Expand All @@ -41,6 +44,10 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,

def close(): Unit = firebaseApp.delete()

private final val FCM_URL: String = s"https://fcm.googleapis.com/v1/projects/${projectId}/messages:send";

private val fcmTransport: FcmTransport = new FcmTransportJdkImpl(credential, FCM_URL, jsonFactory)

def payloadBuilder: Notification => Option[FcmPayload] = n => FcmPayloadBuilder(n, config.debug)

def isUnregistered(e: FirebaseMessagingException): Boolean = {
Expand All @@ -65,9 +72,7 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
} else {
import FirebaseHelpers._
val start = Instant.now
firebaseMessaging
.sendAsync(message)
.asScala
fcmTransport.sendAsync(token, payload, dryRun)
.onComplete { response =>
val requestCompletionTime = Instant.now
logger.info(Map(
Expand Down Expand Up @@ -112,7 +117,7 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
}
}

def parseSendResponse(
def parseFirebaseSdkSendResponse(
notificationId: UUID, token: String, response: Try[String], requestCompletionTime: Instant
): Either[DeliveryException, Success] = response match {
case Success(messageId) =>
Expand All @@ -127,6 +132,27 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
Left(UnknownReasonFailedRequest(notificationId, token))
}

def parseSendResponse(
notificationId: UUID, token: String, response: Try[String], requestCompletionTime: Instant
): Either[DeliveryException, Success] = response match {
case Success(messageId) =>
Right(FcmDeliverySuccess(token, messageId, requestCompletionTime))
case Failure(e: InvalidTokenException) =>
Left(InvalidToken(notificationId, token, e.getMessage()))
case Failure(e: FcmServerException) =>
Left(FailedRequest(notificationId, token, e, Option(e.details.status)))
case Failure(e: UnknownException) =>
Left(FailedRequest(notificationId, token, e, Option(e.details.status)))
case Failure(e: InvalidResponseException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(e: QuotaExceededException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(e: FcmServerTransportException) =>
Left(FailedRequest(notificationId, token, e, None))
case Failure(_) =>
Left(UnknownReasonFailedRequest(notificationId, token))
}

def parseBatchSendResponse(
notificationId: UUID, tokens: List[String], triedResponse: Try[BatchResponse], requestCompletionTime: Instant
)(cb: Either[DeliveryException, BatchSuccess] => Unit): Unit = triedResponse match {
Expand All @@ -136,7 +162,7 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
batchResponse.getResponses.asScala.toList.zip(tokens).map { el => {
val (r, token) = el
if (!r.isSuccessful) {
parseSendResponse(notificationId, token, Failure(r.getException), requestCompletionTime)
parseFirebaseSdkSendResponse(notificationId, token, Failure(r.getException), requestCompletionTime)
} else {
Right(FcmDeliverySuccess(s"Token in batch response succeeded", token, requestCompletionTime))
}
Expand All @@ -151,19 +177,24 @@ class FcmClient (firebaseMessaging: FirebaseMessaging, firebaseApp: FirebaseApp,
}

object FcmClient {
def apply(config: FcmConfig, firebaseAppName: Option[String]): Try[FcmClient] = {
def apply(config: FcmConfig, firebaseAppName: Option[String]): Try[FcmClient] =
Try {
val credential = GoogleCredentials.fromStream(new ByteArrayInputStream(config.serviceAccountKey.getBytes))
val firebaseOptions: FirebaseOptions = FirebaseOptions.builder()
.setCredentials(GoogleCredentials.fromStream(new ByteArrayInputStream(config.serviceAccountKey.getBytes)))
.setCredentials(credential)
.setHttpTransport(new OkGoogleHttpTransport)
.setConnectTimeout(10000) // 10 seconds
.build
firebaseAppName match {
val firebaseApp = firebaseAppName match {
case None => FirebaseApp.initializeApp(firebaseOptions)
case Some(name) => FirebaseApp.initializeApp(firebaseOptions, name)
}
}.map(app => new FcmClient(FirebaseMessaging.getInstance(app), app, config))
}
val projectId = credential match {
case s: ServiceAccountCredentials => s.getProjectId()
case _ => ""
}
new FcmClient(FirebaseMessaging.getInstance(firebaseApp), firebaseApp, config, projectId, credential, firebaseOptions.getJsonFactory())
}
}

object FirebaseHelpers {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.gu.notifications.worker.delivery.fcm

import com.gu.notifications.worker.delivery.fcm.models.payload.FcmErrorPayload

case class InvalidResponseException(responseBody: String) extends Exception("Invalid success response") {
override def getMessage(): String = {
s"${super.getMessage()}. Response body: [${responseBody.take(200)}]"
}
}

case class QuotaExceededException(details: FcmErrorPayload) extends Exception("Request quota exceeded") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class InvalidTokenException(details: FcmErrorPayload) extends Exception("Invalid device token") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class FcmServerException(details: FcmErrorPayload) extends Exception("FCM server error") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class UnknownException(details: FcmErrorPayload) extends Exception("Unexpected exception") {
override def getMessage(): String = {
s"${super.getMessage()}. Details: ${details.toString()}]"
}
}

case class FcmServerTransportException(ex: Throwable) extends Exception("Failed to send HTTP request", ex) {
override def getMessage(): String = {
s"${super.getMessage()}. Reason: ${ex.getMessage()}]"
}
}
Loading

0 comments on commit 8d86dd7

Please sign in to comment.