30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
|
from pathlib import Path
|
||
|
import numpy as np
|
||
|
from paddle import fluid
|
||
|
from parakeet.data.sampler import DistributedSampler
|
||
|
from parakeet.data.datacargo import DataCargo
|
||
|
from preprocess import batch_examples, LJSpeech, batch_examples_postnet
|
||
|
|
||
|
class LJSpeechLoader:
|
||
|
def __init__(self, config, nranks, rank, is_postnet=False):
|
||
|
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
|
||
|
|
||
|
LJSPEECH_ROOT = Path(config.data_path)
|
||
|
dataset = LJSpeech(LJSPEECH_ROOT)
|
||
|
sampler = DistributedSampler(len(dataset), nranks, rank)
|
||
|
|
||
|
assert config.batch_size % nranks == 0
|
||
|
each_bs = config.batch_size // nranks
|
||
|
if is_postnet:
|
||
|
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_postnet, drop_last=True)
|
||
|
else:
|
||
|
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True)
|
||
|
|
||
|
self.reader = fluid.io.DataLoader.from_generator(
|
||
|
capacity=32,
|
||
|
iterable=True,
|
||
|
use_double_buffer=True,
|
||
|
return_list=True)
|
||
|
self.reader.set_batch_generator(dataloader, place)
|
||
|
|