修复sqlserver依赖注入bug

This commit is contained in:
xuejmnet 2021-02-01 08:16:28 +08:00
parent 071f0356d6
commit a7b20599de
16 changed files with 239 additions and 61 deletions

View File

@ -38,14 +38,7 @@ namespace ShardingCore.MySql
services.AddSingleton(options);
services.AddShardingCore();
services.AddScoped<IVirtualDbContext, VirtualDbContext>();
services.AddScoped<IDbContextOptionsProvider, SqlServerDbContextOptionsProvider>();
services.AddSingleton<IShardingDbContextFactory, ShardingDbContextFactory>();
services.AddSingleton<IShardingTableCreator, ShardingTableCreator>();
services.AddSingleton<IVirtualTableManager, OneDbVirtualTableManager>();
services.AddSingleton(typeof(IVirtualTable<>), typeof(OneDbVirtualTable<>));
services.AddSingleton<IShardingAccessor, ShardingAccessor>();
services.AddSingleton<IShardingScopeFactory, ShardingScopeFactory>();
services.AddScoped<IDbContextOptionsProvider, MySqlDbContextOptionsProvider>();
services.AddSingleton<IShardingParallelDbContextFactory, ShardingMySqlParallelDbContextFactory>();
if (options.HasSharding)
{

View File

@ -20,12 +20,12 @@ namespace ShardingCore.MySql
* @Date: Thursday, 24 December 2020 10:33:51
* @Email: 326308290@qq.com
*/
public class SqlServerDbContextOptionsProvider:IDbContextOptionsProvider
public class MySqlDbContextOptionsProvider:IDbContextOptionsProvider
{
private DbContextOptions _dbContextOptions;
private MySqlConnection _connection;
public SqlServerDbContextOptionsProvider(MySqlOptions mySqlOptions,ILoggerFactory loggerFactory)
public MySqlDbContextOptionsProvider(MySqlOptions mySqlOptions,ILoggerFactory loggerFactory)
{
_connection=new MySqlConnection(mySqlOptions.ConnectionString);
_dbContextOptions = new DbContextOptionsBuilder()

View File

@ -34,15 +34,9 @@ namespace ShardingCore.SqlServer
var options = new SqlServerOptions();
configure(options);
services.AddSingleton(options);
services.AddShardingCore();
services.AddScoped<IVirtualDbContext, VirtualDbContext>();
services.AddScoped<IDbContextOptionsProvider, SqlServerDbContextOptionsProvider>();
services.AddSingleton<IShardingDbContextFactory, ShardingDbContextFactory>();
services.AddSingleton<IShardingTableCreator, ShardingTableCreator>();
services.AddSingleton<IVirtualTableManager, OneDbVirtualTableManager>();
services.AddSingleton(typeof(IVirtualTable<>), typeof(OneDbVirtualTable<>));
services.AddSingleton<IShardingAccessor, ShardingAccessor>();
services.AddSingleton<IShardingScopeFactory, ShardingScopeFactory>();
services.AddSingleton<IShardingParallelDbContextFactory, ShardingSqlServerParallelDbContextFactory>();
if (options.HasSharding)
{

View File

@ -111,6 +111,17 @@ namespace ShardingCore.Core
/// </summary>
/// <returns></returns>
Task<float> FloatSumAsync();
/// <summary>
/// 平均数
/// </summary>
/// <returns></returns>
Task<double> AverageAsync();
/// <summary>
/// 平均数
/// </summary>
/// <returns></returns>
Task<double> LongAverageAsync();
/// <summary>
/// 平均数
/// </summary>

View File

@ -15,5 +15,6 @@ namespace ShardingCore.Core.Internal.RoutingRuleEngines
IRouteRuleEngine CreateEngine();
RouteRuleContext<T> CreateContext<T>(IQueryable<T> queryable);
IEnumerable<RouteResult> Route<T>(IQueryable<T> queryable);
IEnumerable<RouteResult> Route<T>(IQueryable<T> queryable,RouteRuleContext<T> ruleContext);
}
}

View File

@ -35,8 +35,14 @@ namespace ShardingCore.Core.Internal.RoutingRuleEngines
public IEnumerable<RouteResult> Route<T>(IQueryable<T> queryable)
{
var engine = CreateEngine();
var routeRuleContext = CreateContext<T>(queryable);
return engine.Route(routeRuleContext);
var ruleContext = CreateContext<T>(queryable);
return engine.Route(ruleContext);
}
public IEnumerable<RouteResult> Route<T>(IQueryable<T> queryable, RouteRuleContext<T> ruleContext)
{
var engine = CreateEngine();
return engine.Route(ruleContext);
}
}
}

