diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index f85a280b61..73944056d5 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -190,36 +190,31 @@ class ArrayField(CheckFieldDefaultMixin, Field): }) -@ArrayField.register_lookup -class ArrayContains(lookups.DataContains): - def as_sql(self, qn, connection): - sql, params = super().as_sql(qn, connection) - sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) - return sql, params +class ArrayCastRHSMixin: + def process_rhs(self, compiler, connection): + rhs, rhs_params = super().process_rhs(compiler, connection) + cast_type = self.lhs.output_field.db_type(connection) + return '%s::%s' % (rhs, cast_type), rhs_params @ArrayField.register_lookup -class ArrayContainedBy(lookups.ContainedBy): - def as_sql(self, qn, connection): - sql, params = super().as_sql(qn, connection) - sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) - return sql, params +class ArrayContains(ArrayCastRHSMixin, lookups.DataContains): + pass @ArrayField.register_lookup -class ArrayExact(Exact): - def as_sql(self, qn, connection): - sql, params = super().as_sql(qn, connection) - sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) - return sql, params +class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy): + pass @ArrayField.register_lookup -class ArrayOverlap(lookups.Overlap): - def as_sql(self, qn, connection): - sql, params = super().as_sql(qn, connection) - sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection)) - return sql, params +class ArrayExact(ArrayCastRHSMixin, Exact): + pass + + +@ArrayField.register_lookup +class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap): + pass @ArrayField.register_lookup