Refs #30581 -- Added Q.check() hook.
This commit is contained in:
parent
1109e66990
commit
5d91dc8ee3
|
@ -8,12 +8,16 @@ circular import difficulties.
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
|
from django.db import DEFAULT_DB_ALIAS, DatabaseError
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.utils import tree
|
from django.utils import tree
|
||||||
|
|
||||||
|
logger = logging.getLogger("django.db.models")
|
||||||
|
|
||||||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||||
# describe the relation in Model terms (model Options and Fields for both
|
# describe the relation in Model terms (model Options and Fields for both
|
||||||
# sides of the relation. The join_field is the field backing the relation.
|
# sides of the relation. The join_field is the field backing the relation.
|
||||||
|
@ -110,6 +114,31 @@ class Q(tree.Node):
|
||||||
else:
|
else:
|
||||||
yield child
|
yield child
|
||||||
|
|
||||||
|
def check(self, against, using=DEFAULT_DB_ALIAS):
|
||||||
|
"""
|
||||||
|
Do a database query to check if the expressions of the Q instance
|
||||||
|
matches against the expressions.
|
||||||
|
"""
|
||||||
|
# Avoid circular imports.
|
||||||
|
from django.db.models import Value
|
||||||
|
from django.db.models.sql import Query
|
||||||
|
from django.db.models.sql.constants import SINGLE
|
||||||
|
|
||||||
|
query = Query(None)
|
||||||
|
for name, value in against.items():
|
||||||
|
if not hasattr(value, "resolve_expression"):
|
||||||
|
value = Value(value)
|
||||||
|
query.add_annotation(value, name, select=False)
|
||||||
|
query.add_annotation(Value(1), "_check")
|
||||||
|
# This will raise a FieldError if a field is missing in "against".
|
||||||
|
query.add_q(self)
|
||||||
|
compiler = query.get_compiler(using=using)
|
||||||
|
try:
|
||||||
|
return compiler.execute_sql(SINGLE) is not None
|
||||||
|
except DatabaseError as e:
|
||||||
|
logger.warning("Got a database error calling check() on %r: %s", self, e)
|
||||||
|
return True
|
||||||
|
|
||||||
def deconstruct(self):
|
def deconstruct(self):
|
||||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||||
if path.startswith("django.db.models.query_utils"):
|
if path.startswith("django.db.models.query_utils"):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
BooleanField,
|
BooleanField,
|
||||||
Exists,
|
Exists,
|
||||||
|
@ -10,7 +11,7 @@ from django.db.models import (
|
||||||
from django.db.models.expressions import RawSQL
|
from django.db.models.expressions import RawSQL
|
||||||
from django.db.models.functions import Lower
|
from django.db.models.functions import Lower
|
||||||
from django.db.models.sql.where import NothingNode
|
from django.db.models.sql.where import NothingNode
|
||||||
from django.test import SimpleTestCase
|
from django.test import SimpleTestCase, TestCase
|
||||||
|
|
||||||
from .models import Tag
|
from .models import Tag
|
||||||
|
|
||||||
|
@ -214,3 +215,40 @@ class QTests(SimpleTestCase):
|
||||||
)
|
)
|
||||||
flatten = list(q.flatten())
|
flatten = list(q.flatten())
|
||||||
self.assertEqual(len(flatten), 7)
|
self.assertEqual(len(flatten), 7)
|
||||||
|
|
||||||
|
|
||||||
|
class QCheckTests(TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
q = Q(price__gt=20)
|
||||||
|
self.assertIs(q.check({"price": 30}), True)
|
||||||
|
self.assertIs(q.check({"price": 10}), False)
|
||||||
|
|
||||||
|
def test_expression(self):
|
||||||
|
q = Q(name="test")
|
||||||
|
self.assertIs(q.check({"name": Lower(Value("TeSt"))}), True)
|
||||||
|
self.assertIs(q.check({"name": Value("other")}), False)
|
||||||
|
|
||||||
|
def test_missing_field(self):
|
||||||
|
q = Q(description__startswith="prefix")
|
||||||
|
msg = "Cannot resolve keyword 'description' into field."
|
||||||
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
q.check({"name": "test"})
|
||||||
|
|
||||||
|
def test_boolean_expression(self):
|
||||||
|
q = Q(ExpressionWrapper(Q(price__gt=20), output_field=BooleanField()))
|
||||||
|
self.assertIs(q.check({"price": 25}), True)
|
||||||
|
self.assertIs(q.check({"price": Value(10)}), False)
|
||||||
|
|
||||||
|
def test_rawsql(self):
|
||||||
|
"""
|
||||||
|
RawSQL expressions cause a database error because "price" cannot be
|
||||||
|
replaced by its value. In this case, Q.check() logs a warning and
|
||||||
|
return True.
|
||||||
|
"""
|
||||||
|
q = Q(RawSQL("price > %s", params=(20,), output_field=BooleanField()))
|
||||||
|
with self.assertLogs("django.db.models", "WARNING") as cm:
|
||||||
|
self.assertIs(q.check({"price": 10}), True)
|
||||||
|
self.assertIn(
|
||||||
|
f"Got a database error calling check() on {q!r}: ",
|
||||||
|
cm.records[0].getMessage(),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue