优化average [#138]

This commit is contained in:
xuejiaming 2022-04-29 11:07:32 +08:00
parent 062f823bbf
commit 83145e8bfd
4 changed files with 89 additions and 74 deletions

View File

@ -21,7 +21,7 @@ namespace ShardingCore.Sharding.StreamMergeEngines.AggregateMergeEngines
* @Ver: 1.0
* @Email: 326308290@qq.com
*/
internal class AverageAsyncInMemoryMergeEngine<TEntity, TResult,TSelect> : AbstractNoTripEnsureMethodCallInMemoryAsyncMergeEngine<TEntity, TResult>
internal class AverageAsyncInMemoryMergeEngine<TEntity, TResult, TSelect> : AbstractNoTripEnsureMethodCallInMemoryAsyncMergeEngine<TEntity, TResult>
{
public AverageAsyncInMemoryMergeEngine(StreamMergeContext<TEntity> streamMergeContext) : base(streamMergeContext)
{
@ -29,89 +29,31 @@ namespace ShardingCore.Sharding.StreamMergeEngines.AggregateMergeEngines
private async Task<List<RouteQueryResult<AverageResult<T>>>> AggregateAverageResultAsync<T>(CancellationToken cancellationToken = new CancellationToken())
{
return (await base.ExecuteAsync(
async queryable =>
{
var count = await ((IQueryable<T>)queryable).LongCountAsync(cancellationToken);
var count = 0L;
T sum = default;
var newQueryable = ((IQueryable<T>)queryable);
var r = await newQueryable.GroupBy(o => 1).BuildExpression().FirstOrDefaultAsync(cancellationToken);
if (r != null)
{
count = r.Item1;
sum = r.Item2;
}
if (count <= 0)
{
return default;
}
var sum = await GetSumAsync<T>(queryable, cancellationToken);
return new AverageResult<T>(sum, count);
},
cancellationToken)).Where(o => o.QueryResult != null).ToList();
// return (await base.ExecuteAsync(
// async queryable =>
// {
// var count = 0L;
// T sum = default;
// //MethodInfo sumMethod = typeof(Queryable).GetMethods().First(
// // m => m.Name == nameof(Queryable.Sum)
// // && m.ReturnType == typeof(T)
// // && m.IsGenericMethod);
// //var genericSumMethod = sumMethod.MakeGenericMethod(new[] { source.ElementType });
// var newQueryable = ((IQueryable<T>)queryable).Select(o=>(decimal?)(object)o);
//#if !EFCORE2
// var r = await newQueryable.GroupBy(o=>1).Select(o=>new
// {
// C= o.LongCount(),
// //S = ShardingEntityFrameworkQueryableExtensions.Execute<T,T>(ShardingQueryableMethods.GetSumWithoutSelector(typeof(T)), newQueryable, (Expression)null)
// S = o.Sum()
// }).FirstOrDefaultAsync(cancellationToken);
// ////https://stackoverflow.com/questions/21143179/build-groupby-expression-tree-with-multiple-fields
// ////https://blog.wiseowls.co.nz/index.php/2021/05/13/ef-core-3-1-dynamic-groupby-clause/
// ////https://blog.wiseowls.co.nz/index.php/2021/05/13/ef-core-3-1-dynamic-groupby-clause/
// ////https://stackoverflow.com/questions/39728898/groupby-query-by-linq-expressions-and-lambdas
// // Expression.New(
// // Type.GetType("System.Tuple`" + fields.Length)
// // .MakeGenericType(fields.Select(studentType.GetProperty),
// // fields.Select(f => Expression.PropertyOrField(itemParam, f))
// // )
// // if (r != null)
// // {
// // count = r.C;
// // //sum = r.S;
// // }
//#endif
//#if EFCORE2
// count = await ((IQueryable<T>)queryable).LongCountAsync(cancellationToken);
// if (count <= 0)
// {
// return default;
// }
// sum = await GetSumAsync<T>(queryable, cancellationToken);
//#endif
// return new AverageResult<T>(sum, count);
// },
// cancellationToken)).Where(o => o.QueryResult != null).ToList();
}
private async Task<T> GetSumAsync<T>(IQueryable queryable,
CancellationToken cancellationToken = new CancellationToken())
{
var resultType = typeof(T);
if (!resultType.IsNumericType())
throw new ShardingCoreException(
$"not support {GetStreamMergeContext().MergeQueryCompilerContext.GetQueryExpression().ShardingPrint()} result {resultType}");
#if !EFCORE2
return await ShardingEntityFrameworkQueryableExtensions.ExecuteAsync<T, Task<T>>(ShardingQueryableMethods.GetSumWithoutSelector(resultType), (IQueryable<T>)queryable, (Expression)null, cancellationToken);
#endif
#if EFCORE2
return await ShardingEntityFrameworkQueryableExtensions.ExecuteAsync<T, T>(ShardingQueryableMethods.GetSumWithoutSelector(resultType), (IQueryable<T>)queryable, cancellationToken);
#endif
}
public override async Task<TResult> MergeResultAsync(
CancellationToken cancellationToken = new CancellationToken())
{
@ -133,7 +75,7 @@ namespace ShardingCore.Sharding.StreamMergeEngines.AggregateMergeEngines
var count = queryable.Sum(o => o.Count);
return AggregateExtension.AverageConstant<TSelect, long, TResult>(sum, count);
}
}
}

View File

@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Query;
namespace ShardingCore.Sharding.MergeEngines.AggregateMergeEngines
{
/// <summary>
/// https://github.com/tkhadimullin/dynamic-ef-examples
/// </summary>
/// Author: xjm
/// Created: 2022/4/29 9:32:16
/// Email: 326308290@qq.com
public static class AverageMergeEngineExtension
{
public static IQueryable<Tuple<long, T>> BuildExpression<T>(this IQueryable<IGrouping<int, T>> queryable)
{
var sourceParameter = Expression.Parameter(typeof(IQueryable<IGrouping<int, T>>));
var selectCall = BuildSelect<T>(sourceParameter);
var lambda = Expression.Lambda<Func<IQueryable<IGrouping<int, T>>, IQueryable<Tuple<long, T>>>>(selectCall, sourceParameter);
var compile = lambda.Compile();
return compile(queryable);
}
private static MethodCallExpression BuildSelect<T>(this ParameterExpression sourceParameter)
{
var groupingType = typeof(IGrouping<int, T>);
var selectMethod = ShardingQueryableMethods.Select.MakeGenericMethod(groupingType, typeof(Tuple<long, T>));
var resultParameter = Expression.Parameter(groupingType);
var longCountCall = BuildLongCount<T>(resultParameter);
var sumCall = BuildSum<T>(resultParameter);
var resultSelector = Expression.New(typeof(Tuple<long, T>).GetConstructors().First(), longCountCall, sumCall);
//queryable.Expression,
return Expression.Call(selectMethod, sourceParameter,Expression.Lambda(resultSelector, resultParameter));
}
private static MethodCallExpression BuildLongCount<T>(ParameterExpression resultParameter)
{
var asQueryableMethod =ShardingQueryableMethods.AsQueryable.MakeGenericMethod(typeof(T));
var longCountMethod = ShardingQueryableMethods.LongCountWithoutPredicate.MakeGenericMethod(typeof(T));
return Expression.Call(longCountMethod, Expression.Call(asQueryableMethod, resultParameter));
}
private static MethodCallExpression BuildSum<T>(ParameterExpression resultParameter)
{
var asQueryableMethod =ShardingQueryableMethods.AsQueryable.MakeGenericMethod(typeof(T));
var sumMethod = ShardingQueryableMethods.GetSumWithoutSelector(typeof(T));
return Expression.Call(sumMethod, Expression.Call(asQueryableMethod, resultParameter));
}
}
}

View File

@ -26,6 +26,9 @@ namespace ShardingCore.Sharding
private static Dictionary<Type, MethodInfo> SumWithoutSelectorMethods { get; }
private static Dictionary<Type, MethodInfo> SumWithSelectorMethods { get; }
public static MethodInfo AsQueryable { get; }
public static MethodInfo LongCountWithoutPredicate { get; }
public static MethodInfo Select { get; }
static ShardingQueryableMethods()
{
@ -47,6 +50,20 @@ namespace ShardingCore.Sharding
ShardingQueryableMethods.AverageWithSelectorMethods = new Dictionary<Type, MethodInfo>();
ShardingQueryableMethods.SumWithoutSelectorMethods = new Dictionary<Type, MethodInfo>();
ShardingQueryableMethods.SumWithSelectorMethods = new Dictionary<Type, MethodInfo>();
ShardingQueryableMethods.AsQueryable = GetMethod(nameof(AsQueryable), 1, (Func<Type[], Type[]>)(types => new Type[1]
{
typeof (IEnumerable<>).MakeGenericType(types[0])
}));
ShardingQueryableMethods.LongCountWithoutPredicate = GetMethod("LongCount", 1, (Func<Type[], Type[]>)(types => new Type[1]
{
typeof (IQueryable<>).MakeGenericType(types[0])
}));
ShardingQueryableMethods.Select = GetMethod(nameof(Select), 2, (Func<Type[], Type[]>)(types => new Type[2]
{
typeof (IQueryable<>).MakeGenericType(types[0]),
typeof (Expression<>).MakeGenericType(typeof (Func<,>).MakeGenericType(types[0], types[1]))
}));
foreach (Type type1 in typeArray)
{
Type type = type1;
@ -68,6 +85,7 @@ namespace ShardingCore.Sharding
typeof (IQueryable<>).MakeGenericType(types[0]),
typeof (Expression<>).MakeGenericType(typeof (Func<,>).MakeGenericType(types[0], type))
}));
}
MethodInfo GetMethod(
string name,

View File

@ -979,7 +979,7 @@ namespace ShardingCore.Test
var fourBegin = new DateTime(2021, 4, 1).Date;
var fiveBegin = new DateTime(2021, 5, 1).Date;
var moneyAverage = await _virtualDbContext.Set<Order>()
.Where(o => o.CreateTime >= fourBegin && o.CreateTime <= fiveBegin).Select(o => o.Money).AverageAsync();
.Where(o => o.CreateTime >= fourBegin&& o.CreateTime <= fiveBegin).Select(o => o.Money).AverageAsync();
Assert.Equal(105, moneyAverage);
using (_shardingRouteManager.CreateScope())