Simplified imports from django.db and django.contrib.gis.db.

This commit is contained in:
Nick Pope 2019-08-20 08:54:41 +01:00 committed by Mariusz Felisiak
parent 469bf2db15
commit 335c9c94ac
113 changed files with 382 additions and 450 deletions

View File

@ -10,7 +10,7 @@ from django.core import checks
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from django.db import models from django.db import models
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Combinable, F, OrderBy from django.db.models.expressions import Combinable
from django.forms.models import ( from django.forms.models import (
BaseModelForm, BaseModelFormSet, _get_foreign_key, BaseModelForm, BaseModelFormSet, _get_foreign_key,
) )
@ -546,10 +546,10 @@ class BaseModelAdminChecks:
def _check_ordering_item(self, obj, field_name, label): def _check_ordering_item(self, obj, field_name, label):
""" Check that `ordering` refers to existing fields. """ """ Check that `ordering` refers to existing fields. """
if isinstance(field_name, (Combinable, OrderBy)): if isinstance(field_name, (Combinable, models.OrderBy)):
if not isinstance(field_name, OrderBy): if not isinstance(field_name, models.OrderBy):
field_name = field_name.asc() field_name = field_name.asc()
if isinstance(field_name.expression, F): if isinstance(field_name.expression, models.F):
field_name = field_name.expression.name field_name = field_name.expression.name
else: else:
return [] return []

View File

@ -7,7 +7,7 @@ from django.contrib.admin.utils import (
lookup_field, lookup_field,
) )
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models.fields.related import ManyToManyRel from django.db.models import ManyToManyRel
from django.forms.utils import flatatt from django.forms.utils import flatatt
from django.template.defaultfilters import capfirst, linebreaksbr from django.template.defaultfilters import capfirst, linebreaksbr
from django.utils.html import conditional_escape, format_html from django.utils.html import conditional_escape, format_html

View File

@ -30,7 +30,6 @@ from django.core.exceptions import (
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.db import models, router, transaction from django.db import models, router, transaction
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms.formsets import DELETION_FIELD_NAME, all_valid from django.forms.formsets import DELETION_FIELD_NAME, all_valid
from django.forms.models import ( from django.forms.models import (
BaseInlineFormSet, inlineformset_factory, modelform_defines_fields, BaseInlineFormSet, inlineformset_factory, modelform_defines_fields,
@ -889,7 +888,7 @@ class ModelAdmin(BaseModelAdmin):
actions = self._filter_actions_by_permissions(request, self._get_base_actions()) actions = self._filter_actions_by_permissions(request, self._get_base_actions())
return {name: (func, name, desc) for func, name, desc in actions} return {name: (func, name, desc) for func, name, desc in actions}
def get_action_choices(self, request, default_choices=BLANK_CHOICE_DASH): def get_action_choices(self, request, default_choices=models.BLANK_CHOICE_DASH):
""" """
Return a list of choices for use in a form object. Each choice is a Return a list of choices for use in a form object. Each choice is a
tuple (name, description). tuple (name, description).

View File

@ -17,8 +17,8 @@ from django.core.exceptions import (
FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation, FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation,
) )
from django.core.paginator import InvalidPage from django.core.paginator import InvalidPage
from django.db import models from django.db.models import F, Field, ManyToOneRel, OrderBy
from django.db.models.expressions import Combinable, F, OrderBy from django.db.models.expressions import Combinable
from django.urls import reverse from django.urls import reverse
from django.utils.http import urlencode from django.utils.http import urlencode
from django.utils.timezone import make_aware from django.utils.timezone import make_aware
@ -141,7 +141,7 @@ class ChangeList:
# FieldListFilter class that has been registered for the # FieldListFilter class that has been registered for the
# type of the given field. # type of the given field.
field, field_list_filter_class = list_filter, FieldListFilter.create field, field_list_filter_class = list_filter, FieldListFilter.create
if not isinstance(field, models.Field): if not isinstance(field, Field):
field_path = field field_path = field
field = get_fields_from_path(self.model, field_path)[-1] field = get_fields_from_path(self.model, field_path)[-1]
@ -487,7 +487,7 @@ class ChangeList:
except FieldDoesNotExist: except FieldDoesNotExist:
pass pass
else: else:
if isinstance(field.remote_field, models.ManyToOneRel): if isinstance(field.remote_field, ManyToOneRel):
# <FK>_id field names don't require a join. # <FK>_id field names don't require a join.
if field_name != field.get_attname(): if field_name != field.get_attname():
return True return True

View File

@ -8,7 +8,7 @@ from django import forms
from django.conf import settings from django.conf import settings
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import URLValidator from django.core.validators import URLValidator
from django.db.models.deletion import CASCADE from django.db.models import CASCADE
from django.urls import reverse from django.urls import reverse
from django.urls.exceptions import NoReverseMatch from django.urls.exceptions import NoReverseMatch
from django.utils.html import smart_urlquote from django.utils.html import smart_urlquote

View File

@ -1,9 +1,8 @@
import sys import sys
from django.core.management.color import color_style from django.core.management.color import color_style
from django.db import migrations, transaction from django.db import IntegrityError, migrations, transaction
from django.db.models import Q from django.db.models import Q
from django.db.utils import IntegrityError
WARNING = """ WARNING = """
A problem arose migrating proxy model permissions for {old} to {new}. A problem arose migrating proxy model permissions for {old} to {new}.

View File

@ -7,12 +7,11 @@ from django.contrib.contenttypes.models import ContentType
from django.core import checks from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, models, router, transaction from django.db import DEFAULT_DB_ALIAS, models, router, transaction
from django.db.models import DO_NOTHING from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel
from django.db.models.base import ModelBase, make_foreign_order_accessors from django.db.models.base import ModelBase, make_foreign_order_accessors
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.fields.related import ( from django.db.models.fields.related import (
ForeignObject, ForeignObjectRel, ReverseManyToOneDescriptor, ReverseManyToOneDescriptor, lazy_related_operation,
lazy_related_operation,
) )
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1,6 +1,7 @@
from django.apps import apps as global_apps from django.apps import apps as global_apps
from django.db import DEFAULT_DB_ALIAS, migrations, router, transaction from django.db import (
from django.db.utils import IntegrityError DEFAULT_DB_ALIAS, IntegrityError, migrations, router, transaction,
)
class RenameContentType(migrations.RunPython): class RenameContentType(migrations.RunPython):

View File

@ -1,6 +1,6 @@
import re import re
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db import models
class BaseSpatialFeatures: class BaseSpatialFeatures:
@ -77,19 +77,19 @@ class BaseSpatialFeatures:
# Is the aggregate supported by the database? # Is the aggregate supported by the database?
@property @property
def supports_collect_aggr(self): def supports_collect_aggr(self):
return aggregates.Collect not in self.connection.ops.disallowed_aggregates return models.Collect not in self.connection.ops.disallowed_aggregates
@property @property
def supports_extent_aggr(self): def supports_extent_aggr(self):
return aggregates.Extent not in self.connection.ops.disallowed_aggregates return models.Extent not in self.connection.ops.disallowed_aggregates
@property @property
def supports_make_line_aggr(self): def supports_make_line_aggr(self):
return aggregates.MakeLine not in self.connection.ops.disallowed_aggregates return models.MakeLine not in self.connection.ops.disallowed_aggregates
@property @property
def supports_union_aggr(self): def supports_union_aggr(self):
return aggregates.Union not in self.connection.ops.disallowed_aggregates return models.Union not in self.connection.ops.disallowed_aggregates
def __getattr__(self, name): def __getattr__(self, name):
m = re.match(r'has_(\w*)_function$', name) m = re.match(r'has_(\w*)_function$', name)

View File

@ -3,7 +3,7 @@ from django.contrib.gis.db.models.functions import Distance
from django.contrib.gis.measure import ( from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure, Area as AreaMeasure, Distance as DistanceMeasure,
) )
from django.db.utils import NotSupportedError from django.db import NotSupportedError
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1,9 +1,9 @@
from django.contrib.gis.db import models
from django.contrib.gis.db.backends.base.adapter import WKTAdapter from django.contrib.gis.db.backends.base.adapter import WKTAdapter
from django.contrib.gis.db.backends.base.operations import ( from django.contrib.gis.db.backends.base.operations import (
BaseSpatialOperations, BaseSpatialOperations,
) )
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.geometry import GEOSGeometryBase
from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.geos.prototypes.io import wkb_r
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
@ -49,8 +49,8 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
return operators return operators
disallowed_aggregates = ( disallowed_aggregates = (
aggregates.Collect, aggregates.Extent, aggregates.Extent3D, models.Collect, models.Extent, models.Extent3D, models.MakeLine,
aggregates.MakeLine, aggregates.Union, models.Union,
) )
@cached_property @cached_property

View File

@ -1,8 +1,8 @@
import logging import logging
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
from django.db import OperationalError
from django.db.backends.mysql.schema import DatabaseSchemaEditor from django.db.backends.mysql.schema import DatabaseSchemaEditor
from django.db.utils import OperationalError
logger = logging.getLogger('django.contrib.gis') logger = logging.getLogger('django.contrib.gis')

View File

@ -9,12 +9,12 @@
""" """
import re import re
from django.contrib.gis.db import models
from django.contrib.gis.db.backends.base.operations import ( from django.contrib.gis.db.backends.base.operations import (
BaseSpatialOperations, BaseSpatialOperations,
) )
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase
from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.geos.prototypes.io import wkb_r
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
@ -53,7 +53,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
name = 'oracle' name = 'oracle'
oracle = True oracle = True
disallowed_aggregates = (aggregates.Collect, aggregates.Extent3D, aggregates.MakeLine) disallowed_aggregates = (models.Collect, models.Extent3D, models.MakeLine)
Adapter = OracleSpatialAdapter Adapter = OracleSpatialAdapter

View File

@ -1,4 +1,4 @@
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
from django.db.backends.oracle.schema import DatabaseSchemaEditor from django.db.backends.oracle.schema import DatabaseSchemaEditor
from django.db.backends.utils import strip_quotes, truncate_name from django.db.backends.utils import strip_quotes, truncate_name

View File

@ -11,9 +11,9 @@ from django.contrib.gis.geos.geometry import GEOSGeometryBase
from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.geos.prototypes.io import wkb_r
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import NotSupportedError, ProgrammingError
from django.db.backends.postgresql.operations import DatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.models import Func, Value from django.db.models import Func, Value
from django.db.utils import NotSupportedError, ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.version import get_version_tuple from django.utils.version import get_version_tuple

View File

@ -2,12 +2,12 @@
SQL functions reference lists: SQL functions reference lists:
https://www.gaia-gis.it/gaia-sins/spatialite-sql-4.3.0.html https://www.gaia-gis.it/gaia-sins/spatialite-sql-4.3.0.html
""" """
from django.contrib.gis.db import models
from django.contrib.gis.db.backends.base.operations import ( from django.contrib.gis.db.backends.base.operations import (
BaseSpatialOperations, BaseSpatialOperations,
) )
from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase from django.contrib.gis.geos.geometry import GEOSGeometry, GEOSGeometryBase
from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.geos.prototypes.io import wkb_r
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
@ -62,7 +62,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
'dwithin': SpatialOperator(func='PtDistWithin'), 'dwithin': SpatialOperator(func='PtDistWithin'),
} }
disallowed_aggregates = (aggregates.Extent3D,) disallowed_aggregates = (models.Extent3D,)
select = 'CAST (AsEWKB(%s) AS BLOB)' select = 'CAST (AsEWKB(%s) AS BLOB)'

View File

@ -1,5 +1,5 @@
from django.db import DatabaseError
from django.db.backends.sqlite3.schema import DatabaseSchemaEditor from django.db.backends.sqlite3.schema import DatabaseSchemaEditor
from django.db.utils import DatabaseError
class SpatialiteSchemaEditor(DatabaseSchemaEditor): class SpatialiteSchemaEditor(DatabaseSchemaEditor):
@ -35,7 +35,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor):
return self.connection.ops.geo_quote_name(name) return self.connection.ops.geo_quote_name(name)
def column_sql(self, model, field, include_default=False): def column_sql(self, model, field, include_default=False):
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
if not isinstance(field, GeometryField): if not isinstance(field, GeometryField):
return super().column_sql(model, field, include_default) return super().column_sql(model, field, include_default)
@ -82,7 +82,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor):
self.geometry_sql = [] self.geometry_sql = []
def delete_model(self, model, **kwargs): def delete_model(self, model, **kwargs):
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
# Drop spatial metadata (dropping the table does not automatically remove them) # Drop spatial metadata (dropping the table does not automatically remove them)
for field in model._meta.local_fields: for field in model._meta.local_fields:
if isinstance(field, GeometryField): if isinstance(field, GeometryField):
@ -101,7 +101,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor):
super().delete_model(model, **kwargs) super().delete_model(model, **kwargs)
def add_field(self, model, field): def add_field(self, model, field):
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
if isinstance(field, GeometryField): if isinstance(field, GeometryField):
# Populate self.geometry_sql # Populate self.geometry_sql
self.column_sql(model, field) self.column_sql(model, field)
@ -112,7 +112,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor):
super().add_field(model, field) super().add_field(model, field)
def remove_field(self, model, field): def remove_field(self, model, field):
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
# NOTE: If the field is a geometry field, the table is just recreated, # NOTE: If the field is a geometry field, the table is just recreated,
# the parent's remove_field can't be used cause it will skip the # the parent's remove_field can't be used cause it will skip the
# recreation if the field does not have a database type. Geometry fields # recreation if the field does not have a database type. Geometry fields
@ -124,7 +124,7 @@ class SpatialiteSchemaEditor(DatabaseSchemaEditor):
super().remove_field(model, field) super().remove_field(model, field)
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True): def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
# Remove geometry-ness from temp table # Remove geometry-ness from temp table
for field in model._meta.local_fields: for field in model._meta.local_fields:
if isinstance(field, GeometryField): if isinstance(field, GeometryField):

