Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an encoded discriminator value attribute for coproducts, use it to render const constraints #3955

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/scala/sttp/tapir/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,26 @@ object Schema extends LowPrioritySchema with SchemaCompanionMacros {
val Attribute: AttributeKey[Tuple] = new AttributeKey[Tuple]("sttp.tapir.Schema.Tuple")
}

/** For coproduct schemas, when there's a discriminator field, used to attach the encoded value of the discriminator field. Such value is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coproduct may be also rendererd without a discriminator field (Json object nesting). We need to make sure this case is covered as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this applies only to discriminators. I think if there's no discriminator, then we don't have to do anything, that is we shouldn't add any constraints?

* added to the discriminator field schemas in each of the coproduct's subtypes. When rendering OpenAPI/JSON schema, these values are
* converted to `const` constraints on fields.
*/
case class EncodedDiscriminatorValue(v: String)
object EncodedDiscriminatorValue {
/*
Implementation note: the discriminator value constraint is in fact an enum validator with a single possible enum value. Hence an
alternative design would be to add such validators to discriminator fields, instead of an attribute. However, this has two drawbacks:
1. when adding discriminator fields using `addDiscriminatorField`, we don't have access to the decoded discriminator value - only
to the encoded one, via reverse mapping lookup
2. the validator doesn't necessarily make sense, as it can't be used to validate the deserialiszd object. Usually the discriminator
fields don't even exist on the high-level representations.
That's why instead of re-using the validators, we decided to use a specialised attribute.
*/

val Attribute: AttributeKey[EncodedDiscriminatorValue] =
new AttributeKey[EncodedDiscriminatorValue]("sttp.tapir.Schema.EncodedDiscriminatorValue")
}

/** @param typeParameterShortNames
* full name of type parameters, name is legacy and kept only for backward compatibility
*/
Expand Down
34 changes: 31 additions & 3 deletions core/src/main/scala/sttp/tapir/SchemaType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,39 @@ object SchemaType {
discriminatorSchema: Schema[D] = Schema.string,
discriminatorMapping: Map[String, SRef[_]] = Map.empty
): SCoproduct[T] = {
// used to add encoded discriminator value attributes
val reverseDiscriminatorByNameMapping: Map[SName, String] = discriminatorMapping.toList.map { case (v, ref) => (ref.name, v) }.toMap

SCoproduct(
subtypes.map {
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _)
if st.fields.forall(_.name != discriminatorName) =>
s.copy(schemaType = st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None)))
case s @ Schema(st: SchemaType.SProduct[Any @unchecked], _, _, _, _, _, _, _, _, _, _) =>
// first, ensuring that the discriminator field is added to the schema type - it might already be present
var targetSt =
if (st.fields.forall(_.name != discriminatorName))
st.copy(fields = st.fields :+ SProductField[Any, D](discriminatorName, discriminatorSchema, _ => None))
else st

// next, modifying the discriminator field, by adding the value attribute (if a value can be found)
targetSt = targetSt.copy(fields = targetSt.fields.map { field =>
if (field.name == discriminatorName) {
val discriminatorValue = s.name.flatMap { subtypeName =>
reverseDiscriminatorByNameMapping.get(subtypeName)
}

discriminatorValue match {
case Some(v) =>
SProductField(
field.name,
field.schema.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(v)),
field.get
)
case None => field
}

} else field
})

s.copy(schemaType = targetSt)
case s => s
},
Some(SDiscriminator(discriminatorName, discriminatorMapping))
Expand Down
9 changes: 8 additions & 1 deletion core/src/test/scala/sttp/tapir/SchemaMacroTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,14 @@ class SchemaMacroTest extends AnyFlatSpec with Matchers with TableDrivenProperty

schemaType.subtypes.foreach { childSchema =>
val childProduct = childSchema.schemaType.asInstanceOf[SProduct[_]]
childProduct.fields.find(_.name.name == "kind") shouldBe Some(SProductField(FieldName("kind"), Schema.string, (_: Any) => None))
val discValue = if (childSchema.name.get.fullName == "sttp.tapir.SchemaMacroTestData.User") "user" else "org"
childProduct.fields.find(_.name.name == "kind") shouldBe Some(
SProductField(
FieldName("kind"),
Schema.string.attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue(discValue)),
(_: Any) => None
)
)
}
}