View File

@ -1,10 +1,8 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using ShardingCore.Core.Internal.StreamMerge.Abstractions;
using ShardingCore.Core.Internal.StreamMerge.Enumerators;
namespace ShardingCore.Core.Internal.StreamMerge.GenericMerges
namespace ShardingCore.Core.Internal.StreamMerge.GenericMerges.Proxies
{
/*
* @Author: xjm
@ -14,13 +12,11 @@ namespace ShardingCore.Core.Internal.StreamMerge.GenericMerges
*/
internal class GenericStreamMergeProxyEngine<T> : IDisposable
{
private readonly StreamMergeContext<T> _mergeContext;
private IStreamMergeEngine<T> _streamMergeEngine;
private GenericStreamMergeProxyEngine(StreamMergeContext<T> mergeContext)
{
_mergeContext = mergeContext;
_streamMergeEngine = GenericStreamMergeEngine<T>.Create(mergeContext);
}

View File

@ -15,5 +15,6 @@ namespace ShardingCore.Core.Internal.StreamMerge
{
StreamMergeContext<T> Create<T>(IQueryable<T> queryable, IEnumerable<RouteResult> routeResults);
StreamMergeContext<T> Create<T>(IQueryable<T> queryable);
StreamMergeContext<T> Create<T>(IQueryable<T> queryable, RouteRuleContext<T> ruleContext);
}
}

View File

@ -33,5 +33,9 @@ namespace ShardingCore.Core.Internal.StreamMerge
{
return new StreamMergeContext<T>(queryable, _routingRuleEngineFactory.Route(queryable), _shardingParallelDbContextFactory, _shardingScopeFactory);
}
public StreamMergeContext<T> Create<T>(IQueryable<T> queryable,RouteRuleContext<T> ruleContext)
{
return new StreamMergeContext<T>(queryable, _routingRuleEngineFactory.Route(queryable,ruleContext), _shardingParallelDbContextFactory, _shardingScopeFactory);
}
}
}

View File