View File

@ -1,7 +1,7 @@
from django.contrib.gis.db.models.fields import ( from django.contrib.gis.db.models.fields import (
ExtentField, GeometryCollectionField, GeometryField, LineStringField, ExtentField, GeometryCollectionField, GeometryField, LineStringField,
) )
from django.db.models.aggregates import Aggregate from django.db.models import Aggregate
from django.utils.functional import cached_property from django.utils.functional import cached_property
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']

View File

@ -8,7 +8,7 @@ from django.contrib.gis.geos import (
MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon,
) )
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.models.fields import Field from django.db.models import Field
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
# Local cache of the spatial_ref_sys table, which holds SRID data for each # Local cache of the spatial_ref_sys table, which holds SRID data for each

View File

@ -4,12 +4,12 @@ from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
from django.contrib.gis.db.models.sql import AreaField, DistanceField from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.contrib.gis.geos import GEOSGeometry from django.contrib.gis.geos import GEOSGeometry
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import NotSupportedError
from django.db.models import ( from django.db.models import (
BinaryField, BooleanField, FloatField, IntegerField, TextField, Transform, BinaryField, BooleanField, FloatField, Func, IntegerField, TextField,
Transform, Value,
) )
from django.db.models.expressions import Func, Value
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.db.utils import NotSupportedError
from django.utils.functional import cached_property from django.utils.functional import cached_property
NUMERIC_TYPES = (int, float, Decimal) NUMERIC_TYPES = (int, float, Decimal)

View File

@ -1,8 +1,7 @@
from django.contrib.gis.db.models.fields import BaseSpatialField from django.contrib.gis.db.models.fields import BaseSpatialField
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
from django.db import NotSupportedError from django.db import NotSupportedError
from django.db.models.expressions import Expression from django.db.models import Expression, Lookup, Transform
from django.db.models.lookups import Lookup, Transform
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils.regex_helper import _lazy_re_compile from django.utils.regex_helper import _lazy_re_compile

View File

@ -1,5 +1,5 @@
from django.apps import apps from django.apps import apps
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
from django.contrib.sitemaps import Sitemap from django.contrib.sitemaps import Sitemap
from django.db import models from django.db import models
from django.urls import reverse from django.urls import reverse

View File

@ -1,5 +1,5 @@
from django.apps import apps from django.apps import apps
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
from django.contrib.gis.db.models.functions import AsKML, Transform from django.contrib.gis.db.models.functions import AsKML, Transform
from django.contrib.gis.shortcuts import render_to_kml, render_to_kmz from django.contrib.gis.shortcuts import render_to_kml, render_to_kmz
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist

View File

@ -1,6 +1,5 @@
from django.contrib.postgres.fields import ArrayField, JSONField from django.contrib.postgres.fields import ArrayField, JSONField
from django.db.models import Value from django.db.models import Aggregate, Value
from django.db.models.aggregates import Aggregate
from .mixins import OrderableAggMixin from .mixins import OrderableAggMixin

View File

@ -1,4 +1,4 @@
from django.db.models.expressions import F, OrderBy from django.db.models import F, OrderBy
class OrderableAggMixin: class OrderableAggMixin:

View File

@ -1,5 +1,4 @@
from django.db.models import FloatField, IntegerField from django.db.models import Aggregate, FloatField, IntegerField
from django.db.models.aggregates import Aggregate
__all__ = [ __all__ = [
'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept', 'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept',

View File

@ -1,5 +1,5 @@
from django.db import NotSupportedError
from django.db.models import Index from django.db.models import Index
from django.db.utils import NotSupportedError
from django.utils.functional import cached_property from django.utils.functional import cached_property
__all__ = [ __all__ = [

View File

@ -1,9 +1,9 @@
from django.contrib.postgres.signals import ( from django.contrib.postgres.signals import (
get_citext_oids, get_hstore_oids, register_type_handlers, get_citext_oids, get_hstore_oids, register_type_handlers,
) )
from django.db import NotSupportedError
from django.db.migrations import AddIndex, RemoveIndex from django.db.migrations import AddIndex, RemoveIndex
from django.db.migrations.operations.base import Operation from django.db.migrations.operations.base import Operation
from django.db.utils import NotSupportedError
class CreateExtension(Operation): class CreateExtension(Operation):

View File

@ -1,9 +1,8 @@
from django.db.models import CharField, Field, FloatField, TextField from django.db.models import (
from django.db.models.expressions import ( CharField, Expression, Field, FloatField, Func, Lookup, TextField, Value,
CombinedExpression, Expression, Func, Value,
) )
from django.db.models.expressions import CombinedExpression
from django.db.models.functions import Cast, Coalesce from django.db.models.functions import Cast, Coalesce
from django.db.models.lookups import Lookup
class SearchVectorExact(Lookup): class SearchVectorExact(Lookup):

View File

@ -3,9 +3,8 @@ from django.core.cache import caches
from django.core.cache.backends.db import BaseDatabaseCache from django.core.cache.backends.db import BaseDatabaseCache
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.db import ( from django.db import (
DEFAULT_DB_ALIAS, connections, models, router, transaction, DEFAULT_DB_ALIAS, DatabaseError, connections, models, router, transaction,
) )
from django.db.utils import DatabaseError
class Command(BaseCommand): class Command(BaseCommand):

View File

@ -10,12 +10,12 @@ import pytz
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS, DatabaseError
from django.db.backends import utils from django.db.backends import utils
from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.base.validation import BaseDatabaseValidation
from django.db.backends.signals import connection_created from django.db.backends.signals import connection_created
from django.db.transaction import TransactionManagementError from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError, DatabaseErrorWrapper from django.db.utils import DatabaseErrorWrapper
from django.utils import timezone from django.utils import timezone
from django.utils.asyncio import async_unsafe from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1,4 +1,4 @@
from django.db.utils import ProgrammingError from django.db import ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -4,7 +4,7 @@ MySQL database backend for Django.
Requires mysqlclient: https://pypi.org/project/mysqlclient/ Requires mysqlclient: https://pypi.org/project/mysqlclient/
""" """
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import utils from django.db import IntegrityError
from django.db.backends import utils as backend_utils from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe from django.utils.asyncio import async_unsafe
@ -75,7 +75,7 @@ class CursorWrapper:
# Map some error codes to IntegrityError, since they seem to be # Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place. # misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror: if e.args[0] in self.codes_for_integrityerror:
raise utils.IntegrityError(*tuple(e.args)) raise IntegrityError(*tuple(e.args))
raise raise
def executemany(self, query, args): def executemany(self, query, args):
@ -85,7 +85,7 @@ class CursorWrapper:
# Map some error codes to IntegrityError, since they seem to be # Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place. # misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror: if e.args[0] in self.codes_for_integrityerror:
raise utils.IntegrityError(*tuple(e.args)) raise IntegrityError(*tuple(e.args))
raise raise
def __getattr__(self, attr): def __getattr__(self, attr):
@ -314,7 +314,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
) )
for bad_row in cursor.fetchall(): for bad_row in cursor.fetchall():
raise utils.IntegrityError( raise IntegrityError(
"The row in table '%s' with primary key '%s' has an invalid " "The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not " "foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s." "have a corresponding value in %s.%s."

View File

@ -6,7 +6,7 @@ from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
) )
from django.db.models.indexes import Index from django.db.models import Index
from django.utils.datastructures import OrderedSet from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned')) FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned'))

View File

@ -11,7 +11,7 @@ from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import utils from django.db import IntegrityError
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils.asyncio import async_unsafe from django.utils.asyncio import async_unsafe
from django.utils.encoding import force_bytes, force_str from django.utils.encoding import force_bytes, force_str
@ -74,7 +74,7 @@ def wrap_oracle_errors():
# Convert that case to Django's IntegrityError exception. # Convert that case to Django's IntegrityError exception.
x = e.args[0] x = e.args[0]
if hasattr(x, 'code') and hasattr(x, 'message') and x.code == 2091 and 'ORA-02291' in x.message: if hasattr(x, 'code') and hasattr(x, 'message') and x.code == 2091 and 'ORA-02291' in x.message:
raise utils.IntegrityError(*tuple(e.args)) raise IntegrityError(*tuple(e.args))
raise raise

View File

@ -1,8 +1,8 @@
import sys import sys
from django.conf import settings from django.conf import settings
from django.db import DatabaseError
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.utils import DatabaseError
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1,5 +1,5 @@
from django.db import InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures): class DatabaseFeatures(BaseDatabaseFeatures):

View File

@ -3,11 +3,12 @@ import uuid
from functools import lru_cache from functools import lru_cache
from django.conf import settings from django.conf import settings
from django.db import DatabaseError
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import strip_quotes, truncate_name from django.db.backends.utils import strip_quotes, truncate_name
from django.db.models.expressions import Exists, ExpressionWrapper, RawSQL from django.db.models import AutoField, Exists, ExpressionWrapper
from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.db.utils import DatabaseError
from django.utils import timezone from django.utils import timezone
from django.utils.encoding import force_bytes, force_str from django.utils.encoding import force_bytes, force_str
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -466,12 +467,11 @@ END;
return sql return sql
def sequence_reset_sql(self, style, model_list): def sequence_reset_sql(self, style, model_list):
from django.db import models
output = [] output = []
query = self._sequence_reset_sql query = self._sequence_reset_sql
for model in model_list: for model in model_list:
for f in model._meta.local_fields: for f in model._meta.local_fields:
if isinstance(f, models.AutoField): if isinstance(f, AutoField):
no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table) no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
table = self.quote_name(model._meta.db_table) table = self.quote_name(model._meta.db_table)
column = self.quote_name(f.column) column = self.quote_name(f.column)

View File

@ -2,8 +2,8 @@ import copy
import datetime import datetime
import re import re
from django.db import DatabaseError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.utils import DatabaseError
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

View File

@ -10,12 +10,11 @@ import warnings
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connections from django.db import DatabaseError as WrappedDatabaseError, connections
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import ( from django.db.backends.utils import (
CursorDebugWrapper as BaseCursorDebugWrapper, CursorDebugWrapper as BaseCursorDebugWrapper,
) )
from django.db.utils import DatabaseError as WrappedDatabaseError
from django.utils.asyncio import async_unsafe from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.safestring import SafeString from django.utils.safestring import SafeString

View File

@ -1,7 +1,7 @@
import operator import operator
from django.db import InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.utils import InterfaceError
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1,7 +1,7 @@
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo, BaseDatabaseIntrospection, FieldInfo, TableInfo,
) )
from django.db.models.indexes import Index from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):

View File

@ -16,7 +16,7 @@ from sqlite3 import dbapi2 as Database
import pytz import pytz
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import utils from django.db import IntegrityError
from django.db.backends import utils as backend_utils from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils import timezone from django.utils import timezone
@ -328,7 +328,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
), ),
(rowid,), (rowid,),
).fetchone() ).fetchone()
raise utils.IntegrityError( raise IntegrityError(
"The row in table '%s' with primary key '%s' has an " "The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that " "invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % ( "does not have a corresponding value in %s.%s." % (
@ -360,7 +360,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
) )
) )
for bad_row in cursor.fetchall(): for bad_row in cursor.fetchall():
raise utils.IntegrityError( raise IntegrityError(
"The row in table '%s' with primary key '%s' has an " "The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that " "invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s." % ( "does not have a corresponding value in %s.%s." % (

View File

@ -6,7 +6,7 @@ import sqlparse
from django.db.backends.base.introspection import ( from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo, BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
) )
from django.db.models.indexes import Index from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile from django.utils.regex_helper import _lazy_re_compile
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk',)) FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk',))

View File

@ -6,9 +6,8 @@ from itertools import chain
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import utils from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import aggregates, fields
from django.db.models.expressions import Col from django.db.models.expressions import Col
from django.utils import timezone from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.dateparse import parse_date, parse_datetime, parse_time
@ -40,8 +39,8 @@ class DatabaseOperations(BaseDatabaseOperations):
return len(objs) return len(objs)
def check_expression_support(self, expression): def check_expression_support(self, expression):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev) bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
if isinstance(expression, bad_aggregates): if isinstance(expression, bad_aggregates):
for expr in expression.get_source_expressions(): for expr in expression.get_source_expressions():
try: try:
@ -52,13 +51,13 @@ class DatabaseOperations(BaseDatabaseOperations):
pass pass
else: else:
if isinstance(output_field, bad_fields): if isinstance(output_field, bad_fields):
raise utils.NotSupportedError( raise NotSupportedError(
'You cannot use Sum, Avg, StdDev, and Variance ' 'You cannot use Sum, Avg, StdDev, and Variance '
'aggregations on date/time fields in sqlite3 ' 'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.' 'since date/time is saved as text.'
) )
if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1: if isinstance(expression, models.Aggregate) and len(expression.source_expressions) > 1:
raise utils.NotSupportedError( raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions " "SQLite doesn't support DISTINCT on aggregate functions "
"accepting multiple arguments." "accepting multiple arguments."
) )
@ -313,7 +312,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def combine_duration_expression(self, connector, sub_expressions): def combine_duration_expression(self, connector, sub_expressions):
if connector not in ['+', '-']: if connector not in ['+', '-']:
raise utils.DatabaseError('Invalid connector for timedelta: %s.' % connector) raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3: if len(fn_params) > 3:
raise ValueError('Too many params for timedelta operations.') raise ValueError('Too many params for timedelta operations.')

