[#143],[#121]内嵌子句修复

This commit is contained in:
xuejiaming 2022-04-30 07:50:43 +08:00
parent df209b8c06
commit 6d68bcf1db
17 changed files with 310 additions and 80 deletions

View File

@ -1,5 +1,6 @@
using System;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
@ -24,20 +25,38 @@ namespace Sample.SqlServerShardingTable.Controllers
}
public async Task<IActionResult> Query()
{
var sysUser =await _myDbContext.Set<SysUser>().Where(o=>o.Id=="1").FirstOrDefaultAsync();
var dateTime = new DateTime(2021,3,5);
var order = await _myDbContext.Set<Order>().Where(o=>o.CreationTime>= dateTime).OrderBy(o=>o.CreationTime).FirstOrDefaultAsync();
var orderIdOne = await _myDbContext.Set<Order>().FirstOrDefaultAsync(o => o.Id == "3");
Console.WriteLine("123123");
var dateTime = new DateTime(2021,2,1);
var orderSet = _myDbContext.Set<Order>().Where(o=>o.CreationTime== dateTime);
var listAsync = await _myDbContext.Set<SysUser>().Where(o=> orderSet.Any(u=>u.Id== o.Id)).ToListAsync();
//Console.WriteLine("123123456");
//var orderSet1 = _myDbContext.Set<Order>().Select(o => o);
//var listAsync2 = await _myDbContext.Set<SysUser>().Where(o => orderSet1.Any(u => u.Id == o.Id)).ToListAsync();
var sysUsers = await _myDbContext.Set<SysUser>().Where(o => o.Id == "1" || o.Id=="6").ToListAsync();
//Console.WriteLine("456456");
//var dbSet1 = _myDbContext.Set<Order>();
//var dbSet2 = _myDbContext.Set<Order>();
//var queryable = (from u in dbSet1
// join x in dbSet2
// on u.Id equals x.Id
// select u
// );
//var @async = queryable.ToListAsync();
//var sysUser =await _myDbContext.Set<SysUser>().Where(o=>o.Id=="1").FirstOrDefaultAsync();
//var dateTime = new DateTime(2021,3,5);
//var order = await _myDbContext.Set<Order>().Where(o=>o.CreationTime>= dateTime).OrderBy(o=>o.CreationTime).FirstOrDefaultAsync();
//var orderIdOne = await _myDbContext.Set<Order>().FirstOrDefaultAsync(o => o.Id == "3");
return Ok(new object[]
{
sysUser,
order,
orderIdOne,
sysUsers
});
//var sysUsers = await _myDbContext.Set<SysUser>().Where(o => o.Id == "1" || o.Id=="6").ToListAsync();
//return Ok(new object[]
//{
// sysUser,
// order,
// orderIdOne,
// sysUsers
//});
return Ok();
}
public async Task<IActionResult> Query2()
{

View File

@ -5,11 +5,11 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="6.0.1" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="3.1.24" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\ShardingCore\ShardingCore.csproj" />
<ProjectReference Include="..\..\src3x\ShardingCore.3x\ShardingCore.3x.csproj" />
</ItemGroup>
</Project>

View File

@ -161,6 +161,12 @@ namespace ShardingCore.Extensions
return ShardingUtil.GetQueryEntitiesFilter(queryable, dbContextType);
}
public static bool IsMemberQueryable(this MemberExpression memberExpression)
{
if (memberExpression == null)
throw new ArgumentNullException(nameof(memberExpression));
return memberExpression.Type.FullName?.StartsWith("System.Linq.IQueryable`1") ?? false;
}
public static Type GetSequenceType(this Type type)
{

View File

@ -1,9 +1,11 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query;
using ShardingCore.Core.Internal.Visitors;
using ShardingCore.Exceptions;
using ShardingCore.Extensions;
@ -15,13 +17,85 @@ namespace ShardingCore.Core.Internal.Visitors
* @Date: Wednesday, 13 January 2021 16:32:27
* @Email: 326308290@qq.com
*/
internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ExpressionVisitor
{
private readonly DbContext _dbContext;
public DbContextInnerMemberReferenceReplaceQueryableVisitor(DbContext dbContext)
{
_dbContext = dbContext;
}
protected override Expression VisitMember
(MemberExpression memberExpression)
{
// Recurse down to see if we can simplify...
if (memberExpression.IsMemberQueryable()) //2x,3x 路由 单元测试 分表和不分表
{
var expression = Visit(memberExpression.Expression);
// If we've ended up with a constant, and it's a property or a field,
// we can simplify ourselves to a constant
if (expression is ConstantExpression constantExpression)
{
object container = constantExpression.Value;
var member = memberExpression.Member;
if (member is FieldInfo fieldInfo)
{
object value = fieldInfo.GetValue(container);
if (value is IQueryable queryable)
{
return ReplaceMemberExpression(queryable);
}
//return Expression.Constant(value);
}
if (member is PropertyInfo propertyInfo)
{
object value = propertyInfo.GetValue(container, null);
if (value is IQueryable queryable)
{
return ReplaceMemberExpression(queryable);
}
}
}
}
return base.VisitMember(memberExpression);
}
private MemberExpression ReplaceMemberExpression(IQueryable queryable)
{
var dbContextReplaceQueryableVisitor = new DbContextReplaceQueryableVisitor(_dbContext);
var newExpression = dbContextReplaceQueryableVisitor.Visit(queryable.Expression);
var newQueryable = dbContextReplaceQueryableVisitor.Source.Provider.CreateQuery(newExpression);
var tempVariableGenericType = typeof(TempVariable<>).GetGenericType0(queryable.ElementType);
var tempVariable = Activator.CreateInstance(tempVariableGenericType, newQueryable);
MemberExpression queryableMemberReplaceExpression =
Expression.Property(ConstantExpression.Constant(tempVariable), nameof(TempVariable<object>.Queryable));
return queryableMemberReplaceExpression;
}
internal sealed class TempVariable<T1>
{
public IQueryable<T1> Queryable { get; }
public TempVariable(IQueryable<T1> queryable)
{
Queryable = queryable;
}
}
}
#if EFCORE2 || EFCORE3
internal class DbContextReplaceQueryableVisitor : ExpressionVisitor
internal class DbContextReplaceQueryableVisitor : DbContextInnerMemberReferenceReplaceQueryableVisitor
{
private readonly DbContext _dbContext;
public IQueryable Source;
public DbContextReplaceQueryableVisitor(DbContext dbContext)
public DbContextReplaceQueryableVisitor(DbContext dbContext) : base(dbContext)
{
_dbContext = dbContext;
}
@ -30,13 +104,18 @@ namespace ShardingCore.Core.Internal.Visitors
{
if (node.Value is IQueryable queryable)
{
var dbContextDependencies = typeof(DbContext).GetTypePropertyValue(_dbContext, "DbContextDependencies") as IDbContextDependencies;
var targetIQ = (IQueryable)((IDbSetCache)_dbContext).GetOrAddSet(dbContextDependencies.SetSource, queryable.ElementType);
var dbContextDependencies =
typeof(DbContext).GetTypePropertyValue(_dbContext, "DbContextDependencies") as
IDbContextDependencies;
var targetIQ =
(IQueryable)((IDbSetCache)_dbContext).GetOrAddSet(dbContextDependencies.SetSource,
queryable.ElementType);
IQueryable newQueryable = null;
//if (_isParallelQuery)
// newQueryable = targetIQ.Provider.CreateQuery((Expression)Expression.Call((Expression)null, typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(EntityFrameworkQueryableExtensions.AsNoTracking)).MakeGenericMethod(queryable.ElementType), targetIQ.Expression));
//else
newQueryable = targetIQ.Provider.CreateQuery(targetIQ.Expression);
if (Source == null)
Source = newQueryable;
// return base.Visit(Expression.Constant(newQueryable));
return Expression.Constant(newQueryable);
@ -44,16 +123,18 @@ namespace ShardingCore.Core.Internal.Visitors
return base.VisitConstant(node);
}
}
#endif
#if EFCORE5 || EFCORE6
internal class DbContextReplaceQueryableVisitor : ExpressionVisitor
internal class DbContextReplaceQueryableVisitor : DbContextInnerMemberReferenceReplaceQueryableVisitor
{
private readonly DbContext _dbContext;
public IQueryable Source;
public DbContextReplaceQueryableVisitor(DbContext dbContext)
public DbContextReplaceQueryableVisitor(DbContext dbContext) : base(dbContext)
{
_dbContext = dbContext;
}
@ -62,25 +143,24 @@ namespace ShardingCore.Core.Internal.Visitors
{
if (node is QueryRootExpression queryRootExpression)
{
var dbContextDependencies = typeof(DbContext).GetTypePropertyValue(_dbContext, "DbContextDependencies") as IDbContextDependencies;
var targetIQ = (IQueryable)((IDbSetCache)_dbContext).GetOrAddSet(dbContextDependencies.SetSource, queryRootExpression.EntityType.ClrType);
//AsNoTracking
IQueryable newQueryable = null;
//if (_isParallelQuery)
// newQueryable = targetIQ.Provider.CreateQuery((Expression)Expression.Call((Expression)null, typeof(EntityFrameworkQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(EntityFrameworkQueryableExtensions.AsNoTracking)).MakeGenericMethod(queryRootExpression.EntityType.ClrType), targetIQ.Expression));
//else
newQueryable = targetIQ.Provider.CreateQuery(targetIQ.Expression);
var dbContextDependencies =
typeof(DbContext).GetTypePropertyValue(_dbContext, "DbContextDependencies") as IDbContextDependencies;
var targetIQ =
(IQueryable)((IDbSetCache)_dbContext).GetOrAddSet(dbContextDependencies.SetSource, queryRootExpression.EntityType.ClrType);
var newQueryable = targetIQ.Provider.CreateQuery(targetIQ.Expression);
if (Source == null)
Source = newQueryable;
//如何替换ef5的set
var replaceQueryRoot = new ReplaceSingleQueryRootExpressionVisitor();
replaceQueryRoot.Visit(Source.Expression);
replaceQueryRoot.Visit(newQueryable.Expression);
return base.VisitExtension(replaceQueryRoot.QueryRootExpression);
}
return base.VisitExtension(node);
}
class ReplaceSingleQueryRootExpressionVisitor : ExpressionVisitor
internal sealed class ReplaceSingleQueryRootExpressionVisitor : ExpressionVisitor
{
public QueryRootExpression QueryRootExpression { get; set; }
@ -98,37 +178,5 @@ namespace ShardingCore.Core.Internal.Visitors
}
}
#endif
// class ReplaceQueryableVisitor : ExpressionVisitor
// {
// private readonly QueryRootExpression _queryRootExpression;
// public ReplaceQueryableVisitor(IQueryable newQuery)
// {
// var visitor = new GetQueryRootVisitor();
// visitor.Visit(newQuery.Expression);
// _queryRootExpression = visitor.QueryRootExpression;
// }
//
// protected override Expression VisitExtension(Expression node)
// {
// if (node is QueryRootExpression)
// {
// return _queryRootExpression;
// }
//
// return base.VisitExtension(node);
// }
// }
// class GetQueryRootVisitor : ExpressionVisitor
// {
// public QueryRootExpression QueryRootExpression { get; set; }
// protected override Expression VisitExtension(Expression node)
// {
// if (node is QueryRootExpression expression)
// {
// QueryRootExpression = expression;
// }
//
// return base.VisitExtension(node);
// }
// }
}

View File

@ -3,10 +3,12 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;
using ShardingCore.Core.Internal.Visitors;
using ShardingCore.Core.TrackerManagers;
using ShardingCore.Extensions;
@ -49,6 +51,39 @@ namespace ShardingCore.Sharding.Visitors.Querys
return base.VisitExtension(node);
}
#endif
protected override Expression VisitMember
(MemberExpression memberExpression)
{
// Recurse down to see if we can simplify...
var expression = Visit(memberExpression.Expression);
// If we've ended up with a constant, and it's a property or a field,
// we can simplify ourselves to a constant
if (expression is ConstantExpression)
{
object container = ((ConstantExpression)expression).Value;
var member = memberExpression.Member;
if (member is FieldInfo fieldInfo)
{
object value = fieldInfo.GetValue(container);
if (value is IQueryable queryable)
{
shardingEntities.Add(queryable.ElementType);
}
//return Expression.Constant(value);
}
if (member is PropertyInfo propertyInfo)
{
object value = propertyInfo.GetValue(container, null);
if (value is IQueryable queryable)
{
shardingEntities.Add(queryable.ElementType);
}
}
}
return base.VisitMember(memberExpression);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
switch (node.Method.Name)

View File

@ -3,6 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;
using ShardingCore.Core.TrackerManagers;
@ -78,6 +79,38 @@ namespace ShardingCore.Core.Internal.Visitors.Querys
}
return base.VisitMethodCall(node);
}
protected override Expression VisitMember
(MemberExpression memberExpression)
{
// Recurse down to see if we can simplify...
var expression = Visit(memberExpression.Expression);
// If we've ended up with a constant, and it's a property or a field,
// we can simplify ourselves to a constant
if (expression is ConstantExpression)
{
object container = ((ConstantExpression)expression).Value;
var member = memberExpression.Member;
if (member is FieldInfo fieldInfo)
{
object value = fieldInfo.GetValue(container);
if (value is IQueryable queryable)
{
_shardingEntities.Add(queryable.ElementType);
}
//return Expression.Constant(value);
}
if (member is PropertyInfo propertyInfo)
{
object value = propertyInfo.GetValue(container, null);
if (value is IQueryable queryable)
{
_shardingEntities.Add(queryable.ElementType);
}
}
}
return base.VisitMember(memberExpression);
}
}
#endif
@ -143,6 +176,38 @@ namespace ShardingCore.Core.Internal.Visitors.Querys
}
return base.VisitMethodCall(node);
}
protected override Expression VisitMember
(MemberExpression memberExpression)
{
// Recurse down to see if we can simplify...
var expression = Visit(memberExpression.Expression);
// If we've ended up with a constant, and it's a property or a field,
// we can simplify ourselves to a constant
if (expression is ConstantExpression)
{
object container = ((ConstantExpression)expression).Value;
var member = memberExpression.Member;
if (member is FieldInfo fieldInfo)
{
object value = fieldInfo.GetValue(container);
if (value is IQueryable queryable)
{
_shardingEntities.Add(queryable.ElementType);
}
//return Expression.Constant(value);
}
if (member is PropertyInfo propertyInfo)
{
object value = propertyInfo.GetValue(container, null);
if (value is IQueryable queryable)
{
_shardingEntities.Add(queryable.ElementType);
}
}
}
return base.VisitMember(memberExpression);
}
}
#endif
// internal class ShardingEntitiesVisitor : ExpressionVisitor

