From 222bed1011e580cb64978f53031aa2b07450c802 Mon Sep 17 00:00:00 2001 From: David Sanders Date: Thu, 29 Sep 2022 23:25:40 +1000 Subject: [PATCH] check compiler + get Oracle working with nulls --- django/db/backends/mysql/compiler.py | 4 ++ django/db/backends/oracle/compiler.py | 55 +++++++++++++++++++++++++ django/db/backends/oracle/operations.py | 2 + django/db/models/query_utils.py | 30 ++------------ django/db/models/sql/compiler.py | 11 +++++ django/db/models/sql/subqueries.py | 37 ++++++++++++++++- tests/constraints/tests.py | 11 +---- 7 files changed, 112 insertions(+), 38 deletions(-) create mode 100644 django/db/backends/oracle/compiler.py diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index bd2715fb43..33ef2a12dc 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -78,3 +78,7 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass + + +class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler): + pass diff --git a/django/db/backends/oracle/compiler.py b/django/db/backends/oracle/compiler.py new file mode 100644 index 0000000000..6fdf6d3163 --- /dev/null +++ b/django/db/backends/oracle/compiler.py @@ -0,0 +1,55 @@ +from django.db.models.sql import compiler + + +class SQLCompiler(compiler.SQLCompiler): + pass + + +class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): + pass + + +class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): + pass + + +class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): + pass + + +class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): + pass + + +class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler): + # Oracle doesn't support boolean types yet PL/SQL does, meaning that although we + # can't compare boolean expressions to NULL in SELECT or WHERE clauses we can in + # a function which we can then call from a WHERE clause. + check_sql = """\ +WITH + FUNCTION f RETURN NUMBER IS + b BOOLEAN; + BEGIN + b := COALESCE(%s, TRUE); + IF b THEN + RETURN 1; + ELSE + RETURN 0; + END IF; + END f; +SELECT 1 FROM dual WHERE f = 1 +""" + + def as_sql(self): + # Avoid case wrapping. + self.connection.vendor = "sqlite" + condition, params = self.compile(self.query.where) + self.connection.vendor = "oracle" + + # Oracle doesn't allow binding params within functions so bind manually + # with quote_value(). + condition %= tuple( + self.connection.schema_editor().quote_value(param) for param in params + ) + check_sql = self.check_sql % condition + return check_sql, [] diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 78f998183e..fad29a51aa 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -19,6 +19,8 @@ from .utils import BulkInsertMapper, InsertVar, Oracle_datetime class DatabaseOperations(BaseDatabaseOperations): + compiler_module = "django.db.backends.oracle.compiler" + # Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields. # SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by # SmallAutoField, to preserve backward compatibility. diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 4a83fc380d..0cbc0be823 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -7,16 +7,13 @@ circular import difficulties. """ import functools import inspect -import logging from collections import namedtuple from django.core.exceptions import FieldError -from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections +from django.db import DEFAULT_DB_ALIAS from django.db.models.constants import LOOKUP_SEP from django.utils import tree -logger = logging.getLogger("django.db.models") - # PathInfo is used when converting lookups (fk__somecol). The contents # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. @@ -114,29 +111,10 @@ class Q(tree.Node): Do a database query to check if the expressions of the Q instance matches against the expressions. """ - # Avoid circular imports. - from django.db.models import BooleanField, Value - from django.db.models.functions import Coalesce - from django.db.models.sql import Query - from django.db.models.sql.constants import SINGLE + from django.db.models.sql import CheckQuery - query = Query(None) - for name, value in against.items(): - if not hasattr(value, "resolve_expression"): - value = Value(value) - query.add_annotation(value, name, select=False) - query.add_annotation(Value(1), "_check") - # This will raise a FieldError if a field is missing in "against". - if connections[using].features.supports_comparing_boolean_expr: - query.add_q(Q(Coalesce(self, True, output_field=BooleanField()))) - else: - query.add_q(self) - compiler = query.get_compiler(using=using) - try: - return compiler.execute_sql(SINGLE) is not None - except DatabaseError as e: - logger.warning("Got a database error calling check() on %r: %s", self, e) - return True + query = CheckQuery() + return query.do_check(self, against, using) def deconstruct(self): path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index c7efa469a8..ca06e14091 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2028,6 +2028,17 @@ class SQLAggregateCompiler(SQLCompiler): return sql, params +class SQLCheckCompiler(SQLCompiler): + col_count = 1 + check_sql = """\ +SELECT 1 WHERE COALESCE(%s, TRUE) +""" + + def as_sql(self): + condition, params = self.compile(self.query.where) + return self.check_sql % condition, params + + def cursor_iter(cursor, sentinel, col_count, itersize): """ Yield blocks of rows from a cursor and ensure the cursor is closed when diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index d8a246d369..fa35bf0d40 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -2,11 +2,21 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ +import logging + from django.core.exceptions import FieldError -from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS +from django.db import DatabaseError +from django.db.models.sql.constants import ( + CURSOR, + GET_ITERATOR_CHUNK_SIZE, + NO_RESULTS, + SINGLE, +) from django.db.models.sql.query import Query -__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"] +__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery", "CheckQuery"] + +logger = logging.getLogger("django.db.models") class DeleteQuery(Query): @@ -169,3 +179,26 @@ class AggregateQuery(Query): def __init__(self, model, inner_query): self.inner_query = inner_query super().__init__(model) + + +class CheckQuery(Query): + compiler = "SQLCheckCompiler" + + def __init__(self): + super().__init__(model=None) + + def do_check(self, q, against, using): + from django.db.models import Value + + for name, value in against.items(): + if not hasattr(value, "resolve_expression"): + value = Value(value) + self.add_annotation(value, name, select=False) + # This will raise a FieldError if a field is missing in "against". + self.add_q(q) + compiler = self.get_compiler(using=using) + try: + return compiler.execute_sql(SINGLE) is not None + except DatabaseError as e: + logger.warning("Got a database error calling check() on %r: %s", self, e) + return True diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 5a498f0d73..c105073177 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -6,7 +6,7 @@ from django.db.models import F from django.db.models.constraints import BaseConstraint from django.db.models.functions import Lower from django.db.transaction import atomic -from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from .models import ( ChildModel, @@ -234,7 +234,6 @@ class CheckConstraintTests(TestCase): constraint.validate(Product, Product(price=501, discounted_price=5)) constraint.validate(Product, Product(price=499, discounted_price=5)) - @skipUnlessDBFeature("supports_comparing_boolean_expr") def test_validate_nullable_field_with_none(self): # Nullable fields should be considered valid on None values. constraint = models.CheckConstraint( @@ -243,14 +242,6 @@ class CheckConstraintTests(TestCase): ) constraint.validate(Product, Product()) - @skipIfDBFeature("supports_comparing_boolean_expr") - def test_validate_nullable_field_with_isnull(self): - constraint = models.CheckConstraint( - check=models.Q(price__gte=0) | models.Q(price__isnull=True), - name="positive_price", - ) - constraint.validate(Product, Product()) - class UniqueConstraintTests(TestCase): @classmethod