-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
package akkahttp | ||
|
||
import io.circe._ | ||
import org.apache.pekko.actor.ActorSystem | ||
import org.apache.pekko.http.scaladsl.Http | ||
import org.apache.pekko.http.scaladsl.model.Uri.Authority | ||
import org.apache.pekko.http.scaladsl.model._ | ||
import org.apache.pekko.http.scaladsl.model.headers.{Host, RawHeader} | ||
import org.apache.pekko.http.scaladsl.server.Directives._ | ||
import org.apache.pekko.http.scaladsl.server.Route | ||
import org.apache.pekko.pattern.CircuitBreaker | ||
import org.apache.pekko.stream.scaladsl.{Sink, Source} | ||
import org.slf4j.{Logger, LoggerFactory} | ||
|
||
import java.util.concurrent.ConcurrentHashMap | ||
import java.util.concurrent.atomic.AtomicInteger | ||
import scala.collection.parallel.CollectionConverters.ImmutableIterableIsParallelizable | ||
import scala.concurrent.duration.DurationInt | ||
import scala.concurrent.{ExecutionContext, Future, Promise, TimeoutException} | ||
import scala.util.{Failure, Success} | ||
|
||
/** | ||
* HTTP reverse proxy with: | ||
* - weighted round robin load balancing | ||
* - retry | ||
* | ||
* Setup: | ||
* HTTP client(s) --> ReverseProxy --> Target server(s) | ||
* | ||
* Inspired by: | ||
* https://github.com/mathieuancelin/akka-http-reverse-proxy | ||
* | ||
* curl: | ||
* curl -H "Host: local" -o - -i -w " %{time_total}\n" http://127.0.0.1:8080 | ||
*/ | ||
object ReverseProxy extends App { | ||
val logger: Logger = LoggerFactory.getLogger(this.getClass) | ||
implicit val system: ActorSystem = ActorSystem() | ||
|
||
implicit val executionContext = system.dispatcher | ||
|
||
implicit val http = Http(system) | ||
|
||
val circuitBreakers = new ConcurrentHashMap[String, CircuitBreaker]() | ||
|
||
val proxyHost = "127.0.0.1" | ||
val proxyPort = 8080 | ||
|
||
val targetHostLocal = "local" | ||
|
||
val counter = new AtomicInteger(0) | ||
|
||
// HTTP client(s) | ||
val clients = 1 to 2 | ||
clients.par.foreach(clientID => httpClient(clientID, proxyHost, proxyPort)) | ||
|
||
def httpClient(clientId: Int, proxyHost: String, proxyPort: Int) = { | ||
val nbrOfRequests = 10 | ||
Source(1 to nbrOfRequests) | ||
.throttle(1, 1.seconds) | ||
.wireTap(each => logger.info(s"Client: $clientId about to issue $each/$nbrOfRequests request...")) | ||
.mapAsync(1)(_ => http.singleRequest(HttpRequest(uri = s"http://$proxyHost:$proxyPort/").withHeaders(Seq(RawHeader("Host", targetHostLocal))))) | ||
.wireTap(response => logger.info(s"Client: $clientId got response: ${response.status.intValue()}")) | ||
.runWith(Sink.ignore) | ||
} | ||
|
||
// ReverseProxy | ||
def NotFound(path: String) = HttpResponse( | ||
404, | ||
entity = HttpEntity(ContentTypes.`application/json`, Json.obj("error" -> Json.fromString(s"$path not found")).noSpaces) | ||
) | ||
|
||
def GatewayTimeout() = HttpResponse( | ||
504, | ||
entity = HttpEntity(ContentTypes.`application/json`, Json.obj("error" -> Json.fromString(s"Target servers timeout")).noSpaces) | ||
) | ||
|
||
def BadGateway(message: String) = HttpResponse( | ||
502, | ||
entity = HttpEntity(ContentTypes.`application/json`, Json.obj("error" -> Json.fromString(message)).noSpaces) | ||
) | ||
|
||
val services: Map[String, Seq[Target]] = Map( | ||
targetHostLocal -> Seq( | ||
Target.weighted("http://127.0.0.1:9081", 3), | ||
Target.weighted("http://127.0.0.1:9082", 2), | ||
Target.weighted("http://127.0.0.1:9083", 1) | ||
) | ||
) | ||
|
||
def extractHost(request: HttpRequest): String = request.header[Host].map(_.host.address()).getOrElse("--") | ||
|
||
def handlerWithCircuitBreaker(request: HttpRequest): Future[HttpResponse] = { | ||
val host = extractHost(request) | ||
services.get(host) match { | ||
case Some(rawSeq) => | ||
val seq = rawSeq.flatMap(t => (1 to t.weight).map(_ => t)) | ||
Retry.retry[HttpResponse](3) { | ||
val index = counter.incrementAndGet() % (if (seq.nonEmpty) seq.size else 1) | ||
val target = seq.apply(index) | ||
val circuitBreaker = circuitBreakers.computeIfAbsent(target.url, _ => new CircuitBreaker( | ||
system.scheduler, | ||
maxFailures = 5, | ||
callTimeout = 30.seconds, | ||
resetTimeout = 10.seconds)) | ||
val headersIn: Seq[HttpHeader] = | ||
request.headers.filterNot(t => t.name() == "Host") :+ | ||
Host(target.host) :+ | ||
RawHeader("X-Fowarded-Host", host) :+ | ||
RawHeader("X-Fowarded-Scheme", request.uri.scheme) | ||
val uri = request.uri.copy( | ||
scheme = target.scheme, | ||
authority = Authority(host = Uri.NamedHost(target.host), port = target.port)) | ||
// Filter, to avoid log noise, see: https://github.com/akka/akka-http/issues/64 | ||
val proxyReq = request.withUri(uri).withHeaders(headersIn.toList.filterNot(each => each.name() == "Timeout-Access")) | ||
circuitBreaker.withCircuitBreaker(http.singleRequest(proxyReq)) | ||
}.recover { | ||
case _: akka.pattern.CircuitBreakerOpenException => BadGateway("Circuit breaker opened") | ||
case _: TimeoutException => GatewayTimeout() | ||
case e => BadGateway(e.getMessage) | ||
} | ||
case None => Future.successful(NotFound(host)) | ||
} | ||
} | ||
|
||
val futReverseProxy = Http().newServerAt(proxyHost, proxyPort).bind(handlerWithCircuitBreaker) | ||
|
||
futReverseProxy.onComplete { | ||
case Success(b) => | ||
logger.info("ReverseProxy started, listening on: " + b.localAddress) | ||
case Failure(e) => | ||
logger.info(s"ReverseProxy failed. Exception message: ${e.getMessage}") | ||
system.terminate() | ||
} | ||
|
||
// Local Target servers simulation | ||
val echoRoute: Route = | ||
path("")( | ||
get { | ||
// Adjust to provoke more retries on ReverseProxy | ||
val codes = List(200, 200, 200, 500, 500, 500) | ||
val randomResponse = codes(new scala.util.Random().nextInt(codes.length)) | ||
logger.info(s"Local target server got echo request, reply with: $randomResponse") | ||
complete(StatusCode.int2StatusCode(randomResponse)) | ||
}) | ||
|
||
services.get(targetHostLocal).foreach(targetSeq => | ||
targetSeq.foreach(target => { | ||
val futTargetServer = Http().newServerAt(target.host, target.port).bind(echoRoute) | ||
|
||
futTargetServer.onComplete { | ||
case Success(b) => | ||
logger.info(s"Target server started, listening on: ${b.localAddress}") | ||
case Failure(e) => | ||
logger.info(s"Target server could not bind to... Exception message: ${e.getMessage}") | ||
system.terminate() | ||
} | ||
} | ||
) | ||
) | ||
|
||
// Models | ||
case class Target(scheme: String, host: String, port: Int, weight: Int = 1, protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) { | ||
def url: String = s"$scheme://$host:$port" | ||
} | ||
|
||
object Target { | ||
def apply(url: String): Target = { | ||
url.split("://|:").toList match { | ||
case scheme :: host :: port :: Nil => Target(scheme, host, port.toInt) | ||
case _ => throw new RuntimeException(s"Can not resolve target url: $url.") | ||
} | ||
} | ||
|
||
def weighted(url: String, weight: Int): Target = { | ||
Target(url).copy(weight = weight) | ||
} | ||
} | ||
} | ||
|
||
object Retry { | ||
val logger: Logger = LoggerFactory.getLogger(this.getClass) | ||
|
||
def retry[T](times: Int)(f: => Future[T])(implicit ec: ExecutionContext): Future[T] = { | ||
val promise = Promise[T]() | ||
retryPromise[T](times, promise, None, f) | ||
promise.future | ||
} | ||
|
||
private[this] def retryPromise[T](times: Int, promise: Promise[T], failure: Option[Throwable], | ||
f: => Future[T])(implicit ec: ExecutionContext): Unit = { | ||
(times, failure) match { | ||
case (0, Some(e)) => | ||
promise.tryFailure(e) | ||
case (0, None) => | ||
promise.tryFailure(new RuntimeException("Failure, but lost track of exception :-(")) | ||
case (_, _) => | ||
f.onComplete { | ||
case Success(t) => | ||
// we need to cast here | ||
val code = t.asInstanceOf[HttpResponse].status.intValue() | ||
if (code >= 500) { | ||
logger.info(s"Got 50x server error from target server on attempt: $times/3") | ||
retryPromise[T](times - 1, promise, Some(new RuntimeException("Got 500")), f) | ||
} else { | ||
promise.trySuccess(t) | ||
} | ||
|
||
case Failure(e) => | ||
retryPromise[T](times - 1, promise, Some(e), f) | ||
}(ec) | ||
} | ||
} | ||
} |