Replaced Expression.replace_references() with .replace_expressions().

The latter allows for more generic use cases beyond the currently
limited ones constraints validation has.

Refs #28333, #30581.
This commit is contained in:
Simon Charette 2022-08-10 07:51:07 -04:00 committed by Mariusz Felisiak
parent 8533a6af8d
commit 35911078fa
3 changed files with 17 additions and 10 deletions

View File

@ -198,6 +198,7 @@ class ExclusionConstraint(BaseConstraint):
replacement_map = instance._get_field_value_map( replacement_map = instance._get_field_value_map(
meta=model._meta, exclude=exclude meta=model._meta, exclude=exclude
) )
replacements = {F(field): value for field, value in replacement_map.items()}
lookups = [] lookups = []
for idx, (expression, operator) in enumerate(self.expressions): for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str): if isinstance(expression, str):
@ -210,7 +211,7 @@ class ExclusionConstraint(BaseConstraint):
for expr in expression.flatten(): for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude: if isinstance(expr, F) and expr.name in exclude:
return return
rhs_expression = expression.replace_references(replacement_map) rhs_expression = expression.replace_expressions(replacements)
# Remove OpClass because it only has sense during the constraint # Remove OpClass because it only has sense during the constraint
# creation. # creation.
if isinstance(expression, OpClass): if isinstance(expression, OpClass):

View File

@ -332,11 +332,14 @@ class UniqueConstraint(BaseConstraint):
return return
elif isinstance(expression, F) and expression.name in exclude: elif isinstance(expression, F) and expression.name in exclude:
return return
replacement_map = instance._get_field_value_map( replacements = {
meta=model._meta, exclude=exclude F(field): value
) for field, value in instance._get_field_value_map(
meta=model._meta, exclude=exclude
).items()
}
expressions = [ expressions = [
Exact(expr, expr.replace_references(replacement_map)) Exact(expr, expr.replace_expressions(replacements))
for expr in self.expressions for expr in self.expressions
] ]
queryset = queryset.filter(*expressions) queryset = queryset.filter(*expressions)

View File

@ -389,12 +389,15 @@ class BaseExpression:
) )
return clone return clone
def replace_references(self, references_map): def replace_expressions(self, replacements):
if replacement := replacements.get(self):
return replacement
clone = self.copy() clone = self.copy()
source_expressions = clone.get_source_expressions()
clone.set_source_expressions( clone.set_source_expressions(
[ [
expr.replace_references(references_map) expr.replace_expressions(replacements) if expr else None
for expr in self.get_source_expressions() for expr in source_expressions
] ]
) )
return clone return clone
@ -808,8 +811,8 @@ class F(Combinable):
): ):
return query.resolve_ref(self.name, allow_joins, reuse, summarize) return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def replace_references(self, references_map): def replace_expressions(self, replacements):
return references_map.get(self.name, self) return replacements.get(self, self)
def asc(self, **kwargs): def asc(self, **kwargs):
return OrderBy(self, **kwargs) return OrderBy(self, **kwargs)