Refs #31211 -- Prevented SearchConfig nesting in SearchVector and SearchQuery init.

Passing a SearchConfig instance directly to SearchVector and
SearchQuery would result in nested SearchConfig instance.
This commit is contained in:
Simon Charette 2020-02-24 23:48:07 -05:00 committed by Mariusz Felisiak
parent 3d62ddb026
commit d0f1c03331
2 changed files with 17 additions and 4 deletions

View File

@ -41,6 +41,12 @@ class SearchConfig(Expression):
config = Value(config) config = Value(config)
self.config = config self.config = config
@classmethod
def from_parameter(cls, config):
if config is None or isinstance(config, cls):
return config
return cls(config)
def get_source_expressions(self): def get_source_expressions(self):
return [self.config] return [self.config]
@ -75,7 +81,7 @@ class SearchVector(SearchVectorCombinable, Func):
def __init__(self, *expressions, **extra): def __init__(self, *expressions, **extra):
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
config = self.extra.get('config', self.config) config = self.extra.get('config', self.config)
self.config = SearchConfig(config) if config else None self.config = SearchConfig.from_parameter(config)
weight = self.extra.get('weight') weight = self.extra.get('weight')
if weight is not None and not hasattr(weight, 'resolve_expression'): if weight is not None and not hasattr(weight, 'resolve_expression'):
weight = Value(weight) weight = Value(weight)
@ -162,7 +168,7 @@ class SearchQuery(SearchQueryCombinable, Value):
} }
def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'): def __init__(self, value, output_field=None, *, config=None, invert=False, search_type='plain'):
self.config = SearchConfig(config) if config else None self.config = SearchConfig.from_parameter(config)
self.invert = invert self.invert = invert
if search_type not in self.SEARCH_TYPES: if search_type not in self.SEARCH_TYPES:
raise ValueError("Unknown search_type argument '%s'." % search_type) raise ValueError("Unknown search_type argument '%s'." % search_type)

View File

@ -6,13 +6,13 @@ All text copyright Python (Monty) Pictures. Thanks to sacred-texts.com for the
transcript. transcript.
""" """
from django.contrib.postgres.search import ( from django.contrib.postgres.search import (
SearchQuery, SearchRank, SearchVector, SearchConfig, SearchQuery, SearchRank, SearchVector,
) )
from django.db import connection from django.db import connection
from django.db.models import F from django.db.models import F
from django.test import SimpleTestCase, modify_settings, skipUnlessDBFeature from django.test import SimpleTestCase, modify_settings, skipUnlessDBFeature
from . import PostgreSQLTestCase from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase
from .models import Character, Line, Scene from .models import Character, Line, Scene
@ -118,6 +118,13 @@ class SearchVectorFieldTest(GrailTestData, PostgreSQLTestCase):
self.assertNotIn('COALESCE(COALESCE', str(searched.query)) self.assertNotIn('COALESCE(COALESCE', str(searched.query))
class SearchConfigTests(PostgreSQLSimpleTestCase):
def test_from_parameter(self):
self.assertIsNone(SearchConfig.from_parameter(None))
self.assertEqual(SearchConfig.from_parameter('foo'), SearchConfig('foo'))
self.assertEqual(SearchConfig.from_parameter(SearchConfig('bar')), SearchConfig('bar'))
class MultipleFieldsTest(GrailTestData, PostgreSQLTestCase): class MultipleFieldsTest(GrailTestData, PostgreSQLTestCase):
def test_simple_on_dialogue(self): def test_simple_on_dialogue(self):