Skip to content
Permalink
Browse files

Add multi-key SortedAggregator support

  • Loading branch information...
a046 committed Feb 22, 2019
1 parent 8fef5b1 commit ce9d178a7202cd34a15785cc9c27f87fca70f7dc
Showing with 1,239 additions and 71 deletions.
  1. +8 −0 src/NRules/NRules.Fluent/Dsl/IQuery.cs
  2. +1 −1 src/NRules/NRules.Fluent/Dsl/QueryExpression.cs
  3. +31 −2 src/NRules/NRules.Fluent/Dsl/QueryExtensions.cs
  4. +9 −3 src/NRules/NRules.RuleModel/AggregateElement.cs
  5. +5 −5 src/NRules/NRules.RuleModel/Builders/AggregateBuilder.cs
  6. +7 −7 src/NRules/NRules.RuleModel/Builders/Element.cs
  7. +7 −14 src/NRules/NRules.RuleModel/Builders/ElementValidator.cs
  8. +1 −1 src/NRules/NRules.RuleModel/Builders/RuleTransformation.cs
  9. +3 −3 src/NRules/NRules.RuleModel/ExpressionCollection.cs
  10. +1 −1 src/NRules/NRules.RuleModel/RuleElementVisitor.cs
  11. +31 −11 src/NRules/NRules/Aggregators/CollectionAggregatorFactory.cs
  12. +2 −2 src/NRules/NRules/Aggregators/FlatteningAggregatorFactory.cs
  13. +4 −4 src/NRules/NRules/Aggregators/GroupByAggregatorFactory.cs
  14. +90 −0 src/NRules/NRules/Aggregators/MultiKeySortedAggregator.cs
  15. +2 −2 src/NRules/NRules/Aggregators/ProjectionAggregatorFactory.cs
  16. +25 −7 src/NRules/NRules/Aggregators/SortedAggregator.cs
  17. +1 −1 src/NRules/NRules/Diagnostics/NodeInfo.cs
  18. +3 −3 src/NRules/NRules/Rete/AggregateNode.cs
  19. +3 −3 src/NRules/NRules/Rete/ReteBuilder.cs
  20. +1 −1 src/NRules/Tests/NRules.IntegrationTests/CustomSelectAggregatorTest.cs
  21. +199 −0 ...ests/NRules.IntegrationTests/OneFactOneMultiKeySortedCollectionAscendingThenDescendingRuleTest.cs
  22. +199 −0 ...ests/NRules.IntegrationTests/OneFactOneMultiKeySortedCollectionDescendingThenAscendingRuleTest.cs
  23. +116 −0 ...ules/Tests/NRules.IntegrationTests/OneFactOneMultiKeySortedCollectionManyChainedThenByRuleTest.cs
  24. +490 −0 src/NRules/Tests/NRules.Tests/Aggregators/MultiKeySortedAggregatorTest.cs
@@ -36,4 +36,12 @@ public interface IQuery<out TSource>
public interface ICollectQuery<out TSource> : IQuery<TSource>
{
}