Expand Down
18 changes: 15 additions & 3 deletions core/src/test/scala/sttp/tapir/generic/SchemaGenericAutoTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
schemaType.asInstanceOf[SCoproduct[Entity]].subtypes should contain theSameElementsAs List(
Schema(
SProduct[Organization](
List(field(FieldName("name"), Schema(SString())), field(FieldName("who_am_i"), Schema(SString())))
List(
field(FieldName("name"), Schema(SString())),
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Organization"))
)
)
),
Some(SName("sttp.tapir.generic.Organization"))
),
Expand All @@ -254,15 +260,21 @@ class SchemaGenericAutoTest extends AsyncFlatSpec with Matchers {
List(
field(FieldName("first"), Schema(SString())),
field(FieldName("age"), Schema(SInteger(), format = Some("int32"))),
field(FieldName("who_am_i"), Schema(SString()))
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("Person"))
)
)
),
Some(SName("sttp.tapir.generic.Person"))
),
Schema(
SProduct[UnknownEntity.type](
List(
field(FieldName("who_am_i"), Schema(SString()))
field(
FieldName("who_am_i"),
Schema(SString()).attribute(Schema.EncodedDiscriminatorValue.Attribute, Schema.EncodedDiscriminatorValue("UnknownEntity"))
)
)
),
Some(SName("sttp.tapir.generic.UnknownEntity"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ private[docs] class TSchemaToASchema(
// The primary motivation for using schema name as fallback title is to improve Swagger UX with
// `oneOf` schemas in OpenAPI 3.1. See https://github.com/softwaremill/tapir/issues/3447 for details.
def fallbackTitle = tschema.name.map(fallbackSchemaTitle)

val const = tschema.attribute(TSchema.EncodedDiscriminatorValue.Attribute).map(_.v).map(v => ExampleSingleValue(v))

oschema
.copy(title = titleFromAttr orElse fallbackTitle)
.copy(title = titleFromAttr.orElse(fallbackTitle))
.copy(uniqueItems = tschema.attribute(UniqueItems.Attribute).map(_.uniqueItems))
.copy(const = const)
}

private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
asyncapi: 2.6.0
info:
title: discriminator
version: '1.0'
channels:
/animals:
subscribe:
operationId: onAnimals
message:
$ref: '#/components/messages/Animal'
publish:
operationId: sendAnimals
message:
$ref: '#/components/messages/GetAnimal'
bindings:
ws:
method: GET
components:
schemas:
GetAnimal:
title: GetAnimal
type: object
required:
- name
properties:
name:
type: string
Animal:
title: Animal
oneOf:
- $ref: '#/components/schemas/Cat'
- $ref: '#/components/schemas/Dog'
discriminator: pet
Cat:
title: Cat
type: object
required:
- name
- pet
properties:
name:
type: string
pet:
type: string
const: Cat
Dog:
title: Dog
type: object
required:
- name
- breed
- pet
properties:
name:
type: string
breed:
type: string
pet:
type: string
const: Dog
messages:
GetAnimal:
payload:
$ref: '#/components/schemas/GetAnimal'
contentType: application/json
Animal:
payload:
$ref: '#/components/schemas/Animal'
contentType: application/json
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
.out(
webSocketBody[Fruit, CodecFormat.Json, Int, CodecFormat.Json](AkkaStreams)
// TODO: missing `RequestInfo.example(example: EndpointIO.Example)` and friends
.pipe(e => e.copy(requestsInfo = e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))))
.pipe(e =>
e.copy(requestsInfo =
e.requestsInfo.example(Example.of(Fruit("apple")).name("Apple").summary("Sample representation of apple"))
)
)
)

val expectedYaml = loadYaml("expected_json_example_name_summary.yml")
Expand Down Expand Up @@ -232,6 +236,22 @@ class VerifyAsyncAPIYamlTest extends AnyFunSuite with Matchers {
noIndentation(yaml) shouldBe loadYaml("expected_flags_header.yml")
}

test("should work with discriminators") {
case class GetAnimal(name: String)
sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String, breed: String) extends Animal
implicit val configuration: sttp.tapir.generic.Configuration = sttp.tapir.generic.Configuration.default.withDiscriminator("pet")

val animalEndpoint = endpoint.get
.in("animals")
.out(webSocketBody[GetAnimal, CodecFormat.Json, Animal, CodecFormat.Json](AkkaStreams))

val yaml = AsyncAPIInterpreter().toAsyncAPI(animalEndpoint, "discriminator", "1.0").toYaml

noIndentation(yaml) shouldBe loadYaml("expected_coproduct_with_discriminator.yml")
}

private def loadYaml(fileName: String): String = {
noIndentation(Source.fromInputStream(getClass.getResourceAsStream(s"/$fileName")).getLines().mkString("\n"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ components:
properties:
name:
type: string
const: sml
Person:
title: Person
type: object
Expand All @@ -42,6 +43,7 @@ components:
properties:
name:
type: string
const: john
age:
type: integer
format: int32
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ components:
properties:
name:
type: string
const: sml
Person:
title: Person
type: object
Expand All @@ -50,6 +51,7 @@ components:
properties:
name:
type: string
const: john
age:
type: integer
format: int32
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ components:
- red
shapeType:
type: string
const: Square
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ components:
type: string
kind:
type: string
const: organization
Person:
title: Person
type: object
Expand All @@ -100,4 +101,5 @@ components:
type: integer
format: int32
kind:
type: string
type: string
const: person
Loading
Loading