View File

@ -2,12 +2,12 @@ import copy
from decimal import Decimal from decimal import Decimal
from django.apps.registry import Apps from django.apps.registry import Apps
from django.db import NotSupportedError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import Statement from django.db.backends.ddl_references import Statement
from django.db.backends.utils import strip_quotes from django.db.backends.utils import strip_quotes
from django.db.models import UniqueConstraint from django.db.models import UniqueConstraint
from django.db.transaction import atomic from django.db.transaction import atomic
from django.db.utils import NotSupportedError
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

View File

@ -6,7 +6,7 @@ import logging
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from django.db.utils import NotSupportedError from django.db import NotSupportedError
logger = logging.getLogger('django.db.backends') logger = logging.getLogger('django.db.backends')

View File

@ -1,4 +1,4 @@
from django.db.utils import DatabaseError from django.db import DatabaseError
class AmbiguityError(Exception): class AmbiguityError(Exception):

View File

@ -1,5 +1,5 @@
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from django.db.models.fields import NOT_PROVIDED from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property from django.utils.functional import cached_property
from .base import Operation from .base import Operation

View File

@ -4,7 +4,7 @@ import os
import sys import sys
from django.apps import apps from django.apps import apps
from django.db.models.fields import NOT_PROVIDED from django.db.models import NOT_PROVIDED
from django.utils import timezone from django.utils import timezone
from .loader import MigrationLoader from .loader import MigrationLoader

View File

@ -1,6 +1,5 @@
from django.apps.registry import Apps from django.apps.registry import Apps
from django.db import models from django.db import DatabaseError, models
from django.db.utils import DatabaseError
from django.utils.functional import classproperty from django.utils.functional import classproperty
from django.utils.timezone import now from django.utils.timezone import now

View File

@ -5,7 +5,6 @@ from django.apps import AppConfig
from django.apps.registry import Apps, apps as global_apps from django.apps.registry import Apps, apps as global_apps
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.db.models.options import DEFAULT_NAMES, normalize_together from django.db.models.options import DEFAULT_NAMES, normalize_together
from django.db.models.utils import make_model_tuple from django.db.models.utils import make_model_tuple
@ -406,7 +405,7 @@ class ModelState:
for field in model._meta.local_fields: for field in model._meta.local_fields:
if getattr(field, "remote_field", None) and exclude_rels: if getattr(field, "remote_field", None) and exclude_rels:
continue continue
if isinstance(field, OrderWrt): if isinstance(field, models.OrderWrt):
continue continue
name = field.name name = field.name
try: try:

View File

@ -12,7 +12,8 @@ from django.db.models.enums import * # NOQA
from django.db.models.enums import __all__ as enums_all from django.db.models.enums import __all__ as enums_all
from django.db.models.expressions import ( from django.db.models.expressions import (
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func, Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
OuterRef, RowRange, Subquery, Value, ValueRange, When, Window, WindowFrame, OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
WindowFrame,
) )
from django.db.models.fields import * # NOQA from django.db.models.fields import * # NOQA
from django.db.models.fields import __all__ as fields_all from django.db.models.fields import __all__ as fields_all
@ -22,16 +23,14 @@ from django.db.models.indexes import * # NOQA
from django.db.models.indexes import __all__ as indexes_all from django.db.models.indexes import __all__ as indexes_all
from django.db.models.lookups import Lookup, Transform from django.db.models.lookups import Lookup, Transform
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.db.models.query import ( from django.db.models.query import Prefetch, QuerySet, prefetch_related_objects
Prefetch, Q, QuerySet, prefetch_related_objects, from django.db.models.query_utils import FilteredRelation, Q
)
from django.db.models.query_utils import FilteredRelation
# Imports that would create circular imports if sorted # Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip from django.db.models.base import DEFERRED, Model # isort:skip
from django.db.models.fields.related import ( # isort:skip from django.db.models.fields.related import ( # isort:skip
ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
ManyToOneRel, ManyToManyRel, OneToOneRel, ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
) )
@ -41,11 +40,12 @@ __all__ += [
'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT', 'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
'SET_NULL', 'ProtectedError', 'RestrictedError', 'SET_NULL', 'ProtectedError', 'RestrictedError',
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F', 'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When', 'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
'ValueRange', 'When',
'Window', 'WindowFrame', 'Window', 'WindowFrame',
'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager', 'FileField', 'ImageField', 'OrderWrt', 'Lookup', 'Transform', 'Manager',
'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects', 'DEFERRED', 'Model',
'FilteredRelation', 'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField', 'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel', 'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
] ]

View File

@ -4,10 +4,9 @@ import inspect
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db import connection from django.db import NotSupportedError, connection
from django.db.models import fields from django.db.models import fields
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.utils import NotSupportedError
from django.utils.deconstruct import deconstructible from django.utils.deconstruct import deconstructible
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.hashable import make_hashable from django.utils.hashable import make_hashable

View File

@ -1,8 +1,8 @@
from django.db import NotSupportedError
from django.db.models.expressions import Func, Value from django.db.models.expressions import Func, Value
from django.db.models.fields import IntegerField from django.db.models.fields import IntegerField
from django.db.models.functions import Coalesce from django.db.models.functions import Coalesce
from django.db.models.lookups import Transform from django.db.models.lookups import Transform
from django.db.utils import NotSupportedError
class BytesToCharFieldConversionMixin: class BytesToCharFieldConversionMixin:

View File

@ -7,9 +7,7 @@ from django.apps import apps
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from django.db import connections from django.db import connections
from django.db.models import Manager from django.db.models import AutoField, Manager, OrderWrt
from django.db.models.fields import AutoField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.utils.datastructures import ImmutableList, OrderedSet from django.utils.datastructures import ImmutableList, OrderedSet
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -12,19 +12,17 @@ from itertools import chain
from django.conf import settings from django.conf import settings
from django.core import exceptions from django.core import exceptions
from django.db import ( from django.db import (
DJANGO_VERSION_PICKLE_KEY, IntegrityError, connections, router, DJANGO_VERSION_PICKLE_KEY, IntegrityError, NotSupportedError, connections,
transaction, router, transaction,
) )
from django.db.models import DateField, DateTimeField, sql from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.expressions import Case, Expression, F, Value, When from django.db.models.expressions import Case, Expression, F, Value, When
from django.db.models.fields import AutoField
from django.db.models.functions import Cast, Trunc from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
from django.db.models.utils import resolve_callables from django.db.models.utils import resolve_callables
from django.db.utils import NotSupportedError
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property, partition from django.utils.functional import cached_property, partition
from django.utils.version import get_version from django.utils.version import get_version

View File

@ -4,6 +4,7 @@ from functools import partial
from itertools import chain from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db import DatabaseError, NotSupportedError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import OrderBy, Random, RawSQL, Ref, Value from django.db.models.expressions import OrderBy, Random, RawSQL, Ref, Value
from django.db.models.functions import Cast from django.db.models.functions import Cast
@ -13,7 +14,6 @@ from django.db.models.sql.constants import (
) )
from django.db.models.sql.query import Query, get_order_dir from django.db.models.sql.query import Query, get_order_dir
from django.db.transaction import TransactionManagementError from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError, NotSupportedError
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.hashable import make_hashable from django.utils.hashable import make_hashable

View File

@ -140,7 +140,7 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None,
ignored = [] ignored = []
opts = model._meta opts = model._meta
# Avoid circular import # Avoid circular import
from django.db.models.fields import Field as ModelField from django.db.models import Field as ModelField
sortable_private_fields = [f for f in opts.private_fields if isinstance(f, ModelField)] sortable_private_fields = [f for f in opts.private_fields if isinstance(f, ModelField)]
for f in sorted(chain(opts.concrete_fields, sortable_private_fields, opts.many_to_many)): for f in sorted(chain(opts.concrete_fields, sortable_private_fields, opts.many_to_many)):
if not getattr(f, 'editable', False): if not getattr(f, 'editable', False):

View File

@ -1,6 +1,6 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import InvalidPage, Paginator from django.core.paginator import InvalidPage, Paginator
from django.db.models.query import QuerySet from django.db.models import QuerySet
from django.http import Http404 from django.http import Http404
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View from django.views.generic.base import ContextMixin, TemplateResponseMixin, View

View File

@ -42,12 +42,12 @@ the field class we want the lookup to be available for. In this case, the lookup
makes sense on all ``Field`` subclasses, so we register it with ``Field`` makes sense on all ``Field`` subclasses, so we register it with ``Field``
directly:: directly::
from django.db.models.fields import Field from django.db.models import Field
Field.register_lookup(NotEqual) Field.register_lookup(NotEqual)
Lookup registration can also be done using a decorator pattern:: Lookup registration can also be done using a decorator pattern::
from django.db.models.fields import Field from django.db.models import Field
@Field.register_lookup @Field.register_lookup
class NotEqualLookup(Lookup): class NotEqualLookup(Lookup):

View File

@ -10,8 +10,7 @@ from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.contrib.messages.storage.cookie import CookieStorage from django.contrib.messages.storage.cookie import CookieStorage
from django.db import connection, models from django.db import connection, models
from django.db.models import F from django.db.models import F, Field, IntegerField
from django.db.models.fields import Field, IntegerField
from django.db.models.functions import Upper from django.db.models.functions import Upper
from django.db.models.lookups import Contains, Exact from django.db.models.lookups import Contains, Exact
from django.template import Context, Template, TemplateSyntaxError from django.template import Context, Template, TemplateSyntaxError

View File

@ -5,10 +5,9 @@ from decimal import Decimal
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
Avg, Count, DecimalField, DurationField, F, FloatField, Func, IntegerField, Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, Func,
Max, Min, Sum, Value, IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
) )
from django.db.models.expressions import Case, Exists, OuterRef, Subquery, When
from django.db.models.functions import Coalesce from django.db.models.functions import Coalesce
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipUnlessDBFeature

View File

@ -8,10 +8,9 @@ from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, Aggregate, Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev,
Value, Variance, When, Sum, Value, Variance, When,
) )
from django.db.models.aggregates import Aggregate
from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate

View File

@ -3,10 +3,9 @@ import unittest
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
from django.db import connection from django.db import DatabaseError, connection
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.mysql.creation import DatabaseCreation from django.db.backends.mysql.creation import DatabaseCreation
from django.db.utils import DatabaseError
from django.test import SimpleTestCase from django.test import SimpleTestCase

View File

@ -2,9 +2,8 @@ import unittest
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
from django.db import connection from django.db import DatabaseError, connection
from django.db.backends.oracle.creation import DatabaseCreation from django.db.backends.oracle.creation import DatabaseCreation
from django.db.utils import DatabaseError
from django.test import TestCase from django.test import TestCase

View File

@ -1,8 +1,7 @@
import unittest import unittest
from django.db import connection from django.db import DatabaseError, connection
from django.db.models.fields import BooleanField, NullBooleanField from django.db.models import BooleanField, NullBooleanField
from django.db.utils import DatabaseError
from django.test import TransactionTestCase from django.test import TransactionTestCase
from ..models import Square from ..models import Square

View File

@ -3,9 +3,8 @@ from contextlib import contextmanager
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
from django.db import connection from django.db import DatabaseError, connection
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.utils import DatabaseError
from django.test import SimpleTestCase from django.test import SimpleTestCase
try: try:

View File

@ -8,11 +8,9 @@ from sqlite3 import dbapi2
from unittest import mock from unittest import mock
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import ConnectionHandler, connection, transaction from django.db import NotSupportedError, connection, transaction
from django.db.models import Avg, StdDev, Sum, Variance from django.db.models import Aggregate, Avg, CharField, StdDev, Sum, Variance
from django.db.models.aggregates import Aggregate from django.db.utils import ConnectionHandler
from django.db.models.fields import CharField
from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
TestCase, TransactionTestCase, override_settings, skipIfDBFeature, TestCase, TransactionTestCase, override_settings, skipIfDBFeature,
) )

View File

@ -1,11 +1,10 @@
"""Tests for django.db.backends.utils""" """Tests for django.db.backends.utils"""
from decimal import Decimal, Rounded from decimal import Decimal, Rounded
from django.db import connection from django.db import NotSupportedError, connection
from django.db.backends.utils import ( from django.db.backends.utils import (
format_number, split_identifier, truncate_name, format_number, split_identifier, truncate_name,
) )
from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature,
) )

View File

@ -5,7 +5,7 @@ from unittest import mock
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
from django.db.models.manager import BaseManager from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet, QuerySet from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import ( from django.test import (
SimpleTestCase, TestCase, TransactionTestCase, skipUnlessDBFeature, SimpleTestCase, TestCase, TransactionTestCase, skipUnlessDBFeature,
) )
@ -316,7 +316,7 @@ class ModelTest(TestCase):
# A hacky test for custom QuerySet subclass - refs #17271 # A hacky test for custom QuerySet subclass - refs #17271
Article.objects.create(headline='foo', pub_date=datetime.now()) Article.objects.create(headline='foo', pub_date=datetime.now())
class CustomQuerySet(QuerySet): class CustomQuerySet(models.QuerySet):
def do_something(self): def do_something(self):
return 'did something' return 'did something'
@ -607,7 +607,7 @@ class ManagerTest(SimpleTestCase):
`Manager` will need to be added to `ManagerTest.QUERYSET_PROXY_METHODS`. `Manager` will need to be added to `ManagerTest.QUERYSET_PROXY_METHODS`.
""" """
self.assertEqual( self.assertEqual(
sorted(BaseManager._get_queryset_methods(QuerySet)), sorted(BaseManager._get_queryset_methods(models.QuerySet)),
sorted(self.QUERYSET_PROXY_METHODS), sorted(self.QUERYSET_PROXY_METHODS),
) )
@ -640,7 +640,7 @@ class SelectOnSaveTests(TestCase):
orig_class = Article._base_manager._queryset_class orig_class = Article._base_manager._queryset_class
class FakeQuerySet(QuerySet): class FakeQuerySet(models.QuerySet):
# Make sure the _update method below is in fact called. # Make sure the _update method below is in fact called.
called = False called = False

