check compiler + get Oracle working with nulls
This commit is contained in:
parent
19e6efa50b
commit
222bed1011
|
@ -78,3 +78,7 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
|||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
from django.db.models.sql import compiler
|
||||
|
||||
|
||||
class SQLCompiler(compiler.SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLCheckCompiler(compiler.SQLCheckCompiler, SQLCompiler):
|
||||
# Oracle doesn't support boolean types yet PL/SQL does, meaning that although we
|
||||
# can't compare boolean expressions to NULL in SELECT or WHERE clauses we can in
|
||||
# a function which we can then call from a WHERE clause.
|
||||
check_sql = """\
|
||||
WITH
|
||||
FUNCTION f RETURN NUMBER IS
|
||||
b BOOLEAN;
|
||||
BEGIN
|
||||
b := COALESCE(%s, TRUE);
|
||||
IF b THEN
|
||||
RETURN 1;
|
||||
ELSE
|
||||
RETURN 0;
|
||||
END IF;
|
||||
END f;
|
||||
SELECT 1 FROM dual WHERE f = 1
|
||||
"""
|
||||
|
||||
def as_sql(self):
|
||||
# Avoid case wrapping.
|
||||
self.connection.vendor = "sqlite"
|
||||
condition, params = self.compile(self.query.where)
|
||||
self.connection.vendor = "oracle"
|
||||
|
||||
# Oracle doesn't allow binding params within functions so bind manually
|
||||
# with quote_value().
|
||||
condition %= tuple(
|
||||
self.connection.schema_editor().quote_value(param) for param in params
|
||||
)
|
||||
check_sql = self.check_sql % condition
|
||||
return check_sql, []
|
|
@ -19,6 +19,8 @@ from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
|
|||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
compiler_module = "django.db.backends.oracle.compiler"
|
||||
|
||||
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
|
||||
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
|
||||
# SmallAutoField, to preserve backward compatibility.
|
||||
|
|
|
@ -7,16 +7,13 @@ circular import difficulties.
|
|||
"""
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
|
||||
logger = logging.getLogger("django.db.models")
|
||||
|
||||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||
# describe the relation in Model terms (model Options and Fields for both
|
||||
# sides of the relation. The join_field is the field backing the relation.
|
||||
|
@ -114,29 +111,10 @@ class Q(tree.Node):
|
|||
Do a database query to check if the expressions of the Q instance
|
||||
matches against the expressions.
|
||||
"""
|
||||
# Avoid circular imports.
|
||||
from django.db.models import BooleanField, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.constants import SINGLE
|
||||
from django.db.models.sql import CheckQuery
|
||||
|
||||
query = Query(None)
|
||||
for name, value in against.items():
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
value = Value(value)
|
||||
query.add_annotation(value, name, select=False)
|
||||
query.add_annotation(Value(1), "_check")
|
||||
# This will raise a FieldError if a field is missing in "against".
|
||||
if connections[using].features.supports_comparing_boolean_expr:
|
||||
query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
|
||||
else:
|
||||
query.add_q(self)
|
||||
compiler = query.get_compiler(using=using)
|
||||
try:
|
||||
return compiler.execute_sql(SINGLE) is not None
|
||||
except DatabaseError as e:
|
||||
logger.warning("Got a database error calling check() on %r: %s", self, e)
|
||||
return True
|
||||
query = CheckQuery()
|
||||
return query.do_check(self, against, using)
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
|
|
|
@ -2028,6 +2028,17 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||
return sql, params
|
||||
|
||||
|
||||
class SQLCheckCompiler(SQLCompiler):
|
||||
col_count = 1
|
||||
check_sql = """\
|
||||
SELECT 1 WHERE COALESCE(%s, TRUE)
|
||||
"""
|
||||
|
||||
def as_sql(self):
|
||||
condition, params = self.compile(self.query.where)
|
||||
return self.check_sql % condition, params
|
||||
|
||||
|
||||
def cursor_iter(cursor, sentinel, col_count, itersize):
|
||||
"""
|
||||
Yield blocks of rows from a cursor and ensure the cursor is closed when
|
||||
|
|
|
@ -2,11 +2,21 @@
|
|||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db import DatabaseError
|
||||
from django.db.models.sql.constants import (
|
||||
CURSOR,
|
||||
GET_ITERATOR_CHUNK_SIZE,
|
||||
NO_RESULTS,
|
||||
SINGLE,
|
||||
)
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery", "CheckQuery"]
|
||||
|
||||
logger = logging.getLogger("django.db.models")
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
|
@ -169,3 +179,26 @@ class AggregateQuery(Query):
|
|||
def __init__(self, model, inner_query):
|
||||
self.inner_query = inner_query
|
||||
super().__init__(model)
|
||||
|
||||
|
||||
class CheckQuery(Query):
|
||||
compiler = "SQLCheckCompiler"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(model=None)
|
||||
|
||||
def do_check(self, q, against, using):
|
||||
from django.db.models import Value
|
||||
|
||||
for name, value in against.items():
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
value = Value(value)
|
||||
self.add_annotation(value, name, select=False)
|
||||
# This will raise a FieldError if a field is missing in "against".
|
||||
self.add_q(q)
|
||||
compiler = self.get_compiler(using=using)
|
||||
try:
|
||||
return compiler.execute_sql(SINGLE) is not None
|
||||
except DatabaseError as e:
|
||||
logger.warning("Got a database error calling check() on %r: %s", self, e)
|
||||
return True
|
||||
|
|
|
@ -6,7 +6,7 @@ from django.db.models import F
|
|||
from django.db.models.constraints import BaseConstraint
|
||||
from django.db.models.functions import Lower
|
||||
from django.db.transaction import atomic
|
||||
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||
|
||||
from .models import (
|
||||
ChildModel,
|
||||
|
@ -234,7 +234,6 @@ class CheckConstraintTests(TestCase):
|
|||
constraint.validate(Product, Product(price=501, discounted_price=5))
|
||||
constraint.validate(Product, Product(price=499, discounted_price=5))
|
||||
|
||||
@skipUnlessDBFeature("supports_comparing_boolean_expr")
|
||||
def test_validate_nullable_field_with_none(self):
|
||||
# Nullable fields should be considered valid on None values.
|
||||
constraint = models.CheckConstraint(
|
||||
|
@ -243,14 +242,6 @@ class CheckConstraintTests(TestCase):
|
|||
)
|
||||
constraint.validate(Product, Product())
|
||||
|
||||
@skipIfDBFeature("supports_comparing_boolean_expr")
|
||||
def test_validate_nullable_field_with_isnull(self):
|
||||
constraint = models.CheckConstraint(
|
||||
check=models.Q(price__gte=0) | models.Q(price__isnull=True),
|
||||
name="positive_price",
|
||||
)
|
||||
constraint.validate(Product, Product())
|
||||
|
||||
|
||||
class UniqueConstraintTests(TestCase):
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue