239 lines
8.7 KiB
Python
239 lines
8.7 KiB
Python
"""
|
|
At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__.
|
|
|
|
This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use it to collect examples from the dataset with __getitem__. Then collate this examples to form a batch.
|
|
|
|
So the sampler is only responsible for generating valid indices.
|
|
"""
|
|
|
|
|
|
import numpy as np
|
|
import random
|
|
|
|
class Sampler(object):
|
|
def __init__(self, data_source):
|
|
pass
|
|
|
|
def __iter__(self):
|
|
# return a iterator of indices
|
|
# or a iterator of list[int], for BatchSampler
|
|
raise NotImplementedError
|
|
|
|
|
|
class SequentialSampler(Sampler):
|
|
def __init__(self, data_source):
|
|
self.data_source = data_source
|
|
|
|
def __iter__(self):
|
|
return iter(range(len(self.data_source)))
|
|
|
|
def __len__(self):
|
|
return len(self.data_source)
|
|
|
|
|
|
class RandomSampler(Sampler):
|
|
def __init__(self, data_source, replacement=False, num_samples=None):
|
|
self.data_source = data_source
|
|
self.replacement = replacement
|
|
self._num_samples = num_samples
|
|
|
|
if not isinstance(self.replacement, bool):
|
|
raise ValueError("replacement should be a boolean value, but got "
|
|
"replacement={}".format(self.replacement))
|
|
|
|
if self._num_samples is not None and not replacement:
|
|
raise ValueError("With replacement=False, num_samples should not be specified, "
|
|
"since a random permutation will be performed.")
|
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
|
raise ValueError("num_samples should be a positive integer "
|
|
"value, but got num_samples={}".format(self.num_samples))
|
|
|
|
@property
|
|
def num_samples(self):
|
|
# dataset size might change at runtime
|
|
if self._num_samples is None:
|
|
return len(self.data_source)
|
|
return self._num_samples
|
|
|
|
def __iter__(self):
|
|
n = len(self.data_source)
|
|
if self.replacement:
|
|
return iter(np.random.randint(0, n, size=(self.num_samples,), dtype=np.int64).tolist())
|
|
return iter(np.random.permutation(n).tolist())
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
|
|
class SubsetRandomSampler(Sampler):
|
|
r"""Samples elements randomly from a given list of indices, without replacement.
|
|
Arguments:
|
|
indices (sequence): a sequence of indices
|
|
"""
|
|
|
|
def __init__(self, indices):
|
|
self.indices = indices
|
|
|
|
def __iter__(self):
|
|
return (self.indices[i] for i in np.random.permutation(len(self.indices)))
|
|
|
|
def __len__(self):
|
|
return len(self.indices)
|
|
|
|
|
|
class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
|
|
"""Partially randmoized sampler, implemented as a example sampler
|
|
1. Sort by lengths
|
|
2. Pick a small patch and randomize it
|
|
3. Permutate mini-batchs
|
|
"""
|
|
|
|
def __init__(self, lengths, batch_size=4, batch_group_size=None,
|
|
permutate=True):
|
|
_lengths = np.array(lengths, dtype=np.int64) # maybe better implement length as a sort key
|
|
self.lengths = np.sort(_lengths)
|
|
self.sorted_indices = np.argsort(_lengths)
|
|
|
|
self.batch_size = batch_size
|
|
if batch_group_size is None:
|
|
batch_group_size = min(batch_size * 32, len(self.lengths))
|
|
if batch_group_size % batch_size != 0:
|
|
batch_group_size -= batch_group_size % batch_size
|
|
|
|
self.batch_group_size = batch_group_size
|
|
assert batch_group_size % batch_size == 0
|
|
self.permutate = permutate
|
|
|
|
def __iter__(self):
|
|
indices = np.copy(self.sorted_indices)
|
|
batch_group_size = self.batch_group_size
|
|
s, e = 0, 0
|
|
for i in range(len(indices) // batch_group_size):
|
|
s = i * batch_group_size
|
|
e = s + batch_group_size
|
|
random.shuffle(indices[s: e]) # inplace
|
|
|
|
# Permutate batches
|
|
if self.permutate:
|
|
perm = np.arange(len(indices[:e]) // self.batch_size)
|
|
random.shuffle(perm)
|
|
indices[:e] = indices[:e].reshape(-1, self.batch_size)[perm, :].reshape(-1)
|
|
|
|
# Handle last elements
|
|
s += batch_group_size
|
|
#print(indices)
|
|
if s < len(indices):
|
|
random.shuffle(indices[s:])
|
|
|
|
return iter(indices)
|
|
|
|
def __len__(self):
|
|
return len(self.sorted_indices)
|
|
|
|
|
|
class WeightedRandomSampler(Sampler):
|
|
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
|
Args:
|
|
weights (sequence) : a sequence of weights, not necessary summing up to one
|
|
num_samples (int): number of samples to draw
|
|
replacement (bool): if ``True``, samples are drawn with replacement.
|
|
If not, they are drawn without replacement, which means that when a
|
|
sample index is drawn for a row, it cannot be drawn again for that row.
|
|
Example:
|
|
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
|
[0, 0, 0, 1, 0]
|
|
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
|
[0, 1, 4, 3, 2]
|
|
"""
|
|
|
|
def __init__(self, weights, num_samples, replacement):
|
|
if not isinstance(num_samples, int) or num_samples <= 0:
|
|
raise ValueError("num_samples should be a positive integer "
|
|
"value, but got num_samples={}".format(num_samples))
|
|
self.weights = np.array(weights, dtype=np.float64)
|
|
self.num_samples = num_samples
|
|
self.replacement = replacement
|
|
|
|
def __iter__(self):
|
|
return iter(np.random.choice(len(self.weights), size=(self.num_samples, ),
|
|
replace=self.replacement, p=self.weights).tolist())
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
|
|
class DistributedSampler(Sampler):
|
|
def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
|
|
self.dataset_size = dataset_size
|
|
self.num_trainers = num_trainers
|
|
self.rank = rank
|
|
self.num_samples = int(np.ceil(dataset_size / num_trainers))
|
|
self.total_size = self.num_samples * num_trainers
|
|
assert self.total_size >= self.dataset_size
|
|
self.shuffle = shuffle
|
|
|
|
def __iter__(self):
|
|
indices = list(range(self.dataset_size))
|
|
if self.shuffle:
|
|
random.shuffle(indices)
|
|
|
|
# Append extra samples to make it evenly distributed on all trainers.
|
|
indices += indices[:(self.total_size - self.dataset_size)]
|
|
assert len(indices) == self.total_size
|
|
|
|
# Subset samples for each trainer.
|
|
indices = indices[self.rank:self.total_size:self.num_trainers]
|
|
assert len(indices) == self.num_samples
|
|
|
|
return iter(indices)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
|
|
class BatchSampler(Sampler):
|
|
r"""Wraps another sampler to yield a mini-batch of indices.
|
|
Args:
|
|
sampler (Sampler): Base sampler.
|
|
batch_size (int): Size of mini-batch.
|
|
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
|
its size would be less than ``batch_size``
|
|
Example:
|
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
|
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
|
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
|
"""
|
|
|
|
def __init__(self, sampler, batch_size, drop_last):
|
|
if not isinstance(sampler, Sampler):
|
|
raise ValueError("sampler should be an instance of "
|
|
"Sampler, but got sampler={}"
|
|
.format(sampler))
|
|
if not isinstance(batch_size, int) or batch_size <= 0:
|
|
raise ValueError("batch_size should be a positive integer value, "
|
|
"but got batch_size={}".format(batch_size))
|
|
if not isinstance(drop_last, bool):
|
|
raise ValueError("drop_last should be a boolean value, but got "
|
|
"drop_last={}".format(drop_last))
|
|
self.sampler = sampler
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
|
|
def __iter__(self):
|
|
batch = []
|
|
for idx in self.sampler:
|
|
batch.append(idx)
|
|
if len(batch) == self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
if len(batch) > 0 and not self.drop_last:
|
|
yield batch
|
|
|
|
def __len__(self):
|
|
if self.drop_last:
|
|
return len(self.sampler) // self.batch_size
|
|
else:
|
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|