Skip to content

Commit

Permalink
Use BaseTypeInfo API to fix up polymorphic schemas (dotnet#56908)
Browse files Browse the repository at this point in the history
* Use BaseTypeInfo API to fix up polymorphic schemas

* Change schema for polymorphic types with non-abstract base class

* Add comments, update API name, and add more tests

* More feedback and tests
  • Loading branch information
captainsafia authored Aug 7, 2024
1 parent abc946f commit 257d690
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 30 deletions.
36 changes: 32 additions & 4 deletions src/OpenApi/src/Extensions/JsonNodeSchemaExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,27 @@ static bool SupportsNullableProperty(BindingSource bindingSource) =>bindingSourc
}

/// <summary>
/// Applies the polymorphism options to the target schema following OpenAPI v3's conventions.
/// Applies the polymorphism options defined by System.Text.Json to the target schema following OpenAPI v3's
/// conventions for the discriminator property.
/// </summary>
/// <param name="schema">The <see cref="JsonNode"/> produced by the underlying schema generator.</param>
/// <param name="context">The <see cref="JsonSchemaExporterContext"/> associated with the current type.</param>
/// <param name="createSchemaReferenceId">A delegate that generates the reference ID to create for a type.</param>
internal static void ApplyPolymorphismOptions(this JsonNode schema, JsonSchemaExporterContext context, Func<JsonTypeInfo, string?> createSchemaReferenceId)
internal static void MapPolymorphismOptionsToDiscriminator(this JsonNode schema, JsonSchemaExporterContext context, Func<JsonTypeInfo, string?> createSchemaReferenceId)
{
// The `context.Path.Length == 0` check is used to ensure that we only apply the polymorphism options
// The `context.BaseTypeInfo == null` check is used to ensure that we only apply the polymorphism options
// to the top-level schema and not to any nested schemas that are generated.
if (context.TypeInfo.PolymorphismOptions is { } polymorphismOptions && context.Path.Length == 0)
if (context.TypeInfo.PolymorphismOptions is { } polymorphismOptions && context.BaseTypeInfo == null)
{
// System.Text.Json supports serializing to a non-abstract base class if no discriminator is provided.
// OpenAPI requires that all polymorphic sub-schemas have an associated discriminator. If the base type
// doesn't declare itself as its own derived type via [JsonDerived], then it can't have a discriminator,
// which OpenAPI requires. In that case, we exit early to avoid mapping the polymorphism options
// to the `discriminator` property and return an un-discriminated `anyOf` schema instead.
if (IsNonAbstractTypeWithoutDerivedTypeReference(context))
{
return;
}
var mappings = new JsonObject();
foreach (var derivedType in polymorphismOptions.DerivedTypes)
{
Expand Down Expand Up @@ -376,6 +386,24 @@ internal static void ApplySchemaReferenceId(this JsonNode schema, JsonSchemaExpo
{
schema[OpenApiConstants.SchemaId] = schemaReferenceId;
}
// If the type is a non-abstract base class that is not one of the derived types then mark it as a base schema.
if (context.BaseTypeInfo == context.TypeInfo &&
IsNonAbstractTypeWithoutDerivedTypeReference(context))
{
schema[OpenApiConstants.SchemaId] = "Base";
}
}

/// <summary>
/// Returns <langword ref="true" /> if the current type is a non-abstract base class that is not defined as its
/// own derived type.
/// </summary>
/// <param name="context">The <see cref="JsonSchemaExporterContext"/> associated with the current type.</param>
private static bool IsNonAbstractTypeWithoutDerivedTypeReference(JsonSchemaExporterContext context)
{
return !context.TypeInfo.Type.IsAbstract
&& context.TypeInfo.PolymorphismOptions is { } polymorphismOptions
&& !polymorphismOptions.DerivedTypes.Any(type => type.DerivedType == context.TypeInfo.Type);
}

/// <summary>
Expand Down
5 changes: 4 additions & 1 deletion src/OpenApi/src/Schemas/OpenApiJsonSchema.Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ public static void ReadProperty(ref Utf8JsonReader reader, string propertyName,
case OpenApiSchemaKeywords.DiscriminatorMappingKeyword:
reader.Read();
var mappings = ReadDictionary<string>(ref reader);
schema.Discriminator.Mapping = mappings;
if (mappings is not null)
{
schema.Discriminator.Mapping = mappings;
}
break;
case OpenApiConstants.SchemaId:
reader.Read();
Expand Down
1 change: 0 additions & 1 deletion src/OpenApi/src/Schemas/OpenApiSchemaKeywords.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ internal class OpenApiSchemaKeywords
public const string MinItemsKeyword = "minItems";
public const string MaxItemsKeyword = "maxItems";
public const string RefKeyword = "$ref";
public const string SchemaIdKeyword = "x-schema-id";
public const string ConstKeyword = "const";
}
2 changes: 1 addition & 1 deletion src/OpenApi/src/Services/Schemas/OpenApiSchemaService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal sealed class OpenApiSchemaService(
var createSchemaReferenceId = optionsMonitor.Get(documentName).CreateSchemaReferenceId;
schema.ApplyPrimitiveTypesAndFormats(context, createSchemaReferenceId);
schema.ApplySchemaReferenceId(context, createSchemaReferenceId);
schema.ApplyPolymorphismOptions(context, createSchemaReferenceId);
schema.MapPolymorphismOptionsToDiscriminator(context, createSchemaReferenceId);
if (context.PropertyInfo is { } jsonPropertyInfo)
{
// Short-circuit STJ's handling of nested properties, which uses a reference to the
Expand Down
39 changes: 36 additions & 3 deletions src/OpenApi/src/Services/Schemas/OpenApiSchemaStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public void PopulateSchemaIntoReferenceCache(OpenApiSchema schema, bool captureS
// Only capture top-level schemas by ref. Nested schemas will follow the
// reference by duplicate rules.
AddOrUpdateSchemaByReference(schema, captureSchemaByRef: captureSchemaByRef);
AddOrUpdateAnyOfSubSchemaByReference(schema);
if (schema.AdditionalProperties is not null)
{
AddOrUpdateSchemaByReference(schema.AdditionalProperties);
Expand All @@ -99,24 +100,56 @@ public void PopulateSchemaIntoReferenceCache(OpenApiSchema schema, bool captureS
AddOrUpdateSchemaByReference(allOfSchema);
}
}
if (schema.Properties is not null)
{
foreach (var property in schema.Properties.Values)
{
AddOrUpdateSchemaByReference(property);
}
}
}

private void AddOrUpdateAnyOfSubSchemaByReference(OpenApiSchema schema)
{
if (schema.AnyOf is not null)
{
// AnyOf schemas in a polymorphic type should contain a reference to the parent schema
// ID to support disambiguating between a derived type on its own and a derived type
// as part of a polymorphic schema.
var baseTypeSchemaId = schema.Annotations is not null && schema.Annotations.TryGetValue(OpenApiConstants.SchemaId, out var schemaId) ? schemaId?.ToString() : null;
var baseTypeSchemaId = schema.Annotations is not null && schema.Annotations.TryGetValue(OpenApiConstants.SchemaId, out var schemaId)
? schemaId?.ToString()
: null;
foreach (var anyOfSchema in schema.AnyOf)
{
AddOrUpdateSchemaByReference(anyOfSchema, baseTypeSchemaId);
}
}
if (schema.Properties is not null)

if (schema.Items is not null)
{
AddOrUpdateAnyOfSubSchemaByReference(schema.Items);
}

if (schema.Properties is { Count: > 0 })
{
foreach (var property in schema.Properties.Values)
{
AddOrUpdateSchemaByReference(property);
AddOrUpdateAnyOfSubSchemaByReference(property);
}
}

if (schema.AllOf is not null)
{
foreach (var allOfSchema in schema.AllOf)
{
AddOrUpdateAnyOfSubSchemaByReference(allOfSchema);
}
}

if (schema.AdditionalProperties is not null)
{
AddOrUpdateAnyOfSubSchemaByReference(schema.AdditionalProperties);
}
}

private void AddOrUpdateSchemaByReference(OpenApiSchema schema, string? baseTypeSchemaId = null, bool captureSchemaByRef = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,53 @@ await VerifyOpenApiDocument(builder, document =>
}

[Fact]
public async Task HandlesPolymorphicTypesWithNonAbstractBaseClass()
public async Task HandlesPolymorphicTypesWithNonAbstractBaseClassWithNoDiscriminator()
{
// Arrange
var builder = CreateBuilder();

// Act
builder.MapPost("/api", (Color color) => { });

// Assert
await VerifyOpenApiDocument(builder, document =>
{
var operation = document.Paths["/api"].Operations[OperationType.Post];
Assert.NotNull(operation.RequestBody);
var requestBody = operation.RequestBody.Content;
Assert.True(requestBody.TryGetValue("application/json", out var mediaType));
var schema = mediaType.Schema.GetEffective(document);
// Assert discriminator mappings are not configured for this type since we
// can't meet OpenAPI's restrictions that derived types _always_ have a discriminator
// property associated with them.
Assert.Null(schema.Discriminator);
Assert.Collection(schema.AnyOf,
schema => Assert.Equal("ColorPaintColor", schema.Reference.Id),
schema => Assert.Equal("ColorFabricColor", schema.Reference.Id),
schema => Assert.Equal("ColorBase", schema.Reference.Id));
// Assert schema with discriminator = "paint" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("ColorPaintColor", out var paintSchema));
Assert.Contains("$type", paintSchema.Properties.Keys);
Assert.Equal("paint", ((OpenApiString)paintSchema.Properties["$type"].Enum.First()).Value);
// Assert schema with discriminator = "fabric" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("ColorFabricColor", out var fabricSchema));
Assert.Contains("$type", fabricSchema.Properties.Keys);
Assert.Equal("fabric", ((OpenApiString)fabricSchema.Properties["$type"].Enum.First()).Value);
// Assert that schema for `Color` has been inserted into the components without a discriminator
Assert.True(document.Components.Schemas.TryGetValue("ColorBase", out var colorSchema));
Assert.DoesNotContain("$type", colorSchema.Properties.Keys);
});
}

[Fact]
public async Task HandlesPolymorphicTypesWithNonAbstractBaseClassAndDiscriminator()
{
// Arrange
var builder = CreateBuilder();

// Act
builder.MapPost("/api", (Pet pet) => { });

// Assert
await VerifyOpenApiDocument(builder, document =>
{
Expand All @@ -148,30 +187,113 @@ await VerifyOpenApiDocument(builder, document =>
// Assert discriminator mappings have been configured correctly
Assert.Equal("$type", schema.Discriminator.PropertyName);
Assert.Collection(schema.Discriminator.Mapping,
item => Assert.Equal("paint", item.Key),
item => Assert.Equal("fabric", item.Key)
item => Assert.Equal("cat", item.Key),
item => Assert.Equal("dog", item.Key),
item => Assert.Equal("pet", item.Key)
);
Assert.Collection(schema.Discriminator.Mapping,
item => Assert.Equal("#/components/schemas/ColorPaintColor", item.Value),
item => Assert.Equal("#/components/schemas/ColorFabricColor", item.Value)
item => Assert.Equal("#/components/schemas/PetCat", item.Value),
item => Assert.Equal("#/components/schemas/PetDog", item.Value),
item => Assert.Equal("#/components/schemas/PetPet", item.Value)
);
// Note that our implementation diverges from the OpenAPI specification here. OpenAPI
// requires that derived types in a polymorphic schema _always_ have a discriminator
// OpenAPI requires that derived types in a polymorphic schema _always_ have a discriminator
// property associated with them. STJ permits the discriminator to be omitted from the
// if the base type is a non-abstract class and falls back to serializing to this base
// type. This is a known limitation of the current implementation.
// type. In this scenario, we check that the base class is not included in the `anyOf`
// schema.
Assert.Collection(schema.AnyOf,
schema => Assert.Equal("ColorPaintColor", schema.Reference.Id),
schema => Assert.Equal("ColorFabricColor", schema.Reference.Id),
schema => Assert.Equal("ColorColor", schema.Reference.Id));
// Assert schema with discriminator = "paint" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("ColorPaintColor", out var paintSchema));
Assert.Contains(schema.Discriminator.PropertyName, paintSchema.Properties.Keys);
Assert.Equal("paint", ((OpenApiString)paintSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
// Assert schema with discriminator = "fabric" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("ColorFabricColor", out var fabricSchema));
Assert.Contains(schema.Discriminator.PropertyName, fabricSchema.Properties.Keys);
Assert.Equal("fabric", ((OpenApiString)fabricSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
schema => Assert.Equal("PetCat", schema.Reference.Id),
schema => Assert.Equal("PetDog", schema.Reference.Id),
schema => Assert.Equal("PetPet", schema.Reference.Id));
// Assert schema with discriminator = "dog" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("PetDog", out var dogSchema));
Assert.Contains(schema.Discriminator.PropertyName, dogSchema.Properties.Keys);
Assert.Equal("dog", ((OpenApiString)dogSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
// Assert schema with discriminator = "cat" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("PetCat", out var catSchema));
Assert.Contains(schema.Discriminator.PropertyName, catSchema.Properties.Keys);
Assert.Equal("cat", ((OpenApiString)catSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
// Assert schema with discriminator = "cat" has been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("PetPet", out var petSchema));
Assert.Contains(schema.Discriminator.PropertyName, petSchema.Properties.Keys);
Assert.Equal("pet", ((OpenApiString)petSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
});
}

[Fact]
public async Task HandlesPolymorphicTypesWithNoExplicitDiscriminators()
{
// Arrange
var builder = CreateBuilder();

// Act
builder.MapPost("/api", (Organism color) => { });

// Assert
await VerifyOpenApiDocument(builder, document =>
{
var operation = document.Paths["/api"].Operations[OperationType.Post];
Assert.NotNull(operation.RequestBody);
var requestBody = operation.RequestBody.Content;
Assert.True(requestBody.TryGetValue("application/json", out var mediaType));
var schema = mediaType.Schema.GetEffective(document);
// Assert discriminator mappings are not configured for this type since we
// can't meet OpenAPI's restrictions that derived types _always_ have a discriminator
// property associated with them.
Assert.Null(schema.Discriminator);
Assert.Collection(schema.AnyOf,
schema => Assert.Equal("OrganismAnimal", schema.Reference.Id),
schema => Assert.Equal("OrganismPlant", schema.Reference.Id),
schema => Assert.Equal("OrganismBase", schema.Reference.Id));
// Assert that schemas without discriminators have been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("OrganismAnimal", out var animalSchema));
Assert.DoesNotContain("$type", animalSchema.Properties.Keys);
Assert.True(document.Components.Schemas.TryGetValue("OrganismPlant", out var plantSchema));
Assert.DoesNotContain("$type", plantSchema.Properties.Keys);
Assert.True(document.Components.Schemas.TryGetValue("OrganismBase", out var baseSchema));
Assert.DoesNotContain("$type", baseSchema.Properties.Keys);
});
}

[Fact]
public async Task HandlesPolymorphicTypesWithSelfReference()
{
// Arrange
var builder = CreateBuilder();

// Act
builder.MapPost("/api", (Employee color) => { });

// Assert
await VerifyOpenApiDocument(builder, document =>
{
var operation = document.Paths["/api"].Operations[OperationType.Post];
Assert.NotNull(operation.RequestBody);
var requestBody = operation.RequestBody.Content;
Assert.True(requestBody.TryGetValue("application/json", out var mediaType));
Assert.Equal("Employee", mediaType.Schema.Reference.Id);
var schema = mediaType.Schema.GetEffective(document);
// Assert that discriminator mappings are configured correctly for type.
Assert.Equal("$type", schema.Discriminator.PropertyName);
Assert.Collection(schema.Discriminator.Mapping,
item => Assert.Equal("manager", item.Key),
item => Assert.Equal("employee", item.Key)
);
Assert.Collection(schema.Discriminator.Mapping,
item => Assert.Equal("#/components/schemas/EmployeeManager", item.Value),
item => Assert.Equal("#/components/schemas/EmployeeEmployee", item.Value)
);
// Assert that anyOf schemas use the correct reference IDs.
Assert.Collection(schema.AnyOf,
schema => Assert.Equal("EmployeeManager", schema.Reference.Id),
schema => Assert.Equal("EmployeeEmployee", schema.Reference.Id));
// Assert that schemas without discriminators have been inserted into the components
Assert.True(document.Components.Schemas.TryGetValue("EmployeeManager", out var managerSchema));
Assert.Equal("manager", ((OpenApiString)managerSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
Assert.True(document.Components.Schemas.TryGetValue("EmployeeEmployee", out var employeeSchema));
Assert.Equal("employee", ((OpenApiString)employeeSchema.Properties[schema.Discriminator.PropertyName].Enum.First()).Value);
// Assert that the schema has a correct self-reference to the base-type. This points to the schema that contains the discriminator.
Assert.Equal("Employee", employeeSchema.Properties["manager"].Reference.Id);
});
}
}
Loading

0 comments on commit 257d690

Please sign in to comment.