From 67c704111b755274e5daa423c2c923cbb017e8f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Wed, 24 Jan 2024 13:50:40 +0100 Subject: [PATCH] feat: CORS-support (#4343) * WIP: first step towards built in CORS-support in akka-http * Some alloc optimizations, additional tests * Docs, more tests * More missing headers * Compiles on Scala 3 * Worked through API docs and cleaned up and aligned a bit * Some more optimization of the hot path * hash set lookup over loop per header * Headers referencing original project and license * Came to an agreement with the copyright checker * Add some reference bench results * migration guide * Shaped up the benchmark * updated reference bench results --- .../main/scala/akka/http/CorsBenchmark.scala | 127 +++++ .../headers/AccessControlAllowOrigin.java | 7 + .../http/scaladsl/model/headers/headers.scala | 5 +- .../directives/CorsDirectivesSpec.scala | 460 ++++++++++++++++++ akka-http/src/main/resources/reference.conf | 52 ++ .../http/impl/settings/CorsSettingsImpl.scala | 156 ++++++ .../akka/http/javadsl/server/Directives.scala | 5 +- .../akka/http/javadsl/server/Rejections.scala | 8 + .../server/directives/CorsDirectives.scala | 46 ++ .../http/javadsl/settings/CorsSettings.scala | 112 +++++ .../http/scaladsl/server/Directives.scala | 1 + .../akka/http/scaladsl/server/Rejection.scala | 5 + .../scaladsl/server/RejectionHandler.scala | 4 + .../server/directives/CorsDirectives.scala | 150 ++++++ .../http/scaladsl/settings/CorsSettings.scala | 119 +++++ .../main/paradox/compatibility-guidelines.md | 6 +- .../migration-guide/migration-guide-10.6.x.md | 12 + .../routing-dsl/directives/alphabetically.md | 1 + .../routing-dsl/directives/by-trait.md | 1 + .../directives/cors-directives/cors.md | 33 ++ .../directives/cors-directives/index.md | 9 + .../CorsDirectivesExamplesTest.java | 46 ++ .../CorsDirectivesExamplesSpec.scala | 45 ++ project/CopyrightHeader.scala | 3 +- 24 files changed, 1405 insertions(+), 8 deletions(-) create mode 100644 akka-http-bench-jmh/src/main/scala/akka/http/CorsBenchmark.scala create mode 100644 akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CorsDirectivesSpec.scala create mode 100644 akka-http/src/main/scala/akka/http/impl/settings/CorsSettingsImpl.scala create mode 100644 akka-http/src/main/scala/akka/http/javadsl/server/directives/CorsDirectives.scala create mode 100644 akka-http/src/main/scala/akka/http/javadsl/settings/CorsSettings.scala create mode 100644 akka-http/src/main/scala/akka/http/scaladsl/server/directives/CorsDirectives.scala create mode 100644 akka-http/src/main/scala/akka/http/scaladsl/settings/CorsSettings.scala create mode 100644 docs/src/main/paradox/routing-dsl/directives/cors-directives/cors.md create mode 100644 docs/src/main/paradox/routing-dsl/directives/cors-directives/index.md create mode 100644 docs/src/test/java/docs/http/javadsl/server/directives/CorsDirectivesExamplesTest.java create mode 100644 docs/src/test/scala/docs/http/scaladsl/server/directives/CorsDirectivesExamplesSpec.scala diff --git a/akka-http-bench-jmh/src/main/scala/akka/http/CorsBenchmark.scala b/akka-http-bench-jmh/src/main/scala/akka/http/CorsBenchmark.scala new file mode 100644 index 00000000000..d273ab0e903 --- /dev/null +++ b/akka-http-bench-jmh/src/main/scala/akka/http/CorsBenchmark.scala @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http + +import akka.actor.ActorSystem +import akka.dispatch.ExecutionContexts +import akka.http.scaladsl.model.HttpResponse +import akka.http.scaladsl.model.headers.HttpOrigin +import akka.http.scaladsl.model.headers.Origin +import akka.http.scaladsl.model.headers.`Access-Control-Request-Method` +import akka.http.scaladsl.model.HttpMethods +import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.server.Directives +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.settings.CorsSettings +import com.typesafe.config.ConfigFactory +import org.openjdk.jmh.annotations._ + +import java.util.concurrent.TimeUnit +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.concurrent.Await +import scala.concurrent.ExecutionContext + +/* + * This benchmark is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + * + * Reference results from run on Linux Gen 11, i5 2.60GHz, JDK 17.0.8: + * Benchmark Mode Cnt Score Error Units + * CorsBenchmark.baseline thrpt 10 2638704.917 ± 2642.203 ops/s + * CorsBenchmark.default_cors thrpt 10 1811293.399 ± 11974.729 ops/s + * CorsBenchmark.default_preflight thrpt 10 2637449.643 ± 7249.558 ops/s + * CorsBenchmark.settings_cors thrpt 10 1878642.961 ± 4003.510 ops/s + * CorsBenchmark.settings_preflight thrpt 10 2936688.186 ± 3591.821 ops/s + */ +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.SECONDS) +@BenchmarkMode(Array(Mode.Throughput)) +class CorsBenchmark extends Directives { + private val config = ConfigFactory.parseString("akka.loglevel = ERROR").withFallback(ConfigFactory.load()) + + implicit private val system: ActorSystem = ActorSystem("CorsBenchmark", config) + implicit private val ec: ExecutionContext = scala.concurrent.ExecutionContext.global + + private val corsSettings = CorsSettings(system) + + private var baselineHandler: Function[HttpRequest, Future[HttpResponse]] = _ + private var corsDefaultHandler: Function[HttpRequest, Future[HttpResponse]] = _ + private var corsSettingsHandler: Function[HttpRequest, Future[HttpResponse]] = _ + private var request: HttpRequest = _ + private var requestCors: HttpRequest = _ + private var requestPreflight: HttpRequest = _ + + @Setup + def setup(): Unit = { + baselineHandler = Route.toFunction(path("baseline") { + get { + complete("ok") + } + }) + corsDefaultHandler = Route.toFunction(path("cors") { + cors() { + get { + complete("ok") + } + } + }) + corsSettingsHandler = Route.toFunction(path("cors") { + cors(corsSettings) { + get { + complete("ok") + } + } + }) + + val origin = Origin(HttpOrigin("http://example.com")) + + val base = s"http://127.0.0.1:8080" + request = HttpRequest(uri = s"$base/baseline") + requestCors = HttpRequest( + method = HttpMethods.GET, + uri = s"$base/cors", + headers = List(origin) + ) + requestPreflight = HttpRequest( + method = HttpMethods.OPTIONS, + uri = s"$base/cors", + headers = List(origin, `Access-Control-Request-Method`(HttpMethods.GET)) + ) + } + + @TearDown + def shutdown(): Unit = { + Await.ready(system.terminate(), 5.seconds) + } + + @Benchmark + def baseline(): Unit = { + assert(responseBody(baselineHandler(request)) == "ok") + } + + @Benchmark + def default_cors(): Unit = { + assert(responseBody(corsDefaultHandler(requestCors)) == "ok") + } + + @Benchmark + def default_preflight(): Unit = { + assert(responseBody(corsDefaultHandler(requestPreflight)) == "") + } + + @Benchmark + def settings_cors(): Unit = { + assert(responseBody(corsSettingsHandler(requestCors)) == "ok") + } + + @Benchmark + def settings_preflight(): Unit = { + assert(responseBody(corsSettingsHandler(requestPreflight)) == "") + } + + private def responseBody(response: Future[HttpResponse]): String = + Await.result(response.flatMap(_.entity.toStrict(3.seconds)).map(_.data.utf8String)(ExecutionContexts.parasitic), 3.seconds) +} diff --git a/akka-http-core/src/main/java/akka/http/javadsl/model/headers/AccessControlAllowOrigin.java b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/AccessControlAllowOrigin.java index cc2277dd73c..a4240eb2e20 100644 --- a/akka-http-core/src/main/java/akka/http/javadsl/model/headers/AccessControlAllowOrigin.java +++ b/akka-http-core/src/main/java/akka/http/javadsl/model/headers/AccessControlAllowOrigin.java @@ -11,6 +11,13 @@ public abstract class AccessControlAllowOrigin extends akka.http.scaladsl.model.HttpHeader { public abstract HttpOriginRange range(); + public static AccessControlAllowOrigin wildcard() { + return akka.http.scaladsl.model.headers.Access$minusControl$minusAllow$minusOrigin$.MODULE$.$times(); + } + public static AccessControlAllowOrigin nullOrigin() { + return create(HttpOriginRange.create()); + } + public static AccessControlAllowOrigin create(HttpOriginRange range) { return new akka.http.scaladsl.model.headers.Access$minusControl$minusAllow$minusOrigin(((akka.http.scaladsl.model.headers.HttpOriginRange) range)); } diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala index ebd8ab8d60c..d51af8736ef 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/headers/headers.scala @@ -282,8 +282,9 @@ final case class `Access-Control-Allow-Methods`(methods: immutable.Seq[HttpMetho // https://www.w3.org/TR/cors/#access-control-allow-origin-response-header object `Access-Control-Allow-Origin` extends ModeledCompanion[`Access-Control-Allow-Origin`] { - val `*` = forRange(HttpOriginRange.`*`) - val `null` = forRange(HttpOriginRange()) + val `*`: `Access-Control-Allow-Origin` = forRange(HttpOriginRange.`*`) + val `null`: `Access-Control-Allow-Origin` = forRange(HttpOriginRange()) + def apply(origin: HttpOrigin) = forRange(HttpOriginRange(origin)) /** diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CorsDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CorsDirectivesSpec.scala new file mode 100644 index 00000000000..aefa1fbbd8c --- /dev/null +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/CorsDirectivesSpec.scala @@ -0,0 +1,460 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.scaladsl.server.directives + +import akka.http.impl.settings.HttpOriginMatcher +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.server.{ CorsRejection, Route, RoutingSpec } +import akka.http.scaladsl.settings.CorsSettings +import org.scalatest.Inspectors.forAll + +// This test is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. +class CorsDirectivesSpec extends RoutingSpec { + import HttpMethods._ + + val actual = "actual" + val exampleOrigin = HttpOrigin("http://example.com") + val exampleStatus = StatusCodes.Created + + val referenceSettings = CorsSettings(system).withAllowCredentials(true) + + override def testConfigSource = + """ + akka.http.cors.allow-credentials = off + """ + + def route(settings: CorsSettings, responseHeaders: Seq[HttpHeader] = Nil): Route = + cors(settings) { + complete(HttpResponse(exampleStatus, responseHeaders, HttpEntity(actual))) + } + + "The cors() directive" should { + "extract its settings from the actor system" in { + val route = cors() { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + + Get() ~> Origin(exampleOrigin) ~> route ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`.* // no allow-credentials because above + ) + } + } + } + + "The cors(settings) directive" should { + "not affect actual requests when not strict" in { + val settings = referenceSettings + val responseHeaders = Seq(Host("my-host"), `Access-Control-Max-Age`(60)) + Get() ~> { + route(settings, responseHeaders) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + // response headers should be untouched, including the CORS-related ones + response.headers shouldBe responseHeaders + } + } + + "reject requests without Origin header when strict" in { + val settings = referenceSettings.withAllowGenericHttpRequests(false) + Get() ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("malformed request") + } + } + + "accept actual requests with a single Origin" in { + val settings = referenceSettings + Get() ~> Origin(exampleOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(exampleOrigin), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "flag credentials allowed when they are" in { + val settings = referenceSettings.withAllowCredentials(true) + Get() ~> Origin(exampleOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(exampleOrigin.toString), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "accept pre-flight requests with a null origin when allowed-origins = `*`" in { + val settings = referenceSettings + Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> { + route(settings) + } ~> check { + status shouldBe StatusCodes.OK + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`.`null`, + `Access-Control-Allow-Methods`(settings.allowedMethods.toArray), + `Access-Control-Max-Age`(1800), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "reject pre-flight requests with a null origin when allowed-origins != `*`" in { + val settings = referenceSettings.withAllowedOrigins(Set(exampleOrigin.toString)) + Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("invalid origin 'null'") + } + } + + "accept actual requests with a null Origin" in { + val settings = referenceSettings + Get() ~> Origin(Seq.empty) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`.`null`, + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "accept actual requests with an Origin matching an allowed subdomain" in { + val subdomainOrigin = HttpOrigin("http://sub.example.com") + + val settings = referenceSettings.withAllowedOrigins(Set("http://*.example.com")) + Get() ~> Origin(subdomainOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(subdomainOrigin), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "return `Access-Control-Allow-Origin: *` to actual request only when credentials are not allowed" in { + val settings = referenceSettings.withAllowCredentials(false) + Get() ~> Origin(exampleOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers shouldBe Seq( + `Access-Control-Allow-Origin`.* + ) + } + } + + "return `Access-Control-Expose-Headers` to actual request with all the exposed headers in the settings" in { + val exposedHeaders = Set("X-a", "X-b", "X-c") + val settings = referenceSettings.withExposedHeaders(exposedHeaders) + Get() ~> Origin(exampleOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers shouldBe Seq( + `Access-Control-Allow-Origin`(exampleOrigin), + `Access-Control-Expose-Headers`(exposedHeaders.toArray), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "remove CORS-related headers from the original response before adding the new ones" in { + val settings = referenceSettings.withExposedHeaders(Set("X-good")) + val responseHeaders = Seq( + Host("my-host"), // untouched + `Access-Control-Allow-Origin`("http://bad.com"), // replaced + `Access-Control-Expose-Headers`("X-bad"), // replaced + `Access-Control-Allow-Credentials`(false), // replaced + `Access-Control-Allow-Methods`(HttpMethods.POST), // removed + `Access-Control-Allow-Headers`("X-bad"), // removed + `Access-Control-Max-Age`(60) // removed + ) + Get() ~> Origin(exampleOrigin) ~> { + route(settings, responseHeaders) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(exampleOrigin), + `Access-Control-Expose-Headers`("X-good"), + `Access-Control-Allow-Credentials`(true), + Host("my-host") + ) + } + } + + "accept valid pre-flight requests" in { + val settings = referenceSettings + Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> { + route(settings) + } ~> check { + response.entity shouldBe HttpEntity.Empty + status shouldBe StatusCodes.OK + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(exampleOrigin), + `Access-Control-Allow-Methods`(settings.allowedMethods.toArray), + `Access-Control-Max-Age`(1800), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "accept actual requests with OPTION method" in { + val settings = referenceSettings + Options() ~> Origin(exampleOrigin) ~> { + route(settings) + } ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(exampleOrigin), + `Access-Control-Allow-Credentials`(true) + ) + } + } + + "reject actual requests with invalid origin" when { + "the origin is null" in { + val settings = referenceSettings.withAllowedOrigins(Set(exampleOrigin.toString)) + Get() ~> Origin(Seq.empty) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("invalid origin 'null'") + } + } + "there is one origin" in { + val settings = referenceSettings.withAllowedOrigins(Set(exampleOrigin.toString)) + val invalidOrigin = HttpOrigin("http://invalid.com") + Get() ~> Origin(invalidOrigin) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection(s"invalid origin '${invalidOrigin.toString}'") + } + } + } + + "reject pre-flight requests with invalid origin" in { + val settings = referenceSettings.withAllowedOrigins(Set(exampleOrigin.toString)) + val invalidOrigin = HttpOrigin("http://invalid.com") + Options() ~> Origin(invalidOrigin) ~> `Access-Control-Request-Method`(GET) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection(s"invalid origin '${invalidOrigin.toString}'") + } + } + + "reject pre-flight requests with invalid method" in { + val settings = referenceSettings + val invalidMethod = HttpMethods.PATCH + Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(invalidMethod) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("invalid method 'PATCH'") + } + } + + "reject pre-flight requests with invalid header" in { + val settings = referenceSettings.withAllowedHeaders(Set[String]()) + val invalidHeader = "X-header" + Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(HttpMethods.GET) ~> + `Access-Control-Request-Headers`(invalidHeader) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("invalid headers 'X-header'") + } + } + + "reject pre-flight requests with multiple origins" in { + val settings = referenceSettings.withAllowGenericHttpRequests(false) + Options() ~> Origin(exampleOrigin, exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> { + route(settings) + } ~> check { + rejection shouldBe CorsRejection("malformed request") + } + } + } + + "The default rejection handler" should { + val settings = referenceSettings + .withAllowGenericHttpRequests(false) + .withAllowedOrigins(Set(exampleOrigin.toString)) + .withAllowedHeaders(Set[String]()) + val sealedRoute = Route.seal(route(settings)) + + "handle the malformed request cause" in { + Get() ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe "CORS: malformed request" + } + } + + "handle a request with invalid origin" when { + "the origin is null" in { + Get() ~> Origin(Seq.empty) ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe s"CORS: invalid origin 'null'" + } + } + "there is one origin" in { + Get() ~> Origin(HttpOrigin("http://invalid.com")) ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe s"CORS: invalid origin 'http://invalid.com'" + } + } + "there are two origins" in { + Get() ~> Origin(HttpOrigin("http://invalid1.com"), HttpOrigin("http://invalid2.com")) ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe s"CORS: invalid origin 'http://invalid1.com http://invalid2.com'" + } + } + } + + "handle a pre-flight request with invalid method" in { + Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(PATCH) ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe s"CORS: invalid method 'PATCH'" + } + } + + "handle a pre-flight request with invalid headers" in { + Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> + `Access-Control-Request-Headers`("X-a", "X-b") ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe s"CORS: invalid headers 'X-a X-b'" + } + } + + "handle multiple CORS rejections" in { + Options() ~> Origin(HttpOrigin("http://invalid.com")) ~> `Access-Control-Request-Method`(PATCH) ~> + `Access-Control-Request-Headers`("X-a", "X-b") ~> { + sealedRoute + } ~> check { + status shouldBe StatusCodes.BadRequest + entityAs[String] shouldBe + s"CORS: invalid origin 'http://invalid.com', invalid method 'PATCH', invalid headers 'X-a X-b'" + } + } + } + + "The CORS origin matcher" should { + "match any origin with *" in { + val origins = Seq( + "http://localhost", + "http://192.168.1.1", + "http://test.com", + "http://test.com:8080", + "https://test.com", + "https://test.com:4433" + ).map(HttpOrigin.apply) + + forAll(origins) { o => HttpOriginMatcher.matchAny.apply(Seq(o)) shouldBe true } + HttpOriginMatcher.matchAny(origins) shouldBe true + } + + "match exact origins" in { + val positives = Seq( + "http://localhost", + "http://test.com", + "https://test.ch:12345" + ).map(HttpOrigin.apply) + + val negatives = Seq( + "http://localhost:80", + "https://localhost", + "http://test.com:8080", + "https://test.ch", + "https://abc.test.uk.co", + ).map(HttpOrigin.apply) + + val matcher = HttpOriginMatcher.apply(Set( + "http://localhost", + "http://test.com", + "https://test.ch:12345" + )) + + forAll(positives) { o => matcher.apply(Seq(o)) shouldBe true } + + forAll(negatives) { o => matcher.apply(Seq(o)) shouldBe false } + + matcher(positives) shouldBe true + matcher(negatives) shouldBe false + matcher(negatives ++ positives) shouldBe true // at least one match + matcher(positives ++ negatives) shouldBe true // at least one match + } + + "match sub-domains with wildcards" in { + val matcher = HttpOriginMatcher( + Set( + "http://test.com", + "https://test.ch:12345", + "https://*.test.uk.co", + "http://*.abc.com:8080", + "http://*abc.com", // Must start with `*.` + "http://abc.*.middle.com" // The wildcard can't be in the middle + ) + ) + + val positives = Seq( + "http://test.com", + "https://test.ch:12345", + "https://sub.test.uk.co", + "https://sub1.sub2.test.uk.co", + "http://sub.abc.com:8080" + ).map(HttpOrigin.apply) + + val negatives = Seq( + "http://test.com:8080", + "http://sub.test.uk.co", // must compare the scheme + "http://sub.abc.com", // must compare the port + "http://abc.test.com", // no wildcard + "http://sub.abc.com", + "http://subabc.com", + "http://abc.sub.middle.com", + "http://abc.middle.com" + ).map(HttpOrigin.apply) + + forAll(positives) { o => matcher.apply(Seq(o)) shouldBe true } + + forAll(negatives) { o => matcher.apply(Seq(o)) shouldBe false } + matcher(negatives ++ positives) shouldBe true + matcher(positives ++ negatives) shouldBe true + } + } + +} diff --git a/akka-http/src/main/resources/reference.conf b/akka-http/src/main/resources/reference.conf index 0c5b6e701ef..e81fadd8fdc 100644 --- a/akka-http/src/main/resources/reference.conf +++ b/akka-http/src/main/resources/reference.conf @@ -62,4 +62,56 @@ akka.http { # This setting can be enabled to pass those empty events to the application for explicit handling. emit-empty-events = off } + + # Configuration for the cors directive, does not apply unless the directive is used + #cors + cors { + # Allow generic requests, that are outside the scope of the specification, for example lacking + # an `Origin` header to pass through the directive. + # + # When false strict CORS filtering is applied and any invalid request will be rejected. + allow-generic-http-requests = on + + # If enabled, the header `Access-Control-Allow-Credentials` + # is included in the response, indicating that the actual request can include user credentials. + # Examples of user credentials are: cookies, HTTP authentication or client-side certificates. + allow-credentials = on + + # List of origins that the CORS filter must allow. + # + # Can also be set to a single `*` to allow access to the resource from any origin. + # + # Controls the content of the `Access-Control-Allow-Origin` response header: if parameter is `*` and + # credentials are not allowed, a `*` is returned in `Access-Control-Allow-Origin`. Otherwise, the origins given in the + # `Origin` request header are echoed. + # + # Hostname starting with `*.` will match any sub-domain. The scheme and the port are always strictly matched. + # + # The actual or preflight request is rejected if any of the origins from the request is not allowed.. + allowed-origins = ["*"] + + # Set of request headers that are allowed when making an actual request. + # + # Controls the content of the `Access-Control-Allow-Headers` header in a preflight response: If set to a single `*`, + # the headers from `Access-Control-Request-Headers` are echoed. Otherwise specified list of header names is returned + # as part of the header. + allowed-headers = ["*"] + + # List of methods allowed when making an actual request. The listed headers are returned as part of the + # `Access-Control-Allow-Methods` preflight response header. + # + # The preflight request will be rejected if the `Access-Control-Request-Method` header's method is not part of the + # list. + allowed-methods = ["GET", "POST", "HEAD", "OPTIONS"] + + # Set of headers (other than simple response headers) that browsers are allowed to access. If not empty, the listed + # headers are returned as part of the `Access-Control-Expose-Headers` header in responses. + exposed-headers = [] + + # The time the browser is allowed to cache the results of a preflight request. This value is + # returned as part of the `Access-Control-Max-Age` preflight response header. If `scala.concurrent.duration.Duration.Zero`, + # the header is not added to the preflight response. + max-age = 1800 seconds + } + #cors } diff --git a/akka-http/src/main/scala/akka/http/impl/settings/CorsSettingsImpl.scala b/akka-http/src/main/scala/akka/http/impl/settings/CorsSettingsImpl.scala new file mode 100644 index 00000000000..6a5697aced8 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/impl/settings/CorsSettingsImpl.scala @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.impl.settings + +import akka.annotation.InternalApi +import akka.http.impl.util.SettingsCompanionImpl +import akka.http.scaladsl.model.headers.{ HttpOrigin, HttpOriginRange, `Access-Control-Allow-Credentials`, `Access-Control-Allow-Headers`, `Access-Control-Allow-Methods`, `Access-Control-Allow-Origin`, `Access-Control-Expose-Headers`, `Access-Control-Max-Age` } +import akka.http.scaladsl.model.{ HttpHeader, HttpMethod, HttpMethods } +import akka.util.{ ConstantFun, OptionVal } +import com.typesafe.config.Config + +import scala.concurrent.duration.{ Duration, FiniteDuration } +import scala.jdk.CollectionConverters._ +import scala.jdk.DurationConverters._ + +/** + * This implementation is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + * + * INTERNAL API + */ +@InternalApi +private[akka] case class CorsSettingsImpl( + allowGenericHttpRequests: Boolean, + allowCredentials: Boolean, + allowedOrigins: Set[String], + allowedHeaders: Set[String], + allowedMethods: Set[HttpMethod], + exposedHeaders: Set[String], + maxAge: FiniteDuration +) extends akka.http.scaladsl.settings.CorsSettings { + import CorsSettingsImpl.allowAnySet + + // internals for the directive impl + val originsMatches: Seq[HttpOrigin] => Boolean = HttpOriginMatcher(allowedOrigins) + val headerNameAllowed: String => Boolean = + if (allowedHeaders == Set("*")) ConstantFun.anyToTrue + else allowedHeaders.contains + + private def accessControlExposeHeaders: Option[`Access-Control-Expose-Headers`] = + if (exposedHeaders.nonEmpty) + Some(`Access-Control-Expose-Headers`(exposedHeaders.toArray)) + else + None + + private def accessControlAllowCredentials: Option[`Access-Control-Allow-Credentials`] = + if (allowCredentials) + Some(`Access-Control-Allow-Credentials`(true)) + else + None + + private def accessControlMaxAge: Option[`Access-Control-Max-Age`] = + if (maxAge != Duration.Zero) Some(`Access-Control-Max-Age`(maxAge.toSeconds)) + else None + + private def accessControlAllowMethods: `Access-Control-Allow-Methods` = + `Access-Control-Allow-Methods`(allowedMethods.toArray) + + private def accessControlAllowHeaders(requestHeaders: Seq[String], baseHeaders: List[HttpHeader]): List[HttpHeader] = + if (allowedHeaders == allowAnySet) { + if (requestHeaders.nonEmpty) `Access-Control-Allow-Headers`(requestHeaders) :: baseHeaders + else baseHeaders + } else `Access-Control-Allow-Headers`(requestHeaders) :: baseHeaders + + // single instance if possible + private val sameAccessControlAllowHeaderForAll = + if (allowedOrigins == allowAnySet && !allowCredentials) Some(`Access-Control-Allow-Origin`.*) + else None + + // Cache headers that are always included in a preflight response + private val basePreflightResponseHeaders: List[HttpHeader] = + List(accessControlAllowMethods) ++ accessControlMaxAge ++ accessControlAllowCredentials ++ sameAccessControlAllowHeaderForAll + + // Cache headers that are always included in an actual response + private val baseActualResponseHeaders: List[HttpHeader] = + accessControlExposeHeaders.toList ++ accessControlAllowCredentials ++ sameAccessControlAllowHeaderForAll + + private def accessControlAllowOrigin(origins: Seq[HttpOrigin], baseHeaders: List[HttpHeader]): List[HttpHeader] = + if (sameAccessControlAllowHeaderForAll.isDefined) + // we already included it in the base headers + baseHeaders + else + `Access-Control-Allow-Origin`.forRange(HttpOriginRange.Default(origins)) :: baseHeaders + + def actualResponseHeaders(origins: Seq[HttpOrigin]): List[HttpHeader] = + accessControlAllowOrigin(origins, baseActualResponseHeaders) + + def preflightResponseHeaders(origins: Seq[HttpOrigin], requestHeaders: Seq[String]): List[HttpHeader] = + accessControlAllowHeaders(requestHeaders, accessControlAllowOrigin(origins, basePreflightResponseHeaders)) +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object CorsSettingsImpl extends SettingsCompanionImpl[CorsSettingsImpl]("akka.http.cors") { + + val allowAnySet = Set("*") + override def fromSubConfig(root: Config, config: Config): CorsSettingsImpl = { + new CorsSettingsImpl( + allowGenericHttpRequests = config.getBoolean("allow-generic-http-requests"), + allowCredentials = config.getBoolean("allow-credentials"), + allowedOrigins = config.getStringList("allowed-origins").asScala.toSet, + allowedHeaders = config.getStringList("allowed-headers").asScala.toSet, + allowedMethods = config.getStringList("allowed-methods").asScala.toSet[String].map(method => + HttpMethods.getForKey(method).getOrElse(HttpMethod.custom(method))), + exposedHeaders = config.getStringList("exposed-headers").asScala.toSet, + maxAge = config.getDuration("max-age").toScala + ) + } +} + +/** + * INTERNAL API + */ +@InternalApi +private[akka] object HttpOriginMatcher { + val matchAny: Seq[HttpOrigin] => Boolean = ConstantFun.anyToTrue + + private def hasWildcard(origin: HttpOrigin): Boolean = + origin.host.host.isNamedHost && origin.host.host.address.startsWith("*.") + + private def strict(origins: Set[HttpOrigin]): HttpOrigin => Boolean = origins.contains + + private def withWildcards(allowedOrigins: Set[HttpOrigin]): HttpOrigin => Boolean = { + val matchers = allowedOrigins.map { wildcardOrigin => + val suffix = wildcardOrigin.host.host.address.stripPrefix("*") + + (origin: HttpOrigin) => + origin.scheme == wildcardOrigin.scheme && + origin.host.port == wildcardOrigin.host.port && + origin.host.host.address.endsWith(suffix) + } + + origin => matchers.exists(_.apply(origin)) + } + + def apply(allowedOrigins: Set[String]): Seq[HttpOrigin] => Boolean = { + if (allowedOrigins == CorsSettingsImpl.allowAnySet) matchAny + else { + val httpOrigins = allowedOrigins.map(HttpOrigin.apply) + val (wildCardAllows, strictAllows) = httpOrigins.partition(hasWildcard) + val strictMatch = strict(strictAllows) + val wildCardMatch = withWildcards(wildCardAllows) + + // strict is cheaper so start with those + val matcher = { (origin: HttpOrigin) => strictMatch(origin) || wildCardMatch(origin) } + + { (origins: Seq[HttpOrigin]) => + origins.exists(matcher) + } + } + } +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala index 8fe2092226a..3106da19e32 100644 --- a/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Directives.scala @@ -5,13 +5,12 @@ package akka.http.javadsl.server import java.util.function.{ BiFunction, Function, Supplier } +import akka.http.javadsl.server.directives.CorsDirectives -import akka.http.javadsl.server.directives.FramedEntityStreamingDirectives import scala.annotation.nowarn - import scala.annotation.varargs -abstract class AllDirectives extends FramedEntityStreamingDirectives +abstract class AllDirectives extends CorsDirectives /** * Collects all default directives into one class for simple importing of static functions. diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala b/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala index 1d87dce870d..1e92d03f557 100644 --- a/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala +++ b/akka-http/src/main/scala/akka/http/javadsl/server/Rejections.scala @@ -449,3 +449,11 @@ object Rejections { def rejectionError(rejection: Rejection) = s.RejectionError(convertToScala(rejection)) } + +/** + * Not for user extension + */ +@DoNotInherit +trait CorsRejection extends Rejection { + def description: String +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/server/directives/CorsDirectives.scala b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CorsDirectives.scala new file mode 100644 index 00000000000..103b6bb9c42 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/server/directives/CorsDirectives.scala @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.javadsl.server.directives + +import akka.http.javadsl.server.Route +import akka.http.scaladsl.server.directives.{ CorsDirectives => CD } +import akka.http.javadsl.settings.CorsSettings + +import java.util.function.Supplier + +/** + * Directives for CORS, cross origin requests. + * + * For an overview on how CORS works, see the MDN web docs page on CORS: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + * CORS is part of the WHATWG Fetch "Living Standard" https://fetch.spec.whatwg.org/#http-cors-protocol + * + * This implementation is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + */ +abstract class CorsDirectives extends FramedEntityStreamingDirectives { + import akka.http.javadsl.server.RoutingJavaMapping.Implicits._ + import akka.http.javadsl.server.RoutingJavaMapping._ + + /** + * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests. + * + * The settings are loaded from the Actor System configuration. + */ + def cors(inner: Supplier[Route]): Route = RouteAdapter { + CD.cors() { + inner.get().delegate + } + } + + /** + * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests using the given cors + * settings. + */ + def cors(settings: CorsSettings, inner: Supplier[Route]): Route = RouteAdapter { + CD.cors(settings.asInstanceOf[akka.http.scaladsl.settings.CorsSettings]) { + inner.get().delegate + } + } +} diff --git a/akka-http/src/main/scala/akka/http/javadsl/settings/CorsSettings.scala b/akka-http/src/main/scala/akka/http/javadsl/settings/CorsSettings.scala new file mode 100644 index 00000000000..8580887db8d --- /dev/null +++ b/akka-http/src/main/scala/akka/http/javadsl/settings/CorsSettings.scala @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.javadsl.settings + +import akka.annotation.DoNotInherit +import akka.http.impl.settings.CorsSettingsImpl +import akka.http.javadsl.model.HttpMethod + +import java.time.Duration +import java.util.{ Set => JSet } +import scala.jdk.CollectionConverters.{ SetHasAsJava, SetHasAsScala } +import scala.jdk.DurationConverters.ScalaDurationOps + +/** + * Settings for the CORS support + * + * This implementation is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + * + * Not for user extension + */ +@DoNotInherit +abstract class CorsSettings private[akka] { self: CorsSettingsImpl => + import akka.http.impl.util.JavaMapping.Implicits._ + + /** + * Allow generic requests, that are outside the scope of the specification, for example lacking an `Origin` header + * to pass through the directive. + * + * When false strict CORS filtering is applied and any invalid request will be rejected. + */ + def allowGenericHttpRequests: Boolean + + /** + * If enabled, the header `Access-Control-Allow-Credentials` + * is included in the response, indicating that the actual request can include user credentials. Examples of user + * credentials are: cookies, HTTP authentication or client-side certificates. + */ + def allowCredentials: Boolean + + /** + * List of origins that the CORS filter must allow. + * + * Can also be set to a single `*` to allow access to the resource from any origin. + * + * Controls the content of the `Access-Control-Allow-Origin` response header: if parameter is `*` and + * credentials are not allowed, a `*` is returned in `Access-Control-Allow-Origin`. Otherwise, the origins given in the + * `Origin` request header are echoed. + * + * Hostname starting with `*.` will match any sub-domain. The scheme and the port are always strictly matched. + * + * The actual or preflight request is rejected if any of the origins from the request is not allowed. + */ + def getAllowedOrigins: JSet[String] = self.allowedOrigins.asJava + + /** + * Set of request headers that are allowed when making an actual request. + * + * Controls the content of the `Access-Control-Allow-Headers` header in a preflight response: If set to a single `*`, + * the headers from `Access-Control-Request-Headers` are echoed. Otherwise specified list of header names is returned + * as part of the header. + */ + def getAllowedHeaders: JSet[String] = self.allowedHeaders.asJava + + /** + * List of methods allowed when making an actual request. The listed headers are returned as part of the + * `Access-Control-Allow-Methods` preflight response header. + * + * The preflight request will be rejected if the `Access-Control-Request-Method` header's method is not part of the + * list. + */ + def getAllowedMethods: JSet[HttpMethod] = self.allowedMethods.map(_.asJava).asJava + + /** + * Set of headers (other than simple response headers) that browsers are allowed to access. If not empty, the listed + * headers are returned as part of the `Access-Control-Expose-Headers` header in responses. + */ + def getExposedHeaders: JSet[String] = self.exposedHeaders.asJava + + /** + * The time the browser is allowed to cache the results of a preflight request. This value is + * returned as part of the `Access-Control-Max-Age` preflight response header. If `java.time.Duration.ZERO`, + * the header is not added to the preflight response. + */ + def getMaxAge: Duration = self.maxAge.toJava + + def withAllowAnyHeader(): CorsSettings = + self.copy(allowedHeaders = Set("*")) + + def withAllowedHeaders(headerNames: JSet[String]): CorsSettings = + self.copy(allowedHeaders = headerNames.asScala.toSet) + + def withAllowAnyOrigin(): CorsSettings = + self.copy(allowedOrigins = Set("*")) + + def withAllowedOrigins(origins: JSet[String]): CorsSettings = + self.copy(allowedOrigins = origins.asScala.toSet) + + def withAllowedMethods(methods: JSet[HttpMethod]): CorsSettings = + self.copy(allowedMethods = methods.asScala.toSet[HttpMethod].map(_.asScala)) + + def withExposedHeaders(headerNames: JSet[String]): CorsSettings = + self.copy(exposedHeaders = headerNames.asScala.toSet) + + def withAllowGenericHttpRequests(allow: Boolean): CorsSettings = + self.copy(allowGenericHttpRequests = allow) + + def withAllowCredentials(allow: Boolean): CorsSettings = + self.copy(allowCredentials = allow) +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala index c0eab1638dd..3c4680282a2 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Directives.scala @@ -38,6 +38,7 @@ trait Directives extends RouteConcatenation with WebSocketDirectives with FramedEntityStreamingDirectives with AttributeDirectives + with CorsDirectives /** * Collects all default directives into one object for simple importing. diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala index 36b9933b1ba..ae4dc13bfaf 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/Rejection.scala @@ -339,3 +339,8 @@ final case class CircuitBreakerOpenRejection(cause: CircuitBreakerOpenException) * (Custom marshallers can of course use it as well.) */ final case class RejectionError(rejection: Rejection) extends RuntimeException(rejection.toString) + +/** + * Rejection created by the CORS directives. + */ +final case class CorsRejection(description: String) extends jserver.CorsRejection with Rejection diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala index b7a10efa60b..51c4f8ff796 100644 --- a/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/RejectionHandler.scala @@ -286,6 +286,10 @@ object RejectionHandler { headers = `Sec-WebSocket-Protocol`(supported) :: Nil)) } .handle { case ValidationRejection(msg, _) => rejectRequestEntityAndComplete((BadRequest, msg)) } + .handleAll[CorsRejection] { rejections => + val causes = rejections.map(_.description).mkString(", ") + rejectRequestEntityAndComplete((BadRequest, s"CORS: $causes")) + } .handle { case x => sys.error("Unhandled rejection: " + x) } .handleNotFound { rejectRequestEntityAndComplete((NotFound, "The requested resource could not be found.")) } .result() diff --git a/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CorsDirectives.scala b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CorsDirectives.scala new file mode 100644 index 00000000000..97684a6b955 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CorsDirectives.scala @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.scaladsl.server +package directives + +import akka.http.impl.settings.CorsSettingsImpl +import akka.http.scaladsl.model.HttpMethods.OPTIONS +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{ HttpMethod, HttpResponse, StatusCodes } +import akka.http.scaladsl.settings.CorsSettings +import akka.util.OptionVal + +/** + * Directives for CORS, cross origin requests. + * + * For an overview on how CORS works, see the MDN web docs page on CORS: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + * CORS is part of the WHATWG Fetch "Living Standard" https://fetch.spec.whatwg.org/#http-cors-protocol + * + * This implementation is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + * + * @groupname cors CORS directives + * @groupprio cors 50 + */ +trait CorsDirectives { + + import BasicDirectives._ + import CorsDirectives._ + import RouteDirectives._ + + /** + * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests using the default cors + * configuration from the actor system. + */ + def cors(): Directive0 = { + extractActorSystem.flatMap { system => + cors(CorsSettings(system)) + } + } + + /** + * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests using the given cors + * settings. + */ + def cors(settings: CorsSettings): Directive0 = { + val settingsImpl = settings.asInstanceOf[CorsSettingsImpl] + + def validateOrigins(origins: Seq[HttpOrigin]): OptionVal[CorsRejection] = + if (settingsImpl.originsMatches(origins)) OptionVal.None + else OptionVal.Some(invalidOriginRejection(origins)) + + def validateMethod(method: HttpMethod): OptionVal[CorsRejection] = + if (settings.allowedMethods.contains(method)) OptionVal.None + else OptionVal.Some(invalidMethodRejection(method)) + + def validateHeaders(headers: Seq[String]): OptionVal[CorsRejection] = { + val invalidHeaders = headers.filterNot(settingsImpl.headerNameAllowed) + if (invalidHeaders.isEmpty) OptionVal.None + else OptionVal.Some(invalidHeadersRejection(invalidHeaders)) + } + + extractRequest.flatMap { request => + val origins = request.header[Origin] match { + case Some(origin) => OptionVal.Some(origin.origins) + case None => OptionVal.None + } + val method = request.header[`Access-Control-Request-Method`] match { + case Some(accessControlMethod) => OptionVal.Some(accessControlMethod.method) + case None => OptionVal.None + } + (request.method, origins, method) match { + case (OPTIONS, OptionVal.Some(origins), OptionVal.Some(requestMethod)) if origins.lengthCompare(1) <= 0 => + // pre-flight CORS request + val headers = request.header[`Access-Control-Request-Headers`] match { + case Some(header) => header.headers + case None => Seq.empty + } + + val rejections = collectRejections( + validateOrigins(origins), + validateMethod(requestMethod), + validateHeaders(headers)) + + if (rejections.isEmpty) { + complete(HttpResponse(StatusCodes.OK, settingsImpl.preflightResponseHeaders(origins, headers))) + } else { + reject(rejections: _*) + } + + case (_, OptionVal.Some(origins), OptionVal.None) => + // actual CORS request + validateOrigins(origins) match { + case OptionVal.Some(rejection) => + reject(rejection) + case _ => + mapResponseHeaders { oldHeaders => + settingsImpl.actualResponseHeaders(origins) ++ oldHeaders.filterNot(h => lcHeaderNamesToClean(h.lowercaseName)) + } + } + + case _ => + // not a valid CORS request, can be allowed through setting + if (settings.allowGenericHttpRequests) pass + else reject(malformedRejection) + } + } + } +} + +object CorsDirectives extends CorsDirectives { + private val NoRejections = Array.empty[Rejection] + + // allocation optimized collection of multiple rejections + private def collectRejections(originsRejection: OptionVal[Rejection], methodRejection: OptionVal[Rejection], headerRejection: OptionVal[Rejection]): Array[Rejection] = + if (originsRejection.isEmpty && methodRejection.isEmpty && headerRejection.isEmpty) NoRejections + else { + def count(opt: OptionVal[_]) = if (opt.isDefined) 1 else 0 + val rejections = Array.ofDim[Rejection](count(originsRejection) + count(methodRejection) + count(headerRejection)) + var idx = 0 + def addIfPresent(opt: OptionVal[Rejection]): Unit = + if (opt.isDefined) { + rejections(idx) = opt.get + idx += 1 + } + addIfPresent(originsRejection) + addIfPresent(methodRejection) + addIfPresent(headerRejection) + rejections + } + + private val lcHeaderNamesToClean: Set[String] = Set( + `Access-Control-Allow-Origin`, + `Access-Control-Expose-Headers`, + `Access-Control-Allow-Credentials`, + `Access-Control-Allow-Methods`, + `Access-Control-Allow-Headers`, + `Access-Control-Max-Age` + ).map(_.lowercaseName) + + private def malformedRejection = CorsRejection("malformed request") + + private def invalidOriginRejection(origins: Seq[HttpOrigin]) = CorsRejection(s"invalid origin '${if (origins.isEmpty) "null" else origins.mkString(" ")}'") + + private def invalidMethodRejection(method: HttpMethod) = CorsRejection(s"invalid method '${method.value}'") + + private def invalidHeadersRejection(headers: Seq[String]) = CorsRejection(s"invalid headers '${headers.mkString(" ")}'") + +} diff --git a/akka-http/src/main/scala/akka/http/scaladsl/settings/CorsSettings.scala b/akka-http/src/main/scala/akka/http/scaladsl/settings/CorsSettings.scala new file mode 100644 index 00000000000..755378664c4 --- /dev/null +++ b/akka-http/src/main/scala/akka/http/scaladsl/settings/CorsSettings.scala @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2024 Lightbend Inc. + * Copyright 2016 Lomig Mégard + */ + +package akka.http.scaladsl.settings + +import akka.actor.ClassicActorSystemProvider +import akka.annotation.{ ApiMayChange, DoNotInherit } +import akka.http.impl.settings.CorsSettingsImpl +import akka.http.scaladsl.model.HttpMethod +import com.typesafe.config.Config + +import scala.concurrent.duration.FiniteDuration + +/** + * Settings for the CORS support + * + * This implementation is based on the akka-http-cors project by Lomig Mégard, licensed under the Apache License, Version 2.0. + * + * Not for user extension + */ +@ApiMayChange @DoNotInherit +trait CorsSettings extends akka.http.javadsl.settings.CorsSettings { self: CorsSettingsImpl => + + /** + * Allow generic requests, that are outside the scope of the specification, for example lacking an `Origin` header + * to pass through the directive. + * + * When false strict CORS filtering is applied and any invalid request will be rejected. + */ + override def allowGenericHttpRequests: Boolean + + /** + * If enabled, the header `Access-Control-Allow-Credentials` + * is included in the response, indicating that the actual request can include user credentials. Examples of user + * credentials are: cookies, HTTP authentication or client-side certificates. + */ + override def allowCredentials: Boolean + + /** + * List of origins that the CORS filter must allow. + * + * Can also be set to a single `*` to allow access to the resource from any origin. + * + * Controls the content of the `Access-Control-Allow-Origin` response header: if parameter is `*` and + * credentials are not allowed, a `*` is returned in `Access-Control-Allow-Origin`. Otherwise, the origins given in the + * `Origin` request header are echoed. + * + * Hostname starting with `*.` will match any sub-domain. The scheme and the port are always strictly matched. + * + * The actual or preflight request is rejected if any of the origins from the request is not allowed. + */ + override def allowedOrigins: Set[String] + + /** + * Set of request headers that are allowed when making an actual request. + * + * Controls the content of the `Access-Control-Allow-Headers` header in a preflight response: If set to a single `*`, + * the headers from `Access-Control-Request-Headers` are echoed. Otherwise specified list of header names is returned + * as part of the header. + */ + override def allowedHeaders: Set[String] + + /** + * List of methods allowed when making an actual request. The listed headers are returned as part of the + * `Access-Control-Allow-Methods` preflight response header. + * + * The preflight request will be rejected if the `Access-Control-Request-Method` header's method is not part of the + * list. + */ + override def allowedMethods: Set[HttpMethod] + + /** + * Set of headers (other than simple response headers) that browsers are allowed to access. If not empty, the listed + * headers are returned as part of the `Access-Control-Expose-Headers` header in responses. + */ + override def exposedHeaders: Set[String] + + /** + * The time the browser is allowed to cache the results of a preflight request. This value is + * returned as part of the `Access-Control-Max-Age` preflight response header. If `scala.concurrent.duration.Duration.Zero`, + * the header is not added to the preflight response. + */ + override def maxAge: FiniteDuration + + def withMaxAge(maxAge: FiniteDuration): CorsSettings = + self.copy(maxAge = maxAge) + + override def withAllowAnyOrigin(): CorsSettings = + self.copy(allowedOrigins = Set("*")) + + def withAllowedOrigins(origins: Set[String]): CorsSettings = + self.copy(allowedOrigins = origins) + + override def withAllowAnyHeader(): CorsSettings = + self.copy(allowedHeaders = Set("*")) + def withAllowedHeaders(headerNames: Set[String]): CorsSettings = + self.copy(allowedHeaders = headerNames) + def withAllowedMethods(methods: Set[HttpMethod]): CorsSettings = + self.copy(allowedMethods = methods) + + def withExposedHeaders(headerNames: Set[String]): CorsSettings = + self.copy(exposedHeaders = headerNames) + + override def withAllowGenericHttpRequests(allow: Boolean): CorsSettings = + self.copy(allowGenericHttpRequests = allow) + + override def withAllowCredentials(allow: Boolean): CorsSettings = + self.copy(allowCredentials = allow) + +} + +object CorsSettings { + def apply(system: ClassicActorSystemProvider): CorsSettings = + CorsSettingsImpl(system.classicSystem) + def apply(config: Config): CorsSettings = + CorsSettingsImpl(config) +} diff --git a/docs/src/main/paradox/compatibility-guidelines.md b/docs/src/main/paradox/compatibility-guidelines.md index 079fe9594c8..605f220cf77 100644 --- a/docs/src/main/paradox/compatibility-guidelines.md +++ b/docs/src/main/paradox/compatibility-guidelines.md @@ -29,6 +29,8 @@ Scala akka.http.scaladsl.unmarshalling.sse.EventStreamUnmarshalling akka.http.scaladsl.OutgoingConnectionBuilder#managedPersistentHttp2 akka.http.scaladsl.OutgoingConnectionBuilder#managedPersistentHttp2WithPriorKnowledge + akka.http.scaladsl.settings.ServerSentEventSettings + akka.http.scaladsl.settings.CorsSettings ``` Java @@ -43,6 +45,8 @@ Java akka.http.javadsl.model.RequestResponseAssociation akka.http.javadsl.OutgoingConnectionBuilder#managedPersistentHttp2WithPriorKnowledge akka.http.javadsl.OutgoingConnectionBuilder#managedPersistentHttp2 + akka.http.javadsl.settings.ServerSentEventSettings + akka.http.javadsl.settings.CorsSettings ``` #### akka-http-caching @@ -82,7 +86,6 @@ Scala akka.http.scaladsl.settings.Http2ServerSettings akka.http.scaladsl.settings.Http2ClientSettings akka.http.scaladsl.settings.PreviewServerSettings - akka.http.scaladsl.settings.ServerSentEventSettings akka.http.scaladsl.model.headers.CacheDirectives.immutableDirective akka.http.scaladsl.model.headers.X-Forwarded-Host akka.http.scaladsl.model.headers.X-Forwarded-Proto @@ -107,7 +110,6 @@ Java akka.http.javadsl.settings.ConnectionPoolSettings#withResponseEntitySubscriptionTimeout akka.http.javadsl.settings.PoolImplementation akka.http.javadsl.settings.PreviewServerSettings - akka.http.javadsl.settings.ServerSentEventSettings ``` ## Versioning and Compatibility diff --git a/docs/src/main/paradox/migration-guide/migration-guide-10.6.x.md b/docs/src/main/paradox/migration-guide/migration-guide-10.6.x.md index eb7c3239421..1f5a1258cd5 100644 --- a/docs/src/main/paradox/migration-guide/migration-guide-10.6.x.md +++ b/docs/src/main/paradox/migration-guide/migration-guide-10.6.x.md @@ -45,3 +45,15 @@ Akka HTTP 10.6.x requires Akka version >= 2.9.0. The Jackson dependency has been updated to 2.15.2 in Akka HTTP 10.6.0. That bump includes many fixes and changes to Jackson, but it should not introduce any incompatibility in serialized format. + +### Built in CORS support + +Built in directives with CORS support @ref[has been added](../routing-dsl/directives/cors-directives/cors.md) heavily inspired +by the pre-existing community library [akka-http-cors](https://github.com/lomigmegard/akka-http-cors). + +Directive API and configuration are similar and migrating should be straightforward. Some of the lower level APIs for implementing +CORS that the library gave access to (`HttpOriginMatcher`) and the `CorsRejection` implementation +is simplified or not available as public API in the new Akka HTTP CORS implementation. + +The new configuration namespace is `akka.http.cors` instead of `akka-http-cors`, the individual setting names are the same +however `allowed-origins`, `allowed-headers` are always lists of values with a single `["*"]` to represent match-any. \ No newline at end of file diff --git a/docs/src/main/paradox/routing-dsl/directives/alphabetically.md b/docs/src/main/paradox/routing-dsl/directives/alphabetically.md index 750dbbe199e..a2cb65a3ffa 100644 --- a/docs/src/main/paradox/routing-dsl/directives/alphabetically.md +++ b/docs/src/main/paradox/routing-dsl/directives/alphabetically.md @@ -25,6 +25,7 @@ |@ref[completeWith](marshalling-directives/completeWith.md) | Uses the marshaller for a given type to extract a completion function | |@ref[conditional](cache-condition-directives/conditional.md) | Wraps its inner route with support for conditional requests as defined by [RFC 7232](https://tools.ietf.org/html/rfc7232) | |@ref[cookie](cookie-directives/cookie.md) | Extracts the @apidoc[HttpCookie] with the given name | +|@ref[cors](cors-directives/cors.md) | Wrapps its inner route with CORS handling | |@ref[decodeRequest](coding-directives/decodeRequest.md) | Decompresses the request if it is `gzip` or `deflate` compressed | |@ref[decodeRequestWith](coding-directives/decodeRequestWith.md) | Decodes the incoming request using one of the given decoders | |@ref[delete](method-directives/delete.md) | Rejects all non-DELETE requests | diff --git a/docs/src/main/paradox/routing-dsl/directives/by-trait.md b/docs/src/main/paradox/routing-dsl/directives/by-trait.md index 70c5c44b041..4935fcaaa56 100644 --- a/docs/src/main/paradox/routing-dsl/directives/by-trait.md +++ b/docs/src/main/paradox/routing-dsl/directives/by-trait.md @@ -89,6 +89,7 @@ All predefined directives are organized into traits that form one part of the ov * [caching-directives/index](caching-directives/index.md) * [coding-directives/index](coding-directives/index.md) * [cookie-directives/index](cookie-directives/index.md) +* [cors-directives/index](cors-directives/index.md) * [debugging-directives/index](debugging-directives/index.md) * [execution-directives/index](execution-directives/index.md) * [file-and-resource-directives/index](file-and-resource-directives/index.md) diff --git a/docs/src/main/paradox/routing-dsl/directives/cors-directives/cors.md b/docs/src/main/paradox/routing-dsl/directives/cors-directives/cors.md new file mode 100644 index 00000000000..0834f369cfa --- /dev/null +++ b/docs/src/main/paradox/routing-dsl/directives/cors-directives/cors.md @@ -0,0 +1,33 @@ +# cors + +@@@ div { .group-scala } + +## Signature + +@@signature [CorsDirectives.scala](/akka-http/src/main/scala/akka/http/scaladsl/server/directives/CorsDirectives.scala) { #cors } + +@@@ + +## Description + +CORS (Cross Origin Resource Sharing) is a mechanism to enable cross origin requests by informing browsers about origins +other than the server itself that the browser can load resources from via HTTP headers. + +For an overview on how CORS works, see the [MDN web docs page on CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) + +The directive uses config defined under `akka.http.cors`, or an explicitly provided `CorsSettings` instance. + +## Example + +The `cors` directive will provide a pre-flight `OPTIONS` handler and let other requests through to the inner route: + +Scala +: @@snip [CorsDirectivesExamplesSpec.scala](/docs/src/test/scala/docs/http/scaladsl/server/directives/CorsDirectivesExamplesSpec.scala) { #cors } + +Java +: @@snip [CorsDirectivesExamplesTest.java](/docs/src/test/java/docs/http/javadsl/server/directives/CorsDirectivesExamplesTest.java) { #cors } + + +## Reference configuration + +@@snip [reference.conf](/akka-http/src/main/resources/reference.conf) { #cors } \ No newline at end of file diff --git a/docs/src/main/paradox/routing-dsl/directives/cors-directives/index.md b/docs/src/main/paradox/routing-dsl/directives/cors-directives/index.md new file mode 100644 index 00000000000..8eaec7a9691 --- /dev/null +++ b/docs/src/main/paradox/routing-dsl/directives/cors-directives/index.md @@ -0,0 +1,9 @@ +# CorsDirectives + +@@toc { depth=1 } + +@@@ index + +* [cors](cors.md) + +@@@ \ No newline at end of file diff --git a/docs/src/test/java/docs/http/javadsl/server/directives/CorsDirectivesExamplesTest.java b/docs/src/test/java/docs/http/javadsl/server/directives/CorsDirectivesExamplesTest.java new file mode 100644 index 00000000000..01ce9856590 --- /dev/null +++ b/docs/src/test/java/docs/http/javadsl/server/directives/CorsDirectivesExamplesTest.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2023 Lightbend Inc. + */ + +package docs.http.javadsl.server.directives; + +import akka.http.javadsl.model.HttpMethods; +import akka.http.javadsl.model.HttpRequest; +import akka.http.javadsl.model.headers.*; +import akka.http.javadsl.server.Rejections; +import akka.http.javadsl.server.Route; +import akka.http.javadsl.testkit.JUnitRouteTest; +import akka.http.javadsl.model.StatusCodes; +import org.junit.Test; + +public class CorsDirectivesExamplesTest extends JUnitRouteTest { + + @Test + public void cors() { + //#cors + final Route route = cors(() -> + complete(StatusCodes.OK) + ); + + // tests: + // preflight + HttpOrigin exampleOrigin = HttpOrigin.parse("http://example.com"); + testRoute(route).run(HttpRequest.OPTIONS("/") + .addHeader(Origin.create(exampleOrigin)) + .addHeader(AccessControlRequestMethod.create(HttpMethods.GET))) + .assertStatusCode(StatusCodes.OK) + .assertHeaderExists(AccessControlAllowOrigin.create(HttpOriginRange.create(exampleOrigin))) + .assertHeaderExists(AccessControlAllowMethods.create(HttpMethods.GET, HttpMethods.POST, HttpMethods.HEAD, HttpMethods.OPTIONS)) + .assertHeaderExists(AccessControlMaxAge.create(1800)) + .assertHeaderExists(AccessControlAllowCredentials.create(true)); + + // regular call + runRouteUnSealed(route, HttpRequest.GET("/") + .addHeader(Origin.create(exampleOrigin))) + .assertStatusCode(StatusCodes.OK) + .assertHeaderExists(AccessControlAllowOrigin.create(HttpOriginRange.create(exampleOrigin))) + .assertHeaderExists(AccessControlAllowCredentials.create(true)); + //#cors + } + +} diff --git a/docs/src/test/scala/docs/http/scaladsl/server/directives/CorsDirectivesExamplesSpec.scala b/docs/src/test/scala/docs/http/scaladsl/server/directives/CorsDirectivesExamplesSpec.scala new file mode 100644 index 00000000000..4a9e738bb32 --- /dev/null +++ b/docs/src/test/scala/docs/http/scaladsl/server/directives/CorsDirectivesExamplesSpec.scala @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2009-2023 Lightbend Inc. + */ + +package docs.http.scaladsl.server.directives + +import akka.http.scaladsl.model.StatusCodes._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.{ HttpOrigin, Origin, `Access-Control-Allow-Credentials`, `Access-Control-Allow-Methods`, `Access-Control-Allow-Origin`, `Access-Control-Max-Age`, `Access-Control-Request-Method` } +import akka.http.scaladsl.server.{ Route, RoutingSpec } +import docs.CompileOnlySpec + +class CorsDirectivesExamplesSpec extends RoutingSpec with CompileOnlySpec { + "cors" in { + //#cors + val route = + cors() { + complete(Ok) + } + + // tests: + + // preflight + Options() ~> Origin(HttpOrigin("http://example.com")) ~> `Access-Control-Request-Method`(HttpMethods.GET) ~> route ~> check { + status shouldBe StatusCodes.OK + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(HttpOrigin("http://example.com")), + `Access-Control-Allow-Methods`(HttpMethods.GET, HttpMethods.POST, HttpMethods.HEAD, HttpMethods.OPTIONS), + `Access-Control-Max-Age`(1800), + `Access-Control-Allow-Credentials`(true) + ) + } + + // regular request + Get() ~> Origin(HttpOrigin("http://example.com")) ~> route ~> check { + status shouldEqual OK + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`(HttpOrigin("http://example.com")), + `Access-Control-Allow-Credentials`(true) + ) + } + + //#cors + } +} diff --git a/project/CopyrightHeader.scala b/project/CopyrightHeader.scala index 633d2df7f38..f0d863cfadc 100644 --- a/project/CopyrightHeader.scala +++ b/project/CopyrightHeader.scala @@ -34,6 +34,7 @@ object CopyrightHeader extends AutoPlugin { // We hard-code this so PR's created in year X will not suddenly in X+1. // Of course we should remember to update it early in the year. val CurrentYear = "2023" + val AlsoOkYear = "2024" // until we bump the above val CopyrightPattern = "Copyright \\([Cc]\\) (\\d{4}(-\\d{4})?) (Lightbend|Typesafe) Inc. <.*>".r val CopyrightHeaderPattern = s"(?s).*${CopyrightPattern}.*".r @@ -45,7 +46,7 @@ object CopyrightHeader extends AutoPlugin { def updateLightbendHeader(header: String): String = header match { case CopyrightHeaderPattern(years, null, _) => - if (years != CurrentYear) + if (years != CurrentYear && years != AlsoOkYear) CopyrightPattern.replaceFirstIn(header, headerFor(years + "-" + CurrentYear)) else CopyrightPattern.replaceFirstIn(header, headerFor(years))