from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import Exists, F, IntegerField, OuterRef, Value
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature

from .models import Number, ReservedName


@skipUnlessDBFeature('supports_select_union')
class QuerySetSetOperationTests(TestCase):
    @classmethod
    def setUpTestData(cls):
        Number.objects.bulk_create(Number(num=i, other_num=10 - i) for i in range(10))

    def number_transform(self, value):
        return value.num

    def assertNumbersEqual(self, queryset, expected_numbers, ordered=True):
        self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered)

    def test_simple_union(self):
        qs1 = Number.objects.filter(num__lte=1)
        qs2 = Number.objects.filter(num__gte=8)
        qs3 = Number.objects.filter(num=5)
        self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False)

    @skipUnlessDBFeature('supports_select_intersection')
    def test_simple_intersection(self):
        qs1 = Number.objects.filter(num__lte=5)
        qs2 = Number.objects.filter(num__gte=5)
        qs3 = Number.objects.filter(num__gte=4, num__lte=6)
        self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False)

    @skipUnlessDBFeature('supports_select_intersection')
    def test_intersection_with_values(self):
        ReservedName.objects.create(name='a', order=2)
        qs1 = ReservedName.objects.all()
        reserved_name = qs1.intersection(qs1).values('name', 'order', 'id').get()
        self.assertEqual(reserved_name['name'], 'a')
        self.assertEqual(reserved_name['order'], 2)
        reserved_name = qs1.intersection(qs1).values_list('name', 'order', 'id').get()
        self.assertEqual(reserved_name[:2], ('a', 2))

    @skipUnlessDBFeature('supports_select_difference')
    def test_simple_difference(self):
        qs1 = Number.objects.filter(num__lte=5)
        qs2 = Number.objects.filter(num__lte=4)
        self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False)

    def test_union_distinct(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.all()
        self.assertEqual(len(list(qs1.union(qs2, all=True))), 20)
        self.assertEqual(len(list(qs1.union(qs2))), 10)

    @skipUnlessDBFeature('supports_select_intersection')
    def test_intersection_with_empty_qs(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.none()
        qs3 = Number.objects.filter(pk__in=[])
        self.assertEqual(len(qs1.intersection(qs2)), 0)
        self.assertEqual(len(qs1.intersection(qs3)), 0)
        self.assertEqual(len(qs2.intersection(qs1)), 0)
        self.assertEqual(len(qs3.intersection(qs1)), 0)
        self.assertEqual(len(qs2.intersection(qs2)), 0)
        self.assertEqual(len(qs3.intersection(qs3)), 0)

    @skipUnlessDBFeature('supports_select_difference')
    def test_difference_with_empty_qs(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.none()
        qs3 = Number.objects.filter(pk__in=[])
        self.assertEqual(len(qs1.difference(qs2)), 10)
        self.assertEqual(len(qs1.difference(qs3)), 10)
        self.assertEqual(len(qs2.difference(qs1)), 0)
        self.assertEqual(len(qs3.difference(qs1)), 0)
        self.assertEqual(len(qs2.difference(qs2)), 0)
        self.assertEqual(len(qs3.difference(qs3)), 0)

    @skipUnlessDBFeature('supports_select_difference')
    def test_difference_with_values(self):
        ReservedName.objects.create(name='a', order=2)
        qs1 = ReservedName.objects.all()
        qs2 = ReservedName.objects.none()
        reserved_name = qs1.difference(qs2).values('name', 'order', 'id').get()
        self.assertEqual(reserved_name['name'], 'a')
        self.assertEqual(reserved_name['order'], 2)
        reserved_name = qs1.difference(qs2).values_list('name', 'order', 'id').get()
        self.assertEqual(reserved_name[:2], ('a', 2))

    def test_union_with_empty_qs(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.none()
        qs3 = Number.objects.filter(pk__in=[])
        self.assertEqual(len(qs1.union(qs2)), 10)
        self.assertEqual(len(qs2.union(qs1)), 10)
        self.assertEqual(len(qs1.union(qs3)), 10)
        self.assertEqual(len(qs3.union(qs1)), 10)
        self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10)
        self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20)
        self.assertEqual(len(qs2.union(qs2)), 0)
        self.assertEqual(len(qs3.union(qs3)), 0)

    def test_limits(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.all()
        self.assertEqual(len(list(qs1.union(qs2)[:2])), 2)

    def test_ordering(self):
        qs1 = Number.objects.filter(num__lte=1)
        qs2 = Number.objects.filter(num__gte=2, num__lte=3)
        self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0])

    def test_ordering_by_f_expression(self):
        qs1 = Number.objects.filter(num__lte=1)
        qs2 = Number.objects.filter(num__gte=2, num__lte=3)
        self.assertNumbersEqual(qs1.union(qs2).order_by(F('num').desc()), [3, 2, 1, 0])

    def test_union_with_values(self):
        ReservedName.objects.create(name='a', order=2)
        qs1 = ReservedName.objects.all()
        reserved_name = qs1.union(qs1).values('name', 'order', 'id').get()
        self.assertEqual(reserved_name['name'], 'a')
        self.assertEqual(reserved_name['order'], 2)
        reserved_name = qs1.union(qs1).values_list('name', 'order', 'id').get()
        self.assertEqual(reserved_name[:2], ('a', 2))
        # List of columns can be changed.
        reserved_name = qs1.union(qs1).values_list('order').get()
        self.assertEqual(reserved_name, (2,))

    def test_union_with_two_annotated_values_list(self):
        qs1 = Number.objects.filter(num=1).annotate(
            count=Value(0, IntegerField()),
        ).values_list('num', 'count')
        qs2 = Number.objects.filter(num=2).values('pk').annotate(
            count=F('num'),
        ).annotate(
            num=Value(1, IntegerField()),
        ).values_list('num', 'count')
        self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)])

    def test_union_with_extra_and_values_list(self):
        qs1 = Number.objects.filter(num=1).extra(
            select={'count': 0},
        ).values_list('num', 'count')
        qs2 = Number.objects.filter(num=2).extra(select={'count': 1})
        self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)])

    def test_union_with_values_list_on_annotated_and_unannotated(self):
        ReservedName.objects.create(name='rn1', order=1)
        qs1 = Number.objects.annotate(
            has_reserved_name=Exists(ReservedName.objects.filter(order=OuterRef('num')))
        ).filter(has_reserved_name=True)
        qs2 = Number.objects.filter(num=9)
        self.assertCountEqual(qs1.union(qs2).values_list('num', flat=True), [1, 9])

    def test_union_with_values_list_and_order(self):
        ReservedName.objects.bulk_create([
            ReservedName(name='rn1', order=7),
            ReservedName(name='rn2', order=5),
            ReservedName(name='rn0', order=6),
            ReservedName(name='rn9', order=-1),
        ])
        qs1 = ReservedName.objects.filter(order__gte=6)
        qs2 = ReservedName.objects.filter(order__lte=5)
        union_qs = qs1.union(qs2)
        for qs, expected_result in (
            # Order by a single column.
            (union_qs.order_by('-pk').values_list('order', flat=True), [-1, 6, 5, 7]),
            (union_qs.order_by('pk').values_list('order', flat=True), [7, 5, 6, -1]),
            (union_qs.values_list('order', flat=True).order_by('-pk'), [-1, 6, 5, 7]),
            (union_qs.values_list('order', flat=True).order_by('pk'), [7, 5, 6, -1]),
            # Order by multiple columns.
            (union_qs.order_by('-name', 'pk').values_list('order', flat=True), [-1, 5, 7, 6]),
            (union_qs.values_list('order', flat=True).order_by('-name', 'pk'), [-1, 5, 7, 6]),
        ):
            with self.subTest(qs=qs):
                self.assertEqual(list(qs), expected_result)

    def test_count_union(self):
        qs1 = Number.objects.filter(num__lte=1).values('num')
        qs2 = Number.objects.filter(num__gte=2, num__lte=3).values('num')
        self.assertEqual(qs1.union(qs2).count(), 4)

    def test_count_union_empty_result(self):
        qs = Number.objects.filter(pk__in=[])
        self.assertEqual(qs.union(qs).count(), 0)

    @skipUnlessDBFeature('supports_select_difference')
    def test_count_difference(self):
        qs1 = Number.objects.filter(num__lt=10)
        qs2 = Number.objects.filter(num__lt=9)
        self.assertEqual(qs1.difference(qs2).count(), 1)

    @skipUnlessDBFeature('supports_select_intersection')
    def test_count_intersection(self):
        qs1 = Number.objects.filter(num__gte=5)
        qs2 = Number.objects.filter(num__lte=5)
        self.assertEqual(qs1.intersection(qs2).count(), 1)

    @skipUnlessDBFeature('supports_slicing_ordering_in_compound')
    def test_ordering_subqueries(self):
        qs1 = Number.objects.order_by('num')[:2]
        qs2 = Number.objects.order_by('-num')[:2]
        self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0])

    @skipIfDBFeature('supports_slicing_ordering_in_compound')
    def test_unsupported_ordering_slicing_raises_db_error(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.all()
        msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements'
        with self.assertRaisesMessage(DatabaseError, msg):
            list(qs1.union(qs2[:10]))
        msg = 'ORDER BY not allowed in subqueries of compound statements'
        with self.assertRaisesMessage(DatabaseError, msg):
            list(qs1.order_by('id').union(qs2))

    @skipIfDBFeature('supports_select_intersection')
    def test_unsupported_intersection_raises_db_error(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.all()
        msg = 'intersection is not supported on this database backend'
        with self.assertRaisesMessage(NotSupportedError, msg):
            list(qs1.intersection(qs2))

    def test_combining_multiple_models(self):
        ReservedName.objects.create(name='99 little bugs', order=99)
        qs1 = Number.objects.filter(num=1).values_list('num', flat=True)
        qs2 = ReservedName.objects.values_list('order')
        self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99])

    def test_order_raises_on_non_selected_column(self):
        qs1 = Number.objects.filter().annotate(
            annotation=Value(1, IntegerField()),
        ).values('annotation', num2=F('num'))
        qs2 = Number.objects.filter().values('id', 'num')
        # Should not raise
        list(qs1.union(qs2).order_by('annotation'))
        list(qs1.union(qs2).order_by('num2'))
        msg = 'ORDER BY term does not match any column in the result set'
        # 'id' is not part of the select
        with self.assertRaisesMessage(DatabaseError, msg):
            list(qs1.union(qs2).order_by('id'))
        # 'num' got realiased to num2
        with self.assertRaisesMessage(DatabaseError, msg):
            list(qs1.union(qs2).order_by('num'))
        # switched order, now 'exists' again:
        list(qs2.union(qs1).order_by('num'))

    @skipUnlessDBFeature('supports_select_difference', 'supports_select_intersection')
    def test_qs_with_subcompound_qs(self):
        qs1 = Number.objects.all()
        qs2 = Number.objects.intersection(Number.objects.filter(num__gt=1))
        self.assertEqual(qs1.difference(qs2).count(), 2)

    def test_order_by_same_type(self):
        qs = Number.objects.all()
        union = qs.union(qs)
        numbers = list(range(10))
        self.assertNumbersEqual(union.order_by('num'), numbers)
        self.assertNumbersEqual(union.order_by('other_num'), reversed(numbers))

    def test_unsupported_operations_on_combined_qs(self):
        qs = Number.objects.all()
        msg = 'Calling QuerySet.%s() after %s() is not supported.'
        combinators = ['union']
        if connection.features.supports_select_difference:
            combinators.append('difference')
        if connection.features.supports_select_intersection:
            combinators.append('intersection')
        for combinator in combinators:
            for operation in (
                'annotate',
                'defer',
                'delete',
                'exclude',
                'extra',
                'filter',
                'only',
                'prefetch_related',
                'select_related',
                'update',
            ):
                with self.subTest(combinator=combinator, operation=operation):
                    with self.assertRaisesMessage(
                        NotSupportedError,
                        msg % (operation, combinator),
                    ):
                        getattr(getattr(qs, combinator)(qs), operation)()