diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs index fdcd59916ed..54e1dbb54bc 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs @@ -23,11 +23,22 @@ namespace MongoDB.Bson.Serialization { + /// + /// An interface implemented by BsonClassMapSerializer. + /// + public interface IBsonClassMapSerializer + { + /// + /// Gets the class map for a BsonClassMapSerializer. + /// + public BsonClassMap ClassMap { get; } + } + /// /// Represents a serializer for a class map. /// /// The type of the class. - public sealed class BsonClassMapSerializer : SerializerBase, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention + public sealed class BsonClassMapSerializer : SerializerBase, IBsonClassMapSerializer, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention { // private fields private readonly BsonClassMap _classMap; @@ -57,6 +68,9 @@ public BsonClassMapSerializer(BsonClassMap classMap) } // public properties + /// + public BsonClassMap ClassMap => _classMap; + /// public IDiscriminatorConvention DiscriminatorConvention => _classMap.GetDiscriminatorConvention(); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 1e40aa1347f..fe58ebba3a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,7 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); - var context = TranslationContext.Create(partiallyEvaluatedOutput, translationOptions); + var context = TranslationContext.Create(translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +106,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +150,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +188,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true)); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index d4c74a22b09..87efbfdd596 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,19 +23,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(ConstantExpression constantExpression) + public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer resultSerializer) { - var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); - return Translate(constantExpression, constantSerializer); - } + if (resultSerializer == null) + { + var constantType = constantExpression.Type; + resultSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + } - public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) - { var constantValue = constantExpression.Value; - var serializedValue = constantSerializer.ToBsonValue(constantValue); + var serializedValue = resultSerializer.ToBsonValue(constantValue); var ast = AstExpression.Constant(serializedValue); - return new AggregationExpression(constantExpression, ast, constantSerializer); + return new AggregationExpression(constantExpression, ast, resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index 81a4dfc6ae7..f5ff2417c4b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -27,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg internal static class ExpressionToAggregationExpressionTranslator { // public static methods - public static AggregationExpression Translate(TranslationContext context, Expression expression) + public static AggregationExpression Translate(TranslationContext context, Expression expression, IBsonSerializer resultSerializer = null) { switch (expression.NodeType) { @@ -67,7 +67,7 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression, resultSerializer); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: @@ -75,13 +75,13 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.MemberAccess: return MemberExpressionToAggregationExpressionTranslator.Translate(context, (MemberExpression)expression); case ExpressionType.MemberInit: - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression); + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression, resultSerializer); case ExpressionType.Negate: return NegateExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression); case ExpressionType.New: - return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression); + return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression, resultSerializer); case ExpressionType.NewArrayInit: - return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression); + return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression, resultSerializer); case ExpressionType.Parameter: return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression); case ExpressionType.TypeIs: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index e1a8cd5f399..8cad8bb1c99 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -28,22 +28,28 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class MemberInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression) + public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression, IBsonSerializer resultSerializer) { if (expression.Type == typeof(BsonDocument)) { return NewBsonDocumentExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return Translate(context, expression, expression.NewExpression, expression.Bindings); + return Translate(context, expression, expression.NewExpression, expression.Bindings, resultSerializer); } public static AggregationExpression Translate( TranslationContext context, Expression expression, NewExpression newExpression, - IReadOnlyList bindings) + IReadOnlyList bindings, + IBsonSerializer resultSerializer) { + if (resultSerializer != null) + { + return TranslateWithTargetSerializer(context, expression, newExpression, bindings, resultSerializer); + } + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; var computedFields = new List(); @@ -100,6 +106,71 @@ public static AggregationExpression Translate( return new AggregationExpression(expression, ast, serializer); } + private static AggregationExpression TranslateWithTargetSerializer( + TranslationContext context, + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings, + IBsonSerializer resultSerializer) + { + if (!(resultSerializer is IBsonDocumentSerializer documentSerializer)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer.GetType()} does not implement IBsonDocumentSerializer."); + } + + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var computedFields = new List(); + + if (constructorInfo != null && constructorArguments.Count > 0) + { + var constructorParameters = constructorInfo.GetParameters(); + + // if the documentSerializer is a BsonClassMappedSerializer we can use the classMap + var classMap = (documentSerializer as IBsonClassMapSerializer)?.ClassMap; + var creatorMap = classMap == null ? null : FindMatchingCreatorMap(classMap, constructorInfo); + if (creatorMap == null && classMap != null) + { + throw new ExpressionNotSupportedException(expression, because: "couldn't find matching creator map"); + } + var creatorMapArguments = creatorMap?.Arguments?.ToArray(); + + for (var i = 0; i < constructorParameters.Length; i++) + { + var parameterName = constructorParameters[i].Name; + var argumentExpression = constructorArguments[i]; + var memberName = creatorMapArguments?[i].Name; // null if there is no classMap + + var (elementName, memberSerializer) = FindMemberElementNameAndSerializer(argumentExpression, classMap, memberName, documentSerializer, parameterName); + if (elementName == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching class member for constructor parameter {parameterName}"); + } + + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression, memberSerializer); + computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast)); + } + } + + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var valueExpression = memberAssignment.Expression; + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}"); + } + var memberSerializer = memberSerializationInfo.Serializer; + + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast)); + } + + var ast = AstExpression.ComputedDocument(computedFields); + return new AggregationExpression(expression, ast, documentSerializer); + } + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) { BsonClassMap baseClassMap = null; @@ -190,6 +261,36 @@ private static void EnsureDefaultValue(BsonMemberMap memberMap) memberMap.SetDefaultValue(defaultValue); } + private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo) + => classMap?.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); + + private static (string, IBsonSerializer) FindMemberElementNameAndSerializer( + Expression expression, + BsonClassMap classMap, + string memberName, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + // if we have a classMap use it + if (classMap != null) + { + var memberMap = FindMemberMap(expression, classMap, memberName); + return (memberMap.ElementName, memberMap.GetSerializer()); + } + + // otherwise fall back to calling TryGetMemberSerializationInfo on potential matches + var bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.IgnoreCase; + foreach (var memberInfo in documentSerializer.ValueType.GetMember(constructorParameterName, bindingFlags)) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberInfo.Name, out var serializationInfo)) + { + return (serializationInfo.ElementName, serializationInfo.Serializer); + } + } + + return (null, null); + } + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) { foreach (var memberMap in classMap.DeclaredMemberMaps) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c6d95c36a43..996ccce14d2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -24,15 +24,34 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class NewArrayInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression, IBsonSerializer resultSerializer) { - var items = new List(); + IBsonArraySerializer arraySerializer = null; IBsonSerializer itemSerializer = null; + + if (resultSerializer != null) + { + if ((arraySerializer = resultSerializer as IBsonArraySerializer) == null) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer} does not implement IBsonArraySerializer"); + } + if (!arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer} returned false for TryGetItemSerializationInfo"); + } + + itemSerializer = itemSerializationInfo.Serializer; + } + + var items = new List(); foreach (var itemExpression in expression.Expressions) { - var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); + var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression, itemSerializer); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; + if (itemSerializer == null) + { + itemSerializer = itemTranslation.Serializer; + } // make sure all items are serialized using the same serializer if (!itemTranslation.Serializer.Equals(itemSerializer)) @@ -42,12 +61,14 @@ public static AggregationExpression Translate(TranslationContext context, NewArr } var ast = AstExpression.ComputedArray(items); - - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + if (arraySerializer == null) + { + var arrayType = expression.Type; + var itemType = arrayType.GetElementType(); + itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + arraySerializer = (IBsonArraySerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } return new AggregationExpression(expression, ast, arraySerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index ee930fd6d5f..4d90d2bb96d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -17,12 +17,13 @@ using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; +using MongoDB.Bson.Serialization; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { internal static class NewExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewExpression expression, IBsonSerializer resultSerializer) { var expressionType = expression.Type; @@ -46,7 +47,7 @@ public static AggregationExpression Translate(TranslationContext context, NewExp { return NewTupleExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty(), resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index 29f929085d2..b96a193e323 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -31,7 +31,7 @@ public static ExecutableQuery> Translate TranslateScalar>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions, contextData); + var context = TranslationContext.Create(translationOptions, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -75,7 +75,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -105,7 +105,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -124,7 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch @@ -141,7 +141,7 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +176,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions); try { @@ -200,7 +200,7 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(expression, translationOptions); // do not partially evaluate expression + var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbolWithVarName(parameter, varName: "ROOT", documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs index 36fecaaf8b9..c88054db978 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs @@ -40,7 +40,7 @@ public void Nested_Any_should_translate_correctly() var parameter = expression.Parameters[0]; var serializerRegistry = BsonSerializer.SerializerRegistry; var documentSerializer = serializerRegistry.GetSerializer(); - var context = TranslationContext.Create(expression, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..dd4df74ece6 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,226 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_to_derived_value_using_property_setter() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 700bbcbf7ee..10f3f2a5d14 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -183,7 +183,7 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue private TranslationContext CreateContext(ParameterExpression parameter) { var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(parameter, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index e312485b14a..fd2e09732c8 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -546,7 +546,7 @@ private ProjectedResult Group(Expression(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(expression, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 4d878ca70b9..0869d70822e 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1154,7 +1154,7 @@ public List Assert(IMongoCollection collection, var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(filter, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body);