Skip to content

Commit

Permalink
feat: CORS-support (#4343)
Browse files Browse the repository at this point in the history
* WIP: first step towards built in CORS-support in akka-http

* Some alloc optimizations, additional tests

* Docs, more tests

* More missing headers

* Compiles on Scala 3

* Worked through API docs and cleaned up and aligned a bit

* Some more optimization of the hot path

* hash set lookup over loop per header

* Headers referencing original project and license

* Came to an agreement with the copyright checker

* Add some reference bench results

* migration guide

* Shaped up the benchmark

* updated reference bench results
  • Loading branch information
johanandren authored Jan 24, 2024
1 parent 44162d4 commit 67c7041
Show file tree
Hide file tree
Showing 24 changed files with 1,405 additions and 8 deletions.
127 changes: 127 additions & 0 deletions akka-http-bench-jmh/src/main/scala/akka/http/CorsBenchmark.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (C) 2024 Lightbend Inc. <https://www.lightbend.com>
* Copyright 2016 Lomig Mégard
*/

package akka.http

import akka.actor.ActorSystem
import akka.dispatch.ExecutionContexts
import akka.http.scaladsl.model.HttpResponse
import akka.http.scaladsl.model.headers.HttpOrigin
import akka.http.scaladsl.model.headers.Origin
import akka.http.scaladsl.model.headers.`Access-Control-Request-Method`
import akka.http.scaladsl.model.HttpMethods
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.settings.CorsSettings
import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._

import java.util.concurrent.TimeUnit
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.concurrent.Await
import scala.concurrent.ExecutionContext

/*
* This benchmark is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0.
*
* Reference results from run on Linux Gen 11, i5 2.60GHz, JDK 17.0.8:
* Benchmark Mode Cnt Score Error Units
* CorsBenchmark.baseline thrpt 10 2638704.917 ± 2642.203 ops/s
* CorsBenchmark.default_cors thrpt 10 1811293.399 ± 11974.729 ops/s
* CorsBenchmark.default_preflight thrpt 10 2637449.643 ± 7249.558 ops/s
* CorsBenchmark.settings_cors thrpt 10 1878642.961 ± 4003.510 ops/s
* CorsBenchmark.settings_preflight thrpt 10 2936688.186 ± 3591.821 ops/s
*/
@State(Scope.Benchmark)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
class CorsBenchmark extends Directives {
private val config = ConfigFactory.parseString("akka.loglevel = ERROR").withFallback(ConfigFactory.load())

implicit private val system: ActorSystem = ActorSystem("CorsBenchmark", config)
implicit private val ec: ExecutionContext = scala.concurrent.ExecutionContext.global

private val corsSettings = CorsSettings(system)

private var baselineHandler: Function[HttpRequest, Future[HttpResponse]] = _
private var corsDefaultHandler: Function[HttpRequest, Future[HttpResponse]] = _
private var corsSettingsHandler: Function[HttpRequest, Future[HttpResponse]] = _
private var request: HttpRequest = _
private var requestCors: HttpRequest = _
private var requestPreflight: HttpRequest = _

@Setup
def setup(): Unit = {
baselineHandler = Route.toFunction(path("baseline") {
get {
complete("ok")
}
})
corsDefaultHandler = Route.toFunction(path("cors") {
cors() {
get {
complete("ok")
}
}
})
corsSettingsHandler = Route.toFunction(path("cors") {
cors(corsSettings) {
get {
complete("ok")
}
}
})

val origin = Origin(HttpOrigin("http://example.com"))

val base = s"http://127.0.0.1:8080"
request = HttpRequest(uri = s"$base/baseline")
requestCors = HttpRequest(
method = HttpMethods.GET,
uri = s"$base/cors",
headers = List(origin)
)
requestPreflight = HttpRequest(
method = HttpMethods.OPTIONS,
uri = s"$base/cors",
headers = List(origin, `Access-Control-Request-Method`(HttpMethods.GET))
)
}

@TearDown
def shutdown(): Unit = {
Await.ready(system.terminate(), 5.seconds)
}

@Benchmark
def baseline(): Unit = {
assert(responseBody(baselineHandler(request)) == "ok")
}

@Benchmark
def default_cors(): Unit = {
assert(responseBody(corsDefaultHandler(requestCors)) == "ok")
}

@Benchmark
def default_preflight(): Unit = {
assert(responseBody(corsDefaultHandler(requestPreflight)) == "")
}

@Benchmark
def settings_cors(): Unit = {
assert(responseBody(corsSettingsHandler(requestCors)) == "ok")
}

@Benchmark
def settings_preflight(): Unit = {
assert(responseBody(corsSettingsHandler(requestPreflight)) == "")
}

private def responseBody(response: Future[HttpResponse]): String =
Await.result(response.flatMap(_.entity.toStrict(3.seconds)).map(_.data.utf8String)(ExecutionContexts.parasitic), 3.seconds)
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
public abstract class AccessControlAllowOrigin extends akka.http.scaladsl.model.HttpHeader {
public abstract HttpOriginRange range();

public static AccessControlAllowOrigin wildcard() {
return akka.http.scaladsl.model.headers.Access$minusControl$minusAllow$minusOrigin$.MODULE$.$times();
}
public static AccessControlAllowOrigin nullOrigin() {
return create(HttpOriginRange.create());
}

public static AccessControlAllowOrigin create(HttpOriginRange range) {
return new akka.http.scaladsl.model.headers.Access$minusControl$minusAllow$minusOrigin(((akka.http.scaladsl.model.headers.HttpOriginRange) range));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ final case class `Access-Control-Allow-Methods`(methods: immutable.Seq[HttpMetho

// https://www.w3.org/TR/cors/#access-control-allow-origin-response-header
object `Access-Control-Allow-Origin` extends ModeledCompanion[`Access-Control-Allow-Origin`] {
val `*` = forRange(HttpOriginRange.`*`)
val `null` = forRange(HttpOriginRange())
val `*`: `Access-Control-Allow-Origin` = forRange(HttpOriginRange.`*`)
val `null`: `Access-Control-Allow-Origin` = forRange(HttpOriginRange())

def apply(origin: HttpOrigin) = forRange(HttpOriginRange(origin))

/**
Expand Down
Loading

0 comments on commit 67c7041

Please sign in to comment.