diff --git a/src/LinqToDB.Include.Tests/LinqToDB.Include.Tests/TestModel/Person.cs b/src/LinqToDB.Include.Tests/LinqToDB.Include.Tests/TestModel/Person.cs index bc4e74a..fb1cd7f 100644 --- a/src/LinqToDB.Include.Tests/LinqToDB.Include.Tests/TestModel/Person.cs +++ b/src/LinqToDB.Include.Tests/LinqToDB.Include.Tests/TestModel/Person.cs @@ -18,7 +18,7 @@ public class Person public int? SpouseId { get; set; } public Person Spouse { get; set; } - public List Orders { get; set; } = new List(); + public IEnumerable Orders { get; set; } = new List(); public static Expression> ExtraJoinOptions() => (p, o) => o.OrderId < 99; } diff --git a/src/LinqToDB.Include/Accessors/IPropertyAccessor.cs b/src/LinqToDB.Include/Accessors/IPropertyAccessor.cs index 0ae1aa5..7abadf8 100644 --- a/src/LinqToDB.Include/Accessors/IPropertyAccessor.cs +++ b/src/LinqToDB.Include/Accessors/IPropertyAccessor.cs @@ -10,7 +10,7 @@ interface IPropertyAccessor Type MemberEntityType { get; } HashSet Properties { get; } string PropertyName { get; } - bool IsMemberTypeICollection { get; } + bool IsMemberTypeIEnumerable { get; } IPropertyAccessor FindAccessor(List pathParts); diff --git a/src/LinqToDB.Include/Accessors/PropertyAccessor.cs b/src/LinqToDB.Include/Accessors/PropertyAccessor.cs index dc5274c..f499dba 100644 --- a/src/LinqToDB.Include/Accessors/PropertyAccessor.cs +++ b/src/LinqToDB.Include/Accessors/PropertyAccessor.cs @@ -18,7 +18,7 @@ public PropertyAccessor(MemberExpression exp, MappingSchema mappingSchema) _declaringType = exp.Member.DeclaringType; _memberType = exp.Type; _memberEntityType = typeof(TProperty); - _isMemberTypeICollection = _memberType.IsICollection(); + _isMemberTypeIEnumerable = _memberType.IsIEnumerable(); _isMemberEntityTypeIEnumerable = _memberEntityType.IsIEnumerable(); ChildEntityDescriptor = mappingSchema.GetEntityDescriptor(_memberEntityType); diff --git a/src/LinqToDB.Include/Accessors/PropertyAccessorAbstract.cs b/src/LinqToDB.Include/Accessors/PropertyAccessorAbstract.cs index f38fafb..e6cc9af 100644 --- a/src/LinqToDB.Include/Accessors/PropertyAccessorAbstract.cs +++ b/src/LinqToDB.Include/Accessors/PropertyAccessorAbstract.cs @@ -11,7 +11,7 @@ abstract class PropertyAccessor : IPropertyAccessor where TClass protected Type _declaringType; protected Type _memberType; protected Type _memberEntityType; - protected bool _isMemberTypeICollection; + protected bool _isMemberTypeIEnumerable; protected bool _isMemberEntityTypeIEnumerable; internal abstract void Load(List entities, IQueryable query); @@ -21,7 +21,7 @@ abstract class PropertyAccessor : IPropertyAccessor where TClass public Type MemberType { get => _memberType; } public Type MemberEntityType { get => _memberEntityType; } - public bool IsMemberTypeICollection { get => _isMemberTypeICollection; } + public bool IsMemberTypeIEnumerable { get => _isMemberTypeIEnumerable; } public bool IsMemberEntityTypeIEnumerable { get => _isMemberEntityTypeIEnumerable; } public abstract HashSet Properties { get; } diff --git a/src/LinqToDB.Include/ReflectionExtensions.cs b/src/LinqToDB.Include/ReflectionExtensions.cs index 8200f86..204223e 100644 --- a/src/LinqToDB.Include/ReflectionExtensions.cs +++ b/src/LinqToDB.Include/ReflectionExtensions.cs @@ -28,25 +28,6 @@ internal static Type GetTypeToUse(this Type type) } internal static bool IsIEnumerable(this Type type) - { - if (type.IsGenericType) - { - if (type.IsInterface) - { - return type.GetGenericTypeDefinition().GetInterfaces() - .Any(x => x.GetGenericTypeDefinition() == typeof(IEnumerable<>)); - } - - var genericTypeDefinition = type.GetGenericTypeDefinition(); - - return genericTypeDefinition.GetInterfaces() - .Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); - } - - return false; - } - - internal static bool IsICollection(this Type type) { if (type.IsGenericType) { @@ -54,19 +35,19 @@ internal static bool IsICollection(this Type type) { var def = type.GetGenericTypeDefinition(); - return def == typeof(ICollection<>) || - def.GetInterfaces().Any(x => x.GetGenericTypeDefinition() == typeof(ICollection<>)); + return def == typeof(IEnumerable<>) || type.GetGenericTypeDefinition().GetInterfaces() + .Any(x => x.GetGenericTypeDefinition() == typeof(IEnumerable<>)); } var genericTypeDefinition = type.GetGenericTypeDefinition(); return genericTypeDefinition.GetInterfaces() - .Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ICollection<>)); + .Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>)); } return false; } - + internal static Action CreatePropertySetter( this Type elementType, string propertyName) { @@ -81,38 +62,16 @@ internal static Action CreatePropertySetter( return action.Compile(); } - [Obsolete] - internal static Action> CreateICollectionPropertySetter( - this Type elementType, string propertyName) - { - var pi = elementType.GetProperty(propertyName, BindingFlags.Public | BindingFlags.Instance); - var mi = pi.GetSetMethod(); - - var oParam = Expression.Parameter(elementType, "obj"); - var vParam = Expression.Parameter(typeof(ICollection), "val"); - var mce = Expression.Call(oParam, mi, vParam); - - var clearMethod = Expression.Call(vParam, typeof(ICollection).GetMethod("Clear")); - var ifCriteria = Expression.Equal(Expression.Call(oParam, pi.GetGetMethod()), Expression.Constant(null)); - var ifnullSetter = Expression.IfThenElse(ifCriteria, mce, clearMethod); - - var action = Expression.Lambda>>(ifnullSetter, oParam, vParam); - - return action.Compile(); - } - - - internal static Action CreateCollectionPropertySetter( this Type elementType, string propertyName, Type propertyType) { - var pi = elementType.GetProperty(propertyName, BindingFlags.Public | BindingFlags.Instance); - var mi = pi.GetSetMethod(); - var oParam = Expression.Parameter(elementType, "obj"); var vParam = Expression.Parameter(typeof(TValue), "val"); - var mce = Expression.Call(Expression.Property(oParam, propertyName), - typeof(ICollection).GetMethod("Add"), vParam); + var mce = Expression.Call( + Expression.Convert( + Expression.Property(oParam, propertyName) + , typeof(ICollection)) + , typeof(ICollection).GetMethod("Add"), vParam); var action = Expression.Lambda>(mce, oParam, vParam); @@ -153,9 +112,7 @@ internal static Action CreatePropertySetup(this Type i var mi = pi.GetSetMethod(); var mce = Expression.Call(parentParam, mi, newCollection); - var @if = Expression.IfThenElse(isParamNull, - mce, - Expression.Call(property, typeof(ICollection).GetMethod(nameof(ICollection.Clear)))); + var @if = Expression.IfThen(isParamNull, mce); var finalCode = Expression.Lambda>(@if, parentParam); @@ -183,7 +140,7 @@ private static Type GetTypeToCreate(Type type) { typeNum = 2; } - else if (def == typeof(ICollection<>)) + else if (def == typeof(ICollection<>) || def == typeof(IEnumerable<>)) { typeNum = 3; } diff --git a/src/LinqToDB.Include/Setters/EntityPropertySetter.cs b/src/LinqToDB.Include/Setters/EntityPropertySetter.cs index 26b1dd4..58fd9f9 100644 --- a/src/LinqToDB.Include/Setters/EntityPropertySetter.cs +++ b/src/LinqToDB.Include/Setters/EntityPropertySetter.cs @@ -109,7 +109,7 @@ private static void MatchEntityLookup(PropertyAccessor(schema.PropertyName, schema.MemberType); @@ -146,7 +146,7 @@ private static void MatchEntityList(PropertyAccessor(schema.PropertyName, schema.MemberType); diff --git a/src/LinqToDB.Include/Visitors/PropertyVisitor.cs b/src/LinqToDB.Include/Visitors/PropertyVisitor.cs index a9c52c4..7ad9dee 100644 --- a/src/LinqToDB.Include/Visitors/PropertyVisitor.cs +++ b/src/LinqToDB.Include/Visitors/PropertyVisitor.cs @@ -15,11 +15,11 @@ internal PropertyVisitor(IRootAccessor rootAccessor) { _rootAccessor = rootAccessor; } - + private static void AddFilterForDynamicType(IPropertyAccessor accessor, - Expression> includeFilter) - where T : class - where TProperty : class + Expression> includeFilter) + where T : class + where TProperty : class { var accessorImpl = accessor as PropertyAccessor; if (accessorImpl == null) @@ -29,7 +29,7 @@ private static void AddFilterForDynamicType(IPropertyAccessor a accessorImpl.AddFilter(includeFilter); } - public IRootAccessor MapProperties(Expression> expr, + public IRootAccessor MapProperties(Expression> expr, Expression> includeFilter = null) where TProperty : class { @@ -41,7 +41,7 @@ public IRootAccessor MapProperties(Expression accessorImpl) @@ -53,8 +53,8 @@ public IRootAccessor MapProperties(Expression t.IsGenericType && - t.GetGenericTypeDefinition() == typeof(IEnumerable<>))) + if (genericTypeDefinition == typeof(IEnumerable<>) || + genericTypeDefinition.GetInterfaces().Any(t => t.IsGenericType && + t.GetGenericTypeDefinition() == typeof(IEnumerable<>))) { return type.GetGenericArguments()[0]; }