check compiler + get Oracle working with nulls

This commit is contained in:
David Sanders 2022-09-29 23:25:40 +10:00
parent 19e6efa50b
commit 222bed1011
7 changed files with 112 additions and 38 deletions

View File

@ -78,3 +78,7 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass
class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler):
pass

View File

@ -0,0 +1,55 @@
from django.db.models.sql import compiler
class SQLCompiler(compiler.SQLCompiler):
pass
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass
class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler):
# Oracle doesn't support boolean types yet PL/SQL does, meaning that although we
# can't compare boolean expressions to NULL in SELECT or WHERE clauses we can in
# a function which we can then call from a WHERE clause.
check_sql = """\
WITH
FUNCTION f RETURN NUMBER IS
b BOOLEAN;
BEGIN
b := COALESCE(%s, TRUE);
IF b THEN
RETURN 1;
ELSE
RETURN 0;
END IF;
END f;
SELECT 1 FROM dual WHERE f = 1
"""
def as_sql(self):
# Avoid case wrapping.
self.connection.vendor = "sqlite"
condition, params = self.compile(self.query.where)
self.connection.vendor = "oracle"
# Oracle doesn't allow binding params within functions so bind manually
# with quote_value().
condition %= tuple(
self.connection.schema_editor().quote_value(param) for param in params
)
check_sql = self.check_sql % condition
return check_sql, []

View File

@ -19,6 +19,8 @@ from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.oracle.compiler"
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
# SmallAutoField, to preserve backward compatibility.

View File

@ -7,16 +7,13 @@ circular import difficulties.
"""
import functools
import inspect
import logging
from collections import namedtuple
from django.core.exceptions import FieldError
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections
from django.db import DEFAULT_DB_ALIAS
from django.db.models.constants import LOOKUP_SEP
from django.utils import tree
logger = logging.getLogger("django.db.models")
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
@ -114,29 +111,10 @@ class Q(tree.Node):
Do a database query to check if the expressions of the Q instance
matches against the expressions.
"""
# Avoid circular imports.
from django.db.models import BooleanField, Value
from django.db.models.functions import Coalesce
from django.db.models.sql import Query
from django.db.models.sql.constants import SINGLE
from django.db.models.sql import CheckQuery
query = Query(None)
for name, value in against.items():
if not hasattr(value, "resolve_expression"):
value = Value(value)
query.add_annotation(value, name, select=False)
query.add_annotation(Value(1), "_check")
# This will raise a FieldError if a field is missing in "against".
if connections[using].features.supports_comparing_boolean_expr:
query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
else:
query.add_q(self)
compiler = query.get_compiler(using=using)
try:
return compiler.execute_sql(SINGLE) is not None
except DatabaseError as e:
logger.warning("Got a database error calling check() on %r: %s", self, e)
return True
query = CheckQuery()
return query.do_check(self, against, using)
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)

View File

@ -2028,6 +2028,17 @@ class SQLAggregateCompiler(SQLCompiler):
return sql, params
class SQLCheckCompiler(SQLCompiler):
col_count = 1
check_sql = """\
SELECT 1 WHERE COALESCE(%s, TRUE)
"""
def as_sql(self):
condition, params = self.compile(self.query.where)
return self.check_sql % condition, params
def cursor_iter(cursor, sentinel, col_count, itersize):
"""
Yield blocks of rows from a cursor and ensure the cursor is closed when

View File

@ -2,11 +2,21 @@
Query subclasses which provide extra functionality beyond simple data retrieval.
"""
import logging
from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db import DatabaseError
from django.db.models.sql.constants import (
CURSOR,
GET_ITERATOR_CHUNK_SIZE,
NO_RESULTS,
SINGLE,
)
from django.db.models.sql.query import Query
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery", "CheckQuery"]
logger = logging.getLogger("django.db.models")
class DeleteQuery(Query):
@ -169,3 +179,26 @@ class AggregateQuery(Query):
def __init__(self, model, inner_query):
self.inner_query = inner_query
super().__init__(model)
class CheckQuery(Query):
compiler = "SQLCheckCompiler"
def __init__(self):
super().__init__(model=None)
def do_check(self, q, against, using):
from django.db.models import Value
for name, value in against.items():
if not hasattr(value, "resolve_expression"):
value = Value(value)
self.add_annotation(value, name, select=False)
# This will raise a FieldError if a field is missing in "against".
self.add_q(q)
compiler = self.get_compiler(using=using)
try:
return compiler.execute_sql(SINGLE) is not None
except DatabaseError as e:
logger.warning("Got a database error calling check() on %r: %s", self, e)
return True

View File

@ -6,7 +6,7 @@ from django.db.models import F
from django.db.models.constraints import BaseConstraint
from django.db.models.functions import Lower
from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import (
ChildModel,
@ -234,7 +234,6 @@ class CheckConstraintTests(TestCase):
constraint.validate(Product, Product(price=501, discounted_price=5))
constraint.validate(Product, Product(price=499, discounted_price=5))
@skipUnlessDBFeature("supports_comparing_boolean_expr")
def test_validate_nullable_field_with_none(self):
# Nullable fields should be considered valid on None values.
constraint = models.CheckConstraint(
@ -243,14 +242,6 @@ class CheckConstraintTests(TestCase):
)
constraint.validate(Product, Product())
@skipIfDBFeature("supports_comparing_boolean_expr")
def test_validate_nullable_field_with_isnull(self):
constraint = models.CheckConstraint(
check=models.Q(price__gte=0) | models.Q(price__isnull=True),
name="positive_price",
)
constraint.validate(Product, Product())
class UniqueConstraintTests(TestCase):
@classmethod