from copy import copy

from django.conf import settings
from django.db.models.expressions import Func, Value
from django.db.models.fields import (
    DateField, DateTimeField, Field, IntegerField, TimeField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.six.moves import range


class Lookup(object):
    lookup_name = None

    def __init__(self, lhs, rhs):
        self.lhs, self.rhs = lhs, rhs
        self.rhs = self.get_prep_lookup()
        if hasattr(self.lhs, 'get_bilateral_transforms'):
            bilateral_transforms = self.lhs.get_bilateral_transforms()
        else:
            bilateral_transforms = []
        if bilateral_transforms:
            # We should warn the user as soon as possible if he is trying to apply
            # a bilateral transformation on a nested QuerySet: that won't work.
            # We need to import QuerySet here so as to avoid circular
            from django.db.models.query import QuerySet
            if isinstance(rhs, QuerySet):
                raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
        self.bilateral_transforms = bilateral_transforms

    def apply_bilateral_transforms(self, value):
        for transform in self.bilateral_transforms:
            value = transform(value)
        return value

    def batch_process_rhs(self, compiler, connection, rhs=None):
        if rhs is None:
            rhs = self.rhs
        if self.bilateral_transforms:
            sqls, sqls_params = [], []
            for p in rhs:
                value = Value(p, output_field=self.lhs.output_field)
                value = self.apply_bilateral_transforms(value)
                value = value.resolve_expression(compiler.query)
                sql, sql_params = compiler.compile(value)
                sqls.append(sql)
                sqls_params.extend(sql_params)
        else:
            params = self.lhs.output_field.get_db_prep_lookup(
                self.lookup_name, rhs, connection, prepared=True)
            sqls, sqls_params = ['%s'] * len(params), params
        return sqls, sqls_params

    def get_prep_lookup(self):
        return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)

    def get_db_prep_lookup(self, value, connection):
        return (
            '%s', self.lhs.output_field.get_db_prep_lookup(
                self.lookup_name, value, connection, prepared=True))

    def process_lhs(self, compiler, connection, lhs=None):
        lhs = lhs or self.lhs
        return compiler.compile(lhs)

    def process_rhs(self, compiler, connection):
        value = self.rhs
        if self.bilateral_transforms:
            if self.rhs_is_direct_value():
                # Do not call get_db_prep_lookup here as the value will be
                # transformed before being used for lookup
                value = Value(value, output_field=self.lhs.output_field)
            value = self.apply_bilateral_transforms(value)
            value = value.resolve_expression(compiler.query)
        # Due to historical reasons there are a couple of different
        # ways to produce sql here. get_compiler is likely a Query
        # instance, _as_sql QuerySet and as_sql just something with
        # as_sql. Finally the value can of course be just plain
        # Python value.
        if hasattr(value, 'get_compiler'):
            value = value.get_compiler(connection=connection)
        if hasattr(value, 'as_sql'):
            sql, params = compiler.compile(value)
            return '(' + sql + ')', params
        if hasattr(value, '_as_sql'):
            sql, params = value._as_sql(connection=connection)
            return '(' + sql + ')', params
        else:
            return self.get_db_prep_lookup(value, connection)

    def rhs_is_direct_value(self):
        return not(
            hasattr(self.rhs, 'as_sql') or
            hasattr(self.rhs, '_as_sql') or
            hasattr(self.rhs, 'get_compiler'))

    def relabeled_clone(self, relabels):
        new = copy(self)
        new.lhs = new.lhs.relabeled_clone(relabels)
        if hasattr(new.rhs, 'relabeled_clone'):
            new.rhs = new.rhs.relabeled_clone(relabels)
        return new

    def get_group_by_cols(self):
        cols = self.lhs.get_group_by_cols()
        if hasattr(self.rhs, 'get_group_by_cols'):
            cols.extend(self.rhs.get_group_by_cols())
        return cols

    def as_sql(self, compiler, connection):
        raise NotImplementedError

    @cached_property
    def contains_aggregate(self):
        return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)


class Transform(RegisterLookupMixin, Func):
    """
    RegisterLookupMixin() is first so that get_lookup() and get_transform()
    first examine self and then check output_field.
    """
    bilateral = False

    def __init__(self, expression, **extra):
        # Restrict Transform to allow only a single expression.
        super(Transform, self).__init__(expression, **extra)

    @property
    def lhs(self):
        return self.get_source_expressions()[0]

    def get_bilateral_transforms(self):
        if hasattr(self.lhs, 'get_bilateral_transforms'):
            bilateral_transforms = self.lhs.get_bilateral_transforms()
        else:
            bilateral_transforms = []
        if self.bilateral:
            bilateral_transforms.append(self.__class__)
        return bilateral_transforms


