Skip to content
Permalink
Browse files

Merge pull request #181 from a046/feature-multiaggregator

Add multi-key SortedAggregator support
  • Loading branch information...
snikolayev committed Feb 23, 2019
2 parents 8fef5b1 + 7d1f833 commit 1cd8d2e41229d5d2ceee11c787e50924258f1777
@@ -36,4 +36,12 @@ public interface IQuery<out TSource>
public interface ICollectQuery<out TSource> : IQuery<TSource>
{
}

/// <summary>
/// Intermediate query chain element used for OrderBy modifiers.
/// </summary>
/// <typeparam name="TSource">Type of the element the query operates on.</typeparam>
public interface IOrderedQuery<out TSource> : IQuery<TSource>
{
}
}
@@ -4,7 +4,7 @@
/// Expression builder for queries.
/// </summary>
/// <typeparam name="TSource">Type of query source.</typeparam>
public class QueryExpression<TSource> : IQuery<TSource>, ICollectQuery<TSource>
public class QueryExpression<TSource> : IQuery<TSource>, ICollectQuery<TSource>, IOrderedQuery<TSource>
{
/// <summary>
/// Constructs a query expression builder that wraps a <see cref="IQueryBuilder"/>.
@@ -129,7 +129,7 @@ public static ICollectQuery<IEnumerable<TSource>> Collect<TSource>(this IQuery<T
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IQuery<IEnumerable<TSource>> OrderBy<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
public static IOrderedQuery<IEnumerable<TSource>> OrderBy<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Ascending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
@@ -143,7 +143,35 @@ public static ICollectQuery<IEnumerable<TSource>> Collect<TSource>(this IQuery<T
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IQuery<IEnumerable<TSource>> OrderByDescending<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
public static IOrderedQuery<IEnumerable<TSource>> OrderByDescending<TSource, TKey>(this ICollectQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Descending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
}

/// <summary>
/// Configures sorted matching facts to subsequently be sorted ascending by key.
/// </summary>
/// <typeparam name="TSource">Type of source facts.</typeparam>
/// <typeparam name="TKey">Type of sorting key.</typeparam>
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IOrderedQuery<IEnumerable<TSource>> ThenBy<TSource, TKey>(this IOrderedQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Ascending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
}

/// <summary>
/// Configures sorted matching facts to subsequently be sorted descending by key.
/// </summary>
/// <typeparam name="TSource">Type of source facts.</typeparam>
/// <typeparam name="TKey">Type of sorting key.</typeparam>
/// <param name="source">Query expression builder.</param>
/// <param name="keySelector">Key selection expression used for sorting.</param>
/// <returns>Query expression builder.</returns>
public static IOrderedQuery<IEnumerable<TSource>> ThenByDescending<TSource, TKey>(this IOrderedQuery<IEnumerable<TSource>> source, Expression<Func<TSource, TKey>> keySelector)
{
source.Builder.OrderBy(keySelector, SortDirection.Descending);
return new QueryExpression<IEnumerable<TSource>>(source.Builder);
@@ -12,6 +12,12 @@ public class AggregateElement : PatternSourceElement
public const string ProjectName = "Project";
public const string FlattenName = "Flatten";

public const string SelectorName = "Selector";
public const string ElementSelectorName = "ElementSelector";
public const string KeySelectorName = "KeySelector";
public const string KeySelectorAscendingName = "KeySelectorAscending";
public const string KeySelectorDescendingName = "KeySelectorDescending";

/// <summary>
/// Fact source of the aggregate.
/// </summary>
@@ -59,7 +59,7 @@ public void Collect()
/// <param name="sortDirection">Order to sort the aggregation in.</param>
public void OrderBy(LambdaExpression keySelector, SortDirection sortDirection)
{
var expressionName = sortDirection == SortDirection.Ascending ? "KeySelectorAscending" : "KeySelectorDescending";
var expressionName = sortDirection == SortDirection.Ascending ? AggregateElement.KeySelectorAscendingName : AggregateElement.KeySelectorDescendingName;
AddExpression(expressionName, keySelector);
}

@@ -71,8 +71,8 @@ public void OrderBy(LambdaExpression keySelector, SortDirection sortDirection)
public void GroupBy(LambdaExpression keySelector, LambdaExpression elementSelector)
{
_name = AggregateElement.GroupByName;
AddExpression("KeySelector", keySelector);
AddExpression("ElementSelector", elementSelector);
AddExpression(AggregateElement.KeySelectorName, keySelector);
AddExpression(AggregateElement.ElementSelectorName, elementSelector);
}

/// <summary>
@@ -82,7 +82,7 @@ public void GroupBy(LambdaExpression keySelector, LambdaExpression elementSelect
public void Project(LambdaExpression selector)
{
_name = AggregateElement.ProjectName;
AddExpression("Selector", selector);
AddExpression(AggregateElement.SelectorName, selector);
}

/// <summary>
@@ -92,7 +92,7 @@ public void Project(LambdaExpression selector)
public void Flatten(LambdaExpression selector)
{
_name = AggregateElement.FlattenName;
AddExpression("Selector", selector);
AddExpression(AggregateElement.SelectorName, selector);
}

/// <summary>
@@ -671,4 +671,4 @@ public static ActionElement Action(LambdaExpression expression)
return element;
}
}
}
}
@@ -93,17 +93,10 @@ public static void ValidateCollectAggregate(AggregateElement element)
$"Collect result must be a collection of source elements. ElementType={sourceType}, ResultType={resultType}");
}