View File

@ -29,8 +29,8 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="6.0.4" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.4" />
</ItemGroup>
<ItemGroup>

View File

@ -41,8 +41,8 @@
<Compile Remove="..\..\src\ShardingCore\ShardingTableConfig.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="3.1.22" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.22" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="3.1.24" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.24" />
</ItemGroup>
</Project>

View File

@ -31,8 +31,8 @@
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="5.0.13" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="5.0.13" />
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="5.0.16" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="5.0.16" />
</ItemGroup>
</Project>

View File

@ -1692,6 +1692,13 @@ namespace ShardingCore.Test
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).SumAsync();
Assert.Equal(0, sum);
}
[Fact]
public async Task QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o=>o);
var sysUserModInts = await _virtualDbContext.Set<SysUserModInt>().Where(o=>sysUserMods.Select(i=>i.Age).Any(i=>i==o.Age)).ToListAsync();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public async Task Group_API_Test()
// {

View File

@ -1458,6 +1458,13 @@ namespace ShardingCore.Test
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).Sum();
Assert.Equal(0, sum);
}
[Fact]
public void QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToList();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public void Group_API_Test()
// {

View File

@ -1532,6 +1532,14 @@ namespace ShardingCore.Test2x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).SumAsync();
Assert.Equal(0, sum);
}
[Fact]
public async Task QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = await _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToListAsync();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public async Task Group_API_Test()
// {