class BuiltinLookup(Lookup):
    def process_lhs(self, compiler, connection, lhs=None):
        lhs_sql, params = super(BuiltinLookup, self).process_lhs(
            compiler, connection, lhs)
        field_internal_type = self.lhs.output_field.get_internal_type()
        db_type = self.lhs.output_field.db_type(connection=connection)
        lhs_sql = connection.ops.field_cast_sql(
            db_type, field_internal_type) % lhs_sql
        lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
        return lhs_sql, params

    def as_sql(self, compiler, connection):
        lhs_sql, params = self.process_lhs(compiler, connection)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        params.extend(rhs_params)
        rhs_sql = self.get_rhs_op(connection, rhs_sql)
        return '%s %s' % (lhs_sql, rhs_sql), params

    def get_rhs_op(self, connection, rhs):
        return connection.operators[self.lookup_name] % rhs


class Exact(BuiltinLookup):
    lookup_name = 'exact'
Field.register_lookup(Exact)


class IExact(BuiltinLookup):
    lookup_name = 'iexact'

    def process_rhs(self, qn, connection):
        rhs, params = super(IExact, self).process_rhs(qn, connection)
        if params:
            params[0] = connection.ops.prep_for_iexact_query(params[0])
        return rhs, params


Field.register_lookup(IExact)


class GreaterThan(BuiltinLookup):
    lookup_name = 'gt'
Field.register_lookup(GreaterThan)


class GreaterThanOrEqual(BuiltinLookup):
    lookup_name = 'gte'
Field.register_lookup(GreaterThanOrEqual)


class LessThan(BuiltinLookup):
    lookup_name = 'lt'
Field.register_lookup(LessThan)


class LessThanOrEqual(BuiltinLookup):
    lookup_name = 'lte'
Field.register_lookup(LessThanOrEqual)


class In(BuiltinLookup):
    lookup_name = 'in'

    def process_rhs(self, compiler, connection):
        if self.rhs_is_direct_value():
            # rhs should be an iterable, we use batch_process_rhs
            # to prepare/transform those values
            rhs = list(self.rhs)
            if not rhs:
                from django.db.models.sql.datastructures import EmptyResultSet
                raise EmptyResultSet
            sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
            placeholder = '(' + ', '.join(sqls) + ')'
            return (placeholder, sqls_params)
        else:
            return super(In, self).process_rhs(compiler, connection)

    def get_rhs_op(self, connection, rhs):
        return 'IN %s' % rhs

    def as_sql(self, compiler, connection):
        max_in_list_size = connection.ops.max_in_list_size()
        if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
            return self.split_parameter_list_as_sql(compiler, connection)
        return super(In, self).as_sql(compiler, connection)

    def split_parameter_list_as_sql(self, compiler, connection):
        # This is a special case for databases which limit the number of
        # elements which can appear in an 'IN' clause.
        max_in_list_size = connection.ops.max_in_list_size()
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.batch_process_rhs(compiler, connection)
        in_clause_elements = ['(']
        params = []
        for offset in range(0, len(rhs_params), max_in_list_size):
            if offset > 0:
                in_clause_elements.append(' OR ')
            in_clause_elements.append('%s IN (' % lhs)
            params.extend(lhs_params)
            sqls = rhs[offset: offset + max_in_list_size]
            sqls_params = rhs_params[offset: offset + max_in_list_size]
            param_group = ', '.join(sqls)
            in_clause_elements.append(param_group)
            in_clause_elements.append(')')
            params.extend(sqls_params)
        in_clause_elements.append(')')
        return ''.join(in_clause_elements), params
Field.register_lookup(In)


class PatternLookup(BuiltinLookup):

    def get_rhs_op(self, connection, rhs):
        # Assume we are in startswith. We need to produce SQL like:
        #     col LIKE %s, ['thevalue%']
        # For python values we can (and should) do that directly in Python,
        # but if the value is for example reference to other column, then
        # we need to add the % pattern match to the lookup by something like
        #     col LIKE othercol || '%%'
        # So, for Python values we don't need any special pattern, but for
        # SQL reference values or SQL transformations we need the correct
        # pattern added.
        if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
                or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
            pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
            return pattern.format(rhs)
        else:
            return super(PatternLookup, self).get_rhs_op(connection, rhs)


class Contains(PatternLookup):
    lookup_name = 'contains'

    def process_rhs(self, qn, connection):
        rhs, params = super(Contains, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params
Field.register_lookup(Contains)


class IContains(Contains):
    lookup_name = 'icontains'
Field.register_lookup(IContains)


class StartsWith(PatternLookup):
    lookup_name = 'startswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(StartsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params
Field.register_lookup(StartsWith)


class IStartsWith(PatternLookup):
    lookup_name = 'istartswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params
Field.register_lookup(IStartsWith)


class EndsWith(PatternLookup):
    lookup_name = 'endswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(EndsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
        return rhs, params
Field.register_lookup(EndsWith)


class IEndsWith(PatternLookup):
    lookup_name = 'iendswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
        return rhs, params
Field.register_lookup(IEndsWith)


class Between(BuiltinLookup):
    def get_rhs_op(self, connection, rhs):
        return "BETWEEN %s AND %s" % (rhs, rhs)


class Range(BuiltinLookup):
    lookup_name = 'range'

    def get_rhs_op(self, connection, rhs):
        return "BETWEEN %s AND %s" % (rhs[0], rhs[1])

    def process_rhs(self, compiler, connection):
        if self.rhs_is_direct_value():
            # rhs should be an iterable of 2 values, we use batch_process_rhs
            # to prepare/transform those values
            return self.batch_process_rhs(compiler, connection)
        else:
            return super(Range, self).process_rhs(compiler, connection)
Field.register_lookup(Range)


class IsNull(BuiltinLookup):
    lookup_name = 'isnull'

    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        if self.rhs:
            return "%s IS NULL" % sql, params
        else:
            return "%s IS NOT NULL" % sql, params
Field.register_lookup(IsNull)


class Search(BuiltinLookup):
    lookup_name = 'search'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
        return sql_template, lhs_params + rhs_params
Field.register_lookup(Search)


class Regex(BuiltinLookup):
    lookup_name = 'regex'

    def as_sql(self, compiler, connection):
        if self.lookup_name in connection.operators:
            return super(Regex, self).as_sql(compiler, connection)
        else:
            lhs, lhs_params = self.process_lhs(compiler, connection)
            rhs, rhs_params = self.process_rhs(compiler, connection)
            sql_template = connection.ops.regex_lookup(self.lookup_name)
            return sql_template % (lhs, rhs), lhs_params + rhs_params
Field.register_lookup(Regex)


class IRegex(Regex):
    lookup_name = 'iregex'
Field.register_lookup(IRegex)


class DateTimeDateTransform(Transform):
    lookup_name = 'date'

    @cached_property
    def output_field(self):
        return DateField()

    def as_sql(self, compiler, connection):
        lhs, lhs_params = compiler.compile(self.lhs)
        tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
        sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
        lhs_params.extend(tz_params)
        return sql, lhs_params


class DateTransform(Transform):
    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        lhs_output_field = self.lhs.output_field
        if isinstance(lhs_output_field, DateTimeField):
            tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
            sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
            params.extend(tz_params)
        elif isinstance(lhs_output_field, DateField):
            sql = connection.ops.date_extract_sql(self.lookup_name, sql)
        elif isinstance(lhs_output_field, TimeField):
            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
        else:
            raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
        return sql, params

    @cached_property
    def output_field(self):
        return IntegerField()


class YearTransform(DateTransform):
    lookup_name = 'year'


class YearLookup(Lookup):
    def year_lookup_bounds(self, connection, year):
        output_field = self.lhs.lhs.output_field
        if isinstance(output_field, DateTimeField):
            bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
        else:
            bounds = connection.ops.year_lookup_bounds_for_date_field(year)
        return bounds


@YearTransform.register_lookup
class YearExact(YearLookup):
    lookup_name = 'exact'

    def as_sql(self, compiler, connection):
        # We will need to skip the extract part and instead go
        # directly with the originating field, that is self.lhs.lhs.
        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        bounds = self.year_lookup_bounds(connection, rhs_params[0])
        params.extend(bounds)
        return '%s BETWEEN %%s AND %%s' % lhs_sql, params


class YearComparisonLookup(YearLookup):
    def as_sql(self, compiler, connection):
        # We will need to skip the extract part and instead go
        # directly with the originating field, that is self.lhs.lhs.
        lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        rhs_sql = self.get_rhs_op(connection, rhs_sql)
        start, finish = self.year_lookup_bounds(connection, rhs_params[0])
        params.append(self.get_bound(start, finish))
        return '%s %s' % (lhs_sql, rhs_sql), params

    def get_rhs_op(self, connection, rhs):
        return connection.operators[self.lookup_name] % rhs

    def get_bound(self):
        raise NotImplementedError(
            'subclasses of YearComparisonLookup must provide a get_bound() method'
        )


@YearTransform.register_lookup
class YearGt(YearComparisonLookup):
    lookup_name = 'gt'

    def get_bound(self, start, finish):
        return finish


@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
    lookup_name = 'gte'

    def get_bound(self, start, finish):
        return start


@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
    lookup_name = 'lt'

    def get_bound(self, start, finish):
        return start


@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
    lookup_name = 'lte'

    def get_bound(self, start, finish):
        return finish


class MonthTransform(DateTransform):
    lookup_name = 'month'


class DayTransform(DateTransform):
    lookup_name = 'day'


class WeekDayTransform(DateTransform):
    lookup_name = 'week_day'


class HourTransform(DateTransform):
    lookup_name = 'hour'


class MinuteTransform(DateTransform):
    lookup_name = 'minute'


class SecondTransform(DateTransform):
    lookup_name = 'second'


DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)

TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)

DateTimeField.register_lookup(DateTimeDateTransform)
DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)