sharding/samples/Samples.AbpSharding/AbstractShardingAbpDbContex...

565 lines
19 KiB
C#
Raw Normal View History

2021-08-21 13:20:35 +08:00
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Abp.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Storage;
using ShardingCore;
using ShardingCore.Core;
2021-08-25 19:02:44 +08:00
using ShardingCore.Core.VirtualRoutes.RouteTails.Abstractions;
2021-08-21 13:20:35 +08:00
using ShardingCore.Core.VirtualRoutes.TableRoutes;
using ShardingCore.Core.VirtualTables;
using ShardingCore.DbContexts;
using ShardingCore.DbContexts.ShardingDbContexts;
using ShardingCore.Exceptions;
2021-08-21 13:20:35 +08:00
using ShardingCore.Extensions;
using ShardingCore.Sharding.Abstractions;
namespace Samples.AbpSharding
{
2021-08-21 15:32:51 +08:00
public abstract class AbstractShardingAbpDbContext<T> : AbpDbContext, IShardingTableDbContext<T> where T : AbpDbContext, IShardingTableDbContext
2021-08-21 13:20:35 +08:00
{
2021-08-21 13:20:35 +08:00
private readonly ConcurrentDictionary<string, DbContext> _dbContextCaches = new ConcurrentDictionary<string, DbContext>();
private readonly IVirtualTableManager _virtualTableManager;
private readonly IRouteTailFactory _routeTailFactory;
2021-08-21 13:20:35 +08:00
private readonly IShardingDbContextFactory _shardingDbContextFactory;
private readonly IShardingDbContextOptionsBuilderConfig _shardingDbContextOptionsBuilderConfig;
private DbContextOptions<T> _dbContextOptions;
private readonly object CREATELOCK = new object();
public AbstractShardingAbpDbContext(DbContextOptions options) : base(options)
{
_shardingDbContextFactory = ShardingContainer.GetService<IShardingDbContextFactory>();
_virtualTableManager = ShardingContainer.GetService<IVirtualTableManager>();
_routeTailFactory = ShardingContainer.GetService<IRouteTailFactory>();
2021-08-21 13:20:35 +08:00
_shardingDbContextOptionsBuilderConfig = ShardingContainer
.GetService<IEnumerable<IShardingDbContextOptionsBuilderConfig>>()
.FirstOrDefault(o => o.ShardingDbContextType == ShardingDbContextType);
}
public abstract Type ShardingDbContextType { get; }
public Type ActualDbContextType => typeof(T);
private DbContextOptionsBuilder<T> CreateDbContextOptionBuilder()
{
Type type = typeof(DbContextOptionsBuilder<>);
type = type.MakeGenericType(ActualDbContextType);
return (DbContextOptionsBuilder<T>) Activator.CreateInstance(type);
2021-08-21 13:20:35 +08:00
}
private DbContextOptions<T> CreateShareDbContextOptions()
{
var dbContextOptionBuilder = CreateDbContextOptionBuilder();
var dbConnection = Database.GetDbConnection();
_shardingDbContextOptionsBuilderConfig.UseDbContextOptionsBuilder(dbConnection, dbContextOptionBuilder);
return dbContextOptionBuilder.Options;
}
private DbContextOptions<T> CreateMonopolyDbContextOptions()
{
var dbContextOptionBuilder = CreateDbContextOptionBuilder();
var connectionString = Database.GetConnectionString();
_shardingDbContextOptionsBuilderConfig.UseDbContextOptionsBuilder(connectionString,dbContextOptionBuilder);
2021-08-21 13:20:35 +08:00
return dbContextOptionBuilder.Options;
}
private ShardingDbContextOptions GetShareShardingDbContextOptions(IRouteTail routeTail)
2021-08-21 13:20:35 +08:00
{
if (_dbContextOptions == null)
{
lock (CREATELOCK)
{
if (_dbContextOptions == null)
{
_dbContextOptions = CreateShareDbContextOptions();
}
}
}
return new ShardingDbContextOptions(_dbContextOptions, routeTail);
2021-08-21 13:20:35 +08:00
}
private ShardingDbContextOptions CetMonopolyShardingDbContextOptions(IRouteTail routeTail)
2021-08-21 13:20:35 +08:00
{
return new ShardingDbContextOptions(CreateMonopolyDbContextOptions(), routeTail);
2021-08-21 13:20:35 +08:00
}
public DbContext GetDbContext(bool track, IRouteTail routeTail)
2021-08-21 13:20:35 +08:00
{
if (track)
{
if (routeTail.IsMultiEntityQuery())
throw new ShardingCoreException("multi route not support track");
if(!(routeTail is ISingleQueryRouteTail singleQueryRouteTail))
throw new ShardingCoreException("multi route not support track");
var cacheKey = routeTail.GetRouteTailIdenty();
if (!_dbContextCaches.TryGetValue(cacheKey, out var dbContext))
2021-08-21 13:20:35 +08:00
{
dbContext = _shardingDbContextFactory.Create(ShardingDbContextType, GetShareShardingDbContextOptions(routeTail));
_dbContextCaches.TryAdd(cacheKey, dbContext);
2021-08-21 13:20:35 +08:00
}
return dbContext;
}
else
{
return _shardingDbContextFactory.Create(ShardingDbContextType, CetMonopolyShardingDbContextOptions(routeTail));
2021-08-21 13:20:35 +08:00
}
}
public bool IsBeginTransaction => Database.CurrentTransaction != null;
public DbContext CreateGenericDbContext<T>(T entity) where T : class
{
var tail = string.Empty;
2021-08-21 13:20:35 +08:00
if (entity.IsShardingTable())
{
var physicTable = _virtualTableManager.GetVirtualTable(ShardingDbContextType, entity.GetType()).RouteTo(new TableRouteConfig(null, entity as IShardingTable, null))[0];
tail = physicTable.Tail;
}
return GetDbContext(true, _routeTailFactory.Create(tail));
2021-08-21 13:20:35 +08:00
}
public override EntityEntry Add(object entity)
{
return CreateGenericDbContext(entity).Add(entity);
}
public override EntityEntry<TEntity> Add<TEntity>(TEntity entity)
{
return CreateGenericDbContext(entity).Add(entity);
}
public override ValueTask<EntityEntry<TEntity>> AddAsync<TEntity>(TEntity entity, CancellationToken cancellationToken = new CancellationToken())
{
return CreateGenericDbContext(entity).AddAsync(entity, cancellationToken);
}
public override ValueTask<EntityEntry> AddAsync(object entity, CancellationToken cancellationToken = new CancellationToken())
{
return CreateGenericDbContext(entity).AddAsync(entity, cancellationToken);
}
public override void AddRange(params object[] entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.AddRange(group.Select(o => o.Entity));
}
}
public override void AddRange(IEnumerable<object> entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.AddRange(group.Select(o => o.Entity));
}
}
public override async Task AddRangeAsync(params object[] entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
await group.Key.AddRangeAsync(group.Select(o => o.Entity));
}
}
public override async Task AddRangeAsync(IEnumerable<object> entities, CancellationToken cancellationToken = new CancellationToken())
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
await group.Key.AddRangeAsync(group.Select(o => o.Entity));
}
}
public override EntityEntry<TEntity> Attach<TEntity>(TEntity entity)
{
return CreateGenericDbContext(entity).Attach(entity);
}
public override EntityEntry Attach(object entity)
{
return CreateGenericDbContext(entity).Attach(entity);
}
public override void AttachRange(params object[] entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.AttachRange(group.Select(o => o.Entity));
}
}
public override void AttachRange(IEnumerable<object> entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.AttachRange(group.Select(o => o.Entity));
}
}
//public override DatabaseFacade Database => _dbContextCaches.Any()
// ? _dbContextCaches.First().Value.Database
// : GetDbContext(true, string.Empty).Database;
public override EntityEntry<TEntity> Entry<TEntity>(TEntity entity)
{
return CreateGenericDbContext(entity).Entry(entity);
}
public override EntityEntry Entry(object entity)
{
return CreateGenericDbContext(entity).Entry(entity);
}
public override EntityEntry<TEntity> Update<TEntity>(TEntity entity)
{
return CreateGenericDbContext(entity).Update(entity);
}
public override EntityEntry Update(object entity)
{
return CreateGenericDbContext(entity).Update(entity);
}
public override void UpdateRange(params object[] entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.UpdateRange(group.Select(o => o.Entity));
}
}
public override void UpdateRange(IEnumerable<object> entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.UpdateRange(group.Select(o => o.Entity));
}
}
public override EntityEntry<TEntity> Remove<TEntity>(TEntity entity)
{
return CreateGenericDbContext(entity).Remove(entity);
}
public override EntityEntry Remove(object entity)
{
return CreateGenericDbContext(entity).Remove(entity);
}
public override void RemoveRange(params object[] entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.RemoveRange(group.Select(o => o.Entity));
}
}
public override void RemoveRange(IEnumerable<object> entities)
{
var groups = entities.Select(o =>
{
var dbContext = CreateGenericDbContext(o);
return new
{
DbContext = dbContext,
Entity = o
};
}).GroupBy(g => g.DbContext);
foreach (var group in groups)
{
group.Key.RemoveRange(group.Select(o => o.Entity));
}
}
public override int SaveChanges()
{
var isBeginTransaction = IsBeginTransaction;
//如果是内部开的事务就内部自己消化
if (!isBeginTransaction)
{
Database.BeginTransaction();
}
int i = 0;
try
{
foreach (var dbContextCache in _dbContextCaches)
{
dbContextCache.Value.Database.UseTransaction(Database.CurrentTransaction.GetDbTransaction());
i += dbContextCache.Value.SaveChanges();
}
if (!isBeginTransaction)
Database.CurrentTransaction.Commit();
}
finally
{
if (!isBeginTransaction)
{
Database.CurrentTransaction?.Dispose();
foreach (var dbContextCache in _dbContextCaches)
{
dbContextCache.Value.Database.UseTransaction(null);
}
}
}
return i;
}
public override int SaveChanges(bool acceptAllChangesOnSuccess)
{
var isBeginTransaction = IsBeginTransaction;
//如果是内部开的事务就内部自己消化
if (!isBeginTransaction)
{
Database.BeginTransaction();
}
int i = 0;
try
{
foreach (var dbContextCache in _dbContextCaches)
{
dbContextCache.Value.Database.UseTransaction(Database.CurrentTransaction.GetDbTransaction());
i += dbContextCache.Value.SaveChanges(acceptAllChangesOnSuccess);
}
if (!isBeginTransaction)
Database.CurrentTransaction.Commit();
}
finally
{
if (!isBeginTransaction)
{
Database.CurrentTransaction?.Dispose();
foreach (var dbContextCache in _dbContextCaches)
{
dbContextCache.Value.Database.UseTransaction(null);
}
}
}
return i;
}
public override async Task<int> SaveChangesAsync(CancellationToken cancellationToken = new CancellationToken())
{
var isBeginTransaction = IsBeginTransaction;
//如果是内部开的事务就内部自己消化
if (!isBeginTransaction)
{
await Database.BeginTransactionAsync(cancellationToken);
}
int i = 0;
try
{
foreach (var dbContextCache in _dbContextCaches)
{
await dbContextCache.Value.Database.UseTransactionAsync(Database.CurrentTransaction.GetDbTransaction(), cancellationToken: cancellationToken);
i += await dbContextCache.Value.SaveChangesAsync(cancellationToken);
}
if (!isBeginTransaction)
await Database.CurrentTransaction.CommitAsync(cancellationToken);
}
finally
{
if (!isBeginTransaction)
{
}
if (Database.CurrentTransaction != null)
{
await Database.CurrentTransaction.DisposeAsync();
foreach (var dbContextCache in _dbContextCaches)
{
await dbContextCache.Value.Database.UseTransactionAsync(null, cancellationToken: cancellationToken);
}
}
}
return i;
}
public override async Task<int> SaveChangesAsync(bool acceptAllChangesOnSuccess, CancellationToken cancellationToken = new CancellationToken())
{
var isBeginTransaction = IsBeginTransaction;
//如果是内部开的事务就内部自己消化
if (!isBeginTransaction)
{
await Database.BeginTransactionAsync(cancellationToken);
}
int i = 0;
try
{
foreach (var dbContextCache in _dbContextCaches)
{
await dbContextCache.Value.Database.UseTransactionAsync(Database.CurrentTransaction.GetDbTransaction(), cancellationToken: cancellationToken);
i += await dbContextCache.Value.SaveChangesAsync(acceptAllChangesOnSuccess, cancellationToken);
}
if (!isBeginTransaction)
await Database.CurrentTransaction.CommitAsync(cancellationToken);
}
finally
{
if (!isBeginTransaction)
if (Database.CurrentTransaction != null)
{
await Database.CurrentTransaction.DisposeAsync();
foreach (var dbContextCache in _dbContextCaches)
{
await dbContextCache.Value.Database.UseTransactionAsync(null, cancellationToken: cancellationToken);
}
}
}
return i;
}
public override void Dispose()
{
foreach (var dbContextCache in _dbContextCaches)
{
try
{
dbContextCache.Value.Dispose();
}
catch (Exception e)
{
Console.WriteLine(e);
}
}
base.Dispose();
}
public override async ValueTask DisposeAsync()
{
foreach (var dbContextCache in _dbContextCaches)
{
try
{
await dbContextCache.Value.DisposeAsync();
}
catch (Exception e)
{
Console.WriteLine(e);
}
}
await base.DisposeAsync();
}
}
}