View File

@ -1455,6 +1455,13 @@ namespace ShardingCore.Test2x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).Sum();
Assert.Equal(0, sum);
}
[Fact]
public void QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToList();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public void Group_API_Test()
// {

View File

@ -1531,6 +1531,13 @@ namespace ShardingCore.Test3x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).SumAsync();
Assert.Equal(0, sum);
}
[Fact]
public async Task QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = await _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToListAsync();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public async Task Group_API_Test()
// {

View File

@ -1457,6 +1457,13 @@ namespace ShardingCore.Test3x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).Sum();
Assert.Equal(0, sum);
}
[Fact]
public void QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToList();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public void Group_API_Test()
// {

View File

@ -1531,6 +1531,13 @@ namespace ShardingCore.Test5x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).SumAsync();
Assert.Equal(0, sum);
}
[Fact]
public async Task QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = await _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToListAsync();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public async Task Group_API_Test()
// {

View File

@ -1457,6 +1457,13 @@ namespace ShardingCore.Test5x
.Where(o => o.CreateTime == fiveBegin).Select(o => o.Money).Sum();
Assert.Equal(0, sum);
}
[Fact]
public void QueryInner_Test()
{
var sysUserMods = _virtualDbContext.Set<SysUserMod>().Select(o => o);
var sysUserModInts = _virtualDbContext.Set<SysUserModInt>().Where(o => sysUserMods.Select(i => i.Age).Any(i => i == o.Age)).ToList();
Assert.Equal(1000, sysUserModInts.Count);
}
// [Fact]
// public void Group_API_Test()
// {