Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Permit autogenerating types aliases for endpoints #4213

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ object GenScala {
private val streamingImplementationOpt: Opts[Option[String]] =
Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone

private val generateEndpointTypesOpt: Opts[Boolean] =
Opts.flag("generateEndpointTypes", "Whether to emit explicit type aliases for endpoint declarations", "e").orFalse

private val destDirOpt: Opts[File] =
Opts
.option[String]("destdir", "Destination directory", "d")
Expand All @@ -88,7 +91,8 @@ object GenScala {
jsonLibOpt,
validateNonDiscriminatedOneOfsOpt,
maxSchemasPerFileOpt,
streamingImplementationOpt
streamingImplementationOpt,
generateEndpointTypesOpt
)
.mapN {
case (
Expand All @@ -101,7 +105,8 @@ object GenScala {
jsonLib,
validateNonDiscriminatedOneOfs,
maxSchemasPerFile,
streamingImplementation
streamingImplementation,
generateEndpointTypes
) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

Expand All @@ -116,7 +121,8 @@ object GenScala {
jsonLib.getOrElse("circe"),
streamingImplementation.getOrElse("fs2"),
validateNonDiscriminatedOneOfs,
maxSchemasPerFile.getOrElse(400)
maxSchemasPerFile.getOrElse(400),
generateEndpointTypes
)
)
destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ object StreamingImplementation extends Enumeration {
val Akka, FS2, Pekko, Zio = Value
type StreamingImplementation = Value
}
object EndpointCapabilites extends Enumeration {
val Akka, FS2, Nothing, Pekko, Zio = Value
type EndpointCapabilites = Value
}

object BasicGenerator {

Expand All @@ -40,7 +44,8 @@ object BasicGenerator {
jsonSerdeLib: String,
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int
maxSchemasPerFile: Int,
generateEndpointTypes: Boolean
): Map[String, String] = {
val normalisedJsonLib = jsonSerdeLib.toLowerCase match {
case "circe" => JsonSerdeLib.Circe
Expand All @@ -65,7 +70,14 @@ object BasicGenerator {
}

val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation)
endpointGenerator.endpointDefs(
doc,
useHeadTagForObjectNames,
targetScala3,
normalisedJsonLib,
normalisedStreamingImplementation,
generateEndpointTypes
)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
jsonSerdeLib = jsonSerdeLib,
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
streamingImplementation = "fs2",
generateEndpointTypes = false
)
}
def gen(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
streamingImplementation = StreamingImplementation.FS2,
generateEndpointTypes = false
)
.endpointDecls(None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
streamingImplementation = StreamingImplementation.FS2,
generateEndpointTypes = false
)
.endpointDecls(None)
generatedCode should include("val getTestAsdId =")
Expand Down Expand Up @@ -153,7 +154,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
streamingImplementation = StreamingImplementation.FS2,
generateEndpointTypes = false
)
.endpointDecls(None) shouldCompile ()
}
Expand Down Expand Up @@ -205,7 +207,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
streamingImplementation = StreamingImplementation.FS2,
generateEndpointTypes = false
)
.endpointDecls(None)
generatedCode should include(
Expand Down Expand Up @@ -272,7 +275,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
streamingImplementation = "fs2",
generateEndpointTypes = false
)("TapirGeneratedEndpoints")
generatedCode should include(
"""file: sttp.model.Part[java.io.File]"""
Expand All @@ -294,7 +298,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
streamingImplementation = "fs2",
generateEndpointTypes = false
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
val expectedAttrDecls = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ case class OpenApiConfiguration(
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int,
generateEndpointTypes: Boolean,
additionalPackages: List[(String, File)]
)

Expand All @@ -27,6 +28,7 @@ trait OpenapiCodegenKeys {
lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file")
lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.")
lazy val openapiGenerateEndpointTypes = settingKey[Boolean]("Whether to emit explicit types for endpoint denfs")
lazy val openapiOpenApiConfiguration =
settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
openapiStreamingImplementation.value,
openapiValidateNonDiscriminatedOneOfs.value,
openapiMaxSchemasPerFile.value,
openapiGenerateEndpointTypes.value,
openapiAdditionalPackages.value
)
def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq(
Expand All @@ -44,6 +45,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
openapiMaxSchemasPerFile := 400,
openapiAdditionalPackages := Nil,
openapiStreamingImplementation := "fs2",
openapiGenerateEndpointTypes := false,
standardParamSetting
)

Expand Down Expand Up @@ -73,6 +75,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
c.streamingImplementation,
c.validateNonDiscriminatedOneOfs,
c.maxSchemasPerFile,
c.generateEndpointTypes,
srcDir,
taskStreams.cacheDirectory,
sv.startsWith("3"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ case class OpenapiCodegenTask(
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int,
generateEndpointTypes: Boolean,
dir: File,
cacheDir: File,
targetScala3: Boolean,
Expand Down Expand Up @@ -59,7 +60,8 @@ case class OpenapiCodegenTask(
jsonSerdeLib,
streamingImplementation,
validateNonDiscriminatedOneOfs,
maxSchemasPerFile
maxSchemasPerFile,
generateEndpointTypes
)
.map { case (objectName, fileBody) =>
val file = directory / s"$objectName.scala"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ object TapirGeneratedEndpoints {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}


case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
// Case-insensitive mapping
def decode(s: String): sttp.tapir.DecodeResult[T] =
Expand All @@ -63,9 +62,16 @@ object TapirGeneratedEndpoints {
}
def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
EnumExtraParamSupport(enumName, T)
sealed trait Error
sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
case class SimpleError (
message: String
) extends Error
case class NotFoundError (
reason: String
) extends Error
case class SubtypeWithoutD1 (
s: String,
i: Option[Int] = None,
Expand Down Expand Up @@ -119,34 +125,39 @@ object TapirGeneratedEndpoints {



lazy val getBinaryTest =
type GetBinaryTestEndpoint = Endpoint[Unit, Unit, Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, sttp.capabilities.pekko.PekkoStreams]
lazy val getBinaryTest: GetBinaryTestEndpoint =
endpoint
.get
.in(("binary" / "test"))
.out(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()).description("Response CSV body"))

lazy val postBinaryTest =
type PostBinaryTestEndpoint = Endpoint[Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, Unit, String, sttp.capabilities.pekko.PekkoStreams]
lazy val postBinaryTest: PostBinaryTestEndpoint =
endpoint
.post
.in(("binary" / "test"))
.in(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()))
.out(jsonBody[String].description("successful operation"))

lazy val putAdtTest =
type PutAdtTestEndpoint = Endpoint[Unit, ADTWithoutDiscriminator, Unit, ADTWithoutDiscriminator, Any]
lazy val putAdtTest: PutAdtTestEndpoint =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
type PostAdtTestEndpoint = Endpoint[Unit, ADTWithDiscriminatorNoMapping, Unit, ADTWithDiscriminator, Any]
lazy val postAdtTest: PostAdtTestEndpoint =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))

lazy val postInlineEnumTest =
type PostInlineEnumTestEndpoint = Endpoint[Unit, (PostInlineEnumTestQueryEnum, Option[PostInlineEnumTestQueryOptEnum], List[PostInlineEnumTestQuerySeqEnum], Option[List[PostInlineEnumTestQueryOptSeqEnum]], ObjectWithInlineEnum), Unit, Unit, Any]
lazy val postInlineEnumTest: PostInlineEnumTestEndpoint =
endpoint
.post
.in(("inline" / "enum" / "test"))
Expand Down Expand Up @@ -197,7 +208,14 @@ object TapirGeneratedEndpoints {
extraCodecSupport[PostInlineEnumTestQueryOptSeqEnum]("PostInlineEnumTestQueryOptSeqEnum", PostInlineEnumTestQueryOptSeqEnum)
}

type GetOneofErrorTestEndpoint = Endpoint[Unit, Unit, Error, Unit, Any]
lazy val getOneofErrorTest: GetOneofErrorTestEndpoint =
endpoint
.get
.in(("oneof" / "error" / "test"))
.errorOut(oneOf[Error](oneOfVariant(sttp.model.StatusCode(404), jsonBody[NotFoundError].description("Not found")), oneOfVariant(sttp.model.StatusCode(400), jsonBody[SimpleError].description("Not found"))))
.out(statusCode(sttp.model.StatusCode(204)).description("No response"))

lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest)
lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest, getOneofErrorTest)

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ object TapirGeneratedEndpointsJsonSerdes {
}
} yield res
}
implicit lazy val simpleErrorJsonDecoder: io.circe.Decoder[SimpleError] = io.circe.generic.semiauto.deriveDecoder[SimpleError]
implicit lazy val simpleErrorJsonEncoder: io.circe.Encoder[SimpleError] = io.circe.generic.semiauto.deriveEncoder[SimpleError]
implicit lazy val notFoundErrorJsonDecoder: io.circe.Decoder[NotFoundError] = io.circe.generic.semiauto.deriveDecoder[NotFoundError]
implicit lazy val notFoundErrorJsonEncoder: io.circe.Encoder[NotFoundError] = io.circe.generic.semiauto.deriveEncoder[NotFoundError]
implicit lazy val subtypeWithoutD1JsonDecoder: io.circe.Decoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithoutD1]
implicit lazy val subtypeWithoutD1JsonEncoder: io.circe.Encoder[SubtypeWithoutD1] = io.circe.generic.semiauto.deriveEncoder[SubtypeWithoutD1]
implicit lazy val subtypeWithD1JsonDecoder: io.circe.Decoder[SubtypeWithD1] = io.circe.generic.semiauto.deriveDecoder[SubtypeWithD1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ object TapirGeneratedEndpointsSchemas {
import sttp.tapir.generated.TapirGeneratedEndpoints._
import sttp.tapir.generic.auto._
implicit lazy val anEnumTapirSchema: sttp.tapir.Schema[AnEnum] = sttp.tapir.Schema.derived
implicit lazy val notFoundErrorTapirSchema: sttp.tapir.Schema[NotFoundError] = sttp.tapir.Schema.derived
implicit lazy val objectWithInlineEnumInlineEnumTapirSchema: sttp.tapir.Schema[ObjectWithInlineEnumInlineEnum] = sttp.tapir.Schema.derived
implicit lazy val objectWithInlineEnumTapirSchema: sttp.tapir.Schema[ObjectWithInlineEnum] = sttp.tapir.Schema.derived
implicit lazy val simpleErrorTapirSchema: sttp.tapir.Schema[SimpleError] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithD1TapirSchema: sttp.tapir.Schema[SubtypeWithD1] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithD2TapirSchema: sttp.tapir.Schema[SubtypeWithD2] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithoutD1TapirSchema: sttp.tapir.Schema[SubtypeWithoutD1] = sttp.tapir.Schema.derived
Expand Down Expand Up @@ -38,6 +40,7 @@ object TapirGeneratedEndpointsSchemas {
case _ => throw new IllegalStateException("Derived schema for ADTWithDiscriminatorNoMapping should be a coproduct")
}
}
implicit lazy val errorTapirSchema: sttp.tapir.Schema[Error] = sttp.tapir.Schema.derived
implicit lazy val subtypeWithoutD3TapirSchema: sttp.tapir.Schema[SubtypeWithoutD3] = sttp.tapir.Schema.derived
implicit lazy val aDTWithoutDiscriminatorTapirSchema: sttp.tapir.Schema[ADTWithoutDiscriminator] = sttp.tapir.Schema.derived
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.15",
version := "0.1",
openapiStreamingImplementation := "pekko"
openapiStreamingImplementation := "pekko",
openapiGenerateEndpointTypes := true
)

libraryDependencies ++= Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/ObjectWithInlineEnum'
'/oneof/error/test':
get:
responses:
"204":
description: "No response"
"404":
description: Not found
content:
application/json:
schema:
$ref: '#/components/schemas/NotFoundError'
default:
description: Not found
content:
application/json:
schema:
$ref: '#/components/schemas/SimpleError'


components:
schemas:
Expand Down Expand Up @@ -247,4 +265,26 @@ components:
- foo1
- foo2
- foo3
- foo4
- foo4
Error:
title: Error
type: object
oneOf:
- $ref: '#/components/schemas/NotFoundError'
- $ref: '#/components/schemas/SimpleError'
NotFoundError:
title: NotFoundError
required:
- reason
type: object
properties:
reason:
type: string
SimpleError:
title: SimpleError
required:
- message
type: object
properties:
message:
type: string
Loading