@ -8,6 +8,7 @@ using Microsoft.Extensions.DependencyInjection;
using ShardingCore.Core.Internal.RoutingRuleEngines;
using ShardingCore.Core.Internal.StreamMerge;
using ShardingCore.Core.Internal.StreamMerge.GenericMerges;
using ShardingCore.Core.Internal.StreamMerge.GenericMerges.Proxies;
using ShardingCore.Core.VirtualTables;
using ShardingCore.Extensions;
#if EFCORE2
@ -30,18 +31,16 @@ namespace ShardingCore.Core
public class ShardingQueryable<T> : IShardingQueryable<T>
{
private IQueryable<T> _source;
private bool _autoParseRoute = true;
private readonly IStreamMergeContextFactory _streamMergeContextFactory;
private Dictionary<Type, Expression> _routes = new Dictionary<Type, Expression>();
private readonly Dictionary<IVirtualTable, List<string>> _endRoutes = new Dictionary<IVirtualTable, List<string>>();
private readonly IRoutingRuleEngineFactory _routingRuleEngineFactory;
private readonly RouteRuleContext<T> _routeRuleContext;
private ShardingQueryable(IQueryable<T> source)
{
_source = source;
_streamMergeContextFactory = ShardingContainer.Services.GetService<IStreamMergeContextFactory>();
_routingRuleEngineFactory=ShardingContainer.Services.GetService<IRoutingRuleEngineFactory>();
var routingRuleEngineFactory=ShardingContainer.Services.GetService<IRoutingRuleEngineFactory>();
_routeRuleContext = routingRuleEngineFactory.CreateContext<T>(source);
}
public static ShardingQueryable<TSource> Create<TSource>(IQueryable<TSource> source)
@ -52,48 +51,25 @@ namespace ShardingCore.Core
public IShardingQueryable<T> EnableAutoRouteParse()
{
_autoParseRoute = true;
_routeRuleContext.EnableAutoRouteParse();
return this;
}
public IShardingQueryable<T> DisableAutoRouteParse()
{
_autoParseRoute = false;
_routeRuleContext.DisableAutoRouteParse();
return this;
}
public IShardingQueryable<T> AddManualRoute<TShardingEntity>(Expression<Func<TShardingEntity, bool>> predicate) where TShardingEntity : class, IShardingEntity
{
var shardingEntityType = typeof(TShardingEntity);
if (!_routes.ContainsKey(shardingEntityType))
{
((Expression<Func<TShardingEntity, bool>>) _routes[shardingEntityType]).And(predicate);
}
else
{
_routes.Add(typeof(TShardingEntity), predicate);
}
_routeRuleContext.AddRoute(predicate);
return this;
}
public IShardingQueryable<T> AddManualRoute(IVirtualTable virtualTable, string tail)
{
if (_endRoutes.ContainsKey(virtualTable))
{
var tails = _endRoutes[virtualTable];
if (!tails.Contains(tail))
{
tails.Add(tail);
}
}
else
{
_endRoutes.Add(virtualTable, new List<string>()
{
tail
});
}
_routeRuleContext.AddRoute(virtualTable,tail);
return this;
}
@ -101,7 +77,7 @@ namespace ShardingCore.Core
private StreamMergeContext<T> GetContext()
{
return _streamMergeContextFactory.Create(_source);
return _streamMergeContextFactory.Create(_source,_routeRuleContext);
}
private async Task<List<TResult>> GetGenericMergeEngine<TResult>(Func<IQueryable, Task<TResult>> efQuery)
{
@ -210,6 +186,20 @@ namespace ShardingCore.Core
return results.Sum()/results.Count();
}
public async Task<double> AverageAsync()
{
if (typeof(T) != typeof(int))
throw new InvalidOperationException($"{typeof(T)} cast to int failed");
var results = await GetGenericMergeEngine(async queryable => await EntityFrameworkQueryableExtensions.AverageAsync((IQueryable<int>) queryable));
return results.Sum()/results.Count();
}
public async Task<double> LongAverageAsync()
{
if (typeof(T) != typeof(long))
throw new InvalidOperationException($"{typeof(T)} cast to long failed");
var results = await GetGenericMergeEngine(async queryable => await EntityFrameworkQueryableExtensions.AverageAsync((IQueryable<long>) queryable));
return results.Sum()/results.Count();
}
public async Task<double> DoubleAverageAsync()
{
if (typeof(T) != typeof(double))

View File

@ -2,6 +2,11 @@ using System;
using Microsoft.Extensions.DependencyInjection;
using ShardingCore.Core.Internal.RoutingRuleEngines;
using ShardingCore.Core.Internal.StreamMerge;
using ShardingCore.Core.ShardingAccessors;
using ShardingCore.Core.VirtualTables;
using ShardingCore.DbContexts;
using ShardingCore.DbContexts.VirtualDbContexts;
using ShardingCore.TableCreator;
namespace ShardingCore
{
@ -19,6 +24,13 @@ namespace ShardingCore
services.AddScoped<IStreamMergeContextFactory, StreamMergeContextFactory>();
services.AddScoped<IRouteRuleEngine, QueryRouteRuleEngines>();
services.AddScoped<IRoutingRuleEngineFactory, RoutingRuleEngineFactory>();
services.AddScoped<IVirtualDbContext, VirtualDbContext>();
services.AddSingleton<IShardingDbContextFactory, ShardingDbContextFactory>();
services.AddSingleton<IShardingTableCreator, ShardingTableCreator>();
services.AddSingleton<IVirtualTableManager, OneDbVirtualTableManager>();
services.AddSingleton(typeof(IVirtualTable<>), typeof(OneDbVirtualTable<>));
services.AddSingleton<IShardingAccessor, ShardingAccessor>();
services.AddSingleton<IShardingScopeFactory, ShardingScopeFactory>();
return services;
}
}

View File

@ -271,6 +271,100 @@ namespace ShardingCore.Extensions
{
return await ShardingSumAsync(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync(this IQueryable<int> source)
{
return await ShardingQueryable<double>.Create(source).AverageAsync();
}
public static double ShardingAverage(this IQueryable<int> source)
{
return ShardingQueryable<double>.Create(source).Average();
}
public static double ShardingAverage<T>(this IQueryable<T> source,Expression<Func<T,int>> keySelector)
{
return ShardingAverage(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync<T>(this IQueryable<T> source,Expression<Func<T,int>> keySelector)
{
return await ShardingAverageAsync(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync(this IQueryable<long> source)
{
return await ShardingQueryable<double>.Create(source).AverageAsync();
}
public static double ShardingAverage(this IQueryable<long> source)
{
return ShardingQueryable<double>.Create(source).Average();
}
public static double ShardingAverage<T>(this IQueryable<T> source,Expression<Func<T,long>> keySelector)
{
return ShardingAverage(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync<T>(this IQueryable<T> source,Expression<Func<T,long>> keySelector)
{
return await ShardingAverageAsync(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync(this IQueryable<double> source)
{
return await ShardingQueryable<double>.Create(source).AverageAsync();
}
public static double ShardingAverage(this IQueryable<double> source)
{
return ShardingQueryable<double>.Create(source).Average();
}
public static double ShardingAverage<T>(this IQueryable<T> source,Expression<Func<T,double>> keySelector)
{
return ShardingAverage(source.Select(keySelector));
}
public static async Task<double> ShardingAverageAsync<T>(this IQueryable<T> source,Expression<Func<T,double>> keySelector)
{
return await ShardingAverageAsync(source.Select(keySelector));
}
public static async Task<decimal> ShardingAverageAsync(this IQueryable<decimal> source)
{
return await ShardingQueryable<decimal>.Create(source).DecimalAverageAsync();
}
public static decimal ShardingAverage(this IQueryable<decimal> source)
{
return ShardingQueryable<decimal>.Create(source).DecimalAverage();
}
public static decimal ShardingAverage<T>(this IQueryable<T> source,Expression<Func<T,decimal>> keySelector)
{
return ShardingAverage(source.Select(keySelector));
}
public static async Task<decimal> ShardingAverageAsync<T>(this IQueryable<T> source,Expression<Func<T,decimal>> keySelector)
{
return await ShardingAverageAsync(source.Select(keySelector));
}
public static async Task<float> ShardingAverageAsync(this IQueryable<float> source)
{
return await ShardingQueryable<float>.Create(source).FloatAverageAsync();
}
public static float ShardingAverage(this IQueryable<float> source)
{
return ShardingQueryable<float>.Create(source).FloatAverage();
}
public static float ShardingAverage<T>(this IQueryable<T> source,Expression<Func<T,float>> keySelector)
{
return ShardingAverage(source.Select(keySelector));
}
public static async Task<float> ShardingAverageAsync<T>(this IQueryable<T> source,Expression<Func<T,float>> keySelector)
{
return await ShardingAverageAsync(source.Select(keySelector));
}
}
}

View File

@ -71,5 +71,26 @@ namespace ShardingCore.Extensions
{
return AsyncHelper.RunSync(() => queryable.MinAsync());
}
public static double Average<T>(this IShardingQueryable<T> queryable)
{
return AsyncHelper.RunSync(() => queryable.AverageAsync());
}
public static double LongAverage<T>(this IShardingQueryable<T> queryable)
{
return AsyncHelper.RunSync(() => queryable.LongAverageAsync());
}
public static double DoubleAverage<T>(this IShardingQueryable<T> queryable)
{
return AsyncHelper.RunSync(() => queryable.DoubleAverageAsync());
}
public static decimal DecimalAverage<T>(this IShardingQueryable<T> queryable)
{
return AsyncHelper.RunSync(() => queryable.DecimalAverageAsync());
}
public static float FloatAverage<T>(this IShardingQueryable<T> queryable)
{
return AsyncHelper.RunSync(() => queryable.FloatAverageAsync());
}
}
}

View File

@ -20,6 +20,9 @@
<Compile Include="..\..\src\ShardingCore\**\*.cs" />
<Compile Remove="..\..\src\ShardingCore\obj\**" />
<Compile Remove="..\..\src\ShardingCore\bin\**" />
<Compile Update="..\..\src\ShardingCore\Core\Internal\StreamMerge\StreamMergeExtension.cs">
<Link>Core\Internal\StreamMerge\StreamMergeExtension.cs</Link>
</Compile>
</ItemGroup>
<ItemGroup>

View File

@ -173,10 +173,18 @@ namespace ShardingCore.Test50.MySql
[Fact]
public async Task FirstOrDefault5()
{
var sysUserMod=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_1001").ShardingFirstOrDefaultAsync();
var sysUserMod=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_101").ShardingFirstOrDefaultAsync();
Assert.Null(sysUserMod);
var sysUserRange=await _virtualDbContext.Set<SysUserRange>().Where(o=>o.Name=="name_range_1001").ShardingFirstOrDefaultAsync();
Assert.Null(sysUserRange);
}
[Fact]
public async Task Count_Test()
{
var a=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_100").ShardingCountAsync();
Assert.Equal(1,a);
var b=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name!="name_100").ShardingCountAsync();
Assert.Equal(99,b);
}
}
}

View File

@ -35,10 +35,22 @@ namespace ShardingCore.Test50
{
var modascs=await _virtualDbContext.Set<SysUserMod>().OrderBy(o=>o.Age).ToShardingListAsync();
Assert.Equal(100,modascs.Count);
Assert.Equal(100,modascs.Last().Age);
var i = 1;
foreach (var age in modascs)
{
Assert.Equal(i,age.Age);
i++;
}
var moddescs=await _virtualDbContext.Set<SysUserMod>().OrderByDescending(o=>o.Age).ToShardingListAsync();
Assert.Equal(100,moddescs.Count);
Assert.Equal(1,moddescs.Last().Age);
var j = 100;
foreach (var age in moddescs)
{
Assert.Equal(j,age.Age);
j--;
}
}
[Fact]
public async Task ToList_Id_In_Test()
@ -51,6 +63,8 @@ namespace ShardingCore.Test50
Assert.Contains(sysUserMods, o =>o.Id==id);
Assert.Contains(sysUserRanges, o =>o.Id==id);
}
Assert.DoesNotContain(sysUserMods,o=>o.Age>4);
Assert.DoesNotContain(sysUserRanges,o=>o.Age>4);
}
[Fact]
public async Task ToList_Id_Eq_Test()
@ -73,6 +87,21 @@ namespace ShardingCore.Test50
Assert.DoesNotContain(ranges,o=>o.Id=="3");
}
[Fact]
public async Task ToList_Id_Not_Eq_Skip_Test()
{
var mods=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Id!="3").OrderBy(o=>o.Age).Skip(2).ToShardingListAsync();
Assert.Equal(97,mods.Count);
Assert.DoesNotContain(mods,o=>o.Id=="3");
Assert.Equal(4,mods[0].Age);
Assert.Equal(5,mods[1].Age);
var modsDesc=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Id!="3").OrderByDescending(o=>o.Age).Skip(13).ToShardingListAsync();
Assert.Equal(86,modsDesc.Count);
Assert.DoesNotContain(mods,o=>o.Id=="3");
Assert.Equal(87,modsDesc[0].Age);
Assert.Equal(86,modsDesc[1].Age);
}
[Fact]
public async Task ToList_Name_Eq_Test()
{
var mods=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_3").ToShardingListAsync();
@ -101,8 +130,15 @@ namespace ShardingCore.Test50
[Fact]
public async Task FirstOrDefault_Order_By_Id_Test()
{
var sysUserModAge=await _virtualDbContext.Set<SysUserMod>().OrderBy(o=>o.Age).ShardingFirstOrDefaultAsync();
Assert.True(sysUserModAge!=null&&sysUserModAge.Id=="1");
var sysUserModAgeDesc=await _virtualDbContext.Set<SysUserMod>().OrderByDescending(o=>o.Age).ShardingFirstOrDefaultAsync();
Assert.True(sysUserModAgeDesc!=null&&sysUserModAgeDesc.Id=="100");
var sysUserMod=await _virtualDbContext.Set<SysUserMod>().OrderBy(o=>o.Id).ShardingFirstOrDefaultAsync();
Assert.True(sysUserMod!=null&&sysUserMod.Id=="1");
var sysUserModDesc=await _virtualDbContext.Set<SysUserMod>().OrderByDescending(o=>o.Id).ShardingFirstOrDefaultAsync();
Assert.True(sysUserModDesc!=null&&sysUserModDesc.Id=="99");
var sysUserRange=await _virtualDbContext.Set<SysUserRange>().OrderBy(o=>o.Id).ShardingFirstOrDefaultAsync();
Assert.True(sysUserRange!=null&&sysUserRange.Id=="1");
}
@ -137,10 +173,18 @@ namespace ShardingCore.Test50
[Fact]
public async Task FirstOrDefault5()
{
var sysUserMod=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_1001").ShardingFirstOrDefaultAsync();
var sysUserMod=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_101").ShardingFirstOrDefaultAsync();
Assert.Null(sysUserMod);
var sysUserRange=await _virtualDbContext.Set<SysUserRange>().Where(o=>o.Name=="name_range_1001").ShardingFirstOrDefaultAsync();
Assert.Null(sysUserRange);
}
[Fact]
public async Task Count_Test()
{
var a=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name=="name_100").ShardingCountAsync();
Assert.Equal(1,a);
var b=await _virtualDbContext.Set<SysUserMod>().Where(o=>o.Name!="name_100").ShardingCountAsync();
Assert.Equal(99,b);
}
}
}