Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathan Parkinson committed Feb 27, 2019
2 parents daa4e3c + 349978e commit 5e3a78e
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class Person
public int? SpouseId { get; set; }
public Person Spouse { get; set; }

public List<Order> Orders { get; set; } = new List<Order>();
public IEnumerable<Order> Orders { get; set; } = new List<Order>();

public static Expression<Func<Person, Order, bool>> ExtraJoinOptions() => (p, o) => o.OrderId < 99;
}
Expand Down
2 changes: 1 addition & 1 deletion src/LinqToDB.Include/Accessors/IPropertyAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ interface IPropertyAccessor
Type MemberEntityType { get; }
HashSet<IPropertyAccessor> Properties { get; }
string PropertyName { get; }
bool IsMemberTypeICollection { get; }
bool IsMemberTypeIEnumerable { get; }


IPropertyAccessor FindAccessor(List<string> pathParts);
Expand Down
2 changes: 1 addition & 1 deletion src/LinqToDB.Include/Accessors/PropertyAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public PropertyAccessor(MemberExpression exp, MappingSchema mappingSchema)
_declaringType = exp.Member.DeclaringType;
_memberType = exp.Type;
_memberEntityType = typeof(TProperty);
_isMemberTypeICollection = _memberType.IsICollection();
_isMemberTypeIEnumerable = _memberType.IsIEnumerable();
_isMemberEntityTypeIEnumerable = _memberEntityType.IsIEnumerable();

ChildEntityDescriptor = mappingSchema.GetEntityDescriptor(_memberEntityType);
Expand Down
4 changes: 2 additions & 2 deletions src/LinqToDB.Include/Accessors/PropertyAccessorAbstract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ abstract class PropertyAccessor<TClass> : IPropertyAccessor<TClass> where TClass
protected Type _declaringType;
protected Type _memberType;
protected Type _memberEntityType;
protected bool _isMemberTypeICollection;
protected bool _isMemberTypeIEnumerable;
protected bool _isMemberEntityTypeIEnumerable;

internal abstract void Load(List<TClass> entities, IQueryable<TClass> query);
Expand All @@ -21,7 +21,7 @@ abstract class PropertyAccessor<TClass> : IPropertyAccessor<TClass> where TClass
public Type MemberType { get => _memberType; }
public Type MemberEntityType { get => _memberEntityType; }

public bool IsMemberTypeICollection { get => _isMemberTypeICollection; }
public bool IsMemberTypeIEnumerable { get => _isMemberTypeIEnumerable; }
public bool IsMemberEntityTypeIEnumerable { get => _isMemberEntityTypeIEnumerable; }

public abstract HashSet<IPropertyAccessor> Properties { get; }
Expand Down
65 changes: 11 additions & 54 deletions src/LinqToDB.Include/ReflectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,26 @@ internal static Type GetTypeToUse(this Type type)
}

internal static bool IsIEnumerable(this Type type)
{
if (type.IsGenericType)
{
if (type.IsInterface)
{
return type.GetGenericTypeDefinition().GetInterfaces()
.Any(x => x.GetGenericTypeDefinition() == typeof(IEnumerable<>));
}

var genericTypeDefinition = type.GetGenericTypeDefinition();

return genericTypeDefinition.GetInterfaces()
.Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>));
}

return false;
}

internal static bool IsICollection(this Type type)
{
if (type.IsGenericType)
{
if (type.IsInterface)
{
var def = type.GetGenericTypeDefinition();

return def == typeof(ICollection<>) ||
def.GetInterfaces().Any(x => x.GetGenericTypeDefinition() == typeof(ICollection<>));
return def == typeof(IEnumerable<>) || type.GetGenericTypeDefinition().GetInterfaces()
.Any(x => x.GetGenericTypeDefinition() == typeof(IEnumerable<>));
}

var genericTypeDefinition = type.GetGenericTypeDefinition();

return genericTypeDefinition.GetInterfaces()
.Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ICollection<>));
.Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>));
}

return false;
}