var keySelectorAscending = element.Expressions.FindSingleOrDefault("KeySelectorAscending")?.Expression;
var keySelectorDescending = element.Expressions.FindSingleOrDefault("KeySelectorDescending")?.Expression;
var keySelectorsAscending = element.Expressions.Find(AggregateElement.KeySelectorAscendingName);
var keySelectorsDescending = element.Expressions.Find(AggregateElement.KeySelectorDescendingName);

if (keySelectorAscending != null && keySelectorDescending != null)
{
throw new ArgumentException(
"Must have a single key selector for sorting");
}

var sortKeySelector = keySelectorAscending ?? keySelectorDescending;
if (sortKeySelector != null)
foreach (var sortKeySelector in keySelectorsAscending.Concat(keySelectorsDescending).Select(x => x.Expression))
{
if (sortKeySelector.Parameters.Count == 0)
{
@@ -123,7 +116,7 @@ public static void ValidateGroupByAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var keySelector = element.Expressions["KeySelector"].Expression;
var keySelector = element.Expressions[AggregateElement.KeySelectorName].Expression;
if (keySelector.Parameters.Count == 0)
{
throw new ArgumentException(
@@ -162,7 +155,7 @@ public static void ValidateProjectAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var selector = element.Expressions["Selector"].Expression;
var selector = element.Expressions[AggregateElement.SelectorName].Expression;
if (selector.Parameters.Count == 0)
{
throw new ArgumentException(
@@ -186,7 +179,7 @@ public static void ValidateFlattenAggregate(AggregateElement element)
{
var sourceType = element.Source.ValueType;
var resultType = element.ResultType;
var selector = element.Expressions["Selector"].Expression;
var selector = element.Expressions[AggregateElement.SelectorName].Expression;
if (selector.Parameters.Count != 1)
{
throw new ArgumentException(
@@ -6,8 +6,8 @@
namespace NRules.RuleModel
{
/// <summary>
/// Ordered readonly collection of named expressions.
/// </summary>
/// Ordered readonly collection of named expressions.
/// /// </summary>
public class ExpressionCollection : IEnumerable<NamedExpressionElement>
{
private readonly List<NamedExpressionElement> _expressions;
@@ -52,7 +52,7 @@ public IEnumerable<NamedExpressionElement> Find(string name)
}

/// <summary>
/// Retrieves single expression by name.
/// Retrieves only expression by name.
/// </summary>
/// <param name="name">Expression name.</param>
/// <returns>Matching expression or <c>null</c>.</returns>
@@ -8,7 +8,7 @@
namespace NRules.Aggregators
{
/// <summary>
/// Aggregator factory for collection aggregator.
/// Aggregator factory for collection aggregator, including modifiers such as OrderBy.
/// </summary>
internal class CollectionAggregatorFactory : IAggregatorFactory
{
@@ -18,15 +18,19 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
{
var sourceType = element.Source.ValueType;

var ascendingSortSelector = element.Expressions.FindSingleOrDefault("KeySelectorAscending");
var descendingSortSelector = element.Expressions.FindSingleOrDefault("KeySelectorDescending");
if (ascendingSortSelector != null)
var sortCriteriaKeySelectors = compiledExpressions.Where(x => x.Name == AggregateElement.KeySelectorAscendingName || x.Name == AggregateElement.KeySelectorDescendingName).ToArray();
if (sortCriteriaKeySelectors.Any())
{
_factory = CreateSortedAggregatorFactory(sourceType, SortDirection.Ascending, ascendingSortSelector, compiledExpressions.FindSingle("KeySelectorAscending"));
}
else if (descendingSortSelector != null)
{
_factory = CreateSortedAggregatorFactory(sourceType, SortDirection.Descending, descendingSortSelector, compiledExpressions.FindSingle("KeySelectorDescending"));
if (sortCriteriaKeySelectors.Length == 1)
{
var keySelector = sortCriteriaKeySelectors[0];
_factory = CreateSingleKeySortedAggregatorFactory(sourceType, GetSortDirection(keySelector.Name), element.Expressions.FindSingleOrDefault(keySelector.Name), keySelector);
}
else
{
var sortCriteria = compiledExpressions.Select(x => new SortCriteria(x, GetSortDirection(x.Name))).ToArray();
_factory = CreateMultiKeySortedAggregatorFactory(sourceType, sortCriteria);
}
}
else
{
@@ -37,7 +41,12 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
}
}

private static Func<IAggregator> CreateSortedAggregatorFactory(Type sourceType, SortDirection sortDirection, NamedExpressionElement selector, IAggregateExpression compiledSelector)
private static SortDirection GetSortDirection(string keySelectorName)
{
return keySelectorName == AggregateElement.KeySelectorAscendingName ? SortDirection.Ascending : SortDirection.Descending;
}

private static Func<IAggregator> CreateSingleKeySortedAggregatorFactory(Type sourceType, SortDirection sortDirection, NamedExpressionElement selector, IAggregateExpression compiledSelector)
{
var resultType = selector.Expression.ReturnType;
var aggregatorType = typeof(SortedAggregator<,>).MakeGenericType(sourceType, resultType);
@@ -49,9 +58,20 @@ private static Func<IAggregator> CreateSortedAggregatorFactory(Type sourceType,
return factoryExpression.Compile();
}

private static Func<IAggregator> CreateMultiKeySortedAggregatorFactory(Type sourceType, SortCriteria[] sortCriterias)
{
var aggregatorType = typeof(MultiKeySortedAggregator<>).MakeGenericType(sourceType);

var ctor = aggregatorType.GetTypeInfo().DeclaredConstructors.Single();
var factoryExpression = Expression.Lambda<Func<IAggregator>>(
Expression.New(ctor, Expression.Constant(sortCriterias)));

return factoryExpression.Compile();
}

public IAggregator Create()
{
return _factory();
}
}
}
}
@@ -2,8 +2,8 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using NRules.RuleModel;
using System.Reflection;
using NRules.RuleModel;

namespace NRules.Aggregators
{
@@ -21,7 +21,7 @@ public void Compile(AggregateElement element, IEnumerable<IAggregateExpression>
var resultType = element.ResultType;
Type aggregatorType = typeof(FlatteningAggregator<,>).MakeGenericType(sourceType, resultType);

var compiledSelector = compiledExpressions.FindSingle("Selector");
var compiledSelector = compiledExpressions.FindSingle(AggregateElement.SelectorName);
var ctor = aggregatorType.GetTypeInfo().DeclaredConstructors.Single();
var factoryExpression = Expression.Lambda<Func<IAggregator>>(
Expression.New(ctor, Expression.Constant(compiledSelector)));
@@ -16,16 +16,16 @@ internal class GroupByAggregatorFactory : IAggregatorFactory

public void Compile(AggregateElement element, IEnumerable<IAggregateExpression> compiledExpressions)
{
var keySelector = element.Expressions["KeySelector"];
var elementSelector = element.Expressions["ElementSelector"];
var keySelector = element.Expressions[AggregateElement.KeySelectorName];
var elementSelector = element.Expressions[AggregateElement.ElementSelectorName];

var sourceType = element.Source.ValueType;
var keyType = keySelector.Expression.ReturnType;
var elementType = elementSelector.Expression.ReturnType;
Type aggregatorType = typeof(GroupByAggregator<,,>).MakeGenericType(sourceType, keyType, elementType);

var compiledKeySelector = compiledExpressions.FindSingle("KeySelector");
var compiledElementSelector = compiledExpressions.FindSingle("ElementSelector");
var compiledKeySelector = compiledExpressions.FindSingle(AggregateElement.KeySelectorName);
var compiledElementSelector = compiledExpressions.FindSingle(AggregateElement.ElementSelectorName);
var ctor = aggregatorType.GetTypeInfo().DeclaredConstructors.Single();
var factoryExpression = Expression.Lambda<Func<IAggregator>>(
Expression.New(ctor, Expression.Constant(compiledKeySelector), Expression.Constant(compiledElementSelector)));
@@ -0,0 +1,76 @@
using System.Collections.Generic;
using System.Linq;
using NRules.RuleModel;
using NRules.Utilities;

namespace NRules.Aggregators
{
internal class SortCriteria
{
public SortCriteria(IAggregateExpression expression, SortDirection direction)
{
KeySelector = expression;
Direction = direction;
}

public IAggregateExpression KeySelector { get; }

public SortDirection Direction { get; }
}

/// <summary>
/// Aggregate that adds matching facts into a collection sorted by a given key selector and sort direction.
/// </summary>
/// <typeparam name="TSource">Type of elements to collect.</typeparam>
internal class MultiKeySortedAggregator<TSource> : SortedAggregatorBase<TSource, object[]>
{
private readonly SortCriteria[] _sortCriterias;

public MultiKeySortedAggregator(IEnumerable<SortCriteria> sortCriterias)
: base(GetComparer(sortCriterias))
{
_sortCriterias = sortCriterias.ToArray();
}

private static IComparer<object[]> GetComparer(IEnumerable<SortCriteria> sortCriterias)
{
var comparers = new List<IComparer<object>>();
foreach (var sortCriteria in sortCriterias)
{
var defaultComparer = (IComparer<object>)Comparer<object>.Default;
var comparer = sortCriteria.Direction == SortDirection.Ascending ? defaultComparer : new ReverseComparer<object>(defaultComparer);
comparers.Add(comparer);
}

return new MultiKeyComparer(comparers);
}

protected override object[] GetKey(AggregationContext context, ITuple tuple, IFact fact)
{
return _sortCriterias.Select(x => x.KeySelector.Invoke(context, tuple, fact)).ToArray();
}
}

internal class MultiKeyComparer : IComparer<object[]>
{
readonly IComparer<object>[] _comparers;

public MultiKeyComparer(IEnumerable<IComparer<object>> comparers)
{
_comparers = comparers.ToArray();
}

public int Compare(object[] x, object[] y)
{
var result = 0;

for (int i = 0; i < _comparers.Length; i++)
{
result = _comparers[i].Compare(x[i], y[i]);
if (result != 0) break;
}

return result;
}
}
}
Oops, something went wrong.

0 comments on commit 1cd8d2e

Please sign in to comment.
You can’t perform that action at this time.