添加表达式查询的include的对象的支持

This commit is contained in:
xuejiaming 2021-12-17 11:06:46 +08:00
parent 3270e789ae
commit 8f7c851e22
12 changed files with 254 additions and 165 deletions

View File

@ -17,6 +17,7 @@ namespace ShardingCore.Core.TrackerManagers
{
bool AddDbContextModel(Type entityType);
bool EntityUseTrack(Type entityType);
bool IsDbContextModel(Type entityType);
}
public interface ITrackerManager<TShardingDbContext>: ITrackerManager where TShardingDbContext:DbContext,IShardingDbContext
{

View File

@ -34,5 +34,10 @@ namespace ShardingCore.Core.TrackerManagers
return false;
return _dbContextModels.Contains(entityType);
}
public bool IsDbContextModel(Type entityType)
{
return _dbContextModels.Contains(entityType);
}
}
}

View File

@ -20,10 +20,10 @@ namespace ShardingCore.Core.VirtualRoutes.DataSourceRoutes.RouteRuleEngine
public class DataSourceRouteRuleContext<T>
{
public ISet<Type> QueryEntities { get; }
public DataSourceRouteRuleContext(IQueryable<T> queryable)
public DataSourceRouteRuleContext(IQueryable<T> queryable,Type dbContextType)
{
Queryable = queryable;
QueryEntities = queryable.ParseQueryableEntities();
QueryEntities = queryable.ParseQueryableEntities(dbContextType);
}
/// <summary>
/// 查询条件

View File

@ -36,7 +36,7 @@ namespace ShardingCore.Core.VirtualRoutes.DataSourceRoutes.RouteRuleEngine
/// <returns></returns>
public DataSourceRouteRuleContext<T> CreateContext<T>(IQueryable<T> queryable)
{
return new DataSourceRouteRuleContext<T>(queryable);
return new DataSourceRouteRuleContext<T>(queryable,typeof(TShardingDbContext));
}
/// <summary>
/// 路由到具体的物理数据源

View File

@ -33,7 +33,7 @@ namespace ShardingCore.Core.VirtualRoutes.TableRoutes.RoutingRuleEngine
public IEnumerable<TableRouteResult> Route<T>(TableRouteRuleContext<T> tableRouteRuleContext)
{
Dictionary<IVirtualTable, ISet<IPhysicTable>> routeMaps = new Dictionary<IVirtualTable, ISet<IPhysicTable>>();
var queryEntities = tableRouteRuleContext.Queryable.ParseQueryableEntities();
var queryEntities = tableRouteRuleContext.Queryable.ParseQueryableEntities(typeof(TShardingDbContext));
foreach (var shardingEntity in queryEntities)

View File

@ -54,7 +54,7 @@ namespace ShardingCore.EFCores
private IQueryCompiler GetQueryCompilerIfNoShardingQuery(IShardingDbContext shardingDbContext, Expression query)
{
var queryEntities = ShardingUtil.GetQueryEntitiesByExpression(query);
var queryEntities = ShardingUtil.GetQueryEntitiesByExpression(query, shardingDbContext.GetType());
var entityMetadataManager = (IEntityMetadataManager)ShardingContainer.GetService(typeof(IEntityMetadataManager<>).GetGenericType0(shardingDbContext.GetType()));
if (queryEntities.All(o => !entityMetadataManager.IsSharding(o)))
{

View File

@ -100,9 +100,9 @@ namespace ShardingCore.Extensions
return nameof(object.Equals).Equals(express.Method.Name);
}
public static ISet<Type> ParseQueryableEntities(this IQueryable queryable)
public static ISet<Type> ParseQueryableEntities(this IQueryable queryable, Type dbContextType)
{
return ShardingUtil.GetQueryEntitiesFilter(queryable);
return ShardingUtil.GetQueryEntitiesFilter(queryable, dbContextType);
}

View File

@ -14,16 +14,16 @@ namespace ShardingCore.Extensions
*/
public static class GenericExtension
{
//public static Type[] GetGenericArguments(this Type type, Type genericType)
//{
// return type.GetInterfaces() //取类型的接口
// .Where(i => IsGenericType(i)) //筛选出相应泛型接口
// .SelectMany(i => i.GetGenericArguments()) //选择所有接口的泛型参数
// .ToArray(); //ToArray
public static Type[] GetGenericArguments(this Type type, Type genericType)
{
return type.GetInterfaces() //取类型的接口
.Where(i => IsGenericType(i)) //筛选出相应泛型接口
.SelectMany(i => i.GetGenericArguments()) //选择所有接口的泛型参数
.ToArray(); //ToArray
// bool IsGenericType(Type type1)
// => type1.IsGenericType && type1.GetGenericTypeDefinition() == genericType;
//}
bool IsGenericType(Type type1)
=> type1.IsGenericType && type1.GetGenericTypeDefinition() == genericType;
}
public static bool HasImplementedRawGeneric(this Type type, Type generic)
{

View File

@ -40,7 +40,7 @@ namespace ShardingCore.Extensions
{
var entityMetadataManager = ShardingContainer.GetService<IEntityMetadataManager<TShardingDbContext>>();
var queryEntities = streamMergeContext.GetOriginalQueryable().ParseQueryableEntities();
var queryEntities = streamMergeContext.GetOriginalQueryable().ParseQueryableEntities(typeof(TShardingDbContext));
//仅一个对象支持分库或者分表的组合
return queryEntities.Count(o=>(entityMetadataManager.IsShardingDataSource(o) &&!entityMetadataManager.IsShardingTable(o)) ||(entityMetadataManager.IsShardingDataSource(o)&& entityMetadataManager.IsShardingTable(o))|| (!entityMetadataManager.IsShardingDataSource(o) && entityMetadataManager.IsShardingTable(o))) ==1;
}

View File

@ -82,7 +82,7 @@ namespace ShardingCore.Sharding
IEnumerable<TableRouteResult> tableRouteResults,
IRouteTailFactory routeTailFactory)
{
QueryEntities = source.ParseQueryableEntities();
QueryEntities = source.ParseQueryableEntities(shardingDbContext.GetType());
//_shardingScopeFactory = shardingScopeFactory;
_source = source;
_shardingDbContext = shardingDbContext;

View File

@ -1,8 +1,11 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;
using ShardingCore.Core.TrackerManagers;
using ShardingCore.Extensions;
namespace ShardingCore.Core.Internal.Visitors.Querys
@ -19,8 +22,14 @@ namespace ShardingCore.Core.Internal.Visitors.Querys
/// </summary>
internal class QueryEntitiesVisitor : ExpressionVisitor
{
private readonly ITrackerManager _trackerManager;
private readonly ISet<Type> _shardingEntities = new HashSet<Type>();
public QueryEntitiesVisitor(ITrackerManager trackerManager)
{
_trackerManager = trackerManager;
}
public ISet<Type> GetQueryEntities()
{
@ -35,17 +44,56 @@ namespace ShardingCore.Core.Internal.Visitors.Querys
return base.VisitConstant(node);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
var methodName = node.Method.Name;
if (methodName == nameof(EntityFrameworkQueryableExtensions.Include) || methodName == nameof(EntityFrameworkQueryableExtensions.ThenInclude))
{
var genericArguments = node.Type.GetGenericArguments();
for (var i = 0; i < genericArguments.Length; i++)
{
var genericArgument = genericArguments[i];
if (typeof(IEnumerable).IsAssignableFrom(genericArgument))
{
var arguments = genericArgument.GetGenericArguments();
foreach (var argument in arguments)
{
//if is db context model
if (_trackerManager.IsDbContextModel(argument))
{
_shardingEntities.Add(argument);
}
}
}
if (!genericArgument.IsSimpleType())
{
//if is db context model
if (_trackerManager.IsDbContextModel(genericArgument))
{
_shardingEntities.Add(genericArgument);
}
}
}
}
return base.VisitMethodCall(node);
}
}
#endif
#if EFCORE5 || EFCORE6
#if EFCORE5 || EFCORE6
/// <summary>
/// 获取分表类型
/// </summary>
internal class QueryEntitiesVisitor : ExpressionVisitor
{
private readonly ITrackerManager _trackerManager;
private readonly ISet<Type> _shardingEntities = new HashSet<Type>();
public QueryEntitiesVisitor(ITrackerManager trackerManager)
{
_trackerManager = trackerManager;
}
public ISet<Type> GetQueryEntities()
{
@ -60,142 +108,177 @@ namespace ShardingCore.Core.Internal.Visitors.Querys
}
return base.VisitExtension(node);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
var methodName = node.Method.Name;
if (methodName == nameof(EntityFrameworkQueryableExtensions.Include) || methodName == nameof(EntityFrameworkQueryableExtensions.ThenInclude))
{
var genericArguments = node.Type.GetGenericArguments();
for (var i = 0; i < genericArguments.Length; i++)
{
var genericArgument = genericArguments[i];
if (typeof(IEnumerable).IsAssignableFrom(genericArgument))
{
var arguments = genericArgument.GetGenericArguments();
foreach (var argument in arguments)
{
//if is db context model
if (_trackerManager.IsDbContextModel(argument))
{
_shardingEntities.Add(argument);
}
}
}
if (!genericArgument.IsSimpleType())
{
//if is db context model
if (_trackerManager.IsDbContextModel(genericArgument))
{
_shardingEntities.Add(genericArgument);
}
}
}
}
return base.VisitMethodCall(node);
}
}
#endif
// internal class ShardingEntitiesVisitor : ExpressionVisitor
// {
// private readonly IVirtualTableManager _virtualTableManager;
// private readonly ISet<Type> _shardingEntities = new HashSet<Type>();
//
// public ShardingEntitiesVisitor(IVirtualTableManager virtualTableManager)
// {
// _virtualTableManager = virtualTableManager;
// }
//
// public ISet<Type> GetShardingEntities()
// {
// return _shardingEntities;
// }
//
// private bool IsShardingKey(Expression expression, out Type shardingEntity)
// {
// if (expression is MemberExpression member
// && _virtualTableManager.IsShardingKey(member.Expression.Type, member.Member.Name))
// {
// shardingEntity = member.Expression.Type;
// return true;
// }
//
// shardingEntity = null;
// return false;
// }
//
// private bool IsMethodShardingKey(MethodCallExpression methodCallExpression, out Type shardingEntity)
// {
// if (methodCallExpression.Arguments.IsNotEmpty())
// {
// for (int i = 0; i < methodCallExpression.Arguments.Count; i++)
// {
// if (methodCallExpression.Arguments[i] is MemberExpression member
// && _virtualTableManager.IsShardingKey(member.Expression.Type, member.Member.Name))
// {
// shardingEntity = member.Expression.Type;
// return true;
// }
// }
// }
//
// shardingEntity = null;
// return false;
// }
//
// protected override Expression VisitMethodCall(MethodCallExpression node)
// {
// var methodName = node.Method.Name;
// switch (methodName)
// {
// case nameof(Queryable.Where):
// ResolveWhere(node);
// break;
// }
//
//
// return base.VisitMethodCall(node);
// }
//
// private void ResolveWhere(MethodCallExpression node)
// {
// if (node.Arguments[1] is UnaryExpression unaryExpression)
// {
// if (unaryExpression.Operand is LambdaExpression lambdaExpression)
// {
// Resolve(lambdaExpression);
// }
// }
// }
//
//
// private void Resolve(Expression expression)
// {
// if (expression is LambdaExpression)
// {
// LambdaExpression lambda = expression as LambdaExpression;
// expression = lambda.Body;
// Resolve(expression);
// }
//
// if (expression is BinaryExpression binaryExpression) //解析二元运算符
// {
// ParseGetWhere(binaryExpression);
// }
//
// if (expression is UnaryExpression) //解析一元运算符
// {
// UnaryExpression unary = expression as UnaryExpression;
// if (unary.Operand is MethodCallExpression methodCall1Expression)
// {
// ResolveInFunc(methodCall1Expression, unary.NodeType != ExpressionType.Not);
// }
// }
//
// if (expression is MethodCallExpression methodCallExpression) //解析扩展方法
// {
// ResolveInFunc(methodCallExpression, true);
// }
// }
//
// private void ResolveInFunc(MethodCallExpression methodCallExpression, bool @in)
// {
// if (methodCallExpression.IsEnumerableContains(methodCallExpression.Method.Name) && IsMethodShardingKey(methodCallExpression, out var shardingEntity))
// {
// _shardingEntities.Add(shardingEntity);
// }
// }
//
// private void ParseGetWhere(BinaryExpression binaryExpression)
// {
// //递归获取
// if (binaryExpression.Left is BinaryExpression)
// ParseGetWhere(binaryExpression.Left as BinaryExpression);
// if (binaryExpression.Left is MethodCallExpression methodCallExpression)
// {
// Resolve(methodCallExpression);
// }
//
// if (binaryExpression.Left is UnaryExpression unaryExpression)
// Resolve(unaryExpression);
//
// if (binaryExpression.Right is BinaryExpression)
// ParseGetWhere(binaryExpression.Right as BinaryExpression);
//
// if (IsShardingKey(binaryExpression.Left, out var shardingEntity1))
// {
// _shardingEntities.Add(shardingEntity1);
// }
// else if (IsShardingKey(binaryExpression.Right, out var shardingEntity2))
// {
// _shardingEntities.Add(shardingEntity2);
// }
// }
// }
#endif
// internal class ShardingEntitiesVisitor : ExpressionVisitor
// {
// private readonly IVirtualTableManager _virtualTableManager;
// private readonly ISet<Type> _shardingEntities = new HashSet<Type>();
//
// public ShardingEntitiesVisitor(IVirtualTableManager virtualTableManager)
// {
// _virtualTableManager = virtualTableManager;
// }
//
// public ISet<Type> GetShardingEntities()
// {
// return _shardingEntities;
// }
//
// private bool IsShardingKey(Expression expression, out Type shardingEntity)
// {
// if (expression is MemberExpression member
// && _virtualTableManager.IsShardingKey(member.Expression.Type, member.Member.Name))
// {
// shardingEntity = member.Expression.Type;
// return true;
// }
//
// shardingEntity = null;
// return false;
// }
//
// private bool IsMethodShardingKey(MethodCallExpression methodCallExpression, out Type shardingEntity)
// {
// if (methodCallExpression.Arguments.IsNotEmpty())
// {
// for (int i = 0; i < methodCallExpression.Arguments.Count; i++)
// {
// if (methodCallExpression.Arguments[i] is MemberExpression member
// && _virtualTableManager.IsShardingKey(member.Expression.Type, member.Member.Name))
// {
// shardingEntity = member.Expression.Type;
// return true;
// }
// }
// }
//
// shardingEntity = null;
// return false;
// }
//
// protected override Expression VisitMethodCall(MethodCallExpression node)
// {
// var methodName = node.Method.Name;
// switch (methodName)
// {
// case nameof(Queryable.Where):
// ResolveWhere(node);
// break;
// }
//
//
// return base.VisitMethodCall(node);
// }
//
// private void ResolveWhere(MethodCallExpression node)
// {
// if (node.Arguments[1] is UnaryExpression unaryExpression)
// {
// if (unaryExpression.Operand is LambdaExpression lambdaExpression)
// {
// Resolve(lambdaExpression);
// }
// }
// }
//
//
// private void Resolve(Expression expression)
// {
// if (expression is LambdaExpression)
// {
// LambdaExpression lambda = expression as LambdaExpression;
// expression = lambda.Body;
// Resolve(expression);
// }
//
// if (expression is BinaryExpression binaryExpression) //解析二元运算符
// {
// ParseGetWhere(binaryExpression);
// }
//
// if (expression is UnaryExpression) //解析一元运算符
// {
// UnaryExpression unary = expression as UnaryExpression;
// if (unary.Operand is MethodCallExpression methodCall1Expression)
// {
// ResolveInFunc(methodCall1Expression, unary.NodeType != ExpressionType.Not);
// }
// }
//
// if (expression is MethodCallExpression methodCallExpression) //解析扩展方法
// {
// ResolveInFunc(methodCallExpression, true);
// }
// }
//
// private void ResolveInFunc(MethodCallExpression methodCallExpression, bool @in)
// {
// if (methodCallExpression.IsEnumerableContains(methodCallExpression.Method.Name) && IsMethodShardingKey(methodCallExpression, out var shardingEntity))
// {
// _shardingEntities.Add(shardingEntity);
// }
// }
//
// private void ParseGetWhere(BinaryExpression binaryExpression)
// {
// //递归获取
// if (binaryExpression.Left is BinaryExpression)
// ParseGetWhere(binaryExpression.Left as BinaryExpression);
// if (binaryExpression.Left is MethodCallExpression methodCallExpression)
// {
// Resolve(methodCallExpression);
// }
//
// if (binaryExpression.Left is UnaryExpression unaryExpression)
// Resolve(unaryExpression);
//
// if (binaryExpression.Right is BinaryExpression)
// ParseGetWhere(binaryExpression.Right as BinaryExpression);
//
// if (IsShardingKey(binaryExpression.Left, out var shardingEntity1))
// {
// _shardingEntities.Add(shardingEntity1);
// }
// else if (IsShardingKey(binaryExpression.Right, out var shardingEntity2))
// {
// _shardingEntities.Add(shardingEntity2);
// }
// }
// }
}

View File

@ -11,6 +11,7 @@ using ShardingCore.Core.EntityMetadatas;
using ShardingCore.Core.Internal;
using ShardingCore.Core.Internal.Visitors;
using ShardingCore.Core.Internal.Visitors.Querys;
using ShardingCore.Core.TrackerManagers;
using ShardingCore.Core.VirtualDatabase;
using ShardingCore.Core.VirtualRoutes;
using ShardingCore.Extensions;
@ -65,18 +66,17 @@ namespace ShardingCore.Utils
/// 获取本次查询的所有涉及到的对象
/// </summary>
/// <param name="queryable"></param>
/// <param name="dbContextType"></param>
/// <returns></returns>
public static ISet<Type> GetQueryEntitiesFilter(IQueryable queryable)
public static ISet<Type> GetQueryEntitiesFilter(IQueryable queryable,Type dbContextType)
{
QueryEntitiesVisitor visitor = new QueryEntitiesVisitor();
visitor.Visit(queryable.Expression);
return visitor.GetQueryEntities();
return GetQueryEntitiesByExpression(queryable.Expression, dbContextType);
}
public static ISet<Type> GetQueryEntitiesByExpression(Expression expression)
public static ISet<Type> GetQueryEntitiesByExpression(Expression expression, Type dbContextType)
{
QueryEntitiesVisitor visitor = new QueryEntitiesVisitor();
var trackerManager = (ITrackerManager)ShardingContainer.GetService(typeof(ITrackerManager<>).GetGenericType0(dbContextType));
QueryEntitiesVisitor visitor = new QueryEntitiesVisitor(trackerManager);
visitor.Visit(expression);