View File

@ -3,8 +3,6 @@ import decimal
import unittest import unittest
from django.db import connection, models from django.db import connection, models
from django.db.models import Avg
from django.db.models.expressions import Value
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.test import ( from django.test import (
TestCase, ignore_warnings, override_settings, skipUnlessDBFeature, TestCase, ignore_warnings, override_settings, skipUnlessDBFeature,
@ -19,7 +17,7 @@ class CastTests(TestCase):
Author.objects.create(name='Bob', age=1, alias='1') Author.objects.create(name='Bob', age=1, alias='1')
def test_cast_from_value(self): def test_cast_from_value(self):
numbers = Author.objects.annotate(cast_integer=Cast(Value('0'), models.IntegerField())) numbers = Author.objects.annotate(cast_integer=Cast(models.Value('0'), models.IntegerField()))
self.assertEqual(numbers.get().cast_integer, 0) self.assertEqual(numbers.get().cast_integer, 0)
def test_cast_from_field(self): def test_cast_from_field(self):
@ -127,7 +125,7 @@ class CastTests(TestCase):
The SQL for the Cast expression is wrapped with parentheses in case The SQL for the Cast expression is wrapped with parentheses in case
it's a complex expression. it's a complex expression.
""" """
list(Author.objects.annotate(cast_float=Cast(Avg('age'), models.FloatField()))) list(Author.objects.annotate(cast_float=Cast(models.Avg('age'), models.FloatField())))
self.assertIn('(AVG("db_functions_author"."age"))::double precision', connection.queries[-1]['sql']) self.assertIn('(AVG("db_functions_author"."age"))::double precision', connection.queries[-1]['sql'])
def test_cast_to_text_field(self): def test_cast_to_text_field(self):

View File

@ -1,9 +1,8 @@
import unittest import unittest
from django.db import connection from django.db import NotSupportedError, connection
from django.db.models import CharField from django.db.models import CharField
from django.db.models.functions import SHA224 from django.db.models.functions import SHA224
from django.db.utils import NotSupportedError
from django.test import TestCase from django.test import TestCase
from django.test.utils import register_lookup from django.test.utils import register_lookup

View File

@ -2,8 +2,8 @@
import unittest import unittest
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, connection from django.db import DEFAULT_DB_ALIAS, ProgrammingError, connection
from django.db.utils import ConnectionHandler, ProgrammingError, load_backend from django.db.utils import ConnectionHandler, load_backend
from django.test import SimpleTestCase, TestCase from django.test import SimpleTestCase, TestCase

View File

@ -1,9 +1,8 @@
from math import ceil from math import ceil
from django.db import connection, models from django.db import connection, models
from django.db.models.deletion import ( from django.db.models import ProtectedError, RestrictedError
Collector, ProtectedError, RestrictedError, from django.db.models.deletion import Collector
)
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature

View File

@ -1,5 +1,4 @@
from django.db.models.aggregates import Sum from django.db.models import F, Sum
from django.db.models.expressions import F
from django.test import TestCase from django.test import TestCase
from .models import Company, Employee from .models import Company, Employee

View File

@ -6,16 +6,14 @@ from copy import deepcopy
from unittest import mock from unittest import mock
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models from django.db import DatabaseError, connection
from django.db.models import CharField, Q, TimeField, UUIDField from django.db.models import (
from django.db.models.aggregates import ( Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField,
Avg, Count, Max, Min, StdDev, Sum, Variance, DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F,
) Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev,
from django.db.models.expressions import ( Subquery, Sum, TimeField, UUIDField, Value, Variance, When,
Case, Col, Combinable, Exists, Expression, ExpressionList,
ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, RawSQL, Ref,
Subquery, Value, When,
) )
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
from django.db.models.functions import ( from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper, Coalesce, Concat, Length, Lower, Substr, Upper,
) )
@ -57,7 +55,7 @@ class BasicExpressionsTests(TestCase):
).values('num_employees', 'salaries').aggregate( ).values('num_employees', 'salaries').aggregate(
result=Sum( result=Sum(
F('salaries') + F('num_employees'), F('salaries') + F('num_employees'),
output_field=models.IntegerField() output_field=IntegerField()
), ),
) )
self.assertEqual(companies['result'], 2395) self.assertEqual(companies['result'], 2395)
@ -79,7 +77,7 @@ class BasicExpressionsTests(TestCase):
def test_filtering_on_annotate_that_uses_q(self): def test_filtering_on_annotate_that_uses_q(self):
self.assertEqual( self.assertEqual(
Company.objects.annotate( Company.objects.annotate(
num_employees_check=ExpressionWrapper(Q(num_employees__gt=3), output_field=models.BooleanField()) num_employees_check=ExpressionWrapper(Q(num_employees__gt=3), output_field=BooleanField())
).filter(num_employees_check=True).count(), ).filter(num_employees_check=True).count(),
2, 2,
) )
@ -87,7 +85,7 @@ class BasicExpressionsTests(TestCase):
def test_filtering_on_q_that_is_boolean(self): def test_filtering_on_q_that_is_boolean(self):
self.assertEqual( self.assertEqual(
Company.objects.filter( Company.objects.filter(
ExpressionWrapper(Q(num_employees__gt=3), output_field=models.BooleanField()) ExpressionWrapper(Q(num_employees__gt=3), output_field=BooleanField())
).count(), ).count(),
2, 2,
) )
@ -95,7 +93,7 @@ class BasicExpressionsTests(TestCase):
def test_filtering_on_rawsql_that_is_boolean(self): def test_filtering_on_rawsql_that_is_boolean(self):
self.assertEqual( self.assertEqual(
Company.objects.filter( Company.objects.filter(
RawSQL('num_employees > %s', (3,), output_field=models.BooleanField()), RawSQL('num_employees > %s', (3,), output_field=BooleanField()),
).count(), ).count(),
2, 2,
) )
@ -438,7 +436,7 @@ class BasicExpressionsTests(TestCase):
def test_exist_single_field_output_field(self): def test_exist_single_field_output_field(self):
queryset = Company.objects.values('pk') queryset = Company.objects.values('pk')
self.assertIsInstance(Exists(queryset).output_field, models.BooleanField) self.assertIsInstance(Exists(queryset).output_field, BooleanField)
def test_subquery(self): def test_subquery(self):
Company.objects.filter(name='Example Inc.').update( Company.objects.filter(name='Example Inc.').update(
@ -452,8 +450,8 @@ class BasicExpressionsTests(TestCase):
is_ceo_of_small_company=Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))), is_ceo_of_small_company=Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))),
is_ceo_small_2=~~Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))), is_ceo_small_2=~~Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))),
largest_company=Subquery(Company.objects.order_by('-num_employees').filter( largest_company=Subquery(Company.objects.order_by('-num_employees').filter(
models.Q(ceo=OuterRef('pk')) | models.Q(point_of_contact=OuterRef('pk')) Q(ceo=OuterRef('pk')) | Q(point_of_contact=OuterRef('pk'))
).values('name')[:1], output_field=models.CharField()) ).values('name')[:1], output_field=CharField())
).values( ).values(
'firstname', 'firstname',
'is_point_of_contact', 'is_point_of_contact',
@ -533,7 +531,7 @@ class BasicExpressionsTests(TestCase):
contrived = Employee.objects.annotate( contrived = Employee.objects.annotate(
is_point_of_contact=Subquery( is_point_of_contact=Subquery(
outer.filter(pk=OuterRef('pk')).values('is_point_of_contact'), outer.filter(pk=OuterRef('pk')).values('is_point_of_contact'),
output_field=models.BooleanField(), output_field=BooleanField(),
), ),
) )
self.assertCountEqual(contrived.values_list(), outer.values_list()) self.assertCountEqual(contrived.values_list(), outer.values_list())
@ -564,7 +562,7 @@ class BasicExpressionsTests(TestCase):
]) ])
inner = Time.objects.filter(time=OuterRef(OuterRef('time')), pk=OuterRef('start')).values('time') inner = Time.objects.filter(time=OuterRef(OuterRef('time')), pk=OuterRef('start')).values('time')
middle = SimulationRun.objects.annotate(other=Subquery(inner)).values('other')[:1] middle = SimulationRun.objects.annotate(other=Subquery(inner)).values('other')[:1]
outer = Time.objects.annotate(other=Subquery(middle, output_field=models.TimeField())) outer = Time.objects.annotate(other=Subquery(middle, output_field=TimeField()))
# This is a contrived example. It exercises the double OuterRef form. # This is a contrived example. It exercises the double OuterRef form.
self.assertCountEqual(outer, [first, second, third]) self.assertCountEqual(outer, [first, second, third])
@ -574,7 +572,7 @@ class BasicExpressionsTests(TestCase):
SimulationRun.objects.create(start=first, end=second, midpoint='12:00') SimulationRun.objects.create(start=first, end=second, midpoint='12:00')
inner = SimulationRun.objects.filter(start=OuterRef(OuterRef('pk'))).values('start') inner = SimulationRun.objects.filter(start=OuterRef(OuterRef('pk'))).values('start')
middle = Time.objects.annotate(other=Subquery(inner)).values('other')[:1] middle = Time.objects.annotate(other=Subquery(inner)).values('other')[:1]
outer = Time.objects.annotate(other=Subquery(middle, output_field=models.IntegerField())) outer = Time.objects.annotate(other=Subquery(middle, output_field=IntegerField()))
# This exercises the double OuterRef form with AutoField as pk. # This exercises the double OuterRef form with AutoField as pk.
self.assertCountEqual(outer, [first, second]) self.assertCountEqual(outer, [first, second])
@ -582,7 +580,7 @@ class BasicExpressionsTests(TestCase):
Company.objects.filter(num_employees__lt=50).update(ceo=Employee.objects.get(firstname='Frank')) Company.objects.filter(num_employees__lt=50).update(ceo=Employee.objects.get(firstname='Frank'))
inner = Company.objects.filter( inner = Company.objects.filter(
ceo=OuterRef('pk') ceo=OuterRef('pk')
).values('ceo').annotate(total_employees=models.Sum('num_employees')).values('total_employees') ).values('ceo').annotate(total_employees=Sum('num_employees')).values('total_employees')
outer = Employee.objects.annotate(total_employees=Subquery(inner)).filter(salary__lte=Subquery(inner)) outer = Employee.objects.annotate(total_employees=Subquery(inner)).filter(salary__lte=Subquery(inner))
self.assertSequenceEqual( self.assertSequenceEqual(
outer.order_by('-total_employees').values('salary', 'total_employees'), outer.order_by('-total_employees').values('salary', 'total_employees'),
@ -632,7 +630,7 @@ class BasicExpressionsTests(TestCase):
def test_explicit_output_field(self): def test_explicit_output_field(self):
class FuncA(Func): class FuncA(Func):
output_field = models.CharField() output_field = CharField()
class FuncB(Func): class FuncB(Func):
pass pass
@ -656,13 +654,13 @@ class BasicExpressionsTests(TestCase):
Company.objects.annotate( Company.objects.annotate(
salary_raise=OuterRef('num_employees') + F('num_employees'), salary_raise=OuterRef('num_employees') + F('num_employees'),
).order_by('-salary_raise').values('salary_raise')[:1], ).order_by('-salary_raise').values('salary_raise')[:1],
output_field=models.IntegerField(), output_field=IntegerField(),
), ),
).get(pk=self.gmbh.pk) ).get(pk=self.gmbh.pk)
self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332) self.assertEqual(gmbh_salary.max_ceo_salary_raise, 2332)
def test_pickle_expression(self): def test_pickle_expression(self):
expr = Value(1, output_field=models.IntegerField()) expr = Value(1, output_field=IntegerField())
expr.convert_value # populate cached property expr.convert_value # populate cached property
self.assertEqual(pickle.loads(pickle.dumps(expr)), expr) self.assertEqual(pickle.loads(pickle.dumps(expr)), expr)
@ -697,7 +695,7 @@ class BasicExpressionsTests(TestCase):
When(Exists(is_ceo), then=True), When(Exists(is_ceo), then=True),
When(Exists(is_poc), then=True), When(Exists(is_poc), then=True),
default=False, default=False,
output_field=models.BooleanField(), output_field=BooleanField(),
), ),
) )
self.assertSequenceEqual(qs, [self.example_inc.ceo, self.foobar_ltd.ceo, self.max]) self.assertSequenceEqual(qs, [self.example_inc.ceo, self.foobar_ltd.ceo, self.max])
@ -986,18 +984,18 @@ class SimpleExpressionTests(SimpleTestCase):
def test_equal(self): def test_equal(self):
self.assertEqual(Expression(), Expression()) self.assertEqual(Expression(), Expression())
self.assertEqual( self.assertEqual(
Expression(models.IntegerField()), Expression(IntegerField()),
Expression(output_field=models.IntegerField()) Expression(output_field=IntegerField())
) )
self.assertEqual(Expression(models.IntegerField()), mock.ANY) self.assertEqual(Expression(IntegerField()), mock.ANY)
self.assertNotEqual( self.assertNotEqual(
Expression(models.IntegerField()), Expression(IntegerField()),
Expression(models.CharField()) Expression(CharField())
) )
class TestModel(models.Model): class TestModel(Model):
field = models.IntegerField() field = IntegerField()
other_field = models.IntegerField() other_field = IntegerField()
self.assertNotEqual( self.assertNotEqual(
Expression(TestModel._meta.get_field('field')), Expression(TestModel._meta.get_field('field')),
@ -1007,17 +1005,17 @@ class SimpleExpressionTests(SimpleTestCase):
def test_hash(self): def test_hash(self):
self.assertEqual(hash(Expression()), hash(Expression())) self.assertEqual(hash(Expression()), hash(Expression()))
self.assertEqual( self.assertEqual(
hash(Expression(models.IntegerField())), hash(Expression(IntegerField())),
hash(Expression(output_field=models.IntegerField())) hash(Expression(output_field=IntegerField()))
) )
self.assertNotEqual( self.assertNotEqual(
hash(Expression(models.IntegerField())), hash(Expression(IntegerField())),
hash(Expression(models.CharField())), hash(Expression(CharField())),
) )
class TestModel(models.Model): class TestModel(Model):
field = models.IntegerField() field = IntegerField()
other_field = models.IntegerField() other_field = IntegerField()
self.assertNotEqual( self.assertNotEqual(
hash(Expression(TestModel._meta.get_field('field'))), hash(Expression(TestModel._meta.get_field('field'))),
@ -1392,8 +1390,8 @@ class FTimeDeltaTests(TestCase):
self.assertEqual(delta_math, ['e4']) self.assertEqual(delta_math, ['e4'])
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper( queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('start') + Value(None, output_field=models.DurationField()), F('start') + Value(None, output_field=DurationField()),
output_field=models.DateTimeField(), output_field=DateTimeField(),
)) ))
self.assertIsNone(queryset.first().shifted) self.assertIsNone(queryset.first().shifted)
@ -1401,7 +1399,7 @@ class FTimeDeltaTests(TestCase):
def test_date_subtraction(self): def test_date_subtraction(self):
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
completion_duration=ExpressionWrapper( completion_duration=ExpressionWrapper(
F('completed') - F('assigned'), output_field=models.DurationField() F('completed') - F('assigned'), output_field=DurationField()
) )
) )
@ -1415,14 +1413,14 @@ class FTimeDeltaTests(TestCase):
self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'}) self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'})
queryset = Experiment.objects.annotate(difference=ExpressionWrapper( queryset = Experiment.objects.annotate(difference=ExpressionWrapper(
F('completed') - Value(None, output_field=models.DateField()), F('completed') - Value(None, output_field=DateField()),
output_field=models.DurationField(), output_field=DurationField(),
)) ))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper( queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('completed') - Value(None, output_field=models.DurationField()), F('completed') - Value(None, output_field=DurationField()),
output_field=models.DateField(), output_field=DateField(),
)) ))
self.assertIsNone(queryset.first().shifted) self.assertIsNone(queryset.first().shifted)
@ -1431,7 +1429,7 @@ class FTimeDeltaTests(TestCase):
subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('completed') subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('completed')
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
difference=ExpressionWrapper( difference=ExpressionWrapper(
subquery - F('completed'), output_field=models.DurationField(), subquery - F('completed'), output_field=DurationField(),
), ),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1441,8 +1439,8 @@ class FTimeDeltaTests(TestCase):
Time.objects.create(time=datetime.time(12, 30, 15, 2345)) Time.objects.create(time=datetime.time(12, 30, 15, 2345))
queryset = Time.objects.annotate( queryset = Time.objects.annotate(
difference=ExpressionWrapper( difference=ExpressionWrapper(
F('time') - Value(datetime.time(11, 15, 0), output_field=models.TimeField()), F('time') - Value(datetime.time(11, 15, 0), output_field=TimeField()),
output_field=models.DurationField(), output_field=DurationField(),
) )
) )
self.assertEqual( self.assertEqual(
@ -1451,14 +1449,14 @@ class FTimeDeltaTests(TestCase):
) )
queryset = Time.objects.annotate(difference=ExpressionWrapper( queryset = Time.objects.annotate(difference=ExpressionWrapper(
F('time') - Value(None, output_field=models.TimeField()), F('time') - Value(None, output_field=TimeField()),
output_field=models.DurationField(), output_field=DurationField(),
)) ))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Time.objects.annotate(shifted=ExpressionWrapper( queryset = Time.objects.annotate(shifted=ExpressionWrapper(
F('time') - Value(None, output_field=models.DurationField()), F('time') - Value(None, output_field=DurationField()),
output_field=models.TimeField(), output_field=TimeField(),
)) ))
self.assertIsNone(queryset.first().shifted) self.assertIsNone(queryset.first().shifted)
@ -1468,7 +1466,7 @@ class FTimeDeltaTests(TestCase):
subquery = Time.objects.filter(pk=OuterRef('pk')).values('time') subquery = Time.objects.filter(pk=OuterRef('pk')).values('time')
queryset = Time.objects.annotate( queryset = Time.objects.annotate(
difference=ExpressionWrapper( difference=ExpressionWrapper(
subquery - F('time'), output_field=models.DurationField(), subquery - F('time'), output_field=DurationField(),
), ),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1486,14 +1484,14 @@ class FTimeDeltaTests(TestCase):
self.assertEqual(over_estimate, ['e4']) self.assertEqual(over_estimate, ['e4'])
queryset = Experiment.objects.annotate(difference=ExpressionWrapper( queryset = Experiment.objects.annotate(difference=ExpressionWrapper(
F('start') - Value(None, output_field=models.DateTimeField()), F('start') - Value(None, output_field=DateTimeField()),
output_field=models.DurationField(), output_field=DurationField(),
)) ))
self.assertIsNone(queryset.first().difference) self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper( queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('start') - Value(None, output_field=models.DurationField()), F('start') - Value(None, output_field=DurationField()),
output_field=models.DateTimeField(), output_field=DateTimeField(),
)) ))
self.assertIsNone(queryset.first().shifted) self.assertIsNone(queryset.first().shifted)
@ -1502,7 +1500,7 @@ class FTimeDeltaTests(TestCase):
subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('start') subquery = Experiment.objects.filter(pk=OuterRef('pk')).values('start')
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
difference=ExpressionWrapper( difference=ExpressionWrapper(
subquery - F('start'), output_field=models.DurationField(), subquery - F('start'), output_field=DurationField(),
), ),
).filter(difference=datetime.timedelta()) ).filter(difference=datetime.timedelta())
self.assertTrue(queryset.exists()) self.assertTrue(queryset.exists())
@ -1512,7 +1510,7 @@ class FTimeDeltaTests(TestCase):
delta = datetime.timedelta(microseconds=8999999999999999) delta = datetime.timedelta(microseconds=8999999999999999)
Experiment.objects.update(end=F('start') + delta) Experiment.objects.update(end=F('start') + delta)
qs = Experiment.objects.annotate( qs = Experiment.objects.annotate(
delta=ExpressionWrapper(F('end') - F('start'), output_field=models.DurationField()) delta=ExpressionWrapper(F('end') - F('start'), output_field=DurationField())
) )
for e in qs: for e in qs:
self.assertEqual(e.delta, delta) self.assertEqual(e.delta, delta)
@ -1530,14 +1528,14 @@ class FTimeDeltaTests(TestCase):
delta = datetime.timedelta(microseconds=8999999999999999) delta = datetime.timedelta(microseconds=8999999999999999)
qs = Experiment.objects.annotate(dt=ExpressionWrapper( qs = Experiment.objects.annotate(dt=ExpressionWrapper(
F('start') + delta, F('start') + delta,
output_field=models.DateTimeField(), output_field=DateTimeField(),
)) ))
for e in qs: for e in qs:
self.assertEqual(e.dt, e.start + delta) self.assertEqual(e.dt, e.start + delta)
def test_date_minus_duration(self): def test_date_minus_duration(self):
more_than_4_days = Experiment.objects.filter( more_than_4_days = Experiment.objects.filter(
assigned__lt=F('completed') - Value(datetime.timedelta(days=4), output_field=models.DurationField()) assigned__lt=F('completed') - Value(datetime.timedelta(days=4), output_field=DurationField())
) )
self.assertQuerysetEqual(more_than_4_days, ['e3', 'e4', 'e5'], lambda e: e.name) self.assertQuerysetEqual(more_than_4_days, ['e3', 'e4', 'e5'], lambda e: e.name)
@ -1661,7 +1659,7 @@ class ReprTests(SimpleTestCase):
self.assertEqual(repr(F('published')), "F(published)") self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>") self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>")
self.assertEqual( self.assertEqual(
repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())), repr(ExpressionWrapper(F('cost') + F('tax'), IntegerField())),
"ExpressionWrapper(F(cost) + F(tax))" "ExpressionWrapper(F(cost) + F(tax))"
) )
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)") self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")

