From 833a2340589a018cc5160259a5b46661cff1ba26 Mon Sep 17 00:00:00 2001 From: xuejiaming <326308290@qq.com> Date: Wed, 16 Nov 2022 21:41:40 +0800 Subject: [PATCH] [#217] --- .../Controllers/WeatherForecastController.cs | 27 ++++++++++--------- .../QueryCompilerExecutor.cs | 2 ++ .../DbContextReplaceQueryableVisitor2_6.cs | 19 ++++++++++--- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/samples/Sample.MySql/Controllers/WeatherForecastController.cs b/samples/Sample.MySql/Controllers/WeatherForecastController.cs index c7643e90..5cb5e449 100644 --- a/samples/Sample.MySql/Controllers/WeatherForecastController.cs +++ b/samples/Sample.MySql/Controllers/WeatherForecastController.cs @@ -171,19 +171,20 @@ namespace Sample.MySql.Controllers [HttpGet] public async Task Get2() { - var sql= from a in _defaultTableDbContext.Set() - join b in _defaultTableDbContext.Set() - on a.Id equals b.Id into t1 - from aa1 in t1.DefaultIfEmpty() - // join bc in _defaultTableDbContext.Set() - // on a.Id equals bc.Id into t2 - // from aa2 in t2.DefaultIfEmpty() - select new - { - ID = a.Id - }; - var listAsync =await sql.ToListAsync(); - // var sysUserMods = await _defaultTableDbContext.Set().FromSqlRaw("select * from SysUserMod where id='2'").ToListAsync(); + // var sql= from a in _defaultTableDbContext.Set() + // join b in _defaultTableDbContext.Set() + // on a.Id equals b.Id into t1 + // from aa1 in t1.DefaultIfEmpty() + // // join bc in _defaultTableDbContext.Set() + // // on a.Id equals bc.Id into t2 + // // from aa2 in t2.DefaultIfEmpty() + // select new + // { + // ID = a.Id + // }; + // var listAsync =await sql.ToListAsync(); + var sysUserMods1 = await _defaultTableDbContext.Set().FromSqlRaw("select * from SysUserMod where id='2'").ToListAsync(); + var sysUserMods2 = await _defaultTableDbContext.Set().FromSqlRaw("select * from SysTest where id='2'").ToListAsync(); return Ok(); } } diff --git a/src/ShardingCore/Sharding/ShardingExecutors/QueryCompilerExecutor.cs b/src/ShardingCore/Sharding/ShardingExecutors/QueryCompilerExecutor.cs index 1ccfa1fa..75137c65 100644 --- a/src/ShardingCore/Sharding/ShardingExecutors/QueryCompilerExecutor.cs +++ b/src/ShardingCore/Sharding/ShardingExecutors/QueryCompilerExecutor.cs @@ -15,10 +15,12 @@ namespace ShardingCore.Sharding.ShardingExecutors { private readonly IQueryCompiler _queryCompiler; private readonly Expression _queryExpression; + private readonly Expression _originalQueryExpression; public QueryCompilerExecutor(DbContext dbContext,Expression queryExpression) { _queryCompiler = dbContext.GetService(); + _originalQueryExpression = queryExpression; _queryExpression = queryExpression.ReplaceDbContextExpression(dbContext); } diff --git a/src/ShardingCore/Sharding/Visitors/DbContextReplaceQueryableVisitor2_6.cs b/src/ShardingCore/Sharding/Visitors/DbContextReplaceQueryableVisitor2_6.cs index c5cde54b..9b13a3c2 100644 --- a/src/ShardingCore/Sharding/Visitors/DbContextReplaceQueryableVisitor2_6.cs +++ b/src/ShardingCore/Sharding/Visitors/DbContextReplaceQueryableVisitor2_6.cs @@ -6,6 +6,7 @@ using System.Reflection; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; using ShardingCore.Core.Internal.Visitors; using ShardingCore.Exceptions; using ShardingCore.Extensions; @@ -198,11 +199,21 @@ namespace ShardingCore.Core.Internal.Visitors var newQueryable = targetIQ.Provider.CreateQuery(targetIQ.Expression); if (Source == null) Source = newQueryable; - //如何替换ef5的set - var replaceQueryRoot = new ReplaceSingleQueryRootExpressionVisitor(); - replaceQueryRoot.Visit(newQueryable.Expression); RootIsVisit = true; - return base.VisitExtension(replaceQueryRoot.QueryRootExpression); + if (queryRootExpression is FromSqlQueryRootExpression fromSqlQueryRootExpression) + { + var sqlQueryRootExpression = new FromSqlQueryRootExpression(newQueryable.Provider as IAsyncQueryProvider, + queryRootExpression.EntityType, fromSqlQueryRootExpression.Sql, + fromSqlQueryRootExpression.Argument); + + return base.VisitExtension(sqlQueryRootExpression); + } + else + { + var replaceQueryRoot = new ReplaceSingleQueryRootExpressionVisitor(); + replaceQueryRoot.Visit(newQueryable.Expression); + return base.VisitExtension(replaceQueryRoot.QueryRootExpression); + } } return base.VisitExtension(node);