修复优化当表达式内嵌使用属性的情况下出现:Cannot use multiple context instances within a single query execution. Ensure the query use a single context instance.的错误

This commit is contained in:
xuejiaming 2022-09-29 11:28:26 +08:00
parent 6e4afa7484
commit ecf6419ecc
3 changed files with 39 additions and 36 deletions

View File

@ -22,6 +22,23 @@ namespace Sample.MySql.Controllers
public string name { get; set; }
public int count { get; set; }
}
public class ABC
{
private readonly DefaultShardingDbContext _defaultTableDbContext;
public ABC(DefaultShardingDbContext defaultTableDbContext)
{
_defaultTableDbContext = defaultTableDbContext;
}
public IQueryable<SysTest> GetAll()
{
return _defaultTableDbContext.Set<SysTest>();
}
public virtual IQueryable<SysTest> Select => this.GetAll();
}
[ApiController]
[Route("[controller]/[action]")]
public class WeatherForecastController : ControllerBase
@ -29,11 +46,13 @@ namespace Sample.MySql.Controllers
private readonly DefaultShardingDbContext _defaultTableDbContext;
private readonly IShardingRuntimeContext _shardingRuntimeContext;
private readonly ABC _abc;
public WeatherForecastController(DefaultShardingDbContext defaultTableDbContext,IShardingRuntimeContext shardingRuntimeContext)
{
_defaultTableDbContext = defaultTableDbContext;
_shardingRuntimeContext = shardingRuntimeContext;
_abc=new ABC(_defaultTableDbContext);
}
public IQueryable<SysTest> GetAll()
@ -86,7 +105,12 @@ namespace Sample.MySql.Controllers
// var firstOrDefault = _defaultTableDbContext.Set<SysUserMod>().FromSqlRaw($"select * from {nameof(SysUserMod)}").FirstOrDefault();
var sysUserMods1 = _defaultTableDbContext.Set<SysTest>()
.Select(o => new ssss(){ Id = o.Id, C = _abc.Select.Count(x => x.Id == o.Id) }).ToList();
var sysUserMods2 = _defaultTableDbContext.Set<SysTest>()
.Select(o => new ssss(){ Id = o.Id, C = GetAll().Count(x => x.Id == o.Id) }).ToList();
var sysTests = GetAll();
var sysUserMods3 = _defaultTableDbContext.Set<SysTest>()
.Select(o => new ssss(){ Id = o.Id, C = sysTests.Count(x => x.Id == o.Id) }).ToList();
var resultX = await _defaultTableDbContext.Set<SysUserMod>()
.Where(o => o.Id == "2" || o.Id == "3").FirstOrDefaultAsync();
var resultY = await _defaultTableDbContext.Set<SysUserMod>().FirstOrDefaultAsync(o => o.Id == "2" || o.Id == "3");

View File

@ -176,7 +176,7 @@ namespace ShardingCore.Extensions
{
if (memberExpression == null)
throw new ArgumentNullException(nameof(memberExpression));
return (memberExpression.Type.FullName?.StartsWith("System.Linq.IQueryable`1") ?? false) || typeof(DbContext).IsAssignableFrom(memberExpression.Type);
return (memberExpression.Type.FullName?.StartsWith("System.Linq.IQueryable`1") ?? false) ||typeof(IQueryable).IsAssignableFrom(memberExpression.Type) || typeof(DbContext).IsAssignableFrom(memberExpression.Type);
}
public static Type GetSequenceType(this Type type)

View File

@ -18,7 +18,7 @@ namespace ShardingCore.Core.Internal.Visitors
* @Email: 326308290@qq.com
*/
internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ExpressionVisitor
internal class DbContextInnerMemberReferenceReplaceQueryableVisitor : ShardingExpressionVisitor
{
private readonly DbContext _dbContext;
protected bool RootIsVisit = false;
@ -28,50 +28,29 @@ namespace ShardingCore.Core.Internal.Visitors
_dbContext = dbContext;
}
// public override Expression Visit(Expression node)
// {
// Console.WriteLine("1");
// return base.Visit(node);
// }
protected override Expression VisitMember
(MemberExpression memberExpression)
{
// Recurse down to see if we can simplify...
//if (memberExpression.IsMemberQueryable()) //2x,3x 路由 单元测试 分表和不分表
//{
var expression = Visit(memberExpression.Expression);
// If we've ended up with a constant, and it's a property or a field,
// we can simplify ourselves to a constant
if (expression is ConstantExpression constantExpression)
if (memberExpression.IsMemberQueryable()) //2x,3x 路由 单元测试 分表和不分表
{
object container = constantExpression.Value;
var member = memberExpression.Member;
if (member is FieldInfo fieldInfo)
var expressionValue = GetExpressionValue(memberExpression);
if (expressionValue is IQueryable queryable)
{
object value = fieldInfo.GetValue(container);
if (value is IQueryable queryable)
{
return ReplaceMemberExpression(queryable);
}
if (value is DbContext dbContext)
{
return ReplaceMemberExpression(dbContext);
}
//return Expression.Constant(value);
return ReplaceMemberExpression(queryable);
}
if (member is PropertyInfo propertyInfo)
if (expressionValue is DbContext dbContext)
{
object value = propertyInfo.GetValue(container, null);
if (value is IQueryable queryable)
{
return ReplaceMemberExpression(queryable);
}
if (value is DbContext dbContext)
{
return ReplaceMemberExpression(dbContext);
}
return ReplaceMemberExpression(dbContext);
}
}
//}
}
return base.VisitMember(memberExpression);
}