View File

@ -5,9 +5,11 @@ from operator import attrgetter, itemgetter
from uuid import UUID from uuid import UUID
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import models from django.db.models import (
from django.db.models import F, Max, Min, Q, Sum, Value BinaryField, Case, CharField, Count, DurationField, F,
from django.db.models.expressions import Case, When GenericIPAddressField, IntegerField, Max, Min, Q, Sum, TextField,
TimeField, UUIDField, Value, When,
)
from django.test import SimpleTestCase, TestCase from django.test import SimpleTestCase, TestCase
from .models import CaseTestModel, Client, FKCaseTestModel, O2OCaseTestModel from .models import CaseTestModel, Client, FKCaseTestModel, O2OCaseTestModel
@ -57,7 +59,7 @@ class CaseExpressionTests(TestCase):
# GROUP BY on Oracle fails with TextField/BinaryField; see #24096. # GROUP BY on Oracle fails with TextField/BinaryField; see #24096.
cls.non_lob_fields = [ cls.non_lob_fields = [
f.name for f in CaseTestModel._meta.get_fields() f.name for f in CaseTestModel._meta.get_fields()
if not (f.is_relation and f.auto_created) and not isinstance(f, (models.BinaryField, models.TextField)) if not (f.is_relation and f.auto_created) and not isinstance(f, (BinaryField, TextField))
] ]
def test_annotate(self): def test_annotate(self):
@ -66,7 +68,7 @@ class CaseExpressionTests(TestCase):
When(integer=1, then=Value('one')), When(integer=1, then=Value('one')),
When(integer=2, then=Value('two')), When(integer=2, then=Value('two')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')], [(1, 'one'), (2, 'two'), (3, 'other'), (2, 'two'), (3, 'other'), (3, 'other'), (4, 'other')],
transform=attrgetter('integer', 'test') transform=attrgetter('integer', 'test')
@ -77,7 +79,7 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.annotate(test=Case( CaseTestModel.objects.annotate(test=Case(
When(integer=1, then=1), When(integer=1, then=1),
When(integer=2, then=2), When(integer=2, then=2),
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)],
transform=attrgetter('integer', 'test') transform=attrgetter('integer', 'test')
@ -99,7 +101,7 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.annotate(f_test=Case( CaseTestModel.objects.annotate(f_test=Case(
When(integer2=F('integer'), then=Value('equal')), When(integer2=F('integer'), then=Value('equal')),
When(integer2=F('integer') + 1, then=Value('+1')), When(integer2=F('integer') + 1, then=Value('+1')),
output_field=models.CharField(), output_field=CharField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')],
transform=attrgetter('integer', 'f_test') transform=attrgetter('integer', 'f_test')
@ -133,7 +135,7 @@ class CaseExpressionTests(TestCase):
When(integer2=F('o2o_rel__integer'), then=Value('equal')), When(integer2=F('o2o_rel__integer'), then=Value('equal')),
When(integer2=F('o2o_rel__integer') + 1, then=Value('+1')), When(integer2=F('o2o_rel__integer') + 1, then=Value('+1')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, 'other')], [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, 'other')],
transform=attrgetter('integer', 'join_test') transform=attrgetter('integer', 'join_test')
@ -146,7 +148,7 @@ class CaseExpressionTests(TestCase):
When(o2o_rel__integer=2, then=Value('two')), When(o2o_rel__integer=2, then=Value('two')),
When(o2o_rel__integer=3, then=Value('three')), When(o2o_rel__integer=3, then=Value('three')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 'one'), (2, 'two'), (3, 'three'), (2, 'two'), (3, 'three'), (3, 'three'), (4, 'one')], [(1, 'one'), (2, 'two'), (3, 'three'), (2, 'two'), (3, 'three'), (3, 'three'), (4, 'one')],
transform=attrgetter('integer', 'join_test') transform=attrgetter('integer', 'join_test')
@ -176,7 +178,7 @@ class CaseExpressionTests(TestCase):
f_test=Case( f_test=Case(
When(integer2=F('integer'), then=Value('equal')), When(integer2=F('integer'), then=Value('equal')),
When(integer2=F('f_plus_1'), then=Value('+1')), When(integer2=F('f_plus_1'), then=Value('+1')),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')], [(1, 'equal'), (2, '+1'), (3, '+1'), (2, 'equal'), (3, '+1'), (3, 'equal'), (4, '+1')],
@ -193,7 +195,7 @@ class CaseExpressionTests(TestCase):
When(f_minus_2=0, then=Value('zero')), When(f_minus_2=0, then=Value('zero')),
When(f_minus_2=1, then=Value('one')), When(f_minus_2=1, then=Value('one')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(1, 'negative one'), (2, 'zero'), (3, 'one'), (2, 'zero'), (3, 'one'), (3, 'one'), (4, 'other')], [(1, 'negative one'), (2, 'zero'), (3, 'one'), (2, 'zero'), (3, 'one'), (3, 'one'), (4, 'other')],
@ -224,7 +226,7 @@ class CaseExpressionTests(TestCase):
test=Case( test=Case(
When(integer2=F('min'), then=Value('min')), When(integer2=F('min'), then=Value('min')),
When(integer2=F('max'), then=Value('max')), When(integer2=F('max'), then=Value('max')),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(1, 1, 'min'), (2, 3, 'max'), (3, 4, 'max'), (2, 2, 'min'), (3, 4, 'max'), (3, 3, 'min'), (4, 5, 'min')], [(1, 1, 'min'), (2, 3, 'max'), (3, 4, 'max'), (2, 2, 'min'), (3, 4, 'max'), (3, 3, 'min'), (4, 5, 'min')],
@ -240,7 +242,7 @@ class CaseExpressionTests(TestCase):
When(max=3, then=Value('max = 3')), When(max=3, then=Value('max = 3')),
When(max=4, then=Value('max = 4')), When(max=4, then=Value('max = 4')),
default=Value(''), default=Value(''),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(1, 1, ''), (2, 3, 'max = 3'), (3, 4, 'max = 4'), (2, 3, 'max = 3'), [(1, 1, ''), (2, 3, 'max = 3'), (3, 4, 'max = 4'), (2, 3, 'max = 3'),
@ -254,7 +256,7 @@ class CaseExpressionTests(TestCase):
When(integer=1, then=Value('one')), When(integer=1, then=Value('one')),
When(integer=2, then=Value('two')), When(integer=2, then=Value('two')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
)).exclude(test='other').order_by('pk'), )).exclude(test='other').order_by('pk'),
[(1, 'one'), (2, 'two'), (2, 'two')], [(1, 'one'), (2, 'two'), (2, 'two')],
transform=attrgetter('integer', 'test') transform=attrgetter('integer', 'test')
@ -267,7 +269,7 @@ class CaseExpressionTests(TestCase):
When(integer=2, then=Value('two')), When(integer=2, then=Value('two')),
When(integer=3, then=Value('three')), When(integer=3, then=Value('three')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
)).order_by('test').values_list('integer', flat=True)), )).order_by('test').values_list('integer', flat=True)),
[1, 4, 3, 3, 3, 2, 2] [1, 4, 3, 3, 3, 2, 2]
) )
@ -276,7 +278,7 @@ class CaseExpressionTests(TestCase):
objects = CaseTestModel.objects.annotate( objects = CaseTestModel.objects.annotate(
selected=Case( selected=Case(
When(pk__in=[], then=Value('selected')), When(pk__in=[], then=Value('selected')),
default=Value('not selected'), output_field=models.CharField() default=Value('not selected'), output_field=CharField()
) )
) )
self.assertEqual(len(objects), CaseTestModel.objects.count()) self.assertEqual(len(objects), CaseTestModel.objects.count())
@ -289,7 +291,7 @@ class CaseExpressionTests(TestCase):
When(integer=1, then=2), When(integer=1, then=2),
When(integer=2, then=1), When(integer=2, then=1),
default=3, default=3,
output_field=models.IntegerField(), output_field=IntegerField(),
) + 1, ) + 1,
).order_by('pk'), ).order_by('pk'),
[(1, 3), (2, 2), (3, 4), (2, 2), (3, 4), (3, 4), (4, 4)], [(1, 3), (2, 2), (3, 4), (2, 2), (3, 4), (3, 4), (4, 4)],
@ -303,7 +305,7 @@ class CaseExpressionTests(TestCase):
test=Case( test=Case(
When(integer=F('integer2'), then='pk'), When(integer=F('integer2'), then='pk'),
When(integer=4, then='pk'), When(integer=4, then='pk'),
output_field=models.IntegerField(), output_field=IntegerField(),
), ),
).values('test')).order_by('pk'), ).values('test')).order_by('pk'),
[(1, 1), (2, 2), (3, 3), (4, 5)], [(1, 1), (2, 2), (3, 3), (4, 5)],
@ -314,7 +316,7 @@ class CaseExpressionTests(TestCase):
SOME_CASE = Case( SOME_CASE = Case(
When(pk=0, then=Value('0')), When(pk=0, then=Value('0')),
default=Value('1'), default=Value('1'),
output_field=models.CharField(), output_field=CharField(),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate(somecase=SOME_CASE).order_by('pk'), CaseTestModel.objects.annotate(somecase=SOME_CASE).order_by('pk'),
@ -325,21 +327,21 @@ class CaseExpressionTests(TestCase):
def test_aggregate(self): def test_aggregate(self):
self.assertEqual( self.assertEqual(
CaseTestModel.objects.aggregate( CaseTestModel.objects.aggregate(
one=models.Sum(Case( one=Sum(Case(
When(integer=1, then=1), When(integer=1, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
two=models.Sum(Case( two=Sum(Case(
When(integer=2, then=1), When(integer=2, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
three=models.Sum(Case( three=Sum(Case(
When(integer=3, then=1), When(integer=3, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
four=models.Sum(Case( four=Sum(Case(
When(integer=4, then=1), When(integer=4, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
), ),
{'one': 1, 'two': 2, 'three': 3, 'four': 1} {'one': 1, 'two': 2, 'three': 3, 'four': 1}
@ -348,9 +350,9 @@ class CaseExpressionTests(TestCase):
def test_aggregate_with_expression_as_value(self): def test_aggregate_with_expression_as_value(self):
self.assertEqual( self.assertEqual(
CaseTestModel.objects.aggregate( CaseTestModel.objects.aggregate(
one=models.Sum(Case(When(integer=1, then='integer'))), one=Sum(Case(When(integer=1, then='integer'))),
two=models.Sum(Case(When(integer=2, then=F('integer') - 1))), two=Sum(Case(When(integer=2, then=F('integer') - 1))),
three=models.Sum(Case(When(integer=3, then=F('integer') + 1))), three=Sum(Case(When(integer=3, then=F('integer') + 1))),
), ),
{'one': 1, 'two': 2, 'three': 12} {'one': 1, 'two': 2, 'three': 12}
) )
@ -358,13 +360,13 @@ class CaseExpressionTests(TestCase):
def test_aggregate_with_expression_as_condition(self): def test_aggregate_with_expression_as_condition(self):
self.assertEqual( self.assertEqual(
CaseTestModel.objects.aggregate( CaseTestModel.objects.aggregate(
equal=models.Sum(Case( equal=Sum(Case(
When(integer2=F('integer'), then=1), When(integer2=F('integer'), then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
plus_one=models.Sum(Case( plus_one=Sum(Case(
When(integer2=F('integer') + 1, then=1), When(integer2=F('integer') + 1, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
), ),
{'equal': 3, 'plus_one': 4} {'equal': 3, 'plus_one': 4}
@ -376,7 +378,7 @@ class CaseExpressionTests(TestCase):
When(integer=2, then=3), When(integer=2, then=3),
When(integer=3, then=4), When(integer=3, then=4),
default=1, default=1,
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)], [(1, 1), (2, 3), (3, 4), (3, 4)],
transform=attrgetter('integer', 'integer2') transform=attrgetter('integer', 'integer2')
@ -387,7 +389,7 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.filter(integer2=Case( CaseTestModel.objects.filter(integer2=Case(
When(integer=2, then=3), When(integer=2, then=3),
When(integer=3, then=4), When(integer=3, then=4),
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(2, 3), (3, 4), (3, 4)], [(2, 3), (3, 4), (3, 4)],
transform=attrgetter('integer', 'integer2') transform=attrgetter('integer', 'integer2')
@ -409,7 +411,7 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.filter(string=Case( CaseTestModel.objects.filter(string=Case(
When(integer2=F('integer'), then=Value('2')), When(integer2=F('integer'), then=Value('2')),
When(integer2=F('integer') + 1, then=Value('3')), When(integer2=F('integer') + 1, then=Value('3')),
output_field=models.CharField(), output_field=CharField(),
)).order_by('pk'), )).order_by('pk'),
[(3, 4, '3'), (2, 2, '2'), (3, 4, '3')], [(3, 4, '3'), (2, 2, '2'), (3, 4, '3')],
transform=attrgetter('integer', 'integer2', 'string') transform=attrgetter('integer', 'integer2', 'string')
@ -431,7 +433,7 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.filter(integer=Case( CaseTestModel.objects.filter(integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=2), When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=3), When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(2, 3), (3, 3)], [(2, 3), (3, 3)],
transform=attrgetter('integer', 'integer2') transform=attrgetter('integer', 'integer2')
@ -443,7 +445,7 @@ class CaseExpressionTests(TestCase):
When(o2o_rel__integer=1, then=1), When(o2o_rel__integer=1, then=1),
When(o2o_rel__integer=2, then=3), When(o2o_rel__integer=2, then=3),
When(o2o_rel__integer=3, then=4), When(o2o_rel__integer=3, then=4),
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)], [(1, 1), (2, 3), (3, 4), (3, 4)],
transform=attrgetter('integer', 'integer2') transform=attrgetter('integer', 'integer2')
@ -472,7 +474,7 @@ class CaseExpressionTests(TestCase):
integer=Case( integer=Case(
When(integer2=F('integer'), then=2), When(integer2=F('integer'), then=2),
When(integer2=F('f_plus_1'), then=3), When(integer2=F('f_plus_1'), then=3),
output_field=models.IntegerField(), output_field=IntegerField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(3, 4), (2, 2), (3, 4)], [(3, 4), (2, 2), (3, 4)],
@ -488,7 +490,7 @@ class CaseExpressionTests(TestCase):
When(f_plus_1=3, then=3), When(f_plus_1=3, then=3),
When(f_plus_1=4, then=4), When(f_plus_1=4, then=4),
default=1, default=1,
output_field=models.IntegerField(), output_field=IntegerField(),
), ),
).order_by('pk'), ).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)], [(1, 1), (2, 3), (3, 4), (3, 4)],
@ -599,7 +601,7 @@ class CaseExpressionTests(TestCase):
integer=Case( integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=2), When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=3), When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(), output_field=IntegerField(),
), ),
) )
@ -611,7 +613,7 @@ class CaseExpressionTests(TestCase):
When(o2o_rel__integer=2, then=Value('two')), When(o2o_rel__integer=2, then=Value('two')),
When(o2o_rel__integer=3, then=Value('three')), When(o2o_rel__integer=3, then=Value('three')),
default=Value('other'), default=Value('other'),
output_field=models.CharField(), output_field=CharField(),
), ),
) )
@ -631,9 +633,9 @@ class CaseExpressionTests(TestCase):
def test_update_binary(self): def test_update_binary(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
binary=Case( binary=Case(
When(integer=1, then=Value(b'one', output_field=models.BinaryField())), When(integer=1, then=Value(b'one', output_field=BinaryField())),
When(integer=2, then=Value(b'two', output_field=models.BinaryField())), When(integer=2, then=Value(b'two', output_field=BinaryField())),
default=Value(b'', output_field=models.BinaryField()), default=Value(b'', output_field=BinaryField()),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -714,8 +716,8 @@ class CaseExpressionTests(TestCase):
duration=Case( duration=Case(
# fails on sqlite if output_field is not set explicitly on all # fails on sqlite if output_field is not set explicitly on all
# Values containing timedeltas # Values containing timedeltas
When(integer=1, then=Value(timedelta(1), output_field=models.DurationField())), When(integer=1, then=Value(timedelta(1), output_field=DurationField())),
When(integer=2, then=Value(timedelta(2), output_field=models.DurationField())), When(integer=2, then=Value(timedelta(2), output_field=DurationField())),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -798,7 +800,7 @@ class CaseExpressionTests(TestCase):
# fails on postgresql if output_field is not set explicitly # fails on postgresql if output_field is not set explicitly
When(integer=1, then=Value('1.1.1.1')), When(integer=1, then=Value('1.1.1.1')),
When(integer=2, then=Value('2.2.2.2')), When(integer=2, then=Value('2.2.2.2')),
output_field=models.GenericIPAddressField(), output_field=GenericIPAddressField(),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -902,8 +904,8 @@ class CaseExpressionTests(TestCase):
def test_update_string(self): def test_update_string(self):
CaseTestModel.objects.filter(string__in=['1', '2']).update( CaseTestModel.objects.filter(string__in=['1', '2']).update(
string=Case( string=Case(
When(integer=1, then=Value('1', output_field=models.CharField())), When(integer=1, then=Value('1', output_field=CharField())),
When(integer=2, then=Value('2', output_field=models.CharField())), When(integer=2, then=Value('2', output_field=CharField())),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -931,8 +933,8 @@ class CaseExpressionTests(TestCase):
time=Case( time=Case(
# fails on sqlite if output_field is not set explicitly on all # fails on sqlite if output_field is not set explicitly on all
# Values containing times # Values containing times
When(integer=1, then=Value(time(1), output_field=models.TimeField())), When(integer=1, then=Value(time(1), output_field=TimeField())),
When(integer=2, then=Value(time(2), output_field=models.TimeField())), When(integer=2, then=Value(time(2), output_field=TimeField())),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -965,11 +967,11 @@ class CaseExpressionTests(TestCase):
# Values containing UUIDs # Values containing UUIDs
When(integer=1, then=Value( When(integer=1, then=Value(
UUID('11111111111111111111111111111111'), UUID('11111111111111111111111111111111'),
output_field=models.UUIDField(), output_field=UUIDField(),
)), )),
When(integer=2, then=Value( When(integer=2, then=Value(
UUID('22222222222222222222222222222222'), UUID('22222222222222222222222222222222'),
output_field=models.UUIDField(), output_field=UUIDField(),
)), )),
), ),
) )
@ -1009,7 +1011,7 @@ class CaseExpressionTests(TestCase):
When(integer__lt=2, then=Value('less than 2')), When(integer__lt=2, then=Value('less than 2')),
When(integer__gt=2, then=Value('greater than 2')), When(integer__gt=2, then=Value('greater than 2')),
default=Value('equal to 2'), default=Value('equal to 2'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[ [
@ -1025,7 +1027,7 @@ class CaseExpressionTests(TestCase):
test=Case( test=Case(
When(integer=2, integer2=3, then=Value('when')), When(integer=2, integer2=3, then=Value('when')),
default=Value('default'), default=Value('default'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[ [
@ -1041,7 +1043,7 @@ class CaseExpressionTests(TestCase):
test=Case( test=Case(
When(Q(integer=2) | Q(integer2=3), then=Value('when')), When(Q(integer=2) | Q(integer2=3), then=Value('when')),
default=Value('default'), default=Value('default'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[ [
@ -1057,7 +1059,7 @@ class CaseExpressionTests(TestCase):
When(integer=1, then=2), When(integer=1, then=2),
When(integer=2, then=1), When(integer=2, then=1),
default=3, default=3,
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by('test', 'pk'), )).order_by('test', 'pk'),
[(2, 1), (2, 1), (1, 2)], [(2, 1), (2, 1), (1, 2)],
transform=attrgetter('integer', 'test') transform=attrgetter('integer', 'test')
@ -1069,7 +1071,7 @@ class CaseExpressionTests(TestCase):
When(integer=1, then=2), When(integer=1, then=2),
When(integer=2, then=1), When(integer=2, then=1),
default=3, default=3,
output_field=models.IntegerField(), output_field=IntegerField(),
)).order_by(F('test').asc(), 'pk'), )).order_by(F('test').asc(), 'pk'),
[(2, 1), (2, 1), (1, 2)], [(2, 1), (2, 1), (1, 2)],
transform=attrgetter('integer', 'test') transform=attrgetter('integer', 'test')
@ -1088,7 +1090,7 @@ class CaseExpressionTests(TestCase):
foo=Case( foo=Case(
When(fk_rel__pk=1, then=2), When(fk_rel__pk=1, then=2),
default=3, default=3,
output_field=models.IntegerField() output_field=IntegerField()
), ),
), ),
[(o, 3)], [(o, 3)],
@ -1100,7 +1102,7 @@ class CaseExpressionTests(TestCase):
foo=Case( foo=Case(
When(fk_rel__isnull=True, then=2), When(fk_rel__isnull=True, then=2),
default=3, default=3,
output_field=models.IntegerField() output_field=IntegerField()
), ),
), ),
[(o, 2)], [(o, 2)],
@ -1120,12 +1122,12 @@ class CaseExpressionTests(TestCase):
foo=Case( foo=Case(
When(fk_rel__pk=1, then=2), When(fk_rel__pk=1, then=2),
default=3, default=3,
output_field=models.IntegerField() output_field=IntegerField()
), ),
bar=Case( bar=Case(
When(fk_rel__pk=1, then=4), When(fk_rel__pk=1, then=4),
default=5, default=5,
output_field=models.IntegerField() output_field=IntegerField()
), ),
), ),
[(o, 3, 5)], [(o, 3, 5)],
@ -1137,12 +1139,12 @@ class CaseExpressionTests(TestCase):
foo=Case( foo=Case(
When(fk_rel__isnull=True, then=2), When(fk_rel__isnull=True, then=2),
default=3, default=3,
output_field=models.IntegerField() output_field=IntegerField()
), ),
bar=Case( bar=Case(
When(fk_rel__isnull=True, then=4), When(fk_rel__isnull=True, then=4),
default=5, default=5,
output_field=models.IntegerField() output_field=IntegerField()
), ),
), ),
[(o, 2, 4)], [(o, 2, 4)],
@ -1152,9 +1154,9 @@ class CaseExpressionTests(TestCase):
def test_m2m_exclude(self): def test_m2m_exclude(self):
CaseTestModel.objects.create(integer=10, integer2=1, string='1') CaseTestModel.objects.create(integer=10, integer2=1, string='1')
qs = CaseTestModel.objects.values_list('id', 'integer').annotate( qs = CaseTestModel.objects.values_list('id', 'integer').annotate(
cnt=models.Sum( cnt=Sum(
Case(When(~Q(fk_rel__integer=1), then=1), default=2), Case(When(~Q(fk_rel__integer=1), then=1), default=2),
output_field=models.IntegerField() output_field=IntegerField()
), ),
).order_by('integer') ).order_by('integer')
# The first o has 2 as its fk_rel__integer=1, thus it hits the # The first o has 2 as its fk_rel__integer=1, thus it hits the
@ -1174,14 +1176,14 @@ class CaseExpressionTests(TestCase):
# Need to use values before annotate so that Oracle will not group # Need to use values before annotate so that Oracle will not group
# by fields it isn't capable of grouping by. # by fields it isn't capable of grouping by.
qs = CaseTestModel.objects.values_list('id', 'integer').annotate( qs = CaseTestModel.objects.values_list('id', 'integer').annotate(
cnt=models.Sum( cnt=Sum(
Case(When(~Q(fk_rel__integer=1), then=1), default=2), Case(When(~Q(fk_rel__integer=1), then=1), default=2),
output_field=models.IntegerField() output_field=IntegerField()
), ),
).annotate( ).annotate(
cnt2=models.Sum( cnt2=Sum(
Case(When(~Q(fk_rel__integer=1), then=1), default=2), Case(When(~Q(fk_rel__integer=1), then=1), default=2),
output_field=models.IntegerField() output_field=IntegerField()
), ),
).order_by('integer') ).order_by('integer')
self.assertEqual(str(qs.query).count(' JOIN '), 1) self.assertEqual(str(qs.query).count(' JOIN '), 1)
@ -1218,7 +1220,7 @@ class CaseDocumentationExamples(TestCase):
When(account_type=Client.GOLD, then=Value('5%')), When(account_type=Client.GOLD, then=Value('5%')),
When(account_type=Client.PLATINUM, then=Value('10%')), When(account_type=Client.PLATINUM, then=Value('10%')),
default=Value('0%'), default=Value('0%'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')], [('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')],
@ -1234,7 +1236,7 @@ class CaseDocumentationExamples(TestCase):
When(registered_on__lte=a_year_ago, then=Value('10%')), When(registered_on__lte=a_year_ago, then=Value('10%')),
When(registered_on__lte=a_month_ago, then=Value('5%')), When(registered_on__lte=a_month_ago, then=Value('5%')),
default=Value('0%'), default=Value('0%'),
output_field=models.CharField(), output_field=CharField(),
), ),
).order_by('pk'), ).order_by('pk'),
[('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')], [('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')],
@ -1275,26 +1277,26 @@ class CaseDocumentationExamples(TestCase):
) )
self.assertEqual( self.assertEqual(
Client.objects.aggregate( Client.objects.aggregate(
regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)), regular=Count('pk', filter=Q(account_type=Client.REGULAR)),
gold=models.Count('pk', filter=Q(account_type=Client.GOLD)), gold=Count('pk', filter=Q(account_type=Client.GOLD)),
platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)), platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)),
), ),
{'regular': 2, 'gold': 1, 'platinum': 3} {'regular': 2, 'gold': 1, 'platinum': 3}
) )
# This was the example before the filter argument was added. # This was the example before the filter argument was added.
self.assertEqual( self.assertEqual(
Client.objects.aggregate( Client.objects.aggregate(
regular=models.Sum(Case( regular=Sum(Case(
When(account_type=Client.REGULAR, then=1), When(account_type=Client.REGULAR, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
gold=models.Sum(Case( gold=Sum(Case(
When(account_type=Client.GOLD, then=1), When(account_type=Client.GOLD, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
platinum=models.Sum(Case( platinum=Sum(Case(
When(account_type=Client.PLATINUM, then=1), When(account_type=Client.PLATINUM, then=1),
output_field=models.IntegerField(), output_field=IntegerField(),
)), )),
), ),
{'regular': 2, 'gold': 1, 'platinum': 3} {'regular': 2, 'gold': 1, 'platinum': 3}
@ -1318,12 +1320,12 @@ class CaseDocumentationExamples(TestCase):
expression_1 = Case( expression_1 = Case(
When(account_type__in=[Client.REGULAR, Client.GOLD], then=1), When(account_type__in=[Client.REGULAR, Client.GOLD], then=1),
default=2, default=2,
output_field=models.IntegerField(), output_field=IntegerField(),
) )
expression_2 = Case( expression_2 = Case(
When(account_type__in=(Client.REGULAR, Client.GOLD), then=1), When(account_type__in=(Client.REGULAR, Client.GOLD), then=1),
default=2, default=2,
output_field=models.IntegerField(), output_field=IntegerField(),
) )
expression_3 = Case(When(account_type__in=[Client.REGULAR, Client.GOLD], then=1), default=2) expression_3 = Case(When(account_type__in=[Client.REGULAR, Client.GOLD], then=1), default=2)
expression_4 = Case(When(account_type__in=[Client.PLATINUM, Client.GOLD], then=2), default=1) expression_4 = Case(When(account_type__in=[Client.PLATINUM, Client.GOLD], then=2), default=1)
@ -1347,7 +1349,7 @@ class CaseWhenTests(SimpleTestCase):
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
When(condition=object()) When(condition=object())
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
When(condition=Value(1, output_field=models.IntegerField())) When(condition=Value(1, output_field=IntegerField()))
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
When() When()

View File

@ -4,10 +4,9 @@ from unittest import mock, skipIf
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection from django.db import NotSupportedError, connection
from django.db.models import ( from django.db.models import (
BooleanField, Case, F, Func, OuterRef, Q, RowRange, Subquery, Value, Avg, BooleanField, Case, F, Func, Max, Min, OuterRef, Q, RowRange,
ValueRange, When, Window, WindowFrame, Subquery, Sum, Value, ValueRange, When, Window, WindowFrame,
) )
from django.db.models.aggregates import Avg, Max, Min, Sum
from django.db.models.functions import ( from django.db.models.functions import (
CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead, CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead,
NthValue, Ntile, PercentRank, Rank, RowNumber, Upper, NthValue, Ntile, PercentRank, Rank, RowNumber, Upper,

View File

@ -1,5 +1,4 @@
from django.db import models from django.db import models
from django.db.models.fields.related import ForeignObject
class Address(models.Model): class Address(models.Model):
@ -15,7 +14,7 @@ class Address(models.Model):
class Customer(models.Model): class Customer(models.Model):
company = models.CharField(max_length=1) company = models.CharField(max_length=1)
customer_id = models.IntegerField() customer_id = models.IntegerField()
address = ForeignObject( address = models.ForeignObject(
Address, models.CASCADE, null=True, Address, models.CASCADE, null=True,
# order mismatches the Contact ForeignObject. # order mismatches the Contact ForeignObject.
from_fields=['company', 'customer_id'], from_fields=['company', 'customer_id'],
@ -31,7 +30,7 @@ class Customer(models.Model):
class Contact(models.Model): class Contact(models.Model):
company_code = models.CharField(max_length=1) company_code = models.CharField(max_length=1)
customer_code = models.IntegerField() customer_code = models.IntegerField()
customer = ForeignObject( customer = models.ForeignObject(
Customer, models.CASCADE, related_name='contacts', Customer, models.CASCADE, related_name='contacts',
to_fields=['customer_id', 'company'], to_fields=['customer_id', 'company'],
from_fields=['customer_code', 'company_code'], from_fields=['customer_code', 'company_code'],

View File

@ -1,12 +1,10 @@
from django.db import models from django.db import models
from django.db.models.fields.related import ( from django.db.models.fields.related import ReverseManyToOneDescriptor
ForeignObjectRel, ReverseManyToOneDescriptor,
)
from django.db.models.lookups import StartsWith from django.db.models.lookups import StartsWith
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
class CustomForeignObjectRel(ForeignObjectRel): class CustomForeignObjectRel(models.ForeignObjectRel):
""" """
Define some extra Field methods so this Rel acts more like a Field, which Define some extra Field methods so this Rel acts more like a Field, which
lets us use ReverseManyToOneDescriptor in both directions. lets us use ReverseManyToOneDescriptor in both directions.

View File

@ -3,7 +3,6 @@ from operator import attrgetter
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import models from django.db import models
from django.db.models.fields.related import ForeignObject
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import isolate_apps from django.test.utils import isolate_apps
from django.utils import translation from django.utils import translation
@ -436,7 +435,7 @@ class TestModelCheckTests(SimpleTestCase):
a = models.PositiveIntegerField() a = models.PositiveIntegerField()
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
value = models.CharField(max_length=255) value = models.CharField(max_length=255)
parent = ForeignObject( parent = models.ForeignObject(
Parent, Parent,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b'), from_fields=('a', 'b'),
@ -461,7 +460,7 @@ class TestModelCheckTests(SimpleTestCase):
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
c = models.PositiveIntegerField() c = models.PositiveIntegerField()
d = models.CharField(max_length=255) d = models.CharField(max_length=255)
parent = ForeignObject( parent = models.ForeignObject(
Parent, Parent,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b', 'c'), from_fields=('a', 'b', 'c'),

View File

@ -3,7 +3,6 @@ from django.contrib.contenttypes.fields import (
) )
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.db.models.deletion import ProtectedError
__all__ = ('Link', 'Place', 'Restaurant', 'Person', 'Address', __all__ = ('Link', 'Place', 'Restaurant', 'Person', 'Address',
'CharLink', 'TextLink', 'OddRelation1', 'OddRelation2', 'CharLink', 'TextLink', 'OddRelation1', 'OddRelation2',
@ -214,7 +213,7 @@ class Related(models.Model):
def prevent_deletes(sender, instance, **kwargs): def prevent_deletes(sender, instance, **kwargs):
raise ProtectedError("Not allowed to delete.", [instance]) raise models.ProtectedError("Not allowed to delete.", [instance])
models.signals.pre_delete.connect(prevent_deletes, sender=Node) models.signals.pre_delete.connect(prevent_deletes, sender=Node)

View File

@ -1,6 +1,5 @@
from django.db.models import Q, Sum from django.db import IntegrityError
from django.db.models.deletion import ProtectedError from django.db.models import ProtectedError, Q, Sum
from django.db.utils import IntegrityError
from django.forms.models import modelform_factory from django.forms.models import modelform_factory
from django.test import TestCase, skipIfDBFeature from django.test import TestCase, skipIfDBFeature

View File

@ -1,6 +1,6 @@
import copy import copy
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models import GeometryField
from django.contrib.gis.db.models.sql import AreaField, DistanceField from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.test import SimpleTestCase from django.test import SimpleTestCase

View File

@ -1,5 +1,4 @@
from django.db import connection, models from django.db import connection, models
from django.db.models.expressions import Func
from django.test import SimpleTestCase from django.test import SimpleTestCase
from .utils import FuncTestMixin from .utils import FuncTestMixin
@ -8,7 +7,7 @@ from .utils import FuncTestMixin
def test_mutation(raises=True): def test_mutation(raises=True):
def wrapper(mutation_func): def wrapper(mutation_func):
def test(test_case_instance, *args, **kwargs): def test(test_case_instance, *args, **kwargs):
class TestFunc(Func): class TestFunc(models.Func):
output_field = models.IntegerField() output_field = models.IntegerField()
def __init__(self): def __init__(self):

View File

@ -5,7 +5,7 @@ from unittest import mock
from django.conf import settings from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, connection from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models.expressions import Func from django.db.models import Func
def skipUnlessGISLookup(*gis_lookups): def skipUnlessGISLookup(*gis_lookups):

View File

@ -2,10 +2,7 @@ import datetime
from unittest import skipIf, skipUnless from unittest import skipIf, skipUnless
from django.db import connection from django.db import connection
from django.db.models import Index from django.db.models import CASCADE, ForeignKey, Index, Q
from django.db.models.deletion import CASCADE
from django.db.models.fields.related import ForeignKey
from django.db.models.query_utils import Q
from django.test import ( from django.test import (
TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature,
) )

View File

@ -1,8 +1,7 @@
from unittest import mock, skipUnless from unittest import mock, skipUnless
from django.db import connection from django.db import DatabaseError, connection
from django.db.models import Index from django.db.models import Index
from django.db.utils import DatabaseError
from django.test import TransactionTestCase, skipUnlessDBFeature from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import ( from .models import (

View File

@ -2,7 +2,6 @@ from unittest import mock
from django.core.checks import Error, Warning as DjangoWarning from django.core.checks import Error, Warning as DjangoWarning
from django.db import connection, models from django.db import connection, models
from django.db.models.fields.related import ForeignObject
from django.test.testcases import SimpleTestCase from django.test.testcases import SimpleTestCase
from django.test.utils import isolate_apps, override_settings from django.test.utils import isolate_apps, override_settings
@ -608,7 +607,7 @@ class RelativeFieldTests(SimpleTestCase):
class Child(models.Model): class Child(models.Model):
a = models.PositiveIntegerField() a = models.PositiveIntegerField()
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
parent = ForeignObject( parent = models.ForeignObject(
Parent, Parent,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b'), from_fields=('a', 'b'),
@ -633,7 +632,7 @@ class RelativeFieldTests(SimpleTestCase):
class Child(models.Model): class Child(models.Model):
a = models.PositiveIntegerField() a = models.PositiveIntegerField()
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
parent = ForeignObject( parent = models.ForeignObject(
'invalid_models_tests.Parent', 'invalid_models_tests.Parent',
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b'), from_fields=('a', 'b'),
@ -1441,7 +1440,7 @@ class M2mThroughFieldsTests(SimpleTestCase):
a = models.PositiveIntegerField() a = models.PositiveIntegerField()
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
value = models.CharField(max_length=255) value = models.CharField(max_length=255)
parent = ForeignObject( parent = models.ForeignObject(
Parent, Parent,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b'), from_fields=('a', 'b'),
@ -1477,7 +1476,7 @@ class M2mThroughFieldsTests(SimpleTestCase):
b = models.PositiveIntegerField() b = models.PositiveIntegerField()
d = models.PositiveIntegerField() d = models.PositiveIntegerField()
value = models.CharField(max_length=255) value = models.CharField(max_length=255)
parent = ForeignObject( parent = models.ForeignObject(
Parent, Parent,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
from_fields=('a', 'b', 'd'), from_fields=('a', 'b', 'd'),

View File

@ -1,5 +1,4 @@
from django.db.models.aggregates import Sum from django.db.models import F, Sum
from django.db.models.expressions import F
from django.test import TestCase from django.test import TestCase
from .models import Product, Stock from .models import Product, Stock

View File

@ -1,7 +1,6 @@
from datetime import datetime from datetime import datetime
from django.db.models import Value from django.db.models import DateTimeField, Value
from django.db.models.fields import DateTimeField
from django.db.models.lookups import YearLookup from django.db.models.lookups import YearLookup
from django.test import SimpleTestCase from django.test import SimpleTestCase

View File

@ -5,8 +5,7 @@ from operator import attrgetter
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import Max from django.db.models import Exists, Max, OuterRef
from django.db.models.expressions import Exists, OuterRef
from django.db.models.functions import Substr from django.db.models.functions import Substr
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango40Warning from django.utils.deprecation import RemovedInDjango40Warning

View File

@ -2,8 +2,7 @@ import datetime
from copy import deepcopy from copy import deepcopy
from django.core.exceptions import FieldError, MultipleObjectsReturned from django.core.exceptions import FieldError, MultipleObjectsReturned
from django.db import models, transaction from django.db import IntegrityError, models, transaction
from django.db.utils import IntegrityError
from django.test import TestCase from django.test import TestCase
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy

View File

@ -1,12 +1,11 @@
from unittest import mock from unittest import mock
from django.apps.registry import apps as global_apps from django.apps.registry import apps as global_apps
from django.db import connection from django.db import DatabaseError, connection
from django.db.migrations.exceptions import InvalidMigrationPlan from django.db.migrations.exceptions import InvalidMigrationPlan
from django.db.migrations.executor import MigrationExecutor from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.graph import MigrationGraph from django.db.migrations.graph import MigrationGraph
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
from django.db.utils import DatabaseError
from django.test import ( from django.test import (
SimpleTestCase, modify_settings, override_settings, skipUnlessDBFeature, SimpleTestCase, modify_settings, override_settings, skipUnlessDBFeature,
) )

View File

@ -1,12 +1,12 @@
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from django.db import connection, migrations, models, transaction from django.db import (
IntegrityError, connection, migrations, models, transaction,
)
from django.db.migrations.migration import Migration from django.db.migrations.migration import Migration
from django.db.migrations.operations import CreateModel from django.db.migrations.operations import CreateModel
from django.db.migrations.operations.fields import FieldOperation from django.db.migrations.operations.fields import FieldOperation
from django.db.migrations.state import ModelState, ProjectState from django.db.migrations.state import ModelState, ProjectState
from django.db.models.fields import NOT_PROVIDED
from django.db.transaction import atomic from django.db.transaction import atomic
from django.db.utils import IntegrityError
from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature
from .models import FoodManager, FoodQuerySet, UnicodeModel from .models import FoodManager, FoodQuerySet, UnicodeModel
@ -979,7 +979,7 @@ class OperationTests(OperationTestBase):
f for n, f in new_state.models["test_adflpd", "pony"].fields f for n, f in new_state.models["test_adflpd", "pony"].fields
if n == "height" if n == "height"
][0] ][0]
self.assertEqual(field.default, NOT_PROVIDED) self.assertEqual(field.default, models.NOT_PROVIDED)
# Test the database alteration # Test the database alteration
project_state.apps.get_model("test_adflpd", "pony").objects.create( project_state.apps.get_model("test_adflpd", "pony").objects.create(
weight=4, weight=4,

View File

@ -8,10 +8,7 @@ from django.contrib.contenttypes.fields import (
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.files.storage import FileSystemStorage from django.core.files.storage import FileSystemStorage
from django.db import models from django.db import models
from django.db.models.fields.files import ImageField, ImageFieldFile from django.db.models.fields.files import ImageFieldFile
from django.db.models.fields.related import (
ForeignKey, ForeignObject, ManyToManyField, OneToOneField,
)
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
try: try:
@ -255,7 +252,7 @@ if Image:
self.was_opened = True self.was_opened = True
super().open() super().open()
class TestImageField(ImageField): class TestImageField(models.ImageField):
attr_class = TestImageFieldFile attr_class = TestImageFieldFile
# Set up a temp directory for file storage. # Set up a temp directory for file storage.
@ -359,20 +356,20 @@ class AllFieldsModel(models.Model):
url = models.URLField() url = models.URLField()
uuid = models.UUIDField() uuid = models.UUIDField()
fo = ForeignObject( fo = models.ForeignObject(
'self', 'self',
on_delete=models.CASCADE, on_delete=models.CASCADE,
from_fields=['positive_integer'], from_fields=['positive_integer'],
to_fields=['id'], to_fields=['id'],
related_name='reverse' related_name='reverse'
) )
fk = ForeignKey( fk = models.ForeignKey(
'self', 'self',
models.CASCADE, models.CASCADE,
related_name='reverse2' related_name='reverse2'
) )
m2m = ManyToManyField('self') m2m = models.ManyToManyField('self')
oto = OneToOneField('self', models.CASCADE) oto = models.OneToOneField('self', models.CASCADE)
object_id = models.PositiveIntegerField() object_id = models.PositiveIntegerField()
content_type = models.ForeignKey(ContentType, models.CASCADE) content_type = models.ForeignKey(ContentType, models.CASCADE)

View File

@ -3,15 +3,11 @@ from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation, GenericForeignKey, GenericRelation,
) )
from django.db import models from django.db import models
from django.db.models.fields.related import (
ForeignKey, ForeignObject, ForeignObjectRel, ManyToManyField, ManyToOneRel,
OneToOneField,
)
from .models import AllFieldsModel from .models import AllFieldsModel
NON_CONCRETE_FIELDS = ( NON_CONCRETE_FIELDS = (
ForeignObject, models.ForeignObject,
GenericForeignKey, GenericForeignKey,
GenericRelation, GenericRelation,
) )
@ -23,32 +19,32 @@ NON_EDITABLE_FIELDS = (
) )
RELATION_FIELDS = ( RELATION_FIELDS = (
ForeignKey, models.ForeignKey,
ForeignObject, models.ForeignObject,
ManyToManyField, models.ManyToManyField,
OneToOneField, models.OneToOneField,
GenericForeignKey, GenericForeignKey,
GenericRelation, GenericRelation,
) )
MANY_TO_MANY_CLASSES = { MANY_TO_MANY_CLASSES = {
ManyToManyField, models.ManyToManyField,
} }
MANY_TO_ONE_CLASSES = { MANY_TO_ONE_CLASSES = {
ForeignObject, models.ForeignObject,
ForeignKey, models.ForeignKey,
GenericForeignKey, GenericForeignKey,
} }
ONE_TO_MANY_CLASSES = { ONE_TO_MANY_CLASSES = {
ForeignObjectRel, models.ForeignObjectRel,
ManyToOneRel, models.ManyToOneRel,
GenericRelation, GenericRelation,
} }
ONE_TO_ONE_CLASSES = { ONE_TO_ONE_CLASSES = {
OneToOneField, models.OneToOneField,
} }
FLAG_PROPERTIES = ( FLAG_PROPERTIES = (

View File

@ -8,7 +8,7 @@ from pathlib import Path
from django.core.files import File, temp from django.core.files import File, temp
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
from django.core.files.uploadedfile import TemporaryUploadedFile from django.core.files.uploadedfile import TemporaryUploadedFile
from django.db.utils import IntegrityError from django.db import IntegrityError
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from .models import Document from .models import Document

View File

@ -1,14 +1,14 @@
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db.models.fields import ( from django.db.models import (
AutoField, BinaryField, BooleanField, CharField, DateField, DateTimeField, AutoField, BinaryField, BooleanField, CharField, DateField, DateTimeField,
DecimalField, EmailField, FilePathField, FloatField, GenericIPAddressField, DecimalField, EmailField, FileField, FilePathField, FloatField,
IntegerField, IPAddressField, NullBooleanField, PositiveBigIntegerField, GenericIPAddressField, ImageField, IntegerField, IPAddressField,
PositiveIntegerField, PositiveSmallIntegerField, SlugField, NullBooleanField, PositiveBigIntegerField, PositiveIntegerField,
SmallIntegerField, TextField, TimeField, URLField, PositiveSmallIntegerField, SlugField, SmallIntegerField, TextField,
TimeField, URLField,
) )
from django.db.models.fields.files import FileField, ImageField
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.utils.functional import lazy from django.utils.functional import lazy

Some files were not shown because too many files have changed in this diff Show More