internal static Action<TElement, TValue> CreatePropertySetter<TElement, TValue>(
this Type elementType, string propertyName)
{
Expand All @@ -81,38 +62,16 @@ internal static Action<TElement, TValue> CreatePropertySetter<TElement, TValue>(
return action.Compile();
}

[Obsolete]
internal static Action<TElement, ICollection<TValue>> CreateICollectionPropertySetter<TElement, TValue>(
this Type elementType, string propertyName)
{
var pi = elementType.GetProperty(propertyName, BindingFlags.Public | BindingFlags.Instance);
var mi = pi.GetSetMethod();

var oParam = Expression.Parameter(elementType, "obj");
var vParam = Expression.Parameter(typeof(ICollection<TValue>), "val");
var mce = Expression.Call(oParam, mi, vParam);

var clearMethod = Expression.Call(vParam, typeof(ICollection<TValue>).GetMethod("Clear"));
var ifCriteria = Expression.Equal(Expression.Call(oParam, pi.GetGetMethod()), Expression.Constant(null));
var ifnullSetter = Expression.IfThenElse(ifCriteria, mce, clearMethod);

var action = Expression.Lambda<Action<TElement, ICollection<TValue>>>(ifnullSetter, oParam, vParam);

return action.Compile();
}



internal static Action<TElement, TValue> CreateCollectionPropertySetter<TElement, TValue>(
this Type elementType, string propertyName, Type propertyType)
{
var pi = elementType.GetProperty(propertyName, BindingFlags.Public | BindingFlags.Instance);
var mi = pi.GetSetMethod();

var oParam = Expression.Parameter(elementType, "obj");
var vParam = Expression.Parameter(typeof(TValue), "val");
var mce = Expression.Call(Expression.Property(oParam, propertyName),
typeof(ICollection<TValue>).GetMethod("Add"), vParam);
var mce = Expression.Call(
Expression.Convert(
Expression.Property(oParam, propertyName)
, typeof(ICollection<TValue>))
, typeof(ICollection<TValue>).GetMethod("Add"), vParam);

var action = Expression.Lambda<Action<TElement, TValue>>(mce, oParam, vParam);

Expand Down Expand Up @@ -153,9 +112,7 @@ internal static Action<TParent> CreatePropertySetup<TParent, TChild>(this Type i
var mi = pi.GetSetMethod();
var mce = Expression.Call(parentParam, mi, newCollection);

var @if = Expression.IfThenElse(isParamNull,
mce,
Expression.Call(property, typeof(ICollection<TChild>).GetMethod(nameof(ICollection<int>.Clear))));
var @if = Expression.IfThen(isParamNull, mce);

var finalCode = Expression.Lambda<Action<TParent>>(@if, parentParam);

Expand Down Expand Up @@ -183,7 +140,7 @@ private static Type GetTypeToCreate(Type type)
{
typeNum = 2;
}
else if (def == typeof(ICollection<>))
else if (def == typeof(ICollection<>) || def == typeof(IEnumerable<>))
{
typeNum = 3;
}
Expand Down
4 changes: 2 additions & 2 deletions src/LinqToDB.Include/Setters/EntityPropertySetter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private static void MatchEntityLookup<TParent, TChild>(PropertyAccessor<TParent,
where TParent : class
where TChild : class
{
if (schema.IsMemberTypeICollection)
if (schema.IsMemberTypeIEnumerable)
{
var setter = schema.DeclaringType.CreateCollectionPropertySetter<TParent, TChild>(schema.PropertyName,
schema.MemberType);
Expand Down Expand Up @@ -146,7 +146,7 @@ private static void MatchEntityList<TParent, TChild>(PropertyAccessor<TParent, T
where TParent : class
where TChild : class
{
if (schema.IsMemberTypeICollection)
if (schema.IsMemberTypeIEnumerable)
{
var setter = schema.DeclaringType.CreateCollectionPropertySetter<TParent, TChild>(schema.PropertyName,
schema.MemberType);
Expand Down
24 changes: 12 additions & 12 deletions src/LinqToDB.Include/Visitors/PropertyVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ internal PropertyVisitor(IRootAccessor<TClass> rootAccessor)
{
_rootAccessor = rootAccessor;
}

private static void AddFilterForDynamicType<T, TProperty>(IPropertyAccessor<T> accessor,
Expression<Func<TProperty, bool>> includeFilter)
where T : class
where TProperty : class
Expression<Func<TProperty, bool>> includeFilter)
where T : class
where TProperty : class
{
var accessorImpl = accessor as PropertyAccessor<T, TProperty>;
if (accessorImpl == null)
Expand All @@ -29,7 +29,7 @@ private static void AddFilterForDynamicType<T, TProperty>(IPropertyAccessor<T> a
accessorImpl.AddFilter(includeFilter);
}

public IRootAccessor<TClass> MapProperties<TProperty>(Expression<Func<TClass, TProperty>> expr,
public IRootAccessor<TClass> MapProperties<TProperty>(Expression<Func<TClass, TProperty>> expr,
Expression<Func<TProperty, bool>> includeFilter = null)
where TProperty : class
{
Expand All @@ -41,7 +41,7 @@ public IRootAccessor<TClass> MapProperties<TProperty>(Expression<Func<TClass, TP
{
_rootAccessor.Properties.Add(accessor);
}

if (includeFilter != null)
{
if (latestAccessor is PropertyAccessor<TClass, TProperty> accessorImpl)
Expand All @@ -53,8 +53,8 @@ public IRootAccessor<TClass> MapProperties<TProperty>(Expression<Func<TClass, TP
//can type checking be added here?
var propertyAccessor = _rootAccessor.GetByPath(PathWalker.GetPath(expr));
dynamic dynamicAccessor = propertyAccessor;
AddFilterForDynamicType(dynamicAccessor, includeFilter);
}
AddFilterForDynamicType(dynamicAccessor, includeFilter);
}
}

return _rootAccessor;
Expand Down Expand Up @@ -99,7 +99,7 @@ private IPropertyAccessor CreateAccessor(MemberExpression node)
var declaringType = GetTypeToUse(node.Member.DeclaringType);
var nodeType = GetTypeToUse(node.Type);

if (latestAccessor != null && latestAccessor.DeclaringType == declaringType &&
if (latestAccessor != null && latestAccessor.DeclaringType == declaringType &&
latestAccessor.PropertyName == node.Member.Name)
{
return latestAccessor;
Expand All @@ -125,9 +125,9 @@ internal static Type GetTypeToUse(Type type)
{
var genericTypeDefinition = type.GetGenericTypeDefinition();

if (genericTypeDefinition.GetInterfaces()
.Any(t => t.IsGenericType &&
t.GetGenericTypeDefinition() == typeof(IEnumerable<>)))
if (genericTypeDefinition == typeof(IEnumerable<>) ||
genericTypeDefinition.GetInterfaces().Any(t => t.IsGenericType &&
t.GetGenericTypeDefinition() == typeof(IEnumerable<>)))
{
return type.GetGenericArguments()[0];
}
Expand Down

0 comments on commit 5e3a78e

Please sign in to comment.