/// <summary>
/// Intermediate query chain element used for OrderBy modifiers.
/// </summary>
/// <typeparam name="TSource">Type of the element the query operates on.</typeparam>
public interface IOrderedQuery<out TSource> : IQuery<TSource>
{
}
}
@@ -4,7 +4,7 @@
/// Expression builder for queries.
/// </summary>
/// <typeparam name="TSource">Type of query source.</typeparam>
public class QueryExpression<TSource> : IQuery<TSource>, ICollectQuery<TSource>
public class QueryExpression<TSource> : IQuery<TSource>, ICollectQuery<TSource>, IOrderedQuery<TSource>
{
/// <summary>
/// Constructs a query expression builder that wraps a <see cref="IQueryBuilder"/>.
@@ -129,7 +129,7 @@ public static ICollectQuery<IEnumerable<TSource>> Collect<TSource>(this IQuery<T
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IQuery<IEnumerable<TSource>> OrderBy<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
public static IOrderedQuery<IEnumerable<TSource>> OrderBy<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Ascending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
@@ -143,7 +143,36 @@ public static ICollectQuery<IEnumerable<TSource>> Collect<TSource>(this IQuery<T
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IQuery<IEnumerable<TSource>> OrderByDescending<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
public static IOrderedQuery<IEnumerable<TSource>> OrderByDescending<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Descending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
}


/// <summary>
/// Configures sorted matching facts to subsequently be sorted ascending by key.
/// </summary>
/// <typeparam name="TSource">Type of source facts.</typeparam>
/// <typeparam name="TKey">Type of sorting key.</typeparam>
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IOrderedQuery<IEnumerable<TSource>> ThenBy<TSource, TKey>(this IOrderedQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Ascending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
}

/// <summary>
/// Configures sorted matching facts to subsequently be sorted descending by key.
/// </summary>
/// <typeparam name="TSource">Type of source facts.</typeparam>
/// <typeparam name="TKey">Type of sorting key.</typeparam>
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IOrderedQuery<IEnumerable<TSource>> ThenByDescending<TSource, TKey>(this IOrderedQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Descending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
@@ -12,6 +12,12 @@ public class AggregateElement : PatternSourceElement
public const string ProjectName = "Project";
public const string FlattenName = "Flatten";

public const string SelectorName = "Selector";
public const string ElementSelectorName = "ElementSelector";
public const string KeySelectorName = "KeySelector";
public const string KeySelectorAscendingName = "KeySelectorAscending";
public const string KeySelectorDescendingName = "KeySelectorDescending";

/// <summary>
/// Fact source of the aggregate.
/// </summary>
@@ -30,13 +36,13 @@ public class AggregateElement : PatternSourceElement
/// <summary>
/// Expressions used by the aggregate.
/// </summary>
public ExpressionCollection Expressions { get; }
public ExpressionCollection ExpressionCollection { get; }

internal AggregateElement(Type resultType, string name, ExpressionCollection expressions, PatternElement source, Type customFactoryType)
internal AggregateElement(Type resultType, string name, ExpressionCollection expressionCollection, PatternElement source, Type customFactoryType)
: base(resultType)
{
Name = name;
Expressions = expressions;
ExpressionCollection = expressionCollection;
Source = source;
CustomFactoryType = customFactoryType;

@@ -59,7 +59,7 @@ public void Collect()
/// <param name="sortDirection">Order to sort the aggregation in.</param>
public void OrderBy(LambdaExpression keySelector, SortDirection sortDirection)
{
var expressionName = sortDirection == SortDirection.Ascending ? "KeySelectorAscending" : "KeySelectorDescending";
var expressionName = sortDirection == SortDirection.Ascending ? AggregateElement.KeySelectorAscendingName : AggregateElement.KeySelectorDescendingName;
AddExpression(expressionName, keySelector);
}

@@ -71,8 +71,8 @@ public void OrderBy(LambdaExpression keySelector, SortDirection sortDirection)
public void GroupBy(LambdaExpression keySelector, LambdaExpression elementSelector)
{
_name = AggregateElement.GroupByName;
AddExpression("KeySelector", keySelector);
AddExpression("ElementSelector", elementSelector);
AddExpression(AggregateElement.KeySelectorName, keySelector);
AddExpression(AggregateElement.ElementSelectorName, elementSelector);
}

/// <summary>
@@ -82,7 +82,7 @@ public void GroupBy(LambdaExpression keySelector, LambdaExpression elementSelect
public void Project(LambdaExpression selector)
{
_name = AggregateElement.ProjectName;
AddExpression("Selector", selector);
AddExpression(AggregateElement.SelectorName, selector);
}

/// <summary>
@@ -92,7 +92,7 @@ public void Project(LambdaExpression selector)
public void Flatten(LambdaExpression selector)
{
_name = AggregateElement.FlattenName;
AddExpression("Selector", selector);
AddExpression(AggregateElement.SelectorName, selector);
}

/// <summary>
@@ -506,10 +506,10 @@ public static AggregateElement GroupBy(Type resultType, LambdaExpression keySele
resultType = groupingType.MakeGenericType(keySelector.ReturnType, elementSelector.ReturnType);
}

var expressions = new List<KeyValuePair<string, LambdaExpression>>
var expressions = new Dictionary<string, LambdaExpression>
{
new KeyValuePair<string, LambdaExpression>("KeySelector", keySelector),
new KeyValuePair<string, LambdaExpression>("ElementSelector", elementSelector)
{ AggregateElement.KeySelectorName, keySelector },
{ AggregateElement.ElementSelectorName, elementSelector }
};
var element = Aggregate(resultType, AggregateElement.GroupByName, expressions, source);
return element;
@@ -544,9 +544,9 @@ public static AggregateElement Project(Type resultType, LambdaExpression selecto
resultType = selector.ReturnType;
}

var expressions = new List<KeyValuePair<string, LambdaExpression>>
var expressions = new Dictionary<string, LambdaExpression>
{
new KeyValuePair<string, LambdaExpression>("Selector", selector)
{ AggregateElement.SelectorName, selector }
};
var element = Aggregate(resultType, AggregateElement.ProjectName, expressions, source);
return element;
@@ -564,9 +564,9 @@ public static AggregateElement Flatten(Type resultType, LambdaExpression selecto
if (selector == null)
throw new ArgumentNullException(nameof(selector), "Flattening selector not provided");

var expressions = new List<KeyValuePair<string, LambdaExpression>>
var expressions = new Dictionary<string, LambdaExpression>
{
new KeyValuePair<string, LambdaExpression>("Selector", selector)
{ AggregateElement.SelectorName, selector }
};
var element = Aggregate(resultType, AggregateElement.FlattenName, expressions, source);
return element;
@@ -93,17 +93,10 @@ public static void ValidateCollectAggregate(AggregateElement element)
$"Collect result must be a collection of source elements. ElementType={sourceType}, ResultType={resultType}");
}

var keySelectorAscending = element.Expressions.FindSingleOrDefault("KeySelectorAscending")?.Expression;
var keySelectorDescending = element.Expressions.FindSingleOrDefault("KeySelectorDescending")?.Expression;
var keySelectorsAscending = element.ExpressionCollection.Find(AggregateElement.KeySelectorAscendingName);
var keySelectorsDescending = element.ExpressionCollection.Find(AggregateElement.KeySelectorDescendingName);

if (keySelectorAscending != null && keySelectorDescending != null)
{
throw new ArgumentException(
"Must have a single key selector for sorting");
}

var sortKeySelector = keySelectorAscending ?? keySelectorDescending;
if (sortKeySelector != null)
foreach (var sortKeySelector in keySelectorsAscending.Concat(keySelectorsDescending).Select(x => x.Expression))
{
if (sortKeySelector.Parameters.Count == 0)
{
@@ -123,7 +116,7 @@ public static void ValidateGroupByAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var keySelector = element.Expressions["KeySelector"].Expression;
var keySelector = element.ExpressionCollection[AggregateElement.KeySelectorName].Expression;
if (keySelector.Parameters.Count == 0)
{
throw new ArgumentException(
@@ -136,7 +129,7 @@ public static void ValidateGroupByAggregate(AggregateElement element)
$"KeySelector={keySelector}, ExpectedType={sourceType}, ActualType={keySelector.Parameters[0].Type}");
}

var elementSelector = element.Expressions["ElementSelector"].Expression;
var elementSelector = element.ExpressionCollection["ElementSelector"].Expression;
if (elementSelector.Parameters.Count == 0)
{
throw new ArgumentException(
@@ -162,7 +155,7 @@ public static void ValidateProjectAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var selector = element.Expressions["Selector"].Expression;
var selector = element.ExpressionCollection[AggregateElement.SelectorName].Expression;
if (selector.Parameters.Count == 0)
{
throw new ArgumentException(
@@ -186,7 +179,7 @@ public static void ValidateFlattenAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var selector = element.Expressions["Selector"].Expression;
var selector = element.ExpressionCollection[AggregateElement.SelectorName].Expression;
if (selector.Parameters.Count != 1)
{
throw new ArgumentException(
@@ -81,7 +81,7 @@ protected internal override void VisitAggregate(Context context, AggregateElemen
var source = Transform<PatternElement>(context, element.Source);
if (context.IsModified)
{
var aggregateExpressions = element.Expressions.Select(x => new KeyValuePair<string, LambdaExpression>(x.Name, x.Expression));
var aggregateExpressions = element.ExpressionCollection.Select(x => new KeyValuePair<string, LambdaExpression>(x.Name, x.Expression));
var newElement = Element.Aggregate(element.ResultType, element.Name, aggregateExpressions, source, element.CustomFactoryType);
Result(context, newElement);
}
@@ -6,7 +6,7 @@
namespace NRules.RuleModel
{
/// <summary>
/// Ordered readonly collection of named expressions.
/// Sorted readonly map of named expressions.
/// </summary>
public class ExpressionCollection : IEnumerable<NamedExpressionElement>
{
@@ -18,7 +18,7 @@ public ExpressionCollection(IEnumerable<NamedExpressionElement> expressions)
}

/// <summary>
/// Number of expressions in the collection.
/// Number of expressions in the map.
/// </summary>
public int Count => _expressions.Count;

@@ -52,7 +52,7 @@ public IEnumerable<NamedExpressionElement> Find(string name)
}

/// <summary>
/// Retrieves single expression by name.
/// Retrieves only expression by name.
/// </summary>
/// <param name="name">Expression name.</param>
/// <returns>Matching expression or <c>null</c>.</returns>
@@ -30,7 +30,7 @@ protected internal virtual void VisitCondition(TContext context, ConditionElemen

protected internal virtual void VisitAggregate(TContext context, AggregateElement element)
{
foreach (var expression in element.Expressions)
foreach (var expression in element.ExpressionCollection)
{
expression.Accept(context, this);
}
@@ -8,7 +8,7 @@
namespace NRules.Aggregators
{
/// <summary>
/// Aggregator factory for collection aggregator.
/// Aggregator factory for collection aggregator, including modifiers such as OrderBy.
/// </summary>
internal class CollectionAggregatorFactory : IAggregatorFactory
{
@@ -18,15 +18,19 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
{
var sourceType = element.Source.ValueType;

var ascendingSortSelector = element.Expressions.FindSingleOrDefault("KeySelectorAscending");
var descendingSortSelector = element.Expressions.FindSingleOrDefault("KeySelectorDescending");
if (ascendingSortSelector != null)
var sortCriteriaKeySelectors = compiledExpressions.Where(x => x.Name == AggregateElement.KeySelectorAscendingName || x.Name == AggregateElement.KeySelectorDescendingName).ToArray();
if (sortCriteriaKeySelectors.Any())
{
_factory = CreateSortedAggregatorFactory(sourceType, SortDirection.Ascending, ascendingSortSelector, compiledExpressions.FindSingle("KeySelectorAscending"));
}
else if (descendingSortSelector != null)
{
_factory = CreateSortedAggregatorFactory(sourceType, SortDirection.Descending, descendingSortSelector, compiledExpressions.FindSingle("KeySelectorDescending"));
if (sortCriteriaKeySelectors.Length == 1)
{
var keySelector = sortCriteriaKeySelectors[0];
_factory = CreateSingleKeySortedAggregatorFactory(sourceType, GetSortDirection(keySelector.Name), element.ExpressionCollection.FindSingleOrDefault(keySelector.Name), keySelector);
}
else
{
var sortCriteria = compiledExpressions.Select(x => new SortCriteria(x, GetSortDirection(x.Name))).ToArray();
_factory = CreateMultiKeySortedAggregatorFactory(sourceType, sortCriteria);
}
}
else
{
@@ -37,7 +41,12 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
}
}

private static Func<IAggregator> CreateSortedAggregatorFactory(Type sourceType, SortDirection sortDirection, NamedExpressionElement selector, IAggregateExpression compiledSelector)
private static SortDirection GetSortDirection(string keySelectorName)
{
return keySelectorName == AggregateElement.KeySelectorAscendingName ? SortDirection.Ascending : SortDirection.Descending;
}

private static Func<IAggregator> CreateSingleKeySortedAggregatorFactory(Type sourceType, SortDirection sortDirection, NamedExpressionElement selector, IAggregateExpression compiledSelector)
{
var resultType = selector.Expression.ReturnType;
var aggregatorType = typeof(SortedAggregator<,>).MakeGenericType(sourceType, resultType);
@@ -49,9 +58,20 @@ private static Func<IAggregator> CreateSortedAggregatorFactory(Type sourceType,
return factoryExpression.Compile();
}

private static Func<IAggregator> CreateMultiKeySortedAggregatorFactory(Type sourceType, SortCriteria[] sortCriterias)
{
var aggregatorType = typeof(MultiKeySortedAggregator<>).MakeGenericType(sourceType);

var ctor = aggregatorType.GetTypeInfo().DeclaredConstructors.Single();
var factoryExpression = Expression.Lambda<Func<IAggregator>>(
Expression.New(ctor, Expression.Constant(sortCriterias)));

return factoryExpression.Compile();
}

public IAggregator Create()
{
return _factory();
}
}
}
}
@@ -2,8 +2,8 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using NRules.RuleModel;
using System.Reflection;
using NRules.RuleModel;

namespace NRules.Aggregators
{
@@ -21,7 +21,7 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
var resultType = element.ResultType;
Type aggregatorType = typeof(FlatteningAggregator<,>).MakeGenericType(sourceType, resultType);

var compiledSelector = compiledExpressions.FindSingle("Selector");
var compiledSelector = compiledExpressions.FindSingle(AggregateElement.SelectorName);
var ctor = aggregatorType.GetTypeInfo().DeclaredConstructors.Single();
var factoryExpression = Expression.Lambda<Func<IAggregator>>(
Expression.New(ctor, Expression.Constant(compiledSelector)));
Oops, something went wrong.

0 comments on commit ce9d178

Please sign in to comment.
You can’t perform that action at this time.