diff --git a/nuget-publish.bat b/nuget-publish.bat index 263fe3b9..ac176ad9 100644 --- a/nuget-publish.bat +++ b/nuget-publish.bat @@ -1,8 +1,8 @@ :start ::定义版本 -set EFCORE2=2.3.1.11 -set EFCORE3=3.3.1.11 -set EFCORE5=5.3.1.11 +set EFCORE2=2.3.1.12 +set EFCORE3=3.3.1.12 +set EFCORE5=5.3.1.12 ::删除所有bin与obj下的文件 @echo off diff --git a/src/ShardingCore/Sharding/Visitors/QueryableRouteDiscoverVisitor.cs b/src/ShardingCore/Sharding/Visitors/QueryableRouteDiscoverVisitor.cs index 68545b30..c109bd9c 100644 --- a/src/ShardingCore/Sharding/Visitors/QueryableRouteDiscoverVisitor.cs +++ b/src/ShardingCore/Sharding/Visitors/QueryableRouteDiscoverVisitor.cs @@ -2,6 +2,7 @@ using System; using System.Collections; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using ShardingCore.Core.VirtualDatabase; using ShardingCore.Core.VirtualRoutes; using ShardingCore.Core.VirtualTables; @@ -64,10 +65,16 @@ namespace ShardingCore.Core.Internal.Visitors private bool IsConstantOrMember(Expression expression) { return expression is ConstantExpression - || (expression is MemberExpression member && (member.Expression is ConstantExpression || member.Expression is MemberExpression || member.Expression is MemberExpression)); + || (expression is MemberExpression member && (member.Expression is ConstantExpression || member.Expression is MemberExpression || member.Expression is MemberExpression)) + || expression is MethodCallExpression; } - private object GetFieldValue(Expression expression) + private bool IsMethodCall(Expression expression) + { + return expression is MethodCallExpression; + } + + private object GetShardingKeyValue(Expression expression) { if (expression is ConstantExpression) return (expression as ConstantExpression).Value; @@ -81,9 +88,27 @@ namespace ShardingCore.Core.Internal.Visitors if (expression is MemberExpression member1Expression) { + var target = GetShardingKeyValue(member1Expression.Expression); + if (member1Expression.Member is FieldInfo field) + return field.GetValue(target); + if (member1Expression.Member is PropertyInfo property) + return property.GetValue(target); return Expression.Lambda(member1Expression).Compile().DynamicInvoke(); } + if (expression is MethodCallExpression methodCallExpression) + { + return Expression.Lambda(methodCallExpression).Compile().DynamicInvoke(); + //return methodCallExpression.Method.Invoke( + // GetShardingKeyValue(methodCallExpression.Object), + // methodCallExpression.Arguments + // .Select( + // a => GetShardingKeyValue(a) + // ) + // .ToArray() + //); + } + throw new ShardingKeyGetValueException("cant get value " + expression); } @@ -213,18 +238,18 @@ namespace ShardingCore.Core.Internal.Visitors //单个 else { - bool paramterAtLeft; + bool paramterAtLeft=false; object value = null; - if (IsShardingKey(binaryExpression.Left) && IsConstantOrMember(binaryExpression.Right)) + if (IsShardingKey(binaryExpression.Left)&&IsConstantOrMember(binaryExpression.Right)) { paramterAtLeft = true; - value = GetFieldValue(binaryExpression.Right); + value = GetShardingKeyValue(binaryExpression.Right); } else if (IsConstantOrMember(binaryExpression.Left) && IsShardingKey(binaryExpression.Right)) { paramterAtLeft = false; - value = GetFieldValue(binaryExpression.Left); + value = GetShardingKeyValue(binaryExpression.Left); } else return x => true; diff --git a/test/ShardingCore.Test50/ShardingTest.cs b/test/ShardingCore.Test50/ShardingTest.cs index 73015abe..c2d9ff8f 100644 --- a/test/ShardingCore.Test50/ShardingTest.cs +++ b/test/ShardingCore.Test50/ShardingTest.cs @@ -172,7 +172,7 @@ namespace ShardingCore.Test50 public async Task ToList_Id_In_Test() { var ids = new[] {"1", "2", "3", "4"}; - var sysUserMods = await _virtualDbContext.Set().Where(o => ids.Contains(o.Id)).ToListAsync(); + var sysUserMods = await _virtualDbContext.Set().Where(o => new List { "1", "2", "3", "4" }.Contains(o.Id)).ToListAsync(); foreach (var id in ids) { Assert.Contains(sysUserMods, o => o.Id == id); @@ -184,8 +184,11 @@ namespace ShardingCore.Test50 [Fact] public async Task ToList_Id_Eq_Test() { - var mods = await _virtualDbContext.Set().Where(o => o.Id == "3").ToListAsync(); + var id= 3; + var mods = await _virtualDbContext.Set().Where(o => o.Id == id.ToString()).ToListAsync(); Assert.Single(mods); + var mods1 = await _virtualDbContext.Set().Where(o => o.Id == "4").ToListAsync(); + Assert.Single(mods1); Assert.Equal("3", mods[0].Id); }