183 lines
9.2 KiB
C#
183 lines
9.2 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Text;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.EntityFrameworkCore.Query;
|
|
using Microsoft.EntityFrameworkCore.Query.Internal;
|
|
using ShardingCore.Exceptions;
|
|
using ShardingCore.Extensions;
|
|
using ShardingCore.Extensions.ShardingQueryableExtensions;
|
|
using ShardingCore.Sharding.Abstractions;
|
|
using ShardingCore.Sharding.ShardingExecutors.Abstractions;
|
|
using ShardingCore.Sharding.Visitors.Selects;
|
|
|
|
#if EFCORE2
|
|
using Microsoft.EntityFrameworkCore.Query.Sql;
|
|
#endif
|
|
|
|
namespace ShardingCore.Sharding.MergeContexts
|
|
{
|
|
public sealed class QueryableRewriteEngine : IQueryableRewriteEngine
|
|
{
|
|
private static readonly ISet<string> singleEntityMethodNames = new HashSet<string>();
|
|
private static readonly ISet<string> supportSingleEntityMethodNames = new HashSet<string>();
|
|
|
|
static QueryableRewriteEngine()
|
|
{
|
|
supportSingleEntityMethodNames.Add(nameof(Enumerable.First));
|
|
supportSingleEntityMethodNames.Add(nameof(Enumerable.FirstOrDefault));
|
|
singleEntityMethodNames.Add(nameof(Enumerable.Last));
|
|
singleEntityMethodNames.Add(nameof(Enumerable.LastOrDefault));
|
|
singleEntityMethodNames.Add(nameof(Enumerable.Single));
|
|
singleEntityMethodNames.Add(nameof(Enumerable.SingleOrDefault));
|
|
}
|
|
|
|
public IRewriteResult GetRewriteQueryable(IMergeQueryCompilerContext mergeQueryCompilerContext, IParseResult parseResult)
|
|
{
|
|
var paginationContext = parseResult.GetPaginationContext();
|
|
var orderByContext = parseResult.GetOrderByContext();
|
|
var groupByContext = parseResult.GetGroupByContext();
|
|
var selectContext = parseResult.GetSelectContext();
|
|
var skip = paginationContext.Skip;
|
|
var take = paginationContext.Take;
|
|
var orders = orderByContext.PropertyOrders;
|
|
|
|
if (skip.HasValue && skip.Value > 0)
|
|
{
|
|
if (!mergeQueryCompilerContext.IsEnumerableQuery())
|
|
{
|
|
var queryMethodName = mergeQueryCompilerContext.GetQueryMethodName();
|
|
if (singleEntityMethodNames.Contains(queryMethodName))
|
|
{
|
|
//todo 修复做兼容
|
|
throw new ShardingCoreInvalidOperationException(
|
|
$"single query:[{mergeQueryCompilerContext.GetQueryExpression().ShardingPrint()}] cant use skip:{skip.Value},u should use {nameof(Enumerable.ToList)} than use skip in {nameof(IEnumerable<object>)}");
|
|
}
|
|
}
|
|
}
|
|
|
|
var combineQueryable = mergeQueryCompilerContext.GetQueryCombineResult().GetCombineQueryable();
|
|
//去除分页,获取前Take+Skip数量
|
|
var reWriteQueryable = combineQueryable;
|
|
if (take.HasValue)
|
|
{
|
|
reWriteQueryable = reWriteQueryable.RemoveTake();
|
|
}
|
|
|
|
if (skip.HasValue)
|
|
{
|
|
reWriteQueryable = reWriteQueryable.RemoveSkip();
|
|
}
|
|
|
|
//如果是first or default
|
|
var fixedTake = mergeQueryCompilerContext.GetFixedTake();
|
|
if (fixedTake.HasValue)
|
|
{
|
|
if (skip.HasValue)
|
|
{
|
|
reWriteQueryable = reWriteQueryable.ReSkip(0).ReTake(fixedTake.Value + skip.GetValueOrDefault());
|
|
}
|
|
else
|
|
{
|
|
reWriteQueryable = reWriteQueryable.ReTake(fixedTake.Value);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (take.HasValue)
|
|
{
|
|
if (skip.HasValue)
|
|
{
|
|
reWriteQueryable = reWriteQueryable.ReSkip(0).ReTake(take.Value + skip.GetValueOrDefault());
|
|
}
|
|
else
|
|
{
|
|
reWriteQueryable = reWriteQueryable.ReTake(take.Value + skip.GetValueOrDefault());
|
|
}
|
|
}
|
|
}
|
|
//包含group by
|
|
if (groupByContext.GroupExpression != null)
|
|
{
|
|
if (orders.IsEmpty())
|
|
{
|
|
//将查询的属性转换成order by
|
|
var selectProperties = selectContext.SelectProperties.Where(o => !(o is SelectAggregateProperty)).ToArray();
|
|
if (selectProperties.IsNotEmpty())
|
|
{
|
|
var sort = string.Join(",", selectProperties.Select(o => $"{o.PropertyName} asc"));
|
|
reWriteQueryable = reWriteQueryable.OrderWithExpression(sort, null);
|
|
foreach (var orderProperty in selectProperties)
|
|
{
|
|
orders.AddLast(new PropertyOrder(orderProperty.PropertyName, true, orderProperty.OwnerType));
|
|
}
|
|
}
|
|
}
|
|
else if (!mergeQueryCompilerContext.UseUnionAllMerge())
|
|
{
|
|
//将查询的属性转换成order by 并且order和select的未聚合查询必须一致
|
|
var selectProperties = selectContext.SelectProperties.Where(o => !(o is SelectAggregateProperty));
|
|
|
|
if (orders.Count() != selectProperties.Count())
|
|
throw new ShardingCoreInvalidOperationException("group by query order items not equal select un-aggregate items");
|
|
var os = orders.Select(o => o.PropertyExpression).ToList();
|
|
var ss = selectProperties.Select(o => o.PropertyName).ToList();
|
|
for (int i = 0; i < os.Count(); i++)
|
|
{
|
|
if (!os[i].Equals(ss[i]))
|
|
throw new ShardingCoreInvalidOperationException($"group by query order items not equal select un-aggregate items: order:[{os[i]}],select:[{ss[i]}");
|
|
}
|
|
|
|
}
|
|
if (selectContext.HasAverage())
|
|
{
|
|
var averageSelectProperties = selectContext.SelectProperties.OfType<SelectAverageProperty>().ToList();
|
|
var selectAggregateProperties = selectContext.SelectProperties.OfType<SelectAggregateProperty>().Where(o => !(o is SelectAverageProperty)).ToList();
|
|
foreach (var averageSelectProperty in averageSelectProperties)
|
|
{
|
|
var selectCountProperty = selectAggregateProperties.FirstOrDefault(o => o is SelectCountProperty selectCountProperty);
|
|
if (null != selectCountProperty)
|
|
{
|
|
averageSelectProperty.BindCountProperty(selectCountProperty.Property);
|
|
}
|
|
var selectSumProperty = selectAggregateProperties.FirstOrDefault(o => o is SelectSumProperty selectSumProperty && selectSumProperty.FromProperty == averageSelectProperty.FromProperty);
|
|
if (selectSumProperty != null)
|
|
{
|
|
averageSelectProperty.BindSumProperty(selectSumProperty.Property);
|
|
}
|
|
|
|
if (averageSelectProperty.CountProperty == null && averageSelectProperty.SumProperty == null)
|
|
throw new ShardingCoreInvalidOperationException(
|
|
$"use aggregate function average error,not found count aggregate function and not found sum aggregate function that property name same as average aggregate function property name:[{averageSelectProperty.FromProperty?.Name}]");
|
|
}
|
|
}
|
|
//else
|
|
//{
|
|
// //将查询的属性转换成order by 并且order和select的未聚合查询必须一致
|
|
// var selectProperties = selectContext.SelectProperties.Where(o => !(o is SelectAggregateProperty));
|
|
|
|
// if (orders.Count() != selectProperties.Count())
|
|
// throw new ShardingCoreInvalidOperationException("group by query order items not equal select un-aggregate items");
|
|
// var os = orders.Select(o => o.PropertyExpression).ToList();
|
|
// var ss = selectProperties.Select(o => o.PropertyName).ToList();
|
|
// for (int i = 0; i < os.Count(); i++)
|
|
// {
|
|
// if (!os[i].Equals(ss[i]))
|
|
// throw new ShardingCoreInvalidOperationException($"group by query order items not equal select un-aggregate items: order:[{os[i]}],select:[{ss[i]}");
|
|
// }
|
|
//}
|
|
}
|
|
|
|
if (mergeQueryCompilerContext.UseUnionAllMerge() & !mergeQueryCompilerContext.GetShardingDbContext().SupportUnionAllMerge())
|
|
{
|
|
throw new ShardingCoreException(
|
|
$"if use {nameof(EntityFrameworkShardingQueryableExtension.UseUnionAllMerge)} plz rewrite {nameof(IQuerySqlGeneratorFactory)} with {nameof(IUnionAllMergeQuerySqlGeneratorFactory)} and {nameof(IQueryCompiler)} with {nameof(IUnionAllMergeQueryCompiler)}");
|
|
|
|
}
|
|
|
|
return new RewriteResult(combineQueryable,reWriteQueryable);
|
|
}
|
|
}
|
|
}
|