125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
import torch
|
|
from torch.nn import functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
import numpy as np
|
|
import math
|
|
import os
|
|
|
|
import hparams
|
|
import Audio
|
|
from text import text_to_sequence
|
|
from utils import process_text, pad_1D, pad_2D
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
class FastSpeechDataset(Dataset):
|
|
""" LJSpeech """
|
|
|
|
def __init__(self):
|
|
self.text = process_text(os.path.join("data", "train.txt"))
|
|
|
|
def __len__(self):
|
|
return len(self.text)
|
|
|
|
def __getitem__(self, idx):
|
|
mel_gt_name = os.path.join(
|
|
hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1))
|
|
mel_gt_target = np.load(mel_gt_name)
|
|
D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy"))
|
|
|
|
character = self.text[idx][0:len(self.text[idx])-1]
|
|
character = np.array(text_to_sequence(
|
|
character, hparams.text_cleaners))
|
|
|
|
sample = {"text": character,
|
|
"mel_target": mel_gt_target,
|
|
"D": D}
|
|
|
|
return sample
|
|
|
|
|
|
def reprocess(batch, cut_list):
|
|
texts = [batch[ind]["text"] for ind in cut_list]
|
|
mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
|
|
Ds = [batch[ind]["D"] for ind in cut_list]
|
|
|
|
length_text = np.array([])
|
|
for text in texts:
|
|
length_text = np.append(length_text, text.shape[0])
|
|
|
|
src_pos = list()
|
|
max_len = int(max(length_text))
|
|
for length_src_row in length_text:
|
|
src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
|
|
(0, max_len-int(length_src_row)), 'constant'))
|
|
src_pos = np.array(src_pos)
|
|
|
|
length_mel = np.array(list())
|
|
for mel in mel_targets:
|
|
length_mel = np.append(length_mel, mel.shape[0])
|
|
|
|
mel_pos = list()
|
|
max_mel_len = int(max(length_mel))
|
|
for length_mel_row in length_mel:
|
|
mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
|
|
(0, max_mel_len-int(length_mel_row)), 'constant'))
|
|
mel_pos = np.array(mel_pos)
|
|
|
|
texts = pad_1D(texts)
|
|
Ds = pad_1D(Ds)
|
|
mel_targets = pad_2D(mel_targets)
|
|
|
|
out = {"text": texts,
|
|
"mel_target": mel_targets,
|
|
"D": Ds,
|
|
"mel_pos": mel_pos,
|
|
"src_pos": src_pos,
|
|
"mel_max_len": max_mel_len}
|
|
|
|
return out
|
|
|
|
|
|
def collate_fn(batch):
|
|
len_arr = np.array([d["text"].shape[0] for d in batch])
|
|
index_arr = np.argsort(-len_arr)
|
|
batchsize = len(batch)
|
|
real_batchsize = int(math.sqrt(batchsize))
|
|
|
|
cut_list = list()
|
|
for i in range(real_batchsize):
|
|
cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
|
|
|
|
output = list()
|
|
for i in range(real_batchsize):
|
|
output.append(reprocess(batch, cut_list[i]))
|
|
|
|
return output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test
|
|
dataset = FastSpeechDataset()
|
|
training_loader = DataLoader(dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
collate_fn=collate_fn,
|
|
drop_last=True,
|
|
num_workers=0)
|
|
total_step = hparams.epochs * len(training_loader) * hparams.batch_size
|
|
|
|
cnt = 0
|
|
for i, batchs in enumerate(training_loader):
|
|
for j, data_of_batch in enumerate(batchs):
|
|
mel_target = torch.from_numpy(
|
|
data_of_batch["mel_target"]).float().to(device)
|
|
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
|
|
# print(mel_target.size())
|
|
# print(D.sum())
|
|
print(cnt)
|
|
if mel_target.size(1) == D.sum().item():
|
|
cnt += 1
|
|
|
|
print(cnt)
|