add dataset prototype and an example for ljspeech
This commit is contained in:
parent
fe4b471036
commit
94b95c7b2c
|
@ -0,0 +1,71 @@
|
||||||
|
"""
|
||||||
|
functions to make batch for arrays which satisfy some conditions.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def text_collate(minibatch):
|
||||||
|
"""
|
||||||
|
minibatch: List[Example]
|
||||||
|
Example: ndarray, shape(T,), dtype: int64
|
||||||
|
"""
|
||||||
|
peek_example = minibatch[0]
|
||||||
|
assert len(peek_example.shape) == 1, "text example is an 1D tensor"
|
||||||
|
|
||||||
|
lengths = [example.shape[0] for example in minibatch] # assume (channel, n_samples) or (n_samples, )
|
||||||
|
max_len = np.max(lengths)
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
for example in minibatch:
|
||||||
|
pad_len = max_len - example.shape[0]
|
||||||
|
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0))
|
||||||
|
|
||||||
|
return np.array(batch, dtype=np.int64)
|
||||||
|
|
||||||
|
def wav_collate(minibatch):
|
||||||
|
"""
|
||||||
|
minibatch: List[Example]
|
||||||
|
Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32
|
||||||
|
"""
|
||||||
|
peek_example = minibatch[0]
|
||||||
|
if len(peek_example.shape) == 1:
|
||||||
|
mono_channel = True
|
||||||
|
elif len(peek_example.shape) == 2:
|
||||||
|
mono_channel = False
|
||||||
|
|
||||||
|
lengths = [example.shape[-1] for example in minibatch] # assume (channel, n_samples) or (n_samples, )
|
||||||
|
max_len = np.max(lengths)
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
for example in minibatch:
|
||||||
|
pad_len = max_len - example.shape[-1]
|
||||||
|
if mono_channel:
|
||||||
|
batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0.))
|
||||||
|
else:
|
||||||
|
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no
|
||||||
|
|
||||||
|
return np.array(batch, dtype=np.float32)
|
||||||
|
|
||||||
|
def spec_collate(minibatch):
|
||||||
|
"""
|
||||||
|
minibatch: List[Example]
|
||||||
|
Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32
|
||||||
|
"""
|
||||||
|
# assume (F, T) or (C, F, T)
|
||||||
|
peek_example = minibatch[0]
|
||||||
|
if len(peek_example.shape) == 2:
|
||||||
|
mono_channel = True
|
||||||
|
elif len(peek_example.shape) == 3:
|
||||||
|
mono_channel = False
|
||||||
|
|
||||||
|
lengths = [example.shape[-1] for example in minibatch] # assume (channel, F, n_frame) or (F, n_frame)
|
||||||
|
max_len = np.max(lengths)
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
for example in minibatch:
|
||||||
|
pad_len = max_len - example.shape[-1]
|
||||||
|
if mono_channel:
|
||||||
|
batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.))
|
||||||
|
else:
|
||||||
|
batch.append(np.pad(example, [(0, 0), (0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no
|
||||||
|
|
||||||
|
return np.array(batch, dtype=np.float32)
|
|
@ -0,0 +1,83 @@
|
||||||
|
from sampler import SequentialSampler, RandomSampler, BatchSampler
|
||||||
|
|
||||||
|
class DataLoader(object):
|
||||||
|
def __init__(self, dataset, batch_size=1, collate_fn = lambda x: x,
|
||||||
|
sampler=None, shuffle=False, batch_sampler=None, drop_last=False):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.collate_fn = collate_fn
|
||||||
|
|
||||||
|
if batch_sampler is not None:
|
||||||
|
# auto_collation with custom batch_sampler
|
||||||
|
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
||||||
|
raise ValueError('batch_sampler option is mutually exclusive '
|
||||||
|
'with batch_size, shuffle, sampler, and '
|
||||||
|
'drop_last')
|
||||||
|
batch_size = None
|
||||||
|
drop_last = False
|
||||||
|
elif batch_size is None:
|
||||||
|
# no auto_collation
|
||||||
|
if shuffle or drop_last:
|
||||||
|
raise ValueError('batch_size=None option disables auto-batching '
|
||||||
|
'and is mutually exclusive with '
|
||||||
|
'shuffle, and drop_last')
|
||||||
|
|
||||||
|
if sampler is None: # give default samplers
|
||||||
|
if shuffle:
|
||||||
|
sampler = RandomSampler(dataset)
|
||||||
|
else:
|
||||||
|
sampler = SequentialSampler(dataset)
|
||||||
|
|
||||||
|
if batch_size is not None and batch_sampler is None:
|
||||||
|
# auto_collation without custom batch_sampler
|
||||||
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.drop_last = drop_last
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_sampler = batch_sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return DataIterator(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _auto_collation(self):
|
||||||
|
# we will auto batching
|
||||||
|
return self.batch_sampler is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _index_sampler(self):
|
||||||
|
# The actual sampler used for generating indices for `_DatasetFetcher`
|
||||||
|
# (see _utils/fetch.py) to read data at each time. This would be
|
||||||
|
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
||||||
|
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
||||||
|
# reasons.
|
||||||
|
if self._auto_collation:
|
||||||
|
return self.batch_sampler
|
||||||
|
else:
|
||||||
|
return self.sampler
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._index_sampler) # with iterable-style dataset, this will error
|
||||||
|
|
||||||
|
class DataIterator(object):
|
||||||
|
def __init__(self, loader):
|
||||||
|
self.loader = loader
|
||||||
|
self._dataset = loader.dataset
|
||||||
|
|
||||||
|
self._index_sampler = loader._index_sampler
|
||||||
|
self._sampler_iter = iter(self._index_sampler)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
|
||||||
|
minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size
|
||||||
|
minibatch = self.loader.collate_fn(minibatch) # list[Example] -> Batch
|
||||||
|
return minibatch
|
||||||
|
|
||||||
|
def _next_index(self):
|
||||||
|
return next(self._sampler_iter)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._index_sampler)
|
|
@ -0,0 +1,25 @@
|
||||||
|
class Dataset(object):
|
||||||
|
def __init__(self, lazy=True, stream=False):
|
||||||
|
# note that lazy and stream means two different things in our glossary
|
||||||
|
# lazy means to place preprocessing in __getitem__
|
||||||
|
# stram means the data source is itself a stream
|
||||||
|
self.lazy = lazy
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
def _load_metadata(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_example(self):
|
||||||
|
"""return a Record"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prepare_metadata(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import librosa
|
||||||
|
import g2p
|
||||||
|
|
||||||
|
from sampler import SequentialSampler, RandomSampler, BatchSampler
|
||||||
|
from dataset import Dataset
|
||||||
|
from dataloader import DataLoader
|
||||||
|
|
||||||
|
from collate import text_collate, spec_collate
|
||||||
|
|
||||||
|
LJSPEECH_ROOT = Path("/Users/chenfeiyu/projects/LJSpeech-1.1")
|
||||||
|
class LJSpeech(Dataset):
|
||||||
|
def __init__(self, root=LJSPEECH_ROOT, lazy=True, stream=False):
|
||||||
|
super(LJSpeech, self).__init__(lazy, stream)
|
||||||
|
self.root = root
|
||||||
|
self.metadata = self._prepare_metadata() # we can do this just for luck
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
self.examples_generator = self._read()
|
||||||
|
|
||||||
|
def _prepare_metadata(self):
|
||||||
|
# if pure-stream case, each _prepare_metadata returns a generator
|
||||||
|
csv_path = self.root.joinpath("metadata.csv")
|
||||||
|
metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3,
|
||||||
|
names=["fname", "raw_text", "normalized_text"])
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _read(self):
|
||||||
|
for _, metadatum in self.metadata.iterrows():
|
||||||
|
example = self._get_example(metadatum)
|
||||||
|
yield example
|
||||||
|
|
||||||
|
def _get_example(self, metadatum):
|
||||||
|
"""All the code for generating an Example from a metadatum. If you want a
|
||||||
|
different preprocessing pipeline, you can override this method.
|
||||||
|
This method may require several processor, each of which has a lot of options.
|
||||||
|
In this case, you'd better pass a composed transform and pass it to the init
|
||||||
|
method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fname, raw_text, normalized_text = metadatum
|
||||||
|
wav_path = self.root.joinpath("wavs", fname + ".wav")
|
||||||
|
|
||||||
|
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
||||||
|
wav, sample_rate = librosa.load(wav_path, sr=None) # we would rather use functor to hold its parameters
|
||||||
|
trimed, _ = librosa.effects.trim(wav)
|
||||||
|
preemphasized = librosa.effects.preemphasis(trimed)
|
||||||
|
D = librosa.stft(preemphasized)
|
||||||
|
mag, phase = librosa.magphase(D)
|
||||||
|
mel = librosa.feature.melspectrogram(S=mag)
|
||||||
|
|
||||||
|
mag = librosa.amplitude_to_db(S=mag)
|
||||||
|
mel = librosa.amplitude_to_db(S=mel)
|
||||||
|
|
||||||
|
ref_db = 20
|
||||||
|
max_db = 100
|
||||||
|
mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1)
|
||||||
|
mel = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1)
|
||||||
|
|
||||||
|
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||||
|
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.stream:
|
||||||
|
raise ValueError("__getitem__ is invalid in stream mode")
|
||||||
|
metadatum = self.metadata.iloc[index]
|
||||||
|
example = self._get_example(metadatum)
|
||||||
|
return example
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.stream:
|
||||||
|
for example in self.examples_generator:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield self[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.stream:
|
||||||
|
raise ValueError("__len__ is invalid in stream mode")
|
||||||
|
return len(self.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def fn(minibatch):
|
||||||
|
mag_batch = []
|
||||||
|
mel_batch = []
|
||||||
|
phoneme_batch = []
|
||||||
|
for example in minibatch:
|
||||||
|
mag, mel, phoneme = example
|
||||||
|
mag_batch.append(mag)
|
||||||
|
mel_batch.append(mel)
|
||||||
|
phoneme_batch.append(phoneme)
|
||||||
|
mag_batch = spec_collate(mag_batch)
|
||||||
|
mel_batch = spec_collate(mel_batch)
|
||||||
|
phoneme_batch = text_collate(phoneme_batch)
|
||||||
|
return (mag_batch, mel_batch, phoneme_batch)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ljspeech = LJSpeech(LJSPEECH_ROOT)
|
||||||
|
ljspeech_loader = DataLoader(ljspeech, batch_size=16, shuffle=True, collate_fn=fn)
|
||||||
|
for i, batch in enumerate(ljspeech_loader):
|
||||||
|
print(i)
|
||||||
|
|
|
@ -0,0 +1,209 @@
|
||||||
|
"""
|
||||||
|
At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__.
|
||||||
|
|
||||||
|
This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use it to collect examples from the dataset with __getitem__. Then collate this examples to form a batch.
|
||||||
|
|
||||||
|
So the sampler is only responsible for generating valid indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
class Sampler(object):
|
||||||
|
def __init__(self, data_source):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# return a iterator of indices
|
||||||
|
# or a iterator of list[int], for BatchSampler
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialSampler(Sampler):
|
||||||
|
def __init__(self, data_source):
|
||||||
|
self.data_source = data_source
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(range(len(self.data_source)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_source)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSampler(Sampler):
|
||||||
|
def __init__(self, data_source, replacement=False, num_samples=None):
|
||||||
|
self.data_source = data_source
|
||||||
|
self.replacement = replacement
|
||||||
|
self._num_samples = num_samples
|
||||||
|
|
||||||
|
if not isinstance(self.replacement, bool):
|
||||||
|
raise ValueError("replacement should be a boolean value, but got "
|
||||||
|
"replacement={}".format(self.replacement))
|
||||||
|
|
||||||
|
if self._num_samples is not None and not replacement:
|
||||||
|
raise ValueError("With replacement=False, num_samples should not be specified, "
|
||||||
|
"since a random permutation will be performed.")
|
||||||
|
|
||||||
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
||||||
|
raise ValueError("num_samples should be a positive integer "
|
||||||
|
"value, but got num_samples={}".format(self.num_samples))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self):
|
||||||
|
# dataset size might change at runtime
|
||||||
|
if self._num_samples is None:
|
||||||
|
return len(self.data_source)
|
||||||
|
return self._num_samples
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
n = len(self.data_source)
|
||||||
|
if self.replacement:
|
||||||
|
return iter(np.random.randint(0, n, size=(self.num_samples,), dtype=np.int64).tolist())
|
||||||
|
return iter(np.random.permutation(n).tolist())
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_source)
|
||||||
|
|
||||||
|
|
||||||
|
class SubsetRandomSampler(Sampler):
|
||||||
|
r"""Samples elements randomly from a given list of indices, without replacement.
|
||||||
|
Arguments:
|
||||||
|
indices (sequence): a sequence of indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, indices):
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.indices[i] for i in np.random.permutation(len(self.indices)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
|
||||||
|
"""Partially randmoized sampler, implemented as a example sampler
|
||||||
|
1. Sort by lengths
|
||||||
|
2. Pick a small patch and randomize it
|
||||||
|
3. Permutate mini-batchs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lengths, batch_size=4, batch_group_size=None,
|
||||||
|
permutate=True):
|
||||||
|
_lengths = np.array(lengths, dtype=np.int64) # maybe better implement length as a sort key
|
||||||
|
self.lengths = np.sort(_lengths)
|
||||||
|
self.sorted_indices = np.argsort(_lengths)
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
if batch_group_size is None:
|
||||||
|
batch_group_size = min(batch_size * 32, len(self.lengths))
|
||||||
|
if batch_group_size % batch_size != 0:
|
||||||
|
batch_group_size -= batch_group_size % batch_size
|
||||||
|
|
||||||
|
self.batch_group_size = batch_group_size
|
||||||
|
assert batch_group_size % batch_size == 0
|
||||||
|
self.permutate = permutate
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = np.copy(self.sorted_indices)
|
||||||
|
batch_group_size = self.batch_group_size
|
||||||
|
s, e = 0, 0
|
||||||
|
for i in range(len(indices) // batch_group_size):
|
||||||
|
s = i * batch_group_size
|
||||||
|
e = s + batch_group_size
|
||||||
|
random.shuffle(indices[s: e]) # inplace
|
||||||
|
|
||||||
|
# Permutate batches
|
||||||
|
if self.permutate:
|
||||||
|
perm = np.arange(len(indices[:e]) // self.batch_size)
|
||||||
|
random.shuffle(perm)
|
||||||
|
indices[:e] = indices[:e].reshape(-1, self.batch_size)[perm, :].reshape(-1)
|
||||||
|
|
||||||
|
# Handle last elements
|
||||||
|
s += batch_group_size
|
||||||
|
#print(indices)
|
||||||
|
if s < len(indices):
|
||||||
|
random.shuffle(indices[s:])
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sorted_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedRandomSampler(Sampler):
|
||||||
|
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
||||||
|
Args:
|
||||||
|
weights (sequence) : a sequence of weights, not necessary summing up to one
|
||||||
|
num_samples (int): number of samples to draw
|
||||||
|
replacement (bool): if ``True``, samples are drawn with replacement.
|
||||||
|
If not, they are drawn without replacement, which means that when a
|
||||||
|
sample index is drawn for a row, it cannot be drawn again for that row.
|
||||||
|
Example:
|
||||||
|
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
||||||
|
[0, 0, 0, 1, 0]
|
||||||
|
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
||||||
|
[0, 1, 4, 3, 2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weights, num_samples, replacement):
|
||||||
|
if not isinstance(num_samples, int) or num_samples <= 0:
|
||||||
|
raise ValueError("num_samples should be a positive integer "
|
||||||
|
"value, but got num_samples={}".format(num_samples))
|
||||||
|
self.weights = np.array(weights, dtype=np.float64)
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.replacement = replacement
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(np.random.choice(len(self.weights), size=(self.num_samples, ),
|
||||||
|
replace=self.replacement, p=self.weights).tolist())
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSampler(Sampler):
|
||||||
|
r"""Wraps another sampler to yield a mini-batch of indices.
|
||||||
|
Args:
|
||||||
|
sampler (Sampler): Base sampler.
|
||||||
|
batch_size (int): Size of mini-batch.
|
||||||
|
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
||||||
|
its size would be less than ``batch_size``
|
||||||
|
Example:
|
||||||
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
|
||||||
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
||||||
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
|
||||||
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler, batch_size, drop_last):
|
||||||
|
if not isinstance(sampler, Sampler):
|
||||||
|
raise ValueError("sampler should be an instance of "
|
||||||
|
"Sampler, but got sampler={}"
|
||||||
|
.format(sampler))
|
||||||
|
if not isinstance(batch_size, int) or batch_size <= 0:
|
||||||
|
raise ValueError("batch_size should be a positive integer value, "
|
||||||
|
"but got batch_size={}".format(batch_size))
|
||||||
|
if not isinstance(drop_last, bool):
|
||||||
|
raise ValueError("drop_last should be a boolean value, but got "
|
||||||
|
"drop_last={}".format(drop_last))
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.drop_last = drop_last
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batch = []
|
||||||
|
for idx in self.sampler:
|
||||||
|
batch.append(idx)
|
||||||
|
if len(batch) == self.batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0 and not self.drop_last:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.drop_last:
|
||||||
|
return len(self.sampler) // self.batch_size
|
||||||
|
else:
|
||||||
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
Loading…
Reference in New Issue