From d5e3140cd824453fd1a5201df618c42949e5f69e Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Thu, 26 Dec 2024 09:57:31 -0800 Subject: [PATCH 1/8] CSHARP-5435: NotSupportedException when using Set with sub-documents in an Update with Aggregation Pipeline --- .../Jira/CSharp5435Tests.cs | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..356830e4119 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,90 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test2() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + Long = x.ValueObject.Value, + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + } + + class MyValue + { + public int Value { get; set; } + } + } +} From 6afb5290cc42ad30a5a3acd6455749cce4cca859 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Thu, 2 Jan 2025 10:34:49 -0800 Subject: [PATCH 2/8] Add derived class test --- .../Linq3Implementation/Jira/CSharp5435Tests.cs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs index 356830e4119..ad1c7940fe4 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -34,7 +34,7 @@ public void Test() var pipelineError = new EmptyPipelineDefinition() .Set(x => new MyDocument() { - ValueObject = new() + ValueObject = new MyValue() { Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 } @@ -45,7 +45,7 @@ public void Test() } [Fact] - public void Test2() + public void TestDerived() { var coll = GetCollection(); var doc = new MyDocument(); @@ -54,7 +54,11 @@ public void Test2() var pipelineError = new EmptyPipelineDefinition() .Set(x => new MyDocument() { - Long = x.ValueObject.Value, + ValueObject = new MyDerivedValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + B = 42 + } }); var updateError = Builders.Update.Pipeline(pipelineError); @@ -86,5 +90,10 @@ class MyValue { public int Value { get; set; } } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } } } From 55e1883dea41bdcc1ff1dad173e030c1328d2cf0 Mon Sep 17 00:00:00 2001 From: rstam Date: Wed, 8 Jan 2025 09:55:01 -0800 Subject: [PATCH 3/8] CSHARP-5435: Proof of concept of using "target serializer". --- .../Serializers/BsonClassMapSerializer.cs | 16 +- ...essionToAggregationExpressionTranslator.cs | 167 +++++------------- .../ExpressionToSetStageTranslator.cs | 4 +- .../Translators/TranslationContext.cs | 6 + .../Translators/TranslationContextData.cs | 10 +- .../Jira/CSharp3922Tests.cs | 2 + .../Jira/CSharp4289Tests.cs | 1 + .../Jira/CSharp4524Tests.cs | 6 +- .../Jira/CSharp4586Tests.cs | 2 + .../Jira/CSharp5435Tests.cs | 7 + ...nToAggregationExpressionTranslatorTests.cs | 8 +- .../IntegrationTestBase.cs | 2 + 12 files changed, 104 insertions(+), 127 deletions(-) diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs index fdcd59916ed..a83f752a38d 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs @@ -23,11 +23,22 @@ namespace MongoDB.Bson.Serialization { + /// + /// + /// + public interface IBsonClassMapSerializer + { + /// + /// + /// + public BsonClassMap ClassMap { get; } + } + /// /// Represents a serializer for a class map. /// /// The type of the class. - public sealed class BsonClassMapSerializer : SerializerBase, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention + public sealed class BsonClassMapSerializer : SerializerBase, IBsonClassMapSerializer, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention { // private fields private readonly BsonClassMap _classMap; @@ -57,6 +68,9 @@ public BsonClassMapSerializer(BsonClassMap classMap) } // public properties + /// + public BsonClassMap ClassMap => _classMap; + /// public IDiscriminatorConvention DiscriminatorConvention => _classMap.GetDiscriminatorConvention(); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index e1a8cd5f399..1bc82b153dd 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -22,7 +22,6 @@ using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; -using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -47,10 +46,29 @@ public static AggregationExpression Translate( var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; var computedFields = new List(); - var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); - if (constructorInfo != null && creatorMap != null) + var targetSerializer = context.Data?.GetValueOrDefault("TargetSerializer", null); + var targetType = newExpression.Type; + var serializer = targetSerializer?.ValueType == targetType ? targetSerializer : BsonSerializer.LookupSerializer(targetType); + if (!(serializer is IBsonDocumentSerializer documentSerializer)) { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {serializer.GetType()} does not implement IBsonDocumentSerializer."); + } + + if (constructorInfo != null && constructorInfo.GetParameters().Length > 0) + { + // for now we only support constructors with BsonClassMapSerializers + if (!(documentSerializer is IBsonClassMapSerializer bsonClassMapSerializer)) + { + throw new ExpressionNotSupportedException(expression, because: "constructors are only supported for BsonClassMapSerializer"); + } + var classMap = bsonClassMapSerializer.ClassMap; + var creatorMap = FindMatchingCreatorMap(classMap, constructorInfo); + if (creatorMap == null) + { + throw new ExpressionNotSupportedException(expression, because: "couldn't find matching creator map"); + } + var constructorParameters = constructorInfo.GetParameters(); var creatorMapParameters = creatorMap.Arguments?.ToArray(); if (constructorParameters.Length > 0) @@ -67,15 +85,16 @@ public static AggregationExpression Translate( for (var i = 0; i < creatorMapParameters.Length; i++) { var creatorMapParameter = creatorMapParameters[i]; - var constructorArgumentExpression = constructorArguments[i]; - var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); - var constructorArgumentType = constructorArgumentExpression.Type; - var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); - var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); - EnsureDefaultValue(memberMap); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + var memberMap = FindMatchingMemberMap(creatorMap, creatorMapParameter.Name); + if (memberMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching member map for constructor parameter {creatorMapParameter.Name}"); + } + + var argumentContext = context.WithData("TargetSerializer", memberMap.GetSerializer()); + var argumentExpression = constructorArguments[i]; + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(argumentContext, argumentExpression); + computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, argumentTranslation.Ast)); } } } @@ -83,129 +102,39 @@ public static AggregationExpression Translate( foreach (var binding in bindings) { var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); var valueExpression = memberAssignment.Expression; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); - memberMap.SetSerializer(memberSerializer); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); - } - - var ast = AstExpression.ComputedDocument(computedFields); - classMap.Freeze(); - var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); - var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); - - return new AggregationExpression(expression, ast, serializer); - } - - private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) - { - BsonClassMap baseClassMap = null; - if (classType.BaseType != null) - { - baseClassMap = CreateClassMap(classType.BaseType, null, out _); - } - - var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); - var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); - var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); - if (constructorInfo != null) - { - creatorMap = classMap.MapConstructor(constructorInfo); - } - else - { - creatorMap = null; - } - - classMap.AutoMap(); - classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here - - return classMap; - } - - private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) - { - var memberType = memberMap.MemberType; - var memberSerializer = memberMap.GetSerializer(); - var sourceType = sourceSerializer.ValueType; - - if (memberType != sourceType && - memberType.ImplementsIEnumerable(out var memberItemType) && - sourceType.ImplementsIEnumerable(out var sourceItemType) && - sourceItemType == memberItemType && - sourceSerializer is IBsonArraySerializer sourceArraySerializer && - sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && - memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) - { - var sourceItemSerializer = sourceItemSerializationInfo.Serializer; - return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); - } - - return sourceSerializer; - } - - private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) - { - var declaringClassMap = classMap; - while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) - { - declaringClassMap = declaringClassMap.BaseClassMap; - - if (declaringClassMap == null) + var member = memberAssignment.Member; + if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}"); } - } - foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) - { - if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) - { - return memberMap; - } + var valueContext = context.WithData("TargetSerializer", memberSerializationInfo.Serializer); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(valueContext, valueExpression); + computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast)); } - return declaringClassMap.MapMember(creatorMapParameter); + var ast = AstExpression.ComputedDocument(computedFields); - static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) - { - var memberInfo = memberMap.MemberInfo; - return - memberInfo.MemberType == creatorMapParameter.MemberType && - memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); - } + return new AggregationExpression(expression, ast, documentSerializer); } - private static void EnsureDefaultValue(BsonMemberMap memberMap) - { - if (memberMap.IsDefaultValueSpecified) - { - return; - } - - var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; - memberMap.SetDefaultValue(defaultValue); - } + private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo) + => classMap.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + private static BsonMemberMap FindMatchingMemberMap(BsonCreatorMap creatorMap, string memberName) { - foreach (var memberMap in classMap.DeclaredMemberMaps) + var arguments = creatorMap.Arguments.ToArray(); + for (var index = 0; index < arguments.Length; index++) { - if (memberMap.MemberName == memberName) + if (arguments[index].Name.Equals(memberName, StringComparison.Ordinal)) { - return memberMap; + var elementName = creatorMap.ElementNames.ElementAt(index); + return creatorMap.ClassMap.AllMemberMaps.FirstOrDefault(m => m.ElementName.Equals(elementName, StringComparison.Ordinal)); } } - if (classMap.BaseClassMap != null) - { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); - } - - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + return null; } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs index 018c2ae5051..85ccf143b7c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs @@ -153,7 +153,9 @@ private static AstComputedField CreateComputedField(TranslationContext context, } else { - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var valueSerializer = serializationInfo.Serializer; + var valueContext = context.WithData("TargetSerializer", valueSerializer); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(valueContext, valueExpression); ThrowIfMemberAndValueSerializersAreNotCompatible(valueExpression, memberSerializer, valueTranslation.Serializer); valueAst = valueTranslation.Ast; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs index 55629aac049..57ef2e0f269 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs @@ -101,6 +101,12 @@ public override string ToString() return $"{{ SymbolTable : {_symbolTable} }}"; } + public TranslationContext WithData(string key, object value) + { + var data = _data == null ? new TranslationContextData(key, value) : _data.With(key, value); + return new TranslationContext(_symbolTable, _nameGenerator, _translationOptions, data); + } + public TranslationContext WithSingleSymbol(Symbol newSymbol) { var newSymbolTable = new SymbolTable(newSymbol); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs index d2b3faef80a..573ea089f4c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs @@ -27,6 +27,12 @@ public TranslationContextData() { } + public TranslationContextData(string key, object value) + : this(new Dictionary()) + { + _data[key] = value; + } + private TranslationContextData(Dictionary data) { _data = Ensure.IsNotNull(data, nameof(data)); @@ -42,10 +48,10 @@ public TValue GetValueOrDefault(string key, TValue defaultValue) return _data.TryGetValue(key, out var value) ? (TValue)value : defaultValue; } - public TranslationContextData With(string key, TValue value) + public TranslationContextData With(string key, object value) { var clonedData = new Dictionary(_data); - clonedData.Add(key, value); + clonedData[key] = value; return new TranslationContextData(clonedData); } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs index 0c16dcda035..8ce88f7022f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs @@ -14,6 +14,7 @@ */ using System.Linq; +using MongoDB.Bson.Serialization.Attributes; using Xunit; namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira @@ -80,6 +81,7 @@ private class C private class D { + [BsonConstructor] public D(int x) { X = x; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs index e9dc5de6f7b..14b5dc5808d 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs @@ -75,6 +75,7 @@ public class C public class R { + [BsonConstructor] public R(string v) { V = v; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs index 4a6c6ea8c5b..97ca6f5004b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs @@ -16,6 +16,7 @@ using System; using FluentAssertions; using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; using Xunit; @@ -66,9 +67,10 @@ public enum SpawnPeriod { LIVE, MIDNIGHT, MORNING, EVENING } public struct SpawnData { - public readonly DateTime Date; - public readonly SpawnPeriod Period; + [BsonElement] public readonly DateTime Date; + [BsonElement] public readonly SpawnPeriod Period; + [BsonConstructor] public SpawnData(DateTime date, SpawnPeriod period) { // Normally there is more complex handling here, value-type semantics are important, there are custom comparison operators, etc. hence the point of this struct. diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs index 2f8d3d47df6..f41439e4da7 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs @@ -14,6 +14,7 @@ */ using FluentAssertions; +using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; using Xunit; @@ -166,6 +167,7 @@ public View1(string id) private class View2 { + [BsonConstructor] public View2(string id) { Id = id; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs index ad1c7940fe4..ada15cb79b2 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -16,6 +16,7 @@ using System.Linq; using FluentAssertions; using MongoDB.Bson; +using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Bson.Serialization.Serializers; using Xunit; @@ -41,6 +42,12 @@ public void Test() }); var updateError = Builders.Update.Pipeline(pipelineError); + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs index 442b594e5fa..1197bcf1594 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs @@ -194,10 +194,11 @@ public struct SpawnDataStructParameterless public class SpawnDataClass { - public readonly int Identifier; + [BsonElement] public readonly int Identifier; public DateTime SpawnDate; private string spawnText; + [BsonConstructor] public SpawnDataClass(int identifier, DateTime spawnDate) { Identifier = identifier; @@ -213,10 +214,11 @@ public string SpawnText public class SpawnDataClassWithAdditionalParameter { - public readonly int Identifier; + [BsonElement] public readonly int Identifier; public DateTime SpawnDate; public int AdditionalField; + [BsonConstructor] public SpawnDataClassWithAdditionalParameter(int identifier, DateTime spawnDate, int additionalField) { Identifier = identifier; @@ -233,6 +235,7 @@ public struct SpawnDataStruct private string spawnText; // this constructor is required for the test to compile + [BsonConstructor] public SpawnDataStruct(int identifier, DateTime spawnDate) { Identifier = identifier; @@ -258,6 +261,7 @@ public string SpawnText public class InheritedSpawnData : SpawnDataClass { + [BsonConstructor] public InheritedSpawnData(int identifier, DateTime spawnDate) : base(identifier, spawnDate) { diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs index e422523ff9d..4641ee6a87f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs @@ -268,10 +268,12 @@ private void InsertJoin() public class RootView { + [BsonConstructor] public RootView() { } + [BsonConstructor] public RootView(string property) { Property = property; From a5dcc5206212e923949e505035cd49887b31f9b2 Mon Sep 17 00:00:00 2001 From: rstam Date: Thu, 9 Jan 2025 10:55:48 -0800 Subject: [PATCH 4/8] CSHARP-5435: Added known serializers to TranslationContext. --- ...ingWithOutputExpressionStageDefinitions.cs | 8 +- ...essionToAggregationExpressionTranslator.cs | 189 +++++++++++++++++- .../ExpressionToExecutableQueryTranslator.cs | 4 +- .../ConcatMethodToPipelineTranslator.cs | 2 +- .../UnionMethodToPipelineTranslator.cs | 2 +- .../ExpressionToSetStageTranslator.cs | 2 +- .../Translators/TranslationContext.cs | 27 ++- .../Translators/TranslationContextData.cs | 10 +- .../Linq/LinqProviderAdapter.cs | 14 +- .../Jira/CSharp1585Tests.cs | 2 +- .../Jira/CSharp3922Tests.cs | 2 - .../Jira/CSharp4289Tests.cs | 1 - .../Jira/CSharp4524Tests.cs | 6 +- .../Jira/CSharp4586Tests.cs | 2 - ...nToAggregationExpressionTranslatorTests.cs | 8 +- ...arisonExpressionToFilterTranslatorTests.cs | 2 +- .../IntegrationTestBase.cs | 2 - .../AggregateGroupTranslatorTests.cs | 2 +- .../LegacyPredicateTranslatorTests.cs | 2 +- .../Translators/PredicateTranslatorTests.cs | 2 +- 20 files changed, 227 insertions(+), 62 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 1e40aa1347f..aa1f2922fc8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,7 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); - var context = TranslationContext.Create(partiallyEvaluatedOutput, translationOptions); + var context = TranslationContext.Create(translationOptions, inputSerializer); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +106,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions, inputSerializer); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +150,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions, inputSerializer); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +188,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(partiallyEvaluatedGroupBy, translationOptions); + var context = TranslationContext.Create(translationOptions, inputSerializer); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true)); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 1bc82b153dd..6458c00756d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -22,6 +22,7 @@ using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -43,18 +44,84 @@ public static AggregationExpression Translate( NewExpression newExpression, IReadOnlyList bindings) { + var targetSerializer = context.GetKnownSerializer(newExpression.Type); + if (targetSerializer != null) + { + return TranslateWithTargetSerializer(context, expression, newExpression, bindings, targetSerializer); + } + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; var computedFields = new List(); + var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap); + + if (constructorInfo != null && creatorMap != null) + { + var constructorParameters = constructorInfo.GetParameters(); + var creatorMapParameters = creatorMap.Arguments?.ToArray(); + if (constructorParameters.Length > 0) + { + if (creatorMapParameters == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); + } + if (creatorMapParameters.Length != constructorParameters.Length) + { + throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); + } + + for (var i = 0; i < creatorMapParameters.Length; i++) + { + var creatorMapParameter = creatorMapParameters[i]; + var constructorArgumentExpression = constructorArguments[i]; + var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression); + var constructorArgumentType = constructorArgumentExpression.Type; + var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); + var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); + EnsureDefaultValue(memberMap); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); + } + } + } - var targetSerializer = context.Data?.GetValueOrDefault("TargetSerializer", null); - var targetType = newExpression.Type; - var serializer = targetSerializer?.ValueType == targetType ? targetSerializer : BsonSerializer.LookupSerializer(targetType); - if (!(serializer is IBsonDocumentSerializer documentSerializer)) + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var memberMap = FindMemberMap(expression, classMap, member.Name); + var valueExpression = memberAssignment.Expression; + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); + memberMap.SetSerializer(memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); + } + + var ast = AstExpression.ComputedDocument(computedFields); + classMap.Freeze(); + var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type); + var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap); + + return new AggregationExpression(expression, ast, serializer); + } + + private static AggregationExpression TranslateWithTargetSerializer( + TranslationContext context, + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings, + IBsonSerializer targetSerializer) + { + if (!(targetSerializer is IBsonDocumentSerializer documentSerializer)) { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {serializer.GetType()} does not implement IBsonDocumentSerializer."); + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer.GetType()} does not implement IBsonDocumentSerializer."); } + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var computedFields = new List(); + if (constructorInfo != null && constructorInfo.GetParameters().Length > 0) { // for now we only support constructors with BsonClassMapSerializers @@ -91,7 +158,7 @@ public static AggregationExpression Translate( throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching member map for constructor parameter {creatorMapParameter.Name}"); } - var argumentContext = context.WithData("TargetSerializer", memberMap.GetSerializer()); + var argumentContext = context.WithKnownSerializer(memberMap.GetSerializer()); var argumentExpression = constructorArguments[i]; var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(argumentContext, argumentExpression); computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, argumentTranslation.Ast)); @@ -109,7 +176,7 @@ public static AggregationExpression Translate( throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}"); } - var valueContext = context.WithData("TargetSerializer", memberSerializationInfo.Serializer); + var valueContext = context.WithKnownSerializer(memberSerializationInfo.Serializer); var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(valueContext, valueExpression); computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast)); } @@ -119,6 +186,96 @@ public static AggregationExpression Translate( return new AggregationExpression(expression, ast, documentSerializer); } + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) + { + BsonClassMap baseClassMap = null; + if (classType.BaseType != null) + { + baseClassMap = CreateClassMap(classType.BaseType, null, out _); + } + + var classMapType = typeof(BsonClassMap<>).MakeGenericType(classType); + var classMapConstructorInfo = classMapType.GetConstructor(new Type[] { typeof(BsonClassMap) }); + var classMap = (BsonClassMap)classMapConstructorInfo.Invoke(new object[] { baseClassMap }); + if (constructorInfo != null) + { + creatorMap = classMap.MapConstructor(constructorInfo); + } + else + { + creatorMap = null; + } + + classMap.AutoMap(); + classMap.IdMemberMap?.SetElementName("_id"); // normally happens when Freeze is called but we need it sooner here + + return classMap; + } + + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) + { + var declaringClassMap = classMap; + while (declaringClassMap.ClassType != creatorMapParameter.DeclaringType) + { + declaringClassMap = declaringClassMap.BaseClassMap; + + if (declaringClassMap == null) + { + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching property for constructor parameter: {creatorMapParameter.Name}"); + } + } + + foreach (var memberMap in declaringClassMap.DeclaredMemberMaps) + { + if (MemberMapMatchesCreatorMapParameter(memberMap, creatorMapParameter)) + { + return memberMap; + } + } + + return declaringClassMap.MapMember(creatorMapParameter); + + static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberInfo creatorMapParameter) + { + var memberInfo = memberMap.MemberInfo; + return + memberInfo.MemberType == creatorMapParameter.MemberType && + memberInfo.Name.Equals(creatorMapParameter.Name, StringComparison.OrdinalIgnoreCase); + } + } + + private static void EnsureDefaultValue(BsonMemberMap memberMap) + { + if (memberMap.IsDefaultValueSpecified) + { + return; + } + + var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null; + memberMap.SetDefaultValue(defaultValue); + } + private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo) => classMap.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); @@ -136,5 +293,23 @@ private static BsonMemberMap FindMatchingMemberMap(BsonCreatorMap creatorMap, st return null; } + + private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + { + foreach (var memberMap in classMap.DeclaredMemberMaps) + { + if (memberMap.MemberName == memberName) + { + return memberMap; + } + } + + if (classMap.BaseClassMap != null) + { + return FindMemberMap(expression, classMap.BaseClassMap, memberName); + } + + throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index 29f929085d2..0107bee990c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -31,7 +31,7 @@ public static ExecutableQuery> Translate TranslateScalar _knownSerializers; private readonly NameGenerator _nameGenerator; private readonly SymbolTable _symbolTable; private readonly ExpressionTranslationOptions _translationOptions; private TranslationContext( + ExpressionTranslationOptions translationOptions, + IEnumerable knownSerializers, SymbolTable symbolTable, NameGenerator nameGenerator, - ExpressionTranslationOptions translationOptions, TranslationContextData data = null) { _symbolTable = Ensure.IsNotNull(symbolTable, nameof(symbolTable)); - _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); _translationOptions = translationOptions ?? new ExpressionTranslationOptions(); + _knownSerializers = knownSerializers?.AsReadOnlyList() ?? []; + _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); _data = data; // can be null } // public properties public TranslationContextData Data => _data; + public IReadOnlyList KnownSerializers => _knownSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -96,15 +102,18 @@ public Symbol CreateSymbolWithVarName(ParameterExpression parameter, string varN return CreateSymbol(parameter, name: parameterName, varName, serializer, isCurrent); } + public IBsonSerializer GetKnownSerializer(Type type) + => _knownSerializers?.FirstOrDefault(serializer => serializer.ValueType == type); + public override string ToString() { return $"{{ SymbolTable : {_symbolTable} }}"; } - public TranslationContext WithData(string key, object value) + public TranslationContext WithKnownSerializer(IBsonSerializer serializer) { - var data = _data == null ? new TranslationContextData(key, value) : _data.With(key, value); - return new TranslationContext(_symbolTable, _nameGenerator, _translationOptions, data); + var knownSerializers = _knownSerializers.Prepend(serializer).AsReadOnlyList(); + return new TranslationContext(_translationOptions, knownSerializers, _symbolTable, _nameGenerator, _data); } public TranslationContext WithSingleSymbol(Symbol newSymbol) @@ -127,7 +136,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(symbolTable, _nameGenerator, _translationOptions, _data); + return new TranslationContext(_translationOptions, _knownSerializers, symbolTable, _nameGenerator, _data); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs index 573ea089f4c..d2b3faef80a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContextData.cs @@ -27,12 +27,6 @@ public TranslationContextData() { } - public TranslationContextData(string key, object value) - : this(new Dictionary()) - { - _data[key] = value; - } - private TranslationContextData(Dictionary data) { _data = Ensure.IsNotNull(data, nameof(data)); @@ -48,10 +42,10 @@ public TValue GetValueOrDefault(string key, TValue defaultValue) return _data.TryGetValue(key, out var value) ? (TValue)value : defaultValue; } - public TranslationContextData With(string key, object value) + public TranslationContextData With(string key, TValue value) { var clonedData = new Dictionary(_data); - clonedData[key] = value; + clonedData.Add(key, value); return new TranslationContextData(clonedData); } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 69c7cd21b8c..2878af74cd7 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -60,7 +60,7 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions, contextData); + var context = TranslationContext.Create(translationOptions, sourceSerializer, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -75,7 +75,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions, documentSerializer); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -105,7 +105,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions, documentSerializer); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -124,7 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions, elementSerializer); var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch @@ -141,7 +141,7 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions, documentSerializer); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +176,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(expression, translationOptions); + var context = TranslationContext.Create(translationOptions, inputSerializer); try { @@ -200,7 +200,7 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(expression, translationOptions); // do not partially evaluate expression + var context = TranslationContext.Create(translationOptions, documentSerializer); // do not partially evaluate expression var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbolWithVarName(parameter, varName: "ROOT", documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs index 36fecaaf8b9..801d7c5d03e 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs @@ -40,7 +40,7 @@ public void Nested_Any_should_translate_correctly() var parameter = expression.Parameters[0]; var serializerRegistry = BsonSerializer.SerializerRegistry; var documentSerializer = serializerRegistry.GetSerializer(); - var context = TranslationContext.Create(expression, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null, documentSerializer); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs index 8ce88f7022f..0c16dcda035 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp3922Tests.cs @@ -14,7 +14,6 @@ */ using System.Linq; -using MongoDB.Bson.Serialization.Attributes; using Xunit; namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira @@ -81,7 +80,6 @@ private class C private class D { - [BsonConstructor] public D(int x) { X = x; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs index 14b5dc5808d..e9dc5de6f7b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4289Tests.cs @@ -75,7 +75,6 @@ public class C public class R { - [BsonConstructor] public R(string v) { V = v; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs index 97ca6f5004b..4a6c6ea8c5b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4524Tests.cs @@ -16,7 +16,6 @@ using System; using FluentAssertions; using MongoDB.Bson.Serialization; -using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; using Xunit; @@ -67,10 +66,9 @@ public enum SpawnPeriod { LIVE, MIDNIGHT, MORNING, EVENING } public struct SpawnData { - [BsonElement] public readonly DateTime Date; - [BsonElement] public readonly SpawnPeriod Period; + public readonly DateTime Date; + public readonly SpawnPeriod Period; - [BsonConstructor] public SpawnData(DateTime date, SpawnPeriod period) { // Normally there is more complex handling here, value-type semantics are important, there are custom comparison operators, etc. hence the point of this struct. diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs index f41439e4da7..2f8d3d47df6 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4586Tests.cs @@ -14,7 +14,6 @@ */ using FluentAssertions; -using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; using Xunit; @@ -167,7 +166,6 @@ public View1(string id) private class View2 { - [BsonConstructor] public View2(string id) { Id = id; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs index 1197bcf1594..442b594e5fa 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslatorTests.cs @@ -194,11 +194,10 @@ public struct SpawnDataStructParameterless public class SpawnDataClass { - [BsonElement] public readonly int Identifier; + public readonly int Identifier; public DateTime SpawnDate; private string spawnText; - [BsonConstructor] public SpawnDataClass(int identifier, DateTime spawnDate) { Identifier = identifier; @@ -214,11 +213,10 @@ public string SpawnText public class SpawnDataClassWithAdditionalParameter { - [BsonElement] public readonly int Identifier; + public readonly int Identifier; public DateTime SpawnDate; public int AdditionalField; - [BsonConstructor] public SpawnDataClassWithAdditionalParameter(int identifier, DateTime spawnDate, int additionalField) { Identifier = identifier; @@ -235,7 +233,6 @@ public struct SpawnDataStruct private string spawnText; // this constructor is required for the test to compile - [BsonConstructor] public SpawnDataStruct(int identifier, DateTime spawnDate) { Identifier = identifier; @@ -261,7 +258,6 @@ public string SpawnText public class InheritedSpawnData : SpawnDataClass { - [BsonConstructor] public InheritedSpawnData(int identifier, DateTime spawnDate) : base(identifier, spawnDate) { diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index 700bbcbf7ee..e6203d7271f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -183,7 +183,7 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue private TranslationContext CreateContext(ParameterExpression parameter) { var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(parameter, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null, serializer); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs index 4641ee6a87f..e422523ff9d 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/IntegrationTestBase.cs @@ -268,12 +268,10 @@ private void InsertJoin() public class RootView { - [BsonConstructor] public RootView() { } - [BsonConstructor] public RootView(string property) { Property = property; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index e312485b14a..4d8bfca1484 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -546,7 +546,7 @@ private ProjectedResult Group(Expression(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(expression, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null, serializer); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 4d878ca70b9..30f6435db34 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1154,7 +1154,7 @@ public List Assert(IMongoCollection collection, var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(filter, translationOptions: null); + var context = TranslationContext.Create(translationOptions: null, serializer); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body); From 5d066319c35bea49a6572f3766440d6812e275a9 Mon Sep 17 00:00:00 2001 From: rstam Date: Thu, 9 Jan 2025 15:26:38 -0800 Subject: [PATCH 5/8] Add support for constructors in new expressions when not using BsonClassMapSerializer. --- ...essionToAggregationExpressionTranslator.cs | 77 +++++++++---------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 6458c00756d..4660252f36c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -124,45 +124,32 @@ private static AggregationExpression TranslateWithTargetSerializer( if (constructorInfo != null && constructorInfo.GetParameters().Length > 0) { - // for now we only support constructors with BsonClassMapSerializers - if (!(documentSerializer is IBsonClassMapSerializer bsonClassMapSerializer)) - { - throw new ExpressionNotSupportedException(expression, because: "constructors are only supported for BsonClassMapSerializer"); - } - var classMap = bsonClassMapSerializer.ClassMap; - var creatorMap = FindMatchingCreatorMap(classMap, constructorInfo); - if (creatorMap == null) + var constructorParameters = constructorInfo.GetParameters(); + + // if the documentSerializer is a BsonClassMappedSerializer we can use the classMap + var classMap = (documentSerializer as IBsonClassMapSerializer)?.ClassMap; + var creatorMap = classMap == null ? null : FindMatchingCreatorMap(classMap, constructorInfo); + if (creatorMap == null && classMap != null) { throw new ExpressionNotSupportedException(expression, because: "couldn't find matching creator map"); } + var creatorMapArguments = creatorMap?.Arguments?.ToArray(); - var constructorParameters = constructorInfo.GetParameters(); - var creatorMapParameters = creatorMap.Arguments?.ToArray(); - if (constructorParameters.Length > 0) + for (var i = 0; i < constructorParameters.Length; i++) { - if (creatorMapParameters == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters."); - } - if (creatorMapParameters.Length != constructorParameters.Length) - { - throw new ExpressionNotSupportedException(expression, because: $"the constructor has {constructorParameters} parameters but the creatorMap has {creatorMapParameters.Length} parameters."); - } + var parameterName = constructorParameters[i].Name; + var argumentExpression = constructorArguments[i]; + var memberName = creatorMapArguments?[i].Name; // null if there is no classMap - for (var i = 0; i < creatorMapParameters.Length; i++) + var (elementName, memberSerializer) = FindMemberElementNameAndSerializer(argumentExpression, classMap, memberName, documentSerializer, parameterName); + if (elementName == null) { - var creatorMapParameter = creatorMapParameters[i]; - var memberMap = FindMatchingMemberMap(creatorMap, creatorMapParameter.Name); - if (memberMap == null) - { - throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching member map for constructor parameter {creatorMapParameter.Name}"); - } - - var argumentContext = context.WithKnownSerializer(memberMap.GetSerializer()); - var argumentExpression = constructorArguments[i]; - var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(argumentContext, argumentExpression); - computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, argumentTranslation.Ast)); + throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching class member for constructor parameter {parameterName}"); } + + var argumentContext = context.WithKnownSerializer(memberSerializer); + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(argumentContext, argumentExpression); + computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast)); } } @@ -277,21 +264,33 @@ private static void EnsureDefaultValue(BsonMemberMap memberMap) } private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo) - => classMap.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); + => classMap?.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); - private static BsonMemberMap FindMatchingMemberMap(BsonCreatorMap creatorMap, string memberName) + private static (string, IBsonSerializer) FindMemberElementNameAndSerializer( + Expression expression, + BsonClassMap classMap, + string memberName, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) { - var arguments = creatorMap.Arguments.ToArray(); - for (var index = 0; index < arguments.Length; index++) + // if we have a creatorMap use it + if (classMap != null) + { + var memberMap = FindMemberMap(expression, classMap, memberName); + return (memberMap.ElementName, memberMap.GetSerializer()); + } + + // otherwise fall back to calling TryGetMemberSerializationInfo on potential matches + var bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.IgnoreCase; + foreach (var memberInfo in documentSerializer.ValueType.GetMember(constructorParameterName, bindingFlags)) { - if (arguments[index].Name.Equals(memberName, StringComparison.Ordinal)) + if (documentSerializer.TryGetMemberSerializationInfo(memberInfo.Name, out var serializationInfo)) { - var elementName = creatorMap.ElementNames.ElementAt(index); - return creatorMap.ClassMap.AllMemberMaps.FirstOrDefault(m => m.ElementName.Equals(elementName, StringComparison.Ordinal)); + return (serializationInfo.ElementName, serializationInfo.Serializer); } } - return null; + return (null, null); } private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) From e0d3f5d477f910c2ea9da1ee9a3445e89601189a Mon Sep 17 00:00:00 2001 From: rstam Date: Thu, 9 Jan 2025 17:21:23 -0800 Subject: [PATCH 6/8] Support array target serializers. --- ...essionToAggregationExpressionTranslator.cs | 6 +- ...essionToAggregationExpressionTranslator.cs | 2 +- ...essionToAggregationExpressionTranslator.cs | 42 ++++-- .../Translators/TranslationContext.cs | 1 + .../Jira/CSharp5435Tests.cs | 124 +++++++++++++++++- 5 files changed, 161 insertions(+), 14 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index d4c74a22b09..bf4829d0040 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,10 +23,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(ConstantExpression constantExpression) + public static AggregationExpression Translate(TranslationContext context, ConstantExpression constantExpression) { var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + var constantSerializer = + context.GetKnownSerializer(constantType) ?? + (StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType)); return Translate(constantExpression, constantSerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index 81a4dfc6ae7..a88ffb55ea1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -67,7 +67,7 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c6d95c36a43..8f805e846b9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -26,13 +26,35 @@ internal static class NewArrayInitExpressionToAggregationExpressionTranslator { public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression) { - var items = new List(); + IBsonArraySerializer arraySerializer = null; IBsonSerializer itemSerializer = null; + + var targetSerializer = context.GetKnownSerializer(expression.Type); + if (targetSerializer != null) + { + if ((arraySerializer = targetSerializer as IBsonArraySerializer) == null) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} does not implement IBsonArraySerializer"); + } + if (!arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} returned false for TryGetItemSerializationInfo"); + } + + itemSerializer = itemSerializationInfo.Serializer; + } + + var items = new List(); + var itemContext = itemSerializer == null ? context : context.WithKnownSerializer(itemSerializer); foreach (var itemExpression in expression.Expressions) { - var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); + var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(itemContext, itemExpression); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; + if (itemSerializer == null) + { + itemSerializer = itemTranslation.Serializer; + itemContext = context.WithKnownSerializer(itemSerializer); + } // make sure all items are serialized using the same serializer if (!itemTranslation.Serializer.Equals(itemSerializer)) @@ -42,12 +64,14 @@ public static AggregationExpression Translate(TranslationContext context, NewArr } var ast = AstExpression.ComputedArray(items); - - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + if (arraySerializer == null) + { + var arrayType = expression.Type; + var itemType = arrayType.GetElementType(); + itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + arraySerializer = (IBsonArraySerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } return new AggregationExpression(expression, ast, arraySerializer); } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs index 392da238130..512162279cb 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/TranslationContext.cs @@ -112,6 +112,7 @@ public override string ToString() public TranslationContext WithKnownSerializer(IBsonSerializer serializer) { + Ensure.IsNotNull(serializer, nameof(serializer)); var knownSerializers = _knownSerializers.Prepend(serializer).AsReadOnlyList(); return new TranslationContext(_translationOptions, knownSerializers, _symbolTable, _nameGenerator, _data); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs index ada15cb79b2..e56cc4f285f 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -16,6 +16,7 @@ using System.Linq; using FluentAssertions; using MongoDB.Bson; +using MongoDB.Bson.IO; using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Bson.Serialization.Serializers; @@ -26,7 +27,30 @@ namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira public class CSharp5435Tests : Linq3IntegrationTest { [Fact] - public void Test() + public void Test_set_ValueObject_Value_using_creator_map() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() { var coll = GetCollection(); var doc = new MyDocument(); @@ -52,7 +76,7 @@ public void Test() } [Fact] - public void TestDerived() + public void Test_set_ValueObject_to_derived_value_using_property_setter() { var coll = GetCollection(); var doc = new MyDocument(); @@ -72,6 +96,52 @@ public void TestDerived() coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); } + [Fact] + public void Test_set_X_using_constructor() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', '0'] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + private IMongoCollection GetCollection() { var collection = GetCollection("test"); @@ -91,10 +161,21 @@ class MyDocument public MyValue ValueObject { get; set; } public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } } class MyValue { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } public int Value { get; set; } } @@ -102,5 +183,44 @@ class MyDerivedValue : MyValue { public int B { get; set; } } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } } } From 0d8369d2074913a449c018fbc1019738535c796b Mon Sep 17 00:00:00 2001 From: rstam Date: Fri, 10 Jan 2025 09:09:40 -0800 Subject: [PATCH 7/8] CSHARP-5435: Added resultSerializer. --- .../Serializers/BsonClassMapSerializer.cs | 4 +-- ...ingWithOutputExpressionStageDefinitions.cs | 8 ++--- ...essionToAggregationExpressionTranslator.cs | 19 +++++------ ...essionToAggregationExpressionTranslator.cs | 10 +++--- ...essionToAggregationExpressionTranslator.cs | 32 +++++++++---------- ...essionToAggregationExpressionTranslator.cs | 15 ++++----- ...essionToAggregationExpressionTranslator.cs | 5 +-- .../ExpressionToExecutableQueryTranslator.cs | 4 +-- .../ConcatMethodToPipelineTranslator.cs | 2 +- .../UnionMethodToPipelineTranslator.cs | 2 +- .../ExpressionToSetStageTranslator.cs | 5 ++- .../Translators/TranslationContext.cs | 19 ++--------- .../Linq/LinqProviderAdapter.cs | 14 ++++---- .../Jira/CSharp1585Tests.cs | 2 +- .../Jira/CSharp5435Tests.cs | 2 +- ...arisonExpressionToFilterTranslatorTests.cs | 2 +- .../AggregateGroupTranslatorTests.cs | 2 +- .../LegacyPredicateTranslatorTests.cs | 2 +- .../Translators/PredicateTranslatorTests.cs | 2 +- 19 files changed, 64 insertions(+), 87 deletions(-) diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs index a83f752a38d..54e1dbb54bc 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs @@ -24,12 +24,12 @@ namespace MongoDB.Bson.Serialization { /// - /// + /// An interface implemented by BsonClassMapSerializer. /// public interface IBsonClassMapSerializer { /// - /// + /// Gets the class map for a BsonClassMapSerializer. /// public BsonClassMap ClassMap { get; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index aa1f2922fc8..fe58ebba3a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -64,7 +64,7 @@ private AstStage RenderProjectStage( out IBsonSerializer outputSerializer) { var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); - var context = TranslationContext.Create(translationOptions, inputSerializer); + var context = TranslationContext.Create(translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); outputSerializer = (IBsonSerializer)projectSerializer; @@ -106,7 +106,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions, inputSerializer); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -150,7 +150,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer, TInput>> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions, inputSerializer); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var valueSerializer = (IBsonSerializer)groupByTranslation.Serializer; @@ -188,7 +188,7 @@ protected override AstStage RenderGroupingStage( out IBsonSerializer> groupingOutputSerializer) { var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); - var context = TranslationContext.Create(translationOptions, inputSerializer); + var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.Var("ROOT", isCurrent: true)); var groupBySerializer = (IBsonSerializer)groupByTranslation.Serializer; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index bf4829d0040..87efbfdd596 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -23,21 +23,18 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, ConstantExpression constantExpression) + public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer resultSerializer) { - var constantType = constantExpression.Type; - var constantSerializer = - context.GetKnownSerializer(constantType) ?? - (StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType)); - return Translate(constantExpression, constantSerializer); - } + if (resultSerializer == null) + { + var constantType = constantExpression.Type; + resultSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + } - public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) - { var constantValue = constantExpression.Value; - var serializedValue = constantSerializer.ToBsonValue(constantValue); + var serializedValue = resultSerializer.ToBsonValue(constantValue); var ast = AstExpression.Constant(serializedValue); - return new AggregationExpression(constantExpression, ast, constantSerializer); + return new AggregationExpression(constantExpression, ast, resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index a88ffb55ea1..f5ff2417c4b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -27,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg internal static class ExpressionToAggregationExpressionTranslator { // public static methods - public static AggregationExpression Translate(TranslationContext context, Expression expression) + public static AggregationExpression Translate(TranslationContext context, Expression expression, IBsonSerializer resultSerializer = null) { switch (expression.NodeType) { @@ -67,7 +67,7 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate(context, (ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression, resultSerializer); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: @@ -75,13 +75,13 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.MemberAccess: return MemberExpressionToAggregationExpressionTranslator.Translate(context, (MemberExpression)expression); case ExpressionType.MemberInit: - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression); + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression, resultSerializer); case ExpressionType.Negate: return NegateExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression); case ExpressionType.New: - return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression); + return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression, resultSerializer); case ExpressionType.NewArrayInit: - return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression); + return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression, resultSerializer); case ExpressionType.Parameter: return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression); case ExpressionType.TypeIs: diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 4660252f36c..8cad8bb1c99 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -28,26 +28,26 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class MemberInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression) + public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression, IBsonSerializer resultSerializer) { if (expression.Type == typeof(BsonDocument)) { return NewBsonDocumentExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return Translate(context, expression, expression.NewExpression, expression.Bindings); + return Translate(context, expression, expression.NewExpression, expression.Bindings, resultSerializer); } public static AggregationExpression Translate( TranslationContext context, Expression expression, NewExpression newExpression, - IReadOnlyList bindings) + IReadOnlyList bindings, + IBsonSerializer resultSerializer) { - var targetSerializer = context.GetKnownSerializer(newExpression.Type); - if (targetSerializer != null) + if (resultSerializer != null) { - return TranslateWithTargetSerializer(context, expression, newExpression, bindings, targetSerializer); + return TranslateWithTargetSerializer(context, expression, newExpression, bindings, resultSerializer); } var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct @@ -111,18 +111,18 @@ private static AggregationExpression TranslateWithTargetSerializer( Expression expression, NewExpression newExpression, IReadOnlyList bindings, - IBsonSerializer targetSerializer) + IBsonSerializer resultSerializer) { - if (!(targetSerializer is IBsonDocumentSerializer documentSerializer)) + if (!(resultSerializer is IBsonDocumentSerializer documentSerializer)) { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer.GetType()} does not implement IBsonDocumentSerializer."); + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer.GetType()} does not implement IBsonDocumentSerializer."); } var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; var computedFields = new List(); - if (constructorInfo != null && constructorInfo.GetParameters().Length > 0) + if (constructorInfo != null && constructorArguments.Count > 0) { var constructorParameters = constructorInfo.GetParameters(); @@ -147,8 +147,7 @@ private static AggregationExpression TranslateWithTargetSerializer( throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching class member for constructor parameter {parameterName}"); } - var argumentContext = context.WithKnownSerializer(memberSerializer); - var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(argumentContext, argumentExpression); + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression, memberSerializer); computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast)); } } @@ -156,20 +155,19 @@ private static AggregationExpression TranslateWithTargetSerializer( foreach (var binding in bindings) { var memberAssignment = (MemberAssignment)binding; - var valueExpression = memberAssignment.Expression; var member = memberAssignment.Member; + var valueExpression = memberAssignment.Expression; if (!documentSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) { throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}"); } + var memberSerializer = memberSerializationInfo.Serializer; - var valueContext = context.WithKnownSerializer(memberSerializationInfo.Serializer); - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(valueContext, valueExpression); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, memberSerializer); computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast)); } var ast = AstExpression.ComputedDocument(computedFields); - return new AggregationExpression(expression, ast, documentSerializer); } @@ -273,7 +271,7 @@ private static (string, IBsonSerializer) FindMemberElementNameAndSerializer( IBsonDocumentSerializer documentSerializer, string constructorParameterName) { - // if we have a creatorMap use it + // if we have a classMap use it if (classMap != null) { var memberMap = FindMemberMap(expression, classMap, memberName); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index 8f805e846b9..996ccce14d2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -24,36 +24,33 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class NewArrayInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression, IBsonSerializer resultSerializer) { IBsonArraySerializer arraySerializer = null; IBsonSerializer itemSerializer = null; - var targetSerializer = context.GetKnownSerializer(expression.Type); - if (targetSerializer != null) + if (resultSerializer != null) { - if ((arraySerializer = targetSerializer as IBsonArraySerializer) == null) + if ((arraySerializer = resultSerializer as IBsonArraySerializer) == null) { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} does not implement IBsonArraySerializer"); + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer} does not implement IBsonArraySerializer"); } if (!arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) { - throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} returned false for TryGetItemSerializationInfo"); + throw new ExpressionNotSupportedException(expression, because: $"serializer class {resultSerializer} returned false for TryGetItemSerializationInfo"); } itemSerializer = itemSerializationInfo.Serializer; } var items = new List(); - var itemContext = itemSerializer == null ? context : context.WithKnownSerializer(itemSerializer); foreach (var itemExpression in expression.Expressions) { - var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(itemContext, itemExpression); + var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression, itemSerializer); items.Add(itemTranslation.Ast); if (itemSerializer == null) { itemSerializer = itemTranslation.Serializer; - itemContext = context.WithKnownSerializer(itemSerializer); } // make sure all items are serialized using the same serializer diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index ee930fd6d5f..4d90d2bb96d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -17,12 +17,13 @@ using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; +using MongoDB.Bson.Serialization; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { internal static class NewExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewExpression expression, IBsonSerializer resultSerializer) { var expressionType = expression.Type; @@ -46,7 +47,7 @@ public static AggregationExpression Translate(TranslationContext context, NewExp { return NewTupleExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty(), resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index 0107bee990c..b96a193e323 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -31,7 +31,7 @@ public static ExecutableQuery> Translate TranslateScalar _knownSerializers; private readonly NameGenerator _nameGenerator; private readonly SymbolTable _symbolTable; private readonly ExpressionTranslationOptions _translationOptions; private TranslationContext( ExpressionTranslationOptions translationOptions, - IEnumerable knownSerializers, SymbolTable symbolTable, NameGenerator nameGenerator, TranslationContextData data = null) { _symbolTable = Ensure.IsNotNull(symbolTable, nameof(symbolTable)); _translationOptions = translationOptions ?? new ExpressionTranslationOptions(); - _knownSerializers = knownSerializers?.AsReadOnlyList() ?? []; _nameGenerator = Ensure.IsNotNull(nameGenerator, nameof(nameGenerator)); _data = data; // can be null } // public properties public TranslationContextData Data => _data; - public IReadOnlyList KnownSerializers => _knownSerializers; public NameGenerator NameGenerator => _nameGenerator; public SymbolTable SymbolTable => _symbolTable; public ExpressionTranslationOptions TranslationOptions => _translationOptions; @@ -102,21 +97,11 @@ public Symbol CreateSymbolWithVarName(ParameterExpression parameter, string varN return CreateSymbol(parameter, name: parameterName, varName, serializer, isCurrent); } - public IBsonSerializer GetKnownSerializer(Type type) - => _knownSerializers?.FirstOrDefault(serializer => serializer.ValueType == type); - public override string ToString() { return $"{{ SymbolTable : {_symbolTable} }}"; } - public TranslationContext WithKnownSerializer(IBsonSerializer serializer) - { - Ensure.IsNotNull(serializer, nameof(serializer)); - var knownSerializers = _knownSerializers.Prepend(serializer).AsReadOnlyList(); - return new TranslationContext(_translationOptions, knownSerializers, _symbolTable, _nameGenerator, _data); - } - public TranslationContext WithSingleSymbol(Symbol newSymbol) { var newSymbolTable = new SymbolTable(newSymbol); @@ -137,7 +122,7 @@ public TranslationContext WithSymbols(params Symbol[] newSymbols) public TranslationContext WithSymbolTable(SymbolTable symbolTable) { - return new TranslationContext(_translationOptions, _knownSerializers, symbolTable, _nameGenerator, _data); + return new TranslationContext(_translationOptions, symbolTable, _nameGenerator, _data); } } } diff --git a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs index 2878af74cd7..5b436d76703 100644 --- a/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs +++ b/src/MongoDB.Driver/Linq/LinqProviderAdapter.cs @@ -60,7 +60,7 @@ internal static BsonValue TranslateExpressionToAggregateExpression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, sourceSerializer, contextData); + var context = TranslationContext.Create(translationOptions, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -75,7 +75,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( { expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions, documentSerializer); + var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var body = RemovePossibleConvertToObject(expression.Body); @@ -105,7 +105,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); var parameter = expression.Parameters.Single(); - var context = TranslationContext.Create(translationOptions, documentSerializer); + var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, expression.Body); @@ -124,7 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, elementSerializer); + var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element context = context.WithSingleSymbol(symbol); // @ is the only symbol visible inside an $elemMatch @@ -141,7 +141,7 @@ internal static BsonDocument TranslateExpressionToFilter( ExpressionTranslationOptions translationOptions) { expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, documentSerializer); + var context = TranslationContext.Create(translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +176,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec } expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); - var context = TranslationContext.Create(translationOptions, inputSerializer); + var context = TranslationContext.Create(translationOptions); try { @@ -200,7 +200,7 @@ internal static BsonDocument TranslateExpressionToSetStage( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - var context = TranslationContext.Create(translationOptions, documentSerializer); // do not partially evaluate expression + var context = TranslationContext.Create(translationOptions); // do not partially evaluate expression var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbolWithVarName(parameter, varName: "ROOT", documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs index 801d7c5d03e..c88054db978 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp1585Tests.cs @@ -40,7 +40,7 @@ public void Nested_Any_should_translate_correctly() var parameter = expression.Parameters[0]; var serializerRegistry = BsonSerializer.SerializerRegistry; var documentSerializer = serializerRegistry.GetSerializer(); - var context = TranslationContext.Create(translationOptions: null, documentSerializer); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); context = context.WithSymbol(symbol); var filter = ExpressionToFilterTranslator.Translate(context, expression.Body, exprOk: false); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs index e56cc4f285f..dd4df74ece6 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -137,7 +137,7 @@ public void Test_set_A() updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) .AsBsonArray .Cast(); - AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', '0'] }] } }"); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs index e6203d7271f..10f3f2a5d14 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/ModuloComparisonExpressionToFilterTranslatorTests.cs @@ -183,7 +183,7 @@ private void Assert(AstFilter result, string path, BsonValue divisor, BsonValue private TranslationContext CreateContext(ParameterExpression parameter) { var serializer = BsonSerializer.LookupSerializer(parameter.Type); - var context = TranslationContext.Create(translationOptions: null, serializer); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); return context.WithSymbol(symbol); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index 4d8bfca1484..fd2e09732c8 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -546,7 +546,7 @@ private ProjectedResult Group(Expression(Expression> expression, int var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); - var context = TranslationContext.Create(translationOptions: null, serializer); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, expression.Body); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 30f6435db34..0869d70822e 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1154,7 +1154,7 @@ public List Assert(IMongoCollection collection, var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single(); - var context = TranslationContext.Create(translationOptions: null, serializer); + var context = TranslationContext.Create(translationOptions: null); var symbol = context.CreateSymbol(parameter, serializer, isCurrent: true); context = context.WithSymbol(symbol); var filterAst = ExpressionToFilterTranslator.Translate(context, filter.Body); From 314be7f5f3d6c41e57ecfc0c6d243a70e3205f25 Mon Sep 17 00:00:00 2001 From: rstam Date: Fri, 10 Jan 2025 09:20:31 -0800 Subject: [PATCH 8/8] CSHARP-5435: Fix rebasing issue. --- .../LookupMethodToPipelineTranslator.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs index 183b00bf1c8..995e73bf185 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/LookupMethodToPipelineTranslator.cs @@ -281,7 +281,7 @@ private static TranslatedPipeline TranslateDocumentsPipelineGeneric