Fixed #32317 -- Refactored loaddata command to make it extensible.

Moved deeply nested blocks out of inner loops to improve readability
and maintainability.

Thanks to Mariusz Felisiak, Shreyas Ravi, and Paolo Melchiorre for
feedback.
This commit is contained in:
William Schwartz 2020-12-30 11:32:46 -06:00 committed by Mariusz Felisiak
parent 3954bf50fb
commit de32fe83a2
1 changed files with 91 additions and 63 deletions

View File

@ -84,6 +84,33 @@ class Command(BaseCommand):
if transaction.get_autocommit(self.using):
connections[self.using].close()
@cached_property
def compression_formats(self):
"""A dict mapping format names to (open function, mode arg) tuples."""
# Forcing binary mode may be revisited after dropping Python 2 support (see #22399)
compression_formats = {
None: (open, 'rb'),
'gz': (gzip.GzipFile, 'rb'),
'zip': (SingleZipReader, 'r'),
'stdin': (lambda *args: sys.stdin, None),
}
if has_bz2:
compression_formats['bz2'] = (bz2.BZ2File, 'r')
if has_lzma:
compression_formats['lzma'] = (lzma.LZMAFile, 'r')
compression_formats['xz'] = (lzma.LZMAFile, 'r')
return compression_formats
def reset_sequences(self, connection, models):
"""Reset database sequences for the given connection and models."""
sequence_sql = connection.ops.sequence_reset_sql(no_style(), models)
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write('Resetting sequences')
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
def loaddata(self, fixture_labels):
connection = connections[self.using]
@ -94,18 +121,6 @@ class Command(BaseCommand):
self.models = set()
self.serialization_formats = serializers.get_public_serializer_formats()
# Forcing binary mode may be revisited after dropping Python 2 support (see #22399)
self.compression_formats = {
None: (open, 'rb'),
'gz': (gzip.GzipFile, 'rb'),
'zip': (SingleZipReader, 'r'),
'stdin': (lambda *args: sys.stdin, None),
}
if has_bz2:
self.compression_formats['bz2'] = (bz2.BZ2File, 'r')
if has_lzma:
self.compression_formats['lzma'] = (lzma.LZMAFile, 'r')
self.compression_formats['xz'] = (lzma.LZMAFile, 'r')
# Django's test suite repeatedly tries to load initial_data fixtures
# from apps that don't have any fixtures. Because disabling constraint
@ -136,13 +151,7 @@ class Command(BaseCommand):
# If we found even one object in a fixture, we need to reset the
# database sequences.
if self.loaded_object_count > 0:
sequence_sql = connection.ops.sequence_reset_sql(no_style(), self.models)
if sequence_sql:
if self.verbosity >= 2:
self.stdout.write('Resetting sequences')
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
self.reset_sequences(connection, self.models)
if self.verbosity >= 1:
if self.fixture_object_count == self.loaded_object_count:
@ -156,6 +165,31 @@ class Command(BaseCommand):
% (self.loaded_object_count, self.fixture_object_count, self.fixture_count)
)
def save_obj(self, obj):
"""Save an object if permitted."""
if (
obj.object._meta.app_config in self.excluded_apps or
type(obj.object) in self.excluded_models
):
return False
saved = False
if router.allow_migrate_model(self.using, obj.object.__class__):
saved = True
self.models.add(obj.object.__class__)
try:
obj.save(using=self.using)
# psycopg2 raises ValueError if data contains NUL chars.
except (DatabaseError, IntegrityError, ValueError) as e:
e.args = ('Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s' % {
'object_label': obj.object._meta.label,
'pk': obj.object.pk,
'error_msg': e,
},)
raise
if obj.deferred_fields:
self.objs_with_deferred_fields.append(obj)
return saved
def load_label(self, fixture_label):
"""Load fixtures files for a given label."""
show_progress = self.verbosity >= 3
@ -179,29 +213,13 @@ class Command(BaseCommand):
for obj in objects:
objects_in_fixture += 1
if (obj.object._meta.app_config in self.excluded_apps or
type(obj.object) in self.excluded_models):
continue
if router.allow_migrate_model(self.using, obj.object.__class__):
if self.save_obj(obj):
loaded_objects_in_fixture += 1
self.models.add(obj.object.__class__)
try:
obj.save(using=self.using)
# psycopg2 raises ValueError if data contains NUL chars.
except (DatabaseError, IntegrityError, ValueError) as e:
e.args = ("Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s" % {
'object_label': obj.object._meta.label,
'pk': obj.object.pk,
'error_msg': e,
},)
raise
if show_progress:
self.stdout.write(
'\rProcessed %i object(s).' % loaded_objects_in_fixture,
ending=''
)
if obj.deferred_fields:
self.objs_with_deferred_fields.append(obj)
except Exception as e:
if not isinstance(e, CommandError):
e.args = ("Problem installing fixture '%s': %s" % (fixture_file, e),)
@ -221,20 +239,7 @@ class Command(BaseCommand):
RuntimeWarning
)
@functools.lru_cache(maxsize=None)
def find_fixtures(self, fixture_label):
"""Find fixture files for a given label."""
if fixture_label == READ_STDIN:
return [(READ_STDIN, None, READ_STDIN)]
fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label)
databases = [self.using, None]
cmp_fmts = list(self.compression_formats) if cmp_fmt is None else [cmp_fmt]
ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt]
if self.verbosity >= 2:
self.stdout.write("Loading '%s' fixtures..." % fixture_name)
def get_fixture_name_and_dirs(self, fixture_name):
dirname, basename = os.path.split(fixture_name)
if os.path.isabs(fixture_name):
fixture_dirs = [dirname]
@ -242,25 +247,48 @@ class Command(BaseCommand):
fixture_dirs = self.fixture_dirs
if os.path.sep in os.path.normpath(fixture_name):
fixture_dirs = [os.path.join(dir_, dirname) for dir_ in fixture_dirs]
fixture_name = basename
return basename, fixture_dirs
suffixes = (
'.'.join(ext for ext in combo if ext)
for combo in product(databases, ser_fmts, cmp_fmts)
)
targets = {'.'.join((fixture_name, suffix)) for suffix in suffixes}
def get_targets(self, fixture_name, ser_fmt, cmp_fmt):
databases = [self.using, None]
cmp_fmts = self.compression_formats if cmp_fmt is None else [cmp_fmt]
ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt]
return {
'%s.%s' % (
fixture_name,
'.'.join([ext for ext in combo if ext]),
) for combo in product(databases, ser_fmts, cmp_fmts)
}
def find_fixture_files_in_dir(self, fixture_dir, fixture_name, targets):
fixture_files_in_dir = []
path = os.path.join(fixture_dir, fixture_name)
for candidate in glob.iglob(glob.escape(path) + '*'):
if os.path.basename(candidate) in targets:
# Save the fixture_dir and fixture_name for future error
# messages.
fixture_files_in_dir.append((candidate, fixture_dir, fixture_name))
return fixture_files_in_dir
@functools.lru_cache(maxsize=None)
def find_fixtures(self, fixture_label):
"""Find fixture files for a given label."""
if fixture_label == READ_STDIN:
return [(READ_STDIN, None, READ_STDIN)]
fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label)
if self.verbosity >= 2:
self.stdout.write("Loading '%s' fixtures..." % fixture_name)
fixture_name, fixture_dirs = self.get_fixture_name_and_dirs(fixture_name)
targets = self.get_targets(fixture_name, ser_fmt, cmp_fmt)
fixture_files = []
for fixture_dir in fixture_dirs:
if self.verbosity >= 2:
self.stdout.write("Checking %s for fixtures..." % humanize(fixture_dir))
fixture_files_in_dir = []
path = os.path.join(fixture_dir, fixture_name)
for candidate in glob.iglob(glob.escape(path) + '*'):
if os.path.basename(candidate) in targets:
# Save the fixture_dir and fixture_name for future error messages.
fixture_files_in_dir.append((candidate, fixture_dir, fixture_name))
fixture_files_in_dir = self.find_fixture_files_in_dir(
fixture_dir, fixture_name, targets,
)
if self.verbosity >= 2 and not fixture_files_in_dir:
self.stdout.write("No fixture '%s' in %s." %
(fixture_name, humanize(fixture_dir)))